# Preamble: Install and Import Packages

In [1]:
!pip install transformers==4.40.1 git+https://github.com/csebuetnlp/normalizer multilingual-clip open_clip_torch -q

[0m

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import torch.nn.functional as F
from torchvision.transforms import Resize
from torchvision.io import read_image, ImageReadMode
from multilingual_clip import Config_MCLIP
import open_clip
import json
import pandas as pd
import random
from pathlib import Path
import cv2
import numpy as np
import transformers
from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoImageProcessor, AutoModelForMaskedLM
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score, classification_report
from PIL import Image
import os
import gc
import time
import math
from normalizer import normalize

2024-06-14 13:00:23.798003: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-14 13:00:23.798120: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-14 13:00:24.074668: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
torch.autograd.set_detect_anomaly(True)
transformers.__version__

'4.40.1'

# Initialise the Configuration and Random Seeds

In [4]:
class CFG:
    save_models = True
    init_weights = True
    tokeniser_model_id = 'csebuetnlp/banglabert'
    text_model_id = 'csebuetnlp/banglabert'
    image_model_id = 'google/vit-base-patch16-224-in21k'
    tokeniser_model = None
    text_model = None
    image_model = None
    text_model_config = None
    image_model_config = None
    lang = 'bn'
    
    images_base_path = Path('/kaggle/input/vqa-bangla/Bangla_VQA/Bangla_VQA/images')
    images_base_path_test = Path('/kaggle/input/vqa-bangla/Bangla_VQA/Bangla_VQA/images')
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    debug = False
    print_freq = 50
    apex = True # for faster training
    epochs = 15
    learning_rate = 2e-5  # for adam optimizer
    eps = 1e-6
    betas = (0.9, 0.999)  # for adam optimizer
    batch_size = 64
    max_len = 512
    weight_decay = 0.01  # for adam optimizer regulaization parameter
    gradient_accumulation_steps = 1
    max_grad_norm = 1000
    seed = 42
    train = True
    num_classes = 0
    
    frozen_lm = False # True or False 
    fusion_mode = "co_attention" # "merged_attention", "co_attention"
    no_fusion_encoder = 2
    num_heads = 4
    
    mlp_hidden_size = 256
    mlp_hidden_layers = 0
    mlp_dropout = 0.1
    mlp_grad_clip = 1.0
    mlp_init_range = 0.2
    mlp_attn_dim = 256

In [5]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_everything(CFG.seed)

# Preprocess the Dataset

In [6]:
def normalise_bn(text_bn):
    return normalize(
        text_bn,
        unicode_norm="NFKC",
        punct_replacement=None,
        url_replacement=None,
        emoji_replacement=None,
        apply_unicode_norm_last=True
    )

In [7]:
train_df = pd.read_csv("/kaggle/input/vqa-bangla/updated_train.csv")
val_df = pd.read_csv("/kaggle/input/vqa-bangla/updated_valid.csv")
test_df = pd.read_csv("/kaggle/input/vqa-bangla/updated_test.csv")

df = pd.concat([train_df, val_df, test_df], ignore_index=True)

df.head()

Unnamed: 0,image_name,Captions,Question,Answer,Category,Question_en,Answer_en,Captions_en,Answer_fixed
0,bnature_663.jpg,খালের পানিতে তিনটি গাছের প্রতিচ্ছবি সাথে গৌধোল...,ছবিতে কতগুলো গাছের প্রতিচ্ছবি দেখা যাচ্ছে?,তিনটি,numeric,How many trees are reflected in the picture?,three,Goudholi's beauty with three trees reflected i...,তিন
1,chitron_5113.png,অনেকগুলো মানুষ বসে আছে। মঞ্চের উপর কয়েকজন মানু...,ছবিতে কতজন মানুষ মঞ্চের উপর দাঁড়িয়ে আছে?,পাঁচজন,numeric,How many people are on the stage?,five,"A lot of people were sitting, a few people wer...",পাঁচ
2,bnature_876.jpg,দুজন ছেলে ও দুজন মেয়ে রাস্তা দিয়ে পাশাপাশি হ...,ছবিতে কতজন ছেলে ও মেয়ে একসাথে হাটছে?,চারজন,numeric,How many boys and girls are walking together i...,four,Two boys and two girls walking side by side on...,চার
3,bnature_1007.jpg,"রাস্তা দিয়ে কয়েকজন ছাত্র ছাত্রী যাচ্ছে, যাদে...",ছবিতে কতজন ছাত্র ছাত্রী রাস্তা দিয়ে হাঁটছে?,৪ জন,numeric,How many students are walking on the street in...,four,"Several students walking on the street, carryi...",চার
4,chitron_7446.png,'১ ইট তালগাছ ১ টি খেজুর গাছ এবং রাস্তা দিয়ে ছা...,ছবিতে কতগুলো গাছ দেখা যাচ্ছে?,২ টি,numeric,How many trees are shown in the picture?,Two,1 brick palm tree 1 date tree and 4 school stu...,দুই


In [8]:
all_labels = list(set(df['Answer_fixed' if CFG.lang == 'bn' else 'Answer_en'].unique().astype(str)))
all_labels.sort()
label_map = dict()
CFG.num_classes = len(all_labels)
for idx, label in enumerate(all_labels):
    label_map[normalise_bn(str(label)) if CFG.lang == 'bn' else str(label)] = idx

# Initialise the Processors/Tokenisers/Models

In [9]:
CFG.text_model_config = AutoConfig.from_pretrained(CFG.text_model_id) if not 'M-CLIP' in CFG.text_model_id else None
CFG.image_model_config = AutoConfig.from_pretrained(CFG.image_model_id) if not 'M-CLIP' in CFG.text_model_id else None



config.json:   0%|          | 0.00/586 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

In [10]:
CFG.tokenizer = AutoTokenizer.from_pretrained(CFG.tokeniser_model_id, padding=True, truncation=True)
CFG.processor = AutoImageProcessor.from_pretrained(CFG.image_model_id)
CFG.text_model = AutoModelForMaskedLM.from_pretrained(CFG.text_model_id).to(CFG.device)
CFG.text_model_vanilla = AutoModelForMaskedLM.from_pretrained(CFG.text_model_id).to(CFG.device)
CFG.image_model = AutoModel.from_pretrained(CFG.image_model_id).to(CFG.device)

tokenizer_config.json:   0%|          | 0.00/119 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/528k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]



