In [1]:
!pip install -q yacs unicodedata2 nltk transformers sentencepiece accelerate transformers[torch]

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m468.0/468.0 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m265.7/265.7 kB[0m [31m18.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
!gdown 1TztHmmwN8R6MGEABPnxgaBErZ4vDOjdJ
!unzip -q 'vivqa-dataset.zip'
!rm /content/vivqa-dataset.zip

Downloading...
From: https://drive.google.com/uc?id=1TztHmmwN8R6MGEABPnxgaBErZ4vDOjdJ
To: /content/vivqa-dataset.zip
100% 527M/527M [00:04<00:00, 112MB/s]


In [3]:
# Mount google drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [4]:
import os
import re
import string
import unicodedata
import numpy as np
import pandas as pd
import pickle
import shutil
from tqdm import tqdm
import itertools
from PIL import Image
from collections import Counter
from typing import Dict, List, Union
from sklearn.model_selection import train_test_split

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.nn import NLLLoss

from transformers import BeitImageProcessor, BeitModel
from transformers import BertTokenizer
from transformers.models.bert.modeling_bert import (
    BertConfig,
    BertEmbeddings,
    BertEncoder,
    BertPreTrainedModel,
)

# Prepare Dataset

In [5]:
DATASET_DIR = 'vivqa-dataset'
IMAGES_DIR = os.path.join(DATASET_DIR, 'images')
TRAIN_PATH = os.path.join(DATASET_DIR, 'train.csv')
TEST_PATH = os.path.join(DATASET_DIR, 'test.csv')

In [6]:
# Load train.csv
train_df = pd.read_csv(TRAIN_PATH)
if train_df.shape[1] > 4:
    train_df = train_df.iloc[: , 1:] # Drop first column (index)

# Load test.csv
test_df = pd.read_csv(TEST_PATH)
if test_df.shape[1] > 4:
    test_df = test_df.iloc[: , 1:] # Drop first column (index)

## Remove duplicate

In [7]:
# Remove duplicate train_df
train_df.drop_duplicates(keep=False, inplace=True)
train_df.to_csv(TRAIN_PATH, index=False)
print(f'Train shape: {train_df.shape}')

# Remove duplicate test_df
test_df.drop_duplicates(keep=False, inplace=True)
test_df.to_csv(TEST_PATH, index=False)
print(f'Test shape: {test_df.shape}')

Train shape: (11819, 4)
Test shape: (2987, 4)


# Config

In [8]:
from yacs.config import CfgNode
import yaml

if os.path.exists('/content/drive'):
    BASE_DIR = '/content/drive/MyDrive'
else:
    BASE_DIR = ''

config_file = os.path.join(BASE_DIR, "ViVQA-Models", "beit_mbert_classification.yaml")

with open(config_file, "r") as stream:
    try:
        CONFIG =  CfgNode(init_dict=yaml.safe_load(stream))
    except yaml.YAMLError as exc:
        print(exc)

print(CONFIG.MODEL.NAME)

beit_mbert_classification


# Build dataset

In [9]:
class ClassificationVocab(object):
    # This class is especially designed for ViVQA dataset by treating the VQA as a classification task.
    # For more information, please visit https://arxiv.org/abs/1708.02711

    def __init__(self, config):

        self.tokenizer = config.TOKENIZER

        self.padding_token = config.PAD_TOKEN
        self.bos_token = config.BOS_TOKEN
        self.eos_token = config.EOS_TOKEN
        self.unk_token = config.UNK_TOKEN

        self.make_vocab([
            config.DF_PATH.TRAIN,
            config.DF_PATH.TEST
        ])

        counter = self.freqs.copy()

        min_freq = max(config.MIN_FREQ, 1)

        specials = [self.padding_token, self.bos_token, self.eos_token, self.unk_token]
        itos = specials
        # frequencies of special tokens are not counted when building vocabulary
        # in frequency order
        for tok in specials:
            del counter[tok]

        # sort by frequency, then alphabetically
        words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
        words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)

        for word, freq in words_and_frequencies:
            if freq < min_freq:
                break
            itos.append(word)

        self.itos = {i: tok for i, tok in enumerate(itos)}
        self.stoi = {tok: i for i, tok in enumerate(itos)}

        self.specials = [self.padding_token, self.bos_token, self.eos_token, self.unk_token]

        self.padding_idx = self.stoi[self.padding_token]
        self.bos_idx = self.stoi[self.bos_token]
        self.eos_idx = self.stoi[self.eos_token]
        self.unk_idx = self.stoi[self.unk_token]


    def make_vocab(self, df_path_list):
        self.freqs = Counter()
        itoa = set()
        self.max_question_length = 0
        for df_path in df_path_list:
            df = pd.read_csv(df_path)
            for question in df['question']:
                question = preprocess_sentence(question, self.tokenizer)
                self.freqs.update(question)
                if len(question) + 2 > self.max_question_length:
                    self.max_question_length = len(question) + 2
            for answer in df['answer']:
                answer = " ".join(preprocess_sentence(answer, self.tokenizer))
                itoa.add(answer)

        self.itoa = {ith: answer for ith, answer in enumerate(itoa)}
        self.atoi = {answer: ith for ith, answer in self.itoa.items()}
        self.total_answers = len(self.atoi)

    def encode_question(self, question: List[str]) -> torch.Tensor:
        """ Turn a question into a vector of indices and a question length """
        vec = torch.ones(self.max_question_length).long() * self.padding_idx
        for i, token in enumerate([self.bos_token] + question + [self.eos_token]):
            vec[i] = self.stoi[token] if token in self.stoi else self.unk_idx
        return vec

    def encode_answer(self, answer: List[str]) -> torch.Tensor:
        answer = " ".join(answer)
        return torch.tensor([self.atoi[answer]], dtype=torch.long)

    def decode_question(self, question_vecs: torch.Tensor, join_words=True) -> List[str]:
        '''
            question_vecs: (bs, max_length)
        '''
        questions = []
        for vec in question_vecs:
            question = " ".join([self.itos[idx] for idx in vec.tolist() if self.itos[idx] not in self.specials])
            if join_words:
                questions.append(question)
            else:
                questions.append(question.strip().split())
        return questions

    def decode_answer(self, answer_vecs: torch.Tensor, join_word=False) -> Union[List[str], List[List[str]]]:
        answers = []
        list_answers = answer_vecs.tolist()
        for answer_idx in list_answers:
            ans_i = self.itoa[answer_idx] if join_word else self.itoa[answer_idx].split()
            answers.append(str(ans_i))
        return answers

In [10]:
# Build dataset class
class ViVQA_Dataset(torch.utils.data.Dataset):
  """
  Dataset class for the ViVQA dataset.
  """
  def __init__(self, df, img_dir, vocab):
    self.df = df
    self.img_dir = img_dir
    self.vocab = vocab

  def __len__(self):
    return self.df.shape[0]

  def __getitem__(self, idx):
    question = self.df.loc[idx, 'question']
    answer = self.df.loc[idx, 'answer']
    image_id = self.df.loc[idx, 'img_id']
    quest_type = self.df.loc[idx, 'type']

    img_file = os.path.join(self.img_dir, f'image_{image_id}.jpg')

    return {'image':img_file, 'question':question, 'answer':answer}

# Metrics

## CIDEr

In [11]:
import copy
from collections import defaultdict
import math

def precook(s, n=4):
    """
    Takes a string as input and returns an object that can be given to
    either cook_refs or cook_test. This is optional: cook_refs and cook_test
    can take string arguments as well.
    :param s: string : sentence to be converted into ngrams
    :param n: int    : number of ngrams for which representation is calculated
    :return: term frequency vector for occuring ngrams
    """
    words = str(s).split()
    counts = defaultdict(int)
    for k in range(1,n+1):
        for i in range(len(words)-k+1):
            ngram = tuple(words[i:i+k])
            counts[ngram] += 1
    return counts

def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
    '''Takes a list of reference sentences for a single segment
    and returns an object that encapsulates everything that BLEU
    needs to know about them.
    :param refs: list of string : reference sentences for some image
    :param n: int : number of ngrams for which (ngram) representation is calculated
    :return: result (list of dict)
    '''
    return [precook(ref, n) for ref in refs]

def cook_test(test, n=4):
    '''Takes a test sentence and returns an object that
    encapsulates everything that BLEU needs to know about it.
    :param test: list of string : hypothesis sentence for some image
    :param n: int : number of ngrams for which (ngram) representation is calculated
    :return: result (dict)
    '''
    return precook(test, n)

class CiderScorer(object):
    """CIDEr scorer.
    """

    def __init__(self, refs, test=None, n=4, sigma=6.0, doc_frequency=None, ref_len=None):
        ''' singular instance '''
        self.n = n
        self.sigma = sigma
        self.crefs = []
        self.ctest = []
        self.doc_frequency = defaultdict(float)
        self.ref_len = None

        for k in range(len(refs)):
            self.crefs.append(cook_refs(refs[k]))
            if test is not None:
                self.ctest.append(cook_test(test[k][0]))  ## N.B.: -1
            else:
                self.ctest.append(None)  # lens of crefs and ctest have to match

        if doc_frequency is None and ref_len is None:
            # compute idf
            self.compute_doc_freq()
            # compute log reference length
            self.ref_len = np.log(float(len(self.crefs)))
        else:
            self.doc_frequency = doc_frequency
            self.ref_len = ref_len

    def compute_doc_freq(self):
        '''
        Compute term frequency for reference data.
        This will be used to compute idf (inverse document frequency later)
        The term frequency is stored in the object
        :return: None
        '''
        for refs in self.crefs:
            # refs, k ref captions of one image
            for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
                self.doc_frequency[ngram] += 1
            # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)

    def compute_cider(self):
        def counts2vec(cnts):
            """
            Function maps counts of ngram to vector of tfidf weights.
            The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
            The n-th entry of array denotes length of n-grams.
            :param cnts:
            :return: vec (array of dict), norm (array of float), length (int)
            """
            vec = [defaultdict(float) for _ in range(self.n)]
            length = 0
            norm = [0.0 for _ in range(self.n)]
            for (ngram,term_freq) in cnts.items():
                # give word count 1 if it doesn't appear in reference corpus
                df = np.log(max(1.0, self.doc_frequency[ngram]))
                # ngram index
                n = len(ngram)-1
                # tf (term_freq) * idf (precomputed idf) for n-grams
                vec[n][ngram] = float(term_freq)*(self.ref_len - df)
                # compute norm for the vector.  the norm will be used for computing similarity
                norm[n] += pow(vec[n][ngram], 2)

                if n == 1:
                    length += term_freq
            norm = [np.sqrt(n) for n in norm]
            return vec, norm, length

        def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
            '''
            Compute the cosine similarity of two vectors.
            :param vec_hyp: array of dictionary for vector corresponding to hypothesis
            :param vec_ref: array of dictionary for vector corresponding to reference
            :param norm_hyp: array of float for vector corresponding to hypothesis
            :param norm_ref: array of float for vector corresponding to reference
            :param length_hyp: int containing length of hypothesis
            :param length_ref: int containing length of reference
            :return: array of score for each n-grams cosine similarity
            '''
            delta = float(length_hyp - length_ref)
            # measure consine similarity
            val = np.array([0.0 for _ in range(self.n)])
            for n in range(self.n):
                # ngram
                for (ngram,count) in vec_hyp[n].items():
                    # vrama91 : added clipping
                    val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]

                if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
                    val[n] /= (norm_hyp[n]*norm_ref[n])

                assert(not math.isnan(val[n]))
                # vrama91: added a length based gaussian penalty
                val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
            return val

        scores = []
        for test, refs in zip(self.ctest, self.crefs):
            # compute vector for test captions
            vec, norm, length = counts2vec(test)
            # compute vector for ref captions
            score = np.array([0.0 for _ in range(self.n)])
            for ref in refs:
                vec_ref, norm_ref, length_ref = counts2vec(ref)
                score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
            # change by vrama91 - mean of ngram scores, instead of sum
            score_avg = np.mean(score)
            # divide by number of references
            score_avg /= len(refs)
            # multiply score by 10
            score_avg *= 10.0
            # append score of an image to the score list
            scores.append(score_avg)
        return scores

    def compute_score(self):
        # compute cider score
        score = self.compute_cider()
        # debug
        # print score
        return np.mean(np.array(score)), np.array(score)

