# A Flexible, Encoder-Decoder Framework for Image Captioning
### *Train, Evaluate, and Compare Captioning Models*

This notebook provides a comprehensive, single-file framework to train, evaluate, and compare two distinct image captioning models. The architecture is robust, reusable, and portable, inspired by professional research pipelines.

### Core Architectures for Comparison:
The primary goal is to compare two encoder-decoder architectures:
1.  **Classic Approach**: ResNet50 (Encoder) + LSTM (Decoder).
2.  **State-of-the-Art Approach**: ViT (Encoder) + Pre-trained Transformer Decoder (GPT-2).

### Key Architectural & Workflow Requirements:
- **Modular, Multi-Step Notebook**: Structured into logical cells (e.g., "Step 1: Setup", "Step 2: Configuration") for clarity.
- **Centralized Configuration**: All settings, including model choices and hyperparameters, are managed in a central `experiment_configs` dictionary.
- **Decoupled Computation and Reporting**: The main loop handles all computation first, storing results in a dictionary. A separate, final cell generates a clean, consolidated report.
- **Portable Evaluation**: Supports an "evaluate-only" mode by loading pre-trained weights from a local path or a Google Drive zip file.
- **Data Pipeline & Vocabulary**: Includes a `Vocabulary` class to handle word-to-index mapping and special tokens (`<start>`, `<end>`, `<pad>`, `<unk>`). The DataLoader uses this to numericalize captions.
- **Advanced Training**: Implements **teacher forcing**, **two-phase fine-tuning** with **differential learning rates**, **gradient accumulation**, **early stopping**, and **caption sampling** for robust training.
- **Inference and Evaluation**: A `generate_caption` function produces text for evaluation. Standard captioning metrics are calculated: **BLEU, METEOR, and CIDEr**.

### **Step 1: Environment Setup and Dependency Installation**
This cell handles the initial setup, installing necessary packages for the project. It includes the `pycocoevalcap` library and its dependency, Java, which is essential for calculating captioning metrics.

In [None]:
import os
import subprocess
import sys
import zipfile

def run_shell_command(command, shell_mode=True):
    """Executes a shell command and raises an error if it fails."""
    try:
        print(f"Running command: {command}")
        subprocess.run(command, shell=shell_mode, check=True, capture_output=True, text=True)
    except subprocess.CalledProcessError as e:
        print(f"Error executing command: {command}")
        print(e.stderr)
        raise

def setup_environment():
    """Detects the environment and installs dependencies."""
    is_colab = "google.colab" in sys.modules
    is_runpod = os.path.exists("/workspace") or "RUNPOD_POD_ID" in os.environ
    
    if is_colab or is_runpod:
        env_type = "Google Colab" if is_colab else "RunPod"
        print(f"🚀 {env_type} environment detected. Installing dependencies...")
        
        # Install Java and Zip, dependencies for pycocoevalcap and data handling
        print("Installing Java and Zip...")
        run_shell_command("apt-get update && apt-get install -y openjdk-8-jre zip")

        # Forcibly reinstall compatible Torch and Transformers versions
        print("Cleaning and reinstalling Torch and Transformers...")
        run_shell_command("pip uninstall -y torch torchvision torchaudio transformers accelerate")
        run_shell_command("pip install torch==2.2.2 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121")
        run_shell_command("pip install transformers==4.39.3 accelerate")

        # Install other Python packages
        print("Installing other Python packages...")
        pip_commands = [
            "pip install -q 'numpy<2.0'",
            "pip install -q pandas timm tqdm opencv-python scikit-learn nltk albumentations tabulate wandb nbformat",
            "pip install -q pycocotools pycocoevalcap kaggle matplotlib",
            "pip install -U -q ipywidgets"
        ]
        for cmd in pip_commands:
            run_shell_command(cmd)
        print(f"✅ {env_type} dependencies installed successfully.")
        
        if is_colab:
             print("\n🔥 IMPORTANT: Please restart the Colab runtime now for the new libraries to take effect! 🔥")
             print("Go to 'Runtime' > 'Restart Session' in the menu above.")
             
        return ("colab" if is_colab else "runpod"), "/content" if is_colab else "/workspace"
    else:
        print("Environment: Local machine detected.")
        return "local", os.getcwd()

