In [9]:
import torch
import pandas as pd
import os
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

import sys
import pprint
from collections import Counter,defaultdict
from itertools import chain

class DialogDataset(Dataset):
    def __init__(self, json_data, transform=None):
        self.json_data = pd.read_json(json_data, orient='index')
        self.corpus = self.get_words()
        self.vocab = list(set(self.corpus))
        
    # collect all the words from dialogs and 
    # captions and use them to create embedding map
    def get_words(self):
        words = [datapoint['dialog'] for datapoint in self]
        return list(chain.from_iterable(words))
    
#     def filter_vocabulary(self):
#         common_words = [word for word, count in Counter(self.corpus).most_common()]
#         return common_words[:50]

    def __len__(self):
        return len(self.json_data)

    def __getitem__(self, idx):
        item = self.json_data.iloc[idx]

        # Flatten dialog and add caption into 1d array
        dialog = [word for line in item.dialog for word in line[0].split()]
        dialog.extend(item.caption.split(' '))
        #words = np.array(dialog)

        img_ids = np.array(item.img_list)
        target = np.array([item.target, item.target_img_id])

        return {'dialog':dialog, 'img_ids':item.img_list, 'target':item.target_img_id}

In [10]:
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Find n words before, and n words after given word
# and add them to the dataset
def scan_context(raw_text, n):
    data = []
    for i in range(n, len(raw_text) - n):
        context = raw_text[i-n:i] + raw_text[i+1:i+n+1]
        target = raw_text[i]
        data.append((context, target))
    return data

class CBOW(torch.nn.Module):

    def __init__(self, vocab_size, embedding_dim):
        super(CBOW, self).__init__()

        #out: 1 x emdedding_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear1 = nn.Linear(embedding_dim, 128)
        self.activation_function1 = nn.ReLU()
        
        #out: 1 x vocab_size
        self.linear2 = nn.Linear(128, vocab_size)
        self.activation_function2 = nn.LogSoftmax()
        

    def forward(self, inputs):
        # i believe .view() is useless here because the sum already produces a 1xEMB_DIM vector
        embeds = sum(self.embeddings(inputs)).view(1,-1)
        out = self.linear1(embeds)
        out = self.activation_function1(out)
        out = self.linear2(out)
        out = self.activation_function2(out)
        return out

    def get_word_emdedding(self, word):
        word = Variable(torch.LongTensor([w2i[word]]))
        return self.embeddings(word).view(1,-1)


def make_context_vector(context, w2i):
    idxs = [w2i[w] for w in context]
    tensor = torch.LongTensor(idxs)
    return autograd.Variable(tensor)

def sample_weights(model, n):
    for i, param in enumerate(model.parameters()):
        if i == n:
            print(param[0])
        

def train(model, data):
    
    loss_function = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

    for epoch in range(50):
        total_loss = 0
        for i, observation in enumerate(data):
            dialog = observation['dialog']
            target = observation['target']
            
            # Converts input and target to tensors
            context_vec = make_context_vector(dialog, w2i)
            target = Variable(torch.LongTensor([target.data]))
            
            # Zero out gradients
            model.zero_grad()
            
            # Forward pass
            log_probs = model(context_vec)
    
            # Calculate loss and update gradients
            loss = loss_function(log_probs, target)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.data
        
        if epoch % 10 == 0:
            print("Finished epoch", epoch, "loss:", total_loss)
            sample_weights(model, 1)

In [14]:
SAMPLE_EASY = ['Data', 'sample_easy.json']
TRAIN_EASY = ['Data', 'Easy', 'IR_train_easy.json']
EMBEDDING_DIM = 5
CONTEXT_SIZE = 2
FREQ_THRESHOLD = 0

torch.manual_seed(1)
dialog_data = DialogDataset(os.path.join(*SAMPLE_EASY))

w2i = {word : i for i, word in enumerate(dialog_data.vocab)}

In [20]:
model = CBOW(len(w2i), EMBEDDING_DIM)

train(model, dialog_data)

Sample text feature: 

 Variable containing:

Columns 0 to 7 
-10.6338 -10.1004 -11.8775 -10.6243 -10.4888 -12.0977  -8.0102  -6.8183

Columns 8 to 15 
 -3.1825  -4.6175  -9.2084 -16.5549 -13.4291 -14.3406  -6.9950  -8.0047

Columns 16 to 23 
 -9.2783 -12.0327 -11.1683  -7.1018 -11.7965  -9.6078 -11.3064 -10.3735

Columns 24 to 31 
 -8.1839 -12.0130 -15.1411  -4.2859 -12.3538  -5.1868  -7.3269 -11.8197

Columns 32 to 39 
-10.5630 -11.7885  -8.7020 -13.0513 -11.3707  -6.8526  -7.0757 -16.1151

Columns 40 to 47 
-10.3158  -6.4633  -9.9530  -7.1912  -6.8112  -9.7762  -7.8742  -6.6350

Columns 48 to 55 
 -7.6338  -7.3580 -13.7431  -9.9217  -8.1045 -11.2687  -7.2950  -8.1376

Columns 56 to 63 
-12.7855 -12.6861 -16.0727  -1.6448  -6.8837 -11.1470 -10.0113 -12.5082

Columns 64 to 71 
 -9.2732  -7.5832  -8.4038 -15.3918 -17.4359 -11.8781  -7.9767  -6.4230

Columns 72 to 79 
-12.2951  -7.3652 -14.8451 -14.0839  -9.0215 -15.5635 -12.9066 -12.0712

Columns 80 to 87 
-10.7096  -8.7530 -10.2872  -