## Extract match

In [12]:
class Exact_Match:
    def compute_score(self, y_true, y_pred):
        if y_true==y_pred:
            return 1
        else:
            return 0

## F1 Score

In [13]:
class F1:
  def Precision(self,y_true,y_pred):
    if y_pred is None:
       return 0
    common = set(y_true) & set(y_pred)
    return len(common) / len(set(y_pred))

  def Recall(self,y_true,y_pred):
    common = set(y_true) & set(y_pred)
    return len(common) / len(set(y_true))

  def compute_score(self,y_true,y_pred):
    if len(y_pred) == 0 or len(y_true) == 0:
        return int(y_pred == y_true)

    precision = self.Precision(y_true, y_pred)
    recall = self.Recall(y_true, y_pred)

    if precision == 0 or recall == 0:
        return 0
    f1 = 2*precision*recall / (precision+recall)
    return f1

## Wup

In [14]:
import nltk
nltk.download('wordnet')
from nltk.corpus import wordnet

class Wup:
    def get_semantic_field(self,a):
        weight = 1.0
        semantic_field = wordnet.synsets(str(a), pos=wordnet.NOUN)
        return (semantic_field,weight)

    def get_stem_word(self,a):
        """
        Sometimes answer has form word\d+:wordid.
        If so we return word and downweight
        """
        weight = 1.0
        return (a,weight)
    def compute_score(self, a: str, b: str, similarity_threshold: float = 0.9):
        """
        Returns Wu-Palmer similarity score.
        More specifically, it computes:
            max_{x \in interp(a)} max_{y \in interp(b)} wup(x,y)
            where interp is a 'interpretation field'
        """
        global_weight=1.0

        (a,global_weight_a)=self.get_stem_word(a)
        (b,global_weight_b)=self.get_stem_word(b)
        global_weight = min(global_weight_a,global_weight_b)

        if a==b:
            # they are the same
            return 1.0*global_weight

        if a==[] or b==[]:
            return 0

        interp_a,weight_a = self.get_semantic_field(a)
        interp_b,weight_b = self.get_semantic_field(b)

        if interp_a == [] or interp_b == []:
            return 0

        # we take the most optimistic interpretation
        global_max=0.0
        for x in interp_a:
            for y in interp_b:
                local_score=x.wup_similarity(y)
                if local_score > global_max:
                    global_max=local_score

        # we need to use the semantic fields and therefore we downweight
        # unless the score is high which indicates both are synonyms
        if global_max < similarity_threshold:
            interp_weight = 0.1
        else:
            interp_weight = 1.0

        final_score=global_max*weight_a*weight_b*interp_weight*global_weight
        return final_score

