In [1]:
import json
import os
import re
import random
import time
import datetime

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.utils.data as data
from torch.nn.utils import clip_grad_norm_
from tensorboardX import SummaryWriter
import torch.utils
import torchtext
from torchtext.data import Example
from torch.utils.tensorboard import SummaryWriter

import torch.nn.functional as F

import torch_geometric
from torch_geometric.utils.convert import from_networkx
from torch_geometric.utils import dense_to_sparse, to_dense_adj, to_dense_batch
from torch_geometric.nn import GCNConv, GATConv
import torch_geometric.nn as pyg_nn
import networkx as nx
import shutil

import numpy as np
from PIL import Image
import h5py
from collections import defaultdict

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [2]:
ROOT_PATH = "/local/home/rchan31/SIGIR/detailed-experiments/data"

IMG_FTR_FOLDER = os.path.join(ROOT_PATH, "img_ftrs")
OBJ_EMBEDDING_FILE = os.path.join(ROOT_PATH, "gqa_50_singular_predicates_embeddings_300d")

TRAIN_BBOX_FOLDER = os.path.join(ROOT_PATH, "bbox_data_gqa_pytorch_normalized2")
TRAIN_SG_FILE = os.path.join(ROOT_PATH, "2gqa_rel_annotations_train_compatible_for_scene_graph.json")

In [3]:
# # toDO:
# 1. use unfiltered sg data, box features, pred features
# 2. use

In [4]:
GT_VAL_ROOT = "/local/home/rchan31/SIGIR/detailed-experiments/data/GT_VAL_DATA/"

VAL_GT_BBOX_FOLDER = os.path.join(GT_VAL_ROOT, "gt_val_bboxes")
VAL_GT_SG_FILE = os.path.join(GT_VAL_ROOT, "scene_graph_val_gt.json")

In [60]:
SAVE_DIR_ROOT = "/local/home/rchan31/SIGIR/detailed-experiments/SAVED_DIRS/"

SAVE_DIR = os.path.join(SAVE_DIR_ROOT, "saved-exp01")
writer = SummaryWriter(SAVE_DIR + '/runs_1/')

RESET_MODEL = False

if RESET_MODEL:
    shutil.rmtree(SAVE_DIR)
    
os.makedirs(SAVE_DIR, exist_ok=True)

# Load train data

In [7]:
TRAIN_QST_FILE = "/local/home/rchan31/SIGIR/detailed-experiments/data/QUESTIONS/train_balanced_questions.json"

train_qas_data = json.load( open(TRAIN_QST_FILE, 'r') )
print(len(train_qas_data))

943000


In [10]:
TRAIN_SG_DATA = json.load( open(TRAIN_SG_FILE, "r") )
print(len(TRAIN_SG_DATA) )

74942


# Load val data`

In [11]:
VAL_QST_FILE = "/local/home/rchan31/SIGIR/detailed-experiments/data/QUESTIONS/val_balanced_questions.json"

val_qas_data = json.load( open(VAL_QST_FILE, 'r') )
print(len(val_qas_data))

132062


In [12]:
VAL_GT_SG_DATA = json.load( open(VAL_GT_SG_FILE, "r") )
print(len(VAL_GT_SG_DATA) )

10696


# Data processing & build vocabs

In [13]:
def tokenise(text):
    # Replace annoying unicode with a space
    text = re.sub(r'[^\x00-\x7F]+', ' ', text)
    # The following replacements are suggested in the paper
    # BidAF (Seo et al., 2016)
    text = text.replace("''", '" ')
    text = text.replace("``", '" ')

    # Space out punctuation
    space_list = "!\"#$%&()*+,-./:;<=>?@[\\]^_`{|}~"
    space_list = "!\"#$%&()*+,./:;<=>?@[\\]^_`{|}~"
    text = text.translate(str.maketrans({key: " {0} ".format(key) for key in space_list}))

    # space out singlequotes a bit better (and like stanford)
    text = text.replace("'", " '")
    
    # use any APIs
    text = text.replace('\t', '').replace('\n', '').lower().strip()
    return text

In [18]:
PAD_TOKEN = '<PAD>'
SOS_TOKEN = '<SOS>'
UNK_TOKEN = '<UNK>'
EOS_TOKEN = '<EOS>'

print("complete data (train + val) size : ", len(train_qas_data) + len(val_qas_data) )

MAX_QUESTION_LENGTH = 14
MAX_ANSWER_LEN = 1

QUE_TEXT = torchtext.data.Field(sequential=True, 
                                tokenize=lambda x: x.split(),
                                # init_token=SOS_WORD,
                                eos_token=EOS_TOKEN,
                                pad_token=PAD_TOKEN,
                                include_lengths=True,
                                batch_first=True,
                                fix_length=MAX_QUESTION_LENGTH, # + 1
                                lower=True)

ANS_TEXT = torchtext.data.Field(sequential=True, 
                                tokenize=lambda x: x.split(),                                
                                # eos_token=EOS_WORD,
                                init_token=SOS_TOKEN,
                                pad_token=PAD_TOKEN,
                                batch_first=True,
                                fix_length=MAX_ANSWER_LEN + 1, # +1 because of attached SOS token at the beginning
                                lower=True)

example_texts = []

FIELDS = [('ans_text', ANS_TEXT), ('que_text', QUE_TEXT)]

print("Processing train vocab")
for idx, (qas_id, qas_ins) in enumerate(train_qas_data.items()):
    q = qas_ins["question"]
    a = qas_ins["answer"]
    
    q = tokenise(q).replace('?', '')
    a = tokenise(a).replace('?', '')
    
    example_texts.append( Example.fromlist([a, q] , FIELDS ) )
    
    if idx % 100000 == 0:
        print("Finished ", idx)
        
print("Processing val vocab")
for idx, (qas_id, qas_ins) in enumerate(val_qas_data.items()):
    q = qas_ins["question"]
    a = qas_ins["answer"]
    
    q = tokenise(q).replace('?', '')
    a = tokenise(a).replace('?', '')
    
    example_texts.append( Example.fromlist([a, q] , FIELDS ) )
    
    if idx % 100000 == 0:
        print("Finished ", idx)

