<a href="https://colab.research.google.com/github/mobarakol/tutorial_notebooks/blob/main/VisualBert_EndoVis18_VQA_Sentence.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Download Dataset

In [1]:
# Downloading the VQA EndoVis18 Dataset https://drive.google.com/file/d/1WGdztykX3nW6pi_BKp4rO8nA7ESNRfVN/view?usp=sharing
!gdown --id 1WGdztykX3nW6pi_BKp4rO8nA7ESNRfVN

# Unzipping the VQA EndoVis18 Dataset\
!unzip -q EndoVis-18-VQA.zip

Downloading...
From (original): https://drive.google.com/uc?id=1WGdztykX3nW6pi_BKp4rO8nA7ESNRfVN
From (redirected): https://drive.google.com/uc?id=1WGdztykX3nW6pi_BKp4rO8nA7ESNRfVN&confirm=t&uuid=1cf8155d-61d9-4d9f-a2d6-f4de37d62da8
To: /content/EndoVis-18-VQA.zip
100% 2.70G/2.70G [00:32<00:00, 83.1MB/s]


VisualBert Sentence Generation Model:

In [1]:
'''
Description     : VisualBert + Transformer based sentence generation model.
Paper           : Surgical-VQA: Visual Question Answering in Surgical Scenes Using Transformers
Author          : Lalithkumar Seenivasan, Mobarakol Islam, Adithya Krishna, Hongliang Ren
Lab             : MMLAB, National University of Singapore
Acknowledgement : Code adopted from the official implementation of VisualBertModel from
                  huggingface/transformers (https://github.com/huggingface/transformers.git) and modified.
'''

import numpy as np

import torch
from torch import nn
from torchvision import models
from transformers import VisualBertModel, VisualBertConfig


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
channel_number = 512


'''
Encoder transformer : visualBert Encoder
'''
class VisualBertEncoder(nn.Module):
    def __init__(self, vocab_size, decoder_layers, n_heads):
        '''
        VisualBert encoder
        vocab_size = tokenizer length
        decoder_layers = 6
        n_heads = 6
        '''
        super(VisualBertEncoder, self).__init__()
        VBconfig = VisualBertConfig(vocab_size= vocab_size, visual_embedding_dim = 512, num_hidden_layers = decoder_layers, num_attention_heads = n_heads, hidden_size = 2048)
        self.VisualBertEncoder = VisualBertModel(VBconfig)

        ## image processing
        self.img_feature_extractor = models.resnet18(weights=True)
        self.img_feature_extractor.fc = nn.Sequential(*list(self.img_feature_extractor.fc.children())[:-1])

    def forward(self, inputs, imgs):
        # print(visual_embeds.shape)
        # prepare visual embedding
        visual_embeds = self.img_feature_extractor(imgs)
        visual_embeds = torch.unsqueeze(visual_embeds, dim=1)
        visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long).to(device)
        visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float).to(device)

        # append visual features to text
        inputs.update({
        "visual_embeds": visual_embeds,
        "visual_token_type_ids": visual_token_type_ids,
        "visual_attention_mask": visual_attention_mask,
        "output_attentions": True
        })

        inputs['input_ids'] = inputs['input_ids'].to(device)
        inputs['token_type_ids'] = inputs['token_type_ids'].to(device)
        inputs['attention_mask'] = inputs['attention_mask'].to(device)
        inputs['visual_token_type_ids'] = inputs['visual_token_type_ids'].to(device)
        inputs['visual_attention_mask'] = inputs['visual_attention_mask'].to(device)

        outputs = self.VisualBertEncoder(**inputs)

        return outputs


'''
Decoder transformer
'''
class ScaledDotProductAttention(nn.Module):
    def __init__(self, QKVdim):
        super(ScaledDotProductAttention, self).__init__()
        self.QKVdim = QKVdim

    def forward(self, Q, K, V, attn_mask):
        """
        :param Q: [batch_size, n_heads, -1(len_q), QKVdim]
        :param K, V: [batch_size, n_heads, -1(len_k=len_v), QKVdim]
        :param attn_mask: [batch_size, n_heads, len_q, len_k]
        """
        # scores: [batch_size, n_heads, len_q, len_k]
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.QKVdim)
        # Fills elements of self tensor with value where mask is True.
        scores.to(device).masked_fill_(attn_mask, -1e9)
        attn = nn.Softmax(dim=-1)(scores)  # [batch_size, n_heads, len_q, len_k]
        context = torch.matmul(attn, V).to(device)  # [batch_size, n_heads, len_q, QKVdim]
        return context, attn