[nltk_data] Downloading package wordnet to /root/nltk_data...


In [15]:
import re
import unicodedata

def normalize_text(text):
    text = text.translate(str.maketrans("", "", string.punctuation))
    text = text.lower().strip()
    return text

def preprocess_sentence(sentence: str, tokenizer=None):
    sentence = sentence.lower()
    sentence = unicodedata.normalize('NFC', sentence)
    sentence = re.sub(r"[“”]", "\"", sentence)
    sentence = re.sub(r"!", " ! ", sentence)
    sentence = re.sub(r"\?", " ? ", sentence)
    sentence = re.sub(r":", " : ", sentence)
    sentence = re.sub(r";", " ; ", sentence)
    sentence = re.sub(r",", " , ", sentence)
    sentence = re.sub(r"\"", " \" ", sentence)
    sentence = re.sub(r"'", " ' ", sentence)
    sentence = re.sub(r"\(", " ( ", sentence)
    sentence = re.sub(r"\[", " [ ", sentence)
    sentence = re.sub(r"\)", " ) ", sentence)
    sentence = re.sub(r"\]", " ] ", sentence)
    sentence = re.sub(r"/", " / ", sentence)
    sentence = re.sub(r"\.", " . ", sentence)
    sentence = re.sub(r"-", " - ", sentence)
    sentence = re.sub(r"\$", " $ ", sentence)
    sentence = re.sub(r"\&", " & ", sentence)
    sentence = re.sub(r"\*", " * ", sentence)
    # tokenize the sentence
    if tokenizer is None:
        tokenizer = lambda s: s
    sentence = tokenizer(sentence)
    sentence = " ".join(sentence.strip().split()) # remove duplicated spaces
    tokens = sentence.strip().split()

    return tokens

