In [9]:
import os
import pickle
import pandas as pd

# Define paths to data files and image directory
IMAGE_FOLDER = "/kaggle/input/kanakdb/images-20250405T093757Z-001/images"
TRAIN_TSV = "/kaggle/input/kanakdb/train_df.tsv"
VAL_TSV = "/kaggle/input/kanakdb/val_df.tsv"

# Read TSV data into DataFrames
train_data = pd.read_csv(TRAIN_TSV, sep='\t', header=None, names=["pid", "caption", "explanation", "target"])
val_data = pd.read_csv(VAL_TSV, sep='\t', header=None, names=["pid", "caption", "explanation", "target"])

print(f"Number of training samples: {len(train_data)}, validation samples: {len(val_data)}")
print("First training entry:", train_data.iloc[0].to_dict())

# Initialize variables for pickled objects and descriptions
objects_train = objects_val = descs_train = descs_val = None

# File locations for pickled data
OBJ_TRAIN_PATH = "/kaggle/input/kanakdb/O_train.pkl"
DESC_TRAIN_PATH = "/kaggle/input/kanakdb/D_train.pkl"
OBJ_VAL_PATH = "/kaggle/input/kanakdb/O_val.pkl"
DESC_VAL_PATH = "/kaggle/input/kanakdb/D_val.pkl"

# Load object annotations if available
if os.path.exists(OBJ_TRAIN_PATH) and os.path.exists(OBJ_VAL_PATH):
    with open(OBJ_TRAIN_PATH, 'rb') as f:
        objects_train = pickle.load(f)
    with open(OBJ_VAL_PATH, 'rb') as f:
        objects_val = pickle.load(f)
    print("Successfully loaded object annotations.")

# Load image descriptions if available
if os.path.exists(DESC_TRAIN_PATH) and os.path.exists(DESC_VAL_PATH):
    with open(DESC_TRAIN_PATH, 'rb') as f:
        descs_train = pickle.load(f)
    with open(DESC_VAL_PATH, 'rb') as f:
        descs_val = pickle.load(f)
    print("Successfully loaded image descriptions.")

Number of training samples: 2984, validation samples: 176
First training entry: {'pid': 'pid', 'caption': 'text', 'explanation': 'explanation', 'target': 'target_of_sarcasm'}
Successfully loaded object annotations.
Successfully loaded image descriptions.


In [11]:
import os
import torch
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
from transformers import ViTFeatureExtractor, BartTokenizer

# Initialize models
VISION_MODEL = "google/vit-base-patch16-224"
TEXT_MODEL = "facebook/bart-base"

extractor = ViTFeatureExtractor.from_pretrained(VISION_MODEL)
tokenizer = BartTokenizer.from_pretrained(TEXT_MODEL)

# Token configuration
MAX_INPUT_TOKENS = 64
MAX_OUTPUT_TOKENS = 32
BOS_ID = tokenizer.bos_token_id
EOS_ID = tokenizer.eos_token_id
PAD_ID = tokenizer.pad_token_id

