# Import libraries

In [1]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import cv2
from torch.utils.data import Dataset, DataLoader
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoTokenizer, AutoModel 
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F 
from torch.optim.lr_scheduler import LambdaLR
import json
from tqdm import tqdm
import sys
import math

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# Pretrained model

In [3]:
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224", use_fast=True)
image_model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
text_tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True, use_flash_attn=False)
text_model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True, use_flash_attn=False)

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/964 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.80k [00:00<?, ?B/s]

configuration_xlm_roberta.py:   0%|          | 0.00/6.54k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/jinaai/xlm-roberta-flash-implementation:
- configuration_xlm_roberta.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_lora.py:   0%|          | 0.00/15.4k [00:00<?, ?B/s]

modeling_xlm_roberta.py:   0%|          | 0.00/51.1k [00:00<?, ?B/s]

mlp.py:   0%|          | 0.00/7.62k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/jinaai/xlm-roberta-flash-implementation:
- mlp.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


block.py:   0%|          | 0.00/17.8k [00:00<?, ?B/s]

stochastic_depth.py:   0%|          | 0.00/3.76k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/jinaai/xlm-roberta-flash-implementation:
- stochastic_depth.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


mha.py:   0%|          | 0.00/34.4k [00:00<?, ?B/s]

rotary.py:   0%|          | 0.00/24.5k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/jinaai/xlm-roberta-flash-implementation:
- rotary.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/jinaai/xlm-roberta-flash-implementation:
- mha.py
- rotary.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/jinaai/xlm-roberta-flash-implementation:
- block.py
- stochastic_depth.py
- mha.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


embedding.py:   0%|          | 0.00/3.88k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/jinaai/xlm-roberta-flash-implementation:
- embedding.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


xlm_padding.py:   0%|          | 0.00/10.0k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/jinaai/xlm-roberta-flash-implementation:
- xlm_padding.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/jinaai/xlm-roberta-flash-implementation:
- modeling_xlm_roberta.py
- mlp.py
- block.py
- embedding.py
- xlm_padding.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/jinaai/xlm-roberta-flash-implementation:
- modeling_lora.py
- modeling_xlm_roberta.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors:   0%|          | 0.00/1.14G [00:00<?, ?B/s]

# Data Loader

## Define data loader

In [4]:
class SarcasmDataset(Dataset):
    def __init__(self, image_file_names, captions, labels, img_dir, ocr_df_path):
        self.image_file_names = image_file_names
        self.captions = captions
        self.labels = labels
        self.img_dir = img_dir
        self.ocr_df = pd.read_csv(ocr_df_path) 
        self.ocr_df["combined_text"] = self.ocr_df["combined_text"].fillna("")

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image_name = os.path.join(self.img_dir, self.image_file_names[idx])

        # Get OCR text
        ocr_text = self.ocr_df[self.ocr_df["image_name"] == self.image_file_names[idx]]["combined_text"].values[0]
        if not ocr_text:
            ocr_text = text_tokenizer.pad_token  

        caption = self.captions[idx]
        label = self.labels[idx]

        return image_name, ocr_text, caption, label

## Load data

In [5]:
with open('/kaggle/input/dataset/vimmsd-train.json', 'r', encoding='utf-8') as f:
    data = json.load(f)

In [6]:
image_paths = []
captions = []
labels = []

for key, value in data.items():
    image_paths.append(value['image'])
    captions.append(value['caption'])
    labels.append(value['label'])

In [7]:
from sklearn.preprocessing import LabelEncoder
label_encoder = LabelEncoder()
labels_encoded = label_encoder.fit_transform(labels)
label_mapping = {index: label for index, label in enumerate(label_encoder.classes_)}

In [8]:
label_mapping

{0: 'image-sarcasm', 1: 'multi-sarcasm', 2: 'not-sarcasm', 3: 'text-sarcasm'}

In [9]:
from sklearn.utils.class_weight import compute_class_weight
class_weights = compute_class_weight(
                    class_weight='balanced',
                    classes=np.unique(labels_encoded),
                    y=labels_encoded
                )
class_weights

array([ 6.11142534,  0.63950047,  0.44560376, 35.08116883])

In [10]:
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
class_weights

tensor([ 6.1114,  0.6395,  0.4456, 35.0812], device='cuda:0')

In [11]:
alpha_list = class_weights.tolist()