class ScoreCalculator:
    def __init__(self):
        self.f1_caculate=F1()
        self.em_caculate=Exact_Match()
        self.Wup_caculate=Wup()
    #F1 score character level
    def f1_char(self,labels: List[str], preds: List[str]) -> float:
        scores=[]
        for i in range(len(labels)):
            scores.append(self.f1_caculate.compute_score(str(preprocess_sentence(normalize_text(labels[i]))).split(),str(preprocess_sentence(normalize_text(preds[i]))).split()))
        return np.mean(scores)

    #F1 score token level
    def f1_token(self,labels: List[str], preds: List[str]) -> float:
        scores=[]
        for i in range(len(labels)):
            scores.append(self.f1_caculate.compute_score(str(preprocess_sentence(normalize_text(labels[i]))).split(),str(preprocess_sentence(normalize_text(preds[i]))).split()))
        return np.mean(scores)
    #Excat match score
    def em(self,labels: List[str], preds: List[str]) -> float:
        scores=[]
        for i in range(len(labels)):
            scores.append(self.em_caculate.compute_score(str(preprocess_sentence(normalize_text(labels[i]))).split(),str(preprocess_sentence(normalize_text(preds[i]))).split()))
        return np.mean(scores)
    #Wup score
    def wup(self,labels: List[str], preds: List[str]) -> float:
        scores=[]
        for i in range(len(labels)):
            scores.append(self.Wup_caculate.compute_score(str(preprocess_sentence(normalize_text(labels[i]))).split(),str(preprocess_sentence(normalize_text(preds[i]))).split()))
        return np.mean(scores)
    #Cider score
    def cider_score(self,labels: List[str], preds: List[str]) -> float:
        labels=[[preprocess_sentence(normalize_text(label))] for label in labels]
        preds=[[preprocess_sentence(normalize_text(pred))] for pred in preds ]
        cider_caculate= CiderScorer(labels, test=preds, n=4, sigma=6.)
        scores,_=cider_caculate.compute_score()
        return scores

# Model