class ImageTextDataset(Dataset):
    def __init__(self, dataframe, object_map=None, description_map=None, mode="train"):
        self.dataframe = dataframe
        self.objects = object_map
        self.descriptions = description_map
        self.mode = mode

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        row = self.dataframe.iloc[index]
        pid = str(row["pid"])
        caption = str(row["caption"])
        target = str(row["target"]) if pd.notna(row["target"]) else ""
        explanation = str(row["explanation"]) if pd.notna(row["explanation"]) else ""

        # Attempt to load image; fallback to black placeholder
        image_path = os.path.join(IMAGE_FOLDER, f"{pid}.jpg")
        try:
            image = Image.open(image_path).convert("RGB")
        except FileNotFoundError:
            image = Image.new("RGB", (224, 224), color=(0, 0, 0))

        image_tensor = extractor(images=image, return_tensors="pt")["pixel_values"].squeeze(0)

        # Construct combined input text
        components = [caption]
        if self.objects and pid in self.objects:
            obj_text = " ".join(self.objects[pid]) if isinstance(self.objects[pid], list) else str(self.objects[pid])
            components.append(obj_text)
        if self.descriptions and pid in self.descriptions:
            desc_text = " ".join(self.descriptions[pid]) if isinstance(self.descriptions[pid], list) else str(self.descriptions[pid])
            components.append(desc_text)
        components.append(target)

        separator = tokenizer.eos_token
        combined_text = f" {separator} ".join([part for part in components if part.strip()])
        tokenized = tokenizer(combined_text, max_length=MAX_INPUT_TOKENS, truncation=True, padding="max_length", add_special_tokens=False, return_tensors="pt")

        input_ids = tokenized["input_ids"].squeeze(0)
        attention = tokenized["attention_mask"].squeeze(0)

        # Ensure EOS token is present
        token_count = attention.sum().item()
        if token_count > 0:
            last_index = int(token_count) - 1
            if input_ids[last_index] != EOS_ID and last_index < MAX_INPUT_TOKENS - 1:
                input_ids[last_index + 1] = EOS_ID
                attention[last_index + 1] = 1
            elif input_ids[last_index] != EOS_ID:
                input_ids[last_index] = EOS_ID
        else:
            input_ids[0] = EOS_ID
            attention[0] = 1

        if self.mode != "test":
            # Tokenize explanation for decoder input and labels
            explanation_encoded = tokenizer(explanation, max_length=MAX_OUTPUT_TOKENS - 1, truncation=True, padding="max_length", add_special_tokens=False, return_tensors="pt")
            explanation_ids = explanation_encoded["input_ids"].squeeze(0)

            decoder_inputs = torch.full((MAX_OUTPUT_TOKENS,), PAD_ID, dtype=torch.long)
            decoder_inputs[0] = BOS_ID
            decoder_inputs[1:] = explanation_ids[:MAX_OUTPUT_TOKENS - 1]

            label_tokens = torch.full((MAX_OUTPUT_TOKENS,), PAD_ID, dtype=torch.long)
            valid_tokens = (explanation_ids != PAD_ID).sum().item()
            if valid_tokens > 0:
                label_tokens[:valid_tokens] = explanation_ids[:valid_tokens]
                if valid_tokens < MAX_OUTPUT_TOKENS:
                    label_tokens[valid_tokens] = EOS_ID
            else:
                label_tokens[0] = EOS_ID
        else:
            decoder_inputs = None
            label_tokens = None

        return {
            "pixel_values": image_tensor,
            "input_ids": input_ids,
            "attention_mask": attention,
            "decoder_input_ids": decoder_inputs,
            "labels": label_tokens,
            "pid": pid,
            "caption": caption,
            "target": target,
            "explanation": explanation
        }

# Dataset instantiation
train_dataset = ImageTextDataset(train_df, object_map=O_train, description_map=D_train, mode="train")
val_dataset = ImageTextDataset(val_df, object_map=O_val, description_map=D_val, mode="val")

# View a sample
print("Sample training datapoint:", train_dataset[0])