torchtext_dataset = torchtext.data.Dataset(example_texts, fields=FIELDS)

QUE_TEXT.build_vocab(torchtext_dataset, vectors='glove.6B.300d', vectors_cache='/local/home/rchan31/SIGIR/detailed-experiments/data/vocab_cache')
ANS_TEXT.build_vocab(torchtext_dataset, vectors='glove.6B.300d', vectors_cache='/local/home/rchan31/SIGIR/detailed-experiments/data/vocab_cache')

print("======= Question ========")
print(len(QUE_TEXT.vocab), len(ANS_TEXT.vocab))
print(QUE_TEXT.vocab.vectors.size(), ANS_TEXT.vocab.vectors.size())

complete data (train + val) size :  1075062
Processing train vocab
Finished  0
Finished  100000
Finished  200000
Finished  300000
Finished  400000
Finished  500000
Finished  600000
Finished  700000
Finished  800000
Finished  900000
Processing val vocab
Finished  0
Finished  100000
2874 1676
torch.Size([2874, 300]) torch.Size([1676, 300])


In [24]:
def train_get_qas_for_imgs_present(img_id):
    if os.path.isfile( os.path.join(IMG_FTR_FOLDER, "{}.json".format(img_id)) ) and \
        os.path.isfile( os.path.join(TRAIN_BBOX_FOLDER, "{}.json".format(img_id)) ):
        if (img_id in TRAIN_SG_DATA) and (len(TRAIN_SG_DATA[str(img_id)]) != 0):
            return True
    else: False

def train_decouple_q_and_a(qas_data):    
    structured_qas = {"img_id": [], "questions": [], "answers": [], "question_ids": []}
    for idx, (qas_id, qas_ins) in enumerate(qas_data.items()):
        img_id = qas_ins['imageId']        
        
        if not (train_get_qas_for_imgs_present(img_id)): continue
        
        q = qas_ins["question"]
        a = qas_ins["answer"]
        
        structured_qas['img_id'].append(img_id)
        structured_qas['questions'].append(q)
        structured_qas['answers'].append(a)
        structured_qas['question_ids'].append(qas_id)
    
    return structured_qas

train_structured_qas = train_decouple_q_and_a(train_qas_data)
print("Finished loading train_structured_qas ...")

print( len(train_structured_qas['answers']) )

Finished loading train_structured_qas ...
897111


In [28]:
def val_get_qas_for_imgs_present(img_id):
    if os.path.isfile( os.path.join(IMG_FTR_FOLDER, "{}.json".format(img_id)) ) and \
        os.path.isfile( os.path.join(VAL_GT_BBOX_FOLDER, "{}.json".format(img_id)) ):
        if (img_id in VAL_GT_SG_DATA) and (len(VAL_GT_SG_DATA[str(img_id)]) != 0):
            return True
    else: False

def val_decouple_q_and_a(qas_data):    
    structured_qas = {"img_id": [], "questions": [], "answers": [], "question_ids": []}
    for idx, (qas_id, qas_ins) in enumerate(qas_data.items()):
        img_id = qas_ins['imageId']        
        
        if not (val_get_qas_for_imgs_present(img_id)): continue
        
        q = qas_ins["question"]
        a = qas_ins["answer"]
        
        structured_qas['img_id'].append(img_id)
        structured_qas['questions'].append(q)
        structured_qas['answers'].append(a)
        structured_qas['question_ids'].append(qas_id)
    
    return structured_qas

valid_structured_qas = val_decouple_q_and_a(val_qas_data)
print("Finished loading valid_structured_qas ...")

print( len(valid_structured_qas['answers']) )

Finished loading valid_structured_qas ...
125726


# Dataset

In [29]:
PAD_token_idx = QUE_TEXT.vocab.stoi[PAD_TOKEN]
NUM_NODES = 40

def load_predicate_features_file():
    pred_embeddings = []
    with open(OBJ_EMBEDDING_FILE, 'r') as pred_file:
        all_pred_embeddings = [line.rstrip('\n') for line in pred_file]
        for embedding in all_pred_embeddings:
            pred_embeddings.append([float(i) for i in embedding[2:][:-2].split()])
    return np.asarray(pred_embeddings)

def load_object_visual_features(imageID, data_mode):
    if data_mode == 'train':
        with open(os.path.join(TRAIN_BBOX_FOLDER, "%s.json" % imageID), 'r') as f:
            obj_feat_set = json.loads(f.read())
        return obj_feat_set
    elif data_mode == 'gt_val':
        with open(os.path.join(VAL_GT_BBOX_FOLDER, "%s.json" % imageID), 'r') as f:
            obj_feat_set = json.loads(f.read())
        return obj_feat_set
    else:
        with open(os.path.join(VAL_INF_BBOX_FOLDER, "%s.json" % imageID), 'r') as f:
            obj_feat_set = json.loads(f.read())
        return obj_feat_set

def prepare_questions(questions_list):
    for q in questions_list:
        q = tokenise(q).replace('?', '')
        yield q.split()
        
def prepare_answers(answers_list):
    for a in answers_list:
        a = tokenise(a).replace('.', '')
        yield a.split()

        
