In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms

import time
import os
import copy
import json
import nltk

from tqdm import *
from sklearn.metrics import average_precision_score, precision_recall_curve

from torch.utils.data import Dataset
import skimage
from PIL import Image
from copy import deepcopy
from sklearn.utils.fixes import signature
import time
import pickle

from preprocess import tokenize
import programs
from modify_program import eliminate_obj_id
from utils import *

import matplotlib.pyplot as plt
%matplotlib inline
plt.ion()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
def modify_program(questions):
    for q in tqdm(questions):
        for k in ['equivalent', 'entailed', 'isBalanced', 'groups', 'semanticStr', 'annotations', 'types', 'fullAnswer']:
            if k in q:
                del q[k]

        program = q['semantic']
        program = eliminate_obj_id(program)
        q['semantic'] = program

def program_to_seq(program, mode):
    if mode == 'prefix':
        program = programs.list_to_prefix(program)
    elif mode == 'postfix':
        program = programs.list_to_postfix(program)
    return program

In [4]:
class GQA(Dataset):
    def __init__(self, split):
        super(GQA, self).__init__()
        
        self.word2vocab_id = dict(zip(question_vocab, range(len(question_vocab))))
        self.operation2id = dict(zip(operation_vocab, range(len(operation_vocab))))
        self.argument2id = dict(zip(argument_vocab, range(len(argument_vocab))))
        
        dataroot = '/home/qing/Desktop/Datasets/GQA/'
        dataset = json.load(open(os.path.join(dataroot, "questions1.2/%s_questions.json"%split)))
        for k, v in dataset.items():
            v['qid'] = k
        dataset = list(dataset.values())
        modify_program(dataset)
        
        max_length = 25
        max_op = 9
        max_arg = 4
        for sample in tqdm(dataset):
            # 0 -> UNK, 1 -> START, 2-> END
            # question
            question = sample['question']
            tokens = tokenize(question)
            tokens = [self.word2vocab_id[x] for x in tokens if x in self.word2vocab_id]
            tokens = tokens[:max_length]
            padding = []
            if len(tokens) < max_length:
                padding = [0] * (max_length - len(tokens)) 
            tokens = padding + [1] + tokens + [2] # Note here we pad in front of the sentence
            sample['q_token'] = tokens
            
            # program
            program = program_to_seq(sample['semantic'], 'prefix')
            #sample['program_len'] = np.maximum(max_op, len(program)) # program_len does not count 'START' 'END'
            # operation
            operations = [self.operation2id[x['operation']] for x in program]
            operations = operations[:max_op]
            padding = []
            if len(operations) < max_op:
                padding = [0] * (max_op - len(operations))
            operations = [1] + operations + [2] + padding
            sample['operations'] = operations
            
            # argument
            arguments = []
            for op in program:
                arg = [self.argument2id[x] for x in op['argument']]
                if len(arg) < max_arg:
                    padding = [0] * (max_arg - len(arg)) 
                    arg = arg + padding
                arguments.append(arg)
                    
            padding = []
            if len(arguments) < max_op:
                padding = [[0] * max_arg] * (max_op - len(arguments))
            arguments = [[0] * max_arg] + arguments + [[0] * max_arg] + padding
            sample['arguments'] = arguments
            
            del sample['question'], sample['semantic']
            
            
        self.dataset = dataset
        
        
    def __getitem__(self, index):
        entry = self.dataset[index]
        question = torch.LongTensor(np.array(entry['q_token']))
        operations = torch.LongTensor(np.array(entry['operations']))
        arguments = torch.LongTensor(np.array(entry['arguments']))
        return question, operations, arguments
    
    def __len__(self):
        return len(self.dataset)

In [5]:
splits = ['val_balanced', 'train_balanced']
datasets = {}
datasets.update({x: GQA(x) for x in splits})
dataset_sizes = {x: len(datasets[x]) for x in splits}
print(dataset_sizes)

question_vocab = json.load(open('question_vocab.json'))
operation_vocab = json.load(open('operation_vocab.json'))
argument_vocab = json.load(open('argument_vocab.json'))

100%|██████████| 132062/132062 [00:00<00:00, 148025.90it/s]
100%|██████████| 132062/132062 [00:03<00:00, 39127.78it/s]
100%|██████████| 943000/943000 [00:06<00:00, 136158.65it/s]
100%|██████████| 943000/943000 [00:25<00:00, 36931.03it/s]

{'val_balanced': 132062, 'train_balanced': 943000}





