In [1]:
import sys, os, re, json, time

import pandas as pd
import pickle
import h5py

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
import plotting
from PIL import Image
from tqdm import tqdm
from utils import imread, img_data_2_mini_batch, imgs2batch

from sklearn import metrics

from naive import Enc, Dec, EncDec
from data_loader import VQADataSet

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
from torchvision import transforms

%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
N = 2000
dataset_filename = "./data/data_{}.pkl".format(N)
dataset = None
print(dataset_filename)
if (os.path.exists(dataset_filename)):
    with open(dataset_filename, 'rb') as handle:
        print("reading from " + dataset_filename)
        dataset = pickle.load(handle)
else:
    dataset = VQADataSet(Q=N)
    with open(dataset_filename, 'wb') as handle:
        print("writing to " + dataset_filename)
        pickle.dump(dataset, handle)

assert(dataset is not None)
def debug(v,q,a):
    print('\nV: {}\nQ: {}\nA: {}'.format(v.shape, q.shape, a.shape))


./data/data_2000.pkl
reading from ./data/data_2000.pkl


In [3]:
embed_size        = 128
hidden_size       = 128
batch_size        = 32
ques_vocab_size   = len(dataset.vocab['question'])
ans_vocab_size    = len(dataset.vocab['answer'])
rnn_layers        = 1
n_epochs          = 25
learning_rate     = 0.01
momentum          = 0.98

print(ques_vocab_size, ans_vocab_size)

1469 1282


In [4]:
def eval_model(data_loader, model, criterion, optimizer, batch_size, training=False,
              total_loss_over_epochs=[], scores_over_epochs=[]):
    running_loss = 0.
    final_labels, final_preds = [], []
    if data_loader is None:
        return
    
    if training:
        model.train()
    else:
        model.eval()
    
    for i, minibatch in enumerate(data_loader):
        # extract minibatch
        t0 = time.time()
        idxs, v, q, a, q_len = minibatch
        
        # convert torch's DataLoader output to proper format.
        # torch gives a List[Tensor_1, ... ] where tensor has been transposed. 
        # batchify transposes back.`
        v = v.to(device)
        q = VQADataSet.batchify_questions(q).to(device)
        a = a.to(device)

        logits = model(v, q, q_len)
        preds = torch.argmax(logits, dim=1)
        
        loss = F.nll_loss(logits, a)
        running_loss += loss.item()
        
        score = metrics.precision_recall_fscore_support(preds.tolist(),
                                                        a.tolist(),
                                                        average='weighted')
        
        total_loss_over_epochs['train_loss'].append(loss)
        scores_over_epochs['train_scores'].append(score)
        
        if training and optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        final_labels += a.tolist()
        final_preds  += preds.tolist()
        if True:#(i%20==0):
#             plotting.plot_score_over_n_epochs(scores_over_epochs, score_type='precision', fig_size=(7,3))
#             plotting.plot_loss_over_n_epochs(total_loss_over_epochs, fig_size=(7, 3), title="Loss")
            print("Loss: {} - score: {} - t: {}".format(loss, score, time.time()-t0))
            
    return running_loss, final_labels, final_preds

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = EncDec(embed_size, hidden_size, ques_vocab_size, ans_vocab_size, rnn_layers).to(device)

model = EncDec(embed_size, hidden_size, ques_vocab_size, ans_vocab_size, rnn_layers).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.get_parameters(), lr=learning_rate)


In [6]:

train_loader = dataset.build_data_loader(train=True, args={'batch_size': batch_size})
test_loader  = dataset.build_data_loader(test=True, args={'batch_size': batch_size})

best_score = 0

train_all_loss, train_all_labels, train_all_preds = [], [], []

total_loss_over_epochs, scores_over_epochs = plotting.get_empty_stat_over_n_epoch_dictionaries()

for epoch in tqdm(range(n_epochs)):
    t0= time.time()
    tr_loss, tr_labels, tr_preds = eval_model(data_loader = train_loader,
                                     model       = model,
                                     criterion   = criterion,
                                     optimizer   = optimizer,
                                     batch_size  = batch_size,
                                     training    = True,
                                     total_loss_over_epochs = total_loss_over_epochs,
                                     scores_over_epochs     = scores_over_epochs)
    
#     train_scores = metrics.precision_recall_fscore_support(tr_labels,
#                                                            tr_preds,
#                                                            average='weighted')
    
#     total_loss_over_epochs['train_loss'].append(tr_loss)
#     scores_over_epochs['train_scores'].append(train_scores)
    
