In [16]:
%load_ext autoreload
%autoreload 2

import math
import numpy as np
import copy
from tqdm import *
import matplotlib.pyplot as plt
%matplotlib inline
plt.ion()

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models
from torch.nn.utils.weight_norm import weight_norm
from torch import optim
from torch.optim import lr_scheduler

from language_model import WordEmbedding, QuestionEmbedding
from fc import FCNet
from utils import *
from modify_program import *
from dataset import *


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = torch.device('cpu')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
def make_answer_prob(yes_prob):
    yes_id = answer2id['yes']
    no_id = answer2id['no']
    prob = torch.zeros((1, len(answer_vocab))).to(device)
    prob[0, yes_id] = yes_prob
    prob[0, no_id] = 1. - yes_prob
    return prob

In [3]:
class And(nn.Module):
    def __init__(self):
        super(And, self).__init__()
    
    def forward(self, v, p0, p1):
        yes_id = answer2id['yes']
        p0 = p0[0, yes_id]
        p1 = p1[0, yes_id]
        yes_prob = p0 * p1
        prob = make_answer_prob(yes_prob)
        return prob
    
class Or(nn.Module):
    def __init__(self):
        super(Or, self).__init__()
    
    def forward(self, v, p0, p1):
        yes_id = answer2id['yes']
        p0 = p0[0, yes_id]
        p1 = p1[0, yes_id]
        yes_prob = p0 + p1 - p0 * p1
        prob = make_answer_prob(yes_prob)
        return prob
    
class AttentionAnd(nn.Module):
    def __init__(self):
        super(AttentionAnd, self).__init__()
        
    def forward(self, v, a1, a2):
        return torch.min(a1, a2)
    
class AttentionNot(nn.Module):
    def __init__(self):
        super(AttentionNot, self).__init__()
        
    def forward(self, v, a):
        return 1. - a
    
class AttentionOr(nn.Module):
    def __init__(self):
        super(AttentionOr, self).__init__()
        
    def forward(self, v, a1, a2):
        return torch.max(a1, a2)

class Exist(nn.Module):
    def __init__(self, att_size=100):
        super(Exist, self).__init__()
        self.linear = weight_norm(nn.Linear(att_size, 1), dim=None)
    
    def forward(self, v, att, arg):
        batch = att.size(0)
        att = att.view(batch, -1)
        logits = self.linear(att)
        logits = torch.sigmoid(logits)
        prob = make_answer_prob(logits)
        return prob
    
class Choose(nn.Module):
    def __init__(self):
        super(Choose, self).__init__()
    
    def forward(self, v, p0, p1, choice0, choice1):
        yes_id = answer2id['yes']
        p0 = p0[0, yes_id]
        p1 = p1[0, yes_id]
        p0 /= p0 + p1
        p1 /= p0 + p1
        prob = torch.zeros((1, len(answer_vocab),)).to(device)
        
        mapping = {'to the left of': 'left',
                    'to the right of': 'right',
                    'in front of': 'front', 
                    'standing in front of': 'front'}
        choice0 = mapping.get(choice0, choice0)
        choice1 = mapping.get(choice1, choice1)
            
        prob[0, answer2id.get(choice0, 0)] = p0 # 0 -> UNK
        prob[0, answer2id.get(choice1, 0)] = p1
        return prob
    

class Select(nn.Module):
    def __init__(self, v_dim=2048, t_dim=512, num_hid=512, dropout=0.):
        super(Select, self).__init__()

        self.v_proj = FCNet([v_dim, num_hid])
        self.t_proj = FCNet([t_dim, num_hid])
        self.dropout = nn.Dropout(dropout)
        self.linear = weight_norm(nn.Linear(num_hid, 1), dim=None)

    def forward(self, v, t):
        """
        v: [batch, k, v_dim]
        t: [batch, t_dim]
        """
        batch, k, _ = v.size()
        v_proj = self.v_proj(v)
        t_proj = self.t_proj(t).unsqueeze(1).repeat(1, k, 1)
        joint_repr = v_proj * t_proj
        joint_repr = self.dropout(joint_repr)
        logits = self.linear(joint_repr)
        #w = F.softmax(logits, 1)
        w = torch.sigmoid(logits)
        return w
        
class Relocate(nn.Module):
    def __init__(self, v_dim=2048, t_dim=512, num_hid=512, dropout=0.):
        super(Relocate, self).__init__()

        self.v_proj = FCNet([v_dim, num_hid])
        self.t_proj = FCNet([t_dim, num_hid])
        self.av_proj = FCNet([v_dim, num_hid])
        self.dropout = nn.Dropout(dropout)
        self.linear = weight_norm(nn.Linear(num_hid, 1), dim=None)

    def forward(self, v, a, t, so):
        """
        v: [batch, k, v_dim] vis
        a: [batch, k] attention
        t: [batch, t_dim] txt
        so: [batch, t_dim] subject, object
        """
        batch, k, _ = v.size()

        v_proj = self.v_proj(v)
        t_proj = self.t_proj(t).unsqueeze(1).repeat(1, k, 1)
        
        av = (a * v).sum(1)
        av_proj = self.av_proj(av).unsqueeze(1).repeat(1, k, 1)
        
        joint_repr = v_proj * t_proj * av_proj
        joint_repr = self.dropout(joint_repr)
        logits = self.linear(joint_repr)
        #w = F.softmax(logits, 1)
        w = torch.sigmoid(logits)
        return w
        
