In [11]:
import os
import pandas as pd
import argparse
import shutil
import numpy as np
import copy
import random
import math
import time
import re
import json
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, TensorDataset, RandomSampler, SequentialSampler
from torchvision import transforms, utils
from skimage import io, transform
import torch.nn.functional as F
from PIL import Image
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate import bleu_score
from nltk.translate.bleu_score import corpus_bleu
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch
import torchvision.models as models
import torchvision
from transformers import BertTokenizer
from transformers import BertModel
import nltk
nltk.download('punkt')
from collections import defaultdict
import collections
import pickle
from tqdm import tqdm

import torchvision.transforms as transforms


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

ROOT_PATH = '/content/drive/MyDrive/VQA/Datasets/OVQA_publish/'

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
ANS_LABLE_DICT = {}
Q_TYPE_LABLE_DICT = {}
ANS_TYPE_LABLE_DICT = {}
IMG_ORGAN_LABLE_DICT = {}
with open(f"{ROOT_PATH}valset.json") as f_val:
    data_val = json.load(f_val)

with open(f"{ROOT_PATH}trainset.json") as f_train:
    data_train = json.load(f_train)

with open(f"{ROOT_PATH}testset.json") as f_test:
    data_test = json.load(f_test)

data = data_val + data_train + data_test
i = 0
j = 0
k = 0
l = 0
for elem in data:
  if elem["answer"] not in ANS_LABLE_DICT.keys():
    # remove special characters and make it all lower
    ANS_LABLE_DICT[elem["answer"]] = i
    i += 1

  if elem["question_type"] not in Q_TYPE_LABLE_DICT.keys():
    Q_TYPE_LABLE_DICT[elem["question_type"]] = j
    j += 1

  if elem["answer_type"] not in ANS_TYPE_LABLE_DICT.keys():
    ANS_TYPE_LABLE_DICT[elem["answer_type"]] = k
    k += 1

  if elem["image_organ"] not in IMG_ORGAN_LABLE_DICT.keys():
    IMG_ORGAN_LABLE_DICT[elem["image_organ"]] = l
    l += 1

print(len(ANS_LABLE_DICT))
print(len(Q_TYPE_LABLE_DICT))
print(len(ANS_TYPE_LABLE_DICT))
print(len(IMG_ORGAN_LABLE_DICT))

1067
6
2
5


In [4]:
ANS_TYPE_LABLE_DICT

{'OPEN': 0, 'CLOSED': 1}

