In [1]:
%matplotlib inline
%pprint on

Pretty printing has been turned OFF


In [2]:
import numpy as np
from matplotlib import pyplot as plt
import utils
import os
from PIL import Image
import json
import torch
import json
from gensim.models.keyedvectors import KeyedVectors
from types import SimpleNamespace
from networks import TextEncoder, ImageEncoder, DiscourseClassifier
from datasets import CoherenceDataset, val_transform
from torch.utils.data import DataLoader
from tqdm import tqdm
import re
from PIL import Image
import requests
from io import BytesIO
from torchvision import transforms

# Explore Attention

In [3]:
device = 'cuda'
args = {
    'data_source': 'conceptual',
    'img_dir': '../data/conceptual/images/',
    'word2vec_dim': 300,
    'rnn_hid_dim': 300,
    'feature_dim': 1024,
    'max_len': 40,
    'dataset_q': 0,
    'with_attention': 2,
    'batch_size': 64,
    'workers': 4
}
args = SimpleNamespace(**args)
relations = relations = ['Visible', 'Subjective', 'Action', 'Story', 'Meta', 'Irrelevant']

## Dataset

In [4]:
test_set = CoherenceDataset(
            part='test',
            datasource=args.data_source,
            img_dir=args.img_dir,
            word2vec_file=f'models/word2vec_{args.data_source}.bin',
            max_len=args.max_len,
            dataset_q=args.dataset_q,  # experimental things, ignore it for now
            transform=val_transform)

test_loader = DataLoader(
            test_set, batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True,
            drop_last=False)

print('test data:', len(test_set), len(test_loader))

vocab size = 5612
test data: 1512 24


In [5]:
test_set.n2p

[0.5074775672981057, 15.085106382978724, 5.406779661016949, 3.0427807486631018, 1.597938144329897, 9.5]

## Model

In [12]:
def load_model(path):
    text_encoder = TextEncoder(
        emb_dim=args.word2vec_dim,
        hid_dim=args.rnn_hid_dim,
        z_dim=args.feature_dim,
        max_len = args.max_len,
        word2vec_file=f'models/word2vec_{args.data_source}.bin',
        with_attention=args.with_attention).to(device)
    image_encoder = ImageEncoder(
        z_dim=args.feature_dim).to(device)
    discourse_class = DiscourseClassifier(
        len(relations), args.feature_dim).to(device)

    ckpt = torch.load(path)
    text_encoder.load_state_dict(ckpt['text_encoder'])
    image_encoder.load_state_dict(ckpt['image_encoder'])
    discourse_class.load_state_dict(ckpt['discourse_class'])
    return text_encoder, image_encoder, discourse_class

# path_vis = 'runs/samples6047_retrieval=1.00_classification=0.10_reweight=1000.00_weightDecay=0.0_withAttention=2_question=0_maxLen=40/e14.ckpt'

# path_sub = 'runs/samples6047_retrieval=1.00_classification=0.10_reweight=1000.00_weightDecay=0.0_withAttention=2_question=1_maxLen=40/e14.ckpt'

# path_act = 'runs/samples6047_retrieval=1.00_classification=0.10_reweight=1000.00_weightDecay=0.0_withAttention=2_question=2_maxLen=40/e14.ckpt'

# path_sto = 'runs/samples6047_retrieval=1.00_classification=0.10_reweight=1000.00_weightDecay=0.0_withAttention=2_question=3_maxLen=40/e14.ckpt'

path_met = 'runs/samples6047_retrieval=1.00_classification=0.10_reweight=1000.00_weightDecay=0.0_withAttention=2_question=4_maxLen=40/e14.ckpt'
path_irr = 'runs/samples6047_retrieval=1.00_classification=0.10_reweight=1000.00_weightDecay=0.0_withAttention=2_question=5_maxLen=40/e14.ckpt'

path_all = 'runs/samples6047_retrieval=1.00_classification=0.00_reweight=1000.00_weightDecay=0.0_withAttention=2_question=0,1,2,3,4,5_maxLen=40/e14.ckpt'

# t_vis, i_vis, d_vis = load_model(path_vis)
# t_sub, i_sub, d_sub = load_model(path_sub)
# t_act, i_act, d_act = load_model(path_act)
# t_sto, i_sto, d_sto = load_model(path_sto)

t_met, i_met, d_met = load_model(path_met)
t_irr, i_irr, d_irr = load_model(path_irr)

t_all, i_all, d_all = load_model(path_all)