In [12]:
class_counts = np.bincount(labels_encoded)
class_counts

array([ 442, 4224, 6062,   77])

In [13]:
from sklearn.model_selection import train_test_split
train_image_paths, val_image_paths, train_captions, val_captions, train_labels, val_labels = train_test_split(image_paths, captions, labels_encoded, test_size=0.2, stratify=labels, random_state=42)

In [14]:
unique_classes, counts = np.unique(train_labels, return_counts=True)
for cls, count in zip(unique_classes, counts):
    print(f"Class {cls}: {count} predictions")

Class 0: 353 predictions
Class 1: 3379 predictions
Class 2: 4850 predictions
Class 3: 62 predictions


In [15]:
unique_classes, counts = np.unique(val_labels, return_counts=True)
for cls, count in zip(unique_classes, counts):
    print(f"Class {cls}: {count} predictions")

Class 0: 89 predictions
Class 1: 845 predictions
Class 2: 1212 predictions
Class 3: 15 predictions


In [16]:
img_dir = "/kaggle/input/dataset/training-images/train-images"
ocr_df_path = "/kaggle/input/sarcasm-ocr-text/train_ocr_text.csv"

In [17]:
train_dataset = SarcasmDataset(
    image_file_names=train_image_paths,
    captions=train_captions,
    labels=train_labels,
    img_dir=img_dir,
    ocr_df_path=ocr_df_path 
)

In [18]:
val_dataset = SarcasmDataset(
    image_file_names=val_image_paths,
    captions=val_captions,
    labels=val_labels,
    img_dir=img_dir,
    ocr_df_path=ocr_df_path
)

In [19]:
train_data = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_data = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Model

In [20]:
class CrossAttention(nn.Module):
    def __init__(self, d_in_q, d_in_kv, d_out_kq, d_out_v):
        super().__init__()
        self.d_out_kq = d_out_kq
        self.d_out_v = d_out_v

        self.W_query = nn.Linear(d_in_q, d_out_kq)
        self.W_key = nn.Linear(d_in_kv, d_out_kq)
        self.W_value = nn.Linear(d_in_kv, d_out_v)

    def forward(self, x_1, x_2):
        queries_1 = self.W_query(x_1)
        keys_2 = self.W_key(x_2)
        values_2 = self.W_value(x_2)
        attn_scores = torch.matmul(queries_1, keys_2.transpose(-2, -1))
        attn_scores = attn_scores / (self.d_out_kq ** 0.5)
        attn_weights = torch.softmax(attn_scores, dim=-1)
        context_vec = torch.matmul(attn_weights, values_2)
        return context_vec

In [21]:
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v):
        super().__init__()
        self.d_out_kq = d_out_kq
        self.W_query = nn.Linear(d_in, d_out_kq)
        self.W_key = nn.Linear(d_in, d_out_kq)
        self.W_value = nn.Linear(d_in, d_out_v)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = torch.matmul(queries, keys.transpose(-2, -1))
        attn_scores = attn_scores / (self.d_out_kq ** 0.5)
        attn_weights = torch.softmax(attn_scores, dim=-1)
        context_vec = torch.matmul(attn_weights, values)
        return context_vec