Sample training datapoint: {'pixel_values': tensor([[[-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         ...,
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.]],

        [[-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         ...,
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.]],

        [[-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         ...,
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.]]]), 'input_ids': tensor([29015,  1437,     2,  1002,  1215,  1116

In [24]:
import torch
import torch.nn as nn
from transformers import BartForConditionalGeneration, ViTModel

class FusionMultimodalModel(nn.Module):
    def __init__(self, text_model="facebook/bart-base", vision_model="google/vit-base-patch16-224-in21k"):
        super().__init__()

        # Load pretrained vision and language encoders
        self.vision_encoder = ViTModel.from_pretrained(vision_model)
        self.language_model = BartForConditionalGeneration.from_pretrained(text_model)

        # Optional: freeze visual backbone
        for param in self.vision_encoder.parameters():
            param.requires_grad = False

        # Model dimensionality
        hidden_dim = self.language_model.config.d_model

        # Learnable gating for feature merging
        self.gate_text = nn.Linear(2 * hidden_dim, hidden_dim)
        self.gate_image = nn.Linear(2 * hidden_dim, hidden_dim)

        # Padding token ID used in loss
        self.pad_token_id = self.language_model.config.pad_token_id

    def forward(self, input_ids, attention_mask, pixel_values, decoder_input_ids=None, labels=None):
        # Encode textual tokens
        lang_output = self.language_model.model.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        H_text = lang_output.last_hidden_state  # (B, T, D)

        # Encode image features
        vis_output = self.vision_encoder(pixel_values=pixel_values, return_dict=True)
        H_image = vis_output.last_hidden_state  # (B, V, D)

        # Global summaries
        global_image = H_image[:, 0, :]  # CLS token
        mask_float = attention_mask.unsqueeze(-1).float()
        global_text = (H_text * mask_float).sum(dim=1) / torch.clamp(mask_float.sum(dim=1), min=1e-9)

        # Inter-modal conditioning
        B, T, D = H_text.shape
        V = H_image.shape[1]

        global_image_rep = global_image.unsqueeze(1).expand(-1, T, -1)
        global_text_rep = global_text.unsqueeze(1).expand(-1, V, -1)

        conditioned_text = H_text * global_image_rep
        conditioned_image = H_image * global_text_rep

        # Gated multimodal fusion
        text_merge = torch.cat([H_text, conditioned_text], dim=-1)
        image_merge = torch.cat([H_image, conditioned_image], dim=-1)

        gate_t = torch.sigmoid(self.gate_text(text_merge))
        gate_v = torch.sigmoid(self.gate_image(image_merge))

        fused_text = gate_t * H_text + (1 - gate_t) * conditioned_text
        fused_image = gate_v * H_image + (1 - gate_v) * conditioned_image

        # Combine both modalities
        fused_sequence = torch.cat([fused_text, fused_image], dim=1)

        # Create extended attention mask for decoder
        mask_image = torch.ones((B, V), device=attention_mask.device).long()
        combined_mask = torch.cat([attention_mask, mask_image], dim=1)

        # Auto-generate decoder inputs if labels provided but decoder inputs aren't
        if decoder_input_ids is None and labels is not None:
            decoder_input_ids = self.language_model.prepare_decoder_input_ids_from_labels(labels)

        # Run decoder with fused features
        decoder_output = self.language_model.model.decoder(
            input_ids=decoder_input_ids,
            encoder_hidden_states=fused_sequence,
            encoder_attention_mask=combined_mask,
            return_dict=True,
            use_cache=False
        )
        decoder_hidden = decoder_output.last_hidden_state

        # Language modeling head
        logits = self.language_model.lm_head(decoder_hidden) + self.language_model.final_logits_bias

        # Loss calculation
        loss = None
        if labels is not None:
            criterion = nn.CrossEntropyLoss(ignore_index=self.pad_token_id)
            loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))

        return {
            "loss": loss,
            "logits": logits,
            "encoder_outputs": fused_sequence
        }

In [25]:
!pip install evaluate
!pip install rouge_score
!pip install --upgrade nltk
!pip install bert_score
import nltk
nltk.download('wordnet')
nltk.download('omw-1.4')
nltk.download('wordnet', download_dir='/kaggle/working/nltk_data')
nltk.download('omw-1.4', download_dir='/kaggle/working/nltk_data')

# Telling nltk to look here
import os
os.environ['NLTK_DATA'] = '/kaggle/working/nltk_data'
!find /kaggle/working/nltk_data -type f


/kaggle/working/nltk_data/corpora/omw-1.4/msa/wn-data-zsm.tab
/kaggle/working/nltk_data/corpora/omw-1.4/msa/README
/kaggle/working/nltk_data/corpora/omw-1.4/msa/citation.bib
/kaggle/working/nltk_data/corpora/omw-1.4/msa/wn-data-ind.tab
/kaggle/working/nltk_data/corpora/omw-1.4/msa/LICENSE
/kaggle/working/nltk_data/corpora/omw-1.4/hrv/README
/kaggle/working/nltk_data/corpora/omw-1.4/hrv/citation.bib
/kaggle/working/nltk_data/corpora/omw-1.4/hrv/wn-data-hrv.tab
/kaggle/working/nltk_data/corpora/omw-1.4/hrv/LICENSE
/kaggle/working/nltk_data/corpora/omw-1.4/tha/wn-data-tha.tab
/kaggle/working/nltk_data/corpora/omw-1.4/tha/citation.bib
/kaggle/working/nltk_data/corpora/omw-1.4/tha/LICENSE
/kaggle/working/nltk_data/corpora/omw-1.4/slk/wn-data-lit.tab
/kaggle/working/nltk_data/corpora/omw-1.4/slk/wn-data-slk.tab
/kaggle/working/nltk_data/corpora/omw-1.4/slk/README
/kaggle/working/nltk_data/corpora/omw-1.4/slk/citation.bib
/kaggle/working/nltk_data/corpora/omw-1.4/slk/LICENSE
/kaggle/working/n

[nltk_data] Downloading package wordnet to /usr/share/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /usr/share/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /kaggle/working/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     /kaggle/working/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [26]:
import os
import zipfile
import nltk

# Define a custom location for NLTK resource loading
custom_nltk_dir = '/kaggle/working/nltk_data'
os.makedirs(custom_nltk_dir, exist_ok=True)

# Paths to zipped resources and their target folders
corpus_dir = os.path.join(custom_nltk_dir, 'corpora')
wordnet_zip = os.path.join(corpus_dir, 'wordnet.zip')
omw_zip = os.path.join(corpus_dir, 'omw-1.4.zip')

wordnet_folder = os.path.join(corpus_dir, 'wordnet')
omw_folder = os.path.join(corpus_dir, 'omw-1.4')

# Unzip WordNet corpus if not already extracted
if not os.path.exists(wordnet_folder):
    with zipfile.ZipFile(wordnet_zip, 'r') as archive:
        archive.extractall(corpus_dir)
    print("WordNet corpus extracted.")

# Unzip OMW corpus if not already extracted
if not os.path.exists(omw_folder):
    with zipfile.ZipFile(omw_zip, 'r') as archive:
        archive.extractall(corpus_dir)
    print("OMW 1.4 corpus extracted.")

# Add the custom directory to NLTK's search paths
if custom_nltk_dir not in nltk.data.path:
    nltk.data.path.append(custom_nltk_dir)

# Example METEOR scoring (commented for now)
# meteor_score = meteor.compute(predictions=val_preds, references=val_refs)
# print("METEOR score:", meteor_score)

In [None]:
--


In [28]:
# Paste or import the TurboModel class before this point
# Example:
class TurboModel(nn.Module):
    def __init__(self, bart_model_name="facebook/bart-base", vit_model_name="google/vit-base-patch16-224-in21k"):
        super(TurboModel, self).__init__()
        self.vit = ViTModel.from_pretrained(vit_model_name)
        self.bart_model = BartForConditionalGeneration.from_pretrained(bart_model_name)

        for param in self.vit.parameters():
            param.requires_grad = False

        d = self.bart_model.config.d_model
        self.text_gate = nn.Linear(2 * d, d)
        self.image_gate = nn.Linear(2 * d, d)
        self.pad_token_id = self.bart_model.config.pad_token_id

    def forward(self, input_ids, attention_mask, pixel_values, decoder_input_ids=None, labels=None):
        encoder_outputs = self.bart_model.model.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        A_t = encoder_outputs.last_hidden_state
        A_v = self.vit(pixel_values=pixel_values, return_dict=True).last_hidden_state
        E_v = A_v[:, 0, :]
        mask = attention_mask.unsqueeze(-1).float()
        text_len = mask.sum(dim=1)
        E_t = (A_t * mask).sum(dim=1) / torch.clamp(text_len, min=1e-9)

        B, T, d = A_t.shape
        V = A_v.shape[1]
        E_v_exp = E_v.unsqueeze(1).expand(-1, T, -1)
        E_t_exp = E_t.unsqueeze(1).expand(-1, V, -1)
        E_tv = A_t * E_v_exp
        E_vt = A_v * E_t_exp

        g_t = torch.sigmoid(self.text_gate(torch.cat([A_t, E_tv], dim=-1)))
        g_v = torch.sigmoid(self.image_gate(torch.cat([A_v, E_vt], dim=-1)))
        F_t = g_t * A_t + (1 - g_t) * E_tv
        F_v = g_v * A_v + (1 - g_v) * E_vt

        fused_hidden = torch.cat([F_t, F_v], dim=1)
        fused_mask = torch.cat([attention_mask, torch.ones((B, V), device=attention_mask.device).long()], dim=1)

        if decoder_input_ids is None and labels is not None:
            decoder_input_ids = self.bart_model.prepare_decoder_input_ids_from_labels(labels)

        decoder_outputs = self.bart_model.model.decoder(
            input_ids=decoder_input_ids,
            encoder_hidden_states=fused_hidden,
            encoder_attention_mask=fused_mask,
            use_cache=False,
            return_dict=True
        )
        decoder_hidden = decoder_outputs.last_hidden_state
        logits = self.bart_model.lm_head(decoder_hidden) + self.bart_model.final_logits_bias

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=self.pad_token_id)
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))

        return {
            "loss": loss,
            "logits": logits,
            "encoder_outputs": fused_hidden
        }

