# Utils

In [None]:
!pip install pytorch_lightning
!pip install transformers
!pip install pretrainedmodels

In [None]:
import os
import numpy as np
import pandas as pd
import random
import math
import cv2

import torch
from torchvision import transforms, models
from torch.cuda.amp import GradScaler
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel
from nltk.translate.bleu_score import sentence_bleu
from tqdm import tqdm
from PIL import Image
from random import choice
import matplotlib.pyplot as plt

import pretrainedmodels

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

# get data

In [None]:
def make_df(file_path):
    paths = os.listdir(file_path)
    
    df_list = []
    
    for p in paths:
        df = pd.read_csv(os.path.join(file_path, p), sep='|', names = ['img_id', 'question', 'answer'])
        df['category'] = p.split('_')[1]
        df['mode'] = p.split('_')[2][:-4]
        df_list.append(df)
    
    return pd.concat(df_list)

In [None]:
def load_all_data(args, remove = None):
    
    #### 2019 ####
    clef2019train_path = '../input/clef2019/clef2019/ImageClef-2019-VQA-Med-Training'
    clef2019valid_path = '../input/clef2019/clef2019/ImageClef-2019-VQA-Med-Validation'
    clef2019test_path = '../input/clef2019/clef2019/ImageClef-2019-VQA-Med-Test'
    
    traindf2019 = pd.read_csv(os.path.join(clef2019train_path, 'traindf.csv'))
    print("traindf2019: ", len(traindf2019))
    valdf2019 = pd.read_csv(os.path.join(clef2019valid_path, 'valdf.csv'))
    print("valdf2019: ", len(valdf2019))
    testdf2019 = pd.read_csv(os.path.join(clef2019test_path, 'testdf.csv'))
    print("testdf2019: ", len(testdf2019))
    
    if remove is not None:
        traindf2019 = traindf2019[~traindf2019['img_id'].isin(remove)].reset_index(drop=True)

    traindf2019['img_id'] = traindf2019['img_id'].apply(lambda x: os.path.join(clef2019train_path, 'Train_images', x + '.jpg'))
    valdf2019['img_id'] = valdf2019['img_id'].apply(lambda x: os.path.join(clef2019valid_path, 'Val_images', x + '.jpg'))
    testdf2019['img_id'] = testdf2019['img_id'].apply(lambda x: os.path.join(clef2019test_path, 'Test_images', x + '.jpg'))
    # testdf2019['img_id'] = testdf2019['img_id'].apply(lambda x: os.path.join(args.data_dir2019, x + '.jpg'))

    traindf2019['category'] = traindf2019['category'].str.lower()
    valdf2019['category'] = valdf2019['category'].str.lower()
    testdf2019['category'] = testdf2019['category'].str.lower()

    traindf2019['answer'] = traindf2019['answer'].str.lower()
    valdf2019['answer'] = valdf2019['answer'].str.lower()
    testdf2019['answer'] = testdf2019['answer'].str.lower()

    traindf2019 = traindf2019.sample(frac = args.train_pct)
    valdf2019 = valdf2019.sample(frac = args.valid_pct)
    testdf2019 = testdf2019.sample(frac = args.test_pct)
    
    #### 2020 ####
    clef2020train_path = '../input/clef2020/clef2020/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet'
    clef2020valid_path = '../input/clef2020/clef2020/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-ValidationSet'
    
    traindf2020 = pd.read_csv(os.path.join(clef2020train_path, 'clef2020_train_category.csv'))
    valdf2020 = pd.read_csv(os.path.join(clef2020valid_path, 'clef2020_valid_category.csv'))

    traindf2020['img_id'] = traindf2020['img_id'].apply(lambda x: os.path.join(clef2020train_path, 'VQAnswering_2020_Train_images', x + '.jpg'))
    valdf2020['img_id'] = valdf2020['img_id'].apply(lambda x: os.path.join(clef2020valid_path, 'VQAnswering_2020_Val_images', x + '.jpg'))

    traindf2020['category'] = traindf2020['category'].str.lower()
    valdf2020['category'] = valdf2020['category'].str.lower()

    traindf2020['answer'] = traindf2020['answer'].str.lower()
    valdf2020['answer'] = valdf2020['answer'].str.lower()

    traindf2020 = traindf2020.sample(frac = args.train_pct)
    valdf2020 = valdf2020.sample(frac = args.valid_pct)
    
    #### 2018 ####
    clef2018train_path = '../input/clef2018/clef2018/VQAMed2018Train/VQAMed2018Train'
    clef2018valid_path = '../input/clef2018/clef2018/VQAMed2018Valid/VQAMed2018Valid'
    
    traindf2018 = pd.read_csv(os.path.join(clef2018train_path, 'clef2018_train_category.csv'))
    valdf2018 = pd.read_csv(os.path.join(clef2018valid_path, 'clef2018_valid_category.csv'))

    traindf2018['img_id'] = traindf2018['img_id'].apply(lambda x: os.path.join(clef2018train_path, 'VQAMed2018Train-images', x + '.jpg'))
    valdf2018['img_id'] = valdf2018['img_id'].apply(lambda x: os.path.join(clef2018valid_path, 'VQAMed2018Valid-images', x + '.jpg'))

    traindf2018['category'] = traindf2018['category'].str.lower()
    valdf2018['category'] = valdf2018['category'].str.lower()

    traindf2018['answer'] = traindf2018['answer'].str.lower()
    valdf2018['answer'] = valdf2018['answer'].str.lower()

    traindf2018 = traindf2018.sample(frac = args.train_pct)
    valdf2018 = valdf2018.sample(frac = args.valid_pct)
    
    #### VQARAD ####
    vqaradtrain_path='../input/vqarad'
    traindf_vqarad = pd.read_csv(os.path.join(vqaradtrain_path, 'vqa_rad.csv'))

    traindf_vqarad['img_id'] = traindf_vqarad['img_id'].apply(lambda x: os.path.join(vqaradtrain_path, 'VQA_RAD Image Folder', x + '.jpg'))
    
    traindf_vqarad.head()
    
    #### union ####
    
    data_frames_train = [traindf2018, traindf2019, traindf2020, traindf_vqarad]
    data_frames_valid = [valdf2018, valdf2019, valdf2020]
    
    union_train = pd.concat(data_frames_train, ignore_index=True)
    union_valid = pd.concat(data_frames_valid, ignore_index=True)
    
    return union_train, union_valid, testdf2019