In [53]:
class DecoderRNN(nn.Module):
    def __init__(self, in_dim, num_hid, nlayers = 1, dropout = 0., rnn_type='GRU'):
        """Module for decoder
        """
        super(DecoderRNN, self).__init__()
        assert rnn_type == 'LSTM' or rnn_type == 'GRU'
        rnn_cls = nn.LSTM if rnn_type == 'LSTM' else nn.GRU

        self.rnn = rnn_cls(
            in_dim, num_hid, nlayers,
            dropout=dropout,
            batch_first=True)

        self.in_dim = in_dim
        self.num_hid = num_hid
        self.nlayers = nlayers
        self.rnn_type = rnn_type

    def init_hidden(self, batch):
        # just to get the type of tensor
        weight = next(self.parameters()).data
        hid_shape = (self.nlayers, batch, self.num_hid)
        if self.rnn_type == 'LSTM':
            return (Variable(weight.new(*hid_shape).zero_()),
                    Variable(weight.new(*hid_shape).zero_()))
        else:
            return Variable(weight.new(*hid_shape).zero_())

    def forward(self, input, hidden):
        # input: [batch, sequence, in_dim]
        self.rnn.flatten_parameters()
        output, hidden = self.rnn(input, hidden)
        return output, hidden

In [74]:
from language_model import WordEmbedding, QuestionEmbedding
from torch.nn.utils.weight_norm import weight_norm

class ProgramGenerator(nn.Module):
    def __init__(self):
        super(ProgramGenerator, self).__init__()
        self.w_emb = WordEmbedding(len(question_vocab), 300)
        self.w_emb.init_embedding('question_word_embedding_glove_init.npy')
        self.op_emb = WordEmbedding(len(operation_vocab), 300)
        self.arg_emb = WordEmbedding(len(argument_vocab), 300)
        
        self.q_encoder = QuestionEmbedding(300, 512)
        self.p_decoder = DecoderRNN(300*5, 512) # 5 = 1 + 4, 1 operation, 4 arguments
        self.op_predictor = weight_norm(nn.Linear(512, len(operation_vocab)), dim=None)
        self.arg_predictor = weight_norm(nn.Linear(512, 4 * len(argument_vocab)), dim=None)


    def forward(self, question, operations, arguments):
        """Forward
        return: logits, not probs
        """
        
        batch_size = question.size(0)
        question_len = question.size(1)
        program_len = operations.size(1)
        
        w_emb = self.w_emb(question)
        q_enc = self.q_encoder(w_emb) # [batch, q_dim]
        
        
        operation_emb = self.op_emb(operations)
        argument_emb = self.arg_emb(arguments)
        
        argument_emb = argument_emb.view(batch_size, program_len, -1)
        op_arg_emb = torch.cat((operation_emb, argument_emb), -1)
        
        if self.training:
            # Teacher forcing: Feed the target as the next input
            p_dec, _ = self.p_decoder(op_arg_emb, q_enc.view(1, batch_size, -1))
            op_logit = self.op_predictor(p_dec) # [batch, program_len, operation_vocab]
            arg_logit = self.arg_predictor(p_dec) 
            arg_logit = arg_logit.view(batch_size, program_len, 4, -1) # [batch, program_len, 4, arg_vocab]
        
        else:
            # Without teacher forcing: use its own predictions as the next input
            op_logit = []
            arg_logit = []
            decoder_hidden = q_enc.view(1, batch_size, -1)
            decoder_input = op_arg_emb[:, 0].view(batch_size, 1, -1)
            for di in range(program_len):
                decoder_output, decoder_hidden = self.p_decoder(
                    decoder_input, decoder_hidden)

                one_op_logit = self.op_predictor(decoder_output)
                one_arg_logit = self.arg_predictor(decoder_output)
                one_arg_logit = one_arg_logit.view(batch_size, 4, -1)

                _, op_pred = one_op_logit.topk(1, dim=-1)
                _, arg_pred = one_arg_logit.topk(1, dim=-1)
                op_pred = op_pred.view(batch_size)
                arg_pred = arg_pred.view(batch_size, -1)
                operation_emb = self.op_emb(op_pred)
                argument_emb = self.arg_emb(arg_pred)
                argument_emb = argument_emb.view(batch_size, -1)
                
                op_arg_emb = torch.cat((operation_emb, argument_emb), -1)
                decoder_input = op_arg_emb.view(batch_size, 1, -1).detach()  # detach from history as input
                
                op_logit.append(one_op_logit)
                arg_logit.append(one_arg_logit)
            op_logit = torch.stack(op_logit, dim=1)
            arg_logit = torch.stack(arg_logit, dim=1)
        
        return op_logit, arg_logit

In [75]:
def compute_loss_acc(prob, gt):
    """
    Inputs:
    - y_pred: Variable of shape (N, V_out)
    - y: LongTensor Variable of shape (N,)
    """
    mask = gt.data != 0 # pad with 0
    prob = prob[mask, :]
    gt = gt[mask]
    loss = F.cross_entropy(prob, gt)
    
    pred = prob.max(dim=1)[1]
    acc = (pred == gt).float().mean()
    return loss, acc

In [76]:
def evaluate_model(model, dataloader):
    model.eval() 
    score_all = []
    label_all = []

    # Iterate over data.
    op_correct = 0
    arg_correct = 0
    total_count = 0
    for question, operations, arguments in dataloader:
        question = question.to(device)
        operations = operations.to(device)
        arguments = arguments.to(device)
        op_logit, arg_logit = model(question, operations[:, :-1], arguments[:, :-1])
        op_loss, op_acc = compute_loss_acc(
            op_logit.view(-1, len(operation_vocab)),
            operations[:, 1:].contiguous().view(-1))
        arg_loss, arg_acc = compute_loss_acc(
            arg_logit.view(-1, len(argument_vocab)),
            arguments[:, 1:].contiguous().view(-1))
        
        batch_size = question.size(0)
        op_correct += op_acc * batch_size
        arg_correct += arg_acc * batch_size
        total_count += batch_size
    op_acc = op_correct / total_count
    arg_acc = arg_correct / total_count
    
    return op_acc, arg_acc