In [16]:
def generate_padding_mask(sequences, padding_idx: int) -> torch.BoolTensor:
    '''
        sequences: (bs, seq_len, dim)
    '''
    if sequences is None:
        return None

    if len(sequences.shape) == 2: # (bs, seq_len)
        __seq = sequences.unsqueeze(dim=-1) # (bs, seq_len, 1)
    else:
        __seq = sequences

    mask = (torch.sum(__seq, dim=-1) == (padding_idx*__seq.shape[-1])).long() * -10e4 # (b_s, seq_len)
    return mask.unsqueeze(1).unsqueeze(1) # (bs, 1, 1, seq_len)

class TextBert(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.init_weights()

    def forward(self, txt_inds, txt_mask):
        encoder_inputs = self.embeddings(txt_inds)

        attention_mask = txt_mask
        head_mask = [None] * self.config.num_hidden_layers
        encoder_outputs = self.encoder(
            encoder_inputs, attention_mask, head_mask=head_mask
        )
        seq_output = encoder_outputs[0]

        return seq_output

class BertEmbedding(nn.Module):
    def __init__(self, config, vocab):
        super().__init__()

        self.device = config.DEVICE

        bert_config = BertConfig.from_pretrained(config.PRETRAINED_NAME)

        self.tokenizer = BertTokenizer.from_pretrained(config.PRETRAINED_NAME)
        self.embedding = TextBert(bert_config)
        self.embedding = self.embedding.from_pretrained(config.PRETRAINED_NAME)

        # freeze all parameters of pretrained model
        for param in self.embedding.parameters():
            param.requires_grad = False

        self.proj = nn.Linear(config.D_PRETRAINED_FEATURE, config.D_MODEL)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(config.DROPOUT)

    def forward(self, questions: List[str]):
        inputs = self.tokenizer(questions, return_tensors="pt", padding=True).input_ids.to(self.device)
        padding_mask = generate_padding_mask(inputs, padding_idx=self.tokenizer.pad_token_id)
        features = self.embedding(inputs, padding_mask)

        out = self.proj(features)
        out = self.dropout(self.gelu(out))

        return out, padding_mask

In [17]:
class BEiTEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.device = torch.device(config.DEVICE)

        self.feature_extractor = BeitImageProcessor.from_pretrained(config.PRETRAINED_NAME)
        self.backbone = BeitModel.from_pretrained(config.PRETRAINED_NAME)
        # freeze all parameters of pretrained model
        for param in self.backbone.parameters():
            param.requires_grad = False

        self.proj = nn.Linear(config.D_PRETRAINED_FEATURE, config.D_MODEL)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(config.DROPOUT)

    def forward(self, images: List[Image.Image]):
        inputs = self.feature_extractor(images, return_tensors="pt").to(self.device)
        features = self.backbone(**inputs).last_hidden_state
        padding_mask = generate_padding_mask(features, padding_idx=0)

        out = self.proj(features)
        out = self.dropout(self.gelu(out))

        return out, padding_mask

In [18]:
class BEiTmBERTClassification(nn.Module):
    def __init__(self, config, vocab):
        super().__init__()
        self.d_model = config.D_MODEL

        self.text_embedding = BertEmbedding(config.TEXT_EMBEDDING, vocab)
        self.vision_encoder = BEiTEmbedding(config.VISION_EMBEDDING)

        self.fusion = nn.Linear(config.D_MODEL, config.D_MODEL)
        self.dropout = nn.Dropout(config.DROPOUT)
        self.norm = nn.LayerNorm(config.D_MODEL)

        self.proj = nn.Linear(config.D_MODEL, vocab.total_answers)

    def init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, questions: List[str], images: List[str]):
        images = [Image.open(image_path).convert("RGB")  for image_path in images]
        vision_features,_ = self.vision_encoder(images)
        text_features,_ = self.text_embedding(questions)

        fused_features = torch.cat([vision_features, text_features], dim=1)
        fused_features = self.dropout(self.fusion(fused_features))
        out = fused_features.sum(dim=1)
        out = self.proj(out)

        return F.log_softmax(out, dim=-1)