class VQA(torch.utils.data.Dataset):
    """ VQA dataset, open-ended """
    def __init__(self,
                structured_qas,
                question_field,
                answer_field,
                data_mode="train"):
        super(VQA, self).__init__()

        print("Step1 : Data loading")
        print("....... {} mode number of data samples : {}".format(data_mode, len(structured_qas['img_id'])))
        self.question_field = question_field
        self.answer_field = answer_field

        # ========================= Load Answer Vocab =================
        self.answers = list(prepare_answers( structured_qas['answers'] ))
        self.answers = answer_field.pad(self.answers)
        self.answers = answer_field.numericalize(self.answers)
        print("....... VG Answer data have been PREPARED & ENCODED ...")

        # ========================= Load Question Vocab =================
        self.questions_vocab = question_field.vocab
        self.questions = list(prepare_questions( structured_qas['questions'] ))
        self.questions = question_field.pad(self.questions)
        self.questions = question_field.numericalize(self.questions)
        
        print("....... VG Question data have been PREPARED & ENCODED ...")
        self.vg_ids = structured_qas['img_id']    
        self.question_ids = structured_qas['question_ids']    
        self.data_mode = data_mode
        
        if self.data_mode == 'train':
            self.sg_data = TRAIN_SG_DATA
        elif data_mode == 'gt_val':
            self.sg_data = VAL_GT_SG_DATA
        else:
            self.sg_data = VAL_INF_SG_DATA
            
        self.pred_embeddings = load_predicate_features_file()
        
        print("....... Data loading completed ...")
    
    @property
    def max_que_length(self):
        if not hasattr(self, '_ques_max_length'):
            self._ques_max_length = max(map(len, self.questions))
        return self._ques_max_length
    @property
    def max_ans_length(self):
        if not hasattr(self, '_ans_max_length'):
            self._ans_max_length = max(map(len, self.answers))
        return self._ans_max_length
    
    def get_all_img_ftrs(self, img_id):
        img_ftr_folder = os.path.join(IMG_FTR_FOLDER, img_id + '.json')
        with open(img_ftr_folder, "r") as f:
            img_ftr_data = json.loads(f.read())
            return np.asarray(img_ftr_data, np.float32)
        
    def __len__(self):
        return len(self.answers)
    
    def __getitem__(self, item):
        img_id = self.vg_ids[item]
        qst_id = self.question_ids[item]
        
        q = self.questions[0][item]
        q_length = self.questions[1][item]
        a = self.answers[item]
        
        v = self.get_all_img_ftrs(img_id)
    
        obj_feat_data_hf = load_object_visual_features(img_id, self.data_mode) 
        g = self._load_sgraph(img_id, obj_feat_data_hf)
            
        data = torch_geometric.data.Data(x=g.x.float(), 
                                         edge_index=g.edge_index.long(),
                                         x_sem=g.cat.long(),
                                         y=a, edge_attr=g.edge_x.float())
        data['v'] = v
        data['q'] = q
        data['q_length'] = q_length
        data['img_id'] = img_id
        data['qst_id'] = qst_id
        
        return data       

    
    def _load_sgraph(self, image_id, obj_feat_data_hf):
        G = nx.DiGraph()
        sg_data_per_img = self.sg_data[str(image_id)]

        for relation in sg_data_per_img:
            sub_id = "%s" % (relation["subject"]["bbox"])
            obj_id = "%s" % (relation["object"]["bbox"])
            pred_id = relation["predicate"] - 1 # -1 is required
            sub_cat = relation["subject"]["category"]
            obj_cat = relation["object"]["category"]   
            
            if G.number_of_nodes() < NUM_NODES - 1:
                
                G.add_node(sub_id, 
                           x=np.array(obj_feat_data_hf.get(sub_id , np.zeros(512, ))), 
                           cat=self.question_field.vocab.stoi[sub_cat])
                G.add_node(obj_id, 
                           x=np.array(obj_feat_data_hf.get(obj_id , np.zeros(512, ))), 
                           cat=self.question_field.vocab.stoi[obj_cat])         
            
            if (sub_id in list(G.nodes())) and (obj_id in list(G.nodes())):
                
                G.add_edge(sub_id, obj_id, edge_x=self.pred_embeddings[pred_id])
#                 if pred_id >= 50: # TODO:
#                     G.add_edge(sub_id, obj_id, edge_x=np.zeros(300, ) )
#                 else:
#                     G.add_edge(sub_id, obj_id, edge_x=self.pred_embeddings[pred_id])

        g = from_networkx(G)
        return g

In [31]:
train_vqa_dataset = VQA(train_structured_qas,
                        QUE_TEXT,
                        ANS_TEXT,
                        data_mode='train')

val_vqa_dataset = VQA(valid_structured_qas,
                      QUE_TEXT,
                      ANS_TEXT,
                      data_mode='gt_val')

train_data = train_vqa_dataset.__getitem__(1)
print ( train_vqa_dataset.__len__() )

print( val_vqa_dataset.__len__() )
val_data = val_vqa_dataset.__getitem__(1)

Step1 : Data loading
....... train mode number of data samples : 897111
....... VG Answer data have been PREPARED & ENCODED ...
....... VG Question data have been PREPARED & ENCODED ...
....... Data loading completed ...
Step1 : Data loading
....... gt_val mode number of data samples : 125726
....... VG Answer data have been PREPARED & ENCODED ...
....... VG Question data have been PREPARED & ENCODED ...
....... Data loading completed ...
897111
125726


In [32]:
BATCH_SIZE = 128

train_data_loader = torch_geometric.data.DataLoader(train_vqa_dataset,
                                                    batch_size=BATCH_SIZE)
val_data_loader = torch_geometric.data.DataLoader(val_vqa_dataset,
                                                  batch_size=BATCH_SIZE )

n_train_batches = len(train_data_loader)
print( n_train_batches )

n_val_batches = len(val_data_loader)
print( n_val_batches )

7009
983


In [59]:
# for x in train_data_loader:
#     print(x)
#     v = torch.tensor(x.v)
#     print(v.shape)
#     print()
#     print(x.img_id[:5])
#     print()
#     print(x.qst_id[:5])
#     break

In [58]:
# # FOR DEBUGGING
# sample_idx = 3
# batch_size = 128

# for x in train_data_loader:
    
#     print(x.qst_id)
#     print(x.img_id)
    
#     q = torch.tensor(x.q).view(batch_size, -1)[sample_idx]
#     a = torch.tensor(x.y).view(batch_size, -1)[sample_idx]
#     print(q.shape)
#     print(a.shape)
#     print()
    
#     q_str = []
#     for qq in q:
#         q_str.append(QUE_TEXT.vocab.itos[qq])
#     print(' '.join( q_str) )
#     print()
    
#     a_str = []
#     for aa in a:
#         a_str.append(ANS_TEXT.vocab.itos[aa])
#     print(' '.join( a_str) )
#     print()
    
#     break

# Model

## Attention-based Sequence Decoder

In [36]:
# http://www.adeveloperdiary.com/data-science/deep-learning/nlp/machine-translation-using-attention-with-pytorch/