In [30]:
import torch
from torch.utils.data import DataLoader
from transformers import AdamW, BartTokenizer
from tqdm.auto import tqdm
import evaluate

# Initialize tokenizer
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")

# Training hyperparameters
num_epochs = 3
batch_size = 8
learning_rate = 1e-4
max_output_length = 64  # adjust if needed
bos_token_id = tokenizer.bos_token_id
eos_token_id = tokenizer.eos_token_id

# Data loaders (make sure `train_dataset` and `val_dataset` are defined)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)  # for generation convenience

# Initialize model and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TurboModel().to(device)
optimizer = AdamW(model.parameters(), lr=learning_rate)

# Evaluation metrics
rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")
meteor = evaluate.load("meteor")
bertscore = evaluate.load("bertscore")

# Training loop
for epoch in range(1, num_epochs + 1):
    model.train()
    total_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)

    for batch in progress_bar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        pixel_values = batch['pixel_values'].to(device)
        decoder_input_ids = batch['decoder_input_ids'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            decoder_input_ids=decoder_input_ids,
            labels=labels
        )
        loss = outputs['loss']
        total_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        progress_bar.set_postfix({"train_loss": loss.item()})

    avg_loss = total_loss / len(train_loader)

    # Validation and generation
    model.eval()
    val_preds = []
    val_refs = []

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            pixel_values = batch['pixel_values'].to(device)

            # Initialize generated_ids with BOS token
            generated_ids = torch.tensor([[bos_token_id]], dtype=torch.long, device=device)

            for _ in range(max_output_length):
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    pixel_values=pixel_values,
                    decoder_input_ids=generated_ids
                )
                logits = outputs['logits']
                next_token_id = torch.argmax(logits[0, -1, :]).unsqueeze(0).unsqueeze(0)
                generated_ids = torch.cat([generated_ids, next_token_id], dim=1)

                if next_token_id.item() == eos_token_id:
                    break

            # Decode generated tokens to string (omit BOS and EOS)
            output_tokens = generated_ids[0].tolist()
            if output_tokens and output_tokens[0] == bos_token_id:
                output_tokens = output_tokens[1:]
            if eos_token_id in output_tokens:
                output_tokens = output_tokens[:output_tokens.index(eos_token_id)]

            pred_text = tokenizer.decode(output_tokens, skip_special_tokens=True)
            val_preds.append(pred_text.strip())

            # Reference explanation
            val_refs.append(batch['explanation'][0])

    # Compute evaluation metrics
    rouge_scores = rouge.compute(predictions=val_preds, references=val_refs)
    bleu_scores = bleu.compute(predictions=val_preds, references=val_refs)
    meteor_score = meteor.compute(predictions=val_preds, references=val_refs)
    bert_score = bertscore.compute(predictions=val_preds, references=val_refs, lang="en")

    bleu1, bleu2, bleu3, bleu4 = bleu_scores['precisions']
    bert_f1 = sum(bert_score['f1']) / len(bert_score['f1'])

    # Print epoch summary
    print(f"\nEpoch {epoch} Summary:")
    print(f"  Training Loss: {avg_loss:.4f}")
    print(f"  Validation ROUGE-1: {rouge_scores['rouge1']:.4f}, ROUGE-2: {rouge_scores['rouge2']:.4f}, ROUGE-L: {rouge_scores['rougeL']:.4f}")
    print(f"  Validation BLEU-1: {bleu1:.4f}, BLEU-2: {bleu2:.4f}, BLEU-3: {bleu3:.4f}, BLEU-4: {bleu4:.4f}")
    print(f"  Validation METEOR: {meteor_score['meteor']:.4f}, BERTScore F1: {bert_f1:.4f}")

    # Save model checkpoint
    ckpt_path = f"turbo_model_epoch{epoch}.pt"
    torch.save(model.state_dict(), ckpt_path)
    print(f"  Saved model checkpoint to {ckpt_path}")

    # Show some validation samples
    num_examples_to_show = 3
    print("\nExamples of generated explanations on validation set:")
    for i in range(min(num_examples_to_show, len(val_preds))):
        print(f"\nExample {i + 1}:")
        # print("Caption:", val_dataset.df.iloc[i]["caption"])
        # print("Sarcasm Target:", val_dataset.df.iloc[i]["target"])
        print("Caption:", val_dataset.dataframe.iloc[i]["caption"])
        print("Sarcasm Target:", val_dataset.dataframe.iloc[i]["target"])
        print("Ground Truth Explanation:", val_refs[i])
        print("Generated Explanation:", val_preds[i])