pytorch_model.bin:   0%|          | 0.00/443M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()
Some weights of ElectraForMaskedLM were not initialized from the model checkpoint at csebuetnlp/banglabert and are newly initialized: ['generator_lm_head.bias', 'generator_predictions.LayerNorm.bias', 'generator_predictions.LayerNorm.weight', 'generator_predictions.dense.bias', 'generator_predictions.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ElectraForMaskedLM were not initialized from the model checkpoint at csebuetnlp/banglabert and are newly initialized: ['generator_lm_head.bias', 'generator_predictions.LayerNorm.bias', 'generator_predictions.LayerNorm.weight', 'generator_predictions.dense.bias', 'generator_predictions.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

# Custom Dataset Definition

In [11]:
class VQADataset(Dataset):
    def __init__(self, features, img_dir, img_transform=None, caption_transform=None, target_transform=None):
        self.features = features
        self.img_dir = img_dir
        self.img_transform = img_transform
        self.caption_transform = caption_transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        img_path = str(self.img_dir.joinpath(self.features['image_name'].iloc[idx]))
        image = read_image(img_path, mode=ImageReadMode.RGB).to(device=CFG.device)
        caption = normalise_bn(self.features['Question' if CFG.lang == 'bn' else 'Question_en'].iloc[idx])
        identity = self.features['image_name'].iloc[idx]
        label = torch.tensor(label_map[normalise_bn(str(self.features['Answer_fixed'].iloc[idx])) if CFG.lang == 'bn' else str(self.features['Answer_en'].iloc[idx])], dtype=torch.long)
        
        if self.img_transform:
            image = self.img_transform(image)
        if self.caption_transform:
            caption = self.caption_transform(caption)
        if self.target_transform:
            label = self.target_transform(label)
            
        processed_img = CFG.processor(images=image, return_tensors="pt")
        image = processed_img['pixel_values']
        
        processed_txt = CFG.tokenizer.encode_plus(
            caption,
            padding='longest',
            truncation=True,
            return_tensors='pt'
        )
        seq = processed_txt['input_ids']
        mask = processed_txt['attention_mask']
        
        return identity, image, seq, mask, label

In [12]:
class Collator(object):
    def __init__(self, test=False):
        self.test = test
    def __call__(self, batch):
        ids, images, seqs, masks, labels = zip(*batch)

        seqs = [seq.squeeze(dim=0) for seq in seqs]
        masks = [mask.squeeze(dim=0) for mask in masks]
        images = [image.squeeze(dim=0) for image in images]
        labels = torch.stack(labels)

        seqs = nn.utils.rnn.pad_sequence(seqs, batch_first=True)
        masks = nn.utils.rnn.pad_sequence(masks, batch_first=True)

        images = torch.stack(images)
        
        return ids, images, seqs, masks, labels

In [13]:
resizer = Resize((224, 224), antialias=True)

def resize_images(img_tensor):
    return resizer(img_tensor)

# Dataset Initialisation

In [14]:
train_dataset = VQADataset(train_df, CFG.images_base_path, img_transform=resize_images)
val_dataset = VQADataset(val_df, CFG.images_base_path, img_transform=resize_images)
test_dataset = VQADataset(test_df, CFG.images_base_path, img_transform=resize_images)

print(len(train_dataset))
print(len(val_dataset))
print(len(test_dataset))

12231
1529
1532


# Model Architecture

In [15]:
class ITMHead(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.head = nn.Linear(hidden_size, 2, dtype=torch.float16)
        
    def forward(self, embedding):
        x = self.head(embedding)
        
        return x

In [16]:
class CrossAttention(nn.Module):
    def __init__(self, input_dim):
        super(CrossAttention, self).__init__()
        self.query_linear = nn.Linear(input_dim, input_dim)
        self.key_linear = nn.Linear(input_dim, input_dim)
        self.value_linear = nn.Linear(input_dim, input_dim)

    def forward(self, query, key, value):
        query_proj = self.query_linear(query)
        key_proj = self.key_linear(key)
        value_proj = self.value_linear(value)

        # Calculate attention scores
        scores = torch.matmul(query_proj, key_proj.transpose(-2, -1))
        attention_weights = F.softmax(scores, dim=-1)

        # Apply attention weights to values
        output = torch.matmul(attention_weights, value_proj)
        return output
    

class MultiHeadCrossAttention(nn.Module):
    def __init__(self, input_dim, num_heads):
        super(MultiHeadCrossAttention, self).__init__()
        assert input_dim % num_heads == 0, "input_dim must be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = input_dim // num_heads
        
        # Linear transformations for queries, keys, and values for each head
        self.query_linear = nn.Linear(input_dim, input_dim)
        self.key_linear = nn.Linear(input_dim, input_dim)
        self.value_linear = nn.Linear(input_dim, input_dim)
        
        # Final linear transformation after concatenating heads
        self.out_linear = nn.Linear(input_dim, input_dim)

    def forward(self, query, key, value):
        batch_size = query.size(0)
        
        # Apply linear transformations for queries, keys, and values for each head
        query = self.query_linear(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_len, head_dim)
        key = self.key_linear(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_len, head_dim)
        value = self.value_linear(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_len, head_dim)
        
        # Calculate attention scores
        scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attention_weights = F.softmax(scores, dim=-1)

        # Apply attention weights to values
        output = torch.matmul(attention_weights, value)  # (batch_size, num_heads, seq_len, head_dim)
        
        # Concatenate heads and perform final linear transformation
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)  # (batch_size, seq_len, input_dim)
        output = self.out_linear(output)
        return output
    
    
class CrossAttentionLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads, dropout):
        super(CrossAttentionLayer, self).__init__()
        self.self_attention = nn.MultiheadAttention(input_dim, num_heads)
        self.cross_attention = MultiHeadCrossAttention(input_dim, num_heads)
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, input_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, memory):
        # Self-attention
        x, _ = self.self_attention(x, x, x)

        # Cross-attention
        x = self.cross_attention(x, memory, memory)

        x = self.linear2(F.relu(self.linear1(x)))
        x = self.dropout(x)
        return x

class CrossAttentionEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_heads, dropout = 0.05):
        super(CrossAttentionEncoder, self).__init__()
        self.layers = nn.ModuleList([CrossAttentionLayer(input_dim, hidden_dim, num_heads, dropout) 
                                     for _ in range(num_layers)])

    def forward(self, x, memory):
        for layer in self.layers:
            x = layer(x, memory)
        return x

In [17]:
class MeanPooling(nn.Module):
    def __init__(self):
        super(MeanPooling, self).__init__()
        
    def forward(self, last_hidden_state, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        mean_embeddings = sum_embeddings / sum_mask
        return mean_embeddings

In [18]:
class BanCAP_Pretraining(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.itm = ITMHead(self.cfg.text_model_config.hidden_size)
        self.text_model = self.cfg.text_model
        self.image_model = self.cfg.image_model
        
        self.hidden_size = 768
        self.v_proj = torch.nn.Linear(self.hidden_size, self.hidden_size)
        self.l_proj = torch.nn.Linear(self.hidden_size, self.hidden_size)
        
        if self.cfg.fusion_mode == "merged_attention":
            self.fusion_encoder = nn.TransformerEncoderLayer(d_model=self.hidden_size, 
                                                        nhead=self.cfg.num_heads)
            self.fusion_layers = nn.TransformerEncoder(self.fusion_encoder, 
                                                  num_layers=self.cfg.no_fusion_encoder)
        
        elif self.cfg.fusion_mode == "co_attention":
            self.pool = MeanPooling()
            self.visual_cross_encoder = CrossAttentionEncoder(self.hidden_size, self.hidden_size,
                                                        self.cfg.no_fusion_encoder,
                                                        self.cfg.num_heads)
            
            self.language_cross_encoder = CrossAttentionEncoder(self.hidden_size, self.hidden_size,
                                                        self.cfg.no_fusion_encoder,
                                                        self.cfg.num_heads)
        
        
        self.text_mlp = nn.Linear(self.hidden_size, self.hidden_size)
        self.vision_mlp = nn.Linear(self.hidden_size, self.hidden_size)
        
        if self.cfg.frozen_lm:
            self._frozen_lm()
            
    def _frozen_lm(self):
        for param in self.text_model.parameters():
            param.requires_grad = False
            
        for param in self.image_model.parameters():
            param.requires_grad = False
            

    def forward(self, input_ids, attention_mask, image):
        
        text_output = self.text_model(input_ids=input_ids, 
                                      attention_mask=attention_mask,
                                      output_hidden_states=True)
        image_output = self.image_model(image)

        visual_features = image_output.last_hidden_state
        text_features = text_output.hidden_states[-1]
        
        visual_embeds = visual_features[:, 0, :]
        text_embeds = text_features[:, 0, :]
        
        visual_poj = self.v_proj(visual_features)
        language_proj = self.l_proj(text_features)
  
        if self.cfg.fusion_mode == "co_attention":
            vision_final = self.visual_cross_encoder(visual_poj, language_proj)
            language_final = self.language_cross_encoder(language_proj, visual_poj)
            
            vision_final = torch.mean(vision_final, dim = 1)
            language_final = self.pool(language_final, attention_mask)
            
            vision_final = self.vision_mlp(vision_final)
            language_final = self.text_mlp(language_final)
            
        elif self.cfg.fusion_mode == "merged_attention":
            merged_embed = torch.cat((language_proj, visual_poj), dim = 1)
            merged_attention_features = self.fusion_layers(merged_embed)
            
            text_embed_len = attention_mask.size(1)
            language_final = merged_attention_features[:, :text_embed_len, : ]
            vision_final = merged_attention_features[:, text_embed_len:, : ]
            
            vision_final = self.vision_mlp(vision_final)
            language_final = self.text_mlp(language_final)
        
        return language_final, vision_final

In [19]:
class BanCAP_Pretraining_Classifier(torch.nn.Module):
    def __init__(self, backbone_model, cfg):
        super().__init__()
        self.backbone_model = backbone_model
        self.hidden_size = 768
        self.cfg = cfg
        self.classification_head = nn.Linear(2 * self.hidden_size, self.cfg.num_classes)
        
    def forward(self, input_ids, attention_mask, image):
        language_final, vision_final = self.backbone_model(input_ids, attention_mask, image)
        mm_final = torch.cat((language_final, vision_final), dim=1)
        output = self.classification_head(mm_final)
        
        return output

# Utility Functions

In [20]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))