## utils methods

In [None]:
def gelu(x):
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

In [None]:
def encode_text(caption,tokenizer, args):
    part1 = [0 for _ in range(5)]
    #get token ids and remove [CLS] and [SEP] token id
    part2 = tokenizer.encode(caption)[1:-1]

    tokens = [tokenizer.cls_token_id] + part1 + [tokenizer.sep_token_id] + part2[:args.max_position_embeddings-8] + [tokenizer.sep_token_id]
    segment_ids = [0]*(len(part1)+2) + [1]*(len(part2[:args.max_position_embeddings-8])+1)
    input_mask = [1]*len(tokens)
    n_pad = args.max_position_embeddings - len(tokens)
    tokens.extend([0]*n_pad)
    segment_ids.extend([0]*n_pad)
    input_mask.extend([0]*n_pad)

    
    return tokens, segment_ids, input_mask

In [None]:
def onehot(size, target):
    vec = torch.zeros(size, dtype=torch.float32)
    vec[target] = 1.
    return vec

In [None]:
class LabelSmoothing(nn.Module):
    def __init__(self, smoothing = 0.1):
        super(LabelSmoothing, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing

    def forward(self, x, target):
        if self.training:
            x = x.float()
            target = target.float()
            logprobs = torch.nn.functional.log_softmax(x, dim = -1)

            nll_loss = -logprobs * target
            nll_loss = nll_loss.sum(-1)
    
            smooth_loss = -logprobs.mean(dim=-1)

            loss = self.confidence * nll_loss + self.smoothing * smooth_loss

            return loss.mean()
        else:
            return torch.nn.functional.cross_entropy(x, target)

# Classes

In [None]:
class VQAMed(Dataset):
    def __init__(self, df, imgsize, tfm, args, mode = 'train'):
        self.df = df
        self.tfm = tfm
        self.size = imgsize
        self.args = args
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.mode = mode

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        path = self.df.loc[idx,'img_id']
        question = self.df.loc[idx, 'question']
 
        answer = self.df.loc[idx, 'answer']

        if self.mode == 'eval':
            tok_ques = self.tokenizer.tokenize(question)

        if self.args.smoothing:
            answer = onehot(self.args.num_classes, answer)

        img = cv2.imread(path)
  

        if self.tfm:
            img = self.tfm(img)
            
        tokens, segment_ids, input_mask= encode_text(question, self.tokenizer, self.args)


        return img, torch.tensor(tokens, dtype = torch.long), torch.tensor(segment_ids, dtype = torch.long), torch.tensor(input_mask, dtype = torch.long), torch.tensor(answer, dtype = torch.long), path

In [None]:
class Model_Keyword(nn.Module):
    def __init__(self, num_classes):
        super(Model_Keyword, self).__init__()
        self.model = pretrainedmodels.__dict__['se_resnext50_32x4d'](num_classes=1000, pretrained='imagenet')
        last_in = self.model.last_linear.in_features
        self.model.last_linear = nn.Identity()
        self.embed = nn.Embedding(3, last_in)
        self.last_layer = nn.Linear(2 * last_in, num_classes)

    def forward(self, img, keyword):

        img_feat = self.model(img)
        key_feat = self.embed(keyword)

        feat = torch.cat([img_feat, key_feat], -1)

        logits = self.last_layer(feat)

        return logits

In [None]:
def calculate_bleu_score(preds,targets, idx2ans):
  bleu_per_answer = np.asarray([sentence_bleu([idx2ans[target].split()],idx2ans[pred].split(), weights = [1]) for pred,target in zip(preds,targets)])
  return np.mean(bleu_per_answer)

In [None]:
class Embeddings(nn.Module):
    def __init__(self, args):
        super(Embeddings, self).__init__()
        self.word_embeddings = nn.Embedding(args.vocab_size, 128, padding_idx=0)
        self.word_embeddings_2 = nn.Linear(128, args.hidden_size, bias=False)
        self.position_embeddings = nn.Embedding(args.max_position_embeddings, args.hidden_size)
        self.type_embeddings = nn.Embedding(3, args.hidden_size)
        self.LayerNorm = nn.LayerNorm(args.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(args.hidden_dropout_prob)
        self.len = args.max_position_embeddings
    def forward(self, input_ids, segment_ids, position_ids=None):
        if position_ids is None:
            if torch.cuda.is_available():
                position_ids = torch.arange(self.len, dtype=torch.long).cuda()
            else:
                position_ids = torch.arange(self.len, dtype=torch.long)
            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        words_embeddings = self.word_embeddings(input_ids)
        words_embeddings = self.word_embeddings_2(words_embeddings)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.type_embeddings(segment_ids)
        embeddings = words_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)

        return embeddings

In [None]:
class Transfer(nn.Module):
    def __init__(self,args):
        super(Transfer, self).__init__()

        self.args = args
        self.num_vis = args.num_vis
        self.model = models.resnet152(pretrained=True)
        # for p in self.parameters():
        #     p.requires_grad=False

        if self.num_vis == 5:
            self.relu = nn.ReLU()
            self.conv2 = nn.Conv2d(2048, args.hidden_size, kernel_size=(1, 1), stride=(1, 1), bias=False)
            self.gap2 = nn.AdaptiveAvgPool2d((1,1))
            self.conv3 = nn.Conv2d(1024, args.hidden_size, kernel_size=(1, 1), stride=(1, 1), bias=False)
            self.gap3 = nn.AdaptiveAvgPool2d((1,1))
            self.conv4 = nn.Conv2d(512, args.hidden_size, kernel_size=(1, 1), stride=(1, 1), bias=False)
            self.gap4 = nn.AdaptiveAvgPool2d((1,1))
            self.conv5 = nn.Conv2d(256, args.hidden_size, kernel_size=(1, 1), stride=(1, 1), bias=False)
            self.gap5 = nn.AdaptiveAvgPool2d((1,1))
            self.conv7 = nn.Conv2d(64, args.hidden_size, kernel_size=(1, 1), stride=(1, 1), bias=False)
            self.gap7 = nn.AdaptiveAvgPool2d((1,1))

        elif self.num_vis == 3:
            self.relu = nn.ReLU()
            self.conv2 = nn.Conv2d(2048, args.hidden_size, kernel_size=(1, 1), stride=(1, 1), bias=False)
            self.gap2 = nn.AdaptiveAvgPool2d((1,1))
            self.conv3 = nn.Conv2d(1024, args.hidden_size, kernel_size=(1, 1), stride=(1, 1), bias=False)
            self.gap3 = nn.AdaptiveAvgPool2d((1,1))
            self.conv4 = nn.Conv2d(512, args.hidden_size, kernel_size=(1, 1), stride=(1, 1), bias=False)
            self.gap4 = nn.AdaptiveAvgPool2d((1,1))

        else:
            self.relu = nn.ReLU()
            self.conv2 = nn.Conv2d(2048, args.hidden_size, kernel_size=(1, 1), stride=(1, 1), bias=False)
            self.gap2 = nn.AdaptiveAvgPool2d((1,1))            
            
    def forward(self, img):

        if self.num_vis == 5: 
            modules2 = list(self.model.children())[:-2]
            fix2 = nn.Sequential(*modules2)
            inter_2 = self.conv2(fix2(img))
            v_2 = self.gap2(self.relu(inter_2)).view(-1,self.args.hidden_size)
            modules3 = list(self.model.children())[:-3]
            fix3 = nn.Sequential(*modules3)
            inter_3 = self.conv3(fix3(img))
            v_3 = self.gap3(self.relu(inter_3)).view(-1,self.args.hidden_size)
            modules4 = list(self.model.children())[:-4]
            fix4 = nn.Sequential(*modules4)
            inter_4 = self.conv4(fix4(img))
            v_4 = self.gap4(self.relu(inter_4)).view(-1,self.args.hidden_size)
            modules5 = list(self.model.children())[:-5]
            fix5 = nn.Sequential(*modules5)
            inter_5 = self.conv5(fix5(img))
            v_5 = self.gap5(self.relu(inter_5)).view(-1,self.args.hidden_size)
            modules7 = list(self.model.children())[:-7]
            fix7 = nn.Sequential(*modules7)
            inter_7 = self.conv7(fix7(img))
            v_7 = self.gap7(self.relu(inter_7)).view(-1,self.args.hidden_size)

            return v_2, v_3, v_4, v_5, v_7, [inter_2.mean(1), inter_3.mean(1), inter_4.mean(1), inter_5.mean(1), inter_7.mean(1)]

        if self.num_vis == 3: 
            modules2 = list(self.model.children())[:-2]
            fix2 = nn.Sequential(*modules2)
            inter_2 = self.conv2(fix2(img))
            v_2 = self.gap2(self.relu(inter_2)).view(-1,self.args.hidden_size)
            modules3 = list(self.model.children())[:-3]
            fix3 = nn.Sequential(*modules3)
            inter_3 = self.conv3(fix3(img))
            v_3 = self.gap3(self.relu(inter_3)).view(-1,self.args.hidden_size)
            modules4 = list(self.model.children())[:-4]
            fix4 = nn.Sequential(*modules4)
            inter_4 = self.conv4(fix4(img))
            v_4 = self.gap4(self.relu(inter_4)).view(-1,self.args.hidden_size)

            return v_2, v_3, v_4, [inter_2.mean(1), inter_3.mean(1), inter_4.mean(1)]

        else:
            modules2 = list(self.model.children())[:-2]
            fix2 = nn.Sequential(*modules2)
            inter_2 = self.conv2(fix2(img))
            v_2 = self.gap2(self.relu(inter_2)).view(-1,self.args.hidden_size)    
            
            return v_2, [inter_2.mean(1)]  

In [None]:
class MultiHeadedSelfAttention(nn.Module):
    def __init__(self,args):
        super(MultiHeadedSelfAttention,self).__init__()
        self.proj_q = nn.Linear(args.hidden_size, args.hidden_size)
        self.proj_k = nn.Linear(args.hidden_size, args.hidden_size)
        self.proj_v = nn.Linear(args.hidden_size, args.hidden_size)
        self.drop = nn.Dropout(args.hidden_dropout_prob)
        self.scores = None
        self.n_heads = args.heads
    def forward(self, x, mask):
        q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x)
        q, k, v = (self.split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v])
        scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
        if mask is not None:
            mask = mask[:, None, None, :].float()
            scores -= 10000.0 * (1.0 - mask)
        scores = self.drop(F.softmax(scores, dim=-1))
        h = (scores @ v).transpose(1, 2).contiguous()
        h = self.merge_last(h, 2)
        self.scores = scores
        return h, scores
    def split_last(self, x, shape):
        shape = list(shape)
        assert shape.count(-1) <= 1  
        if -1 in shape:
            shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape))
        return x.view(*x.size()[:-1], *shape)
    def merge_last(self, x, n_dims):
        s = x.size()
        assert n_dims > 1 and n_dims < len(s)
        return x.view(*s[:-n_dims], -1)