In [26]:
class ClassificationTask():
    def __init__(self, config):
        # Create checkpoint folder if not existed
        self.checkpoint_path = os.path.join(BASE_DIR, config.TRAINING.CHECKPOINT_PATH, config.MODEL.NAME)
        if not os.path.isdir(self.checkpoint_path):
            print("Creating checkpoint path")
            os.makedirs(self.checkpoint_path)

        # Create/Load vocab
        if not os.path.isfile(os.path.join(self.checkpoint_path, "vocab.bin")):
            print("Creating vocab")
            self.load_vocab(config.DATASET.VOCAB)
            print("Saving vocab to %s" % os.path.join(self.checkpoint_path, "vocab.bin"))
            pickle.dump(self.vocab, open(os.path.join(self.checkpoint_path, "vocab.bin"), "wb"))
        else:
            print("Loading vocab from %s" % os.path.join(self.checkpoint_path, "vocab.bin"))
            self.vocab = pickle.load(open(os.path.join(self.checkpoint_path, "vocab.bin"), "rb"))

        print("Loading data")
        self.load_dataset(config)
        self.create_dataloader(config)

        self.config = config
        self.device = torch.device(config.MODEL.DEVICE)
        self.model = BEiTmBERTClassification(config.MODEL, self.vocab)
        self.model.to(self.device)

        # Init hyperparameters
        self.epoch = 0
        self.score = config.TRAINING.SCORE
        self.score_value = 0.
        self.compute_score = ScoreCalculator()
        self.learning_rate = 1.0e-6 #config.TRAINING.LEARNING_RATE
        self.patience = config.TRAINING.PATIENCE

        self.optim = Adam(self.model.parameters(), lr=self.learning_rate, betas=(0.9, 0.98))
        lambda_epoch = lambda epoch: 0.85 ** epoch
        self.scheduler = LambdaLR(self.optim, lr_lambda=lambda_epoch)
        self.loss_fn = NLLLoss(ignore_index=self.vocab.padding_idx)

    def load_vocab(self, config):
        self.vocab = ClassificationVocab(config)

    def load_dataset(self, config):
        train_df = pd.read_csv(config.DATASET.DF_PATH.TRAIN)
        X = train_df[['question', 'answer', 'img_id', 'type']]
        # Add a dummy target variable
        train_df['dummy_target'] = train_df['type']
        train_X, valid_X, train_dummy, valid_dummy = train_test_split(X, train_df['dummy_target'], test_size=0.2, random_state=42, stratify=train_df['dummy_target'])
        # Create dataframes for training and validation sets
        train_df = pd.DataFrame({'question': train_X['question'], 'answer': train_X['answer'], 'img_id': train_X['img_id'], 'type': train_X['type']})
        valid_df = pd.DataFrame({'question': valid_X['question'], 'answer': valid_X['answer'], 'img_id': valid_X['img_id'], 'type': valid_X['type']})
        train_df.reset_index(drop=True, inplace=True)
        valid_df.reset_index(drop=True, inplace=True)
        test_df = pd.read_csv(config.DATASET.DF_PATH.TEST)
        self.train_dataset = ViVQA_Dataset(train_df, config.DATASET.FEATURE_PATH.IMAGE, self.vocab)
        self.valid_dataset = ViVQA_Dataset(valid_df, config.DATASET.FEATURE_PATH.IMAGE, self.vocab)
        self.test_dataset = ViVQA_Dataset(test_df, config.DATASET.FEATURE_PATH.IMAGE, self.vocab)
        print(f'[INFO] Train size: {len(self.train_dataset)}')
        print(f'[INFO] Valid size: {len(self.valid_dataset)}')
        print(f'[INFO] Test size: {len(self.test_dataset)}')

    def create_dataloader(self, config):
        self.train_dataloader = DataLoader(self.train_dataset, batch_size=config.DATASET.BATCH_SIZE, shuffle=True, num_workers=config.DATASET.WORKERS)
        self.valid_dataloader = DataLoader(self.valid_dataset, batch_size=config.DATASET.BATCH_SIZE, shuffle=False, num_workers=config.DATASET.WORKERS)
        self.test_dataloader = DataLoader(self.test_dataset, batch_size=config.DATASET.BATCH_SIZE, shuffle=False, num_workers=config.DATASET.WORKERS)

    def load_checkpoint(self, fname) -> dict:
        if not os.path.exists(fname):
            return None
        print(f"Loading checkpoint from {fname}")
        checkpoint = torch.load(fname)
        self.score_value = checkpoint['score_value']
        self.model.load_state_dict(checkpoint['model_state_dict'], strict=False)
        self.optim.load_state_dict(checkpoint['optimizer_state_dict'])
        return checkpoint

    def save_checkpoint(self) -> None:
        dict_for_saving = {
            'epoch': self.epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optim.state_dict(),
            'score_value': self.score_value,
        }
        torch.save(dict_for_saving, os.path.join(self.checkpoint_path, "last_model.pth"))

    def evaluate_loss(self, dataloader: DataLoader):
        self.model.eval()
        running_loss = .0
        with tqdm(desc='Epoch %d - Validation' % self.epoch, unit='it', total=len(dataloader)) as pbar:
            with torch.no_grad():
                for it, items in enumerate(dataloader):
                    quest, img, answer = items['question'], items['image'], items['answer']
                    with torch.no_grad():
                        out = self.model(quest, img).contiguous()

                    answer = torch.stack([self.vocab.encode_answer(preprocess_sentence(ans, self.vocab.tokenizer)) for ans in answer]).to(self.device)
                    loss = self.loss_fn(out.view(-1, self.vocab.total_answers), answer.view(-1))
                    this_loss = loss.item()
                    running_loss += this_loss

                    pbar.set_postfix(loss=running_loss / (it + 1))
                    pbar.update()

        val_loss = running_loss / len(dataloader)
        return val_loss

    def evaluate_metrics(self, dataloader: DataLoader):
        self.model.eval()
        wups=0.
        em=0.
        f1=0.
        cider=0.
        with tqdm(desc='Epoch %d - Evaluation' % self.epoch, unit='it', total=len(dataloader)) as pbar:
            for it, items in enumerate(dataloader):
                quest, img, answer = items['question'], items['image'], items['answer']
                with torch.no_grad():
                    outs = self.model(quest, img).contiguous()

                answer = torch.stack([self.vocab.encode_answer(preprocess_sentence(ans, self.vocab.tokenizer)) for ans in answer]).to(self.device)
                answers_gt = self.vocab.decode_answer(answer.squeeze(-1), join_word=True)
                answers_gen = self.vocab.decode_answer(outs.argmax(dim=-1), join_word=True)

                wups+=self.compute_score.wup(answers_gt, answers_gen)
                em+=self.compute_score.em(answers_gt, answers_gen)
                f1+=self.compute_score.f1_token(answers_gt, answers_gen)
                cider+=self.compute_score.cider_score(answers_gt, answers_gen)
                pbar.update()

        scores ={
            'wups':wups/len(dataloader),
            'em':em/len(dataloader),
            'f1':f1/len(dataloader),
            'cider':cider/len(dataloader)
        }
        return scores

    def train(self):
        self.model.train()
        running_loss = .0
        with tqdm(desc='Epoch %d - Training  ' % self.epoch, unit='it', total=len(self.train_dataloader)) as pbar:
            for it, items in enumerate(self.train_dataloader):
                quest, img, answer = items['question'], items['image'], items['answer']
                out = self.model(quest, img).contiguous()
                self.optim.zero_grad()
                answer = torch.stack([self.vocab.encode_answer(preprocess_sentence(ans, self.vocab.tokenizer)) for ans in answer]).to(self.device)
                loss = self.loss_fn(out.view(-1, self.vocab.total_answers), answer.view(-1))
                loss.backward()

                self.optim.step()
                this_loss = loss.item()
                running_loss += this_loss

                pbar.set_postfix({'loss': running_loss / (it + 1), 'lr':self.scheduler.get_last_lr()[0]})
                pbar.update()
        self.scheduler.step()

        return running_loss / len(self.train_dataloader)

    def start(self):
        if os.path.isfile(os.path.join(self.checkpoint_path, "last_model.pth")):
            checkpoint = self.load_checkpoint(os.path.join(self.checkpoint_path, "last_model.pth"))
            self.epoch = checkpoint["epoch"] + 1
            print("Resuming from epoch %d" % self.epoch)
            patience = 0
        else:
            self.score_value = .0
            patience = 0

        while True:
            train_loss = self.train()

            # val scores
            scores = self.evaluate_metrics(self.valid_dataloader)
            with open(os.path.join(self.checkpoint_path, "log.txt"), "a") as f:
                f.write(f"Epoch {self.epoch:2d} - Training loss: {train_loss:.4f} - Validation wups: {scores['wups']:.4f} - Validation em: {scores['em']:.4f} - Validation f1: {scores['f1']:.4f} - Validation cider: {scores['cider']:.4f}\n")

            print(f"Validation wups: {scores['wups']:.4f} - em: {scores['em']:.4f} - f1: {scores['f1']:.4f} - cider: {scores['cider']:.4f}")
            val_score = scores[self.score]

            # Prepare for next epoch
            best = False
            if val_score > self.score_value:
                self.score_value = val_score
                patience = 0
                best = True
            else:
                patience += 1

            exit_train = False
            if patience == self.patience:
                print('Patience reached.')
                exit_train = True

            self.save_checkpoint()

            if best:
                shutil.copy(os.path.join(self.checkpoint_path, "last_model.pth"),
                        os.path.join(self.checkpoint_path, "best_model.pth"))

            if exit_train:
                break

            self.epoch += 1

    def get_predictions(self):
        if not os.path.isfile(os.path.join(self.checkpoint_path, 'best_model.pth')):
            print("Prediction require the model must be trained. There is no weights to load for model prediction!")
            raise FileNotFoundError("Make sure your checkpoint path is correct or the best_model.pth is available in your checkpoint path")

        self.load_checkpoint(os.path.join(self.checkpoint_path, "best_model.pth"))

        self.model.eval()
        img_path=[]
        quests_result=[]
        gts=[]
        preds=[]
        wups=0.
        em=0.
        f1=0.
        cider=0.
        with tqdm(desc='Getting predictions: ', unit='it', total=len(self.test_dataloader)) as pbar:
            for it, items in enumerate(self.test_dataloader):
                quest, img, answers = items['question'], items['image'], items['answer']
                with torch.no_grad():
                    outs = self.model(items['question'], items['image'])

                answers_gen = self.vocab.decode_answer(outs.argmax(dim=-1), join_word=True)

                img_path.extend(img)
                quests_result.extend(quest)
                gts.extend(answers)
                preds.extend(answers_gen)
                wups+=self.compute_score.wup(answers, answers_gen)
                em+=self.compute_score.em(answers, answers_gen)
                f1+=self.compute_score.f1_token(answers, answers_gen)
                cider+=self.compute_score.cider_score(answers, answers_gen)

                pbar.update()

        results={
            "img_path": img_path,
            "question": quests_result,
            "ground_truth":gts,
            "predict": preds,
        }

        scores ={
            'wups':wups/len(self.test_dataloader),
            'em':em/len(self.test_dataloader),
            'f1':f1/len(self.test_dataloader),
            'cider':cider/len(self.test_dataloader)
        }
        print(f"Evaluation scores on test - wups: {scores['wups']:.4f} - em: {scores['em']:.4f} - f1: {scores['f1']:.4f} - cider: {scores['cider']:.4f}")

        df = pd.DataFrame(results)
        df.to_csv(os.path.join(self.checkpoint_path,'result.csv'), index=False)
        print(f"Save result to: {os.path.join(self.checkpoint_path,'result.csv')}")