In [81]:
def train_model(model, num_epochs=5, train_splits=['train'], 
                eval_splits=['val'], n_epochs_per_eval = 1):
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    # Decay LR by a factor of 0.1 every 100 epochs
    scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    train_dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=128,
                                             shuffle=True, num_workers=4) for x in train_splits}
    
    eval_dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=128,
                                         shuffle=False, num_workers=4) for x in eval_splits}
    
    dataloaders = {}
    dataloaders.update(train_dataloaders)
    dataloaders.update(eval_dataloaders)
    
    ###########evaluate init model###########
    for eval_split in eval_splits:
        op_acc, arg_acc = evaluate_model(model, dataloaders[eval_split])
        print('(op_acc={1:.2f}, arg_acc={2:.2f}) {0}'.format(eval_split, 100*op_acc, 100*arg_acc))
    print()
    #########################################

    for epoch in range(num_epochs):
        since = time.time()
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        scheduler.step()

        # Iterate over data.
        for train_split in train_splits:
            for question, operations, arguments in dataloaders[train_split]:
                model.train()  # Set model to training mode
                question = question.to(device)
                operations = operations.to(device)
                arguments = arguments.to(device)
                op_logit, arg_logit = model(question, operations[:, :-1], arguments[:, :-1])
                
                op_loss, op_acc = compute_loss_acc(
                    op_logit.view(-1, len(operation_vocab)),
                    operations[:, 1:].contiguous().view(-1))
                arg_loss, arg_acc = compute_loss_acc(
                    arg_logit.view(-1, len(argument_vocab)),
                    arguments[:, 1:].contiguous().view(-1))
                
                loss = op_loss + arg_loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        print(op_acc.item(), arg_acc.item(), loss.item())
        # compute average precision
        if (epoch+1) % n_epochs_per_eval == 0:
            for eval_split in eval_splits:
                op_acc, arg_acc = evaluate_model(model, dataloaders[eval_split])
                print('(op_acc={1:.2f}, arg_acc={2:.2f}) {0}'.format(eval_split, 100*op_acc, 100*arg_acc))
            acc = op_acc + arg_acc
            # deep copy the model
            if acc > best_acc:
                best_acc = acc
                best_model_wts = copy.deepcopy(model.state_dict())
                
        time_elapsed = time.time() - since
        print('Epoch time: {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print(flush=True)
    
    ###########evaluate final model###########
    for eval_split in eval_splits:
        ap, label, score = evaluate_model(model, dataloaders[eval_split])
        print('(AP={1:.2f}) {0}'.format(eval_split, 100*ap))
    # deep copy the model
    if ap > best_ap:
        best_ap = ap
        best_model_wts = copy.deepcopy(model.state_dict())
    print()
    #########################################

    print('Best val AP: {:2f}'.format(100*best_ap))
    # load best model weights
    model.load_state_dict(best_model_wts)
    return

In [82]:
train_splits = ['train_balanced']
eval_splits = ['val_balanced']
model = ProgramGenerator().to(device)
train_model(model, num_epochs=20, train_splits=train_splits, eval_splits=eval_splits, n_epochs_per_eval = 1)

(op_acc=0.05, arg_acc=0.00) val_balanced

Epoch 0/19
----------
0.9897959232330322 0.7685949802398682 1.0681862831115723
(op_acc=66.69, arg_acc=45.78) val_balanced
Epoch time: 2m 32s

Epoch 1/19
----------
1.0 0.8869564533233643 0.455322802066803
(op_acc=67.01, arg_acc=55.52) val_balanced
Epoch time: 2m 33s

Epoch 2/19
----------
1.0 0.8691588640213013 0.38715264201164246
(op_acc=60.86, arg_acc=58.14) val_balanced
Epoch time: 2m 37s

Epoch 3/19
----------
1.0 0.9711538553237915 0.09018203616142273
(op_acc=59.64, arg_acc=58.89) val_balanced
Epoch time: 2m 40s

Epoch 4/19
----------
0.9999999403953552 0.9385964870452881 0.1323327273130417
(op_acc=58.15, arg_acc=58.92) val_balanced
Epoch time: 2m 40s

Epoch 5/19
----------
0.9909909963607788 0.9752065539360046 0.13251852989196777
(op_acc=58.41, arg_acc=58.56) val_balanced
Epoch time: 2m 36s

Epoch 6/19
----------
0.9999999403953552 0.9523810148239136 0.08971542119979858
(op_acc=61.00, arg_acc=59.92) val_balanced
Epoch time: 2m 39s

Epoch 

KeyboardInterrupt: 