In [None]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self,args):
        super(PositionWiseFeedForward,self).__init__()
        self.fc1 = nn.Linear(args.hidden_size, args.hidden_size*4)
        self.fc2 = nn.Linear(args.hidden_size*4, args.hidden_size)
    def forward(self, x):
        return self.fc2(gelu(self.fc1(x)))

In [None]:
class BertLayer(nn.Module):
    def __init__(self,args, share='all', norm='pre'):
        super(BertLayer, self).__init__()
        self.share = share
        self.norm_pos = norm
        self.norm1 = nn.LayerNorm(args.hidden_size, eps=1e-12)
        self.norm2 = nn.LayerNorm(args.hidden_size, eps=1e-12)
        self.drop1 = nn.Dropout(args.hidden_dropout_prob)
        self.drop2 = nn.Dropout(args.hidden_dropout_prob)
        if self.share == 'ffn':
            self.attention = nn.ModuleList([MultiHeadedSelfAttention(args) for _ in range(args.n_layers)])
            self.proj = nn.ModuleList([nn.Linear(args.hidden_size, args.hidden_size) for _ in range(args.n_layers)])
            self.feedforward = PositionWiseFeedForward(args)
        elif self.share == 'att':
            self.attention = MultiHeadedSelfAttention(args)
            self.proj = nn.Linear(args.hidden_size, args.hidden_size)
            self.feedforward = nn.ModuleList([PositionWiseFeedForward(args) for _ in range(args.n_layers)])
        elif self.share == 'all':
            self.attention = MultiHeadedSelfAttention(args)
            self.proj = nn.Linear(args.hidden_size, args.hidden_size)
            self.feedforward = PositionWiseFeedForward(args)
        elif self.share == 'none':
            self.attention = nn.ModuleList([MultiHeadedSelfAttention(args) for _ in range(args.n_layers)])
            self.proj = nn.ModuleList([nn.Linear(args.hidden_size, args.hidden_size) for _ in range(args.n_layers)])
            self.feedforward = nn.ModuleList([PositionWiseFeedForward(args) for _ in range(args.n_layers)])
    def forward(self, hidden_states, attention_mask, layer_num):
        if self.norm_pos == 'pre':
            if isinstance(self.attention, nn.ModuleList):
                attn_output, attn_scores = self.attention[layer_num](self.norm1(hidden_states), attention_mask)
                h = self.proj[layer_num](attn_output)
            else:
                h = self.proj(self.attention(self.norm1(hidden_states), attention_mask))
            out = hidden_states + self.drop1(h)
            if isinstance(self.feedforward, nn.ModuleList):
                h = self.feedforward[layer_num](self.norm1(out))
            else:
                h = self.feedforward(self.norm1(out))
            out = out + self.drop2(h)
        if self.norm_pos == 'post':
            if isinstance(self.attention, nn.ModuleList):
                h = self.proj[layer_num](self.attention[layer_num](hidden_states, attention_mask))
            else:
                h = self.proj(self.attention(hidden_states, attention_mask))
            out = self.norm1(hidden_states + self.drop1(h))
            if isinstance(self.feedforward, nn.ModuleList):
                h = self.feedforward[layer_num](out)
            else:
                h = self.feedforward(out)
            out = self.norm2(out + self.drop2(h))
        return out, attn_scores

