In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.autograd import Variable
from PIL import Image

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

import pickle
import json

from models.questioner import QuestionerNet
from models.oracle import OracleNet
from models.guesser import GuesserNet
import data
from vocab import VocabTagger
from resnet_feature_extractor import ResnetFeatureExtractor

In [None]:
resnet_feature_extractor = ResnetFeatureExtractor()
vocab_tagger = VocabTagger()

questioner_net = QuestionerNet().cuda()
oracle_net = OracleNet().cuda()
guesser_net = GuesserNet().cuda()

oracle_net.load_state_dict(
    torch.load(data.get_saved_model('oracle_gru2_fc3_cat32_h128_we64')))
guesser_net.load_state_dict(
    torch.load(data.get_saved_model('guesser_gru2_fc2_cat16_h256_we64')))
questioner_net.load_state_dict(
    torch.load(data.get_saved_model('questioner_lstm1_fc2')))

In [None]:
split = 'train'

i = 0
with open(data.get_gw_file(split), 'r') as f:
    for line in f:
        example = json.loads(line)
        img_path = data.get_coco_file(example['image']['file_name'])
        img = Image.open(img_path)
        if img.mode != 'RGB':
            img = img.convert('RGB')
        
        i += 1
        if i == 4:
            break

plt.imshow(img)

In [None]:
feature = torch.from_numpy(resnet_feature_extractor.get_image_features(img))
feature.unsqueeze_(0)
feature_var = Variable(feature.cuda(), volatile=True)

In [None]:
utterance, h = questioner_net.sample(feature_var, mode='sample')
print(vocab_tagger.get_question_tokens(utterance))

In [None]:
with open(data.get_processed_file('oracle', split, small=True), 'rb') as f:
    tokens, question_lengths, features, categories, answers = pickle.load(f)

In [None]:
features[0]

In [None]:
vocab_tagger.vocab_map.get_id_from_token('can')