In [250]:
import torch
import pandas as pd
import os
import numpy as np
import json
import h5py
import heapq
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from pathlib import Path
import time

import sys
import pprint
from collections import Counter,defaultdict
from itertools import chain

PAD = "<pad>"

class DialogDataset(Dataset):
    
    def __init__(self, json_data, image_features, img2feat, transform=None):
        
        with open(img2feat, 'r') as f:
            self.img2feat = json.load(f)['IR_imgid2id']
            
        self.img_features = np.asarray(h5py.File(image_features, 'r')['img_features'])
        self.json_data = pd.read_json(json_data, orient='index')
        self.corpus = self.get_words()
        self.vocab = list(set(self.corpus))
        self.vocab.append(PAD)
        self.w2i = {word : i for i, word in enumerate(self.vocab)}
        
    # collect all the words from dialogs and 
    # captions and use them to create embedding map
    def get_words(self):
        words = []
        for idx in range(len(self)):
            item = self.json_data.iloc[idx]

            # Flatten dialog and add caption into 1d array
            dialog = [word for line in item.dialog for word in line[0].split()]
            dialog.extend(item.caption.split(' '))

            words.append(dialog)
            
        return list(chain.from_iterable(words))
    
    def make_context_vector(self, context):
        idxs = [self.w2i[w] for w in context]
        tensor = torch.LongTensor(idxs)
        return tensor
    
    def convert_to_idx(self, sequence):
        return [self.w2i[w] for w in sequence]
    
    def pad(self, dialog):
        # length of longest question/answer pair
        n = max(map(len, dialog))
        return [sentence + [PAD] * (n - len(sentence)) for sentence in dialog]
        

    def __len__(self):
        return len(self.json_data)

    def __getitem__(self, key):
        item = self.json_data.iloc[key]
        
        diag = item.dialog
        capt = [item.caption]

        # No appending required if it is already done before:
        if diag[-1] != capt:
            diag.append(capt)
        diag = [QA[0].split() for QA in diag]
        diag = self.pad(diag)
        
        try:
            diag = [self.convert_to_idx(QA) for QA in diag]
        except:
            print(diag)

        diag = torch.LongTensor(diag)

        img_ids = np.array(item.img_list)
        img_features = [self.img_features[idx] for idx in map(lambda x: self.img2feat[str(x)], img_ids)]
        img_features = np.array(img_features)
        img_features = torch.FloatTensor(img_features)

        target = item.target
        target = torch.LongTensor(np.array([target]))

        if torch.cuda.is_available():
            diag, img_features, target = diag.cuda(), img_features.cuda(), target.cuda()

        return diag, img_features, target

In [252]:
SAMPLE_EASY = ['Data', 'sample_easy.json']
TRAIN_EASY = ['Data', 'Easy', 'IR_train_easy.json']
EASY_1000 = ['Data', 'Easy', 'IR_train_easy_1000.json']
VAL_200 = ['Data', 'Easy', 'IR_val_easy_200.json']
VALID_EASY = ['Data', 'Easy', 'IR_val_easy.json']
IMG_FEATURES = ['Data', 'Features', 'IR_image_features.h5']
INDEX_MAP = ['Data', 'Features', 'IR_img_features2id.json']

IMG_SIZE = 2048
EMBEDDING_DIM = 5

torch.manual_seed(1)
# dialog_data = DialogDataset(os.path.join(*SAMPLE_EASY), os.path.join(*IMG_FEATURES), os.path.join(*INDEX_MAP))
dialog_data = DialogDataset(os.path.join(*EASY_1000), os.path.join(*IMG_FEATURES), os.path.join(*INDEX_MAP))
valid_data = DialogDataset(os.path.join(*VAL_200), os.path.join(*IMG_FEATURES), os.path.join(*INDEX_MAP))

vocab_size = len(dialog_data.vocab)
dialog_data[0]

