In [3]:
%matplotlib inline
%pprint on

Pretty printing has been turned ON


In [4]:
import numpy as np
from matplotlib import pyplot as plt
import utils
import os
from PIL import Image
import json
import torch
import json
from gensim.models.keyedvectors import KeyedVectors
from types import SimpleNamespace
from networks import TextEncoder, ImageEncoder, DiscourseClassifier
from datasets import CoherenceDataset, val_transform
from torch.utils.data import DataLoader
from tqdm import tqdm
import re
from PIL import Image
import requests
from io import BytesIO
from torchvision import transforms

# Explore Attention

In [3]:
device = 'cuda'
args = {
    'data_source': 'recipe',
    'img_dir': '../data/RecipeQA/images-qa/train/images-qa',
    'word2vec_dim': 300,
    'rnn_hid_dim': 300,
    'feature_dim': 1024,
    'max_len': 200,
    'dataset_q': 0,
    'with_attention': 2,
    'batch_size': 64,
    'workers': 4
}
args = SimpleNamespace(**args)
relations = ['q2_resp', 'q3_resp', 'q4_resp', 'q5_resp', 'q6_resp', 'q7_resp', 'q8_resp']

## Dataset

In [5]:
train_set = CoherenceDataset(
            part='train',
            datasource=args.data_source,
            word2vec_file=f'models/word2vec_{args.data_source}.bin',
            max_len=args.max_len,
            dataset_q=args.dataset_q,  # experimental things, ignore it for now
            transform=val_transform)

train_loader = DataLoader(
            train_set, batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True,
            drop_last=False)

print('train data:', len(train_set), len(train_loader))

vocab size = 8918
train data: 3439 54


In [6]:
test_set = CoherenceDataset(
            part='test',
            datasource=args.data_source,
            word2vec_file=f'models/word2vec_{args.data_source}.bin',
            max_len=args.max_len,
            dataset_q=args.dataset_q,  # experimental things, ignore it for now
            transform=val_transform)

test_loader = DataLoader(
            test_set, batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True,
            drop_last=False)

print('test data:', len(test_set), len(test_loader))

vocab size = 8918
test data: 860 14


In [7]:
train_set.n2p

[0.21777620396600567, 7.684343434343434, 2.043362831858407, 4.171428571428572, 5.333333333333333, 0.7016328550222662, 2.196096654275093]

## Model

In [8]:
def load_model(path):
    text_encoder = TextEncoder(
        emb_dim=args.word2vec_dim,
        hid_dim=args.rnn_hid_dim,
        z_dim=args.feature_dim,
        max_len = args.max_len,
        word2vec_file=f'models/word2vec_{args.data_source}.bin',
        with_attention=args.with_attention).to(device)
    image_encoder = ImageEncoder(
        z_dim=args.feature_dim).to(device)
    discourse_class = DiscourseClassifier(
        len(relations), args.feature_dim).to(device)

    ckpt = torch.load(path)
    text_encoder.load_state_dict(ckpt['text_encoder'])
    image_encoder.load_state_dict(ckpt['image_encoder'])
    discourse_class.load_state_dict(ckpt['discourse_class'])
    return text_encoder, image_encoder, discourse_class

path_base = 'runs/samples3439_retrieval=1.00_classification=0.00_reweight=1000.00_weightDecay=0.0_withAttention=2_question=2,3,4,5,6,7,8_maxLen=200/e19.ckpt'
path_all = 'runs/samples3439_retrieval=1.00_classification=0.10_reweight=1000.00_weightDecay=0.0_withAttention=2_question=2,3,4,5,6,7,8_maxLen=200/e19.ckpt'

t_base, i_base, d_base = load_model(path_base)

In [9]:
def generate_output(test_loader, text_encoder, image_encoder, discourse_class, valid_questions):
    txt_feats = []
    img_feats = []
    probs = []
    labels = []
    attns = []
    for batch in tqdm(test_loader):
        for i in range(len(batch)):
            batch[i] = batch[i].to(device)
        txt, txt_len, img, target = batch
        with torch.no_grad():
            txt_feat, attn = text_encoder(txt.long(), txt_len)
            img_feat = image_encoder(img)
            prob = torch.sigmoid(discourse_class(txt_feat, img_feat))[:,valid_questions]
            txt_feats.append(txt_feat.detach().cpu())
            img_feats.append(img_feat.detach().cpu())
            probs.append(prob.detach().cpu())
            attns.append(attn.detach().cpu())
            labels.append(target[:,valid_questions].detach().cpu())

    txt_feats = torch.cat(txt_feats, dim=0).numpy()
    img_feats = torch.cat(img_feats, dim=0).numpy()
    probs = torch.cat(probs, dim=0).numpy()
    labels = torch.cat(labels, dim=0).numpy()
    attns = torch.cat(attns, dim=0).numpy()
    return probs, labels, attns, txt_feats, img_feats