In [5]:
class OVQADataset(Dataset):
    """OVQA images and questions dataset."""

    def __init__(self, json_file, root_dir, phase, transform=None):
        """
        Arguments:
            json_file (string): Path to the json file with questions.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        with open(json_file) as f:
          self.question_data = json.load(f)
        self.root_dir = root_dir
        self.transform = transform
        self.phase = phase
        self.img_feat_vqa = np.load(f"{ROOT_PATH}features.pkl", allow_pickle=True )

    def __len__(self):
        return len(self.question_data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_feat_vqa = self.img_feat_vqa
        img_name = os.path.join(self.root_dir,
                                self.question_data[idx]["image_name"])

        image_feat = torch.Tensor(img_feat_vqa[img_name.split('/')[-1]])

        question = self.question_data[idx]["question"]
        answer =  self.question_data[idx]["answer"]
        question_type = self.question_data[idx]["question_type"]
        answer_type =  self.question_data[idx]["answer_type"]
        image_organ = self.question_data[idx]["image_organ"]
        qid = self.question_data[idx]["qid"]

        #print(image.size())

        sample = {'image': image_feat,
                  'question': question,
                  'answer_label':F.one_hot(torch.tensor([[ANS_LABLE_DICT[answer]]]), len(ANS_LABLE_DICT)),
                  'question_type_label':F.one_hot(torch.tensor([[Q_TYPE_LABLE_DICT[question_type]]]), len(Q_TYPE_LABLE_DICT)),
                  'answer_text': answer,
                  'answer_type_label':F.one_hot(torch.tensor([[ANS_TYPE_LABLE_DICT[answer_type]]]), len(ANS_TYPE_LABLE_DICT)),
                  'qid': qid,
                  'image_organ_label': F.one_hot(torch.tensor([[IMG_ORGAN_LABLE_DICT[image_organ]]]), len(IMG_ORGAN_LABLE_DICT))}

        return sample

In [6]:
def get_loader(batch_size, num_workers,size=228):
    '''
        Load our dataset with dataloader for the train and valid data
    '''

    vqa_dataset = {
        'train': OVQADataset(json_file=f'{ROOT_PATH}trainset.json',
                                    root_dir=f'{ROOT_PATH}img/', phase='train', transform=transform),
        'valid': OVQADataset(json_file=f'{ROOT_PATH}valset.json',
                                    root_dir=f'{ROOT_PATH}/img/', phase='valid', transform=transform)}


    data_loader = {
        phase: torch.utils.data.DataLoader(
            dataset=vqa_dataset[phase],
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            )
        for phase in ['train','valid']}

    return data_loader

In [7]:
import argparse

def parse_opt():
    parser = argparse.ArgumentParser()

    # Data input settings
    parser.add_argument('--SEED', type=int, default=97)
    parser.add_argument('--BATCH_SIZE', type=int, default=64)
    parser.add_argument('--VAL_BATCH_SIZE', type=int, default=64)
    parser.add_argument('--NUM_OUTPUT_UNITS', type=int, default=len(ANS_LABLE_DICT))
    parser.add_argument('--MAX_QUESTION_LEN', type=int, default=17) # double check this
    parser.add_argument('--IMAGE_CHANNEL', type=int, default=1472)
    parser.add_argument('--INIT_LERARNING_RATE', type=float, default=1e-4)
    parser.add_argument('--LAMNDA', type=float, default=0.0001)
    parser.add_argument('--MFB_FACTOR_NUM', type=int, default=5)
    parser.add_argument('--MFB_OUT_DIM', type=int, default=1024)
    parser.add_argument('--BERT_UNIT_NUM', type=int, default=768)
    parser.add_argument('--BERT_DROPOUT_RATIO', type=float, default=0.3)
    parser.add_argument('--MFB_DROPOUT_RATIO', type=float, default=0.1)
    parser.add_argument('--NUM_IMG_GLIMPSE', type=int, default=2)
    parser.add_argument('--NUM_QUESTION_GLIMPSE', type=int, default=2)
    parser.add_argument('--IMG_FEAT_SIZE', type=int, default=1)
    parser.add_argument('--IMG_INPUT_SIZE', type=int, default=224)
    parser.add_argument('--NUM_EPOCHS', type=int, default=200)
    args = parser.parse_args(args=[])
    return args

opt = parse_opt()

In [8]:
answer_classes = ANS_LABLE_DICT


class BERTokenizer():

    def __init__(self,opt):
        # Load the BERT tokenizer
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.opt = opt
    #pre-process the text data
    def text_preprocessing(self, text):

        # Remove trailing whitespace
        text = re.sub(r'\s+', ' ', text).strip()

        return text


    # Create a function to tokenize a set of texts
    def preprocessing_for_bert(self, data):
        """Perform required preprocessing steps for pretrained BERT.
        @param    data (np.array): Array of texts to be processed.
        @return   input_ids (torch.Tensor): Tensor of token ids to be fed to a model.
        @return   attention_masks (torch.Tensor): Tensor of indices specifying which
                    tokens should be attended to by the model.
        """
        # Create empty lists to store outputs
        input_ids = []
        attention_masks = []
        MAX_LEN = self.opt.MAX_QUESTION_LEN
        # For every sentence...
        for sent in data:

            encoded_sent = self.tokenizer.encode_plus(
                text=self.text_preprocessing(sent),  # Preprocess sentence
                add_special_tokens=True,        # Add `[CLS]` and `[SEP]`
                max_length=MAX_LEN,                  # Max length to truncate/pad
                pad_to_max_length=True,         # Pad sentence to max length
                #return_tensors='pt',           # Return PyTorch tensor
                truncation=True,
                return_attention_mask=True      # Return attention mask
                )

            # Add the outputs to the lists
            input_ids.append(encoded_sent.get('input_ids'))
            attention_masks.append(encoded_sent.get('attention_mask'))

        # Convert lists to tensors
        input_ids = torch.tensor(input_ids)
        attention_masks = torch.tensor(attention_masks)

        return input_ids, attention_masks


# Create the question encoder base on  BertClassfier
class BertQstEncoder(nn.Module):
    """Bert Model for Classification Tasks.
    """
    def __init__(self, opt,freeze_bert=True):
        """
        @param    bert: a BertModel object
        @param    classifier: a torch.nn.Module classifier
        @param    freeze_bert (bool): Set `False` to fine-tune the BERT model
        @param    opt: for configuration parameter
        """
        super(BertQstEncoder, self).__init__()
        self.opt = opt
        # Specify hidden size of BERT, hidden size of our classifier, and number of labels
        D_in= self.opt.BERT_UNIT_NUM

        # Instantiate BERT model
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.bert.eval()
        self.bert_emb = self.bert.embeddings
        self.bert_encode_layer1 = (self.bert.encoder.layer)[0]
        self.bert_encode_layer2 = (self.bert.encoder.layer)[1]
        self.bert_encode_layer3 = (self.bert.encoder.layer)[2]
        self.bert_encode_layer4 = (self.bert.encoder.layer)[3]
        self.bert_encode_layer5 = (self.bert.encoder.layer)[4]
        self.bert_encode_layer6 = (self.bert.encoder.layer)[5]
        self.bert_encode_layer7 = (self.bert.encoder.layer)[6]
        self.bert_encode_layer8 = (self.bert.encoder.layer)[7]
        self.bert_encode_layer9 = (self.bert.encoder.layer)[8]
        self.bert_encode_layer10 = (self.bert.encoder.layer)[9]
        self.bert_encode_layer11 = (self.bert.encoder.layer)[10]
        self.bert_encode_layer12 = (self.bert.encoder.layer)[11]


    def forward(self, input_ids, attention_mask):
        """
        Feed input to BERT and the classifier to compute logits.
        @param    input_ids (torch.Tensor): an input tensor with shape (batch_size,
                      max_length)
        @param    attention_mask (torch.Tensor): a tensor that hold attention mask
                      information with shape (batch_size, max_length)
        @return   logits (torch.Tensor): an output tensor with shape (batch_size,
                      num_labels)
        """


        # Feed input to BERT

        with torch.no_grad():
            emb_out = self.bert_emb(input_ids, attention_mask)
            layer1= self.bert_encode_layer1(emb_out)
            layer2= self.bert_encode_layer2(layer1[0])
            layer3= self.bert_encode_layer3(layer2[0])
            layer4= self.bert_encode_layer4(layer3[0])
            layer5= self.bert_encode_layer5(layer4[0])
            layer6= self.bert_encode_layer6(layer5[0])
            layer7= self.bert_encode_layer7(layer6[0])
            layer8= self.bert_encode_layer8(layer7[0])
            layer9= self.bert_encode_layer9(layer8[0])
            layer10= self.bert_encode_layer10(layer9[0])
            layer11= self.bert_encode_layer11(layer10[0])
            layer12= self.bert_encode_layer12(layer11[0])

        word_question_representation = (layer11[0] +layer12[0])/2

        return word_question_representation


#Extract the question feature with co-attention
class QuestionFeatureExtractionAtt(nn.Module):
    '''
        Extract the question with co-attention, get from https://github.com/asdf0982/vqa-mfb.pytorch
    '''

    def __init__(self,opt):

        super(QuestionFeatureExtractionAtt, self).__init__()

        self.opt = opt
        self.NUM_QUESTION_GLIMPSE = self.opt.NUM_QUESTION_GLIMPSE

        self.JOINT_EMB_SIZE = self.opt.MFB_FACTOR_NUM * self.opt.MFB_OUT_DIM
        self.Softmax = nn.Softmax(dim=-1)

        self.Linear1_q_proj = nn.Linear(self.opt.BERT_UNIT_NUM* self.opt.NUM_QUESTION_GLIMPSE, self.JOINT_EMB_SIZE)
        self.Linear2_q_proj = nn.Linear(self.opt.BERT_UNIT_NUM*self.opt.NUM_QUESTION_GLIMPSE, self.JOINT_EMB_SIZE)

        self.Dropout_M = nn.Dropout(p=self.opt.MFB_DROPOUT_RATIO)
        self.dropout = nn.Dropout(self.opt.BERT_DROPOUT_RATIO)
        self.Conv1_Qatt = nn.Conv2d(self.opt.BERT_UNIT_NUM, self.opt.IMAGE_CHANNEL, 1)
        self.Conv2_Qatt = nn.Conv2d(self.opt.IMAGE_CHANNEL, self.opt.NUM_QUESTION_GLIMPSE, 1)

    def forward(self,qst_encoding):

        '''
        Question Attention
        '''
        self.batch_size = qst_encoding.shape[0]
        qst_encoding = self.dropout(qst_encoding)
        qst_encoding_resh =  torch.unsqueeze(qst_encoding, 3)       # N=4 x 768 x T=14 x 1
        qatt_conv1 = self.Conv1_Qatt(qst_encoding_resh)                   # N x 512 x T x 1
        qatt_relu = F.relu(qatt_conv1)
        qatt_conv2 = self.Conv2_Qatt(qatt_relu)                          # N x 2 x T x 1
        qatt_conv2 = qatt_conv2.contiguous().view(self.batch_size*2,-1)
        qatt_softmax = self.Softmax(qatt_conv2)
        qatt_softmax = qatt_softmax.view(self.batch_size, 2, -1, 1)
        qatt_feature_list = []
        for i in range(self.NUM_QUESTION_GLIMPSE):
            t_qatt_mask = qatt_softmax.narrow(1, i, 1)              # N x 1 x T x 1
            t_qatt_mask = t_qatt_mask * qst_encoding_resh           # N x 768 x T x 1
            t_qatt_mask = torch.sum(t_qatt_mask, 2, keepdim=True)   # N x 768 x 1 x 1
            qatt_feature_list.append(t_qatt_mask)
        qatt_feature_concat = torch.cat(qatt_feature_list, 1)       # N x 1536 x 1 x 1

        return qatt_feature_concat


#Extract the image feature with MFB and co-attention
class ImageFeatureExtractionAtt(nn.Module):

    '''
        Extract the image with co-attention, get from https://github.com/asdf0982/vqa-mfb.pytorch
    '''

    def __init__(self,opt):
        super(ImageFeatureExtractionAtt, self).__init__()
        self.opt = opt
        self.MFB_FACTOR_NUM = self.opt.MFB_FACTOR_NUM
        self.MFB_OUT_DIM = self.opt.MFB_OUT_DIM
        self.NUM_IMG_GLIMPSE =self.opt.NUM_IMG_GLIMPSE
        self.IMG_FEAT_SIZE = self.opt.IMG_FEAT_SIZE

        self.JOINT_EMB_SIZE = self.opt.MFB_FACTOR_NUM * self.opt.MFB_OUT_DIM
        self.Softmax = nn.Softmax(dim=-1)

        self.Linear1_q_proj = nn.Linear(self.opt.BERT_UNIT_NUM* self.opt.NUM_QUESTION_GLIMPSE, self.JOINT_EMB_SIZE)
        self.Linear_i_proj = nn.Linear(self.opt.IMAGE_CHANNEL*self.opt.NUM_IMG_GLIMPSE, self.JOINT_EMB_SIZE)
        self.Conv_i_proj = nn.Conv2d(self.opt.IMAGE_CHANNEL, self.JOINT_EMB_SIZE, 1)


        self.Dropout_M = nn.Dropout(p=self.opt.MFB_DROPOUT_RATIO)

        self.Conv1_Iatt = nn.Conv2d(self.opt.MFB_OUT_DIM, self.opt.IMAGE_CHANNEL, 1) # (1000, 512, 1)
        self.Conv2_Iatt = nn.Conv2d(self.opt.IMAGE_CHANNEL, self.NUM_IMG_GLIMPSE, 1)

    def forward(self, img_feature, qstatt_feature):

        '''
        Image Attention with MFB
        '''
        self.batch_size = img_feature.shape[0]
        q_feat_resh = torch.squeeze(qstatt_feature)                              # N x 1536
        i_feat_resh = img_feature.unsqueeze(3)                                   # N x 512 x 196 x 1
        #print(i_feat_resh.shape)
        iatt_q_proj = self.Linear1_q_proj(q_feat_resh)                                  # N x 5000
        iatt_q_resh = iatt_q_proj.view(self.batch_size, self.JOINT_EMB_SIZE, 1, 1)      # N x 5000 x 1 x 1
        iatt_i_conv = self.Conv_i_proj(i_feat_resh)                                     # N x 5000 x 196 x 1
        iatt_iq_eltwise = iatt_q_resh * iatt_i_conv
        iatt_iq_droped = self.Dropout_M(iatt_iq_eltwise)                                # N x 5000 x 196 x 1
        iatt_iq_permute1 = iatt_iq_droped.permute(0,2,1,3).contiguous()                 # N x 196 x 5000 x 1
        iatt_iq_resh = iatt_iq_permute1.view(self.batch_size, self.IMG_FEAT_SIZE, self.MFB_OUT_DIM, self.MFB_FACTOR_NUM)
        iatt_iq_sumpool = torch.sum(iatt_iq_resh, 3, keepdim=True)                      # N x 196 x 1000 x 1
        iatt_iq_permute2 = iatt_iq_sumpool.permute(0,2,1,3)                             # N x 1000 x 196 x 1
        iatt_iq_sqrt = torch.sqrt(F.relu(iatt_iq_permute2)) - torch.sqrt(F.relu(-iatt_iq_permute2))
        iatt_iq_sqrt = iatt_iq_sqrt.reshape(self.batch_size, -1)                           # N x 196000
        iatt_iq_l2 = F.normalize(iatt_iq_sqrt)
        iatt_iq_l2 = iatt_iq_l2.view(self.batch_size, self.MFB_OUT_DIM, self.IMG_FEAT_SIZE, 1)  # N x 1000 x 196 x 1

        iatt_conv1 = self.Conv1_Iatt(iatt_iq_l2)                    # N x 512 x 196 x 1
        iatt_relu = F.relu(iatt_conv1)
        iatt_conv2 = self.Conv2_Iatt(iatt_relu)                     # N x 2 x 196 x 1
        iatt_conv2 = iatt_conv2.view(self.batch_size*self.NUM_IMG_GLIMPSE, -1)
        iatt_softmax = self.Softmax(iatt_conv2)
        iatt_softmax = iatt_softmax.view(self.batch_size, self.NUM_IMG_GLIMPSE, -1, 1)
        iatt_feature_list = []
        for i in range(self.NUM_IMG_GLIMPSE):
            t_iatt_mask = iatt_softmax.narrow(1, i, 1)              # N x 1 x 196 x 1
            t_iatt_mask = t_iatt_mask * i_feat_resh                 # N x 512 x 196 x 1
            t_iatt_mask = torch.sum(t_iatt_mask, 2, keepdim=True)   # N x 512 x 1 x 1
            iatt_feature_list.append(t_iatt_mask)
        iatt_feature_concat = torch.cat(iatt_feature_list, 1)       # N x 1024 x 1 x 1
        iatt_feature_concat = torch.squeeze(iatt_feature_concat)    # N x 1024
        return iatt_feature_concat



class VqaClassifierModel(nn.Module):
    '''
        Fusion with MFB,  get from https://github.com/asdf0982/vqa-mfb.pytorch
    '''

    def __init__(self, opt):
        super(VqaClassifierModel, self).__init__()
        self.opt = opt

        self.JOINT_EMB_SIZE = self.opt.MFB_FACTOR_NUM * self.opt.MFB_OUT_DIM

        self.MFB_OUT_DIM = self.opt.MFB_OUT_DIM
        self.MFB_FACTOR_NUM = self.opt.MFB_FACTOR_NUM
        NUM_OUTPUT_UNITS = self.opt.NUM_OUTPUT_UNITS


        self.tokenizer = BERTokenizer(self.opt)
        self.bert_model = BertQstEncoder(self.opt)


        self.qst_feature_att = QuestionFeatureExtractionAtt(self.opt)
        self.img_feature_att = ImageFeatureExtractionAtt(self.opt)

        self.Linear2_q_proj = nn.Linear(self.opt.BERT_UNIT_NUM*self.opt.NUM_QUESTION_GLIMPSE, self.JOINT_EMB_SIZE)
        self.Linear_i_proj = nn.Linear(self.opt.IMAGE_CHANNEL*self.opt.NUM_IMG_GLIMPSE, self.JOINT_EMB_SIZE)

        self.Dropout_M = nn.Dropout(p=self.opt.MFB_DROPOUT_RATIO)

        self.Linear_predict_1 = nn.Linear(self.opt.MFB_OUT_DIM + len(Q_TYPE_LABLE_DICT) + len(ANS_TYPE_LABLE_DICT) + len(IMG_ORGAN_LABLE_DICT), NUM_OUTPUT_UNITS)
        self.Linear_predict_2 = nn.Linear(self.opt.MFB_OUT_DIM, len(Q_TYPE_LABLE_DICT))
        self.Linear_predict_3 = nn.Linear(self.opt.MFB_OUT_DIM, len(ANS_TYPE_LABLE_DICT))
        self.Linear_predict_4 = nn.Linear(self.opt.MFB_OUT_DIM, len(IMG_ORGAN_LABLE_DICT))


    def forward(self, img, qst):

        self.batch_size = img.shape[0]
        image_feature = img
        input_ids, attention_mask = self.tokenizer.preprocessing_for_bert(qst)
        question_feature = self.bert_model(input_ids.to(device), attention_mask.to(device))
        question_feature = question_feature.transpose(1, 2)      # N=4 x 768 x T=14

        q_featatt = self.qst_feature_att(question_feature)      # N x 1536

        iatt_feature_concat = self.img_feature_att(image_feature,q_featatt)          # N x 1024

        '''
        Fine-grained Image-Question MFB fusion
        '''
        q_feat_resh = torch.squeeze(q_featatt)
        mfb_q_proj = self.Linear2_q_proj(q_feat_resh)               # N x 5000
        mfb_i_proj = self.Linear_i_proj(iatt_feature_concat)        # N x 5000
        mfb_iq_eltwise = torch.mul(mfb_q_proj, mfb_i_proj)          # N x 5000
        mfb_iq_drop = self.Dropout_M(mfb_iq_eltwise)
        mfb_iq_resh = mfb_iq_drop.view(self.batch_size, 1, self.MFB_OUT_DIM, self.MFB_FACTOR_NUM)   # N x 1 x 1000 x 5
        mfb_iq_sumpool = torch.sum(mfb_iq_resh, 3, keepdim=True)    # N x 1 x 1000 x 1
        mfb_out = torch.squeeze(mfb_iq_sumpool)                     # N x 1000
        mfb_sign_sqrt = torch.sqrt(F.relu(mfb_out)) - torch.sqrt(F.relu(-mfb_out))
        mfb_l2 = F.normalize(mfb_sign_sqrt)
        prediction_2 = self.Linear_predict_2(mfb_l2)
        prediction_3 = self.Linear_predict_3(mfb_l2)
        prediction_4 = self.Linear_predict_4(mfb_l2)
        #print(f"first two predictions {torch.cat((prediction_2, prediction_3, mfb_l2), dim=1).shape}")
        #print(f"pred + image {(prediction_1 + prediction_2 + image_feature).shape}")
        prediction_1 = self.Linear_predict_1(torch.cat((prediction_2, prediction_3, prediction_4, mfb_l2), dim=1))
        prediction = [F.log_softmax(prediction_1, -1), F.log_softmax(prediction_2, -1), F.log_softmax(prediction_3, -1), F.log_softmax(prediction_4, -1) ]# N x num_class
        #prediction = [F.log_softmax(prediction_1, -1), F.log_softmax(prediction_2, -1)]
        return prediction

In [12]:
seed_value = 97
np.random.seed(seed_value)
random.seed(seed_value)
torch.manual_seed(seed_value)
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

def train_model():
    since = time.time()

    best_acc1 = 0.0
    best_acc5 = 0.0

    best_acc_q = 0.0
    best_acc_a = 0.0

    #best_acc_val = 0.0
    best_epoch = 0
    list_train_loss_per_epoch = []
    list_valid_loss_per_epoch = []

    list_train_acc1_per_epoch = []
    list_valid_acc1_per_epoch = []

    model = VqaClassifierModel(opt=opt).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=0.0001)


    saved_dir = "/content/drive/MyDrive"
    #saved_dir = "/content"


    num_epochs = 200
    image_size = 224
    num_workers = 0
    batch_size = 32


    # Create the DataLoader for our dataset

    data_loader = get_loader(batch_size = batch_size,
            num_workers = num_workers,
            size = image_size )

    #alpha1 = 0.4
    #alpha2 = 0.3
    #alpha3 = 0.3

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            #accuracy = 0
            top1_acc = 0
            top5_acc = 0
            acc_test_f = 0

            acc_q = 0
            acc_a = 0

            bleu = 0
            batch_step_size = len(data_loader[phase].dataset) / batch_size

            # Iterate over data.
            for batch_idx, batch_sample in enumerate(data_loader[phase]):

                #image = batch_sample['image'].to(device)
                image = batch_sample['image'].to(device)
                #print(image.shape)
                questions = batch_sample['question']#.to(device)
                labels_answer = batch_sample['answer_label'].to(device)
                labels_q_type = batch_sample['question_type_label'].to(device)
                labels_a_type = batch_sample['answer_type_label'].to(device)
                label_answer_text = batch_sample['answer_text']#.to(device)
                label_organ_type = batch_sample['image_organ_label'].to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    output= model(image, questions)

                    #_, preds = torch.max(output, 1)
                    """
                    print(labels_q_type.shape)
                    print(labels_a_type.shape)

                    print(labels_answer.shape)
                    print(output[0].shape)
                    print(output[1].shape)
                    print(output[2].shape)
                    """
                    labels_answer = labels_answer.squeeze(1)
                    labels_answer = labels_answer.squeeze(1).float()


                    labels_q_type = labels_q_type.squeeze(1)
                    labels_q_type = labels_q_type.squeeze(1).float()

                    labels_a_type = labels_a_type.squeeze(1)
                    labels_a_type = labels_a_type.squeeze(1).float()

                    label_organ_type = label_organ_type.squeeze(1)
                    label_organ_type = label_organ_type.squeeze(1).float()

                    #print(labels_answer.shape)

                    loss_0 = criterion(output[0], labels_answer)
                    loss_1 = criterion(output[1], labels_q_type)
                    loss_2 = criterion(output[2], labels_a_type)
                    loss_3 = criterion(output[3], label_organ_type)
                    #print(loss_0, loss_1, loss_2)
                    #loss = loss_0 * alpha1 + loss_1 * alpha2 + loss_2 * alpha3
                    loss = loss_0 + loss_1 + loss_2 + loss_3

                    #loss = criterion(output, labels_answer)
                    #print(loss)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item()

                #exact match score
                acc1, acc5 = accuracy(output[0].data, labels_answer.data, topk=(1, 5))
                #print("first")
                #print(output[0].shape)
                #print(labels_answer.shape)
                #print(output[0])
                #print(labels_answer)
                top1_acc += acc1
                top5_acc += acc5

                acc_quest, _ = accuracy(output[1].data, labels_q_type.data, topk=(1, 5))
                #print("second")
                #print(output[1].shape)
                #print(labels_q_type.shape)
                #print(output[1])
                #print(labels_q_type)
                acc_q += acc_quest


                acc_ans, _ = accuracy(output[2].data, labels_a_type.data, topk=(1, 2))
                acc_a += acc_ans

                #bleu score
                #b = get_bleu_score(preds, label_answer_text)
                #bleu += b


                if batch_idx % 10 == 0:
                    pass
                    # print('| {} SET | Epoch [{:02d}/{:02d}], Step[{:04d}/{:04d}], Loss: {:.4f}, Top 1 Acc: {:.4f}, Top 5 Acc: {:.4f}, Quest Acc: {:.4f} Ans Acc: {:.47}'.format(phase.upper(), epoch+1, num_epochs, batch_idx, int(batch_step_size), loss.item(), acc1, acc5, acc_quest, acc_ans))#Acc: {:.4f},Bleu: {:.4f},acc, b
                    #print('| {} SET | Epoch [{:02d}/{:02d}], Step[{:04d}/{:04d}], Loss: {:.4f}, Top 1 Acc: {:.4f}, Top 5 Acc: {:.4f}'.format(phase.upper(), epoch+1, num_epochs, batch_idx, int(batch_step_size), loss.item(), acc1, acc5))#Acc: {:.4f},Bleu: {:.4f},acc, b


            epoch_loss = running_loss/batch_step_size
            epoch_acc1 = top1_acc/batch_step_size
            epoch_acc5 = top5_acc/batch_step_size
            epoch_acc_q = acc_q/batch_step_size
            epoch_acc_a = acc_a/batch_step_size

            #epoch_blue = bleu/batch_step_size

            #save the loss and accuracy for train and valid
            if phase =='train':

                list_train_loss_per_epoch.append(epoch_loss)
                list_train_acc1_per_epoch.append(epoch_acc1)

            else:

                list_valid_loss_per_epoch.append(epoch_loss)
                list_valid_acc1_per_epoch.append(epoch_acc1)
                """
                alpha1 += 0.02
                total_weight = alpha1 + alpha2 + alpha3
                alpha1 /= total_weight
                alpha2 /= total_weight
                alpha3 /= total_weight
                """

            print('{} Loss: {:.4f} Top 1 Acc: {:.4f} Top 5 Acc: {:.4f} Acc Quest: {:.4f} Acc Ans: {:.4f}'.format(
                phase, epoch_loss, epoch_acc1, epoch_acc5, epoch_acc_q, epoch_acc_a))


            # deep copy the model
            if phase == 'valid' and epoch_acc1 > best_acc1: #or epoch_acc5 > best_acc5 ):
                best_acc1 = epoch_acc1
                best_acc5 = epoch_acc5
                best_epoch = epoch
                best_model_wts = copy.deepcopy(model.state_dict())



    history_loss = {'train':list_train_loss_per_epoch, 'valid':list_valid_loss_per_epoch}
    history_acc1 = {'train':list_train_acc1_per_epoch, 'valid':list_valid_acc1_per_epoch}

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Top 1 Acc: {:4f}, Top 5 Acc: {:4f}'.format(best_acc1,best_acc5))


    # load best model weights
    model.load_state_dict(best_model_wts)
    model.load_state_dict(best_model_wts)
    state = {'epoch': best_epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
                'loss':epoch_loss,'valid_accuracy': best_acc1}

    #full_model_path =saved_dir+'/ovqa_2_tasks_mainandquest_model_state_seed_97.tar'
    full_model_path =saved_dir+'/ovqa_multitask_cascade_200epc_1024mfb_out_batch32_swin_model_state_seed_97.tar'

    torch.save(state, full_model_path)
    return model


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""

    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()



    if target.dim() == 2: # multians option
        _, target = torch.max(target, 1)
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append((correct_k / batch_size))

    return res


def get_bleu_score(predicted, true_ans_text):
    path_output_change = config.path_output_chd
    with open(path_output_change+'/answer_classes.json', 'r') as j:
        answer_classes_dict = json.load(j)
    score = 0.0
    assert (len(predicted) == len(true_ans_text))
    ans_keys = list(answer_classes_dict.keys())
    ans_values = list(answer_classes_dict.values())


    for pred, true_ans in zip(predicted, true_ans_text):
        index_ans = ans_values.index(pred)

        score += sentence_bleu([true_ans.split(' ')], ans_keys[index_ans].split(' '), smoothing_function=bleu_score.SmoothingFunction().method2)

    return score/len(true_ans_text)


def load_checkpoint(model, optimizer, filename=None):
    # Note: Input model & optimizer should be pre-defined. This routine only updates their states.
    start_epoch = 0
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})" .format(filename,
                                                            checkpoint['epoch']))
    else: print("=> no checkpoint found at '{}'".format(filename))
    return model, optimizer, start_epoch



def make_plot(history, epoch_max, path_output_chd, type_plot='loss'):
    train = history['train']
    valid = history['valid']
    fig, ax = plt.subplots()
    epochs = range(epoch_max)


    if type_plot=='loss':
        plt.plot(epochs, train, '-r', lw=2, label='Training loss')
        plt.plot(epochs, valid, '-b',lw=2, label='validation loss')
        plt.legend(borderaxespad=0.)
        plt.title('Training and Validation loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.savefig(path_output_chd+'/imgs/loss.png')

    elif type_plot == 'acc1':

        plt.plot(epochs, train, '-r', lw = 2, label='Training Top 1 Accuracy')
        plt.plot(epochs, valid, '-b', lw = 2, label='validation Top 1 Accuracy')
        plt.legend(borderaxespad=0.)
        plt.title('Training and Validation Top 1 Accuracy')
        plt.xlabel('Epochs')
        plt.ylabel('Top 1 Accuracy')
        plt.savefig(path_output_chd+'/imgs/acc1.png')

    elif type_plot == 'acc5':

        plt.plot(epochs, train, '-r', lw = 2, label='Training Top 5 Accuracy')
        plt.plot(epochs, valid, '-b', lw = 2, label='validation Top 5 Accuracy')
        plt.legend(borderaxespad=0.)
        plt.title('Training and Validation Top 5 Accuracy')
        plt.xlabel('Epochs')
        plt.ylabel('Top 5 Accuracy')
        plt.savefig(path_output_chd+'/imgs/acc5.png')
    else:
        plt.plot(epochs, train, '-r', lw = 2, label='Training blue')
        plt.plot(epochs, valid, '-b', lw = 2, label='validation blue')
        plt.legend(borderaxespad=0.)
        plt.title('Training and Validation blue')
        plt.xlabel('Epochs')
        plt.ylabel('Blue')
        plt.savefig(path_output_chd+'/imgs/blue.png')



    plt.show()


def main():
    train_model()



if __name__ == '__main__':
    main()

Epoch 0/199
----------
train Loss: 8.7577 Top 1 Acc: 0.3623 Top 5 Acc: 0.7089 Acc Quest: 0.8504 Acc Ans: 0.9106
valid Loss: 7.5285 Top 1 Acc: 0.3646 Top 5 Acc: 0.6128 Acc Quest: 0.9300 Acc Ans: 0.9446
Epoch 1/199
----------
train Loss: 5.7724 Top 1 Acc: 0.4462 Top 5 Acc: 0.7245 Acc Quest: 0.9372 Acc Ans: 0.9320
valid Loss: 5.5872 Top 1 Acc: 0.3544 Top 5 Acc: 0.6378 Acc Quest: 0.9401 Acc Ans: 0.9449
Epoch 2/199
----------
train Loss: 4.2677 Top 1 Acc: 0.4573 Top 5 Acc: 0.7522 Acc Quest: 0.9487 Acc Ans: 0.9905
valid Loss: 4.6429 Top 1 Acc: 0.3705 Top 5 Acc: 0.6582 Acc Quest: 0.9579 Acc Ans: 1.0095
Epoch 3/199
----------
train Loss: 3.4598 Top 1 Acc: 0.4902 Top 5 Acc: 0.7704 Acc Quest: 0.9639 Acc Ans: 1.0004
valid Loss: 4.0538 Top 1 Acc: 0.4229 Top 5 Acc: 0.6670 Acc Quest: 0.9795 Acc Ans: 1.0095
Epoch 4/199
----------
train Loss: 2.9281 Top 1 Acc: 0.5724 Top 5 Acc: 0.7798 Acc Quest: 0.9798 Acc Ans: 1.0009
valid Loss: 3.6345 Top 1 Acc: 0.4789 Top 5 Acc: 0.6751 Acc Quest: 0.9926 Acc Ans: 1.

In [None]:
from google.colab import files
files.download('/content/model_state_seed_97.tar')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [13]:
# Testing

import os
import argparse
import shutil
import numpy as np
import pandas as pd
import json
import copy
import random
import math
import time
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([transforms.Pad((0, 85), fill=0, padding_mode='constant'),
                                transforms.Resize((224, 224)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
                                ])

def get_test_loader(batch_size, num_workers):
    test_vqa_dataset = OVQADataset(
            json_file=f"{ROOT_PATH}testset.json",
            root_dir=f"{ROOT_PATH}img",
            phase = 'test',
            transform=transform)


    data_loader = torch.utils.data.DataLoader(dataset=test_vqa_dataset,
                                                batch_size=batch_size,
                                                shuffle=False,
                                                num_workers=num_workers)
    return data_loader


def inference(model, test_loader, answer_classes_dict, path_change):
    since = time.time()
    model.eval()
    results = []
    print('Inferencing ...')
    # Iterate over data.
    for batch_idx, batch_sample in enumerate(test_loader):
        image = batch_sample['image'].to(device)
        qid = batch_sample['qid']
        questions = batch_sample['question']

        output = model(image, questions)
        preds = torch.argmax(output[0], dim=-1)
        preds = preds.cpu().detach().numpy()

        assert (len(preds) == len(qid))

        ans_keys = list(answer_classes_dict.keys())
        ans_values = list(answer_classes_dict.values())


        for pred, image_name in zip(preds, qid):
            index_ans = ans_values.index(pred)
            results.append({image_name+'|'+ans_keys[index_ans]})

    df = pd.DataFrame(results)

    df.columns =['qid-answer']
    df.to_csv(path_change+'/submission.csv', index=False)

    time_elapsed = time.time() - since
    print('Evaluation complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))


def main():

    batch_size = opt.BATCH_SIZE
    num_workers = 0
    image_size = opt.IMG_INPUT_SIZE

    # Create the DataLoader for our dataset
    test_data_loader = get_test_loader(
        batch_size = batch_size,
        num_workers = num_workers)

    model = VqaClassifierModel( opt=opt ).to(device)
    saved_dir = "/content"
    filename =saved_dir+'/drive/MyDrive/ovqa_multitask_cascade_200epc_1024mfb_out_batch32_swin_model_state_seed_97.tar'
    print("=> loading checkpoint '{}'".format(filename))
    checkpoint = torch.load(filename)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])

    print("=> loaded checkpoint '{}' (epoch {})" .format(filename, checkpoint['epoch']))
    inference(model=model, test_loader=test_data_loader, answer_classes_dict=answer_classes, path_change="/content")

if __name__ == '__main__':
    main()

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

=> loading checkpoint '/content/drive/MyDrive/ovqa_multitask_cascade_200epc_1024mfb_out_batch32_swin_model_state_seed_97.tar'
=> loaded checkpoint '/content/drive/MyDrive/ovqa_multitask_cascade_200epc_1024mfb_out_batch32_swin_model_state_seed_97.tar' (epoch 193)
Inferencing ...
Evaluation complete in 0m 4s


#0.5478443743427971 multitask v1
#0.555205047318612 (80 epochs)fara multitask 0.5636172450052577 cu 150 epochs
#0.5368033648790747 multitask weighted
#0.5494216614090431 multitask 2 taskuri
#0.5399579390115667 with changing weights
#0.5509989484752892 cu input de la taskurile de question date la taskul principal concatenate cu outputul de la fusion
#0.5720294426919033 ca mai sus da 150 epoci
#0.5736067297581493 200 epoci

In [14]:
test_answers_dict = {}
for elem in data_test:
  test_answers_dict[elem["qid"]] = elem["answer"]

In [15]:
def get_test_accuracy(test_answers, predictions_csv_file):
  df = pd.read_csv(predictions_csv_file)
  acc = 0
  for index, row in df.iterrows():
    qid, ans = row["qid-answer"].split("|")[0], row["qid-answer"].split("|")[1]
    if test_answers[qid] == ans:
      acc += 1
  print(acc/len(df))

get_test_accuracy(test_answers_dict, "/content/submission.csv")

0.5988433228180863


0.5215562565720294 first resnet 1024
0.5141 second resnet 20148


0.5804 vit 1024 32
0.5778 vit 1024 64

0.5930 swin 32 (trb reantrenat overwritten)
0.5830 swin 64