In [21]:
def get_score(y_trues, y_preds):
    accuracy = accuracy_score(y_trues, y_preds)
    return accuracy

In [22]:
class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)

In [23]:
# ====================================================
# Logger File
# ====================================================

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report

OUTPUT_DIR = "./"
def get_logger(filename=OUTPUT_DIR+'train'):
    from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=f"{filename}.log")
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = get_logger()

# Train/Val Loops

In [24]:
def train_loop(model, optimizer, loss_fn, train_dataloader, epoch):
    model.train()
    train_losses = AverageMeter()
    start = end = time.time()
    
    for step, (_, image, seq, mask, label) in enumerate(tqdm(train_dataloader)):
        train_image = image.to(CFG.device)
        train_seq = seq.to(CFG.device)
        train_mask = mask.to(CFG.device)
        train_label = label.to(device=CFG.device)
        
        batch_size = train_image.shape[0]

        with torch.cuda.amp.autocast(enabled=CFG.apex):
            output = model(train_seq, train_mask, train_image)
            
        loss = loss_fn(output, train_label)
        
        optimizer.zero_grad()
        loss.backward()
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), CFG.mlp_grad_clip)
        optimizer.step()
        
        train_losses.update(loss.item(), batch_size)
            
        end = time.time()

        if step % CFG.print_freq == 0 or step == (len(train_dataloader) - 1):
            print(f'Epoch: [{epoch + 1}][{step}/{len(train_dataloader)}] '
                  f'Elapsed {timeSince(start, float(step + 1) / len(train_dataloader)):s} '
                  f'Loss: {train_losses.val:.4f} ({train_losses.avg:.4f}) ')
        
        if step % 100 == 0:
            torch.cuda.empty_cache()
            gc.collect()

    return train_losses.avg

In [25]:
def validation_loop(model, loss_fn, valid_dataloader, epoch):
    all_ids = []
    all_preds = []
    all_labels = []
    
    model.eval()
    validation_losses = AverageMeter()
    start = end = time.time()
    global_step = 0
    
    for step, (identity, image, seq, mask, label) in enumerate(tqdm(valid_dataloader)):
        image = image.to(device=CFG.device)
        seq = seq.to(device=CFG.device)
        mask = mask.to(device=CFG.device)
        label = label.to(device=CFG.device)
        
        batch_size = image.shape[0]

        with torch.no_grad():
            with torch.cuda.amp.autocast(enabled=CFG.apex):
                output = model(seq, mask, image)

        loss = loss_fn(output, label)
        
        validation_losses.update(loss.item(), batch_size)
        predicted = output.argmax(dim=1)

        all_ids += list(identity)
        all_labels.append(label)
        all_preds.append(predicted)
            
        end = time.time()

        if step % CFG.print_freq == 0 or step == (len(valid_dataloader) - 1):
            print(f'Epoch: [{epoch + 1}][{step}/{len(valid_dataloader)}] '
                  f'Elapsed {timeSince(start, float(step + 1) / len(valid_dataloader)):s} '
                  f'Loss: {validation_losses.val:.4f} ({validation_losses.avg:.4f})')
        
        if step % 100 == 0:
            torch.cuda.empty_cache()
            gc.collect()
            
    all_preds = torch.cat(all_preds, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    
    all_preds_np = all_preds.cpu().numpy().astype(int)
    all_labels_np = all_labels.cpu().numpy().astype(int)
        
    return validation_losses.avg, all_ids, all_preds_np, all_labels_np

In [26]:
cos = nn.CosineSimilarity(dim=1, eps=1e-6)

def bert_scorer(df):
    sim_list = []
    
    for index, row in df.iterrows():
        original = CFG.tokenizer.encode_plus(normalise_bn(str(row['label'])), return_tensors="pt").to(CFG.device)
        preds = CFG.tokenizer.encode_plus(normalise_bn(str(row['pred'])), return_tensors="pt").to(CFG.device)
        
        with torch.no_grad():
            d1 = CFG.text_model_vanilla(original['input_ids'], attention_mask=original['attention_mask'], output_hidden_states=True).hidden_states[-1][:, 0, :]
            d2 = CFG.text_model_vanilla(preds['input_ids'], attention_mask=preds['attention_mask'], output_hidden_states=True).hidden_states[-1][:, 0, :]
        
        sim_list.append(cos(d1, d2).item())
        
    return sim_list

In [27]:
def save_predictions(model_name, ids, labels, preds, split='val'):
    entries = []
    for identity, label, pred in zip(ids, labels, preds):
        entry = {
            'identity': identity,
            'label': all_labels[label],
            'pred': all_labels[pred]
        }
        entries.append(entry)

    with open(f'/kaggle/working/{model_name}_{split}_preds.json', 'w') as fp:
        json.dump(entries, fp, cls=NpEncoder)
        
    preds_df = pd.DataFrame.from_dict(entries)
    similarity_list = bert_scorer(preds_df)
    
    report_dict = classification_report(labels, preds, digits=4, zero_division=0, output_dict=True)
    report_dict['bert_score'] = {
        'model': CFG.text_model_config._name_or_path,
        'bert_score_mean': sum(similarity_list) / len(similarity_list)
    }
    
    with open(f'/kaggle/working/{model_name}_{split}_results.json', 'w') as fp:
        json.dump(report_dict, fp)

# Training and Validation

In [28]:
collate = Collator()
train_dataloader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, collate_fn=collate)
valid_dataloader = DataLoader(val_dataset, batch_size=CFG.batch_size, shuffle=False, collate_fn=collate)
test_dataloader = DataLoader(test_dataset, batch_size=CFG.batch_size, shuffle=False, collate_fn=collate)