def setup_from_zip(zip_path, extract_to):
    """Unzips a results archive to the base path."""
    if not os.path.exists(zip_path):
        print(f"⚠️ Zip file not found at {zip_path}. Cannot set up from zip.")
        return False
    try:
        print(f"Unzipping {zip_path} to {extract_to}...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_to)
        print("✅ Unzipping complete.")
        return True
    except Exception as e:
        print(f"⚠️ Failed to unzip results from {zip_path}. Error: {e}")
        return False

# Run setup and define base_path globally
env_name, base_path = setup_environment()


### **Step 2: Main Imports and Experiment Configuration**
This cell imports all necessary libraries and defines the central `experiment_configs` dictionary. This is where you can easily switch between models (e.g., `resnet50_lstm` vs. `vit_gpt2`), adjust hyperparameters, and enable advanced training techniques like **two-phase training**, **gradient accumulation**, and **caption sampling**.

In [None]:
import glob
import json
import numpy as np
import pandas as pd
import random
import re 
import requests
import time
import shutil
from collections import Counter
from types import SimpleNamespace

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence

import timm
from transformers import AutoModel, AutoTokenizer, GPT2LMHeadModel, GPT2Tokenizer, get_cosine_schedule_with_warmup
from nltk.translate.bleu_score import corpus_bleu
from nltk.tokenize import word_tokenize
import nltk
import wandb

nltk.download('punkt', quiet=True)
nltk.download('punkt_tab', quiet=True)

from PIL import Image
import matplotlib.pyplot as plt
from IPython.display import display, Markdown
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm.notebook import tqdm

from pycocotools.coco import COCO
from pycocoevalcap.eval import COCOEvalCap

class BaseCFG:
    debug = False
    epochs = 40  # Slightly more to accommodate two-phase models
    num_workers = 2
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    force_model_retrain = False
    run_evaluation = True
    evaluate_per_epoch = False # Set to True for detailed metric tracking
    WANDB_API_KEY = "7bbf7dc1d29a93c3cd9e115741e377d149f63ee7" # Add your API key here for non-interactive login

    model_artifacts_zip_path = None

    image_size = 224
    max_length = 40  # Increased
    vocab_threshold = 3  # Lowered to allow more words

    gradient_accumulation_steps = 1
    early_stopping_patience = 10
    use_caption_sampling = True
    use_two_phase_training = True
    phase1_epochs = 5
    
experiment_configs = {
    "resnet50_lstm": {
        "models": {
            "encoder_name": "resnet50",
            "decoder_name": "lstm"
        },
        "hyperparameters": {
            "flickr8k": {
                "batch_size": 64,
                "embed_dim": 256,
                "hidden_dim": 768, # 512
                "num_layers": 1,
                "encoder_lr": 1e-4,
                "decoder_lr": 3e-4,      
                "weight_decay": 5e-4,    
                "dropout": 0.3,          
                "use_caption_sampling": True,
                "early_stopping_patience": 5, 
                "use_two_phase_training": True,
                "phase1_epochs": 8         
            },
            "flickr30k": {
                "batch_size": 96,
                "embed_dim": 512, # earlier 384
                "hidden_dim": 512,
                "num_layers": 2,
                "encoder_lr": 1e-4,
                "decoder_lr": 1.5e-4,    # From current 2e-4
                "weight_decay": 5e-4,   
                "dropout": 0.3,         
                "use_caption_sampling": True,
                "early_stopping_patience": 5, 
                "use_two_phase_training": True,
                "phase1_epochs": 10  # increased from 8      
            }
        }
    },

    "vit_gpt2": {
        "models": {
            "encoder_name": "vit_base_patch16_224",
            "decoder_name": "gpt2"
        },
        "hyperparameters": {
            "flickr8k": {
                "batch_size": 64,
                "num_workers": 8,
                "embed_dim": 768,
                "encoder_lr": 1.2e-4,  # Lowered for ViT stability
                "decoder_lr": 3e-4,  # Lowered for GPT2 stability
                "weight_decay": 2.5e-4,
                "dropout": 0.45,      # Transformer-friendly
                "gradient_accumulation_steps": 4,
                "use_two_phase_training": True,
                "use_caption_sampling": True,
                "early_stopping_patience": 10,
                "max_length": 40,
                "phase1_epochs": 15
            },
            "flickr30k": {
                "batch_size": 96,
                "num_workers": 4,
                "embed_dim": 768,
                "encoder_lr": 5e-5,
                "decoder_lr": 1e-4,
                "weight_decay": 1e-4,
                "dropout": 0.2,
                "gradient_accumulation_steps": 2,
                "use_two_phase_training": True,
                "use_caption_sampling": True,
                "early_stopping_patience": 8,
                "max_length": 40,
                "phase1_epochs": 10
            }
        }
    }
}

# Login to Weights & Biases
wandb_api_key = os.environ.get("WANDB_API_KEY") or BaseCFG.WANDB_API_KEY
try:
    if wandb_api_key:
        print("Logging into WandB using API key.")
        wandb.login(key=wandb_api_key)
    else:
        print("WandB API Key not found. Attempting interactive login.")
        wandb.login()
except Exception as e:
    print(f"Could not log in to WandB: {e}")
    print("Proceeding without WandB logging.")

### **Step 3: Path and Data Download Utilities**
These functions manage directory creation and handle the download and extraction of the Flickr datasets. The Flickr30k download logic has been updated to use a multi-part download from GitHub, removing the need for Kaggle API keys.

In [None]:
def generate_paths(base_path, dataset_name, cfg):
    """Generates and creates all necessary directory paths for an experiment."""
    model_combo_name = f"{dataset_name}_{cfg.encoder_name.replace('/', '-')}_{cfg.decoder_name.replace('/', '-')}"
    paths = {
        "dataset_name": dataset_name,
        "dataset_dir": os.path.join(base_path, "data", dataset_name),
        "image_dir": os.path.join(base_path, "data", dataset_name, "Images"),
        "captions_file": os.path.join(base_path, "data", dataset_name, f"{dataset_name}_captions.csv"),
        "model_save_path": os.path.join(base_path, "models", f"{model_combo_name}.pt"),
        "artifact_dir": os.path.join(base_path, "artifacts", model_combo_name),
        "vocab_path": os.path.join(base_path, "artifacts", model_combo_name, "vocab.pt"),
        "history_path": os.path.join(base_path, "artifacts", model_combo_name, "train_history.pt")
    }
    for path_key in ["dataset_dir", "artifact_dir"]:
        os.makedirs(paths[path_key], exist_ok=True)
    os.makedirs(os.path.dirname(paths["model_save_path"]), exist_ok=True)
    return paths

def download_with_progress(url, filename):
    """Downloads a file from a URL with a tqdm progress bar."""
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with requests.get(url, stream=True) as r, open(filename, 'wb') as f, tqdm(
        unit="B", unit_scale=True, unit_divisor=1024, total=int(r.headers.get('content-length', 0)),
        desc=f"Downloading {os.path.basename(filename)}"
    ) as bar:
        for chunk in r.iter_content(chunk_size=8192):
            f.write(chunk)
            bar.update(len(chunk))

def download_flickr(dataset_name, target_dir):
    """Downloads and extracts the specified Flickr dataset."""
    os.makedirs(target_dir, exist_ok=True)
    print(f"📥 Downloading {dataset_name}...")
    if dataset_name == 'flickr8k':
        url = "https://github.com/awsaf49/flickr-dataset/releases/download/v1.0/flickr8k.zip"
        zip_path = os.path.join(target_dir, "flickr8k.zip")
        download_with_progress(url, zip_path)
        run_shell_command(f"unzip -q -o {zip_path} -d {target_dir}")
        os.remove(zip_path)
    elif dataset_name == 'flickr30k':
        zip_path = os.path.join(target_dir, "flickr30k.zip")
        parts = [f"flickr30k_part0{i}" for i in range(3)]
        urls = [f"https://github.com/awsaf49/flickr-dataset/releases/download/v1.0/{p}" for p in parts]
        part_paths = [os.path.join(target_dir, p) for p in parts]
        for url, part_path in zip(urls, part_paths):
            download_with_progress(url, part_path)
        
        run_shell_command(f"cat {' '.join(part_paths)} > {zip_path}")
        for part in part_paths:
            os.remove(part)
        run_shell_command(f"unzip -q -o {zip_path} -d {target_dir}")
        os.remove(zip_path)
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

def clean_caption(text):
    """Cleans a single caption string."""
    text = str(text).lower().strip()
    text = re.sub(r'[^a-z0-9\s]', '', text)
    text = re.sub(r'\s+', ' ', text)
    return text

def process_captions(raw_captions_path, final_captions_path, cfg):
    print(f"Processing captions from {raw_captions_path}...")
    if not os.path.exists(raw_captions_path):
        print(f"❌ Missing raw captions file: {raw_captions_path}")
        return

    df = pd.read_csv(raw_captions_path)
    df.columns = df.columns.str.strip()
    df.rename(columns={"image_name": "image", "comment": "caption"}, inplace=True)
    df.dropna(subset=["caption"], inplace=True)
    df["caption"] = df["caption"].astype(str).str.strip().apply(clean_caption)
    df["num_tokens"] = df["caption"].apply(lambda x: len(x.split()))
    max_tokens = cfg.max_length - 2
    df = df[(df["num_tokens"] >= 3) & (df["num_tokens"] <= max_tokens)].reset_index(drop=True)

    df["caption_number"] = df.groupby("image").cumcount()
    df["id"] = df["image"].factorize()[0]
    df = df[["image", "caption_number", "caption", "id"]]

    df.to_csv(final_captions_path, index=False)
    print(f"\n✅ Preprocessing DONE")
    print(f"📝 Total captions: {len(df)}")
    print(f"🔤 Avg length: {df['caption'].apply(lambda x: len(x.split())).mean():.2f} tokens")
    print(f"📄 Saved: {final_captions_path}")


def prepare_dataset(config, cfg):
    """Main function to ensure dataset is downloaded and processed."""
    dataset_name = config["dataset_name"]
    dataset_dir = config["dataset_dir"]
    image_dir = config["image_dir"]
    final_captions_file = config["captions_file"]
    raw_captions_path = os.path.join(dataset_dir, 'captions.txt')

    # Stage 1: Check if the FINAL processed file exists.
    if os.path.exists(final_captions_file) and not BaseCFG.force_model_retrain:
        print(f"✅ Dataset '{dataset_name}' found and already processed. Skipping preparation.")
        return

    # Stage 2: Check if the RAW data exists.
    if not (os.path.exists(image_dir) and os.path.exists(raw_captions_path)) or BaseCFG.force_model_retrain:
        print(f"Raw dataset '{dataset_name}' not found or retraining forced. Starting download...")
        download_flickr(dataset_name, dataset_dir)
        # After download, we might have a nested folder, so let's restructure
        if not os.path.exists(image_dir) and os.path.exists(os.path.join(dataset_dir, 'flickr8k')):
            nested_dir = os.path.join(dataset_dir, 'flickr8k')
            for item in os.listdir(nested_dir):
                shutil.move(os.path.join(nested_dir, item), dataset_dir)
            shutil.rmtree(nested_dir)
        if not os.path.exists(image_dir) and os.path.exists(os.path.join(dataset_dir, 'flickr30k-images')):
            shutil.move(os.path.join(dataset_dir, 'flickr30k-images'), image_dir)
    else:
        print(f"Found raw dataset '{dataset_name}'. Skipping download.")

    # Stage 3: Process the raw data.
    print(f"Processing raw dataset...")
    process_captions(raw_captions_path, final_captions_file, cfg)
    print(f"Dataset '{dataset_name}' is now ready for use.")

### **Step 4: Vocabulary, Dataset, and DataLoader**
This is a critical section with major changes from the retrieval notebook.
- **`Vocabulary` class**: Builds a word mapping from the training captions, handling special tokens.
- **`CaptioningDataset`**: Prepares each image and its corresponding numericalized caption.
- **`get_transforms`**: Standard image augmentation for training and resizing for validation.
- **`collate_fn`**: A custom function for the DataLoader that pads caption sequences in each batch to be the same length. This is essential for batch processing in PyTorch.

In [None]:
class Vocabulary:
    def __init__(self, freq_threshold):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold
        self.pad_idx = self.stoi["<PAD>"]

    def __len__(self):
        return len(self.itos)

    def build_vocabulary(self, sentence_list):
        frequencies = Counter()
        idx = 4

        for sentence in sentence_list:
            for word in word_tokenize(str(sentence)):
                frequencies[word] += 1
                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1
    
    def numericalize(self, text):
        tokenized_text = word_tokenize(str(text).lower())
        return [
            self.stoi["<SOS>"]] + [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"] 
            for token in tokenized_text
        ] + [self.stoi["<EOS>"]]

class CaptioningDataset(Dataset):
    def __init__(self, df, image_dir, vocab_or_tokenizer, transforms, cfg):
        self.df = df
        self.image_dir = image_dir
        self.vocab_or_tokenizer = vocab_or_tokenizer
        self.transforms = transforms
        self.cfg = cfg
        self.is_gpt2 = isinstance(self.vocab_or_tokenizer, GPT2Tokenizer)
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        caption = self.df.caption.iloc[idx]
        image_id = self.df.image.iloc[idx]
        image_path = os.path.join(self.image_dir, image_id)
        
        try:
            image = Image.open(image_path).convert("RGB")
            image = np.array(image)
            image = self.transforms(image=image)['image']
        except (FileNotFoundError, OSError):
            print(f"Warning: Could not load image {image_path}. Skipping.")
            return None

        if self.is_gpt2:
            # For GPT-2, use its tokenizer
            encoding = self.vocab_or_tokenizer(caption, padding="max_length", truncation=True, max_length=self.cfg.max_length, return_tensors="pt")
            caption_tensor = encoding['input_ids'].squeeze(0)
        else:
            # For LSTM, use the custom vocabulary
            numericalized_caption = self.vocab_or_tokenizer.numericalize(caption)
            caption_tensor = torch.tensor(numericalized_caption)
        
        return image, caption_tensor, image_id
    
    def update_df(self, new_df):
        """Allows the Trainer to update the dataframe, used for caption sampling."""
        self.df = new_df

def get_transforms(cfg, mode="train"):
    if mode == "train":
        return A.Compose([
            A.Resize(cfg.image_size, cfg.image_size),
            A.HorizontalFlip(p=0.5),
            A.ColorJitter(p=0.5),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
    else:
        return A.Compose([
            A.Resize(cfg.image_size, cfg.image_size),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])

class Collate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx
        
    def __call__(self, batch):
        # Filter out None values from missing images
        batch = [b for b in batch if b is not None]
        if not batch:
            return None, None, None
        
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=True, padding_value=self.pad_idx)
        image_ids = [item[2] for item in batch]
        
        return imgs, targets, image_ids
    
def make_train_valid_dfs(config):
    df = pd.read_csv(config['captions_file'])
    df = df.dropna().reset_index(drop=True)
    
    image_files_in_dir = set(os.listdir(config['image_dir']))
    df = df[df['image'].isin(image_files_in_dir)].reset_index(drop=True)
    
    # For splitting, we use unique images. For training, we use all 5 captions per image.
    unique_images = df['image'].unique()
    np.random.seed(42)
    train_mask = np.random.rand(len(unique_images)) < 0.9
    train_images = unique_images[train_mask]
    valid_images = unique_images[~train_mask]
    
    train_df = df[df['image'].isin(train_images)].reset_index(drop=True)
    # For validation, we keep all 5 captions to calculate metrics correctly
    valid_df = df[df['image'].isin(valid_images)].reset_index(drop=True)
    
    return train_df, valid_df

def build_loaders(df, image_dir, vocab_or_tokenizer, mode, cfg, shuffle=True):
    transforms = get_transforms(cfg, mode)
    dataset = CaptioningDataset(df, image_dir, vocab_or_tokenizer, transforms, cfg)
    
    pad_idx = vocab_or_tokenizer.pad_token_id if isinstance(vocab_or_tokenizer, GPT2Tokenizer) else vocab_or_tokenizer.pad_idx
    
    dataloader = DataLoader(
        dataset=dataset,
        batch_size=cfg.batch_size,
        num_workers=cfg.num_workers,
        shuffle=shuffle,
        collate_fn=Collate(pad_idx=pad_idx),
        pin_memory=True
    )
    return dataloader


### **Step 5: Model Definitions**
This cell defines the core Encoder-Decoder models.
- **`Encoder`**: A wrapper for `timm` models (ResNet50, ViT) to produce image features.
- **`Decoder`**: Separate implementations for the LSTM and the GPT-2 Transformer decoders.
- **`EncoderDecoder`**: The main model that combines an encoder and a decoder. It includes a `forward` method for training and a `generate_caption` method for inference.

In [None]:
class Encoder(nn.Module):
    def __init__(self, model_name, pretrained=True):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=0, global_pool='avg')
            
    def forward(self, x):
        return self.model(x)

class LSTMDecoder(nn.Module):
    def __init__(self, embed_dim, hidden_dim, vocab_size, num_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout)
        self.linear = nn.Linear(hidden_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, features, captions):
        embeddings = self.dropout(self.embedding(captions))
        # Correctly expand initial hidden state for multi-layer LSTMs
        h0 = features.unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
        c0 = features.unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
        hiddens, _ = self.lstm(embeddings, (h0, c0))
        outputs = self.linear(hiddens)
        return outputs

class TransformerDecoder(nn.Module):
    def __init__(self, embed_dim, model_name="gpt2"):
        super().__init__()
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = GPT2LMHeadModel.from_pretrained(model_name)
        self.model.config.pad_token_id = self.model.config.eos_token_id
        
        # A linear layer to project ViT's embedding dim to GPT-2's if they differ
        self.projection = nn.Linear(embed_dim, self.model.config.n_embd)
        
    def forward(self, features, captions):
        projected_features = self.projection(features).unsqueeze(1)
        caption_embeddings = self.model.transformer.wte(captions)
        inputs_embeds = torch.cat([projected_features, caption_embeddings], dim=1)
        attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.long, device=features.device)
        
        outputs = self.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        logits = outputs.logits
        
        # Drop the prediction for the initial image feature token to align sequences
        return logits[:, :-1, :]

