In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import json
import h5py
from utils import img_data_2_mini_batch, imgs2batch
import numpy as np
import matplotlib.pyplot as plt
import torch.utils.data as Data
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
import sys

%matplotlib inline
%reload_ext autoreload

In [2]:
base_f = 'cocoqa_data_prepro_'
base_n = '93'
base_fn = base_f + base_n
transform = transforms.Compose([
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))])


In [3]:
val_data_h5 = h5py.File(base_fn+'.h5', 'r')


In [4]:
val_data_json = json.load(open(base_fn+'.json', 'r'))
# pad fix
itow = val_data_json['ix_to_word']

wtoi = {iv: ik for ik,iv in itow.iteritems()}
old_pad = wtoi['<pad>']
wtoi['<pad_fix>'] = old_pad
wtoi['<pad>'] = '0'
itow[old_pad] = '<pad_fix>'
itow['0'] = '<pad>'


assert(wtoi['<pad>'] == '0')
assert(itow['0'] == '<pad>')


In [5]:
itoa = val_data_json['ix_to_ans']

unique_img_val = val_data_json['unique_img_val']
ques_val = val_data_h5['ques_val'][:]
ans_val = val_data_h5['ans_val'][:]
question_id_val = val_data_h5['question_id_val'][:]
img_pos_val = val_data_h5['img_pos_val'][:]
images = np.array(imgs2batch(unique_img_val, img_pos_val, transform=transform))
ques_val = np.array(ques_val)
ans_val = np.array(ans_val).reshape((-1, 1))

images = torch.from_numpy(images)
ques_val = torch.from_numpy(ques_val)
ans_val = torch.from_numpy(ans_val)

# for i in range(ques_val.size(1)):
#     i += 60
    
#     _img = images[i]
#     _img = _img.detach().numpy()
#     plt.figure()
#     plt.imshow(_img)
#     print 'Question: ' + ' '.join(filter(lambda kx: kx!='<pad>',(map(lambda wr: itow[str(wr)], ques_val[i].detach().numpy().tolist()))))
#     print 'Answer: ' + ' '.join(map(lambda wr: itoa[str(wr)], ans_val[i].detach().numpy().tolist()))
#     print 
#     break



In [6]:
ques_ans_val = torch.cat((ques_val, ans_val), dim=1)
BATCH_SIZE = 20
split_point = int(0.1 * ques_ans_val.size(0)) # split 10% for testing

ques_ans_splits = torch.split(ques_ans_val, split_point, dim=0)
images_splits = torch.split(images, split_point, dim=0)

ques_ans_test = ques_ans_splits[0]
ques_ans_train = torch.cat(ques_ans_splits[1:], dim=0)

images_test = images_splits[0]
images_train = torch.cat(images_splits[1:], dim=0)

# should be (torch.Size([TRAIN_SIZE, 3, 224, 224]), torch.Size([TRAIN_SIZE, MAX_LENGTH]))
print(images_train.size(), ques_ans_train.size()) 
# should be (torch.Size([TEST_SIZEZ, 3, 224, 224]), torch.Size([TEST_SIZE, MAX_LENGTH]))
print(images_test.size(), ques_ans_test.size())

train_dataset=Data.TensorDataset(images_train, ques_ans_train)
test_dataset=Data.TensorDataset(images_test, ques_ans_test)
train_loader = Data.DataLoader(
        dataset=train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True
    )

test_loader = Data.DataLoader(
        dataset=test_dataset,
        shuffle=False
    )


(torch.Size([84, 3, 224, 224]), torch.Size([84, 27]))
(torch.Size([9, 3, 224, 224]), torch.Size([9, 27]))


In [9]:
# from naive import Enc, Dec
from att_model import Enc, Dec
device = torch.device('cpu')
embed_size = 8
hidden_size = 8
ques_vocab_size = len(itow)
ans_vocab_size = len(itoa)+1
num_layers = 1

print 'embed',embed_size,'hidden',hidden_size,'ques_vocab',ques_vocab_size, 'ans_vocab',ans_vocab_size
encoder = Enc(embed_size).to(device)
decoder = Dec(embed_size, hidden_size, ques_vocab_size, ans_vocab_size, num_layers)
# encoder.double()


embed 8 hidden 8 ques_vocab 212 ans_vocab 39


In [10]:
print decoder

Dec(
  (language_model): LanguageModel(
    (embed): Embedding(212, 8)
    (lstm): LSTM(8, 8, batch_first=True)
    (linear): Linear(in_features=8, out_features=39, bias=True)
  )
  (attention_model): AttentionModel(
    (conv1): Conv2d(2048, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (fc1): Linear(in_features=8, out_features=64, bias=True)
    (conv2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    (drop1): Dropout(p=0.0)
    (relu): ReLU(inplace)
  )
  (classifier): Classifier(
    (drop1): Dropout(p=0.5)
    (lin1): Linear(in_features=4104, out_features=512, bias=True)
    (relu): ReLU()
    (drop2): Dropout(p=0.5)
    (lin2): Linear(in_features=512, out_features=39, bias=True)
  )
)


In [8]:
# optimizer and loss
params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
optimizer = torch.optim.Adam(params,lr=0.01)
encoder.train()
decoder.train()

# start your train
lossList = []
accList = []
for epoch in range(100):
    for i, (images, img_ans_val) in enumerate(train_loader):
        ques, ans = torch.split(img_ans_val, 26,dim=1)
        # images of shape [batch, 3, 256, 256]
        # ques of shape [batch, 26]
        # ans of shape [batch, 1]

        lengths = []
        for qix in ques:
            for iy in range(len(qix)):
                if (qix[iy]==0):
                    lengths.append(iy)
                    break;
        tups = []
        for ix in range(ques.size(0)):
            row = ques[ix,:]
            length = lengths[ix]
            image_i = images[ix,:]
            ans_i = ans[ix,:]
            tup = (row, length, image_i, ans_i)
            tups.append(tup)

        sorted_tuples = sorted(tups, key=lambda tup: tup[1], reverse=True)
        questions = torch.stack(list(map(lambda tup: tup[0], sorted_tuples)))
        images = torch.stack(list(map(lambda tup: tup[2], sorted_tuples)))
        answers = torch.stack(list(map(lambda tup: tup[3], sorted_tuples)))
        lengths = list(map(lambda tup: tup[1], sorted_tuples))
    
        images = images.to(device)
        questions = questions.to(device).long()
        raw_features, features = encoder(images)
        output = decoder(raw_features, features, questions, lengths)
        sys.exit()
#         answers = answers.reshape((-1)).long()
        
#         loss = F.nll_loss(output, answers)
        
#         # copy here
#         lossList.append(loss.item())
        
#         _, pred = torch.max(output, dim=1)
        
#         correct = pred.eq(answers.long().view_as(pred)).sum()
#         acc = float(correct) / float(BATCH_SIZE)
        
#         accList.append(acc)
        
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         print('epoch',epoch,'#', i, 'loss:', loss.item(), 'acc:', acc, 'correct:', correct)
#     break
        
     

  attention = F.softmax(attention)


SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
# plt.figure()
# plt.plot(range(len(lossList)), lossList, 'ro')
# plt.show()

# plt.figure()
# plt.plot(range(len(accList)), accList, 'ro')
# plt.show()