(
 
 Columns 0 to 12 
   119  2890   347  1752  3332    85  3757    85  4515  4515  4515  4515  4515
  4003  3332  4004  3757  4003  4515  4515  4515  4515  4515  4515  4515  4515
  1742  3516  1262  3332  4048  3757  1262  4515  4515  4515  4515  4515  4515
  1742  3516  2613  3077  2816  1778  3757  3105  3306  1605   119   347  1838
   119  2816  1778  2895  3332  2642  3757   890   119  2920  4515  4515  4515
  2532   516   119  2816  1838  3757  3920   560   695  4515  4515  4515  4515
  2532   516   119  2816  2920  3757  3473  3920  4515  4515  4515  4515  4515
  2532  1514   119  2816  1042  3757  2327  4515  4515  4515  4515  4515  4515
  1155  2816  1042  3945  4168   560  1817  3757  3436  4515  4515  4515  4515
  2532   516   119  2816  1042  3757  4098  4515  4515  4515  4515  4515  4515
   347   919  1818   119  2613  2941   162   347  1042  4515  4515  4515  4515
 
 Columns 13 to 18 
  4515  4515  4515  4515  4515  4515
  4515  4515  4515  4515  4515  4515
  4515  4515  

In [312]:
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class CBOW(torch.nn.Module):

    def __init__(self, vocab_size, embedding_dim, output_dim):
        super(CBOW, self).__init__()

        #out: 1 x emdedding_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear1 = nn.Linear(embedding_dim, 128)
        self.activation_function1 = nn.ReLU()
        
        #out: 1 x vocab_size
        self.linear2 = nn.Linear(128, output_dim)
        self.activation_function2 = nn.ReLU()
        
    def forward(self, inputs, batch_size = 1):
        # i believe .view() is useless here because the sum already produces a 1xEMB_DIM vector
        embeds = self.embeddings(inputs)
        sum_dim = 1 if batch_size > 1 else 0
        embeds = torch.sum(embeds, sum_dim)
        out = self.linear1(embeds)
        out = self.activation_function1(out)
        out = self.linear2(out)
        out = self.activation_function2(out)
        return out

In [313]:
class MaxEnt(torch.nn.Module):
    
    def __init__(self, text_module, vocab_size, img_size):
        super(MaxEnt, self).__init__()

        self.text_module = text_module
        self.linear = nn.Linear(vocab_size + img_size, 1)
        self.softmax = nn.LogSoftmax()
        
    def prepare (self, dialog, imgFeatures):
        text_features = self.text_module(Variable(dialog))
        text_features = text_features.expand(imgFeatures.size(0), text_features.size(1))
        concat = torch.cat((imgFeatures, text_features.data), 1)
        return concat
    
    def prepareBatch (self, batch):
        inputs = []
        targets = []
        for dialog, imgFeatures, target in batch:
            inputs.append(self.prepare(dialog, imgFeatures))
            targets.append(target)
        inputs = torch.cat(inputs)
        targets = torch.cat(targets)
        return Variable(inputs), Variable(targets)
        
    def forward(self, inp, batch_size = 1):
        scores = self.linear(inp).view(batch_size, -1)
        scores = self.softmax(scores)
        return scores

In [370]:
class MemNet(nn.Module):
    """NB: output size of text_module should match memory_dim"""
    def __init__(self, text_module, memory_dim, output_dim):
        super(MemNet, self).__init__()
        
        self.output_dim = output_dim
        self.memory_dim = memory_dim
        
        self.text_module = text_module
        self.linear = nn.Linear(memory_dim, output_dim)
    
    def forward(self, dialog, img_features):
        
        scores = torch.FloatTensor(len(img_features)).zero_()
        
        for i, img_feature in enumerate(img_features):
            history = self.text_module(dialog, batch_size = len(dialog))

            # inner product of history and current image (normalized to prevent extreme values in softmax)
            memory = history @ img_feature
            memory = torch.div(memory, (torch.norm(memory)))
            
            weights = F.softmax(memory)

            # Take weighted sum of history
            history_vector = torch.mm(weights.unsqueeze(0), history).squeeze()
            
            print(history_vector)

            out = self.linear(history_vector + img_feature)
            scores[i] = F.sigmoid(out).data[0]
        
        scores = F.log_softmax(scores)
        
        return scores
        

In [371]:
# Instantiate CBOW here for sentence embeddings:
torch.manual_seed(1)

EMBEDDING_DIM = 5
IMG_SIZE = 2048
OUTPUT_DIM = IMG_SIZE # For simplicity...

cbow_model = CBOW(vocab_size, EMBEDDING_DIM, OUTPUT_DIM)
mem_net = MemNet(cbow_model, OUTPUT_DIM, 1)
max_ent = MaxEnt(cbow_model, vocab_size, IMG_SIZE)

In [374]:
# MN SEQUENCE:
# Best paper + implementation description I was able to find:
# https://arxiv.org/pdf/1503.08895.pdf
# Where 'Question' in our case is one of the 10 images (or all 10 in one go).

# NOTE: all variable names are in accordance with section 2.1 from paper URL.

# Trying simple forward prediction. Test if right probabilities are generated:

# Set the number of memory items (1 caption + 10 QA's ):
dialog, images, target = dialog_data[0]
x = len(dialog)

for i in images:
    # Create an empty history matrix, [11 x 2048]
    m = torch.FloatTensor(x, IMG_SIZE).zero_()


    # calculate 'm' features for each exchange, a matrix of size 11 x 2048
    for idx, sentence in enumerate(dialog):
        m_i = cbow_model(Variable(sentence))
        m[idx] = m_i.data

    # m = m.expand()
    # Preferably find 'u' img features, a matrix of [2048 x 10]. For now just take one image to 
    # resemble sect. 2.1 as good as possible. u = [2048 x 1]
    u = i

    # inner product of history matrix and img features, produces 11 x 10 matrix 
    # but for now 11 x 1 since u = [2048 x 1]
    p = m @ u
    p.div_(torch.norm(p))
    
    print(p)

    # softmax the to get proper proabilities
    p = F.softmax(p)
    
    o = torch.FloatTensor(x, IMG_SIZE).zero_()
    # Weighted sum 'o', representation of memory output:
    # p_i * c_i, where c_i is in our case equal to m_i
    for idx, p_i in enumerate(p):
        o[idx] = p_i.data * m[idx]

    o = Variable(torch.sum(o, dim=0))


    u = Variable(u)

    print(o)

    # create a "Weight matrix" W:
    # W = nn.Linear(2048, 1)

    # # Finally, push u + o:
    # a = F.logsigmoid(W(u + o))
    # print(a)


 0.3218
 0.4076
 0.3182
 0.1017
 0.2462
 0.2959
 0.3088
 0.3268
 0.2701
 0.3438
 0.2778
[torch.FloatTensor of size 11]

Variable containing:
  0.0000
  1.1168
  2.0008
   ⋮    
  0.2572
  3.3487
  0.0000
[torch.FloatTensor of size 2048]


 0.3218
 0.4063
 0.3194
 0.1040
 0.2482
 0.2976
 0.3083
 0.3269
 0.2684
 0.3448
 0.2748
[torch.FloatTensor of size 11]

Variable containing:
  0.0000
  1.1166
  2.0008
   ⋮    
  0.2572
  3.3482
  0.0000
[torch.FloatTensor of size 2048]


 0.3233
 0.4072
 0.3163
 0.0988
 0.2475
 0.2976
 0.3077
 0.3250
 0.2700
 0.3428
 0.2816
[torch.FloatTensor of size 11]

Variable containing:
  0.0000
  1.1170
  2.0010
   ⋮    
  0.2569
  3.3489
  0.0000
[torch.FloatTensor of size 2048]


 0.3224
 0.4085
 0.3176
 0.1037
 0.2462
 0.2961
 0.3085
 0.3257
 0.2687
 0.3436
 0.2788
[torch.FloatTensor of size 11]

Variable containing:
  0.0000
  1.1173
  2.0010
   ⋮    
  0.2571
  3.3487
  0.0000
[torch.FloatTensor of size 2048]


 0.3212
 0.4095
 0.3170
 0.1028
 0.2463
 0.

In [373]:
%time mem_net(Variable(dialog), Variable(images))

Variable containing:
 0.3218
 0.4076
 0.3182
 0.1017
 0.2462
 0.2959
 0.3088
 0.3268
 0.2701
 0.3438
 0.2778
[torch.FloatTensor of size 11]

Variable containing:
  0.0000
  1.1168
  2.0008
   ⋮    
  0.2572
  3.3487
  0.0000
[torch.FloatTensor of size 2048]

Variable containing:
 0.3218
 0.4063
 0.3194
 0.1040
 0.2482
 0.2976
 0.3083
 0.3269
 0.2684
 0.3448
 0.2748
[torch.FloatTensor of size 11]

Variable containing:
  0.0000
  1.1166
  2.0008
   ⋮    
  0.2572
  3.3482
  0.0000
[torch.FloatTensor of size 2048]

Variable containing:
 0.3233
 0.4072
 0.3163
 0.0988
 0.2475
 0.2976
 0.3077
 0.3250
 0.2700
 0.3428
 0.2816
[torch.FloatTensor of size 11]

Variable containing:
  0.0000
  1.1170
  2.0010
   ⋮    
  0.2569
  3.3489
  0.0000
[torch.FloatTensor of size 2048]

Variable containing:
 0.3224
 0.4085
 0.3176
 0.1037
 0.2462
 0.2961
 0.3085
 0.3257
 0.2687
 0.3436
 0.2788
[torch.FloatTensor of size 11]

Variable containing:
  0.0000
  1.1173
  2.0010
   ⋮    
  0.2571
  3.3487
  0.000

Variable containing:
-2.2970
-2.3557
-2.3056
-2.2864
-2.2960
-2.3080
-2.3067
-2.2879
-2.2928
-2.2916
[torch.FloatTensor of size 10]

In [218]:
def validate(model, data, loss_func):
    total_loss = 0
    
    for i, (dialog, images, target) in enumerate(data):
        
        pred = model(dialog, images).unsqueeze(0)
        target = Variable(target)
        
        loss = loss_func(pred, target)
        total_loss += loss.data[0]

    return total_loss / len(data)

def predict(model, data):
    correct_top1 = 0
    correct_top5 = 0
    
    for i, (dialog, images, target) in enumerate(data):
        
        # For top 1:
        pred = model(dialog, images).unsqueeze(0)
        target = Variable(target)
        
        img, idx = torch.max(pred, 1)

        if idx.data[0] == target.data[0]:
            correct_top1 += 1
        
        # For top 5:
        pred = pred.data.numpy().flatten()
        top_5 = heapq.nlargest(5, range(len(pred)), pred.__getitem__)
        if target.data[0] in top_5:
            correct_top5 += 1
    
    return correct_top1 / len(data), correct_top5 / len(data)

# validate(mem_net, valid_data, nn.NLLLoss())
predict(mem_net, valid_data)

(0.06965174129353234, 0.4626865671641791)

In [None]:
def log_to_console(i, n_epochs, batch_size, batch_per_epoch, errors, start_time):
    avgProcessingSpeed = (i*batch_size) / (time.time() - start_time)
    percentOfEpc = (i / batch_per_epoch) * 100
    print("{:.0f}s:\t epoch: {}\t batch:{} ({:.1f}%) \t training error: {:.6f}\t speed: {:.1f} dialogs/s"
          .format(time.time() - start_time, 
                  n_epochs, 
                  i, 
                  percentOfEpc, 
                  np.mean(errors[-100:]), 
                  avgProcessingSpeed))
    
def init_stats_log(label, training_portion, validation_portion, embeddings_dim, epochs, batch_count):
    timestr = time.strftime("%m-%d-%H-%M")
    filename = "{}-t_size_{}-v_size_{}-emb_{}-eps_{}-dt_{}-batch_{}.txt".format(label,
                                                                       training_portion,
                                                                       validation_portion,
                                                                       EMBEDDING_DIM,
                                                                       epochs,
                                                                       timestr,
                                                                       batch_count)

    target_path = ['Training_recordings', filename]
    stats_log = open(os.path.join(*target_path), 'w')
    stats_log.write("EPOCH|AVG_LOSS|TOT_LOSS|VAL_ERROR|CORRECT_TOP1|CORRECT_TOP5\n")
    print("Logging enabled in:", filename)
    
    return stats_log, filename
    

In [None]:
batchSize = 30
numEpochs = 5
learningRate = 1e-1
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=learningRate)

