In [1]:
import sys, os, re, json

import pandas as pd
import pickle
import h5py

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
from PIL import Image
from utils import imread, img_data_2_mini_batch, imgs2batch

from sklearn import metrics

from naive import Enc, Dec
from data_loader import VQADataSet

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
from torchvision import transforms

%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
N = 2000
dataset_filename = "./data/data_{}.pkl".format(N)
dataset = None
print(dataset_filename)
if (os.path.exists(dataset_filename)):
    with open(dataset_filename, 'rb') as handle:
        print("reading from " + dataset_filename)
        dataset = pickle.load(handle)
else:
    dataset = VQADataSet(Q=N)
    with open(dataset_filename, 'wb') as handle:
        print("writing to " + dataset_filename)
        pickle.dump(dataset, handle)

assert(dataset is not None)

./data/data_2000.pkl
reading from ./data/data_2000.pkl


In [3]:
embed_size        = 128
hidden_size       = 128
batch_size        = 32
ques_vocab_size   = len(dataset.vocab['question'])
ans_vocab_size    = len(dataset.vocab['answer'])
rnn_layers        = 1
n_epochs          = 1


In [4]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

encoder = Enc(embed_size).to(device)
decoder = Dec(embed_size, hidden_size, ques_vocab_size, ans_vocab_size, rnn_layers).to(device)

criterion = nn.CrossEntropyLoss()
params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
optimizer = torch.optim.Adam(params, lr=0.001)

print("device: {}".format(device))
# print(encoder)
# print(decoder)

Downloading: "https://download.pytorch.org/models/resnet152-b121ed2d.pth" to /home/nsimsiri/.torch/models/resnet152-b121ed2d.pth
100%|██████████| 241530880/241530880 [00:03<00:00, 77313050.98it/s] 


device: cuda:0


In [6]:
def eval_model(data_loader=None, batch_size=batch_size, epoch=1):
    if data_loader is None:
        return
    for i, minibatch in enumerate(data_loader):
        # extract minibatch
        idxs, v, q, a, q_len = minibatch
        
        # convert torch's DataLoader output to proper format.
        # torch gives a List[Tensor_1, ... ] where tensor has been transposed. 
        # batchify transposes back.
        v = v.to(device)
        q = VQADataSet.batchify_questions(q).to(device)
        a = a.to(device)
        print("")
#         print('V: ', v.shape)
#         print('Q: ', q.shape)
#         print('A: ', a.shape)

        img_features = encoder(v)
        print("img_features", img_features.shape)

        pred = decoder(img_features, q, q_len)
        print("pred", pred.shape)

        loss = criterion(pred, a)          

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
            
    



In [7]:
train_loader = dataset.build_data_loader(train=True, args={'batch_size': batch_size})

eval_model(data_loader = train_loader, batch_size = batch_size)

batch_size: 32 shuffle: True

img_features torch.Size([32, 128])
pred torch.Size([32, 1282])
epoch: 1 # 0 loss: 7.162944316864014

img_features torch.Size([32, 128])
pred torch.Size([32, 1282])
epoch: 1 # 1 loss: 7.1589789390563965

img_features torch.Size([32, 128])
pred torch.Size([32, 1282])
epoch: 1 # 2 loss: 7.148438930511475

img_features torch.Size([32, 128])
pred torch.Size([32, 1282])
epoch: 1 # 3 loss: 7.15594482421875

img_features torch.Size([32, 128])
pred torch.Size([32, 1282])
epoch: 1 # 4 loss: 7.147489070892334

img_features torch.Size([32, 128])
pred torch.Size([32, 1282])
epoch: 1 # 5 loss: 7.139472484588623

img_features torch.Size([32, 128])
pred torch.Size([32, 1282])
epoch: 1 # 6 loss: 7.119172096252441

img_features torch.Size([32, 128])
pred torch.Size([32, 1282])
epoch: 1 # 7 loss: 7.176149368286133

img_features torch.Size([32, 128])
pred torch.Size([32, 1282])
epoch: 1 # 8 loss: 7.138587474822998

img_features torch.Size([32, 128])
pred torch.Size([32, 1282]

epoch: 1 # 80 loss: 6.863967418670654

img_features torch.Size([32, 128])
pred torch.Size([32, 1282])
epoch: 1 # 81 loss: 6.5897650718688965

img_features torch.Size([32, 128])
pred torch.Size([32, 1282])
epoch: 1 # 82 loss: 6.203444480895996

img_features torch.Size([32, 128])
pred torch.Size([32, 1282])
epoch: 1 # 83 loss: 7.221890926361084

img_features torch.Size([32, 128])
pred torch.Size([32, 1282])
epoch: 1 # 84 loss: 6.035122394561768

img_features torch.Size([32, 128])
pred torch.Size([32, 1282])
epoch: 1 # 85 loss: 6.0488104820251465

img_features torch.Size([32, 128])
pred torch.Size([32, 1282])
epoch: 1 # 86 loss: 6.487653732299805

img_features torch.Size([32, 128])
pred torch.Size([32, 1282])
epoch: 1 # 87 loss: 6.602327823638916

img_features torch.Size([32, 128])
pred torch.Size([32, 1282])
epoch: 1 # 88 loss: 5.422983169555664

img_features torch.Size([32, 128])
pred torch.Size([32, 1282])
epoch: 1 # 89 loss: 6.330343246459961

img_features torch.Size([32, 128])
pred t