In [27]:
task = ClassificationTask(CONFIG)

Loading vocab from /content/drive/MyDrive/ViVQA-Models/beit_mbert_classification/vocab.bin
Loading data
[INFO] Train size: 9455
[INFO] Valid size: 2364
[INFO] Test size: 2987


In [28]:
task.start()

Loading checkpoint from /content/drive/MyDrive/ViVQA-Models/beit_mbert_classification/last_model.pth
Resuming from epoch 43


Epoch 43 - Training  : 100%|██████████| 296/296 [03:41<00:00,  1.33it/s, loss=0.843, lr=1e-6]
Epoch 43 - Evaluation: 100%|██████████| 74/74 [00:53<00:00,  1.37it/s]


Validation wups: 0.4558 - em: 0.4558 - f1: 0.5161 - cider: 2.0828


Epoch 44 - Training  : 100%|██████████| 296/296 [03:39<00:00,  1.35it/s, loss=0.692, lr=8.5e-7]
Epoch 44 - Evaluation: 100%|██████████| 74/74 [00:53<00:00,  1.37it/s]


Validation wups: 0.4548 - em: 0.4548 - f1: 0.5154 - cider: 2.0837


Epoch 45 - Training  : 100%|██████████| 296/296 [03:37<00:00,  1.36it/s, loss=0.754, lr=7.22e-7]
Epoch 45 - Evaluation: 100%|██████████| 74/74 [00:53<00:00,  1.37it/s]