In [None]:
class Transformer(nn.Module):
    def __init__(self, args):
        super(Transformer,self).__init__()
        base_model = BertModel.from_pretrained('bert-base-uncased')
        bert_model = nn.Sequential(*list(base_model.children())[0:])
        self.bert_embedding = bert_model[0]
        # self.embed = Embeddings(args)
        self.num_vis = args.num_vis
        self.trans = Transfer(args)
        self.blocks = BertLayer(args,share='none', norm='pre')
        self.n_layers = args.n_layers
        
    def forward(self, img, input_ids, token_type_ids, mask):

        if self.num_vis==5:
            #print("img.shape: " ,img.shape)
            v_2, v_3, v_4, v_5, v_7, intermediate = self.trans(img)
        elif self.num_vis==3:
            v_2, v_3, v_4, intermediate = self.trans(img)
        else:
            v_2, intermediate = self.trans(img)
        # h = self.embed(input_ids, token_type_ids)
        h = self.bert_embedding(input_ids=input_ids, token_type_ids=token_type_ids, position_ids=None)
        #print("h.size: " ,h.shape)
        #print("v_2.size: " ,v_2.shape)
        #print("v_3.size: " ,v_3.shape)
        #print("v_4.size: " ,v_4.shape)
        #print("v_5.size: " ,v_5.shape)
        #print("v_7.size: " ,v_7.shape)
        if self.num_vis == 5:
            for i in range(len(h)):
                h[i][1] = v_2[i]
            for i in range(len(h)):
                h[i][2] = v_3[i]
            for i in range(len(h)):
                h[i][3] = v_4[i]
            for i in range(len(h)):
                h[i][4] = v_5[i]
            for i in range(len(h)):
                h[i][5] = v_7[i]

        elif self.num_vis == 3:
            for i in range(len(h)):
                h[i][1] = v_2[i]
            for i in range(len(h)):
                h[i][2] = v_3[i]
            for i in range(len(h)):
                h[i][3] = v_4[i]

        else:
            for i in range(len(h)):
                h[i][1] = v_2[i]


        hidden_states = []
        all_attn_scores = []
        for i in range(self.n_layers):
            h, attn_scores = self.blocks(h, mask, i)
            hidden_states.append(h)
            all_attn_scores.append(attn_scores)

        return torch.stack(hidden_states, 0), torch.stack(all_attn_scores, 0), intermediate