class EncoderDecoder(nn.Module):
    def __init__(self, cfg, vocab_size_or_tokenizer):
        super().__init__()
        self.cfg = cfg
        self.encoder = Encoder(cfg.encoder_name)
        encoder_output_dim = self.encoder.model.num_features
        
        if cfg.decoder_name == "lstm":
            # For LSTM, project encoder output to match decoder's hidden_dim
            self.feature_proj = nn.Linear(encoder_output_dim, cfg.hidden_dim)
            self.decoder = LSTMDecoder(cfg.embed_dim, cfg.hidden_dim, vocab_size_or_tokenizer, cfg.num_layers, cfg.dropout)
        elif cfg.decoder_name == "gpt2":
            # GPT-2 decoder has its own projection layer
            self.decoder = TransformerDecoder(encoder_output_dim)
        else:
            raise ValueError(f"Unknown decoder: {cfg.decoder_name}")

    def set_encoder_trainable(self, trainable=True):
        """Helper function to freeze/unfreeze the encoder's weights."""
        print(f"Setting encoder trainability to: {trainable}")
        if self.cfg.encoder_name.startswith('resnet'):
            # Freeze all layers initially
            for param in self.encoder.parameters():
                param.requires_grad = False
            
            # If fine-tuning (Phase 2), unfreeze the final block
            if trainable:
                print("  - Unfreezing final block (layer4) of ResNet for fine-tuning.")
                for param in self.encoder.model.layer4.parameters():
                    param.requires_grad = True
        else: # Default behavior for other models like ViT
            for param in self.encoder.parameters():
                param.requires_grad = trainable
            
    def forward(self, images, captions):
        features = self.encoder(images)
        if self.cfg.decoder_name == "lstm":
            features = self.feature_proj(features)
            outputs = self.decoder(features, captions)
        else: # GPT-2
            outputs = self.decoder(features, captions)
        return outputs

    def generate_caption(self, image, vocab=None, max_length=30):
        self.eval()
        
        with torch.no_grad():
            image = image.unsqueeze(0).to(self.cfg.device)
            features = self.encoder(image)
            
            if self.cfg.decoder_name == "lstm":
                result_caption = []
                features = self.feature_proj(features)
                states = (features.unsqueeze(0).repeat(self.decoder.lstm.num_layers, 1, 1), 
                          features.unsqueeze(0).repeat(self.decoder.lstm.num_layers, 1, 1))
                inputs = self.decoder.embedding(torch.tensor([vocab.stoi["<SOS>"]]).to(self.cfg.device)).unsqueeze(1)
                
                for _ in range(max_length):
                    hiddens, states = self.decoder.lstm(inputs, states)
                    outputs = self.decoder.linear(hiddens.squeeze(1))
                    predicted_idx = outputs.argmax(1)
                    inputs = self.decoder.embedding(predicted_idx).unsqueeze(1)
                    
                    if predicted_idx.item() == vocab.stoi["<EOS>"]:
                        break
                    result_caption.append(vocab.itos[predicted_idx.item()])
                return " ".join(result_caption)
            else: # GPT-2
                features = self.decoder.projection(features).unsqueeze(1)
                attention_mask = torch.ones(features.shape[:2], dtype=torch.long, device=features.device)
                output_ids = self.decoder.model.generate(
                    inputs_embeds=features, 
                    attention_mask=attention_mask,
                    max_length=max_length, 
                    num_beams=5, 
                    early_stopping=True, 
                    pad_token_id=self.decoder.tokenizer.eos_token_id,
                    eos_token_id=self.decoder.tokenizer.eos_token_id
                )
                caption = self.decoder.tokenizer.decode(output_ids[0], skip_special_tokens=True)
                return caption.strip()