In [13]:
def generate_output(test_loader, text_encoder, image_encoder, discourse_class, valid_questions):
    txt_feats = []
    img_feats = []
    probs = []
    labels = []
    attns = []
    for batch in tqdm(test_loader):
        for i in range(len(batch)):
            batch[i] = batch[i].to(device)
        txt, txt_len, img, target = batch
        with torch.no_grad():
            txt_feat, attn = text_encoder(txt.long(), txt_len)
            img_feat = image_encoder(img)
            prob = torch.sigmoid(discourse_class(txt_feat, img_feat))[:,valid_questions]
            txt_feats.append(txt_feat.detach().cpu())
            img_feats.append(img_feat.detach().cpu())
            probs.append(prob.detach().cpu())
            attns.append(attn.detach().cpu())
            labels.append(target[:,valid_questions].detach().cpu())

    txt_feats = torch.cat(txt_feats, dim=0).numpy()
    img_feats = torch.cat(img_feats, dim=0).numpy()
    probs = torch.cat(probs, dim=0).numpy()
    labels = torch.cat(labels, dim=0).numpy()
    attns = torch.cat(attns, dim=0).numpy()
    return probs, labels, attns, txt_feats, img_feats

# valid_questions_vis = torch.tensor([0], dtype=torch.long) # Visible
# valid_questions_sub = torch.tensor([1], dtype=torch.long) # Subjective
# valid_questions_act = torch.tensor([2], dtype=torch.long) # Action
# valid_questions_sto = torch.tensor([3], dtype=torch.long) # Story

valid_questions_met = torch.tensor([4], dtype=torch.long) # Meta
valid_questions_irr = torch.tensor([5], dtype=torch.long) # Irrelavent

valid_questions_all = torch.tensor([0,1,2,3,4,5], dtype=torch.long) # all

In [69]:
probs_vis, labels_vis, attns_vis, txt_vis, img_vis = generate_output(test_loader, t_vis, i_vis, d_vis, valid_questions_vis)

100%|██████████| 24/24 [00:50<00:00,  2.11s/it]


In [70]:
probs_sub, labels_sub, attns_sub, txt_sub, img_sub = generate_output(test_loader, t_sub, i_sub, d_sub, valid_questions_sub)

100%|██████████| 24/24 [00:47<00:00,  1.97s/it]


In [71]:
probs_act, labels_act, attns_act, txt_act, img_act = generate_output(test_loader, t_act, i_act, d_act, valid_questions_act)

100%|██████████| 24/24 [00:47<00:00,  1.97s/it]


In [72]:
probs_sto, labels_sto, attns_sto, txt_sto, img_sto = generate_output(test_loader, t_sto, i_sto, d_sto, valid_questions_sto)

100%|██████████| 24/24 [00:46<00:00,  1.92s/it]


In [8]:
probs_met, labels_met, attns_met, txt_met, img_met = generate_output(test_loader, t_met, i_met, d_met, valid_questions_met)

100%|██████████| 24/24 [01:49<00:00,  4.56s/it]


In [9]:
probs_irr, labels_irr, attns_irr, txt_irr, img_irr = generate_output(test_loader, t_irr, i_irr, d_irr, valid_questions_irr)

100%|██████████| 24/24 [00:50<00:00,  2.08s/it]


In [14]:
probs_all, labels_all, attns_all, txt_all, img_all = generate_output(test_loader, t_all, i_all, d_all, valid_questions_all)

100%|██████████| 24/24 [00:57<00:00,  2.41s/it]


In [15]:
# ranks_vis = utils.compute_ranks(txt_vis, img_vis)
# ranks_sub = utils.compute_ranks(txt_sub, img_sub)
# ranks_act = utils.compute_ranks(txt_act, img_act)
# ranks_sto = utils.compute_ranks(txt_sto, img_sto)
ranks_met = utils.compute_ranks(txt_met, img_met)
ranks_irr = utils.compute_ranks(txt_irr, img_irr)
ranks_all = utils.compute_ranks(txt_all, img_all)
ranks_all.shape

(1512, 1512)

## Compare baseline with single-discourse model

### Which single-discourse model?

In [252]:
# relation = 'Visible'
# probs = probs_vis
# labels = labels_vis
# attns = attns_vis
# ranks = ranks_vis
# relation_idx = 0

# relation = 'Subjective'
# probs = probs_sub
# labels = labels_sub
# attns = attns_sub
# ranks = ranks_sub
# relation_idx = 1

# relation = 'Action'
# probs = probs_act
# labels = labels_act
# attns = attns_act
# ranks = ranks_act
# relation_idx = 2

# relation = 'Story'
# probs = probs_sto
# labels = labels_sto
# attns = attns_sto
# ranks = ranks_sto
# relation_idx = 3

relation = 'Meta'
probs = probs_met
labels = labels_met
attns = attns_met
ranks = ranks_met
relation_idx = 4