In [22]:
class SarcasmModel(nn.Module):
    def __init__(self):
        super(SarcasmModel, self).__init__()
        self.image_processor = image_processor
        self.image_model = image_model
        self.text_tokenizer = text_tokenizer
        self.text_model = text_model

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

        self.image_model.to(self.device).to(torch.float32)
        self.text_model.to(self.device).to(torch.float32)

        self.image_fc = nn.Sequential(
            nn.Linear(1000,1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024,512),
            nn.ReLU()
        )

        self.ocr_fc = nn.Sequential(
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 512),
            nn.ReLU()
        )
        self.caption_fc = nn.Sequential(
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 512),
            nn.ReLU()
        )

        self.image_ocr = nn.Sequential(
            nn.Linear(512*2,1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024,512),
            nn.ReLU()
        )

        self.final_fc = nn.Sequential(
            nn.Linear(512*4,1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024,512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512,256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256,128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128,4),
        )

        self.image_attention = SelfAttention(d_in=512 + 512, d_out_kq=512, d_out_v=512)
        self.text_attend_to_image = CrossAttention(d_in_q=512, d_in_kv=512, d_out_kq=512, d_out_v=512)
        self.image_attend_to_text = CrossAttention(d_in_q=512, d_in_kv=512, d_out_kq=512, d_out_v=512)

    def forward(self, image_names, ocr_texts, captions):
        # Image features
        images = [cv2.imread(image_name) for image_name in image_names]
        image_inputs = self.image_processor(images=images, return_tensors="pt").to(self.device)
        with torch.no_grad():
            image_outputs = self.image_model(**image_inputs)
        image_features = image_outputs.logits
        image_features_fc = self.image_fc(image_features)                # -> (batch, 512)

        # OCR features
        ocr_inputs = self.text_tokenizer(
            ocr_texts,
            return_tensors="pt", 
            padding="longest",
            truncation=True, 
            max_length=512
        ).to(self.device)
        with torch.no_grad():
            ocr_outputs = self.text_model(**ocr_inputs)
        ocr_features = ocr_outputs.last_hidden_state.mean(dim=1)
        ocr_features_fc = self.ocr_fc(ocr_features)

        # Image and OCR combine
        image_ocr_combined = torch.cat((image_features_fc, ocr_features_fc), dim=1)
        image_ocr_combined = self.image_ocr(image_ocr_combined)
        
        # Caption features
        caption_inputs = self.text_tokenizer(
            captions,
            return_tensors="pt", 
            padding="longest",
            truncation=True, 
            max_length=512
        ).to(self.device)
        with torch.no_grad():
            caption_outputs = self.text_model(**caption_inputs)
        caption_features = caption_outputs.last_hidden_state.mean(dim=1)
        caption_features_fc = self.caption_fc(caption_features)

        # Caption and image attention
        text_to_image = self.text_attend_to_image(caption_features_fc, image_ocr_combined)
        image_to_text = self.image_attend_to_text(image_ocr_combined, caption_features_fc)
        combined_features = torch.cat((text_to_image, image_to_text),dim=1)

        final_features = torch.cat((image_ocr_combined, caption_features_fc, combined_features), dim=1)
        outputs = self.final_fc(final_features)
        return outputs

In [23]:
import random
def set_seed(seed):
    random.seed(seed)            # Python's random module
    np.random.seed(seed)         # NumPy
    torch.manual_seed(seed)      # PyTorch
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # For multi-GPU setups
    torch.backends.cudnn.deterministic = True  # Ensures deterministic behavior
    torch.backends.cudnn.benchmark = False     # Disables benchmark for reproducibility

In [24]:
model = SarcasmModel().to(device)

In [25]:
from torchinfo import summary
summary(model)

Layer (type:depth-idx)                                                           Param #
SarcasmModel                                                                     559,364,096
├─ViTForImageClassification: 1-1                                                 --
│    └─ViTModel: 2-1                                                             --
│    │    └─ViTEmbeddings: 3-1                                                   742,656
│    │    └─ViTEncoder: 3-2                                                      85,054,464
│    │    └─LayerNorm: 3-3                                                       1,536
│    └─Linear: 2-2                                                               769,000
├─XLMRobertaLoRA: 1-2                                                            --
│    └─XLMRobertaModel: 2-3                                                      --
│    │    └─XLMRobertaEmbeddings: 3-4                                            261,044,092
│    │    └─Dropout: 3-5        

# Train

## Function

In [26]:
from typing import List