class Multi_Head_Attention(nn.Module):
    def __init__(self, Q_dim, K_dim, QKVdim, n_heads=8, dropout=0.1):
        super(Multi_Head_Attention, self).__init__()
        self.W_Q = nn.Linear(Q_dim, QKVdim * n_heads)
        self.W_K = nn.Linear(K_dim, QKVdim * n_heads)
        self.W_V = nn.Linear(K_dim, QKVdim * n_heads)
        self.n_heads = n_heads
        self.QKVdim = QKVdim
        self.embed_dim = Q_dim
        self.dropout = nn.Dropout(p=dropout)
        self.W_O = nn.Linear(self.n_heads * self.QKVdim, self.embed_dim)

    def forward(self, Q, K, V, attn_mask):
        """
        In self-encoder attention:
                Q = K = V: [batch_size, num_pixels=26, encoder_dim=2048]
                attn_mask: [batch_size, len_q=26, len_k=26]
        In self-decoder attention:
                Q = K = V: [batch_size, max_len=20, embed_dim=512]
                attn_mask: [batch_size, len_q=20, len_k=20]
        encoder-decoder attention:
                Q: [batch_size, 20, 512] from decoder
                K, V: [batch_size, 26, 2048] from encoder
                attn_mask: [batch_size, len_q=20, len_k=26]
        return _, attn: [batch_size, n_heads, len_q, len_k]
        """
        residual, batch_size = Q, Q.size(0)
        # q_s: [batch_size, n_heads=8, len_q, QKVdim] k_s/v_s: [batch_size, n_heads=8, len_k, QKVdim]
        q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.QKVdim).transpose(1, 2)
        k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.QKVdim).transpose(1, 2)
        v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.QKVdim).transpose(1, 2)
        # attn_mask: [batch_size, self.n_heads, len_q, len_k]
        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1)
        # attn: [batch_size, n_heads, len_q, len_k]
        # context: [batch_size, n_heads, len_q, QKVdim]
        context, attn = ScaledDotProductAttention(self.QKVdim)(q_s, k_s, v_s, attn_mask)
        # context: [batch_size, n_heads, len_q, QKVdim] -> [batch_size, len_q, n_heads * QKVdim]
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.QKVdim).to(device)
        # output: [batch_size, len_q, embed_dim]
        output = self.W_O(context)
        output = self.dropout(output)
        return nn.LayerNorm(self.embed_dim).to(device)(output + residual), attn


class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, embed_dim, d_ff, dropout):
        '''
        PosewiseFeedForwardNet
        embed_dim = 300
        d_ff      = dim_size
        dropout`  = 0.1
        '''
        super(PoswiseFeedForwardNet, self).__init__()
        """
        Two fc layers can also be described by two cnn with kernel_size=1.
        """
        self.conv1 = nn.Conv1d(in_channels=embed_dim, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=embed_dim, kernel_size=1)
        self.dropout = nn.Dropout(p=dropout)
        self.embed_dim = embed_dim

    def forward(self, inputs):
        """
        encoder: inputs: [batch_size, len_q=26, embed_dim=2048]
        decoder: inputs: [batch_size, max_len=20, embed_dim=512]
        """
        residual = inputs
        output = nn.ReLU()(self.conv1(inputs.transpose(1, 2)))
        output = self.conv2(output).transpose(1, 2)
        output = self.dropout(output)
        return nn.LayerNorm(self.embed_dim).to(device)(output + residual)


class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, dropout, n_heads):
        '''
        Decoder layer
        embed_dim   = 300
        droput      = 0.1
        n_heads     = 6
        '''
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = Multi_Head_Attention(Q_dim=embed_dim, K_dim=embed_dim, QKVdim=64, n_heads=n_heads, dropout=dropout)
        self.dec_enc_attn = Multi_Head_Attention(Q_dim=embed_dim, K_dim=2048, QKVdim=64, n_heads=n_heads, dropout=dropout)
        self.pos_ffn = PoswiseFeedForwardNet(embed_dim=embed_dim, d_ff=2048, dropout=dropout)

    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
        """
        :param dec_inputs: [batch_size, max_len=20, embed_dim=512]
        :param enc_outputs: [batch_size, num_pixels=26, 2048]
        :param dec_self_attn_mask: [batch_size, 20, 20]
        :param dec_enc_attn_mask: [batch_size, 20, 26]
        """
        # print(dec_inputs.shape, enc_outputs.shape, dec_self_attn_mask.shape, dec_enc_attn_mask.shape)

        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
        dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
        dec_outputs = self.pos_ffn(dec_outputs)
        return dec_outputs, dec_self_attn, dec_enc_attn