[nltk_data] Downloading package wordnet to /usr/share/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /usr/share/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /usr/share/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


Epoch 1:   0%|          | 0/373 [00:00<?, ?it/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1 Summary:
  Training Loss: 2.4873
  Validation ROUGE-1: 0.4922, ROUGE-2: 0.3401, ROUGE-L: 0.4636
  Validation BLEU-1: 0.5924, BLEU-2: 0.4069, BLEU-3: 0.3058, BLEU-4: 0.2397
  Validation METEOR: 0.4875, BERTScore F1: 0.9144
  Saved model checkpoint to turbo_model_epoch1.pt

Examples of generated explanations on validation set:

Example 1:
Caption: text
Sarcasm Target: target_of_sarcasm
Ground Truth Explanation: explanation
Generated Explanation: the target of the target_of_sarcasm isn't target.

Example 2:
Caption: '<user> thank u for this awesome network in malad ( see pic ) .  # patheticcs'
Sarcasm Target: <user>'s network in malad
Ground Truth Explanation: the author is pissed at <user> for not getting network in malad.
Generated Explanation: the author is pissed at <user> for such an awful network in malad.

Example 3:
Caption: Nothing like waiting for an hour on the tarmac for a gate to come open in snowy, windy Chicago!
Sarcasm Target: gate not opening 
Ground Truth Explan

Epoch 2:   0%|          | 0/373 [00:00<?, ?it/s]


Epoch 2 Summary:
  Training Loss: 1.7092
  Validation ROUGE-1: 0.4831, ROUGE-2: 0.3176, ROUGE-L: 0.4495
  Validation BLEU-1: 0.5390, BLEU-2: 0.3518, BLEU-3: 0.2666, BLEU-4: 0.2063
  Validation METEOR: 0.4958, BERTScore F1: 0.9112
  Saved model checkpoint to turbo_model_epoch2.pt

Examples of generated explanations on validation set:

Example 1:
Caption: text
Sarcasm Target: target_of_sarcasm
Ground Truth Explanation: explanation
Generated Explanation: the target of the author'sarcasm isn't the target.

Example 2:
Caption: '<user> thank u for this awesome network in malad ( see pic ) .  # patheticcs'
Sarcasm Target: <user>'s network in malad
Ground Truth Explanation: the author is pissed at <user> for not getting network in malad.
Generated Explanation: the author is pissed at <user> for such an awful network in malad.

Example 3:
Caption: Nothing like waiting for an hour on the tarmac for a gate to come open in snowy, windy Chicago!
Sarcasm Target: gate not opening 
Ground Truth Expla

Epoch 3:   0%|          | 0/373 [00:00<?, ?it/s]


Epoch 3 Summary:
  Training Loss: 1.3186
  Validation ROUGE-1: 0.5166, ROUGE-2: 0.3572, ROUGE-L: 0.4906
  Validation BLEU-1: 0.5885, BLEU-2: 0.4035, BLEU-3: 0.3120, BLEU-4: 0.2435
  Validation METEOR: 0.5236, BERTScore F1: 0.9164
  Saved model checkpoint to turbo_model_epoch3.pt

Examples of generated explanations on validation set:

Example 1:
Caption: text
Sarcasm Target: target_of_sarcasm
Ground Truth Explanation: explanation
Generated Explanation: the author doesn't like the target of the attack.

Example 2:
Caption: '<user> thank u for this awesome network in malad ( see pic ) .  # patheticcs'
Sarcasm Target: <user>'s network in malad
Ground Truth Explanation: the author is pissed at <user> for not getting network in malad.
Generated Explanation: the author is pissed at <user> for such poor network in malad.

Example 3:
Caption: Nothing like waiting for an hour on the tarmac for a gate to come open in snowy, windy Chicago!
Sarcasm Target: gate not opening 
Ground Truth Explanatio

Generate Test data 

In [34]:
# Switch to evaluation mode and generate for test set
model.eval()
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

test_preds = []
test_ids = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Generating Test Explanations"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        pixel_values = batch['pixel_values'].to(device)
        pid = batch['pid'][0]

        # Start generation with BOS token
        generated_ids = torch.tensor([[bos_token_id]], dtype=torch.long, device=device)

        for _ in range(max_output_length):
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                pixel_values=pixel_values,
                decoder_input_ids=generated_ids
            )
            logits = outputs['logits']
            next_token_id = torch.argmax(logits[0, -1, :]).unsqueeze(0).unsqueeze(0)
            generated_ids = torch.cat([generated_ids, next_token_id], dim=1)

            if next_token_id.item() == eos_token_id:
                break

        # Decode generated sequence
        output_tokens = generated_ids[0].tolist()

        # Remove BOS if present
        if output_tokens and output_tokens[0] == bos_token_id:
            output_tokens = output_tokens[1:]

        # Truncate at EOS if exists
        if eos_token_id in output_tokens:
            output_tokens = output_tokens[:output_tokens.index(eos_token_id)]

        # Decode to text
        expl = tokenizer.decode(output_tokens, skip_special_tokens=True).strip()
        test_preds.append(expl)
        test_ids.append(pid)

# Save the generated explanations to a TSV file
output_file = "generated_test_explanations.tsv"
with open(output_file, "w", encoding="utf-8") as f:
    for pid, expl in zip(test_ids, test_preds):
        f.write(f"{pid}\t{expl}\n")

print(f"\nSaved generated test explanations to {output_file}")

Generating Test Explanations:   0%|          | 0/3 [00:00<?, ?it/s]


Saved generated test explanations to val_df.tsv