Validation wups: 0.4579 - em: 0.4579 - f1: 0.5169 - cider: 2.0944


Epoch 46 - Training  : 100%|██████████| 296/296 [03:38<00:00,  1.36it/s, loss=0.686, lr=6.14e-7]
Epoch 46 - Evaluation: 100%|██████████| 74/74 [00:53<00:00,  1.39it/s]


Validation wups: 0.4578 - em: 0.4578 - f1: 0.5172 - cider: 2.0938


Epoch 47 - Training  : 100%|██████████| 296/296 [03:38<00:00,  1.35it/s, loss=0.686, lr=5.22e-7]
Epoch 47 - Evaluation: 100%|██████████| 74/74 [00:53<00:00,  1.37it/s]


Validation wups: 0.4565 - em: 0.4565 - f1: 0.5166 - cider: 2.0886
Patience reached.


In [29]:
task.get_predictions()

Loading checkpoint from /content/drive/MyDrive/ViVQA-Models/beit_mbert_classification/best_model.pth


Getting predictions: 100%|██████████| 94/94 [01:12<00:00,  1.29it/s]

Evaluation scores on test - wups: 0.4607 - em: 0.4607 - f1: 0.5031 - cider: 2.0601
Save result to: /content/drive/MyDrive/ViVQA-Models/beit_mbert_classification/result.csv





In [30]:
!pip install -q torchinfo

In [33]:
from torchinfo import summary
summary(task.model)

Layer (type:depth-idx)                                                      Param #
BEiTmBERTClassification                                                     --
├─BertEmbedding: 1-1                                                        --
│    └─TextBert: 2-1                                                        --
│    │    └─BertEmbeddings: 3-1                                             (81,711,360)
│    │    └─BertEncoder: 3-2                                                (85,054,464)
│    └─Linear: 2-2                                                          393,728
│    └─GELU: 2-3                                                            --
│    └─Dropout: 2-4                                                         --
├─BEiTEmbedding: 1-2                                                        --
│    └─BeitModel: 2-5                                                       --
│    │    └─BeitEmbeddings: 3-3                                             (591,360)
│    │    └─Bei