In [87]:
import torch
import pandas as pd
import os
import numpy as np
import json
import h5py
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

import sys
import pprint
from collections import Counter,defaultdict
from itertools import chain

class DialogDataset(Dataset):
    def __init__(self, json_data, image_features, img2feat, transform=None):
        
        with open(img2feat, 'r') as f:
            self.img2feat = json.load(f)['IR_imgid2id']
            
        self.img_features = np.asarray(h5py.File(image_features, 'r')['img_features'])
        self.json_data = pd.read_json(json_data, orient='index')
        self.corpus = self.get_words()
        self.vocab = list(set(self.corpus))
        
    # collect all the words from dialogs and 
    # captions and use them to create embedding map
    def get_words(self):
        words = [datapoint['dialog'] for datapoint in self]
        return list(chain.from_iterable(words))

    def __len__(self):
        return len(self.json_data)

    def __getitem__(self, idx):
        item = self.json_data.iloc[idx]

        # Flatten dialog and add caption into 1d array
        dialog = [word for line in item.dialog for word in line[0].split()]
        dialog.extend(item.caption.split(' '))

        img_ids = np.array(item.img_list)
        img_features = [self.img_features[idx] for idx in map(lambda x: self.img2feat[str(x)], img_ids)]
        
        return {
            'dialog' : dialog, 
            'img_ids': item.img_list, 
            'img_features': img_features, 
            'target_idx' : item.target}

In [113]:
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class CBOW(torch.nn.Module):

    def __init__(self, vocab_size, embedding_dim):
        super(CBOW, self).__init__()

        #out: 1 x emdedding_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear1 = nn.Linear(embedding_dim, 128)
        self.activation_function1 = nn.ReLU()
        
        #out: 1 x vocab_size
        self.linear2 = nn.Linear(128, vocab_size)
        self.activation_function2 = nn.Sigmoid()
        

    def forward(self, inputs):
        # i believe .view() is useless here because the sum already produces a 1xEMB_DIM vector
        embeds = sum(self.embeddings(inputs)).view(1,-1)
        out = self.linear1(embeds)
        out = self.activation_function1(out)
        out = self.linear2(out)
        out = self.activation_function2(out)
        return out

def make_context_vector(context, w2i):
    idxs = [w2i[w] for w in context]
    tensor = torch.LongTensor(idxs)
    return autograd.Variable(tensor)

In [107]:
SAMPLE_EASY = ['Data', 'sample_easy.json']
TRAIN_EASY = ['Data', 'Easy', 'IR_train_easy.json']
IMG_FEATURES = ['Data', 'Features', 'IR_image_features.h5']
INDEX_MAP = ['Data', 'Features', 'IR_img_features2id.json']

EMBEDDING_DIM = 5
CONTEXT_SIZE = 2
FREQ_THRESHOLD = 0

torch.manual_seed(1)
dialog_data = DialogDataset(os.path.join(*SAMPLE_EASY), os.path.join(*IMG_FEATURES), os.path.join(*INDEX_MAP))

w2i = {word : i for i, word in enumerate(dialog_data.vocab)}

In [115]:
cbow_model = CBOW(len(w2i), EMBEDDING_DIM)
inp = make_context_vector(dialog_data[0]['dialog'], w2i)
text_feat = torch.squeeze(cbow_model(inp), 0)
print("Sample text feature: \n\n", out)

sample_img_feat = Variable(torch.FloatTensor(dialog_data[0]['img_features'][0]))
print("Sample image feature: \n\n", sample_img_feat)

concat_features = torch.cat((out, sample_img_feat), 0)
print(concat_features)

Sample text feature: 

 Variable containing:
 0.2252
 0.3313
 0.0773
 0.2268
 0.2515
 0.0630
 0.8003
 0.9296
 0.9980
 0.9917
 0.5473
 0.0008
 0.0174
 0.0071
 0.9171
 0.8011
 0.5299
 0.0669
 0.1455
 0.9086
 0.0833
 0.4478
 0.1292
 0.2738
 0.7710
 0.0682
 0.0032
 0.9940
 0.0495
 0.9854
 0.8881
 0.0815
 0.2378
 0.0839
 0.6673
 0.0253
 0.1221
 0.9273
 0.9107
 0.0012
 0.2854
 0.9495
 0.3647
 0.9009
 0.9300
 0.4066
 0.8211
 0.9407
 0.8537
 0.8849
 0.0128
 0.3720
 0.7848
 0.1335
 0.8912
 0.7791
 0.0327
 0.0360
 0.0013
 0.9996
 0.9251
 0.1482
 0.3513
 0.0427
 0.5312
 0.8600
 0.7299
 0.0025
 0.0003
 0.0773
 0.8056
 0.9514
 0.0523
 0.8842
 0.0043
 0.0091
 0.5931
 0.0021
 0.0291
 0.0646
 0.2122
 0.6559
 0.2913
 0.9679
 0.9846
 0.9062
 0.9314
 0.7849
 0.8816
 0.1949
 0.3745
 0.9639
 0.3350
 0.1424
 0.9958
 0.0194
 0.5908
 0.9524
 0.5718
 0.9399
 0.9192
 0.8346
 0.9544
 0.8826
 0.0204
 0.5674
 0.9788
 0.5429
 0.0864
 0.0099
 0.8657
 0.0230
 0.0004
 0.3159
 0.9708
 0.9996
 0.3085
 0.0604
 0.9926
 0.