startTime = time.time()
lastPrintTime = startTime

continueFromEpc = 0
continueFromI = 0
sampleCount = len(dialog_data)
batchCountPerEpc = int(sampleCount/batchSize)-1
remainderCount = sampleCount - batchCountPerEpc * batchSize
print("we have: {} dialogs, batch size of {} with {} as remainder to offset batches each epoch".format(sampleCount, batchSize, remainderCount))
offset = 0

logging = False

training_portion = len(dialog_data)
validation_portion = len(valid_data)

if logging == True:
    stats_log, filename = init_stats_log("test_top1_top2", 
                               training_portion,
                               validation_portion,
                               EMBEDDING_DIM,
                               numEpochs,
                               batchSize)

else:
    print("Logging disabled!")
    filename = ""

for t in range(numEpochs):
    lastPrintTime = time.time()
    epochStartTime = time.time()
    
    total_loss = 0
    updates = 0
    
    if t == 0 and continueFromI > 0:
        # continue where I crashed
        print("continuing")
        model.load_state_dict(torch.load('maxent_{}epc_{}iter.pt'.format(continueFromEpc, continueFromI+1)))
    
    for i in range(continueFromI, batchCountPerEpc):
        
        # In case of RNN, clear hidden state
        #model.hidden = steerNet.init_hidden(batchSize)
        
        batchBegin = offset + i * batchSize
        batchEnd = batchBegin + batchSize
        
        batch = dialog_data[batchBegin:batchEnd]
        inputs, targets = model.prepareBatch(batch)
        
        predictions = model(inputs, batchSize)
        
        loss = criterion(predictions.view(batchSize, -1), targets)
        training_errors.append(loss.data[0])
        total_loss += loss.data[0]
        
        model.zero_grad()
        loss.backward()
        optimizer.step()
        
        if time.time()  - lastPrintTime > 10:
            log_to_console(i, t, batchSize, batchCountPerEpc, training_errors, startTime)
            lastPrintTime = time.time()
    
    avg_loss = total_loss / training_portion
    top_1_score, top_5_score = predict(model, valid_data)
    validation_error = validate(model, valid_data, criterion)
    
    if logging == True:
        stats_log.write("{}|{}|{}|{}|{}|{}\n".format(epoch, avg_loss, total_loss, validation_error, top_1_score, top_5_score))
            
    epochsTrained += 1
    offset = (offset + 1) % remainderCount
    print("{:.1f}s:\t Finished epoch. Calculating test error..".format(time.time() - startTime))
    print("{:.1f}s:\t test error: {:.6f}".format(time.time() - startTime, validation_error))
    continueFromI = 0
    fileName = "maxent_{}batch_{}epc.pt".format(batchSize, epochsTrained)
    torch.save(model.state_dict(), fileName)
    print("saved\t", fileName)

