In [1]:
import torch
import pandas as pd
import os
import numpy as np
import json
import h5py
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

import sys
import pprint
from collections import Counter,defaultdict
from itertools import chain

class DialogDataset(Dataset):
    
    def __init__(self, json_data, image_features, img2feat, transform=None):
        
        with open(img2feat, 'r') as f:
            self.img2feat = json.load(f)['IR_imgid2id']
            
        self.img_features = np.asarray(h5py.File(image_features, 'r')['img_features'])
        self.json_data = pd.read_json(json_data, orient='index')
        self.corpus = self.get_words()
        self.vocab = list(set(self.corpus))
        self.w2i = {word : i for i, word in enumerate(self.vocab)}
        
    # collect all the words from dialogs and 
    # captions and use them to create embedding map
    def get_words(self):
        words = []
        for idx in range(len(self)):
            item = self.json_data.iloc[idx]

            # Flatten dialog and add caption into 1d array
            dialog = [word for line in item.dialog for word in line[0].split()]
            dialog.extend(item.caption.split(' '))

            words.append(dialog)
            
        return list(chain.from_iterable(words))
    
    def make_context_vector(self, context):
        idxs = [self.w2i[w] for w in context]
        tensor = torch.LongTensor(idxs)
        return autograd.Variable(tensor)

    def __len__(self):
        return len(self.json_data)

    def __getitem__(self, idx):
        item = self.json_data.iloc[idx]

        # Flatten dialog and add caption into 1d array
        dialog = [word for line in item.dialog for word in line[0].split()]
        dialog.extend(item.caption.split(' '))
        dialog = self.make_context_vector(dialog)

        img_ids = np.array(item.img_list)
        img_features = [self.img_features[idx] for idx in map(lambda x: self.img2feat[str(x)], img_ids)]
        img_features = np.array(img_features)
        img_features = Variable(torch.FloatTensor(img_features))
        
        target_idx = item.target
        
        return {'dialog' : dialog, 'img_features': img_features}, target_idx

In [2]:
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class CBOW(torch.nn.Module):

    def __init__(self, vocab_size, embedding_dim):
        super(CBOW, self).__init__()

        #out: 1 x emdedding_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear1 = nn.Linear(embedding_dim, 128)
        self.activation_function1 = nn.ReLU()
        
        #out: 1 x vocab_size
        self.linear2 = nn.Linear(128, vocab_size)
        self.activation_function2 = nn.Sigmoid()
        

    def forward(self, inputs):
        # i believe .view() is useless here because the sum already produces a 1xEMB_DIM vector
        embeds = sum(self.embeddings(inputs)).view(1,-1)
        out = self.linear1(embeds)
        out = self.activation_function1(out)
        out = self.linear2(out)
        out = self.activation_function2(out)
        return out

In [3]:
class MaxEnt(torch.nn.Module):
    
    def __init__(self, text_module, vocab_size, img_size):
        super(MaxEnt, self).__init__()

        self.text_module = text_module
        self.linear = nn.Linear(vocab_size + img_size, 1)
        self.softmax = nn.LogSoftmax()
        
    def forward(self, inp):
        dialog = inp['dialog']
        all_img_features = inp['img_features']
        
        text_features = self.text_module(dialog)
        text_features = text_features.expand((all_img_features.size(0), text_features.size(1)))
    
        concat = torch.cat((all_img_features, text_features), 1)
        
        scores = self.linear(concat)
        scores = self.softmax(scores.transpose(0, 1))
        
        return scores

In [4]:
SAMPLE_EASY = ['Data', 'sample_easy.json']
TRAIN_EASY = ['Data', 'Easy', 'IR_train_easy.json']
VALID_EASY = ['Data', 'Easy', 'IR_val_easy.json']
IMG_FEATURES = ['Data', 'Features', 'IR_image_features.h5']
INDEX_MAP = ['Data', 'Features', 'IR_img_features2id.json']