### **Step 6: Trainer and Utilities**
- **`AvgMeter`**: A simple utility for tracking average metrics.
- **`Trainer`**: Manages the training and validation loops. The loops are adapted for a sequence-to-sequence task, calculating loss at each timestep.

In [None]:
class AvgMeter:
    def __init__(self, name="Metric"): self.name, self.avg, self.sum, self.count = name, 0, 0, 0
    def update(self, val, n=1): self.sum += val * n; self.count += n; self.avg = self.sum / self.count
    def __repr__(self): return f"{self.name}: {self.avg:.4f}"

class Trainer:
    def __init__(self, model, loss_fn, optimizer, scheduler, cfg):
        self.model = model
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.cfg = cfg
        self.device = cfg.device

    def _train_one_epoch(self, train_loader):
        loss_meter = AvgMeter()
        self.model.train()
        progress_bar = tqdm(train_loader, total=len(train_loader), desc="Training")
        self.optimizer.zero_grad() # Reset gradients at the start of the epoch
        
        for i, (images, captions, _) in enumerate(progress_bar):
            if images is None: continue
            images, captions = images.to(self.device), captions.to(self.device)
            
            outputs = self.model(images, captions[:, :-1])
            loss = self.loss_fn(outputs.reshape(-1, outputs.shape[2]), captions[:, 1:].reshape(-1))
            
            # Gradient Accumulation
            loss = loss / self.cfg.gradient_accumulation_steps
            loss.backward()

            if (i + 1) % self.cfg.gradient_accumulation_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()
                if self.scheduler: self.scheduler.step()
                
            loss_meter.update(loss.item() * self.cfg.gradient_accumulation_steps, images.size(0))
            progress_bar.set_postfix(train_loss=loss_meter.avg, lr=self.optimizer.param_groups[0]['lr'])
            
        return loss_meter

    def _valid_one_epoch(self, valid_loader):
        loss_meter = AvgMeter()
        self.model.eval()
        progress_bar = tqdm(valid_loader, total=len(valid_loader), desc="Validation")
        with torch.no_grad():
            for images, captions, _ in progress_bar:
                if images is None: continue
                images, captions = images.to(self.device), captions.to(self.device)
                
                outputs = self.model(images, captions[:, :-1])
                loss = self.loss_fn(outputs.reshape(-1, outputs.shape[2]), captions[:, 1:].reshape(-1))
                
                loss_meter.update(loss.item(), images.size(0))
                progress_bar.set_postfix(valid_loss=loss_meter.avg)
        return loss_meter

    def fit(self, train_loader, valid_loader, config, start_epoch=0):
        best_loss = float('inf')
        history = {"train_loss": [], "valid_loss": [], "epoch_times": []}
        patience_counter = 0

        for epoch in range(start_epoch, self.cfg.epochs):
            epoch_start_time = time.time()
            print(f"\nEpoch: {epoch + 1}/{self.cfg.epochs}")
            
            # Caption Sampling
            current_train_loader = train_loader
            if self.cfg.use_caption_sampling:
                print("Sampling one caption per image for training this epoch...")
                sampled_df = train_loader.dataset.df.groupby('image').sample(1).reset_index(drop=True)
                #current_train_loader = build_loaders(sampled_df, train_loader.dataset.image_dir, train_loader.dataset.vocab_or_tokenizer, 'train', {}, self.cfg, shuffle=True)
                current_train_loader = build_loaders(sampled_df, train_loader.dataset.image_dir, train_loader.dataset.vocab_or_tokenizer, 'train', self.cfg, shuffle=True)
            
            train_loss = self._train_one_epoch(current_train_loader)
            valid_loss = self._valid_one_epoch(valid_loader)
            
            epoch_end_time = time.time()
            epoch_duration = epoch_end_time - epoch_start_time
            history['train_loss'].append(train_loss.avg)
            history['valid_loss'].append(valid_loss.avg)
            history['epoch_times'].append(epoch_duration)
            
            print(f"Epoch {epoch+1} | Train Loss: {train_loss.avg:.4f} | Valid Loss: {valid_loss.avg:.4f} | Time: {epoch_duration:.2f}s")
            
            if self.cfg.evaluate_per_epoch:
                # Simplified evaluation (BLEU-4 only) for per-epoch tracking
                temp_scores, _, _ = generate_and_evaluate(self.model, valid_loader, train_loader.dataset.vocab_or_tokenizer, config)
                wandb.log({"val_bleu4": temp_scores.get("Bleu_4", 0), "epoch": epoch})

            wandb.log({"train_loss": train_loss.avg, "valid_loss": valid_loss.avg, "epoch": epoch})

            if valid_loss.avg < best_loss:
                best_loss = valid_loss.avg
                torch.save({'epoch': epoch + 1, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict()}, self.cfg.model_save_path)
                print(f"Saved Best Model! Validation Loss: {best_loss:.4f}")
                patience_counter = 0
            else:
                patience_counter += 1
                print(f"Validation loss did not improve. Patience: {patience_counter}/{self.cfg.early_stopping_patience}")
                if patience_counter >= self.cfg.early_stopping_patience:
                    print("Early stopping triggered.")
                    break
        
        return history


