### Import Packages

In [None]:
%pip install -U nltk rouge-score

In [None]:
import os
import random
from itertools import product

import nltk
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from nltk.translate.bleu_score import corpus_bleu
from nltk.translate.meteor_score import meteor_score
from PIL import Image
from rouge_score import rouge_scorer
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer, GPT2LMHeadModel

In [None]:
nltk.download("punkt", quiet=True)
nltk.download("wordnet", quiet=True)

### Environment Setup and Library Imports

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seed()
device = "cuda" if torch.cuda.is_available() else "cpu"

### Custom Dataset Class for Image Captioning


In [None]:
class ImageCaptioningDataset(Dataset):
    def __init__(self, csv_file, img_dir, tokenizer, image_processor, max_length=128):
        self.img_captions = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.max_length = max_length

    def __len__(self):
        return len(self.img_captions)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.img_dir, str(self.img_captions.iloc[idx, 1]))
        caption = self.img_captions.iloc[idx, 2]

        image = Image.open(img_name).convert("RGB")
        pixel_values = self.image_processor(
            images=image, return_tensors="pt"
        ).pixel_values.squeeze(0)

        caption_with_start = f"<|startoftext|> {caption}"
        caption_encoding = self.tokenizer(
            caption_with_start,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
            return_tensors="pt",
        )

        input_ids = caption_encoding.input_ids.squeeze(0)
        attention_mask = caption_encoding.attention_mask.squeeze(0)
        target_ids = input_ids.clone()

        return {
            "pixel_values": pixel_values,
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "target_ids": target_ids,
            "caption": caption,
        }

### DataLoader Preparation for Training, Validation, and Testing


In [None]:
def get_dataloaders(data_dir, batch_size=16):
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token

    special_tokens = {"additional_special_tokens": ["<|startoftext|>"]}
    tokenizer.add_special_tokens(special_tokens)

    image_processor = AutoImageProcessor.from_pretrained(
        "WinKawaks/vit-small-patch16-224"
    )

    train_dataset = ImageCaptioningDataset(
        csv_file=os.path.join(data_dir, "train.csv"),
        img_dir=os.path.join(data_dir, "train"),
        tokenizer=tokenizer,
        image_processor=image_processor,
    )

    val_dataset = ImageCaptioningDataset(
        csv_file=os.path.join(data_dir, "val.csv"),
        img_dir=os.path.join(data_dir, "val"),
        tokenizer=tokenizer,
        image_processor=image_processor,
    )

    test_dataset = ImageCaptioningDataset(
        csv_file=os.path.join(data_dir, "test.csv"),
        img_dir=os.path.join(data_dir, "test"),
        tokenizer=tokenizer,
        image_processor=image_processor,
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
    )

    return train_loader, val_loader, test_loader, tokenizer, image_processor

### Vision-to-Text Image Captioning Model using ViT Encoder and GPT-2 Decoder