IMG_SIZE = 2048
EMBEDDING_DIM = 5

torch.manual_seed(1)
# dialog_data = DialogDataset(os.path.join(*SAMPLE_EASY), os.path.join(*IMG_FEATURES), os.path.join(*INDEX_MAP))
dialog_data = DialogDataset(os.path.join(*TRAIN_EASY), os.path.join(*IMG_FEATURES), os.path.join(*INDEX_MAP))
valid_data = DialogDataset(os.path.join(*VALID_EASY), os.path.join(*IMG_FEATURES), os.path.join(*INDEX_MAP))

vocab_size = len(dialog_data.vocab)

In [5]:
cbow_model = CBOW(vocab_size, EMBEDDING_DIM)
max_ent = MaxEnt(cbow_model, vocab_size, IMG_SIZE)
max_ent(dialog_data[0][0])

Variable containing:
-2.2497 -2.2939 -2.3790 -2.2853 -2.2032 -2.3234 -2.2990 -2.3669 -2.2841 -2.3546
[torch.FloatTensor of size 1x10]

In [9]:
def validate(model, data, loss_func):
    total_loss = 0
    
    for i, (inp, target) in enumerate(data):
        pred = model(inp)
        target = Variable(torch.LongTensor(np.array([target])))
        
        loss = loss_func(pred, target)
        total_loss += loss
        
        if i == 20:
            break
    
    return total_loss / 20

def predict(model, data):
    correct = 0
    
    for i, (inp, target) in enumerate(data):
        pred = model(inp)
        img, idx = torch.max(pred, 1)
        if idx.data[0] == target:
            correct += 1
        
        if i == 20:
            break
    
    return correct

validate(max_ent, valid_data, nn.NLLLoss())

Variable containing:
 2.3480
[torch.FloatTensor of size 1]

In [12]:
EPOCHS = 10

cbow_model = CBOW(vocab_size, EMBEDDING_DIM)
max_ent = MaxEnt(cbow_model, vocab_size, IMG_SIZE)
loss_func = nn.NLLLoss()
optimizer = optim.Adam(max_ent.parameters(), lr=1e-05)
validation_errors = []

for epoch in range(1, EPOCHS + 1):
    total_loss = 0
    for i, (inp, target) in enumerate(dialog_data):
        
        pred = max_ent(inp)
        
        target = Variable(torch.LongTensor(np.array([target])))

        loss = loss_func(pred, target)
        total_loss += loss.data[0]
            
        max_ent.zero_grad()
        
        loss.backward()
        optimizer.step()
        
        if i == 500:
            break
    
    total_loss = total_loss / 500
    print("Epoch {}: {}".format(epoch, total_loss))
    print("Predicted {}/20 samples correctly".format(predict(max_ent, valid_data)))
    
    val = validate(max_ent, valid_data, loss_func)
    validation_errors.append(val.data[0])
        
    
print(validation_errors)

Epoch 1: 2.3020203099250796
Predicted 4/20 samples correctly
Epoch 2: 2.285960863113403
Predicted 8/20 samples correctly
Epoch 3: 2.272419318199158
Predicted 9/20 samples correctly
Epoch 4: 2.260829509735107
Predicted 10/20 samples correctly
Epoch 5: 2.250883805990219
Predicted 10/20 samples correctly
Epoch 6: 2.2423208281993867
Predicted 9/20 samples correctly
Epoch 7: 2.234920102357864
Predicted 10/20 samples correctly
Epoch 8: 2.2284959616661073
Predicted 10/20 samples correctly
Epoch 9: 2.2228928501605987
Predicted 10/20 samples correctly
Epoch 10: 2.217980490922928
Predicted 11/20 samples correctly
[2.3881301879882812, 2.366870880126953, 2.3486194610595703, 2.332937240600586, 2.3194475173950195, 2.3078274726867676, 2.297800302505493, 2.289130210876465, 2.281618356704712, 2.2750935554504395]
