In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import seaborn as sns

from jinja2 import Template

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [3]:
from IPython.display import HTML, display

In [4]:
import models.questioner
import models.answerer

import misc.utilities as utils
import misc.dataloader

import misc.visualization as vis

In [5]:
def load(key, fname=None, epoch=60, agent='qbot', override=None):
    global params
    assert agent in ['abot', 'qbot']
    if fname is None:
        fname = 'data/experiments/{}/{}_ep_{}.vd'.format(key, agent, epoch)
        print(fname)

    bot = misc.utilities.loadModelFromFile(fname, agent=agent)
    bot.eval()
    if isinstance(bot, models.questioner.Questioner):
        qbots[key] = bot
    elif isinstance(bot, models.answerer.Answerer):
        abots[key] = bot

In [6]:
qbots = {}
abots = {}

#load('exp13.0.1.1.1', epoch=19)
#load('exp6.0', epoch=65, agent='abot')

# pool "2 random"
pool_size = 2
qexp = aexp = 'exp16.0.2.0.0'
load(qexp, epoch=20, agent='qbot')
load(aexp, epoch=20, agent='abot')

#pool_size = 2
#load('exp13.0.1.1.1', epoch=28, agent='qbot')
#load('exp6.0', epoch=65, agent='abot')

# pool "4 random"
#pool_size = 4
#load('exp14.4.1.2.0', epoch=19, agent='qbot')
#load('exp14.4.1.2.0', epoch=19, agent='abot')

data/experiments/exp16.0.2.0.0/qbot_ep_20.vd
Loading qbot from data/experiments/exp16.0.2.0.0/qbot_ep_20.vd
data/experiments/exp16.0.2.0.0/abot_ep_20.vd
Loading abot from data/experiments/exp16.0.2.0.0/abot_ep_20.vd


In [7]:
dataset = misc.dataloader.VQADataset({
    'inputImg': 'data/img_bottom_up.h5',
    'inputQues': 'data/v2_vqa_data.h5',
    'inputJson': 'data/v2_vqa_info.json',
    'poolType': 'random', # e.g., contrast, random
    'poolSize': pool_size,
    'randQues': True,
}, ['train', 'val'])


Dataloader loading json file: data/v2_vqa_info.json

Loading the pool
number of answer candidates: 3129

Dataloader loading Ques file: data/v2_vqa_data.h5
Vocab size with <START>, <END>: 5994
Dataloader loading h5 file: data/img_bottom_up.h5


In [8]:
dataset.split = 'val'

In [9]:
dataloader = DataLoader(
     dataset,
     batch_size=20,
     shuffle=False,
     num_workers=0,
     drop_last=True,
     pin_memory=False)

In [10]:
dliter = iter(dataloader)

In [11]:
batch = next(dliter)

batch = {key: v.cuda() if hasattr(v, 'cuda') \
                                    else v for key, v in batch.items()}

gt_ques_str = [q[0].strip('<START> ').strip(' <END>')
               for q in utils.old_idx_to_str(dataset.ind2word, batch['ques'], batch['ques_len'], 
                                         batch['img_id_pool'], 0, [])]

# Latent variable visualizations

In [18]:
with open('templates/z_vis.html') as f:
    template = Template(f.read())

In [21]:
qd = {}
#inferences = [('greedy', 'greedy'), ('greedy', 'sample'), ('sample', 'greedy'), ('sample', 'sample')]
inferences = [('greedy', 'greedy'), ('greedy', 'sample'), ('sample', 'greedy'), ('sample', 'sample')]
# NOTE: manual must be after None because it uses the results of None
zsources = [None, 'policy'] #'prior', 
#zsources += ['manual{}'.format(dim) for dim in list(range(128))[:3]]
zsources += ['interpolate']
for qk in qbots:
    qbot = qbots[qk]
    for zkind in zsources:
        for inference in inferences:
            if zkind is not None and zkind[:6] == 'manual':
                if inference != ('sample', 'greedy'):
                    continue
                inference = ('manual', 'greedy')
                newz = vis.single_dim_manual_z(qd[(qk, None) + ('greedy', 'greedy')][1], qbot, dim=int(zkind[6:]))
                nsamples = len(newz)
            elif zkind == 'interpolate':
                if inference != ('sample', 'greedy'):
                    continue
                zkind = 'manual'
                inference = ('interp', 'greedy')
                newz = vis.interp_manual_z(qd[(qk, None) + ('greedy', 'greedy')][1])
                nsamples = len(newz)
            elif inference == ('greedy', 'greedy'):
                newz = None
                nsamples = 1
            else:
                newz = None
                nsamples = 3
            print(qk, zkind, inference, qbot.vaeMode)
            qd[(qk, zkind) + inference] = vis.get_questions(batch, qbot, zkind, inference, dataset.ind2word, nsamples=nsamples, manualz=newz)