loss_fn = nn.CrossEntropyLoss()

In [29]:
model_id = 'co_attention__itm+mlm+ucl_loss1.0739_best'
backbone_model = BanCAP_Pretraining(CFG).to(CFG.device)
backbone_model = nn.DataParallel(backbone_model)
backbone_model.load_state_dict(torch.load(f'/kaggle/input/bancap-pretraining-for-banvqa/{model_id}.pth', map_location=torch.device(CFG.device))['model'])

model = BanCAP_Pretraining_Classifier(backbone_model, CFG).to(CFG.device)
optim = AdamW(model.parameters(), lr=CFG.learning_rate, eps=CFG.eps, betas=CFG.betas)

In [30]:
best_score = float('-inf')

for epoch in range(CFG.epochs):

    start_time = time.time()

    # train function 
    avg_train_loss = train_loop(model, optim, loss_fn, train_dataloader, epoch)

    # val function 
    avg_val_loss, all_ids, all_preds_np, all_labels_np = validation_loop(model, loss_fn, valid_dataloader, epoch)
    
    score = get_score(all_labels_np, all_preds_np)

    elapsed = time.time() - start_time

    LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_train_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')

    if CFG.save_models and score > best_score:
        model_name = model_id + f'_score_{score:.4f}' + f'_{CFG.lang}'
        torch.save({'model': model.state_dict()}, f'{model_name}.pth')
        print(f'Saved model: {model_name}')
        best_score = score
        
        save_predictions(model_name, all_ids, all_labels_np, all_preds_np)
        
        avg_test_loss, all_ids_test, all_preds_np_test, all_labels_np_test = validation_loop(model, loss_fn, test_dataloader, epoch)
        save_predictions(model_name, all_ids_test, all_labels_np_test, all_preds_np_test, split='test')

    torch.cuda.empty_cache()
    gc.collect()

  0%|          | 0/192 [00:00<?, ?it/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Epoch: [1][0/192] Elapsed 0m 10s (remain 33m 42s) Loss: 8.6484 (8.6484) 
Epoch: [1][50/192] Elapsed 4m 19s (remain 11m 56s) Loss: 7.3945 (7.9907) 
Epoch: [1][100/192] Elapsed 8m 18s (remain 7m 29s) Loss: 7.4961 (7.6870) 
Epoch: [1][150/192] Elapsed 12m 29s (remain 3m 23s) Loss: 6.9844 (7.4746) 
Epoch: [1][191/192] Elapsed 15m 53s (remain 0m 0s) Loss: 6.7422 (7.3285) 


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [1][0/24] Elapsed 0m 4s (remain 1m 34s) Loss: 2.8223 (2.8223)


Epoch 1 - avg_train_loss: 7.3285  avg_val_loss: 6.8877  time: 1030s


Epoch: [1][23/24] Elapsed 1m 16s (remain 0m 0s) Loss: 8.4766 (6.8877)
Saved model: co_attention__itm+mlm+ucl_loss1.0739_best_score_0.0687_bn


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [1][0/24] Elapsed 0m 3s (remain 1m 9s) Loss: 2.6797 (2.6797)
Epoch: [1][23/24] Elapsed 1m 12s (remain 0m 0s) Loss: 8.3750 (6.8200)


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch: [2][0/192] Elapsed 0m 3s (remain 10m 52s) Loss: 6.6719 (6.6719) 
Epoch: [2][50/192] Elapsed 2m 52s (remain 7m 57s) Loss: 6.2461 (6.5982) 
Epoch: [2][100/192] Elapsed 5m 39s (remain 5m 6s) Loss: 7.1406 (6.5914) 
Epoch: [2][150/192] Elapsed 8m 28s (remain 2m 18s) Loss: 6.6367 (6.5654) 
Epoch: [2][191/192] Elapsed 10m 45s (remain 0m 0s) Loss: 5.0000 (6.5406) 


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [2][0/24] Elapsed 0m 1s (remain 0m 34s) Loss: 2.6348 (2.6348)


Epoch 2 - avg_train_loss: 6.5406  avg_val_loss: 6.9518  time: 679s


Epoch: [2][23/24] Elapsed 0m 33s (remain 0m 0s) Loss: 8.6797 (6.9518)
Saved model: co_attention__itm+mlm+ucl_loss1.0739_best_score_0.1020_bn


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [2][0/24] Elapsed 0m 1s (remain 0m 34s) Loss: 2.5469 (2.5469)
Epoch: [2][23/24] Elapsed 0m 33s (remain 0m 0s) Loss: 8.3516 (6.8632)


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch: [3][0/192] Elapsed 0m 3s (remain 11m 38s) Loss: 6.4883 (6.4883) 
Epoch: [3][50/192] Elapsed 2m 49s (remain 7m 49s) Loss: 5.6016 (6.2758) 
Epoch: [3][100/192] Elapsed 5m 36s (remain 5m 2s) Loss: 5.9805 (6.3183) 
Epoch: [3][150/192] Elapsed 8m 22s (remain 2m 16s) Loss: 6.8398 (6.3536) 
Epoch: [3][191/192] Elapsed 10m 37s (remain 0m 0s) Loss: 5.8203 (6.3487) 


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [3][0/24] Elapsed 0m 1s (remain 0m 32s) Loss: 2.4258 (2.4258)


Epoch 3 - avg_train_loss: 6.3487  avg_val_loss: 6.9463  time: 670s


Epoch: [3][23/24] Elapsed 0m 32s (remain 0m 0s) Loss: 8.7188 (6.9463)
Saved model: co_attention__itm+mlm+ucl_loss1.0739_best_score_0.1053_bn


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [3][0/24] Elapsed 0m 1s (remain 0m 33s) Loss: 2.5645 (2.5645)
Epoch: [3][23/24] Elapsed 0m 33s (remain 0m 0s) Loss: 8.3594 (6.8848)


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch: [4][0/192] Elapsed 0m 3s (remain 10m 36s) Loss: 6.0625 (6.0625) 
Epoch: [4][50/192] Elapsed 2m 48s (remain 7m 45s) Loss: 5.9414 (6.2238) 
Epoch: [4][100/192] Elapsed 5m 31s (remain 4m 59s) Loss: 5.5547 (6.1817) 
Epoch: [4][150/192] Elapsed 8m 18s (remain 2m 15s) Loss: 6.4648 (6.1858) 
Epoch: [4][191/192] Elapsed 10m 32s (remain 0m 0s) Loss: 5.0586 (6.1722) 


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [4][0/24] Elapsed 0m 1s (remain 0m 32s) Loss: 2.4160 (2.4160)


Epoch 4 - avg_train_loss: 6.1722  avg_val_loss: 6.9645  time: 665s


Epoch: [4][23/24] Elapsed 0m 32s (remain 0m 0s) Loss: 8.6641 (6.9645)
Saved model: co_attention__itm+mlm+ucl_loss1.0739_best_score_0.1092_bn


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [4][0/24] Elapsed 0m 1s (remain 0m 34s) Loss: 2.6016 (2.6016)
Epoch: [4][23/24] Elapsed 0m 32s (remain 0m 0s) Loss: 8.2031 (6.9068)


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch: [5][0/192] Elapsed 0m 3s (remain 10m 40s) Loss: 5.8477 (5.8477) 
Epoch: [5][50/192] Elapsed 2m 48s (remain 7m 46s) Loss: 5.7031 (6.0592) 
Epoch: [5][100/192] Elapsed 5m 33s (remain 5m 0s) Loss: 5.9180 (6.0148) 
Epoch: [5][150/192] Elapsed 8m 18s (remain 2m 15s) Loss: 6.1016 (6.0001) 
Epoch: [5][191/192] Elapsed 10m 32s (remain 0m 0s) Loss: 4.2891 (5.9552) 


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [5][0/24] Elapsed 0m 1s (remain 0m 32s) Loss: 2.5859 (2.5859)


Epoch 5 - avg_train_loss: 5.9552  avg_val_loss: 7.0540  time: 665s


Epoch: [5][23/24] Elapsed 0m 32s (remain 0m 0s) Loss: 8.7422 (7.0540)


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch: [6][0/192] Elapsed 0m 3s (remain 10m 31s) Loss: 5.8867 (5.8867) 
Epoch: [6][50/192] Elapsed 2m 48s (remain 7m 46s) Loss: 5.4883 (5.7461) 
Epoch: [6][100/192] Elapsed 5m 32s (remain 5m 0s) Loss: 5.1758 (5.7445) 
Epoch: [6][150/192] Elapsed 8m 19s (remain 2m 15s) Loss: 5.5820 (5.7223) 
Epoch: [6][191/192] Elapsed 10m 33s (remain 0m 0s) Loss: 3.0742 (5.7263) 


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [6][0/24] Elapsed 0m 1s (remain 0m 32s) Loss: 2.3887 (2.3887)


Epoch 6 - avg_train_loss: 5.7263  avg_val_loss: 6.9015  time: 666s


Epoch: [6][23/24] Elapsed 0m 32s (remain 0m 0s) Loss: 8.4141 (6.9015)


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch: [7][0/192] Elapsed 0m 3s (remain 10m 25s) Loss: 6.0508 (6.0508) 
Epoch: [7][50/192] Elapsed 2m 48s (remain 7m 44s) Loss: 5.3359 (5.4519) 
Epoch: [7][100/192] Elapsed 5m 33s (remain 5m 0s) Loss: 5.4766 (5.4467) 
Epoch: [7][150/192] Elapsed 8m 18s (remain 2m 15s) Loss: 5.9180 (5.4822) 
Epoch: [7][191/192] Elapsed 10m 32s (remain 0m 0s) Loss: 5.7383 (5.4994) 


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [7][0/24] Elapsed 0m 1s (remain 0m 32s) Loss: 2.5879 (2.5879)


Epoch 7 - avg_train_loss: 5.4994  avg_val_loss: 7.0655  time: 666s


Epoch: [7][23/24] Elapsed 0m 32s (remain 0m 0s) Loss: 8.7500 (7.0655)


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch: [8][0/192] Elapsed 0m 3s (remain 10m 44s) Loss: 4.7812 (4.7812) 
Epoch: [8][50/192] Elapsed 2m 48s (remain 7m 46s) Loss: 5.5859 (5.2739) 
Epoch: [8][100/192] Elapsed 5m 34s (remain 5m 0s) Loss: 4.5039 (5.2816) 
Epoch: [8][150/192] Elapsed 8m 20s (remain 2m 15s) Loss: 5.3359 (5.2631) 
Epoch: [8][191/192] Elapsed 10m 34s (remain 0m 0s) Loss: 6.2734 (5.2760) 


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [8][0/24] Elapsed 0m 1s (remain 0m 33s) Loss: 2.4453 (2.4453)


Epoch 8 - avg_train_loss: 5.2760  avg_val_loss: 7.2027  time: 667s


Epoch: [8][23/24] Elapsed 0m 32s (remain 0m 0s) Loss: 8.9297 (7.2027)


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch: [9][0/192] Elapsed 0m 3s (remain 11m 40s) Loss: 5.2148 (5.2148) 
Epoch: [9][50/192] Elapsed 2m 49s (remain 7m 49s) Loss: 5.3359 (4.9557) 
Epoch: [9][100/192] Elapsed 5m 34s (remain 5m 1s) Loss: 5.1250 (5.0191) 
Epoch: [9][150/192] Elapsed 8m 20s (remain 2m 15s) Loss: 4.9102 (5.0580) 
Epoch: [9][191/192] Elapsed 10m 35s (remain 0m 0s) Loss: 5.6133 (5.0784) 


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [9][0/24] Elapsed 0m 1s (remain 0m 32s) Loss: 2.4668 (2.4668)


Epoch 9 - avg_train_loss: 5.0784  avg_val_loss: 7.2777  time: 668s


Epoch: [9][23/24] Elapsed 0m 32s (remain 0m 0s) Loss: 9.0469 (7.2777)
Saved model: co_attention__itm+mlm+ucl_loss1.0739_best_score_0.1138_bn


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [9][0/24] Elapsed 0m 1s (remain 0m 33s) Loss: 2.5137 (2.5137)
Epoch: [9][23/24] Elapsed 0m 33s (remain 0m 0s) Loss: 8.3125 (7.0672)


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch: [10][0/192] Elapsed 0m 3s (remain 11m 30s) Loss: 4.9141 (4.9141) 
Epoch: [10][50/192] Elapsed 2m 49s (remain 7m 49s) Loss: 5.1523 (4.8329) 
Epoch: [10][100/192] Elapsed 5m 35s (remain 5m 2s) Loss: 5.0000 (4.8625) 
Epoch: [10][150/192] Elapsed 8m 21s (remain 2m 16s) Loss: 4.8047 (4.8970) 
Epoch: [10][191/192] Elapsed 10m 35s (remain 0m 0s) Loss: 5.7188 (4.9124) 


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [10][0/24] Elapsed 0m 1s (remain 0m 33s) Loss: 2.5938 (2.5938)


Epoch 10 - avg_train_loss: 4.9124  avg_val_loss: 7.7016  time: 668s


Epoch: [10][23/24] Elapsed 0m 32s (remain 0m 0s) Loss: 9.6484 (7.7016)
Saved model: co_attention__itm+mlm+ucl_loss1.0739_best_score_0.1158_bn


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [10][0/24] Elapsed 0m 1s (remain 0m 34s) Loss: 2.8086 (2.8086)
Epoch: [10][23/24] Elapsed 0m 34s (remain 0m 0s) Loss: 8.8438 (7.4364)


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch: [11][0/192] Elapsed 0m 3s (remain 11m 36s) Loss: 4.5039 (4.5039) 
Epoch: [11][50/192] Elapsed 2m 50s (remain 7m 50s) Loss: 4.8438 (4.5875) 
Epoch: [11][100/192] Elapsed 5m 37s (remain 5m 4s) Loss: 4.2656 (4.6625) 
Epoch: [11][150/192] Elapsed 8m 24s (remain 2m 16s) Loss: 4.6406 (4.7277) 
Epoch: [11][191/192] Elapsed 10m 38s (remain 0m 0s) Loss: 4.1523 (4.7575) 


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [11][0/24] Elapsed 0m 1s (remain 0m 33s) Loss: 2.6836 (2.6836)


Epoch 11 - avg_train_loss: 4.7575  avg_val_loss: 7.9234  time: 671s


Epoch: [11][23/24] Elapsed 0m 32s (remain 0m 0s) Loss: 9.8906 (7.9234)
Saved model: co_attention__itm+mlm+ucl_loss1.0739_best_score_0.1203_bn


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [11][0/24] Elapsed 0m 1s (remain 0m 34s) Loss: 2.6250 (2.6250)
Epoch: [11][23/24] Elapsed 0m 34s (remain 0m 0s) Loss: 9.0625 (7.5999)


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch: [12][0/192] Elapsed 0m 3s (remain 10m 27s) Loss: 4.3359 (4.3359) 
Epoch: [12][50/192] Elapsed 2m 47s (remain 7m 43s) Loss: 4.1719 (4.4054) 
Epoch: [12][100/192] Elapsed 5m 35s (remain 5m 2s) Loss: 4.6719 (4.5015) 
Epoch: [12][150/192] Elapsed 8m 23s (remain 2m 16s) Loss: 4.8945 (4.5704) 
Epoch: [12][191/192] Elapsed 10m 39s (remain 0m 0s) Loss: 5.1875 (4.6128) 


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [12][0/24] Elapsed 0m 1s (remain 0m 33s) Loss: 2.8359 (2.8359)


Epoch 12 - avg_train_loss: 4.6128  avg_val_loss: 8.1598  time: 672s


Epoch: [12][23/24] Elapsed 0m 32s (remain 0m 0s) Loss: 10.3125 (8.1598)


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch: [13][0/192] Elapsed 0m 3s (remain 11m 11s) Loss: 4.5586 (4.5586) 
Epoch: [13][50/192] Elapsed 2m 48s (remain 7m 45s) Loss: 3.9434 (4.3560) 
Epoch: [13][100/192] Elapsed 5m 34s (remain 5m 1s) Loss: 4.9531 (4.4028) 
Epoch: [13][150/192] Elapsed 8m 20s (remain 2m 15s) Loss: 4.5859 (4.4448) 
Epoch: [13][191/192] Elapsed 10m 36s (remain 0m 0s) Loss: 4.9727 (4.4867) 


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [13][0/24] Elapsed 0m 1s (remain 0m 32s) Loss: 2.6211 (2.6211)


Epoch 13 - avg_train_loss: 4.4867  avg_val_loss: 8.5974  time: 669s


Epoch: [13][23/24] Elapsed 0m 32s (remain 0m 0s) Loss: 10.8672 (8.5974)


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch: [14][0/192] Elapsed 0m 3s (remain 10m 35s) Loss: 3.8594 (3.8594) 
Epoch: [14][50/192] Elapsed 2m 48s (remain 7m 44s) Loss: 4.1289 (4.1809) 
Epoch: [14][100/192] Elapsed 5m 32s (remain 4m 59s) Loss: 4.1445 (4.2461) 
Epoch: [14][150/192] Elapsed 8m 19s (remain 2m 15s) Loss: 4.0547 (4.2941) 
Epoch: [14][191/192] Elapsed 10m 34s (remain 0m 0s) Loss: 6.2500 (4.3454) 


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [14][0/24] Elapsed 0m 1s (remain 0m 33s) Loss: 2.7559 (2.7559)


Epoch 14 - avg_train_loss: 4.3454  avg_val_loss: 8.8120  time: 667s


Epoch: [14][23/24] Elapsed 0m 32s (remain 0m 0s) Loss: 11.0625 (8.8120)


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch: [15][0/192] Elapsed 0m 3s (remain 10m 22s) Loss: 4.0469 (4.0469) 
Epoch: [15][50/192] Elapsed 2m 48s (remain 7m 45s) Loss: 4.4180 (4.0951) 
Epoch: [15][100/192] Elapsed 5m 34s (remain 5m 1s) Loss: 4.3242 (4.1397) 
Epoch: [15][150/192] Elapsed 8m 22s (remain 2m 16s) Loss: 4.7656 (4.1948) 
Epoch: [15][191/192] Elapsed 10m 36s (remain 0m 0s) Loss: 2.9062 (4.2317) 


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [15][0/24] Elapsed 0m 1s (remain 0m 33s) Loss: 2.6562 (2.6562)


Epoch 15 - avg_train_loss: 4.2317  avg_val_loss: 9.1029  time: 669s


Epoch: [15][23/24] Elapsed 0m 32s (remain 0m 0s) Loss: 11.2656 (9.1029)
Saved model: co_attention__itm+mlm+ucl_loss1.0739_best_score_0.1210_bn


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch: [15][0/24] Elapsed 0m 1s (remain 0m 34s) Loss: 3.1504 (3.1504)
Epoch: [15][23/24] Elapsed 0m 34s (remain 0m 0s) Loss: 10.1172 (8.6434)


In [31]:
# del model
# torch.cuda.empty_cache()
# gc.collect()

# Inference From Checkpoint

In [32]:
# inf_model_name = 'M-CLIP-XLM-Roberta-Large-Vit-B-16Plus-ViT-B-16-plus-240_score_0.0190'
# inf_model = MultiModalArch(
#     hidden_size=CFG.mlp_hidden_size,
#     hidden_layers=CFG.mlp_hidden_layers,
#     dropout=CFG.mlp_dropout,
#     num_classes=CFG.num_classes,
#     use_multimodal=CFG.use_multimodal,
#     use_dualencoder=CFG.use_dualencoder,
#     is_mclip=CFG.is_mclip
# ).to(CFG.device)
# if not CFG.is_mclip:
#     inf_model = nn.DataParallel(inf_model)
# inf_model.load_state_dict(torch.load('/kaggle/working/' + inf_model_name + f'_{CFG.lang}' + '.pth', map_location=torch.device(CFG.device))['model'])
# inf_model

In [33]:
# avg_val_loss, all_ids, all_preds_np, all_labels_np = validation_loop(inf_model, loss_fn, valid_dataloader, 0)
# if CFG.debug:
#     print(all_labels_np)
#     print(all_preds_np)
    
# score = get_score(all_labels_np, all_preds_np)

# report = classification_report(all_labels_np, all_preds_np, digits=4)
# print(report)

In [34]:
# entries = []
# for identity, label, pred in zip(all_ids, all_labels_np, all_preds_np):
#     entry = {
#         'identity': identity,
#         'label': all_labels[label],
#         'pred': all_labels[pred]
#     }
#     entries.append(entry)

# with open(f'/kaggle/working/{inf_model_name}_{CFG.lang}_val_preds.json', 'w') as fp:
#     json.dump(entries, fp, cls=NpEncoder)

In [35]:
# avg_test_loss, all_ids_test, all_preds_np_test, all_labels_np_test = validation_loop(inf_model, loss_fn, test_dataloader, 0)
# if CFG.debug:
#     print(all_labels_np_test)
#     print(all_preds_np_test)
    
# score_test = get_score(all_labels_np_test, all_preds_np_test)

# report_test = classification_report(all_labels_np_test, all_preds_np_test, digits=4)
# print(report_test)

In [36]:
# entries_test = []
# for identity, label, pred in zip(all_ids_test, all_labels_np_test, all_preds_np_test):
#     entry = {
#         'identity': identity,
#         'label': all_labels[label],
#         'pred': all_labels[pred]
#     }
#     entries_test.append(entry)

# with open(f'/kaggle/working/{inf_model_name}_{CFG.lang}_test_preds.json', 'w') as fp:
#     json.dump(entries_test, fp, cls=NpEncoder)