In [1]:
import pickle

import numpy as np
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import data
import vocab
import packed_sequence_utils as packed_util

In [2]:
class OracleDataset(Dataset):
    def __init__(self, tokens, question_lengths, features, categories, answers):
        assert tokens.shape[0] == question_lengths.shape[0] == features.shape[0] \
            == categories.shape[0] == answers.shape[0]
        
        self.tokens = torch.from_numpy(tokens)
        self.question_lengths = torch.from_numpy(question_lengths)
        self.features = torch.from_numpy(features)
        self.categories = torch.from_numpy(categories)
        self.answers = torch.from_numpy(answers)
        
    def __len__(self):
        return self.tokens.size(0)
    
    def __getitem__(self, i):
        return (self.tokens[i], self.question_lengths[i],
                self.features[i], self.categories[i], self.answers[i])

In [3]:
BATCH_SIZE = 64

def load_dataset(split, small):
    with open(data.get_processed_file('oracle', split, small), 'rb') as f:
        return pickle.load(f)

def get_data_loader(split, small):
    return DataLoader(
        OracleDataset(*load_dataset(split, small)),
        batch_size=64,
        shuffle=True,
        num_workers=1
    )

small = True
loader_train = get_data_loader('train', small)
loader_valid = get_data_loader('valid', small)

vocab_map = vocab.VocabMap()
vocab_size = vocab_map.vocab_size
print(vocab_size)

2701


In [4]:
class OracleNet(nn.Module):
    def __init__(self, image_spatial_dim=4104,
                 question_hidden_dim=128, token_embed_dim=64,
                 vocab_size=vocab_size, question_max_len=data.MAX_TOKENS_PER_QUESTION,
                 num_categories=data.NUM_CATEGORIES, category_embed_dim=32):
        super(OracleNet, self).__init__()
        
        self.image_spatial_dim = image_spatial_dim
        self.question_max_len = question_max_len
        self.question_hidden_dim = question_hidden_dim
        
        self.token_embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=token_embed_dim
        )
        
        self.category_embedding = nn.Embedding(
            num_embeddings=num_categories,
            embedding_dim=category_embed_dim
        )
        
        self.question_encoder = nn.GRU(
            input_size=token_embed_dim,
            hidden_size=question_hidden_dim,
            num_layers=1,
            batch_first=True
        )
        
        fc1_in = image_spatial_dim + question_hidden_dim + category_embed_dim
        self.fc1 = nn.Linear(fc1_in, 3)
        
    def forward(self, tokens, question_lens, features, categories):
        embed_tokens = self.token_embedding(tokens)
        output, h_n = self.question_encoder(embed_tokens)
        gather_index = (question_lens-1).repeat(self.question_hidden_dim, 1, 1).transpose(0, 2)
        question_encodings = output.gather(dim=1, index=gather_index).squeeze()
        
        embed_category = self.category_embedding(categories)

        fc1_in = torch.cat([question_encodings, features, embed_category], 1)
        return self.fc1(fc1_in)
        
        
# USING PACKED SEQUENCE
#     def sort_by_len(self, embed_tokens, question_lens, features, categories):
#         question_lens, indices = question_lens.sort(descending=True)
#         print(question_lens.data.numpy())
#         print(question_lens.size())
#         embed_tokens = embed_tokens.index_select(dim=0, index=indices)
#         print(embed_tokens.size())
#         features = features.index_select(dim=0, index=indices)
#         categories = categories.index_select(dim=0, index=indices)
#         return (
#             nn.utils.rnn.pack_padded_sequence(
#                 embed_tokens,
#                 question_lens,
#                 batch_first=True
#             ),
#             features,
#             categories
#         )
        
#     def forward(self, tokens, question_lens, features, categories):
#         embed_category = self.category_embedding(categories)
#         embed_tokens = self.token_embedding(tokens)
        
#         questions, features, categories = self.sort_by_len(embed_tokens, question_lens, features, categories)
        
#         output, h_n = self.question_encoder(embed_tokens)
#         print(output.size())
#         print((question_lens-1).size())
#         question_encodings = output[:, question_lens-1, :]
#         print(question_encodings.size())

In [5]:
def check_accuracy(model, loader):
    num_correct = 0
    num_samples = 0
    model.eval() # Put the model in test mode (the opposite of model.train(), essentially)
    for tokens, q_lens, features, cats, answers in loader:
        tokens_var = Variable(tokens.cuda(), volatile=False)
        q_lens_var = Variable(q_lens.cuda(), volatile=False)
        features_var = Variable(features.cuda(), volatile=False)
        cats_var = Variable(cats.cuda(), volatile=False)

        scores = model(tokens_var, q_lens_var, features_var, cats_var)
        _, preds = scores.data.cpu().max(1)
        num_correct += (preds == answers).sum()
        num_samples += preds.size(0)
    acc = float(num_correct) / num_samples
    print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))

def train(model, loss_fn, optimizer, num_epochs=1, print_every=1000):
    for epoch in range(num_epochs):
        print('Starting epoch %d / %d' % (epoch + 1, num_epochs))
        model.train()

        for t, (tokens, q_lens, features, cats, answers) in enumerate(loader_train):
            tokens_var = Variable(tokens.cuda(), requires_grad=False)
            q_lens_var = Variable(q_lens.cuda(), requires_grad=False)
            features_var = Variable(features.cuda(), requires_grad=False)
            cats_var = Variable(cats.cuda(), requires_grad=False)
            answers_var = Variable(answers.cuda(), requires_grad=False)

            scores = model(tokens_var, q_lens_var, features_var, cats_var)
            
            loss = loss_fn(scores, answers_var)
            
            if t % print_every == 0:
                print('t = %d, loss = %.4f' % (t + 1, loss.data[0]))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        check_accuracy(model, loader_train)

In [6]:
loss_fn = nn.CrossEntropyLoss()
oracle_net = OracleNet().cuda()
optimizer = torch.optim.Adam(oracle_net.parameters())

train(oracle_net, loss_fn, optimizer, num_epochs=100)
check_accuracy(oracle_net, loader_valid)

Starting epoch 1 / 100
t = 1, loss = 1.1241
Got 2222 / 4262 correct (52.14)
Starting epoch 2 / 100
t = 1, loss = 0.6642
Got 2246 / 4262 correct (52.70)
Starting epoch 3 / 100
t = 1, loss = 0.7483
Got 2549 / 4262 correct (59.81)
Starting epoch 4 / 100
t = 1, loss = 0.7976
Got 2759 / 4262 correct (64.73)
Starting epoch 5 / 100
t = 1, loss = 0.6549
Got 2847 / 4262 correct (66.80)
Starting epoch 6 / 100
t = 1, loss = 0.6281
Got 2477 / 4262 correct (58.12)
Starting epoch 7 / 100
t = 1, loss = 0.7953
Got 2840 / 4262 correct (66.64)
Starting epoch 8 / 100
t = 1, loss = 0.5640
Got 3002 / 4262 correct (70.44)
Starting epoch 9 / 100
t = 1, loss = 0.5487
Got 3131 / 4262 correct (73.46)
Starting epoch 10 / 100
t = 1, loss = 0.5573
Got 3276 / 4262 correct (76.87)
Starting epoch 11 / 100
t = 1, loss = 0.4230
Got 3111 / 4262 correct (72.99)
Starting epoch 12 / 100
t = 1, loss = 0.4882
Got 3399 / 4262 correct (79.75)
Starting epoch 13 / 100
t = 1, loss = 0.4256
Got 3215 / 4262 correct (75.43)
Starting

In [None]:
from visualization import make_dot

In [None]:
make_dot(y)