In [None]:
class ImageCaptionModel(nn.Module):
    def __init__(
        self,
        vit_model_name="WinKawaks/vit-small-patch16-224",
        gpt2_model_name="gpt2",
        dropout_rate=0.5,
    ):
        super().__init__()

        self.encoder = AutoModel.from_pretrained(vit_model_name)
        self.encoder_dim = self.encoder.config.hidden_size

        self.decoder = GPT2LMHeadModel.from_pretrained(gpt2_model_name)
        self.decoder_dim = self.decoder.config.n_embd

        self.image_proj = nn.Sequential(
            nn.Linear(self.encoder_dim, self.decoder_dim),
            nn.LayerNorm(self.decoder_dim),
            nn.Dropout(dropout_rate),
            nn.ReLU(),
            nn.Linear(self.decoder_dim, self.decoder_dim),
        )
        self.decoder.resize_token_embeddings(self.decoder.config.vocab_size + 1)

    def forward(self, pixel_values, input_ids, attention_mask):
        encoder_outputs = self.encoder(pixel_values=pixel_values)
        image_embedding = encoder_outputs.last_hidden_state[:, 0]
        image_embedding = self.image_proj(image_embedding)
        image_embedding = image_embedding.unsqueeze(1)

        inputs_embeds = self.decoder.transformer.wte(input_ids)
        inputs_embeds = torch.cat([image_embedding, inputs_embeds], dim=1)

        extended_attention_mask = torch.cat(
            [
                torch.ones((attention_mask.size(0), 1), device=attention_mask.device),
                attention_mask,
            ],
            dim=1,
        )

        outputs = self.decoder(
            inputs_embeds=inputs_embeds,
            attention_mask=extended_attention_mask,
            return_dict=True,
        )

        return outputs.logits

    def generate_caption(
        self, pixel_values, tokenizer, max_length=128, temperature=0.7
    ):
        self.eval()
        with torch.no_grad():
            encoder_outputs = self.encoder(pixel_values=pixel_values)
            image_embedding = encoder_outputs.last_hidden_state[:, 0]
            image_embedding = self.image_proj(image_embedding).unsqueeze(1)

            generated_ids = torch.tensor(
                [[tokenizer.convert_tokens_to_ids("<|startoftext|>")]]
            ).to(pixel_values.device)
            inputs_embeds = self.decoder.transformer.wte(generated_ids)
            inputs_embeds = torch.cat([image_embedding, inputs_embeds], dim=1)

            attention_mask = torch.ones(
                (inputs_embeds.size(0), inputs_embeds.size(1)),
                device=pixel_values.device,
            )

            for _ in range(max_length):
                outputs = self.decoder(
                    inputs_embeds=inputs_embeds,
                    attention_mask=attention_mask,
                    return_dict=True,
                )

                next_token_logits = outputs.logits[:, -1, :] / temperature

                probs = F.softmax(next_token_logits, dim=-1)
                sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

                sorted_indices_to_remove = cumulative_probs > 0.9
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                    ..., :-1
                ].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices_to_remove.scatter(
                    1, sorted_indices, sorted_indices_to_remove
                )
                next_token_logits[indices_to_remove] = -float("Inf")

                next_token = torch.multinomial(
                    F.softmax(next_token_logits, dim=-1), num_samples=1
                )

                generated_ids = torch.cat([generated_ids, next_token], dim=1)

                if next_token.item() == tokenizer.eos_token_id:
                    break

                next_token_embeds = self.decoder.transformer.wte(next_token)
                inputs_embeds = torch.cat([inputs_embeds, next_token_embeds], dim=1)
                attention_mask = torch.cat(
                    [attention_mask, torch.ones((1, 1), device=pixel_values.device)],
                    dim=1,
                )

            caption = tokenizer.decode(
                generated_ids[0], skip_special_tokens=True
            ).strip()
            return caption


### Model Training with Early Stopping and Learning Rate Scheduler


In [None]:
def train_model(
    model,
    train_loader,
    val_loader,
    optimizer,
    scheduler,
    criterion,
    device,
    epochs,
    early_stopping_patience=10,
    save_path="best_model.pth",
):
    best_val_loss = float("inf")
    patience_counter = 0

    train_losses = []
    val_losses = []
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        optimizer.zero_grad()

        train_progress = tqdm(
            train_loader,
            desc=f"Epoch {epoch + 1}/{epochs} [Train]",
            position=0,
            leave=False,
            ncols=100,
        )

        for batch in train_progress:
            pixel_values = batch["pixel_values"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            target_ids = batch["target_ids"].to(device)

            outputs = model(pixel_values, input_ids, attention_mask)

            shift_logits = outputs[:, 1:-1, :].contiguous()
            shift_labels = target_ids[:, 1:].contiguous()

            loss = criterion(
                shift_logits.reshape(-1, shift_logits.size(-1)),
                shift_labels.reshape(-1),
            )

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad()

            train_loss += loss.item()
            train_progress.set_postfix(
                {
                    "loss": f"{loss.item():.4f}",
                    "lr": f"{optimizer.param_groups[0]['lr']:.6f}",
                }
            )

        avg_train_loss = train_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        model.eval()
        val_loss = 0.0
        val_progress = tqdm(
            val_loader,
            desc=f"Epoch {epoch + 1}/{epochs} [Val]",
            position=0,
            leave=False,
            ncols=100,
        )
        with torch.no_grad():
            for batch in val_progress:
                pixel_values = batch["pixel_values"].to(device)
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                target_ids = batch["target_ids"].to(device)

                outputs = model(pixel_values, input_ids, attention_mask)

                shift_logits = outputs[:, 1:-1, :].contiguous()
                shift_labels = target_ids[:, 1:].contiguous()

                loss = criterion(
                    shift_logits.reshape(-1, shift_logits.size(-1)),
                    shift_labels.reshape(-1),
                )

                val_loss += loss.item()
                val_progress.set_postfix({"loss": f"{loss.item():.4f}"})

        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        scheduler.step(avg_val_loss)
        print(
            f"Epoch {epoch + 1}/{epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}",
            end="",
        )
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), save_path)
            patience_counter = 0
            print(f"New best model saved with validation loss: {best_val_loss:.4f}")
        else:
            patience_counter += 1
            print(
                f"No improvement. Patience: {patience_counter}/{early_stopping_patience}"
            )
            if patience_counter >= early_stopping_patience:
                print("Early stopping triggered.")
                break

    model.load_state_dict(torch.load(save_path))
    return model, best_val_loss