valid_questions_base = torch.tensor([0,1,2,3,4,5,6], dtype=torch.long) # base
probs_base, labels_base, attns_base, txt_base, img_base = generate_output(test_loader, t_base, i_base, d_base, valid_questions_base)

100%|██████████| 14/14 [00:09<00:00,  1.52it/s]


# Which single-discourse model to load?

In [290]:
q_id = 2
path = f'runs/samples3439_retrieval=1.00_classification=0.10_reweight=1000.00_weightDecay=0.0_withAttention=2_question={q_id}_maxLen=200/e19.ckpt'
relation = f'q{q_id}'
relation_idx = q_id-2
valid_questions_q = torch.tensor([relation_idx], dtype=torch.long) # Q2

In [291]:
t_q, i_q, d_q = load_model(path)
probs_q, labels_q, attns_q, txt_q, img_q = generate_output(test_loader, t_q, i_q, d_q, valid_questions_q)

100%|██████████| 14/14 [00:06<00:00,  2.03it/s]


## Compare baseline with single-discourse model

In [292]:
ranks_base = utils.compute_ranks(txt_base, img_base)
ranks_q = utils.compute_ranks(txt_q, img_q)
ranks_base.shape

(860, 860)

In [293]:
probs = probs_q
labels = labels_q
attns = attns_q
ranks = ranks_q

In [294]:
def get_pos(ranks):
    out = []
    for ii, rank in enumerate(ranks):
        pos = rank.tolist().index(ii)
        out.append(pos)
    return np.asarray(out)

positions = get_pos(ranks)
positions_base = get_pos(ranks_base)

print(probs.shape, labels.shape, attns.shape, ranks.shape, positions.shape)
print(probs_base.shape, labels_base.shape, attns_base.shape, ranks_base.shape, positions_base.shape)

(860, 1) (860, 1) (860, 200) (860, 860) (860,)
(860, 7) (860, 7) (860, 200) (860, 860) (860,)


### Discourse is True or False?

In [295]:
T = 1
t_indices = np.where(labels==T)[0]
t_positions = positions[t_indices]
tmp = np.argsort(t_positions, 0).squeeze()
t_indices = t_indices[tmp]
t_indices.shape

(687,)

### Which sample to show?

In [296]:
imgheader = 'qimg'
capheader = 'stim_txt'
img_dir = '../data/RecipeQA/images-qa/train/images-qa'
save_dir = f'outputs/cite/{relation}={T}/coherence_win'
save_dir = f'outputs/cite/{relation}={T}/coherence_lose'
os.makedirs(save_dir, exist_ok=True)

In [297]:
good_indices = []
for a,b,idx in zip(positions[t_indices], positions_base[t_indices], t_indices):
#     if a<5 and b-a >= 5:
    if b<5 and b-a >= 3:
#         print(a,b, idx)
        good_indices.append(idx)
len(good_indices)

0

In [298]:
with open(os.path.join(save_dir, 'captions.txt'), 'w') as f:
    for idx in tqdm(good_indices[:20]):
        rcp = test_set.recipes.iloc[idx]
        cap = rcp[capheader]
        f.write(f'{idx:>5d} ' + cap + '\n')

0it [00:00, ?it/s]


In [299]:
my_trans = transforms.Compose([
    transforms.Resize(512),
    transforms.CenterCrop(512)
])

def show_img(rcp, title=None):
    img_path = os.path.join(img_dir, rcp[imgheader])
    img = my_trans(Image.open(img_path))
#     _ = plt.imshow(np.asarray(img))
    plt.axis('off')
    if title:
        img.save(title, 'JPEG')
        