class Attention(nn.Module):
    def __init__(self, encoder_hidden_dim, decoder_hidden_dim):
        super().__init__()
 
        # The input dimension will the the concatenation of
        # encoder_hidden_dim (hidden) and  decoder_hidden_dim(encoder_outputs)
        self.attn_hidden_vector = nn.Linear(encoder_hidden_dim + decoder_hidden_dim, decoder_hidden_dim)
 
        # We need source len number of values for n batch as the dimension
        # of the attention weights. The attn_hidden_vector will have the
        # dimension of [source len, batch size, decoder hidden dim]
        # If we set the output dim of this Linear layer to 1 then the
        # effective output dimension will be [source len, batch size]
        self.attn_scoring_fn = nn.Linear(decoder_hidden_dim, 1, bias=False)
 
    def forward(self, hidden, encoder_outputs):
        # hidden = [1, batch size, decoder hidden dim]
        src_len = encoder_outputs.shape[0]
 
        # We need to calculate the attn_hidden for each source words.
        # Instead of repeating this using a loop, we can duplicate
        # hidden src_len number of times and perform the operations.
        hidden = hidden.repeat(src_len, 1, 1)
 
        # Calculate Attention Hidden values
        attn_hidden = torch.tanh(self.attn_hidden_vector(torch.cat((hidden, encoder_outputs), dim=2)))
 
        # Calculate the Scoring function. Remove 3rd dimension.
        attn_scoring_vector = self.attn_scoring_fn(attn_hidden).squeeze(2)
 
        # The attn_scoring_vector has dimension of [source len, batch size]
        # Since we need to calculate the softmax per record in the batch
        # we will switch the dimension to [batch size,source len]
        attn_scoring_vector = attn_scoring_vector.permute(1, 0)
 
        # Softmax function for normalizing the weights to
        # probability distribution
        return F.softmax(attn_scoring_vector, dim=1)
    
class OneStepDecoder(nn.Module):
    def __init__(self, output_dim, embedding_dim, encoder_hidden_dim, decoder_hidden_dim, dropout_prob=0.5):
        super().__init__()
 
        self.attention = Attention(encoder_hidden_dim, decoder_hidden_dim)
 
        # self.embedding = nn.Embedding(input_output_dim, embedding_dim)
        self.embedding = nn.Embedding.from_pretrained(ANS_TEXT.vocab.vectors, padding_idx=PAD_token_idx, freeze=True)
 
        # Add the encoder_hidden_dim and embedding_dim
        self.rnn = nn.GRU(encoder_hidden_dim + embedding_dim, decoder_hidden_dim)
        # Combine all the features for better prediction
        self.fc = nn.Linear(encoder_hidden_dim + decoder_hidden_dim + embedding_dim, output_dim)
 
        self.dropout = nn.Dropout(dropout_prob)
    
    def forward(self, input, hidden, encoder_outputs):
        # Add the source len dimension
        input = input.unsqueeze(0)
 
        embedded = self.dropout(self.embedding(input))
 
        # Calculate the attention weights
        a = self.attention(hidden, encoder_outputs).unsqueeze(1)
 
        # We need to perform the batch wise dot product.
        # Hence need to shift the batch dimension to the front.
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
 
        # Use PyTorch's bmm function to calculate the weight W.
        W = torch.bmm(a, encoder_outputs)
 
        # Revert the batch dimension.
        W = W.permute(1, 0, 2)
 
        # concatenate the previous output with W
        rnn_input = torch.cat((embedded, W), dim=2)
 
        output, hidden = self.rnn(rnn_input, hidden)
 
        # Remove the sentence length dimension and pass them to the Linear layer
        predicted_token = self.fc(torch.cat((output.squeeze(0), W.squeeze(0), embedded.squeeze(0)), dim=1))
 
        return predicted_token, hidden, a.squeeze(1)

class Decoder(nn.Module):
    def __init__(self, output_dim, embedding_dim=300, encoder_hidden_dim=200, decoder_hidden_dim=200, dropout_prob=0.1):
        super().__init__()
        self.one_step_decoder = OneStepDecoder(output_dim, embedding_dim, encoder_hidden_dim, decoder_hidden_dim, dropout_prob)
        self.device = device
        self.ans_vocab_size = len(ANS_TEXT.vocab)
 
    def forward(self, target, encoder_outputs, hidden, teacher_forcing_ratio=0.5):
        batch_size = target.shape[1]
        trg_len = target.shape[0]
        trg_vocab_size = self.ans_vocab_size
 
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size) # .to(self.device)
        input = target[0, :]
 
        for t in range(1, trg_len):
            # Pass the encoder_outputs. For the first time step the 
            # hidden state comes from the encoder model.
            output, hidden, a = self.one_step_decoder(input, hidden, encoder_outputs)
            outputs[t] = output
 
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
 
            input = target[t] if teacher_force else top1
 
        return outputs

## Encoder

In [37]:
n_classes = len(ANS_TEXT.vocab)
print("Number of classes : ", n_classes)

class GraphGATEncoder(torch.nn.Module):
    def __init__(self, input_dim, output_dim, dropout=0.0, heads=1):
        super(GraphGATEncoder, self).__init__()
        self.gat_conv = GATConv(in_channels=input_dim, out_channels=output_dim, heads=heads) # , dropout=self.dropout)       
        
    def forward(self, x_input, edge_index):
        gat_encoding = self.gat_conv(x_input, edge_index)
        return gat_encoding
    
class SequenceEncoder(torch.nn.Module):
    def __init__(self, input_size, 
                 padding_idx, seq_max_len, 
                 hidden_size=200, n_layers=1, dropout=0.1):
        super(SequenceEncoder, self).__init__()
        
        self.padding_idx = padding_idx
        self.seq_max_len = seq_max_len
        
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.dropout = dropout

        
        self.gru = nn.GRU(input_size, self.hidden_size, self.n_layers,
                          batch_first=True, bidirectional=False,
                          dropout=(0 if self.n_layers == 1 else self.dropout) )
            
    def forward(self, seq, seq_len, hidden=None):

        packed = torch.nn.utils.rnn.pack_padded_sequence(seq, seq_len, batch_first=True, enforce_sorted=False)
        output, final = self.gru(packed, hidden)
        output, _ = torch.nn.utils.rnn.pad_packed_sequence(output, batch_first=True, padding_value=self.padding_idx, total_length=self.seq_max_len)

        return output, final
    
