In [2]:
%load_ext autoreload
%autoreload 2

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from data_loader import load_questions_answers, load_image_features
from tensorboardX import SummaryWriter
import progressbar as pb
import numpy as np
from attention_net import Attention_net
import copy

# Load QA Data

In [3]:
data_dir = 'data'
print("Reading QA DATA")
qa_data = load_questions_answers(token_type='word', version=2, data_dir=data_dir)
print("train questions", len(qa_data['training']))
print("val questions", len(qa_data['validation']))
print("answer vocab", len(qa_data['answer_vocab']))
print("question vocab", len(qa_data['question_vocab']))
print("max question length", qa_data['max_question_length'])

Reading QA DATA
train questions 412564
val questions 199148
answer vocab 3000
question vocab 15881
max question length 22


In [4]:
qa_data['training'][0]

{'image_id': 458752,
 'question': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 2.,
        3., 4., 5., 6., 7.]),
 'answer': 824}

# Load Image Data

In [52]:
print("Reading Image DATA")
train_image_features ,train_image_id_list = load_image_features(data_dir, 'train')
print("train image features", train_image_features.shape)
print("train image_id_list", train_image_id_list.shape)
val_image_features, val_image_id_list = load_image_features(data_dir, 'val')
print("val image features", val_image_features.shape)
print("val image_id_list", val_image_id_list.shape)

Reading Image DATA


  from ._conv import register_converters as _register_converters


train image features (82723, 1024, 7, 7)
train image_id_list (82723,)
val image features (40481, 1024, 7, 7)
val image_id_list (40481,)


# Change Image Feature Dimension 

In [53]:
train_image_features = torch.from_numpy(train_image_features)
train_image_features = train_image_features.permute(0, 2, 3, 1)
train_image_features = train_image_features.view(train_image_features.size(0), -1, train_image_features.size(3))
train_image_features.size()

torch.Size([82723, 49, 1024])

In [54]:
val_image_features = torch.from_numpy(val_image_features)
val_image_features = val_image_features.permute(0, 2, 3, 1)
val_image_features = val_image_features.view(val_image_features.size(0), -1, val_image_features.size(3))
val_image_features.size()

torch.Size([40481, 49, 1024])

# Define Data Loader 

In [55]:
def sample_batch(batch_no, batch_size, features, image_id_map, qa, split):
  si = (batch_no * batch_size)%len(qa)
  ei = min(len(qa), si + batch_size)
  n = ei - si
  sentence = np.ndarray( (n, qa_data['max_question_length']), dtype=int) # [N, 22]
  answers = np.zeros( (n, len(qa_data['answer_vocab'])), dtype=int) # [N,]
  fc7 = torch.empty( (n,49,1024) ) # [N, 49, 1024]

  count = 0
  for i in range(si, ei):
    sentence[count,:] = qa[i]['question'][:]
#     answers[count, qa[i]['answer']] = 1
    answers[count] = qa[i]['answer']
    fc7_index = image_id_map[ qa[i]['image_id'] ]
    fc7[count,:,:] = features[fc7_index, :, :]
    count += 1
  
  return fc7, torch.from_numpy(sentence), torch.from_numpy(answers)

In [56]:
train_image_id_map = {image_id: i for i, image_id in enumerate(train_image_id_list)}
val_image_id_map = {image_id: i for i, image_id in enumerate(val_image_id_list)}

# Train 

In [5]:
model = Attention_net()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
num_params = 0
for param in model.parameters():
    num_params += param.numel()
print("Num parameters {}".format(num_params))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# if torch.cuda.device_count() > 1:
#     print("Use", torch.cuda.device_count(), "GPUs!")
#     model = nn.DataParallel(model, device_ids=[0,1,2,3])
model = model.to(device)

Num parameters 21433694


In [6]:
model