class Compare(nn.Module):
    def __init__(self, v_dim=2048, t_dim=512, num_hid=512, dropout=0.):
        super(Compare, self).__init__()

        self.v_proj = FCNet([v_dim, num_hid])
        self.t_proj = FCNet([t_dim, num_hid])
        self.dropout = nn.Dropout(dropout)
        self.linear = weight_norm(nn.Linear(num_hid, 1), dim=None)

    def forward(self, v, a, t1, t2):
        """
        v: [batch, k, v_dim]
        a: [batch, k]
        t1: [batch, t_dim] e.g., different, same, ...
        t2: [batch, t_dim] e.g., color, type, material, ...
        """
        batch, k, _ = v.size()
        v = (a * v).sum(1)
        v = self.v_proj(v)
        t1 = self.t_proj(t1)
        t2 = self.t_proj(t2)
        joint_repr = v * t1 * t2
        joint_repr = self.dropout(joint_repr)
        logits = self.linear(joint_repr)
        logits = torch.sigmoid(logits)
        prob = make_answer_prob(logits)
        return prob
    
class Common(nn.Module):
    def __init__(self, v_dim=2048, t_dim=512, num_hid=512, dropout=0.):
        super(Common, self).__init__()

        self.v_proj = FCNet([v_dim, num_hid])
        self.dropout = nn.Dropout(dropout)
        self.linear = weight_norm(nn.Linear(num_hid, len(answer_vocab)), dim=None)

    def forward(self, v, a1, a2):
        """
        v: [batch, k, v_dim]
        a1: [batch, k]
        a2: [batch, k]
        """
        batch, k, _ = v.size()
        v = self.v_proj(v)
        av1 = (a1 * v).sum(1)
        av2 = (a2 * v).sum(1)
        joint_repr = av1 * av2
        joint_repr = self.dropout(joint_repr)
        logits = self.linear(joint_repr)
        prob = F.softmax(logits, 1)
        return prob
    
class Query(nn.Module):
    def __init__(self, v_dim=2048, t_dim=512, num_hid=512, dropout=0.):
        super(Query, self).__init__()

        self.v_proj = FCNet([v_dim, num_hid])
        self.t_proj = FCNet([t_dim, num_hid])
        self.dropout = nn.Dropout(dropout)
        self.linear = weight_norm(nn.Linear(num_hid, len(answer_vocab)), dim=None)

    def forward(self, v, a, t):
        """
        v: [batch, k, v_dim] 
        a: [batch, k]
        t: [batch, t_dim]
        """
        batch, k, _ = v.size()
        v = self.v_proj(v)
        v = (a * v).sum(1)
        t = self.t_proj(t)
        joint_repr = v * t
        joint_repr = self.dropout(joint_repr)
        logits = self.linear(joint_repr)
        prob = F.softmax(logits, 1)
        return prob

In [4]:
class ModuleNet(nn.Module):
    def __init__(self):
        super(ModuleNet, self).__init__()
        self.function_modules = {}
        
        Modules = ['And', 'Or', 'AttentionAnd', 'AttentionNot', 'AttentionOr', 'Exist', 'Choose', 'Compare', 
                   'Common', 'Query', 'Select', 'Relocate']
        # Initialize all modules
        for module in Modules:
            func_net = eval(module)()
            self.add_module(module, func_net)
            self.function_modules[module.lower()] = func_net
            
        self.arg_emb = WordEmbedding(len(argument_vocab), 512)
    
    def forward(self, img_feats, program):
        N = img_feats.size(0)
        final_module_outputs = []
        for i in range(N):
            module_outputs = []
            for j, f in enumerate(program[i]):
                #print(f)
                module = self.function_modules[f['operation']]
                module_inputs = [img_feats[i:i+1]]
                module_inputs.extend([module_outputs[dep] for dep in f['dependencies']])
                
                if f['operation'] == 'choose': # for choose operation, argument, not the argument embedding
                    assert len(f['argument']) == 2
                    module_inputs.extend(f['argument'])
                else:                    
                    module_inputs.extend([self.arg_emb(torch.LongTensor([argument2id[arg]]).to(device))
                                          for arg in f['argument']])
                    
                module_outputs.append(module(*module_inputs))
            final_module_outputs.append(module_outputs[-1])
        
        final_module_outputs = torch.cat(final_module_outputs, 0)
        return final_module_outputs

In [5]:
splits = ['val_balanced', 'train_balanced']
datasets = {}
datasets.update({x: GQA(x) for x in splits})
dataset_sizes = {x: len(datasets[x]) for x in splits}
print(dataset_sizes)