class VQAEncoder(nn.Module):
    
    def __init__(self, q_word_emb_dim, question_length, hidden_size=200, n_layers=1, dropout=0.1):
        super(VQAEncoder, self).__init__()
        
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.dropout = dropout
                
        self.padding_idx = PAD_token_idx
        self.q_len = question_length

        
        
        self.q_embedding = nn.Embedding.from_pretrained(QUE_TEXT.vocab.vectors, padding_idx=self.padding_idx, freeze=True)
        
        
        
        self.simple_q_encoder = SequenceEncoder(q_word_emb_dim, 
                                                self.padding_idx, self.q_len, 
                                                self.hidden_size, self.n_layers, self.dropout)
        self.img_q_encoder = SequenceEncoder(q_word_emb_dim + 512, 
                                             self.padding_idx, self.q_len, 
                                             self.hidden_size, self.n_layers, self.dropout)

        self.g_sem_emb = nn.Embedding.from_pretrained(QUE_TEXT.vocab.vectors, freeze=True)
        self.visual_graph_encoder = GraphGATEncoder(input_dim=self.hidden_size+512, output_dim=hidden_size, heads=1)
        self.semantic_graph_encoder = GraphGATEncoder(input_dim=self.hidden_size+300, output_dim=hidden_size, heads=1)
        
        self.pred_fc = nn.Linear(300, hidden_size, bias=False) # [128, 40*40, 300]-->[128, 40*40, 200]
        self.tanh = nn.Tanh()
        self.dropout = nn.Dropout(self.dropout)
                
    def forward(self, x, x_sem, edge_index, edge_attr, batch_alloc, qst, qst_len, img, batch_size, encoder_hidden=None):        
        qst_embedding_feature = self.q_embedding(qst) # [b, 28, 300]
        x_sem_embedding_features = self.g_sem_emb(x_sem) # [N, 300]
           
        # ================== simple-question encoding ================================== 
        encoder_outputs, encoder_states = self.simple_q_encoder(qst_embedding_feature, qst_len)                      # [64, 400]
        
        # ================== question with vgg-img features ========================
        vgg_features = torch.unsqueeze(img, 1).repeat(1, self.q_len, 1)
        vgg_qst_features = torch.cat((vgg_features, qst_embedding_feature), axis=2) # [128, 14, 812]
        enc2_op, qn_encoding2 = self.img_q_encoder(vgg_qst_features, qst_len) # [1, 128, 200]
                 
        # ================== spread simple-question encoding with node embeddings ======
        encoder_states_unsqueezed = torch.squeeze(encoder_states, 0) # [128, 200] 
        reformatted_batch_qn_encoding = torch.stack([encoder_states_unsqueezed[b_idx] for b_idx in batch_alloc]) # [1605, 200]
    
        img_ftr_qn_encoding = torch.cat((x, reformatted_batch_qn_encoding), axis=1) # [1605, 200 + 512]
        obj_ftr_qn_encoding = torch.cat([x_sem_embedding_features, reformatted_batch_qn_encoding], axis=1) # [1605, 200 + 300]

        # ================== apply GAT over simple-question mixed scene-graph ==========
        # [1605, 712] --> [1605, 200] --> [128, 40, 200]
        img_node_embeddings = self.visual_graph_encoder(img_ftr_qn_encoding, edge_index) # [1605, 200+512] --> [128, 40, 200]
        # img_node_embeddings = self.convert_geometric_to_standard_batch(img_node_embeddings, batch_alloc, batch_size)
        img_node_embeddings, _ = to_dense_batch(img_node_embeddings, batch=batch_alloc, fill_value=0, max_num_nodes=NUM_NODES)
                
        obj_node_embeddings = self.semantic_graph_encoder(obj_ftr_qn_encoding, edge_index) # [1605, 200+300] --> [128, 40, 200]
        # obj_node_embeddings = self.convert_geometric_to_standard_batch(obj_node_embeddings, batch_alloc, batch_size)
        obj_node_embeddings, _ = to_dense_batch(obj_node_embeddings, batch=batch_alloc, fill_value=0, max_num_nodes=NUM_NODES)
        
        # ================== get predicate embedding =================================== # [744, 300]
        tmp_adj = to_dense_adj(edge_index=edge_index, edge_attr=edge_attr, batch=batch_alloc, max_num_nodes=NUM_NODES) # [128, 40, 40, 300]
        pred_embeddings = tmp_adj.view(batch_size, NUM_NODES * NUM_NODES, 300) # [128, 40*40, 300]
        pred_embeddings_feat = self.tanh( self.pred_fc(pred_embeddings) ) # [128, 40*40, 200]
        
        # ================== fused both graph-encodings ==================
        node_embeddings = torch.cat((img_node_embeddings, pred_embeddings_feat, obj_node_embeddings), 
                                    axis=1) # [128, 40 + 40*40 + 40, 200]
        
        return node_embeddings, qn_encoding2 # [128,1680, 200] [1, 128, 200]

Number of classes :  1676


## Building full model with Encoder + Decoder

In [38]:
qst_token_embedding_dim = 300
ans_token_embedding_dim = 300

HIDDEN_SIZE = 200
N_LAYERS = 1
DROPOUT = 0.1