class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2, reduction='mean', label_smoothing=0.0):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.reduction = reduction
        self.alpha = alpha
        self.label_smoothing = label_smoothing 
        if alpha is not None:
            self.alpha = torch.tensor(alpha, dtype=torch.float)

    def forward(self, logits, targets):
        if self.label_smoothing > 0:
            n_classes = logits.size(-1)
            one_hot = torch.zeros_like(logits).scatter(1, targets.unsqueeze(1), 1)
            smooth_labels = (1 - self.label_smoothing) * one_hot + self.label_smoothing / n_classes
            log_probs = F.log_softmax(logits, dim=-1)
            ce_loss = -(smooth_labels * log_probs).sum(dim=-1)
        else:
            ce_loss = nn.CrossEntropyLoss(reduction='none')(logits, targets)

        pt = torch.exp(-ce_loss)
        if self.alpha is not None:
            alpha = self.alpha.to(targets.device)
            alpha_t = alpha.gather(0, targets.data.view(-1))
            focal_loss = alpha_t * (1 - pt) ** self.gamma * ce_loss
        else:
            focal_loss = (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

In [27]:
def train_one_epoch(model, dataloader, device, optimizer, epoch, scheduler, loss_hist):
    model.train()
    
    loss_function = FocalLoss(gamma=1.5, alpha=alpha_list,reduction="mean", label_smoothing=0.1)
    
    accu_num = torch.zeros(1).to(device)
    accu_loss = torch.zeros(1).to(device)
    dataloader = tqdm(dataloader, file=sys.stdout)
    
    num_sample = 0
    for step, data in enumerate(dataloader):
        optimizer.zero_grad()
        img_names, ocr_texts, captions, labels = data
        labels = labels.to(device) 
        num_sample += len(labels)
        
        pred = model(img_names, ocr_texts, captions)
        pred_classes = torch.argmax(pred, dim=1)
        accu_num += torch.eq(pred_classes, labels).sum()
        
        loss = loss_function(pred, labels)
        loss.backward()
        accu_loss += loss.detach()
        
        dataloader.desc = "Train epoch {}: Loss {:.3f} Accuracy {:.3f}, Learning rate {:.7f}".format(
            epoch,
            accu_loss.item() / (step + 1),
            accu_num.item() / num_sample,
            optimizer.param_groups[0]["lr"]
        )
        
        optimizer.step()
        scheduler.step()
    
    avg_loss = accu_loss.item() / (step + 1)
    avg_accu = accu_num.item() / num_sample
    
    loss_hist.append(avg_loss)
    
    return avg_loss, avg_accu

In [28]:
def evaluate(mode, dataloader, device, epoch, loss_hist):
    model.eval()
    
    loss_function = nn.CrossEntropyLoss()
    
    accu_num = torch.zeros(1).to(device)
    accu_loss = torch.zeros(1).to(device)
    dataloader = tqdm(dataloader, file=sys.stdout)
    all_preds = []
    all_labels = []
    
    num_sample = 0
    with torch.no_grad():
        for step, data in enumerate(dataloader):
            img_names, ocr_texts, captions, labels = data
            labels = labels.to(device) 
            num_sample += len(labels)
            pred = model(img_names, ocr_texts, captions)
            pred_classes = torch.argmax(pred, dim=1)
            accu_num += torch.eq(pred_classes, labels).sum()
            
            all_preds.extend(pred_classes.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            loss = loss_function(pred, labels)
            accu_loss += loss
            
            dataloader.desc = "Validate epoch {}: Loss {:.3f}, Accuracy {:.3f}".format(
                    epoch,
                    accu_loss.item() / (step + 1),
                    accu_num.item() / num_sample,
            )
    
    avg_loss = accu_loss.item() / (step + 1)
    avg_accu = accu_num.item() / num_sample
    
    loss_hist.append(avg_loss)
    
    return avg_loss, avg_accu, np.array(all_labels), np.array(all_preds)

In [29]:
def create_lr_scheduler(optimizer,
                        num_step: int,
                        epochs: int,
                        warmup=True,
                        warmup_epochs=1,
                        warmup_factor=1e-3,
                        end_factor=1e-6):
    assert num_step > 0 and epochs > 0
    if warmup is False:
        warmup_epochs = 0

    def f(x):
        if warmup is True and x <= (warmup_epochs * num_step):
            alpha = float(x) / (warmup_epochs * num_step)
            return warmup_factor * (1 - alpha) + alpha
        else:
            current_step = (x - warmup_epochs * num_step)
            cosine_steps = (epochs - warmup_epochs) * num_step
            return ((1 + math.cos(current_step * math.pi / cosine_steps)) / 2) * (1 - end_factor) + end_factor

    return LambdaLR(optimizer, lr_lambda=f)

## Start training

In [30]:
wd = 1e-4
learning_rate = 2e-5
num_epochs = 50
set_seed(42)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=wd)
lr_scheduler = create_lr_scheduler(optimizer, len(train_data), num_epochs,
                                   warmup=True, warmup_epochs=1)
best_loss = 1e9
patience = 0
train_loss_hist = []
val_loss_hist = []

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_data, device, optimizer, epoch, lr_scheduler, train_loss_hist)
    val_loss, val_acc, all_labels, all_preds = evaluate(model, val_data, device, epoch, val_loss_hist)
    if best_loss >= val_loss:
        torch.save(model.state_dict(), "best_model.pth")
        best_loss = val_loss
        best_preds = all_preds
        best_labels = all_labels
        patience = 0
    else:
        patience += 1
        if patience >= 4:
            print("Loss does not improve in 4 epochs. Early Stopping!")
            break
        else:
            print(f"Loss does not improve in {patience} epochs!")