for idx in tqdm(good_indices[:20]):
    # idx = 489

    sub_dir = os.path.join(save_dir, str(idx))
    os.makedirs(sub_dir, exist_ok=True)
    f = open(os.path.join(sub_dir, 'attentions.txt'), 'w')

    rcp = test_set.recipes.iloc[idx]
    cap = rcp[capheader]
    cap = utils.clean_caption(cap)
    words = re.split(r'\\n| ', cap)[:args.max_len]
    # print(cap)
    # print(words)

#     print(f'\n==> {relation} == {T}')
#     print(probs[idx])
    f.write(f'Model: {relation} is used\n')
    f.write(f'prob = {probs[idx][0]:.4f}\n')
    for w, attn in zip(words, attns[idx][:len(words)]):
        line = f'{w:>20s} = {attn:<.4f}'
#         print(line)
        f.write(line + '\n')

#     print('\n==> no relation')
#     print(probs_all[idx][relation_idx])
    f.write(f'\nModel: no relation is used\n')
    f.write(f'prob = {probs_base[idx][relation_idx]:.4f}\n')
    for w, attn in zip(words, attns_base[idx][:len(words)]):
        line = f'{w:>20s} = {attn:<.4f}'
#         print(line)
        f.write(line + '\n')


#     show_img(rcp, title=os.path.join(sub_dir, f'real.jpg'))

    # ranks
    rank = ranks[idx]
    fig = plt.figure(figsize=(12,6))
    pos = rank.tolist().index(idx)
    line = f'{relation} == {T}. Top 5 retrieved images, while true image is at pos={pos}'
    fig.suptitle(line, y=0.7)
    f.write(line+'\n')
    i = 0
    for idx_ in rank[:5]:
        plt.subplot(151+i)
        show_img(test_set.recipes.iloc[idx_], title=os.path.join(sub_dir, f'{relation}_top{i}.jpg'))
        i+=1

    rank_base = ranks_base[idx]
    fig = plt.figure(figsize=(12,6))
    pos = rank_base.tolist().index(idx)
    line = f'No Relation. Top 5 retrieved images, while true image is at pos={pos}'
    fig.suptitle(line, y=0.7)
    f.write(line+'\n')
    i = 0
    for idx_ in rank_base[:5]:
        plt.subplot(151+i)
        show_img(test_set.recipes.iloc[idx_], title=os.path.join(sub_dir, f'no_relation_top{i}.jpg'))
        i+=1

    f.close()

0it [00:00, ?it/s]


# Generate Human Rate Samples

In [300]:
t_indices = np.where(labels==T)[0]
t_positions = positions[t_indices]
tmp = np.argsort(t_positions, 0).squeeze()
t_indices = t_indices[tmp]
t_indices.shape

(687,)

In [301]:
save_dir = f'outputs/cite/{relation}={T}/human_evaluation'
os.makedirs(save_dir, exist_ok=True)

good_indices = []
for a,b,idx in zip(positions[t_indices], positions_base[t_indices], t_indices):
    if b-a >0:
        good_indices.append(idx)

print(save_dir)
len(good_indices)

outputs/cite/q2=1/human_evaluation


195

In [302]:
with open(os.path.join(save_dir, 'captions.txt'), 'w') as f:
    for idx in tqdm(good_indices[:100]):
        rcp = test_set.recipes.iloc[idx]
        cap = rcp[capheader]
        f.write(f'{idx:>5d} ' + cap + '\n')

100%|██████████| 100/100 [00:00<00:00, 3723.90it/s]


In [303]:
my_trans = transforms.Compose([
    transforms.Resize(512),
    transforms.CenterCrop(512)
])


def show_img(rcp, title=None):
    img_path = os.path.join(img_dir, rcp[imgheader])
    img = my_trans(Image.open(img_path))
    if title:
        img.save(title, 'JPEG')


for idx in tqdm(good_indices[:100]):
    rcp = test_set.recipes.iloc[idx]

    rank = ranks[idx]
    for idx_ in rank[:1]:
        show_img(test_set.recipes.iloc[idx_], title=os.path.join(save_dir, f'{idx}_cohaware.jpg'))


    rank_base = ranks_base[idx]
    for idx_ in rank_base[:1]:
        show_img(test_set.recipes.iloc[idx_], title=os.path.join(save_dir, f'{idx}_cohagnostic.jpg'))

100%|██████████| 100/100 [00:04<00:00, 24.68it/s]