exp16.0.2.0.0 None ('greedy', 'greedy') gumbelst-vae
exp16.0.2.0.0 None ('greedy', 'sample') gumbelst-vae
exp16.0.2.0.0 None ('sample', 'greedy') gumbelst-vae
exp16.0.2.0.0 None ('sample', 'sample') gumbelst-vae
exp16.0.2.0.0 policy ('greedy', 'greedy') gumbelst-vae
exp16.0.2.0.0 policy ('greedy', 'sample') gumbelst-vae
exp16.0.2.0.0 policy ('sample', 'greedy') gumbelst-vae
exp16.0.2.0.0 policy ('sample', 'sample') gumbelst-vae
exp16.0.2.0.0 manual ('interp', 'greedy') gumbelst-vae


In [22]:
def render_column(qbot, questions, latents, k, exi, inference, collapse=False):
    ztype = k[1] if inference[0] == 'manual' else inference[0]
    dectype = inference[1]
    result = '<br><b>z-{}, dec-{}:</b><br>'.format(ztype, dectype)
    prev_question = ''
    prev_sample_i = 0
    for sample_i in range(len(questions)):
        new_question = questions[sample_i][exi]
        if collapse and new_question == prev_question:
            continue
        result += new_question + '<br>'
        if sample_i == 0:
            result += '(z=' + latents[sample_i][exi][0]
        else:
            curr_z = latents[sample_i][exi][1]
            prev_z = latents[prev_sample_i][exi][1]
            _, curr_idx = curr_z.reshape(qbot.num_vars, -1).max(dim=1)
            _, prev_idx = prev_z.reshape(qbot.num_vars, -1).max(dim=1)
            change_idxs = (curr_idx - prev_idx).nonzero().reshape(-1).tolist()
            result += '(zdiff=' + str(change_idxs)
            result += ')<br>'
            result += '(z=' + latents[sample_i][exi][0]
        result += ')<br>'
        prev_question = new_question
        prev_sample_i = sample_i
    return result                                              

In [25]:
#inferences = inferences + [('manual', 'greedy')]
inferences = [('greedy', 'greedy'), ('sample', 'greedy'), ('manual', 'greedy'), ('interp', 'greedy')]
def render_partial_key(k, exi):
    ques = ''
    for inference in inferences:
        if k + inference not in qd:
            continue
        qbot = qbots[k[0]]
        questions = qd[k + inference][0]
        latents = qd[k + inference][1]
        ques += render_column(qbot, questions, latents, k, exi, inference, collapse=(inference[0] == 'interp'))
    return ques

keys = sum([
    #[(k, 'prior') for k in qbots],
    [(k, 'policy') for k in qbots],
    [(k, None) for k in qbots],
] + [
    [(k, zs) for k in qbots] for zs in zsources if zs is not None and 'manual' in zs
] + [
    [(k, 'manual') for k in qbots] for zs in zsources if zs == 'interpolate'
], [])

examples = []
display_data = {
    'keys': keys,
    'examples': examples,
}

for exi in range(len(batch['img_id_pool'])):
    #(qboti, inference, zkind) = k

    img_paths = [vis.load_image(batch['img_id_pool'][exi][i].item())[1] for i in range(batch['img_id_pool'].shape[1])]
    example = {
            'gt_ques': gt_ques_str[exi],
            'img_uris': [vis.img_to_datauri(img_path) for img_path in img_paths],
            'questions': [render_partial_key(k, exi) for k in keys],
        }
    
    examples.append(example)

html = template.render(display_data)
with open('examples.html', 'w') as f:
    f.write(html)
#display(HTML(html), metadata=dict(isolated=True))

# Dialog visualizations

In [13]:
with open('templates/dialog_viz.html') as f:
    dialog_template = Template(f.read())