### **Step 7: Evaluation and Reporting Functions**
These functions handle the caption generation for the entire validation set and then use the `pycocoevalcap` library to compute standard metrics.
- **`generate_and_evaluate`**: Iterates through the validation loader, generates a caption for each image, and stores the results.
- **`get_coco_scores`**: Formats the generated captions and ground truths into the required JSON structure and runs the COCO evaluation scripts.

In [None]:
def generate_and_evaluate(model, dataloader, vocab, config):
    """Generates captions for a dataloader and prepares data for COCO evaluation."""
    print(f"Generating captions for {config['dataset_name']} validation set...")
    model.eval()
    
    results = []
    ground_truths = []
    image_ids_processed = set()
    
    with torch.no_grad():
        for images, captions_gt, image_ids in tqdm(dataloader, desc="Generating Captions"):
            if images is None: continue
            images = images.to(BaseCFG.device)
            
            for i in range(images.size(0)):
                image_id = image_ids[i]
                
                if image_id not in image_ids_processed:
                    generated_caption = model.generate_caption(images[i], vocab)
                    results.append({"image_id": image_id, "caption": generated_caption})
                    image_ids_processed.add(image_id)
    
    # Prepare ground truths from the validation dataframe
    valid_df = dataloader.dataset.df
    annotations = []
    images_info = []
    ann_id_counter = 1
    for img_id in image_ids_processed:
        images_info.append({"id": img_id})
        captions_for_image = valid_df[valid_df['image'] == img_id]['caption'].tolist()
        for cap in captions_for_image:
            annotations.append({"image_id": img_id, "id": ann_id_counter, "caption": cap})
            ann_id_counter += 1
    
    ground_truths = {
        "info": {"description": "Ground-truth captions for evaluation"},
        "images": images_info,
        "licenses": [],
        "annotations": annotations,
        "type": "captions"
    }
    
    # Calculate scores
    scores = get_coco_scores(results, ground_truths, config["artifact_dir"])
    return scores, results, ground_truths

