In [1]:
import sys, os, re, json, time

import pandas as pd
import pickle
import h5py

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
import plotting
from PIL import Image
from tqdm import tqdm
from utils import imread, img_data_2_mini_batch, imgs2batch

from sklearn import metrics

from attention import Enc, Dec, EncDec
from data_loader import VQADataSet

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
from torchvision import transforms

%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
N = 2000
dataset_filename = "./data/data_{}.pkl".format(N)
dataset = None
print(dataset_filename)
if (os.path.exists(dataset_filename)):
    with open(dataset_filename, 'rb') as handle:
        print("reading from " + dataset_filename)
        dataset = pickle.load(handle)
else:
    dataset = VQADataSet(Q=N)
    with open(dataset_filename, 'wb') as handle:
        print("writing to " + dataset_filename)
        pickle.dump(dataset, handle)

assert(dataset is not None)
def debug(v,q,a):
    print('\nV: {}\nQ: {}\nA: {}'.format(v.shape, q.shape, a.shape))


./data/data_2000.pkl
reading from ./data/data_2000.pkl


In [3]:
embed_size        = 300
hidden_size       = 1024
batch_size        = 32
ques_vocab_size   = len(dataset.vocab['question'])
c                 = len(dataset.vocab['answer'])
num_layers        = 1
n_epochs          = 20
learning_rate     = 0.01
momentum          = 0.98
attention_size    = 512
debug             = False

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = EncDec(embed_size, hidden_size, attention_size, ques_vocab_size, c, num_layers, debug).to(device)

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


In [4]:
c

1282

In [5]:
def eval_model(data_loader, model, criterion, optimizer, batch_size, training=False,
              total_loss_over_epochs=[], scores_over_epochs=[]):
    running_loss = 0.
    final_labels, final_preds = [], []
    if data_loader is None:
        return
    
    if training:
        model.train()
    else:
        model.eval()
    
    for i, minibatch in enumerate(data_loader):
        # extract minibatch
        t0 = time.time()
        idxs, v, q, a, q_len = minibatch
        
        # convert torch's DataLoader output to proper format.
        # torch gives a List[Tensor_1, ... ] where tensor has been transposed. 
        # batchify transposes back.`
        v = v.to(device)
        q = VQADataSet.batchify_questions(q).to(device)
        a = a.to(device)
                


        logits = model(v, q, q_len)
        
        loss = F.nll_loss(logits, a)
    
        print(loss.item())
    
    
    
    


In [6]:

train_loader = dataset.build_data_loader(train=True, args={'batch_size': batch_size})
test_loader  = dataset.build_data_loader(test=True, args={'batch_size': batch_size})

best_score = 0

train_all_loss, train_all_labels, train_all_preds = [], [], []

total_loss_over_epochs, scores_over_epochs = plotting.get_empty_stat_over_n_epoch_dictionaries()

for epoch in tqdm(range(n_epochs)):
    t0= time.time()
    tr_loss, tr_labels, tr_preds = eval_model(data_loader = train_loader,
                                     model       = model,
                                     criterion   = criterion,
                                     optimizer   = optimizer,
                                     batch_size  = batch_size,
                                     training    = True,
                                     total_loss_over_epochs = total_loss_over_epochs,
                                     scores_over_epochs     = scores_over_epochs)
    
#     train_scores = metrics.precision_recall_fscore_support(tr_labels,
#                                                            tr_preds,
#                                                            average='weighted')
    
#     total_loss_over_epochs['train_loss'].append(tr_loss)
#     scores_over_epochs['train_scores'].append(train_scores)
    
#     if True:# or epoch%1 == 0:
#         print("#==#"*5 + "epoch: {}".format(epoch) + "#==#"*5)
#         print("time: {}".format(time.time()-t0))
#         print(train_scores)
#     plotting.plot_score_over_n_epochs(scores_over_epochs, score_type='precision', fig_size=(8,5))
#     plotting.plot_loss_over_n_epochs(total_loss_over_epochs, fig_size=(8, 5), title="Loss")
    
    
    
    

  0%|          | 0/20 [00:00<?, ?it/s]

batch_size: 32 shuffle: True
batch_size: 32 shuffle: False
7.1504316329956055
7.165352821350098
7.135380268096924
7.083408832550049
7.1949944496154785
7.144080638885498
7.249777317047119
7.112883567810059
7.164816379547119
7.1680588722229
7.194150447845459
7.140851020812988
7.109874248504639
7.156062602996826
7.206156253814697
7.152928829193115
7.139535903930664
7.182536602020264
7.058390140533447
7.139483451843262
7.156867027282715
7.161696434020996
7.250360488891602
7.154489040374756
7.185787677764893
7.0475382804870605
7.06500244140625
7.276723861694336
7.217381954193115
7.216460704803467
7.251987934112549
7.2334160804748535
7.085299968719482
7.169120788574219
7.238846778869629
7.13566255569458
7.288168430328369
7.174676418304443
7.103623867034912
7.223609447479248
7.115533351898193
7.102667808532715
7.230281352996826
7.127016544342041
7.176954746246338
7.182290077209473
7.189122200012207
7.247584342956543
7.126515865325928
7.12155294418335
7.224917411804199
7.234164237976074
7.1080

TypeError: 'NoneType' object is not iterable

In [None]:
# print(type(tr_labels))
# print(type(tr_preds))


In [None]:
# print(tr_labels[0])
# print(tr_preds[0])