In [14]:
wrap_period = 4
numRounds = 5
qBot = qbots[qexp]
aBot = abots[aexp] # or exp6.0
z_inference = 'sample'
batch_size = batch['img_pool'].shape[0]
pool_size = batch['img_id_pool'].shape[1]
ans2label = dataset.ans2label
label2ans = {label: ans for ans, label in ans2label.items()}
assert len(ans2label) == len(label2ans)

rounds = []

qBot.reset()
# observe the image
qBot.observe(images=batch['img_pool'])
qBot.tracking = True

for Round in range(numRounds):
    print('Round {}'.format(Round))
    if Round == 0:
        # observe initial question
        qBot.observe(start_question=True)
        # since we only has 1 round.
        qBot.observe(start_answer=True)

    # decode the question
    ques, ques_len, stop_logits, _ = qBot.forwardDecode(dec_inference='sample',
                                            z_inference=z_inference, z_source='policy')

    region_attn = qBot.ctx_coder.region_attn.squeeze(3)
    region_uris = vis.region_att_images(batch['img_id_pool'], region_attn,  batch['img_pool_spatial'])
    img_attn = qBot.ctx_coder.img_attn.squeeze(2)
    pool_uris = vis.pool_atten_uris(img_attn, wrap_period=wrap_period)

    logit = qBot.predictImage()
    #loss += ce_criterion(logit.squeeze(2), batch['target_pool'].view(-1))
    _, predict = torch.max(logit.squeeze(2), dim=1)
    predCorrect = (predict == batch['target_pool'].view(-1)).to(torch.float)
    
    predict_region_attn = qBot.predict_ctx.attn.squeeze(3)
    predict_region_uris = vis.region_att_images(batch['img_id_pool'], predict_region_attn,  batch['img_pool_spatial'])
    predict_probs = F.softmax(logit.squeeze(2), dim=1)
    predict_uris = vis.pool_atten_uris(predict_probs, wrap_period=wrap_period)


    # observe the question here.
    qBot.observe(ques=ques, ques_len=ques_len, gt_ques=False)
    # answer the question here.
    ans, rel_logit = aBot.forward(batch['target_image'], ques, ques_len, inference_mode=False)
    _, ansIdx = torch.max(ans, dim=1)
    # to predict the target image, use the latest latent state and the predict answer to select the target images.
    qBot.observe(ans=ansIdx)

    rel_probs = F.softmax(rel_logit, dim=1)[:, 1].tolist()
    rel_uris = vis.pool_atten_uris(F.softmax(rel_logit, dim=1)[:, 1:2], wrap_period=wrap_period)
    
    gen_ques_str = utils.old_idx_to_str(dataset.ind2word, ques, ques_len, batch['img_id_pool'], 0, [])
    gen_ques_str = [q[0].strip('<START> ').strip(' <END>') for q in gen_ques_str]
    
    rounds.append({
        'questions': gen_ques_str,
        'answers': [label2ans[i] for i in ansIdx.tolist()],
        'preds': predict.tolist(),
        'region_uris': region_uris,
        'pool_atten_uris': pool_uris,
        'predict_region_uris': predict_region_uris,
        'predict_uris': predict_uris,
        'is_rel_probs': rel_probs,
        'rel_uris': rel_uris,
    })

    #RoundAccuracy[Round] = float(predCorrect.mean())
    #accuracy += float((predCorrect).mean())

Round 0
Round 1
Round 2
Round 3
Round 4


In [16]:
examples = []
target_idxs = batch['target_pool'].view(-1).tolist()
# when displaying a pool, put this many images on one row then wrap to the next row

for exi in range(batch_size):
    img_paths = []
    for i in range(pool_size):
        img_idx = batch['img_id_pool'][exi][i].item()
        assert img_idx != 0
        img_path = vis.load_image(img_idx)[1]
        img_paths.append(img_path)
    
    example = {
        'gt_ques': gt_ques_str[exi],
        'target': target_idxs[exi],
        'img_uris': [vis.img_to_datauri(img_path) for img_path in img_paths],
    }
    
    examples.append(example)

html = dialog_template.render({
    'title': 'qbot {}, abot {}'.format(qexp, aexp),
    'examples': examples,
    'rounds': rounds,
    'wrap_period': wrap_period,
})
with open('dialog_examples.html', 'w') as f:
    f.write(html)
#display(HTML(html), metadata=dict(isolated=True))
#display(HTML('<img src="{}"></img>'.format(uri)), metadata=dict(isolated=True))