class VQASeqToSeq(nn.Module):
    def __init__(self, qst_token_embedding_dim, ans_token_embedding_dim, hidden_size, n_layers, dropout):
        super(VQASeqToSeq, self).__init__()
    
        self.encoder = VQAEncoder(q_word_emb_dim=qst_token_embedding_dim,
                                  question_length=MAX_QUESTION_LENGTH, # CONFIG.MAX_QUESTION_LENGTH + 1,
                                  hidden_size=hidden_size, n_layers=n_layers, dropout=dropout)
        
        self.decoder = Decoder(output_dim=n_classes, 
                               embedding_dim=ans_token_embedding_dim, 
                               encoder_hidden_dim=hidden_size, decoder_hidden_dim=hidden_size, dropout_prob=dropout)
        
        
    def forward(self, x, x_sem, edge_index, edge_attr, batch_alloc, qst, qst_len, img, batch_size, 
                a, teacher_forcing_ratio=0.5, encoder_hidden=None):
        node_embeddings, qn_encoding2 = self.encoder(x, x_sem, edge_index, edge_attr, batch_alloc, qst, qst_len, img, batch_size)
        node_embeddings = torch.transpose(node_embeddings, 1, 0)
        
        decoder_output = self.decoder(a, node_embeddings, qn_encoding2, teacher_forcing_ratio)
        
        return decoder_output
    
    

vqa_seq2seq_model = VQASeqToSeq(qst_token_embedding_dim, ans_token_embedding_dim,
                                HIDDEN_SIZE, N_LAYERS, DROPOUT)
vqa_seq2seq_model.to(device)
print('Models built and ready to go!')

Models built and ready to go!


# Define Loss, Optimizer, & utility functions

In [39]:
LEARNING_RATE = 0.001

optimizer = optim.Adam(filter(lambda p: p.requires_grad, vqa_seq2seq_model.parameters()), 
                       lr=LEARNING_RATE) #, weight_decay=1e-5)

instance_entropy_with_logits = nn.CrossEntropyLoss(ignore_index=PAD_token_idx)

def accuracy(pred, true):
    acc = (true == pred.argmax(-1)).float().detach().cpu().numpy()
    return float(100 * acc.sum() / len(acc))

def format_time(elapsed_time):
    """
    Takes a time in seconds and returns a string hh:mm:ss
    """
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed_time)))
    return str(datetime.timedelta(seconds=elapsed_rounded)) # Format as hh:mm:ss

# Training

In [41]:
resume_epoch = 0
num_epochs = 3
print_every = 10
save_every = 1000

print('Initializing ...')
print("Training...")

total_t0 = time.time()
for epoch in range(resume_epoch, num_epochs):
    total_loss = 0
    total_score = 0

    vqa_seq2seq_model.train() # IMPORTANT
    
    t0 = time.time()
    for batch_idx, batch_data in enumerate(train_data_loader):
        if (batch_idx == n_train_batches-1):
            break            
            
        vqa_seq2seq_model.zero_grad()  # IMPORTANT
        
        
        
        
        batch_alloc = batch_data.batch.to(device)
        q_len = batch_data.q_length
        batch_size = q_len.shape[0]

        q = torch.tensor(batch_data.q).view(batch_size, -1).to(device)
        a = torch.transpose( torch.tensor(batch_data.y).view(batch_size, -1), 1, 0).to(device)
        v = torch.tensor(batch_data.v).to(device)
        
        edge_attr = batch_data.edge_attr.to(device)
        edge_index = batch_data.edge_index.to(device)
        x = batch_data.x.to(device)
        x_sem = batch_data.x_sem.to(device)

        output = vqa_seq2seq_model(x, x_sem, edge_index, edge_attr, batch_alloc, q, q_len, v, batch_size,
                     a, 0.5)
        output_dim = output.shape[-1]
        output_for_loss = output[1:].view(-1, output_dim).to(device)
        trg = a[1:].reshape(-1)

        output_for_loss_softmaxed = nn.functional.log_softmax(output_for_loss, -1)
        loss = instance_entropy_with_logits(output_for_loss_softmaxed, trg) # LOSS

        loss.backward()  # IMPORTANT
        _ = nn.utils.clip_grad_norm_(vqa_seq2seq_model.parameters(), 0.25)
        optimizer.step()  # IMPORTANT
        
        batch_loss = loss.item()
        batch_score = accuracy(output_for_loss, trg) # ACCURACY
        
        total_loss += batch_loss
        total_score += batch_score
            
        with open(os.path.join(SAVE_DIR, 'train-log-epoch.%s.txt' % (SAVE_DIR.split('-')[-1]) ), 'a') as f:
                f.write(str(epoch+1) + '\t' + str(batch_idx+1) + '\t' + str(batch_loss) + '\t' + str(batch_score) + '\n')
        writer.add_scalar('training loss', loss.item(), epoch * n_train_batches + batch_idx)
        writer.add_scalar('training score', batch_score, epoch * n_train_batches + batch_idx)
                
        if batch_idx % print_every == 0: # Print progress
            total_loss_avg = total_loss / print_every 
            total_score_avg = total_score / print_every
            elapsed = format_time(time.time() - t0)
            print('| TRAIN SET | Epoch [{:02d}/{:02d}], Step [{:04d}/{:04d}], Loss: {:.4f} , Score: {:.4f} | Elapsed: {:}'
                          .format(epoch+1, num_epochs, batch_idx, int(n_train_batches), total_loss_avg, total_score_avg, elapsed))
            total_loss = 0
            total_score = 0                
        
        if ( (batch_idx == n_train_batches-2) or ((batch_idx+1) % save_every == 0) ): # Save checkpoint
            directory = os.path.join(SAVE_DIR, 'vqa-pytorch-model')
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save(vqa_seq2seq_model.state_dict(),
                       os.path.join(directory, 'epoch-{}.batch-{}.{}.pt'.format(epoch+1, batch_idx+1, 'checkpoint')))

training_epoch_time = format_time(time.time() - t0)
print(training_epoch_time)
        
print("Total training took {:} (h:mm:ss)".format(format_time(time.time() - total_t0)))

Initializing ...
Training...


  q = torch.tensor(batch_data.q).view(batch_size, -1).to(device)
  a = torch.transpose( torch.tensor(batch_data.y).view(batch_size, -1), 1, 0).to(device)


| TRAIN SET | Epoch [01/03], Step [0000/7009], Loss: 0.7144 , Score: 2.2656 | Elapsed: 0:00:01
0:00:01
Total training took 0:00:01 (h:mm:ss)


# Validation on Ground Truth

In [42]:
def correct_pred_num(true, pred):
    correct_pred = (true == pred.argmax(-1)).float().detach().cpu().numpy()
    return correct_pred.sum(), len(correct_pred)