def get_coco_scores(res, gts, artifact_dir):
    """Uses pycocoevalcap to calculate captioning metrics."""
    # Ensure artifact directory exists
    os.makedirs(artifact_dir, exist_ok=True)
    res_file = os.path.join(artifact_dir, "results.json")
    gts_file = os.path.join(artifact_dir, "ground_truths.json")
    
    with open(res_file, "w") as f:
        json.dump(res, f)
        
    with open(gts_file, "w") as f:
        json.dump(gts, f)

    coco = COCO(gts_file)
    coco_res = coco.loadRes(res_file)

    coco_eval = COCOEvalCap(coco, coco_res)
    coco_eval.evaluate()

    return coco_eval.eval

def get_parameter_counts(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params


### **Step 8: The Main Pipeline Function**
This function orchestrates the entire process for a single experiment: data preparation, vocabulary building, model training (or loading), evaluation, and result aggregation. It returns a dictionary containing all results and artifacts.

In [None]:
def run_pipeline(config, cfg):
    print("-" * 50)
    print(f"STARTING PIPELINE FOR: {config['dataset_name'].upper()}")
    print(f"With model: {cfg.encoder_name} + {cfg.decoder_name}")
    print("-" * 50)
    
    # --- Setup: Data & Vocabulary ---
    print("\nSetting up datasets and vocabulary...")
    prepare_dataset(config, cfg)
    train_df, valid_df = make_train_valid_dfs(config)
    
    if train_df is None or valid_df is None or train_df.empty or valid_df.empty:
        print("Could not create dataframes. Aborting.")
        return None
        
    # --- Vocabulary or Tokenizer ---
    if cfg.decoder_name == 'gpt2':
        print("Using GPT-2 Tokenizer.")
        vocab_or_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        vocab_or_tokenizer.pad_token = vocab_or_tokenizer.eos_token
        vocab_size_or_tokenizer = None # Not needed for GPT-2 model init
    else:
        print("Building custom vocabulary...")
        vocab = Vocabulary(freq_threshold=cfg.vocab_threshold)
        if os.path.exists(config["vocab_path"]) and not cfg.force_model_retrain:
            print("Loading existing vocabulary...")
            vocab = torch.load(config["vocab_path"], weights_only=False)
        else:
            print("Building new vocabulary...")
            vocab.build_vocabulary(train_df.caption.tolist())
            torch.save(vocab, config["vocab_path"])
        vocab_size_or_tokenizer = len(vocab)
        vocab_or_tokenizer = vocab
        print(f"Vocabulary size: {vocab_size_or_tokenizer}")

    # Build data loaders
    train_loader = build_loaders(train_df, config['image_dir'], vocab_or_tokenizer, 'train', cfg, shuffle=True)
    valid_loader = build_loaders(valid_df, config['image_dir'], vocab_or_tokenizer, 'valid', cfg, shuffle=False)
    
    # --- Model Creation and Loading ---
    print("\nCreating model...")
    model = EncoderDecoder(cfg, vocab_size_or_tokenizer).to(cfg.device)
    cfg.model_save_path = config['model_save_path'] # Pass save path to Trainer
    
    training_history = None
    total_training_duration = 0
    
    if os.path.exists(cfg.model_save_path) and not cfg.force_model_retrain:
        print(f"Model found at '{cfg.model_save_path}'. Loading weights...")
        checkpoint = torch.load(cfg.model_save_path, map_location=cfg.device)
        model.load_state_dict(checkpoint['model_state_dict'])
        if os.path.exists(config['history_path']):
            training_history = torch.load(config['history_path'])
    else:
        # --- Training --- #
        start_time = time.time()
        if cfg.use_two_phase_training:
            # Phase 1: Train only the decoder
            print("--- Starting Two-Phase Training: Phase 1 (Decoder Only) ---")
            model.set_encoder_trainable(False)
            decoder_params = [p for p in model.parameters() if p.requires_grad]
            optimizer = torch.optim.Adam(decoder_params, lr=cfg.decoder_lr, weight_decay=cfg.weight_decay)
            trainer = Trainer(model, nn.CrossEntropyLoss(ignore_index=vocab_or_tokenizer.pad_token_id), optimizer, None, cfg)
            cfg.epochs = cfg.phase1_epochs # Temporarily set epochs for phase 1
            history1 = trainer.fit(train_loader, valid_loader, config)
            
            # Phase 2: Train the full model with differential learning rates
            print("--- Two-Phase Training: Phase 2 (Full Model with Differential LRs) ---")
            model.set_encoder_trainable(True)
            optimizer_params = [
                {"params": model.encoder.parameters(), "lr": cfg.encoder_lr},
                {"params": model.decoder.parameters(), "lr": cfg.decoder_lr}
            ]
            optimizer = torch.optim.Adam(optimizer_params, weight_decay=cfg.weight_decay)
            trainer = Trainer(model, nn.CrossEntropyLoss(ignore_index=vocab_or_tokenizer.pad_token_id), optimizer, None, cfg)
            cfg.epochs = BaseCFG.epochs # Reset to original epochs
            history2 = trainer.fit(train_loader, valid_loader, config, start_epoch=cfg.phase1_epochs)
            
            # Combine histories
            training_history = {k: history1[k] + history2[k] for k in history1}
        else:
            # Standard one-phase training
            print("--- Starting Standard Training ---")
            optimizer = torch.optim.Adam(model.parameters(), lr=cfg.decoder_lr, weight_decay=cfg.weight_decay)
            trainer = Trainer(model, nn.CrossEntropyLoss(ignore_index=vocab_or_tokenizer.pad_token_id), optimizer, None, cfg)
            training_history = trainer.fit(train_loader, valid_loader, config)
            
        total_training_duration = time.time() - start_time
        torch.save(training_history, config['history_path'])

    if not cfg.run_evaluation:
        print("Skipping evaluation as per configuration.")
        return {"history": training_history, "model": model, "config": config, "cfg": cfg, "vocab": vocab_or_tokenizer, "train_loader_len": len(train_loader)}

    # --- Evaluation ---
    print("\nStarting evaluation...")
    scores, generated_captions, ground_truths = generate_and_evaluate(model, valid_loader, vocab_or_tokenizer, config)
        
    print(f"\nPIPELINE FOR {config['dataset_name'].upper()} COMPLETE")
    return {
        "history": training_history,
        "metrics": scores,
        "generated_captions": generated_captions,
        "ground_truths": ground_truths,
        "model": model, 
        "config": config, 
        "cfg": cfg,
        "vocab": vocab_or_tokenizer,
        "duration": total_training_duration,
        "train_loader_len": len(train_loader)
    }


### **Step 9: Main Execution Loop**
This cell runs the main pipeline. It iterates through the `experiment_configs`, sets up the configuration for each run, and calls the `run_pipeline` function. All results are collected in the `results_history` dictionary for the final reporting step.

In [None]:
if __name__ == '__main__':
    results_history = {}
    
    if BaseCFG.model_artifacts_zip_path and setup_from_zip(BaseCFG.model_artifacts_zip_path, base_path):
        print("\n📦 Switched to 'Evaluation from Zip' mode.")
        BaseCFG.force_model_retrain = False
    
    experiments_to_run = ["vit_gpt2"] # , "resnet50_lstm"
    datasets_to_process = ["flickr8k"] # , "flickr30k"

    for exp_name in experiments_to_run:
        if exp_name not in experiment_configs:
            print(f"Skipping unknown experiment: {exp_name}")
            continue
            
        print("\n" + "="*80)
        print(f"                RUNNING EXPERIMENT: {exp_name.upper()}")
        print("="*80 + "\n")
        
        exp_params = experiment_configs[exp_name]
        
        for dataset_name in datasets_to_process:
            base_cfg_dict = {k: v for k, v in BaseCFG.__dict__.items() if not k.startswith('__')}
            # Overwrite base config with experiment-specific ones
            hyperparams = {**base_cfg_dict, **exp_params.get("hyperparameters", {}).get(dataset_name, {})}
            combined_params = {**hyperparams, **exp_params["models"]}
            cfg = SimpleNamespace(**combined_params)
            
            path_config = generate_paths(base_path, dataset_name, cfg)
            
            try:
                wandb.init(
                    project="image-captioning-experiments",
                    name=f"{exp_name}-{dataset_name}-{int(time.time())}",
                    config=vars(cfg)
                )
                run_results = run_pipeline(path_config, cfg)
                if run_results:
                    if dataset_name not in results_history:
                        results_history[dataset_name] = {}
                    results_history[dataset_name][exp_name] = run_results
                    wandb.log(run_results['metrics'])
            except Exception as e:
                print(f"\n❌ An error occurred during the pipeline for {exp_name} on {dataset_name}.")
                print(f"Error: {e}")
                import traceback
                traceback.print_exc()
            finally:
                wandb.finish()
                


### **Step 10: Final Report Generation**
This final, decoupled cell iterates through the collected results and generates all plots, tables, and qualitative examples in one uninterrupted flow to prevent rendering issues in the notebook.

In [None]:
def plot_loss_curves(history, exp_name, dataset_name):
    if not history or 'train_loss' not in history or 'valid_loss' not in history:
        print("No training history found to plot.")
        return
    
    plt.figure(figsize=(10, 5))
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['valid_loss'], label='Validation Loss')
    plt.title(f'Loss Curves for {exp_name} on {dataset_name}')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.show()