class Decoder(nn.Module):
    def __init__(self, n_layers, vocab_size, embed_dim, dropout, n_heads, answer_len):
        '''
        Transformer decoder
        n_layers    = 6
        vocab_size  = tokenizer length
        embed_fim   = 300
        dropout     = 0.1
        n_heads     = 6
        answer_len  = 20
        '''
        super(Decoder, self).__init__()
        self.vocab_size = vocab_size
        self.anwer_len = answer_len
        self.tgt_emb = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.pos_emb = nn.Embedding.from_pretrained(self.get_position_embedding_table(embed_dim), freeze=True)
        self.dropout = nn.Dropout(p=dropout)
        self.layers = nn.ModuleList([DecoderLayer(embed_dim, dropout, n_heads) for _ in range(n_layers)])
        self.projection = nn.Linear(embed_dim, vocab_size, bias=False)

    def get_position_embedding_table(self, embed_dim):
        def cal_angle(position, hid_idx):
            return position / np.power(10000, 2 * (hid_idx // 2) / embed_dim)
        def get_posi_angle_vec(position):
            return [cal_angle(position, hid_idx) for hid_idx in range(embed_dim)]

        embedding_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(self.anwer_len)])
        embedding_table[:, 0::2] = np.sin(embedding_table[:, 0::2])  # dim 2i
        embedding_table[:, 1::2] = np.cos(embedding_table[:, 1::2])  # dim 2i+1
        return torch.FloatTensor(embedding_table).to(device)

    def get_attn_pad_mask(self, seq_q, seq_k):
        batch_size, len_q = seq_q.size()
        batch_size, len_k = seq_k.size()
        # In wordmap, <pad>:0
        # pad_attn_mask: [batch_size, 1, len_k], one is masking
        pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)
        return pad_attn_mask.expand(batch_size, len_q, len_k)  # [batch_size, len_q, len_k]

    def get_attn_subsequent_mask(self, seq):
        attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
        subsequent_mask = np.triu(np.ones(attn_shape), k=1)
        subsequent_mask = torch.from_numpy(subsequent_mask).byte().to(device)
        return subsequent_mask

    def forward(self, encoder_out, encoded_captions, caption_lengths):
        """
        :param encoder_out: [batch_size, num_pixels=26, 2048]
        :param encoded_captions: [batch_size, 20]
        :param caption_lengths: [batch_size, 1]
        """
        batch_size = encoder_out.size(0)
        token_size = encoder_out.size(1)
        # Sort input data by decreasing lengths.
        caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]
        # We won't decode at the <end> position, since we've finished generating as soon as we generate <end>
        # So, decoding lengths are actual lengths - 1
        decode_lengths = (caption_lengths - 1).tolist()

        '''# dec_outputs: [batch_size, max_len=20, embed_dim=512]
        # dec_self_attn_pad_mask: [batch_size, len_q=20, len_k=20], 1 if id=0(<pad>)
        # dec_self_attn_subsequent_mask: [batch_size, 20, 20], Upper triangle of an array with 1.
        # dec_self_attn_mask for self-decoder attention, the position whose val > 0 will be masked.
        # dec_enc_attn_mask for encoder-decoder attention.
        # e.g. 9488, 23, 53, 74, 0, 0  |  dec_self_attn_mask:
        # 0 1 1 1 2 2
        # 0 0 1 1 2 2
        # 0 0 0 1 2 2
        # 0 0 0 0 2 2
        # 0 0 0 0 1 2
        # 0 0 0 0 1 1'''
        dec_outputs = self.tgt_emb(encoded_captions) + self.pos_emb(torch.LongTensor([list(range(self.anwer_len))]*batch_size).to(device))
        dec_outputs = self.dropout(dec_outputs)
        dec_self_attn_pad_mask = self.get_attn_pad_mask(encoded_captions, encoded_captions)
        dec_self_attn_subsequent_mask = self.get_attn_subsequent_mask(encoded_captions)
        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)
        dec_enc_attn_mask = (torch.tensor(np.zeros((batch_size, self.anwer_len, token_size))).to(device) == torch.tensor(np.ones((batch_size, self.anwer_len, token_size))).to(device))

        dec_self_attns, dec_enc_attns = [], []
        for layer in self.layers:
            # attn: [batch_size, n_heads, len_q, len_k]
            dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, encoder_out, dec_self_attn_mask, dec_enc_attn_mask)
            dec_self_attns.append(dec_self_attn)
            dec_enc_attns.append(dec_enc_attn)
        predictions = self.projection(dec_outputs)
        return predictions, encoded_captions, decode_lengths, sort_ind, dec_self_attns, dec_enc_attns