Attention_net(
  (img_emb): Linear(in_features=1024, out_features=512, bias=True)
  (que_emb): Embedding(15881, 512)
  (att0): Attention_layer(
    (att_layer): Attention_1(
      (fc): Linear(in_features=512, out_features=1, bias=True)
    )
    (nonlinear_1): Nonlinear_layer(
      (fc1): Linear(in_features=512, out_features=512, bias=True)
      (fc2): Linear(in_features=512, out_features=512, bias=True)
    )
    (nonlinear_2): Nonlinear_layer(
      (fc1): Linear(in_features=512, out_features=512, bias=True)
      (fc2): Linear(in_features=512, out_features=512, bias=True)
    )
  )
  (att1): Attention_layer(
    (att_layer): Attention_1(
      (fc): Linear(in_features=512, out_features=1, bias=True)
    )
    (nonlinear_1): Nonlinear_layer(
      (fc1): Linear(in_features=512, out_features=512, bias=True)
      (fc2): Linear(in_features=512, out_features=512, bias=True)
    )
    (nonlinear_2): Nonlinear_layer(
      (fc1): Linear(in_features=512, out_features=512, bias=True)
   

In [64]:
num_epoch = 100
batch_size = 128
writer = SummaryWriter()

In [65]:
best_model_wts = copy.deepcopy(model.state_dict())
for epoch in range(num_epoch):
    pbar = pb.ProgressBar()
    model.train()
    loss_value = 0.0
    correct = 0.0

    # Train
    train_qa_data = qa_data['training']
    for j in pbar(range(len(train_qa_data) // batch_size)):
        img_features, que_features, answers = sample_batch(j, batch_size, \
                                                           train_image_features, train_image_id_map, \
                                                           train_qa_data, 'train')
        
        img_features, que_features, answers = torch.randn(128, 49, 1024),\
        torch.randint(100, size=(128, 22), dtype=torch.long), torch.randint(100, size=(128,), dtype=torch.long)

        img_features = img_features.to(device)
        que_features = que_features.to(device)
        answers = answers.to(device)
#         print(img_features.shape, que_features.shape, answers.shape)
        
        pred, que_att, img_att = model(img_features, que_features)
        
#         print(pred.shape)
        loss = criterion(pred, answers)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_value += loss.data[0]
        pred = pred.data.max(1)[1] # get the index of the max log-probability
        correct += pred.eq(answers.data).cpu().sum()
        
    print("Train epoch {}, loss {}, acc {}".format(epoch,
            loss_value / (len(train_qa_data) // batch_size),
            correct / (len(train_qa_data) // batch_size * batch_size)))

#     if epoch > 20 and epoch % 10 == 0:
#         for param_group in early_optimizer.param_groups:
#             param_group['lr'] *= 0.5
    train_epoch_loss = loss_value / (len(train_qa_data) // batch_size)
    model.eval()

#     for module in model.modules():
#         if module.__class__.__name__.find("BatchNorm") > -1:
#             module.train()
#             # BatchNorm for some reasons is not stable in eval

    loss_value = 0.0
    correct = 0.0
    pbar = pb.ProgressBar()

    # Evaluate
    count = 0
    prev_val_epoch_loss = 200.0
    val_qa_data = qa_data['validation']
    for j in pbar(range(len(val_qa_data) // batch_size)):
        img_features, que_features, answers = sample_batch(j, batch_size, \
                                                           val_image_features, val_image_id_map, \
                                                           val_qa_data, 'val')
        
        img_features = img_features.to(device)
        que_features = que_features.to(device)
        answers = answers.to(device)
        
        pred, que_att, img_att = model(img_features, que_features)

        loss = criterion(pred, answers)
        
        loss_value += loss.data[0]
        pred = pred.data.max(1)[1] # get the index of the max log-probability
        correct += pred.eq(answer.data).cpu().sum()

    print("Test epoch {}, loss {}, acc {}".format(epoch,
                    loss_value / (len(val_qa_data) /batch_size),
                    correct / (len(val_qa_data) // batch_size * batch_size)))
    val_epoch_loss = loss_value / (len(train_qa_data) // batch_size)
    
    if val_epoch_loss < prev_val_epoch_loss:
        prev_val_epoch_loss = val_epoch_loss
        best_model_wts = copy.deepcopy(model.state_dict())
        count = 0
    else:
        count += 1
        if count >= 5:
            break
    writer.add_scalars('att1_hard', {'train_loss': train_epoch_loss, 'val_loss': val_epoch_loss}, epoch)
    
# load best model weights
model.load_state_dict(best_model_wts)
torch.save(model.state_dict(), './att1_hard.pth')

                                                                               N/A% (0 of 3223) |                       | Elapsed Time: 0:00:00 ETA:  --:--:--

RuntimeError: cuda runtime error (77) : an illegal memory access was encountered at /opt/conda/conda-bld/pytorch_1524584710464/work/aten/src/THC/THCCachingHostAllocator.cpp:257