In [None]:
class Model(nn.Module):
    def __init__(self,args):
        super(Model,self).__init__()
        self.args = args
        self.transformer = Transformer(args)
        self.fc1 = nn.Linear(args.hidden_size, args.hidden_size)
        self.activ1 = nn.Tanh()
        self.classifier = nn.Sequential(nn.Linear(args.hidden_size, args.hidden_size),
                                        nn.LayerNorm(args.hidden_size, eps=1e-12, elementwise_affine=True),
                                        nn.Linear(args.hidden_size, args.vocab_size))
    def forward(self, img, input_ids, segment_ids, input_mask):
        h, attn_scores, intermediate = self.transformer(img, input_ids, segment_ids, input_mask)
        pooled_h = self.activ1(self.fc1(h[-1].mean(1)))
        logits = self.classifier(pooled_h)
        return logits, attn_scores, intermediate

In [None]:
def train_one_epoch(loader, model, optimizer, criterion, device, scaler, args, idx2ans):

    model.train()
    train_loss = []
    IMGIDS = []
    PREDS = []
    TARGETS = []
    bar = tqdm(loader, leave = False)
    for (img, question_token,segment_ids,attention_mask,target, imgid) in bar:
        
        img, question_token,segment_ids,attention_mask,target = img.to(device), question_token.to(device), segment_ids.to(device), attention_mask.to(device), target.to(device)
        question_token = question_token.squeeze(1)
        attention_mask = attention_mask.squeeze(1)
        loss_func = criterion
        optimizer.zero_grad()

        if args.mixed_precision:
            with torch.cuda.amp.autocast(): 
                logits, _, _ = model(img, question_token, segment_ids, attention_mask)
                loss = loss_func(logits, target)
        else:
            logits, _, _ = model(img, question_token, segment_ids, attention_mask)
            loss = loss_func(logits, target)

        if args.mixed_precision:
            scaler.scale(loss)
            loss.backward()

            if args.clip:
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()

            if args.clip:
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                
            optimizer.step()

        if args.smoothing:
            TARGETS.append(target.argmax(1))
        else:
            TARGETS.append(target)    

        pred = logits.softmax(1).argmax(1).detach()
        PREDS.append(pred)
        IMGIDS.append(imgid)

        loss_np = loss.detach().cpu().numpy()
        train_loss.append(loss_np)
        bar.set_description('train_loss: %.5f' % (loss_np))

    PREDS = torch.cat(PREDS).cpu().numpy()
    TARGETS = torch.cat(TARGETS).cpu().numpy()
    IMGIDS = [i for sub in IMGIDS for i in sub]

    acc = (PREDS == TARGETS).mean() * 100.
    bleu = calculate_bleu_score(PREDS,TARGETS,idx2ans)

    return np.mean(train_loss), PREDS, acc, bleu, IMGIDS

# Eval methods


In [None]:
def validate(loader, model, criterion, device, scaler, args, val_df, idx2ans):

    model.eval()
    val_loss = []

    PREDS = []
    TARGETS = []
    bar = tqdm(loader, leave=False)

    with torch.no_grad():
        for (img, question_token,segment_ids,attention_mask,target, _) in bar:

            img, question_token,segment_ids,attention_mask,target = img.to(device), question_token.to(device), segment_ids.to(device), attention_mask.to(device), target.to(device)
            question_token = question_token.squeeze(1)
            attention_mask = attention_mask.squeeze(1)


            if args.mixed_precision:
                with torch.cuda.amp.autocast(): 
                    logits, _, _ = model(img, question_token, segment_ids, attention_mask)
                    loss = criterion(logits, target)
            else:
                logits, _ , _= model(img, question_token, segment_ids, attention_mask)
                loss = criterion(logits, target)


            loss_np = loss.detach().cpu().numpy()

            pred = logits.softmax(1).argmax(1).detach()

            PREDS.append(pred)

            if args.smoothing:
                TARGETS.append(target.argmax(1))
            else:
                TARGETS.append(target)

            val_loss.append(loss_np)

            bar.set_description('val_loss: %.5f' % (loss_np))

        val_loss = np.mean(val_loss)

    PREDS = torch.cat(PREDS).cpu().numpy()
    TARGETS = torch.cat(TARGETS).cpu().numpy()

    # Calculate total and category wise accuracy
    if args.category:
        acc = (PREDS == TARGETS).mean() * 100.
        bleu = calculate_bleu_score(PREDS,TARGETS,idx2ans)
    else:
        total_acc = (PREDS == TARGETS).mean() * 100.
        plane_acc = (PREDS[val_df['category']=='plane'] == TARGETS[val_df['category']=='plane']).mean() * 100.
        organ_acc = (PREDS[val_df['category']=='organ'] == TARGETS[val_df['category']=='organ']).mean() * 100.
        modality_acc = (PREDS[val_df['category']=='modality'] == TARGETS[val_df['category']=='modality']).mean() * 100.
        abnorm_acc = (PREDS[val_df['category']=='abnormality'] == TARGETS[val_df['category']=='abnormality']).mean() * 100.
        """plane_acc = (PREDS['plane' in val_df['category']] == TARGETS['plane' in val_df['category']]).mean() * 100.
        organ_acc = (PREDS['organ' in val_df['category']] == TARGETS['organ' in val_df['category']]).mean() * 100.
        modality_acc = (PREDS['modality' in val_df['category']] == TARGETS['modality' in val_df['category']]).mean() * 100.
        abnorm_acc = (PREDS['abnormality' in val_df['category']] == TARGETS['abnormality' in val_df['category']]).mean() * 100."""

        acc = {'val_total_acc': np.round(total_acc, 4), 'val_plane_acc': np.round(plane_acc, 4), 'val_organ_acc': np.round(organ_acc, 4), 
               'val_modality_acc': np.round(modality_acc, 4), 'val_abnorm_acc': np.round(abnorm_acc, 4)}

        # add bleu score code
        total_bleu = calculate_bleu_score(PREDS,TARGETS,idx2ans)
        plane_bleu = calculate_bleu_score(PREDS[val_df['category']=='plane'],TARGETS[val_df['category']=='plane'],idx2ans)
        organ_bleu = calculate_bleu_score(PREDS[val_df['category']=='organ'],TARGETS[val_df['category']=='organ'],idx2ans)
        modality_bleu = calculate_bleu_score(PREDS[val_df['category']=='modality'],TARGETS[val_df['category']=='modality'],idx2ans)
        abnorm_bleu = calculate_bleu_score(PREDS[val_df['category']=='abnormality'],TARGETS[val_df['category']=='abnormality'],idx2ans)
        """plane_bleu = calculate_bleu_score(PREDS['plane' in val_df['category']],TARGETS['plane' in val_df['category']],idx2ans)
        organ_bleu = calculate_bleu_score(PREDS['organ' in val_df['category']],TARGETS['organ' in val_df['category']],idx2ans)
        modality_bleu = calculate_bleu_score(PREDS['modality' in val_df['category']],TARGETS['modality' in val_df['category']],idx2ans)
        abnorm_bleu = calculate_bleu_score(PREDS['abnormality' in val_df['category']],TARGETS['abnormality' in val_df['category']],idx2ans)"""


        bleu = {'val_total_bleu': np.round(total_bleu, 4), 'val_plane_bleu': np.round(plane_bleu, 4), 'val_organ_bleu': np.round(organ_bleu, 4), 
            'val_modality_bleu': np.round(modality_bleu, 4), 'val_abnorm_bleu': np.round(abnorm_bleu, 4)}

    return val_loss, PREDS, acc, bleu  