'''
VisualBert Encoder + Transformer decoder
'''
class VisualBertSentence(nn.Module):

    def __init__(self, vocab_size, embed_dim, encoder_layers, decoder_layers, dropout=0.1, n_heads=8, answer_len = 20):
        '''
        VisualBert Encoder + Transformer decoder
        vocab_size     = tokenizer length
        embed_dim      = 300
        encoder_layers = 6
        decoder_layers = 6
        dropout        = 0.1
        n_heads        = 6
        answer_len     = 20
        '''
        super(VisualBertSentence, self).__init__()
        self.encoder = VisualBertEncoder(vocab_size, encoder_layers, n_heads)
        self.decoder = Decoder(decoder_layers, vocab_size, embed_dim, dropout, n_heads, answer_len)
        self.embedding = self.decoder.tgt_emb

    def load_pretrained_embeddings(self, embeddings):
        self.embedding.weight = nn.Parameter(embeddings)

    def fine_tune_embeddings(self, fine_tune=True):
        for p in self.embedding.parameters():
            p.requires_grad = fine_tune

    def forward(self, inputs, imgs, encoded_captions, caption_lengths):
        # Vision and text encoder output
        encoder_outputs = self.encoder(inputs, imgs)

        # predict answer using decoder model
        predictions, encoded_captions, decode_lengths, sort_ind, dec_self_attns, dec_enc_attns = self.decoder(encoder_outputs['last_hidden_state'], encoded_captions, caption_lengths)
        alphas = {"enc_self_attns": encoder_outputs['attentions'], "dec_self_attns": dec_self_attns, "dec_enc_attns": dec_enc_attns}
        return predictions, encoded_captions, decode_lengths, alphas, sort_ind

Utils

In [2]:
import torch
import os
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import average_precision_score
from sklearn.metrics import precision_recall_fscore_support