In [53]:
val_print_every = 10

print('Initializing ...')
print("Validation ...")
print("Validation on ground truth scene-graph")

total_t0 = time.time()
total_loss = 0
total_score = 0
total_num_pred = 0
total_num_correct_pred = 0

vqa_seq2seq_model.eval() # IMPORTANT

prediction_file = os.path.join(SAVE_DIR, "gt_predictions_%s.json" % (SAVE_DIR.split('-')[-1]) )
prediction_dict = []

    
t0 = time.time()
for batch_idx, batch_data in enumerate(val_data_loader):
    if (batch_idx == n_val_batches-1):
            break
    
    batch_alloc = batch_data.batch.to(device)
    q_len = batch_data.q_length
    batch_size = q_len.shape[0]

    q = torch.tensor(batch_data.q).view(batch_size, -1).to(device)
    a = torch.transpose( torch.tensor(batch_data.y).view(batch_size, -1), 1, 0).to(device)
    v = torch.tensor(batch_data.v).to(device)
        
    edge_attr = batch_data.edge_attr.to(device)
    edge_index = batch_data.edge_index.to(device)
    x = batch_data.x.to(device)
    x_sem = batch_data.x_sem.to(device)


    output = vqa_seq2seq_model(x, x_sem, edge_index, edge_attr, batch_alloc, q, q_len, v, batch_size,
                     a, 0.0)
    output_dim = output.shape[-1]
    output_for_loss = output[1:].view(-1, output_dim).to(device)
    trg = a[1:].reshape(-1)
        
    output_for_loss_softmaxed = nn.functional.log_softmax(output_for_loss, -1)
    loss = instance_entropy_with_logits(output_for_loss_softmaxed, trg) # LOSS

    batch_loss = loss.item()
    batch_score = accuracy(output_for_loss, trg) # ACCURACY
    
    batch_num_correct_pred, batch_total_pred = correct_pred_num(trg, output_for_loss)
        
    total_loss += batch_loss
    total_score += batch_score
    total_num_correct_pred += batch_num_correct_pred
    total_num_pred += batch_total_pred

    prediction_labels = output_for_loss.argmax(-1)
    qst_ids_list = batch_data.qst_id
    for sample_idx in range(prediction_labels.shape[0]):
        prediction_dict.append( {"questionId": qst_ids_list[sample_idx], 
                                                    "prediction": ANS_TEXT.vocab.itos[prediction_labels[sample_idx]]} )

    with open(os.path.join(SAVE_DIR, 'gt-val-log-epoch.%s.txt' % (SAVE_DIR.split('-')[-1])), 'a') as f: # Log the loss and accuracy in an epoch.
            f.write(str(batch_idx+1) \
                    + '\t' + str(batch_loss) + '\t' + str(batch_score) \
                    + '\t' + str(total_num_correct_pred) + '\t' + str(total_num_pred) + '\n')

    if batch_idx % val_print_every == 0: # Print progress
        total_loss_avg = total_loss / val_print_every 
        total_score_avg = total_score / val_print_every
        perf = float(100 * total_num_correct_pred / total_num_pred)
        elapsed = format_time(time.time() - t0)
        print('| VAL SET | Step [{:04d}/{:04d}], Loss: {:.4f} , Score: {:.4f} | Perf: {:.4f} | Elapsed: {:}'
                .format(batch_idx, int(n_val_batches), total_loss_avg, total_score_avg, perf, elapsed))
        total_loss = 0
        total_score = 0
        
final_accuracy = float(100 * total_num_correct_pred / total_num_pred)
with open(os.path.join(SAVE_DIR, 'gt-val-log-epoch.%s.txt' % (SAVE_DIR.split('-')[-1])), 'a') as f:
    f.write(str(batch_idx+1) + '\t' + str(batch_loss) + '\t' + str(batch_score) \
            + '\t' + str(total_num_correct_pred) + '\t' + str(total_num_pred) + '\n')
    f.write("final accuracy : %s" % final_accuracy)

with open(prediction_file, 'w') as pred_write_file:
    json.dump(prediction_dict, pred_write_file, indent=4)
    
print("Final performance on ground truth scene-graph ...")
print("... final accuracy : ",  final_accuracy)
        
print("Total validation took {:} (h:mm:ss)".format(format_time(time.time() - total_t0)))

Initializing ...
Validation ...
| VAL SET | Step [0000/0983], Loss: 0.7047 , Score: 1.6406 | Perf: 16.4062 | Elapsed: 0:00:01
Final performance ...
... final accuracy :  16.40625
Total validation took 0:00:01 (h:mm:ss)


  q = torch.tensor(batch_data.q).view(batch_size, -1).to(device)
  a = torch.transpose( torch.tensor(batch_data.y).view(batch_size, -1), 1, 0).to(device)


# Validation on Inferred

In [48]:
INF_VAL_ROOT = "/local/home/rchan31/SIGIR/detailed-experiments/data/INF_VAL_DATA/"

VAL_INF_BBOX_FOLDER = os.path.join(INF_VAL_ROOT, "inferred_val_bboxes")
VAL_INF_SG_FILE = os.path.join(INF_VAL_ROOT, "inferred_val_gq_sg_new_sgg.json")

In [49]:
VAL_INF_SG_DATA = json.load( open(VAL_INF_SG_FILE, "r") )
print(len(VAL_INF_SG_DATA) )

10055


In [50]:
def val_inf_get_qas_for_imgs_present(img_id):
    if os.path.isfile( os.path.join(IMG_FTR_FOLDER, "{}.json".format(img_id)) ) and \
        os.path.isfile( os.path.join(VAL_INF_BBOX_FOLDER, "{}.json".format(img_id)) ):
        if (img_id in VAL_INF_SG_DATA) and (len(VAL_INF_SG_DATA[str(img_id)]) != 0):
            return True
    else: False

def val_inf_decouple_q_and_a(qas_data):    
    structured_qas = {"img_id": [], "questions": [], "answers": [], "question_ids": []}
    for idx, (qas_id, qas_ins) in enumerate(qas_data.items()):
        img_id = qas_ins['imageId']        
        
        if not (val_inf_get_qas_for_imgs_present(img_id)): continue
        
        q = qas_ins["question"]
        a = qas_ins["answer"]
        
        structured_qas['img_id'].append(img_id)
        structured_qas['questions'].append(q)
        structured_qas['answers'].append(a)
        structured_qas['question_ids'].append(qas_id)
    
    return structured_qas