In [None]:
def test(loader, model, criterion, device, scaler, args, val_df,idx2ans):

    model.eval()
    TARGETS = []
    PREDS = []
    test_loss = []

    with torch.no_grad():
        for (img,question_token,segment_ids,attention_mask,target, _) in tqdm(loader, leave=False):

            img, question_token, segment_ids, attention_mask, target = img.to(device), question_token.to(device), segment_ids.to(device), attention_mask.to(device), target.to(device)
            question_token = question_token.squeeze(1)
            attention_mask = attention_mask.squeeze(1)
            
            if args.mixed_precision:
                with torch.cuda.amp.autocast(): 
                    logits, _, _ = model(img, question_token, segment_ids, attention_mask)
                    loss = criterion(logits, target)
            else:
                logits, _, _ = model(img, question_token, segment_ids, attention_mask)
                loss = criterion(logits, target)


            loss_np = loss.detach().cpu().numpy()

            test_loss.append(loss_np)

            pred = logits.softmax(1).argmax(1).detach()
            
            PREDS.append(pred)

            if args.smoothing:
                TARGETS.append(target.argmax(1))
            else:
                TARGETS.append(target)

        test_loss = np.mean(test_loss)

    PREDS = torch.cat(PREDS).cpu().numpy()
    TARGETS = torch.cat(TARGETS).cpu().numpy()

    if args.category:
        acc = (PREDS == TARGETS).mean() * 100.
        bleu = calculate_bleu_score(PREDS,TARGETS,idx2ans)
    else:
        total_acc = (PREDS == TARGETS).mean() * 100.
        plane_acc = (PREDS[val_df['category']=='plane'] == TARGETS[val_df['category']=='plane']).mean() * 100.
        organ_acc = (PREDS[val_df['category']=='organ'] == TARGETS[val_df['category']=='organ']).mean() * 100.
        modality_acc = (PREDS[val_df['category']=='modality'] == TARGETS[val_df['category']=='modality']).mean() * 100.
        abnorm_acc = (PREDS[val_df['category']=='abnormality'] == TARGETS[val_df['category']=='abnormality']).mean() * 100.
        """plane_acc = (PREDS['plane' in val_df['category']] == TARGETS['plane' in val_df['category']]).mean() * 100.
        organ_acc = (PREDS['organ' in val_df['category']] == TARGETS['organ' in val_df['category']]).mean() * 100.
        modality_acc = (PREDS['modality' in val_df['category']] == TARGETS['modality' in val_df['category']]).mean() * 100.
        abnorm_acc = (PREDS['abnormality' in val_df['category']] == TARGETS['abnormality' in val_df['category']]).mean() * 100."""

        acc = {'total_acc': np.round(total_acc, 4), 'plane_acc': np.round(plane_acc, 4), 'organ_acc': np.round(organ_acc, 4), 
               'modality_acc': np.round(modality_acc, 4), 'abnorm_acc': np.round(abnorm_acc, 4)}

        # add bleu score code
        total_bleu = calculate_bleu_score(PREDS,TARGETS,idx2ans)
        plane_bleu = calculate_bleu_score(PREDS[val_df['category']=='plane'],TARGETS[val_df['category']=='plane'],idx2ans)
        organ_bleu = calculate_bleu_score(PREDS[val_df['category']=='organ'],TARGETS[val_df['category']=='organ'],idx2ans)
        modality_bleu = calculate_bleu_score(PREDS[val_df['category']=='modality'],TARGETS[val_df['category']=='modality'],idx2ans)
        abnorm_bleu = calculate_bleu_score(PREDS[val_df['category']=='abnormality'],TARGETS[val_df['category']=='abnormality'],idx2ans)
        """plane_bleu = calculate_bleu_score(PREDS['plane' in val_df['category']],TARGETS['plane' in val_df['category']],idx2ans)
        organ_bleu = calculate_bleu_score(PREDS['organ' in val_df['category']],TARGETS['organ' in val_df['category']],idx2ans)
        modality_bleu = calculate_bleu_score(PREDS['modality' in val_df['category']],TARGETS['modality' in val_df['category']],idx2ans)
        abnorm_bleu = calculate_bleu_score(PREDS['abnormality' in val_df['category']],TARGETS['abnormality' in val_df['category']],idx2ans)"""


        bleu = {'total_bleu': np.round(total_bleu, 4), 'plane_bleu': np.round(plane_bleu, 4), 'organ_bleu': np.round(organ_bleu, 4), 
            'modality_bleu': np.round(modality_bleu, 4), 'abnorm_bleu': np.round(abnorm_bleu, 4)}


    return test_loss, PREDS, acc, bleu