class AverageMeter(object):
    """
    Keeps track of most recent, average, sum, and count of a metric.
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def adjust_learning_rate(optimizer, shrink_factor):
    """
    Shrinks learning rate by a specified factor.

    :param optimizer: optimizer whose learning rate must be shrunk.
    :param shrink_factor: factor in interval (0, 1) to multiply learning rate with.
    """

    print("\nDECAYING learning rate.")
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr'] * shrink_factor
    print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],))


def save_clf_checkpoint(checkpoint_dir, epoch, epochs_since_improvement, model, optimizer, Acc):
    """
    Saves model checkpoint.
    """
    state = {'epoch': epoch,
             'epochs_since_improvement': epochs_since_improvement,
             'Acc': Acc,
             'model': model,
             'optimizer': optimizer}
    filename = checkpoint_dir + 'Best.pth.tar'
    torch.save(state, filename)

def calc_acc(y_true, y_pred):
    acc = accuracy_score(y_true, y_pred)
    return acc

def calc_classwise_acc(y_true, y_pred):
    matrix = confusion_matrix(y_true, y_pred)
    classwise_acc = matrix.diagonal()/matrix.sum(axis=1)
    return classwise_acc

def calc_map(y_true, y_scores):
    mAP = average_precision_score(y_true, y_scores,average=None)
    return mAP

def calc_precision_recall_fscore(y_true, y_pred):
    precision, recall, fscore, _ = precision_recall_fscore_support(y_true, y_pred, average='macro', zero_division = 1)
    return(precision, recall, fscore)


def seed_everything(seed=27):
    '''
    Set random seed for reproducible experiments
    Inputs: seed number
    '''
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [4]:
import glob
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms

class EndoVis18VQAGPTSentence(Dataset):
    '''
    	seq: train_seq  = [2, 3, 4, 6, 7, 9, 10, 11, 12, 14, 15]
    	     val_seq    = [1, 5, 16]
    	folder_head = '../dataset/EndoVis-18-VQA/seq_'
        folder_tail = '/vqa2/Sentence/*.txt'
    '''
    def __init__(self, seq, folder_head, folder_tail, model_ver = None, transform=None):

        self.transform = transforms.Compose([
                                transforms.Resize((300,256)),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
                                ])

        filenames = []
        for curr_seq in seq: filenames = filenames + glob.glob(folder_head + str(curr_seq) + folder_tail)

        self.vqas = []
        for file in filenames:
            file_data = open(file, "r")
            lines = [line.strip("\n") for line in file_data if line != "\n"]
            file_data.close()
            for line in lines:
                q_s, an_s = line.split('|')
                q_s = q_s.split('&')
                an_s = an_s.split(('&'))
                for i in range(len(q_s)):
                    q_a = q_s[i]+'|'+an_s[i]
                    # print(file, q_a)
                    self.vqas.append([file, q_a])
        print('Total files: %d | Total question: %.d' %(len(filenames), len(self.vqas)))

    def __len__(self):
        return len(self.vqas)

    def __getitem__(self, idx):
        loc = self.vqas[idx][0].split('/')

        # img loc[3],
        img_loc = os.path.join(loc[0],loc[1], 'left_frames',loc[-1].split('_')[0]+'.png')
        if self.transform:
            img = Image.open(img_loc)
            img = self.transform(img)

        # question and answer
        question, answer = self.vqas[idx][1].split('|')
        answer = '<|sep|> '+answer


        return img, question, answer


import os
import sys
import argparse
from tqdm import tqdm
import torch
from torch import nn
from torch import optim
from torch import optim
from torch.nn import CrossEntropyLoss
from torch.utils.data  import DataLoader
from transformers import GPT2Tokenizer
from transformers import BertTokenizer
import torch.backends.cudnn as cudnn
from nltk.translate.bleu_score import corpus_bleu

def train(args, train_dataloader, model, criterion, optimizer, epoch, tokenizer, device):

    model.train()

    total_loss = AverageMeter()

    for i, ( imgs, questions, answers) in enumerate(tqdm(train_dataloader),0):

        # prepare questions and answers
        question_list = []
        answer_list = []
        for question in questions: question_list.append(question)
        for answer in answers: answer_list.append(answer)

        question_inputs = tokenizer(question_list, padding="max_length", max_length= args.question_len, return_tensors="pt")
        answer_inputs = tokenizer(answer_list, padding="max_length", max_length= args.answer_len, return_tensors="pt")
        answers_GT_ID = answer_inputs.input_ids.to(device)
        answers_GT_len = torch.sum(answer_inputs.attention_mask, dim=1).unsqueeze(1).to(device)

        # Visual features
        imgs = imgs.to(device)
        visual_len = 80

        # model forward(question, img, answer)
        # print('mob:', answers_GT_len.shape, answers_GT_len, len(answer_inputs))
        logits, _, _, _, _ = model(question_inputs, imgs, answers_GT_ID, answers_GT_len)

        # only consider loss on reference summary just like seq2seq models
        # idx = args.answer_len + 1
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = answer_inputs['input_ids'][..., 1:].contiguous() # 1 because answer has '<|sep|>' in front
        shift_labels = shift_labels.to(device)

        loss = criterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss.update(loss.item())

    print("Epoch: {}/{} Loss: {:.6f} AVG_Loss: {:.6f}".format(epoch, args.epochs, total_loss.val, total_loss.avg))

def validate(args, val_loader, model, criterion, epoch, tokenizer, device, save_output = False):

    references = []
    hypotheses = []

    model.eval()
    total_loss = AverageMeter()

    with torch.no_grad():
        for i, (imgs, questions, answers) in enumerate(tqdm(val_loader),0):

            # prepare questions and answers
            question_list = []
            answer_list = []
            for question in questions: question_list.append(question)
            for answer in answers: answer_list.append(answer)

            question_inputs = tokenizer(question_list, padding="max_length",max_length= args.question_len, return_tensors="pt")
            answer_inputs = tokenizer(answer_list, padding="max_length",max_length= args.answer_len, return_tensors="pt")
            answers_GT_ID = answer_inputs.input_ids.to(device)
            answers_GT_len = torch.sum(answer_inputs.attention_mask, dim=1).unsqueeze(1).to(device)

            # Visual features
            imgs = imgs.to(device)
            visual_len = 80

            # model forward(question, img, answer)
            logits, _, _, _, _ = model(question_inputs, imgs, answers_GT_ID, answers_GT_len)


            # only consider loss on reference summary just like seq2seq models
            # idx = args.question_len + 1
            shift_logits = logits[..., 0:-1, :].contiguous()
            shift_labels = answer_inputs['input_ids'][..., 1:].contiguous() # 1 because answer has '<|sep|>' in front

            # copy for logits and labels for sentence decoding and blue-4 score calculation
            logits_copy = logits.clone()
            shift_labels_copy = shift_labels.clone()

            # loss calculation
            shift_labels = shift_labels.to(device)
            loss = criterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            total_loss.update(loss.item())

            # references    - Ground truth answer
            answer_GT_dec = tokenizer.batch_decode(shift_labels_copy, skip_special_tokens= True)
            for answer_GT_dec_i in answer_GT_dec: references.append([answer_GT_dec_i.split()])
            # print(references)

            # Hypotheses - predicted answer
            _, answer_Gen_id = torch.max(logits_copy, dim=2)
            answer_Gen_dec = tokenizer.batch_decode(answer_Gen_id, skip_special_tokens= True)
            for answer_Gen_dec_i in answer_Gen_dec: hypotheses.append(answer_Gen_dec_i.split())
            # print(hypotheses)


        # Calculate BLEU1~4
        metrics = {}
        metrics["Bleu_1"] = corpus_bleu(references, hypotheses, weights=(1.00, 0.00, 0.00, 0.00))
        metrics["Bleu_2"] = corpus_bleu(references, hypotheses, weights=(0.50, 0.50, 0.00, 0.00))
        metrics["Bleu_3"] = corpus_bleu(references, hypotheses, weights=(0.33, 0.33, 0.33, 0.00))
        metrics["Bleu_4"] = corpus_bleu(references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25))

        print("Epoch: {}/{} EVA LOSS: {:.6f} BLEU-1 {:.6f} BLEU2 {:.6f} BLEU3 {:.6f} BLEU-4 {:.6f}".format
          (epoch, args.epochs, total_loss.avg, metrics["Bleu_1"],  metrics["Bleu_2"],  metrics["Bleu_3"],  metrics["Bleu_4"]))

    return metrics

def get_arg():
    parser = argparse.ArgumentParser(description='VisualQuestionAnswerClassification')

    # Model parameters
    parser.add_argument('--emb_dim',        type=int,   default=300,                                help='dimension of word embeddings.')
    parser.add_argument('--n_heads',        type=int,   default=8,                                  help='Multi-head attention.')
    parser.add_argument('--dropout',        type=float, default=0.1,                                help='dropout')
    parser.add_argument('--encoder_layers', type=int,   default=6,                                  help='the number of layers of encoder in Transformer.')
    parser.add_argument('--decoder_layers', type=int,   default=6,                                  help='the number of layers of decoder in Transformer.')

    # Training parameters
    parser.add_argument('--epochs',         type=int,   default=80,                                 help='number of epochs to train for (if early stopping is not triggered).') #80, 26
    parser.add_argument('--batch_size',     type=int,   default=50,                                 help='batch_size')
    parser.add_argument('--workers',        type=int,   default=1,                                  help='for data-loading; right now, only 1 works with h5pys.')

    # existing checkpoint
    parser.add_argument('--checkpoint',     default=None,                                           help='path to checkpoint, None if none.')

    parser.add_argument('--lr',             type=float, default=0.000001,                            help=' 0.00001, 0.000005')
    parser.add_argument('--checkpoint_dir', default= 'checkpoints/efvlegpt2rs18/m18/v3_p_qf_',      help='m18/c80')
    parser.add_argument('--dataset_type',   default= 'm18',                                         help='m18/c80')
    parser.add_argument('--tokenizer_ver',  default= 'gpt2v1',                                      help='btv2/btv3/gpt2v1')
    parser.add_argument('--model_subver',   default= 'v3',                                          help='V0,v1/v2/v3/v4')
    parser.add_argument('--question_len',   default= 25,                                            help='25')
    parser.add_argument('--answer_len',     default= 35,                                            help='25')
    parser.add_argument('--model_ver',      default= 'efvlegpt2rs18',                               help='efvlegpt2rs18/efvlegpt2Swin/"')  #vrvb/gpt2rs18/gpt2ViT/gpt2Swin/biogpt2rs18/vilgpt2vqa/efgpt2rs18gr/efvlegpt2Swingr
    parser.add_argument('--vis_pos_emb',    default= 'pos',                                         help='None, zeroes, pos')
    parser.add_argument('--patch_size',     default= 5,                                             help='1/2/3/4/5')

    parser.add_argument('--validate',       default=False,                                          help='When only validation required False/True')

    if 'ipykernel' in sys.modules:
        args = parser.parse_args([])
    else:
        args = parser.parse_args()
    return args


if __name__ == '__main__':

    seed_everything()

    args = get_arg()
    args.lr = 0.00005
    args.epochs = 2
    args.checkpoint_dir='checkpoints/efvlegpt2rs18/m18_v1_z_qf_'
    args.dataset_type='m18'
    args.tokenizer_ver='gpt2v1'
    args.model_ver='efvlegpt2rs18'
    args.model_subver='v1'
    args.vis_pos_emb='zeroes'
    args.batch_size=40
    os.makedirs('checkpoints/efvlegpt2rs18', exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
    cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead
    print('device =', device)

    # best model initialize
    start_epoch = 1
    best_epoch = [0]
    best_results = [0.0]
    epochs_since_improvement = 0

    # data location
    train_seq = [2, 3, 4, 6, 7, 9, 10, 11, 12, 14, 15]
    val_seq = [1, 5, 16]

    folder_head = 'EndoVis-18-VQA/seq_'
    folder_tail = '/vqa/Sentence/*.txt'

    train_dataset = EndoVis18VQAGPTSentence(train_seq, folder_head, folder_tail, model_ver=args.model_ver)
    train_dataloader = DataLoader(dataset=train_dataset, batch_size= args.batch_size, shuffle=True, num_workers=8)
    val_dataset = EndoVis18VQAGPTSentence(val_seq, folder_head, folder_tail, model_ver=args.model_ver)
    val_dataloader = DataLoader(dataset=val_dataset, batch_size= args.batch_size, shuffle=False, num_workers=8)

    # tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    # tokenizer.pad_token = tokenizer.eos_token
    # tokenizer_length = len(tokenizer)
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = VisualBertSentence(vocab_size=len(tokenizer), embed_dim=args.emb_dim, encoder_layers=args.encoder_layers, decoder_layers=args.decoder_layers,
                            dropout=args.dropout, n_heads=args.n_heads, answer_len=args.answer_len)
    model = model.to(device)

    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print('model params: ', pytorch_total_params)

    criterion = CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    for epoch in range(start_epoch, args.epochs):

        if epochs_since_improvement > 0 and epochs_since_improvement % 5 == 0:
            adjust_learning_rate(optimizer, 0.8)

        # train
        train(args, train_dataloader=train_dataloader, model = model, criterion=criterion, optimizer=optimizer, epoch=epoch, tokenizer = tokenizer, device = device)

        # validation
        metrics = validate(args, val_loader=val_dataloader, model = model, criterion=criterion, epoch=epoch, tokenizer = tokenizer, device = device)

        if metrics["Bleu_4"] >= best_results[0]:
            epochs_since_improvement = 0

            best_results[0] = metrics["Bleu_4"]
            best_epoch[0] = epoch
            save_clf_checkpoint(args.checkpoint_dir, epoch, epochs_since_improvement, model, optimizer, best_results[0])
        else:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))


device = cuda
Total files: 1560 | Total question: 10574
Total files: 447 | Total question: 3216
model params:  301176908


100%|██████████| 265/265 [06:13<00:00,  1.41s/it]


Epoch: 1/2 Loss: 1.195600 AVG_Loss: 2.574225


100%|██████████| 81/81 [01:52<00:00,  1.38s/it]


Epoch: 1/2 EVA LOSS: 1.352161 BLEU-1 0.547054 BLEU2 0.491378 BLEU3 0.456378 BLEU-4 0.418068


TypeError: save_clf_checkpoint() missing 2 required positional arguments: 'Acc' and 'final_args'