{'val_balanced': 132062, 'train_balanced': 943000}


In [6]:
def compute_loss_acc(prob, gt):
    """
    Inputs:
    - y_pred: Variable of shape (N, V_out)
    - y: LongTensor Variable of shape (N,)
    """
    loss = F.nll_loss(torch.log(prob+1e-7), gt)
    
    pred = prob.max(dim=1)[1]
    acc = (pred == gt).float().mean()
    return loss, acc

In [7]:
def evaluate_model(model, dataloader):
    model.eval() 
    # Iterate over data.
    correct = 0
    total_count = 0
    for img_feat, program, answer in tqdm(dataloader):
        img_feat = img_feat.to(device)
        answer = answer.to(device)
        prob = model(img_feat, program)
        loss, acc = compute_loss_acc(prob, answer)
        batch_size = img_feat.size(0)
        correct += acc * batch_size
        total_count += batch_size
    acc = correct / total_count
    return acc

In [8]:
def train_model(model, num_epochs=5, train_splits=['train'], 
                eval_splits=['val'], n_epochs_per_eval = 1):
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    # Decay LR by a factor of 0.1 every 100 epochs
    scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    train_dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=32,
                         shuffle=True, num_workers=4, collate_fn=GQA_collate) for x in train_splits}
    
    eval_dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=32,
                         shuffle=False, num_workers=4, collate_fn=GQA_collate) for x in eval_splits}
    
    dataloaders = {}
    dataloaders.update(train_dataloaders)
    dataloaders.update(eval_dataloaders)
    
    ###########evaluate init model###########
    for eval_split in eval_splits:
        acc = evaluate_model(model, dataloaders[eval_split])
        print('(acc={1:.2f}) {0}'.format(eval_split, 100*acc))
    print()
    #########################################

    for epoch in range(num_epochs):
        since = time.time()
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        scheduler.step()

        # Iterate over data.
        for train_split in train_splits:
            for img_feat, program, answer in tqdm(dataloaders[train_split]):
                model.train()  # Set model to training mode
                img_feat = img_feat.to(device)
                answer = answer.to(device)
                prob = model(img_feat, program)
                loss, acc = compute_loss_acc(prob, answer)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        print(acc.item(), loss.item())
        # compute average precision
        if (epoch+1) % n_epochs_per_eval == 0:
            for eval_split in eval_splits:
                acc = evaluate_model(model, dataloaders[eval_split])
                print('(acc={1:.2f}) {0}'.format(eval_split, 100*acc))
            # deep copy the model
            if acc > best_acc:
                best_acc = acc
                best_model_wts = copy.deepcopy(model.state_dict())
                
        time_elapsed = time.time() - since
        print('Epoch time: {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print(flush=True)
    
    ###########evaluate final model###########
    for eval_split in eval_splits:
        acc = evaluate_model(model, dataloaders[eval_split])
        print('(acc={1:.2f}) {0}'.format(eval_split, 100*acc))
    # deep copy the model
    if acc > best_acc:
        best_acc = acc
        best_model_wts = copy.deepcopy(model.state_dict())
    #########################################

    print('Best val acc: {:2f}'.format(100*best_acc))
    # load best model weights
    model.load_state_dict(best_model_wts)
    return

In [None]:
train_splits = ['train_balanced']
eval_splits = ['val_balanced']
model = ModuleNet().to(device)
train_model(model, num_epochs=20, train_splits=train_splits, eval_splits=eval_splits, n_epochs_per_eval = 1)


  0%|          | 0/4127 [00:00<?, ?it/s][A
  0%|          | 1/4127 [00:03<4:11:06,  3.65s/it][A
  0%|          | 2/4127 [00:03<2:58:10,  2.59s/it][A
  0%|          | 4/4127 [00:03<2:06:29,  1.84s/it][A
  0%|          | 5/4127 [00:04<1:34:51,  1.38s/it][A
  0%|          | 6/4127 [00:04<1:11:10,  1.04s/it][A
  0%|          | 8/4127 [00:04<51:39,  1.33it/s]  [A
  0%|          | 9/4127 [00:05<45:11,  1.52it/s][A
  0%|          | 10/4127 [00:05<35:09,  1.95it/s][A
  0%|          | 12/4127 [00:05<26:19,  2.61it/s][A
  0%|          | 13/4127 [00:05<28:40,  2.39it/s][A
  0%|          | 14/4127 [00:06<27:24,  2.50it/s][A
  0%|          | 16/4127 [00:06<20:51,  3.29it/s][A
  0%|          | 17/4127 [00:06<18:39,  3.67it/s][A
  0%|          | 18/4127 [00:07<26:52,  2.55it/s][A
  0%|          | 20/4127 [00:07<20:35,  3.32it/s][A
  1%|          | 22/4127 [00:07<18:00,  3.80it/s][A
  1%|          | 23/4127 [00:07<15:14,  4.49it/s][A
  1%|          | 24/4127 [00:08<13:35,  5.03it/s]