# Train

In [None]:
import argparse
import sys
#from utils import seed_everything, Model, VQAMed, train_one_epoch, validate, test, load_data, LabelSmoothing, train_img_only, val_img_only, test_img_only
import wandb
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torchvision import transforms, models
from torch.cuda.amp import GradScaler
import os
import warnings
import albumentations as A
import pretrainedmodels
from albumentations.core.composition import OneOf
#from albumentations.pytorch.transforms import ToTensorV2

warnings.simplefilter("ignore", UserWarning)

In [None]:
#os.mkdir('/kaggle/working/roco-weights')

sys.argv = ['-f']

parser = argparse.ArgumentParser(description = "Finetune on ImageClef 2019")

parser.add_argument('--run_name', type = str, required = False, default = "MMBERT_allclef_vqarad_pre", help = "run name for wandb")
#parser.add_argument('--clef_data_dir', type = str, required = False, default = "../input/allclef-reply", help = "path for clef data")
#parser.add_argument('--vqarad_data_dir', type = str, required = False, default = "../input/vqarad-reply", help = "path for vqarad data")
parser.add_argument('--model_dir', type = str, required = False, default = "../input/mmbert-allclef-vqarad-pre-weights/MMBERT_allclef_vqarad_pre_bestacc.pt", help = "path to load weights")
#parser.add_argument('--model_dir', type = str, required = False, default = "../input/mmbert-pretrain-roco-weights/rocopretrain_weights.pt", help = "path to load weights")
#parser.add_argument('--save_dir', type = str, required = False, default = "/content/drive/MyDrive/Colab Notebooks/Thesis/Transformer VQA/MMBERT weights/MMBERTallClef_noPre", help = "path to save weights")
parser.add_argument('--category', type = str, required = False, default = None,  help = "choose specific category if you want")
parser.add_argument('--use_pretrained', action = 'store_true', default = False, help = "use pretrained weights or not")
parser.add_argument('--mixed_precision', action = 'store_true', default = False, help = "use mixed precision or not")
parser.add_argument('--clip', action = 'store_true', default = False, help = "clip the gradients or not")

parser.add_argument('--seed', type = int, required = False, default = 42, help = "set seed for reproducibility")
parser.add_argument('--num_workers', type = int, required = False, default = 4, help = "number of workers")
parser.add_argument('--epochs', type = int, required = False, default = 100, help = "num epochs to train")
parser.add_argument('--train_pct', type = float, required = False, default = 1.0, help = "fraction of train samples to select")
parser.add_argument('--valid_pct', type = float, required = False, default = 1.0, help = "fraction of validation samples to select")
parser.add_argument('--test_pct', type = float, required = False, default = 1.0, help = "fraction of test samples to select")

parser.add_argument('--max_position_embeddings', type = int, required = False, default = 28, help = "max length of sequence")
parser.add_argument('--batch_size', type = int, required = False, default = 10, help = "batch size")
parser.add_argument('--lr', type = float, required = False, default = 1e-4, help = "learning rate'")
# parser.add_argument('--weight_decay', type = float, required = False, default = 1e-2, help = " weight decay for gradients")
parser.add_argument('--factor', type = float, required = False, default = 0.1, help = "factor for rlp")
parser.add_argument('--patience', type = int, required = False, default = 10, help = "patience for rlp")
# parser.add_argument('--lr_min', type = float, required = False, default = 1e-6, help = "minimum lr for Cosine Annealing")
parser.add_argument('--hidden_dropout_prob', type = float, required = False, default = 0.3, help = "hidden dropout probability")
parser.add_argument('--smoothing', type = float, required = False, default = None, help = "label smoothing")

parser.add_argument('--image_size', type = int, required = False, default = 224, help = "image size")
parser.add_argument('--hidden_size', type = int, required = False, default = 768, help = "hidden size") #og 312
parser.add_argument('--vocab_size', type = int, required = False, default = 30522, help = "vocab size")
parser.add_argument('--type_vocab_size', type = int, required = False, default = 2, help = "type vocab size")
parser.add_argument('--heads', type = int, required = False, default = 12, help = "heads")
parser.add_argument('--n_layers', type = int, required = False, default = 4, help = "num of layers")
parser.add_argument('--num_vis', type = int, required = False , default = 5, help = "num of visual embeddings") #num of conv2d Layers in the transformer, can be: 5, 3 or 1

args = parser.parse_args()

"""wandb.init(project='medvqa', name = args.run_name, config = args)"""

seed_everything(args.seed)

In [None]:
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)

In [None]:
train_df, val_df, test_df = load_all_data(args)
print('len(train_df): ' ,len(train_df))
print('len(val_df): ' ,len(val_df))
print('len(test_df): ' ,len(test_df))

# remove abn
train_df = train_df.loc[train_df['category'].str.contains('abnormality')]
val_df = val_df.loc[val_df['category'].str.contains('abnormality')]
test_df = test_df.loc[test_df['category'].str.contains('abnormality')]

print('len(train_df): ' ,len(train_df))
print('len(val_df): ' ,len(val_df))
print('len(test_df): ' ,len(test_df))