def display_metrics_table(metrics):
    if not metrics:
        print("No metrics to display.")
        return
    
    df = pd.DataFrame([metrics])
    df = df.round(3)
    display(Markdown(df.to_markdown(index=False)))

def display_performance_summary(run_data):
    model = run_data['model']
    training_history = run_data['history']
    total_training_duration = run_data.get('duration', 0)
    train_loader_len = run_data.get('train_loader_len', 0)
    cfg = run_data['cfg']
    config = run_data['config']

    if training_history:
        total_epochs_trained = len(training_history['train_loss'])
        avg_epoch_time = sum(training_history['epoch_times']) / len(training_history['epoch_times']) if training_history.get('epoch_times') else 0
        iterations_per_epoch = train_loader_len
        avg_iteration_time = avg_epoch_time / iterations_per_epoch if iterations_per_epoch > 0 else 0
    else:
        total_epochs_trained = "N/A (Loaded from checkpoint)"
        avg_epoch_time = 0
        avg_iteration_time = 0

    # Parameter breakdown
    encoder_params = sum(p.numel() for p in model.encoder.parameters() if p.requires_grad)
    decoder_params = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad)
    proj_params = 0
    if hasattr(model, 'feature_proj'):
        proj_params = sum(p.numel() for p in model.feature_proj.parameters() if p.requires_grad)
    
    summary_md = f"""**Performance Summary for {config['dataset_name']} - {cfg.encoder_name} + {cfg.decoder_name}**
| Metric | Value |
| :--- | :--- |
| **GPU Used** | {torch.cuda.get_device_name(0) if cfg.device.type == 'cuda' else 'CPU'} |
| **Total Parameters** | {sum(p.numel() for p in model.parameters()):,} |
| **Trainable Parameters** | {sum(p.numel() for p in model.parameters() if p.requires_grad):,} |
| | |
| **Trainable Breakdown** | |
| &nbsp; &nbsp; Encoder | {encoder_params:,} |
| &nbsp; &nbsp; Decoder | {decoder_params:,} |
"""
    if proj_params > 0:
        summary_md += f"| &nbsp; &nbsp; Projection Head | {proj_params:,} |\n"

    summary_md += f"""| | |
| **Training Details** | |
| &nbsp; &nbsp; Total Epochs Trained | {total_epochs_trained} |
| &nbsp; &nbsp; Batch Size | {cfg.batch_size} |
| &nbsp; &nbsp; Decoder LR | {cfg.decoder_lr} |
| &nbsp; &nbsp; Encoder LR | {cfg.encoder_lr} |
| &nbsp; &nbsp; Optimizer | AdamW |
| &nbsp; &nbsp; Vocab Size | {len(run_data['vocab'])} |
| &nbsp; &nbsp; Dropout | {cfg.dropout} |
"""
    if cfg.decoder_name == 'lstm':
        summary_md += f"| &nbsp; &nbsp; LSTM Hidden Size | {cfg.hidden_dim} |\n"
        summary_md += f"| &nbsp; &nbsp; LSTM Layers | {cfg.num_layers} |\n"

    summary_md += f"""| | |
| **Timings** | |
| &nbsp; &nbsp; Total Training Time | {total_training_duration:.2f} s ({total_training_duration/60:.2f} min) |
| &nbsp; &nbsp; Average Time per Epoch | {avg_epoch_time:.2f} s |
| &nbsp; &nbsp; Average Time per Iteration | {avg_iteration_time:.4f} s |
"""
    display(Markdown(summary_md))