if logging == True:
    stats_log.close()

In [None]:
def draw_graph(filename=trained_stats_file):
    
    
    # Read file and data
    with open("Training_recordings/" + filename, 'r') as f:
        data = [x.strip() for x in f.readlines()] 
    
    data = np.array([line.split("|") for line in data[1:]]).T
    
    epochs, avg_loss, total_loss, val_error, correct_top_1, correct_top_5 = data
    
    epochs = np.array(epochs, dtype=np.int8)
    
    plt.subplot(4, 1, 1)
    plt.plot(epochs, np.array(avg_loss, dtype=np.float32), '.-')
    plt.title('average loss, validation error and correct predictions')
    plt.ylabel('Average\nLoss')
    plt.xlabel('Epochs')

    plt.subplot(4, 1, 2)
    plt.plot(epochs, np.array(val_error, dtype=np.float32), '-')
    plt.ylabel('Validation\nLoss')
    plt.xlabel('Epochs')

    plt.subplot(4, 1, 3)
    plt.plot(epochs, np.array(correct_top_1, dtype=np.int8), '-')
    plt.ylabel('Correct\ntop 1')
    plt.xlabel('Epochs')

    plt.subplot(4, 1, 4)
    plt.plot(epochs, np.array(correct_top_5, dtype=np.int8), '-')
    plt.ylabel('Correct\ntop 5')
    plt.xlabel('Epochs')

    
    
    plt.show()

draw_graph()