if args.category:
        
    train_df = train_df[train_df['category']==args.category].reset_index(drop=True)
    val_df = val_df[val_df['category']==args.category].reset_index(drop=True)
    test_df = test_df[test_df['category']==args.category].reset_index(drop=True)

    train_df = train_df[~train_df['answer'].isin(['yes', 'no'])].reset_index(drop = True)
    val_df = val_df[~val_df['answer'].isin(['yes', 'no'])].reset_index(drop = True)
    test_df = test_df[~test_df['answer'].isin(['yes', 'no'])].reset_index(drop = True)

df = pd.concat([train_df, val_df, test_df]).reset_index(drop=True)

ans2idx = {ans:idx for idx,ans in enumerate(df['answer'].unique())}
idx2ans = {idx:ans for ans,idx in ans2idx.items()}
df['answer_mapped'] = df['answer'].map(ans2idx).astype(int)
train_df = df[df['mode']=='train'].reset_index(drop=True)
val_df = df[df['mode']=='val'].reset_index(drop=True)
test_df = df[df['mode']=='test'].reset_index(drop=True)

num_classes = len(ans2idx)

args.num_classes = num_classes
print(num_classes)

df.head(50)


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = Model(args)

In [None]:
model.classifier[2] = nn.Linear(args.hidden_size, num_classes)

if args.use_pretrained:
    print("loading weights")
    model.load_state_dict(torch.load(args.model_dir))
    print("loaded weights")
 
model.to(device)

In [None]:
optimizer = optim.Adam(model.parameters(),lr=args.lr)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, patience = args.patience, factor = args.factor, verbose = True)

if args.smoothing:
    criterion = LabelSmoothing(smoothing=args.smoothing)
else:
    criterion = nn.CrossEntropyLoss()

scaler = GradScaler()


train_tfm = transforms.Compose([transforms.ToPILImage(),
                                transforms.RandomResizedCrop(224,scale=(0.75,1.25),ratio=(0.75,1.25)),
                                transforms.RandomRotation(10),
                                # Cutout(),
                                transforms.ColorJitter(brightness=0.4,contrast=0.4,saturation=0.4,hue=0.4),
                                transforms.ToTensor(), 
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

val_tfm = transforms.Compose([transforms.ToPILImage(),
                              transforms.Resize((224,224)),
                              transforms.ToTensor(), 
                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

test_tfm = transforms.Compose([transforms.ToPILImage(),
                               transforms.Resize((224,224)),    
                               transforms.ToTensor(), 
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [None]:
traindataset = VQAMed(train_df, imgsize = args.image_size, tfm = train_tfm, args = args)
valdataset = VQAMed(val_df, imgsize = args.image_size, tfm = val_tfm, args = args)
testdataset = VQAMed(test_df, imgsize = args.image_size, tfm = test_tfm, args = args)

trainloader = DataLoader(traindataset, batch_size = args.batch_size, shuffle=True, num_workers = args.num_workers)
valloader = DataLoader(valdataset, batch_size = args.batch_size, shuffle=False, num_workers = args.num_workers)
testloader = DataLoader(testdataset, batch_size = args.batch_size, shuffle=False, num_workers = args.num_workers)

In [None]:
from datetime import datetime
now = datetime.now()

with open("../input/mmbert-allclef-vqarad-pre-weights/MMBERT_allclef_vqarad_pre.txt") as input_txt:
    with open("MMBERT_allclef_vqarad_pre.txt", "w") as f:
        for line in input_txt:
            f.write(line) 

f = open(f'{args.run_name}.txt', "a")
#f.write("datasets used: clef2018, clef2019, clef2020, vqa-rad")
#f.write("pretrain on ROCO")
f.write("\n\n\nMMBERT training " + str(now))
f.close()

In [None]:
#best_acc = 0
#best_loss = np.inf
counter = 0

best_acc = 52.9058
best_loss = 2.9077919

In [None]:
for epoch in range(args.epochs):

    print(f'Epoch {epoch+1}/{args.epochs}')


    train_loss, _, train_acc, _, _ = train_one_epoch(trainloader, model, optimizer, criterion, device, scaler, args, idx2ans)
    val_loss, val_predictions, val_acc, val_bleu = validate(valloader, model, criterion, device, scaler, args, val_df,idx2ans)
    test_loss, test_predictions, test_acc, test_bleu = test(testloader, model, criterion, device, scaler, args, test_df,idx2ans)

    scheduler.step(val_loss)

    print("val_loss: " ,val_loss)
    print("val_acc: " ,val_acc)
    print("test_acc: " ,test_acc)
    
    f = open(f'{args.run_name}.txt', "a")
    f.write('\n\nepoch ' + str(epoch))
    f.write('\nAccuracy and Loss')
    f.write('\ntrain_acc: ' + str(train_acc) + '   train_loss: ' + str(train_loss) + ',')
    f.write('\nval_acc: ' + str(val_acc) + '   val_loss: ' + str(val_loss) + ',')
    f.write('\ntest_acc: ' + str(test_acc) + '   test_loss: ' + str(test_loss) + ',')
    f.write('\nBLEU validation: ' + str(val_bleu))
    f.write('\nBLEU test: ' + str(test_bleu))
    f.write('\nlearning_rate: ' + str(optimizer.param_groups[0]["lr"]))

    if test_acc['total_acc'] > best_acc:
        print('Saving model best acc')
        f.write('\nnew best test total acc')
        torch.save(model.state_dict(), f'{args.run_name}_bestacc.pt')
        best_acc=test_acc['total_acc']
    
    if val_loss < best_loss:
        print('Saving model best val loss')
        f.write('\nnew best val_loss')
        torch.save(model.state_dict(), f'{args.run_name}.pt')
        best_loss=val_loss
        counter=0
        f.write('\ncounter: ' + str(counter))
    else:
        counter+=1
        print("counter: " ,counter)
        f.write('\ncounter: ' + str(counter))
        if counter > 20:
            break
    
    torch.save(model.state_dict(), f'{args.run_name}_lastepoch.pt')
    f.close()