Train epoch 0: Loss 0.903 Accuracy 0.540, Learning rate 0.0000199: 100%|██████████| 271/271 [13:22<00:00,  2.96s/it]
Validate epoch 0: Loss 1.391, Accuracy 0.561: 100%|██████████| 68/68 [03:13<00:00,  2.85s/it]
Train epoch 1: Loss 0.929 Accuracy 0.519, Learning rate 0.0000200: 100%|██████████| 271/271 [12:03<00:00,  2.67s/it]
Validate epoch 1: Loss 1.375, Accuracy 0.338: 100%|██████████| 68/68 [02:53<00:00,  2.55s/it]
Train epoch 2: Loss 0.886 Accuracy 0.428, Learning rate 0.0000199: 100%|██████████| 271/271 [11:58<00:00,  2.65s/it]
Validate epoch 2: Loss 1.340, Accuracy 0.302: 100%|██████████| 68/68 [02:53<00:00,  2.55s/it]
Train epoch 3: Loss 0.829 Accuracy 0.474, Learning rate 0.0000198: 100%|██████████| 271/271 [12:04<00:00,  2.67s/it]
Validate epoch 3: Loss 1.266, Accuracy 0.447: 100%|██████████| 68/68 [02:53<00:00,  2.55s/it]
Train epoch 4: Loss 0.750 Accuracy 0.465, Learning rate 0.0000197: 100%|██████████| 271/271 [12:04<00:00,  2.67s/it]
Validate epoch 4: Loss 1.263, Accuracy 

# Private test

In [31]:
with open('/kaggle/input/private-test/vimmsd-private-test.json', 'r', encoding='utf-8') as f:
    dev_data = json.load(f)

In [32]:
dev_image_paths = []
dev_captions = []
dev_labels = []

for key, value in dev_data.items():
    dev_image_paths.append(value['image'])
    dev_captions.append(value['caption'])
    dev_labels.append(value['label'])

In [33]:
dev_labels_encoded =  [0 if label is None else label for label in dev_labels]

In [34]:
dev_labels_encoded = np.array(dev_labels_encoded)

In [35]:
dev_dir = "/kaggle/input/private-test/private-test-images/test-images"
dev_ocr_df_path = "/kaggle/input/sarcasm-ocr-text/test_ocr_text.csv"

In [36]:
dev_sarcasm_dataset = SarcasmDataset(
    image_file_names=dev_image_paths,
    captions=dev_captions,
    labels=dev_labels_encoded,
    img_dir=dev_dir,
    ocr_df_path = dev_ocr_df_path,
)

# Create DataLoader
dev_dataloader = DataLoader(dev_sarcasm_dataset, batch_size=32, shuffle=False)

In [37]:
model.eval()
predictions = []
true_labels = []  
pred_labels = []

with torch.no_grad():
    for step, data in enumerate(tqdm(dev_dataloader, file=sys.stdout)):
        img_names, ocr_texts, captions, labels = data
        labels = labels.to(device)
        pred = model(img_names, ocr_texts, captions)
        pred_classes = torch.argmax(pred, dim=1)
        pred_labels.extend(pred_classes.cpu().numpy())

100%|██████████| 47/47 [01:43<00:00,  2.20s/it]


In [38]:
unique_classes, counts = np.unique(pred_labels, return_counts=True)
for cls, count in zip(unique_classes, counts):
    print(f"Class {cls}: {count} predictions")

Class 0: 78 predictions
Class 1: 934 predictions
Class 2: 452 predictions
Class 3: 40 predictions


In [39]:
results = {i: label_mapping[pred_labels[i]] for i in range(len(dev_image_paths))} 
phase = "test"
final_output = {
    "results": results,
    "phase": phase
}

In [40]:
with open(f'results.json', 'w') as f:
    json.dump(final_output, f, indent=4)

In [41]:
import zipfile
zip_filename = 'results.zip'
with zipfile.ZipFile(zip_filename, 'w') as zipf:
    zipf.write('results.json', os.path.basename('results.json'))