val_inf_structured_qas = val_inf_decouple_q_and_a(val_qas_data)
print("Finished loading inferred valid_structured_qas ...")

print( len(val_inf_structured_qas['answers']) )

Finished loading inferred valid_structured_qas ...
122213


In [55]:
val_inf_vqa_dataset = VQA(val_inf_structured_qas,
                      QUE_TEXT,
                      ANS_TEXT,
                      data_mode='gt_inf')

print( val_inf_vqa_dataset.__len__() )
val_data = val_inf_vqa_dataset.__getitem__(1)

Step1 : Data loading
....... gt_inf mode number of data samples : 122213
....... VG Answer data have been PREPARED & ENCODED ...
....... VG Question data have been PREPARED & ENCODED ...
....... Data loading completed ...
122213


In [56]:
BATCH_SIZE = 128

val_inf_data_loader = torch_geometric.data.DataLoader(val_inf_vqa_dataset,
                                                      batch_size=BATCH_SIZE )
n_val_inf_batches = len(val_inf_data_loader)
print( n_val_inf_batches )

955


In [57]:
val_print_every = 10

print('Initializing ...')
print("Validation ...")
print("Validation on inferred scene-graph")

total_t0 = time.time()
total_loss = 0
total_score = 0
total_num_pred = 0
total_num_correct_pred = 0

vqa_seq2seq_model.eval() # IMPORTANT

prediction_file = os.path.join(SAVE_DIR, "inf_predictions_%s.json" % (SAVE_DIR.split('-')[-1]) )
prediction_dict = []

    
t0 = time.time()
for batch_idx, batch_data in enumerate(val_inf_data_loader):
    if (batch_idx == n_val_batches-1):
            break
    
    batch_alloc = batch_data.batch.to(device)
    q_len = batch_data.q_length
    batch_size = q_len.shape[0]

    q = torch.tensor(batch_data.q).view(batch_size, -1).to(device)
    a = torch.transpose( torch.tensor(batch_data.y).view(batch_size, -1), 1, 0).to(device)
    v = torch.tensor(batch_data.v).to(device)
        
    edge_attr = batch_data.edge_attr.to(device)
    edge_index = batch_data.edge_index.to(device)
    x = batch_data.x.to(device)
    x_sem = batch_data.x_sem.to(device)


    output = vqa_seq2seq_model(x, x_sem, edge_index, edge_attr, batch_alloc, q, q_len, v, batch_size,
                     a, 0.0)
    output_dim = output.shape[-1]
    output_for_loss = output[1:].view(-1, output_dim).to(device)
    trg = a[1:].reshape(-1)
        
    output_for_loss_softmaxed = nn.functional.log_softmax(output_for_loss, -1)
    loss = instance_entropy_with_logits(output_for_loss_softmaxed, trg) # LOSS

    batch_loss = loss.item()
    batch_score = accuracy(output_for_loss, trg) # ACCURACY
    
    batch_num_correct_pred, batch_total_pred = correct_pred_num(trg, output_for_loss)
        
    total_loss += batch_loss
    total_score += batch_score
    total_num_correct_pred += batch_num_correct_pred
    total_num_pred += batch_total_pred

    prediction_labels = output_for_loss.argmax(-1)
    qst_ids_list = batch_data.qst_id
    for sample_idx in range(prediction_labels.shape[0]):
        prediction_dict.append( {"questionId": qst_ids_list[sample_idx], 
                                                    "prediction": ANS_TEXT.vocab.itos[prediction_labels[sample_idx]]} )

    with open(os.path.join(SAVE_DIR, 'inf-val-log-epoch.%s.txt' % (SAVE_DIR.split('-')[-1])), 'a') as f: # Log the loss and accuracy in an epoch.
            f.write(str(batch_idx+1) \
                    + '\t' + str(batch_loss) + '\t' + str(batch_score) \
                    + '\t' + str(total_num_correct_pred) + '\t' + str(total_num_pred) + '\n')

    if batch_idx % val_print_every == 0: # Print progress
        total_loss_avg = total_loss / val_print_every 
        total_score_avg = total_score / val_print_every
        perf = float(100 * total_num_correct_pred / total_num_pred)
        elapsed = format_time(time.time() - t0)
        print('| VAL SET | Step [{:04d}/{:04d}], Loss: {:.4f} , Score: {:.4f} | Perf: {:.4f} | Elapsed: {:}'
                .format(batch_idx, int(n_val_batches), total_loss_avg, total_score_avg, perf, elapsed))
        total_loss = 0
        total_score = 0
        
final_accuracy = float(100 * total_num_correct_pred / total_num_pred)
with open(os.path.join(SAVE_DIR, 'inf-val-log-epoch.%s.txt' % (SAVE_DIR.split('-')[-1])), 'a') as f:
    f.write(str(batch_idx+1) + '\t' + str(batch_loss) + '\t' + str(batch_score) \
            + '\t' + str(total_num_correct_pred) + '\t' + str(total_num_pred) + '\n')
    f.write("final accuracy : %s" % final_accuracy)

with open(prediction_file, 'w') as pred_write_file:
    json.dump(prediction_dict, pred_write_file, indent=4)
    
print("Final performance on inferred scene-graph ...")
print("... final accuracy : ",  final_accuracy)
        
print("Total validation took {:} (h:mm:ss)".format(format_time(time.time() - total_t0)))

Initializing ...
Validation ...
| VAL SET | Step [0000/0983], Loss: 0.7045 , Score: 1.6406 | Perf: 16.4062 | Elapsed: 0:00:01
Final performance ...
... final accuracy :  16.40625
Total validation took 0:00:01 (h:mm:ss)


  q = torch.tensor(batch_data.q).view(batch_size, -1).to(device)
  a = torch.transpose( torch.tensor(batch_data.y).view(batch_size, -1), 1, 0).to(device)