def show_qualitative_results(run_data, num_examples=3):
    if 'generated_captions' not in run_data or 'ground_truths' not in run_data:
        print("Qualitative results not available.")
        return
        
    print("\n" + "="*50)
    print("           QUALITATIVE ANALYSIS: GENERATED CAPTIONS")
    print("="*50 + "\n")
    
    generated_map = {item['image_id']: item['caption'] for item in run_data['generated_captions']}
    gt_map = {}
    for ann in run_data['ground_truths']['annotations']:
        img_id = ann['image_id']
        if img_id not in gt_map:
            gt_map[img_id] = []
        gt_map[img_id].append(ann['caption'])
        
    image_ids = random.sample(list(generated_map.keys()), min(num_examples, len(generated_map)))
    
    for image_id in image_ids:
        image_path = os.path.join(run_data['config']['image_dir'], image_id)
        if not os.path.exists(image_path):
            continue
            
        plt.figure(figsize=(8, 8))
        image = Image.open(image_path)
        plt.imshow(image)
        plt.axis('off')
        plt.show()
        
        display(Markdown(f"**Generated Caption:** `{generated_map[image_id]}`"))
        display(Markdown("**Ground Truths:**"))
        for gt_caption in gt_map.get(image_id, []):
            display(Markdown(f"- *{gt_caption}*"))
        print("-"*50)
        

def generate_final_report(results):
    print("\n" + "="*60)
    print("           FINAL COMPARATIVE ANALYSIS")
    print("="*60 + "\n")
    
    for dataset_name, exps in results.items():
        display(Markdown(f'## 📊 Results for Dataset: `{dataset_name}`'))
        for exp_name, run_data in exps.items():
            display(Markdown(f'### 🔬 Experiment: `{exp_name}`'))
            
            # --- Performance Summary ---
            display_performance_summary(run_data)
            
            # --- Loss Curves ---
            plot_loss_curves(run_data.get('history'), exp_name, dataset_name)
            
            # --- Metrics Table ---
            display(Markdown("#### Evaluation Metrics"))
            display_metrics_table(run_data.get('metrics'))
            
            # --- Qualitative Examples ---
            if run_data['cfg'].run_evaluation:
                show_qualitative_results(run_data)
            
            sys.stdout.flush() # Ensure outputs are displayed in order
            time.sleep(1.0)


if __name__ == '__main__':
    if BaseCFG.run_evaluation:
        generate_final_report(results_history)


### **Final Review Checklist**
This checklist confirms that all identified bugs and robustness improvements have been integrated into the script.

| ID | Status | Description | Location of Fix |
| :--- | :--- | :--- | :--- |
| 1 | ✅ | Dependencies are explicitly installed (`pycocoevalcap`, `albumentations`, `tabulate`). | Step 1 |
| 2 | ✅ | System dependencies (`java`, `zip`) are installed. | Step 1 |
| 3 | ✅ | A compatible PyTorch version is forcibly installed to prevent CUDA errors. | Step 1 |
| 4 | ✅ | A compatible NumPy version (`<2.0`) is explicitly installed. | Step 1 |
| 5 | ✅ | The data processing pipeline correctly parses image filenames from `captions.txt`. | Step 3 |
| 6 | ✅ | The `torch.load` call for the vocabulary object uses `weights_only=False`. | Step 8 |
| 7 | ✅ | Teacher-forcing slicing (`[:, :-1]`) is handled correctly and consistently in the `Trainer`. | Step 6 |
| 8 | ✅ | The GPT-2 `generate` method is correctly passed the `eos_token_id`. | Step 5 |
| 9 | ✅ | The ground-truth JSON for evaluation includes the required `"info"` key. | Step 7 |
| 10 | ✅ | The `base_path` variable is correctly defined in the global scope. | Step 1 |