# relation = 'Irrelavant'
# probs = probs_irr
# labels = labels_irr
# attns = attns_irr
# ranks = ranks_irr
# relation_idx = 5


def get_pos(ranks):
    out = []
    for ii, rank in enumerate(ranks):
        pos = rank.tolist().index(ii)
        out.append(pos)
    return np.asarray(out)

positions = get_pos(ranks)
positions_all = get_pos(ranks_all)

print(probs.shape, labels.shape, attns.shape, ranks.shape, positions.shape)
print(probs_all.shape, labels_all.shape, attns_all.shape, ranks_all.shape, positions_all.shape)

(1512, 1) (1512, 1) (1512, 40) (1512, 1512) (1512,)
(1512, 6) (1512, 6) (1512, 40) (1512, 1512) (1512,)


### Discourse is True or False?

In [253]:
T = 1
t_indices = np.where(labels==T)[0]
t_positions = positions[t_indices]
tmp = np.argsort(t_positions, 0).squeeze()
t_indices = t_indices[tmp]
t_indices.shape

(1003,)

In [254]:
import json
img2id = json.load(open('./../data/conceptual/img2idxmap.json', 'r'))
len(img2id)

5348

In [255]:
imgheader = 'url'
capheader = 'caption'
img_dir = '../data/conceptual/images/'
save_dir = f'outputs/clue/{relation}={T}'

### Which sample to show?

In [256]:
save_dir = os.path.join(save_dir, 'coherence_win')
# save_dir = os.path.join(save_dir, 'coherence_lose')
os.makedirs(save_dir, exist_ok=True)

good_indices = []
for a,b,idx in zip(positions[t_indices], positions_all[t_indices], t_indices):
    if a<5 and b-a >= 5:
#     if b-a >0:
#     if b<5 and a-b >= 5:
        good_indices.append(idx)
    

print(save_dir)
len(good_indices)

outputs/clue/Visible=1/coherence_win


36

In [257]:
with open(os.path.join(save_dir, 'captions.txt'), 'w') as f:
    for idx in tqdm(good_indices[:20]):
        rcp = test_set.recipes.iloc[idx]
        cap = rcp[capheader]
        f.write(f'{idx:>5d} ' + cap + '\n')

100%|██████████| 20/20 [00:00<00:00, 2300.14it/s]


In [258]:
my_trans = transforms.Compose([
    transforms.Resize(512),
    transforms.CenterCrop(512)
])


def show_img(rcp, title=None):
    img_name = img2id[rcp[imgheader]]
    img_path = './../data/conceptual/images/{}.jpg'.format(img_name)
    img = my_trans(Image.open(img_path))
#     _ = plt.imshow(np.asarray(img))
#     plt.axis('off')
    if title:
        img.save(title, 'JPEG')

# idx = 96

for idx in tqdm(good_indices[:20]):
    sub_dir = os.path.join(save_dir, str(idx))
    os.makedirs(sub_dir, exist_ok=True)
    f = open(os.path.join(sub_dir, 'attentions.txt'), 'w')

    rcp = test_set.recipes.iloc[idx]
    cap = rcp[capheader]
    cap = utils.clean_caption(cap)
    words = re.split(r'\\n| ', cap)[:args.max_len]
#     print(cap)
#     print(words)

#     print(f'\n==> {relation} == {T}')
#     print(probs[idx])
    f.write(f'Model: {relation} is used\n')
    f.write(f'prob = {probs[idx][0]:.4f}\n')
    for w, attn in zip(words, attns[idx][:len(words)]):
        line = f'{w:>20s} = {attn:<.4f}'
#         print(line)
        f.write(line + '\n')


#     print('\n==> no relation')
#     print(probs_all[idx][relation_idx])
    f.write(f'\nModel: no relation is used\n')
    f.write(f'prob = {probs_all[idx][relation_idx]:.4f}\n')
    for w, attn in zip(words, attns_all[idx][:len(words)]):
        line = f'{w:>20s} = {attn:<.4f}'
#         print(line)
        f.write(line + '\n')


    show_img(rcp, title=os.path.join(sub_dir, f'real.jpg'))

    # ranks
    rank = ranks[idx]
#     fig = plt.figure(figsize=(12,6))
    pos = rank.tolist().index(idx)
    line = f'{relation} == {T}. Top 5 retrieved images, while true image is at pos={pos}'
#     fig.suptitle(line, y=0.7)
    f.write(line+'\n')
    i = 0
    for idx_ in rank[:5]:
#         plt.subplot(151+i)
        show_img(test_set.recipes.iloc[idx_], title=os.path.join(sub_dir, f'{relation}_top{i}.jpg'))
        i+=1


    rank_all = ranks_all[idx]
#     fig = plt.figure(figsize=(12,6))
    pos = rank_all.tolist().index(idx)
    line = f'No Relation. Top 5 retrieved images, while true image is at pos={pos}'
#     fig.suptitle(line, y=0.7)
    f.write(line+'\n')
    i = 0
    for idx_ in rank_all[:5]:
#         plt.subplot(151+i)
        show_img(test_set.recipes.iloc[idx_], title=os.path.join(sub_dir, f'no_relation_top{i}.jpg'))
        i+=1

    f.close()

100%|██████████| 20/20 [00:53<00:00,  2.67s/it]


# Generate samples for human evaluation

### Which single-discourse model?

In [33]:
# relation = 'Visible'
# probs = probs_vis
# labels = labels_vis
# attns = attns_vis
# ranks = ranks_vis
# relation_idx = 0

# relation = 'Subjective'
# probs = probs_sub
# labels = labels_sub
# attns = attns_sub
# ranks = ranks_sub
# relation_idx = 1

# relation = 'Action'
# probs = probs_act
# labels = labels_act
# attns = attns_act
# ranks = ranks_act
# relation_idx = 2

# relation = 'Story'
# probs = probs_sto
# labels = labels_sto
# attns = attns_sto
# ranks = ranks_sto
# relation_idx = 3

# relation = 'Meta'
# probs = probs_met
# labels = labels_met
# attns = attns_met
# ranks = ranks_met
# relation_idx = 4


relation = 'Irrelavant'
probs = probs_irr
labels = labels_irr
attns = attns_irr
ranks = ranks_irr
relation_idx = 5

def get_pos(ranks):
    out = []
    for ii, rank in enumerate(ranks):
        pos = rank.tolist().index(ii)
        out.append(pos)
    return np.asarray(out)

positions = get_pos(ranks)
positions_all = get_pos(ranks_all)

print(probs.shape, labels.shape, attns.shape, ranks.shape, positions.shape)
print(probs_all.shape, labels_all.shape, attns_all.shape, ranks_all.shape, positions_all.shape)

(1512, 1) (1512, 1) (1512, 40) (1512, 1512) (1512,)
(1512, 6) (1512, 6) (1512, 40) (1512, 1512) (1512,)


In [34]:
T = 1
t_indices = np.where(labels==T)[0]
t_positions = positions[t_indices]
tmp = np.argsort(t_positions, 0).squeeze()
t_indices = t_indices[tmp]
t_indices.shape

(144,)

In [35]:
save_dir = f'outputs/clue/{relation}={T}/human_evaluation'
os.makedirs(save_dir, exist_ok=True)

good_indices = []
for a,b,idx in zip(positions[t_indices], positions_all[t_indices], t_indices):
    if b-a >0:
        good_indices.append(idx)

print(save_dir)
len(good_indices)

outputs/clue/Irrelavant=1/human_evaluation


68

In [36]:
import json
img2id = json.load(open('./../data/conceptual/img2idxmap.json', 'r'))
len(img2id)

5348

In [37]:
imgheader = 'url'
capheader = 'caption'
img_dir = '../data/conceptual/images/'
save_dir = f'outputs/clue/{relation}={T}'

In [38]:
with open(os.path.join(save_dir, 'captions.txt'), 'w') as f:
    for idx in tqdm(good_indices[:100]):
        rcp = test_set.recipes.iloc[idx]
        cap = rcp[capheader]
        f.write(f'{idx:>5d} ' + cap + '\n')

100%|██████████| 68/68 [00:00<00:00, 3225.48it/s]


In [39]:
my_trans = transforms.Compose([
    transforms.Resize(512),
    transforms.CenterCrop(512)
])


def show_img(rcp, title=None):
    img_name = img2id[rcp[imgheader]]
    img_path = './../data/conceptual/images/{}.jpg'.format(img_name)
    img = my_trans(Image.open(img_path))
    if title:
        img.save(title, 'JPEG')


for idx in tqdm(good_indices[:100]):
    rcp = test_set.recipes.iloc[idx]

    rank = ranks[idx]
    for idx_ in rank[:1]:
        show_img(test_set.recipes.iloc[idx_], title=os.path.join(save_dir, f'{idx}_cohaware.jpg'))


    rank_all = ranks_all[idx]
    for idx_ in rank_all[:1]:
        show_img(test_set.recipes.iloc[idx_], title=os.path.join(save_dir, f'{idx}_cohagnostic.jpg'))

100%|██████████| 68/68 [00:32<00:00,  2.07it/s]