#     if True:# or epoch%1 == 0:
#         print("#==#"*5 + "epoch: {}".format(epoch) + "#==#"*5)
#         print("time: {}".format(time.time()-t0))
#         print(train_scores)
#     plotting.plot_score_over_n_epochs(scores_over_epochs, score_type='precision', fig_size=(8,5))
#     plotting.plot_loss_over_n_epochs(total_loss_over_epochs, fig_size=(8, 5), title="Loss")
    
    
    
    

  0%|          | 0/25 [00:00<?, ?it/s]

batch_size: 32 shuffle: True
batch_size: 32 shuffle: False


  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)


Loss: 7.14402961730957 - score: (0.0, 0.0, 0.0, None) - t: 0.26078081130981445
Loss: 6.908522605895996 - score: (0.0234375, 0.03125, 0.026785714285714288, None) - t: 0.23926687240600586
Loss: 6.672055721282959 - score: (0.59375, 0.125, 0.19949494949494945, None) - t: 0.2002110481262207
Loss: 6.0676422119140625 - score: (0.6774553571428571, 0.15625, 0.2153897849462365, None) - t: 0.19977331161499023
Loss: 5.618424415588379 - score: (0.44062500000000004, 0.15625, 0.23069221967963385, None) - t: 0.1987471580505371
Loss: 5.6899800300598145 - score: (0.6805555555555556, 0.25, 0.3654411764705882, None) - t: 0.19895529747009277
Loss: 6.164545059204102 - score: (1.0, 0.15625, 0.2702702702702703, None) - t: 0.19988417625427246
Loss: 7.592854022979736 - score: (1.0, 0.03125, 0.06060606060606061, None) - t: 0.1997828483581543
Loss: 7.332056999206543 - score: (1.0, 0.15625, 0.2702702702702703, None) - t: 0.20004916191101074
Loss: 5.667515754699707 - score: (1.0, 0.21875, 0.358974358974359, None) -

Loss: 5.263310432434082 - score: (0.328125, 0.1875, 0.22499999999999998, None) - t: 0.21109247207641602
Loss: 5.795498371124268 - score: (0.15625, 0.09375, 0.11458333333333333, None) - t: 0.1995549201965332
Loss: 6.6850972175598145 - score: (0.28125, 0.09375, 0.13125, None) - t: 0.20050811767578125
Loss: 6.179559707641602 - score: (0.03125, 0.03125, 0.03125, None) - t: 0.20031213760375977
Loss: 5.755373001098633 - score: (0.46875, 0.125, 0.19736842105263158, None) - t: 0.19981622695922852
Loss: 5.314488887786865 - score: (0.5, 0.09375, 0.15520833333333334, None) - t: 0.19871997833251953
Loss: 5.669483184814453 - score: (0.484375, 0.28125, 0.34374999999999994, None) - t: 0.1999819278717041
Loss: 5.829525470733643 - score: (0.2708333333333333, 0.0625, 0.1015625, None) - t: 0.24324822425842285
Loss: 3.9770607948303223 - score: (0.37053571428571425, 0.28125, 0.3189338235294118, None) - t: 0.2438960075378418
Loss: 4.803102016448975 - score: (0.140625, 0.09375, 0.10714285714285715, None) - t

  4%|▍         | 1/25 [00:58<23:27, 58.64s/it]

Loss: 4.994266986846924 - score: (0.65625, 0.15625, 0.24818840579710144, None) - t: 0.1993546485900879
Loss: 6.107772350311279 - score: (0.5, 0.1, 0.16666666666666669, None) - t: 0.06648540496826172
Loss: 5.90203332901001 - score: (0.0, 0.0, 0.0, None) - t: 0.19936466217041016
Loss: 4.563817501068115 - score: (0.42857142857142855, 0.1875, 0.26086956521739124, None) - t: 0.20107436180114746
Loss: 5.312638282775879 - score: (0.40625, 0.03125, 0.058035714285714295, None) - t: 0.20245051383972168
Loss: 6.0045881271362305 - score: (0.9375, 0.0625, 0.1171875, None) - t: 0.20288515090942383
Loss: 5.186038970947266 - score: (0.9375, 0.0625, 0.1171875, None) - t: 0.20037603378295898
Loss: 5.112027645111084 - score: (0.96875, 0.15625, 0.26909722222222215, None) - t: 0.21403098106384277
Loss: 5.427913665771484 - score: (0.96875, 0.0625, 0.11742424242424243, None) - t: 0.22737836837768555
Loss: 5.606229782104492 - score: (0.96875, 0.125, 0.22142857142857145, None) - t: 0.2007899284362793
Loss: 5.6

KeyboardInterrupt: 

In [None]:
print(type(tr_labels))
print(type(tr_preds))


In [None]:
print(tr_labels[0])
print(tr_preds[0])