### Hyperparameter Tuning for Image Captioning Model


In [None]:
def hyperparameter_tuning(train_loader, val_loader, tokenizer, device):
    model = ImageCaptionModel().to(device)
    hyperparameters = {
        "learning_rate": [5e-5, 1e-4, 5e-4],
        "weight_decay": [0.01, 0.001],
    }

    best_val_loss = float("inf")
    best_hyperparams = None
    best_model_state = None
    products = list(
        product(
            hyperparameters["learning_rate"],
            hyperparameters["weight_decay"],
        )
    )

    hp_progress = tqdm(
        products,
        desc="Hyperparameter tuning",
        total=len(products),
        position=0,
        leave=True,
        ncols=100,
    )

    for lr, wd in hp_progress:
        hp_progress.set_description(f"LR: {lr}, WD: {wd}")

        model = ImageCaptionModel().to(device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
        scheduler = ReduceLROnPlateau(
            optimizer, mode="min", factor=0.5, patience=5, verbose=False
        )
        criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

        model, val_loss = train_model(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            optimizer=optimizer,
            scheduler=scheduler,
            criterion=criterion,
            device=device,
            epochs=50,
            early_stopping_patience=5,
            save_path="best_image_caption_model.pth",
        )

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_hyperparams = {
                "learning_rate": lr,
                "weight_decay": wd,
            }
            best_model_state = model.state_dict()
            hp_progress.set_postfix({"best_val_loss": f"{best_val_loss:.4f}"})

    best_model = ImageCaptionModel()
    best_model.load_state_dict(best_model_state)
    best_model.to(device)

    return best_model, best_hyperparams

### Model Evaluation


In [None]:
def evaluate_model(model, test_dataloader, tokenizer, device):
    model.eval()
    references = []
    hypotheses = []
    rouge_scorer_obj = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)

    with torch.no_grad():
        for batch in tqdm(test_dataloader, desc="Evaluating model"):
            pixel_values = batch["pixel_values"].to(device)
            original_captions = batch["caption"]

            batch_captions = []
            for i in range(pixel_values.size(0)):
                caption = model.generate_caption(
                    pixel_values[i : i + 1],
                    tokenizer,
                    max_length=128,
                    temperature=0.7,
                )
                batch_captions.append(caption)

            for ref_caption, gen_caption in zip(original_captions, batch_captions):
                ref_tokens = nltk.word_tokenize(ref_caption.lower())
                gen_tokens = nltk.word_tokenize(gen_caption.lower())
                references.append([ref_tokens])
                hypotheses.append(gen_tokens)

    bleu_score = corpus_bleu(references, hypotheses)

    rouge_scores = []
    for ref, hyp in zip(
        [" ".join(ref[0]) for ref in references], [" ".join(hyp) for hyp in hypotheses]
    ):
        rouge_score = rouge_scorer_obj.score(ref, hyp)
        rouge_scores.append(rouge_score["rougeL"].fmeasure)
    rouge_l_score = sum(rouge_scores) / len(rouge_scores) if rouge_scores else 0

    meteor_scores = []
    for ref, hyp in zip(references, hypotheses):
        meteor_scores.append(meteor_score(ref, hyp))
    meteor_score_avg = sum(meteor_scores) / len(meteor_scores) if meteor_scores else 0

    print("\nExample Generations:")
    for i in range(min(5, len(references))):
        print(f"Reference: {' '.join(references[i][0])}")
        print(f"Generated: {' '.join(hypotheses[i])}")
        print()

    results = {
        "bleu": bleu_score,
        "rouge_l": rouge_l_score,
        "meteor": meteor_score_avg,
    }

    return results

### Loading Data

In [None]:
train_loader, val_loader, test_loader, tokenizer, image_processor = get_dataloaders(
    data_dir="/kaggle/input/dl-assignment-2/custom_captions_dataset"
)

### Hyperparameter Tuning for Best Model


In [None]:
best_model, best_hyperparams = hyperparameter_tuning(
    train_loader, val_loader, tokenizer, device
)

### Storing Results


In [None]:
test_results = evaluate_model(best_model, test_loader, tokenizer, device)

print("\nTest Results:")
print(f"BLEU: {test_results['bleu']:.4f}")
print(f"ROUGE-L: {test_results['rouge_l']:.4f}")
print(f"METEOR: {test_results['meteor']:.4f}")

torch.save(best_model.state_dict(), "/kaggle/working/best_image_caption_model.pth")
test_results_df = pd.DataFrame([test_results])
test_results_df.to_csv("/kaggle/working/custom_results.csv", index=False, header=True)