# Restructured Multimodal Retrieval Notebook
### *Flexible, Efficient, and Portable for Flickr8k/30k Datasets*

This notebook has been refactored into a professional research pipeline. Key features and improvements include:
- **Experiment-Driven Configuration**: A master `experiment_configs` dictionary allows for defining specific, fine-tuned hyperparameters for every model architecture and dataset combination.
- **Flexible Model Factory**: The architecture is fully modular, allowing for easy experimentation with different encoders (e.g., ResNet, ViT, RoBERTa, DistilBERT) by simply changing a model name in the configuration.
- **Automated & Portable Workflow**: The script auto-detects the execution environment (Colab, RunPod, local) and generates a standardized directory structure for data, models, and embeddings.
- **Robust Data Pipeline**: Correctly downloads and processes both Flickr8k and the multi-part Flickr30k dataset, including caption standardization, directly in the target directory without unnecessary file moving.
- **Advanced Training Techniques**: Incorporates a `Trainer` class, gradient accumulation, early stopping, caption sampling for regularization, and a cosine annealing learning rate scheduler.
- **Efficient Checkpointing & Controls**: Skips re-downloading, re-training, and re-generating embeddings if artifacts already exist. Includes flags like `force_retrain`, `run_embedding_generation`, and `run_evaluation` for fine-grained control.
- **Comprehensive Evaluation**: Performs bidirectional (Text-to-Image, Image-to-Text, and Text-to-Text) retrieval evaluation using Top-K (1-5) accuracy, precision, and recall.
- **Qualitative & Comparative Analysis**: Generates qualitative examples of retrieval results and provides side-by-side plots and summary tables for easy comparison across datasets.

### **Step 1: Imports and Setup**

In [None]:
!pip install -q timm pandas tqdm albumentations opencv-python scikit-learn transformers torch torchvision torchaudio
!pip install kaggle matplotlib
!pip install -U ipywidgets

# Specific to RunPod to avoid unknown command Unzip
!apt update && apt install zip -y

In [None]:
import os
import glob
import numpy as np
import pandas as pd
import subprocess
import shutil
import requests
import re
import time
import random
from tqdm import tqdm
from collections import defaultdict, Counter
from types import SimpleNamespace

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import CosineAnnealingLR

import timm
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup

from sklearn.model_selection import train_test_split
from PIL import Image
import matplotlib.pyplot as plt
from IPython.display import display, Markdown
import albumentations as A
from albumentations.pytorch import ToTensorV2

### **Step 2: Experiment Configurations**
This is the master control panel. A `BaseCFG` holds common parameters. The `experiment_configs` dictionary defines specific models and fine-tuned hyperparameters for each model-dataset combination.

#### **Workflow Control:**
- **Train Only**: `force_retrain = True`, `run_embedding_generation = False`, `run_evaluation = False`
- **Train & Generate Embeddings**: `force_retrain = True`, `run_embedding_generation = True`, `run_evaluation = False`
- **Evaluate Existing Embeddings Only**: `force_retrain = False`, `run_embedding_generation = True` (will skip if files exist), `run_evaluation = True`

In [None]:
class BaseCFG:
    # --- Core Hyperparameters & Controls ---
    debug = False
    epochs = 30
    batch_size = 32
    num_workers = 2
    gradient_accumulation_steps = 2
    early_stopping_patience = 7
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    force_retrain = False
    run_embedding_generation = True
    run_evaluation = True
    show_comparative_plots = True

    # --- LR and Scheduler ---
    warmup_epochs = 5
    lr_min = 1e-6

    # --- Model Configuration ---
    size = 224
    max_length = 200
    pretrained = True
    trainable = True

experiment_configs = {
    "resnet50_distilbert": {
        "models": {
            "image_model_name": "resnet50",
            "text_encoder_model": "distilbert-base-uncased",
            "text_tokenizer": "distilbert-base-uncased"
        },
        "hyperparameters": {
            "flickr8k": { # Reverted to stable parameters
                "head_lr": 1e-3, "image_encoder_lr": 1e-4, "text_encoder_lr": 2e-5, 
                "weight_decay": 1e-3, "projection_dim": 256, "dropout": 0.1, "temperature": 0.07
            },
            "flickr30k": {
                "head_lr": 1e-3, "image_encoder_lr": 1e-4, "text_encoder_lr": 2e-5, 
                "weight_decay": 1e-3, "projection_dim": 256, "dropout": 0.1, "temperature": 0.07
            }
        }
    },
    "vit_roberta": {
        "models": {
            "image_model_name": "vit_base_patch16_224",
            "text_encoder_model": "roberta-base",
            "text_tokenizer": "roberta-base"
        },
        "hyperparameters": {
            "flickr8k": {
                "head_lr": 5e-4, "image_encoder_lr": 1e-5, "text_encoder_lr": 5e-6,
                "weight_decay": 2e-3, "projection_dim": 512, "dropout": 0.2, "temperature": 0.1
            },
            "flickr30k": {
                "head_lr": 1e-3, "image_encoder_lr": 5e-5, "text_encoder_lr": 2e-5, 
                "weight_decay": 1e-3, "projection_dim": 512, "dropout": 0.2, "temperature": 0.1
            }
        }
    }
}

### **Step 3: Environment, Path, and Data Utilities**

In [None]:
def detect_environment():
    try:
        import google.colab
        print("Environment: Google Colab detected.")
        return "colab", "/content"
    except ImportError: pass
    if os.path.exists("/workspace") or "RUNPOD_POD_ID" in os.environ:
        print("Environment: RunPod detected.")
        return "runpod", "/workspace"
    print("Environment: Local machine detected.")
    return "local", os.getcwd()

def generate_paths(base_path, dataset_name, cfg):
    dataset_dir = os.path.join(base_path, "data", dataset_name)
    models_dir = os.path.join(base_path, "models")
    embeddings_dir = os.path.join(base_path, "embeddings")
    paths = {
        "dataset_name": dataset_name,
        "dataset_dir": dataset_dir,
        "image_dir": os.path.join(dataset_dir, "Images"),
        "captions_file": os.path.join(dataset_dir, f"{dataset_name}_captions.csv"),
        "model_save_path": os.path.join(models_dir, f"{dataset_name}_{cfg.image_model_name.replace('/', '-')}_{cfg.text_encoder_model.replace('/', '-')}.pt"),
        "embedding_save_path": os.path.join(embeddings_dir, dataset_name, f"{cfg.image_model_name.replace('/', '-')}_{cfg.text_encoder_model.replace('/', '-')}"),
    }
    for path in [paths["image_dir"], os.path.dirname(paths["model_save_path"]), paths["embedding_save_path"]]:
        os.makedirs(path, exist_ok=True)
    return paths

def run_shell_command(command):
    try:
        print(f"Running command: {' '.join(command)}")
        result = subprocess.run(command, shell=True, check=True, capture_output=True, text=True)
    except subprocess.CalledProcessError as e:
        print(f"Error executing command: {' '.join(command)}")
        print(e.stderr)
        raise

def download_with_progress(url, filename):
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    resp = requests.get(url, stream=True)
    total = int(resp.headers.get('content-length', 0))
    with open(filename, 'wb') as file, tqdm(
        desc=f"Downloading {os.path.basename(filename)}", total=total, unit='B', unit_scale=True, unit_divisor=1024
    ) as bar:
        for chunk in resp.iter_content(chunk_size=8192):
            file.write(chunk)
            bar.update(len(chunk))

def download_flickr8k(target_dir):
    os.makedirs(target_dir, exist_ok=True)
    print("📥 Downloading flickr8k...")
    zip_path = os.path.join(target_dir, "flickr8k.zip")
    url = "https://github.com/awsaf49/flickr-dataset/releases/download/v1.0/flickr8k.zip"
    download_with_progress(url, zip_path)
    run_shell_command([f"unzip -q -o {zip_path} -d {target_dir}"])
    os.remove(zip_path)

def download_flickr30k(target_dir):
    os.makedirs(target_dir, exist_ok=True)
    print("📥 Downloading flickr30k...")
    zip_path = os.path.join(target_dir, "flickr30k.zip")
    parts = [f"flickr30k_part0{i}" for i in range(3)]
    urls = [f"https://github.com/awsaf49/flickr-dataset/releases/download/v1.0/{p}" for p in parts]
    part_paths = [os.path.join(target_dir, p) for p in parts]
    for url, part_path in zip(urls, part_paths):
        download_with_progress(url, part_path)
    
    run_shell_command([f"cat {' '.join(part_paths)} > {zip_path}"])
    for part in part_paths:
        os.remove(part)
    run_shell_command([f"unzip -q -o {zip_path} -d {target_dir}"])
    os.remove(zip_path)

def clean_caption(text):
    text = str(text).lower().strip()
    text = re.sub(r"[^a-z0-9\s]", "", text)          # Remove punctuation
    text = re.sub(r"\s+", " ", text)                 # Normalize whitespace
    return text

def process_captions(raw_captions_path, final_captions_path):
    print(f"Processing captions from {raw_captions_path}...")
    if not os.path.exists(raw_captions_path):
        print(f"❌ Missing raw captions file: {raw_captions_path}")
        return

    df = pd.read_csv(raw_captions_path)
    df.columns = df.columns.str.strip()
    df.rename(columns={"image_name": "image", "comment": "caption"}, inplace=True)
    df.dropna(subset=["caption"], inplace=True)
    df["caption"] = df["caption"].astype(str).str.strip().apply(clean_caption)
    df["num_tokens"] = df["caption"].apply(lambda x: len(x.split()))
    df = df[(df["num_tokens"] >= 3) & (df["num_tokens"] <= 50)].reset_index(drop=True)

    df["caption_number"] = df.groupby("image").cumcount()
    df["id"] = df["image"].factorize()[0]
    df = df[["image", "caption_number", "caption", "id"]]

    df.to_csv(final_captions_path, index=False)
    print(f"\n✅ Preprocessing DONE")
    print(f"📝 Total captions: {len(df)}")
    print(f"🔤 Avg length: {df['caption'].apply(lambda x: len(x.split())).mean():.2f} tokens")
    print(f"📄 Saved: {final_captions_path}")

def prepare_dataset(config):
    dataset_name = config["dataset_name"]
    dataset_dir = config["dataset_dir"]
    image_dir = config["image_dir"]
    captions_file = config["captions_file"]

    if os.path.exists(captions_file) and os.path.exists(image_dir) and len(glob.glob(os.path.join(image_dir, '*.jpg'))) > 10:
        print(f"Dataset '{dataset_name}' found and processed. Skipping preparation.")
        return

    print(f"Dataset '{dataset_name}' not found. Starting download and preparation...")
    
    if dataset_name == 'flickr8k':
        download_flickr8k(dataset_dir)
    elif dataset_name == 'flickr30k':
        download_flickr30k(dataset_dir)
    else:
        print(f"Unknown dataset: {dataset_name}")
        return

    raw_captions_path = os.path.join(dataset_dir, 'captions.txt')
    process_captions(raw_captions_path, captions_file)
    
    if os.path.exists(raw_captions_path):
        os.remove(raw_captions_path)
    
    print(f"Dataset '{dataset_name}' is ready.")

### **Step 4: Data Handling Components**

In [None]:
class CLIPDataset(Dataset):
    def __init__(self, image_paths, captions, tokenizer, transforms, cfg):
        self.image_paths = image_paths
        self.captions = list(captions)
        self.tokenizer = tokenizer
        self.transforms = transforms
        self.cfg = cfg

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        encoded_caption = self.tokenizer(
            self.captions[idx], padding='max_length', truncation=True, 
            max_length=self.cfg.max_length, return_tensors='pt'
        )
        try:
            image = Image.open(self.image_paths[idx]).convert("RGB")
        except FileNotFoundError:
            return None
        image = np.array(image)
        image = self.transforms(image=image)['image']
        
        item = {
            'image': image,
            'input_ids': encoded_caption['input_ids'].squeeze(0),
            'attention_mask': encoded_caption['attention_mask'].squeeze(0),
            'image_path': self.image_paths[idx],
            'caption': self.captions[idx]
        }
        return item

def get_transforms(cfg, mode="train"):
    if mode == "train":
        return A.Compose([
            A.RandomResizedCrop(size=(cfg.size, cfg.size), scale=(0.8, 1.0), ratio=(0.75, 1.333)),
            A.HorizontalFlip(p=0.5),
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
            A.Resize(height=cfg.size, width=cfg.size),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
            ToTensorV2(),
        ])

def make_train_valid_dfs(config):
    try:
        df = pd.read_csv(config['captions_file'])
        df = df.dropna().reset_index(drop=True)
    except Exception as e:
        print(f"Error reading captions file: {e}")
        return None, None

    image_files_in_dir = set(os.listdir(config['image_dir']))
    df = df[df['image'].isin(image_files_in_dir)].reset_index(drop=True)

    unique_images = df['image'].unique()
    train_imgs, valid_imgs = train_test_split(unique_images, test_size=0.2, random_state=42)
    
    train_df = df[df['image'].isin(train_imgs)].reset_index(drop=True)
    valid_df = df[df['image'].isin(valid_imgs)].reset_index(drop=True)
    
    return train_df, valid_df

def build_loaders(df, tokenizer, mode, config, cfg):
    transforms = get_transforms(cfg, mode=mode)
    image_paths = [os.path.join(config['image_dir'], fname) for fname in df['image'].values]
    captions = df['caption'].values
    
    dataset = CLIPDataset(image_paths, captions, tokenizer, transforms, cfg)
    
    def collate_fn(batch):
        batch = [item for item in batch if item is not None]
        if not batch:
            return None
        return torch.utils.data.dataloader.default_collate(batch)
        
    dataloader = DataLoader(
        dataset, 
        batch_size=cfg.batch_size, 
        num_workers=cfg.num_workers, 
        shuffle=True if mode == 'train' else False,
        collate_fn=collate_fn
    )
    return dataloader

### **Step 5: Model, Loss, and Trainer Components**

In [None]:
def create_image_encoder(cfg): return timm.create_model(cfg.image_model_name, pretrained=cfg.pretrained, num_classes=0, global_pool='avg')
def create_text_encoder(cfg): return AutoModel.from_pretrained(cfg.text_encoder_model)

class ProjectionHead(nn.Module):
    def __init__(self, embedding_dim, cfg):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, cfg.projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(cfg.projection_dim, cfg.projection_dim)
        self.dropout = nn.Dropout(cfg.dropout)
        self.layer_norm = nn.LayerNorm(cfg.projection_dim)
    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x) + projected
        return self.layer_norm(x)

class CLIPModel(nn.Module):
    def __init__(self, image_encoder, text_encoder, cfg):
        super().__init__()
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        image_embedding_dim = image_encoder.num_features
        text_embedding_dim = text_encoder.config.hidden_size
        self.image_projection = ProjectionHead(embedding_dim=image_embedding_dim, cfg=cfg)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding_dim, cfg=cfg)

    def forward(self, batch):
        image_features = self.image_encoder(batch['image'])
        text_features = self.text_encoder(input_ids=batch['input_ids'], attention_mask=batch['attention_mask']).last_hidden_state[:, 0, :]
        image_embeddings = self.image_projection(image_features)
        text_embeddings = self.text_projection(text_features)
        return image_embeddings, text_embeddings

class CLIPLoss(nn.Module):
    def __init__(self, temperature):
        super().__init__()
        self.temperature = temperature
    
    def forward(self, image_embeddings, text_embeddings):
        logits = (text_embeddings @ image_embeddings.T) / self.temperature
        targets = torch.arange(len(image_embeddings)).to(image_embeddings.device)
        texts_loss = F.cross_entropy(logits, targets)
        images_loss = F.cross_entropy(logits.T, targets)
        return (images_loss + texts_loss) / 2.0

class AvgMeter:
    def __init__(self, name="Metric"): self.name, self.avg, self.sum, self.count = name, 0, 0, 0
    def update(self, val, n=1): self.sum += val * n; self.count += n; self.avg = self.sum / self.count
    def __repr__(self): return f"{self.name}: {self.avg:.4f}"

class Trainer:
    def __init__(self, model, optimizer, scheduler, loss_fn, cfg):
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.loss_fn = loss_fn
        self.cfg = cfg
        self.device = cfg.device

    def _train_one_epoch(self, train_loader):
        loss_meter = AvgMeter()
        tqdm_object = tqdm(train_loader, total=len(train_loader))
        self.model.train()
        for i, batch in enumerate(tqdm_object):
            if batch is None: continue
            batch = {k: v.to(self.device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
            
            image_embeddings, text_embeddings = self.model(batch)
            loss = self.loss_fn(image_embeddings, text_embeddings)
            loss = loss / self.cfg.gradient_accumulation_steps
            loss.backward()

            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            if (i + 1) % self.cfg.gradient_accumulation_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()
                if self.scheduler: self.scheduler.step()

            loss_meter.update(loss.item() * self.cfg.gradient_accumulation_steps, batch['image'].size(0))
            tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=self.optimizer.param_groups[0]['lr'])
        return loss_meter

    def _valid_one_epoch(self, valid_loader):
        loss_meter = AvgMeter()
        tqdm_object = tqdm(valid_loader, total=len(valid_loader))
        self.model.eval()
        with torch.no_grad():
            for batch in tqdm_object:
                if batch is None: continue
                batch = {k: v.to(self.device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
                image_embeddings, text_embeddings = self.model(batch)
                loss = self.loss_fn(image_embeddings, text_embeddings)
                loss_meter.update(loss.item(), batch['image'].size(0))
                tqdm_object.set_postfix(valid_loss=loss_meter.avg)
        return loss_meter

    def fit(self, train_df, valid_loader, tokenizer, config):
        best_loss = float('inf')
        epochs_without_improvement = 0
        history = defaultdict(list)
        total_training_start_time = time.time()

        for epoch in range(self.cfg.epochs):
            epoch_start_time = time.time()
            print(f"\nEpoch: {epoch + 1}/{self.cfg.epochs}")
            
            epoch_train_df = train_df.groupby('id').sample(n=1).reset_index(drop=True)
            epoch_train_loader = build_loaders(epoch_train_df, tokenizer, mode="train", config=config, cfg=self.cfg)

            train_loss = self._train_one_epoch(epoch_train_loader)
            valid_loss = self._valid_one_epoch(valid_loader)

            epoch_end_time = time.time()
            epoch_duration = epoch_end_time - epoch_start_time
            history['train_loss'].append(train_loss.avg)
            history['valid_loss'].append(valid_loss.avg)
            history['epoch_times'].append(epoch_duration)

            print(f"Epoch {epoch+1} | Train Loss: {train_loss.avg:.4f} | Valid Loss: {valid_loss.avg:.4f} | Time: {epoch_duration:.2f}s")

            if valid_loss.avg < best_loss:
                best_loss = valid_loss.avg
                epochs_without_improvement = 0
                torch.save({'epoch': epoch + 1, 'model_state_dict': self.model.state_dict()}, config['model_save_path'])
                print(f"Saved Best Model! Validation Loss: {best_loss:.4f}")
            else:
                epochs_without_improvement += 1
                print(f"No improvement in validation loss for {epochs_without_improvement} epoch(s).")

            if epochs_without_improvement >= self.cfg.early_stopping_patience:
                print(f"\nEarly stopping triggered after {self.cfg.early_stopping_patience} epochs without improvement.")
                break
        
        total_training_end_time = time.time()
        total_training_duration = total_training_end_time - total_training_start_time
        
        plot_loss_curves(history, config)
        print_performance_summary(history, total_training_duration, epoch_train_loader, self.cfg, config)
        return history

### **Step 6: Embedding Generation and Evaluation**

In [None]:
def generate_and_save_embeddings(model, dataloader, config, cfg, set_name):
    start_time = time.time()
    print(f"Generating embeddings for {config['dataset_name']} {set_name} set...")
    model.eval()
    
    all_image_embeddings, all_text_embeddings = [], []
    all_image_paths, all_captions = [], []

    with torch.no_grad():
        for batch in tqdm(dataloader):
            if batch is None: continue
            image_embeddings, text_embeddings = model({
                'image': batch['image'].to(cfg.device),
                'input_ids': batch['input_ids'].to(cfg.device),
                'attention_mask': batch['attention_mask'].to(cfg.device)
            })
            all_image_embeddings.append(image_embeddings.cpu())
            all_text_embeddings.append(text_embeddings.cpu())
            all_image_paths.extend(batch['image_path'])
            all_captions.extend(batch['caption'])

    image_embeddings_full = torch.cat(all_image_embeddings)
    text_embeddings_full = torch.cat(all_text_embeddings)
    
    unique_image_paths = sorted(list(set(all_image_paths)))
    path_to_idx = {path: i for i, path in enumerate(unique_image_paths)}
    
    unique_image_embeddings = torch.zeros(len(unique_image_paths), cfg.projection_dim)
    path_to_embedding_map = defaultdict(list)
    for i in range(len(all_image_paths)):
        path_to_embedding_map[all_image_paths[i]].append(image_embeddings_full[i])
    
    for i, path in enumerate(unique_image_paths):
        unique_image_embeddings[i] = torch.stack(path_to_embedding_map[path]).mean(dim=0)
        
    ground_truth = torch.tensor([path_to_idx[path] for path in all_image_paths])
    
    embedding_dir = config['embedding_save_path']
    torch.save(unique_image_embeddings, os.path.join(embedding_dir, f'{set_name}_image_embeddings.pt'))
    torch.save(text_embeddings_full, os.path.join(embedding_dir, f'{set_name}_text_embeddings.pt'))
    torch.save(ground_truth, os.path.join(embedding_dir, f'{set_name}_ground_truth.pt'))
    torch.save(all_captions, os.path.join(embedding_dir, f'{set_name}_captions.pt'))
    torch.save(unique_image_paths, os.path.join(embedding_dir, f'{set_name}_unique_image_paths.pt'))

    end_time = time.time()
    duration = end_time - start_time
    print(f"{set_name.capitalize()} embeddings saved to '{embedding_dir}'. Time taken: {duration:.2f} seconds.")

def compute_top_k_retrieval_metrics(ranked_indices, ground_truth_indices, k_values):
    metrics = {k: {"accuracy": 0, "precision": 0, "recall": 0} for k in k_values}
    ground_truth_set = set(ground_truth_indices)
    if not ground_truth_set: return metrics
    
    for k in k_values:
        top_k_preds = set(ranked_indices[:k])
        hits = len(top_k_preds.intersection(ground_truth_set))
        metrics[k]["accuracy"] = 1 if hits > 0 else 0
        metrics[k]["precision"] = hits / k if k > 0 else 0
        metrics[k]["recall"] = hits / len(ground_truth_set) if len(ground_truth_set) > 0 else 0
    return metrics

def report_retrieval_metrics(title, accumulated_metrics, total_queries, k_values):
    print(f"\n--- {title} ---")
    if total_queries == 0:
        print("No valid queries processed.")
        return {}
        
    header = f"| {'Top-K':<5} | {'Accuracy':<10} | {'Precision':<10} | {'Recall':<10} |"
    print(header)
    print("|" + "-" * (len(header) - 2) + "|")
    final_metrics = {}
    for k in k_values:
        acc = accumulated_metrics[k]['accuracy'] / total_queries
        prec = accumulated_metrics[k]['precision'] / total_queries
        rec = accumulated_metrics[k]['recall'] / total_queries
        final_metrics[k] = {'accuracy': acc, 'precision': prec, 'recall': rec}
        row = f"| {k:<5} | {acc:<10.4f} | {prec:<10.4f} | {rec:<10.4f} |"
        print(row)
    print("-" * len(header))
    return final_metrics

def evaluate_retrieval(image_embeddings, text_embeddings, ground_truth_map, config, cfg):
    TOP_K_VALUES = [1, 2, 3, 4, 5]
    
    image_embeddings_gpu = F.normalize(image_embeddings.float().to(cfg.device), p=2, dim=-1)
    text_embeddings_gpu = F.normalize(text_embeddings.float().to(cfg.device), p=2, dim=-1)
    ground_truth_map_np = ground_truth_map.numpy()

    # --- Text-to-Image Retrieval ---
    t2i_metrics = {k: defaultdict(float) for k in TOP_K_VALUES}
    num_text_queries = len(text_embeddings_gpu)
    for i in tqdm(range(num_text_queries), desc="Text-to-Image Retrieval"):
        query_embedding = text_embeddings_gpu[i].unsqueeze(0)
        sim = query_embedding @ image_embeddings_gpu.T
        ranked_indices = torch.argsort(sim, descending=True).squeeze().cpu().numpy().tolist()
        if not isinstance(ranked_indices, list): ranked_indices = [ranked_indices]
        query_metrics = compute_top_k_retrieval_metrics(ranked_indices, [ground_truth_map_np[i]], TOP_K_VALUES)
        for k in TOP_K_VALUES:
            for metric_name in query_metrics[k]: t2i_metrics[k][metric_name] += query_metrics[k][metric_name]
    t2i_final_metrics = report_retrieval_metrics(f"Text-to-Image Retrieval ({config['dataset_name']}) - {cfg.image_model_name} + {cfg.text_encoder_model}", t2i_metrics, num_text_queries, TOP_K_VALUES)

    # --- Image-to-Text Retrieval ---
    image_to_captions = defaultdict(list)
    for caption_idx, image_idx in enumerate(ground_truth_map_np):
        image_to_captions[image_idx].append(caption_idx)
    i2t_metrics = {k: defaultdict(float) for k in TOP_K_VALUES}
    num_image_queries = len(image_embeddings_gpu)
    for i in tqdm(range(num_image_queries), desc="Image-to-Text Retrieval"):
        query_embedding = image_embeddings_gpu[i].unsqueeze(0)
        sim = query_embedding @ text_embeddings_gpu.T
        ranked_indices = torch.argsort(sim, descending=True).squeeze().cpu().numpy().tolist()
        if not isinstance(ranked_indices, list): ranked_indices = [ranked_indices]
        ground_truth_caption_indices = image_to_captions[i]
        query_metrics = compute_top_k_retrieval_metrics(ranked_indices, ground_truth_caption_indices, TOP_K_VALUES)
        for k in TOP_K_VALUES:
            for metric_name in query_metrics[k]: i2t_metrics[k][metric_name] += query_metrics[k][metric_name]
    i2t_final_metrics = report_retrieval_metrics(f"Image-to-Text Retrieval ({config['dataset_name']}) - {cfg.image_model_name} + {cfg.text_encoder_model}", i2t_metrics, num_image_queries, TOP_K_VALUES)

    # --- Text-to-Text Retrieval ---
    t2t_metrics = {k: defaultdict(float) for k in TOP_K_VALUES}
    for i in tqdm(range(num_text_queries), desc="Text-to-Text Retrieval"):
        query_embedding = text_embeddings_gpu[i].unsqueeze(0)
        sim = query_embedding @ text_embeddings_gpu.T
        sim[0, i] = -torch.inf
        ranked_indices = torch.argsort(sim, descending=True).squeeze().cpu().numpy().tolist()
        if not isinstance(ranked_indices, list): ranked_indices = [ranked_indices]
        ground_truth_image_idx = ground_truth_map_np[i]
        ground_truth_caption_indices = image_to_captions[ground_truth_image_idx]
        query_metrics = compute_top_k_retrieval_metrics(ranked_indices, ground_truth_caption_indices, TOP_K_VALUES)
        for k in TOP_K_VALUES:
            for metric_name in query_metrics[k]: t2t_metrics[k][metric_name] += query_metrics[k][metric_name]
    t2t_final_metrics = report_retrieval_metrics(f"Text-to-Text Retrieval ({config['dataset_name']}) - {cfg.image_model_name} + {cfg.text_encoder_model}", t2t_metrics, num_text_queries, TOP_K_VALUES)

    return {"t2i": t2i_final_metrics, "i2t": i2t_final_metrics, "t2t": t2t_final_metrics}

def plot_loss_curves(training_history, config):
    if not training_history: return
    plt.figure(figsize=(12, 5))
    plt.plot(range(1, len(training_history['train_loss']) + 1), training_history['train_loss'], label='Train Loss')
    plt.plot(range(1, len(training_history['valid_loss']) + 1), training_history['valid_loss'], label='Validation Loss')
    plt.title(f'Training and Validation Loss - {config["dataset_name"]}')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.xticks(range(1, len(training_history['train_loss']) + 1))
    plt.legend()
    plt.grid(True)
    plt.show()

def print_performance_summary(training_history, total_training_duration, train_loader, cfg, config):
    if not training_history: return
    total_epochs_trained = len(training_history['train_loss'])
    avg_epoch_time = sum(training_history['epoch_times']) / len(training_history['epoch_times']) if training_history['epoch_times'] else 0
    iterations_per_epoch = len(train_loader)
    avg_iteration_time = avg_epoch_time / iterations_per_epoch if iterations_per_epoch > 0 else 0

    print(f"\n--- Training Performance Summary for {config['dataset_name']} ---")
    print(f"  GPU Used:                  {torch.cuda.get_device_name(0) if cfg.device.type == 'cuda' else 'CPU'}")
    print(f"  Total Epochs Trained:      {total_epochs_trained}")
    print(f"  Batch Size:                {cfg.batch_size}")
    print(f"  Head Learning Rate:        {cfg.head_lr}")
    print(f"  Image Encoder LR:          {cfg.image_encoder_lr}")
    print(f"  Text Encoder LR:           {cfg.text_encoder_lr}")
    print(f"  Optimizer:                 AdamW")
    print("------------------------------------")
    print(f"  Total Training Time:       {total_training_duration:.2f} seconds ({total_training_duration/60:.2f} minutes)")
    print(f"  Average Time per Epoch:    {avg_epoch_time:.2f} seconds")
    print(f"  Average Time per Iteration:{avg_iteration_time:.4f} seconds")
    print("------------------------------------")

### **Step 7: The Main Pipeline Function**

In [None]:
def run_pipeline(config, cfg):
    print("-" * 50)
    print(f"STARTING PIPELINE FOR: {config['dataset_name'].upper()}")
    print(f"With model: {cfg.image_model_name} + {cfg.text_encoder_model}")
    print("-" * 50)

    # --- Setup: Data Loading ---
    print("Setting up datasets and dataloaders...")
    tokenizer = AutoTokenizer.from_pretrained(cfg.text_tokenizer)
    train_df, valid_df = make_train_valid_dfs(config)
    if train_df is None or valid_df is None or train_df.empty or valid_df.empty:
        print("Could not create dataframes or they are empty. Aborting pipeline.")
        return None, None
    
    valid_loader = build_loaders(valid_df, tokenizer, mode="valid", config=config, cfg=cfg)

    # --- Model Creation and Loading ---
    image_encoder = create_image_encoder(cfg)
    text_encoder = create_text_encoder(cfg)
    model = CLIPModel(image_encoder, text_encoder, cfg).to(cfg.device)
    
    model_path = config['model_save_path']
    
    if os.path.exists(model_path) and not cfg.force_retrain:
        print(f"Model found at '{model_path}'. Loading weights and skipping training.")
        checkpoint = torch.load(model_path, map_location=cfg.device)
        model.load_state_dict(checkpoint['model_state_dict'])
        training_history = None # No training was performed
    else:
        print(f"Model not found or force_retrain is True. Starting training...")
        params = [
            {"params": model.image_encoder.parameters(), "lr": cfg.image_encoder_lr},
            {"params": model.text_encoder.parameters(), "lr": cfg.text_encoder_lr},
            {"params": list(model.image_projection.parameters()) + list(model.text_projection.parameters()), "lr": cfg.head_lr, "weight_decay": cfg.weight_decay}
        ]
        optimizer = torch.optim.AdamW(params, weight_decay=0.)
        lr_scheduler = CosineAnnealingLR(optimizer, T_max=len(train_df) // cfg.batch_size * cfg.epochs, eta_min=cfg.lr_min)
        loss_fn = CLIPLoss(temperature=cfg.temperature).to(cfg.device)

        trainer = Trainer(model, optimizer, lr_scheduler, loss_fn, cfg)
        training_history = trainer.fit(train_df, valid_loader, tokenizer, config)

    if not cfg.run_embedding_generation and not cfg.run_evaluation:
        print("Skipping embedding generation and evaluation as per configuration.")
        return training_history, {}

    # --- Embedding Generation ---
    embedding_dir = config['embedding_save_path']
    train_embed_path = os.path.join(embedding_dir, 'train_image_embeddings.pt')
    valid_embed_path = os.path.join(embedding_dir, 'valid_image_embeddings.pt')
    
    if cfg.run_embedding_generation:
        full_train_loader = build_loaders(train_df, tokenizer, mode="valid", config=config, cfg=cfg)
        if not os.path.exists(train_embed_path):
            print("\n--- Generating Embeddings for the TRAINING set ---")
            generate_and_save_embeddings(model, full_train_loader, config, cfg, "train")
        else:
            print("\nTraining embeddings found. Skipping generation.")
    
        if not os.path.exists(valid_embed_path):
            print("\n--- Generating Embeddings for the VALIDATION set ---")
            generate_and_save_embeddings(model, valid_loader, config, cfg, "valid")
        else:
            print("Validation embeddings found. Skipping generation.")
    else:
        print("\nSkipping embedding generation as per configuration.")

    # --- Retrieval Accuracy Calculation ---
    final_metrics = {}
    if cfg.run_evaluation:
        print("\nProceeding to retrieval accuracy calculation...")
        try:
            image_embeddings = torch.load(os.path.join(embedding_dir, 'valid_image_embeddings.pt'))
            text_embeddings = torch.load(os.path.join(embedding_dir, 'valid_text_embeddings.pt'))
            ground_truth = torch.load(os.path.join(embedding_dir, 'valid_ground_truth.pt'))
            
            final_metrics = evaluate_retrieval(image_embeddings, text_embeddings, ground_truth, config, cfg)

        except FileNotFoundError:
            print("Could not find validation embedding files to calculate accuracy.")
        
        # --- Qualitative Analysis ---
        show_qualitative_results(config, cfg)
    else:
        print("Skipping evaluation as per configuration.")
    
    print(f"PIPELINE FOR {config['dataset_name'].upper()} COMPLETE\n")
    return training_history, final_metrics

### **Step 8: Main Execution and Comparative Analysis**

In [None]:
def show_text_to_image_examples(config, cfg, num_examples=3):
    print("\n" + "="*50)
    print("           QUALITATIVE ANALYSIS: TEXT-TO-IMAGE RETRIEVAL")
    print("="*50 + "\n")
    
    embedding_dir = config['embedding_save_path']
    try:
        image_embeddings = torch.load(os.path.join(embedding_dir, 'valid_image_embeddings.pt'))
        text_embeddings = torch.load(os.path.join(embedding_dir, 'valid_text_embeddings.pt'))
        ground_truth_map = torch.load(os.path.join(embedding_dir, 'valid_ground_truth.pt')).numpy()
        valid_captions = torch.load(os.path.join(embedding_dir, 'valid_captions.pt'))
        unique_image_paths = torch.load(os.path.join(embedding_dir, 'valid_unique_image_paths.pt'))
    except FileNotFoundError:
        print("Could not find embedding files for qualitative analysis.")
        return
    
    random_indices = random.sample(range(len(valid_captions)), num_examples)
    
    for idx in random_indices:
        query_caption = valid_captions[idx]
        true_image_idx = ground_truth_map[idx]
        true_image_path = unique_image_paths[true_image_idx]
        
        query_text_embedding = text_embeddings[idx].to(cfg.device).unsqueeze(0)
        similarities = query_text_embedding @ image_embeddings.to(cfg.device).T
        
        top_k_indices = torch.argsort(similarities, descending=True).squeeze()[:5]
        
        print(f"QUERY CAPTION: \"{query_caption}\"")
        
        fig, axes = plt.subplots(1, 6, figsize=(20, 4))
        
        axes[0].imshow(Image.open(true_image_path))
        axes[0].set_title("Ground Truth")
        axes[0].axis("off")
        
        for i, img_idx in enumerate(top_k_indices):
            retrieved_path = unique_image_paths[img_idx]
            axes[i+1].imshow(Image.open(retrieved_path))
            is_correct = img_idx == true_image_idx
            border_color = 'green' if is_correct else 'red'
            axes[i+1].set_title(f"Rank {i+1}", color=border_color)
            for spine in axes[i+1].spines.values():
                spine.set_edgecolor(border_color)
                spine.set_linewidth(4)
            axes[i+1].set_xticks([])
            axes[i+1].set_yticks([])

        plt.show()
        print("-" * 80)

def show_image_to_text_examples(config, cfg, num_examples=3):
    print("\n" + "="*50)
    print("           QUALITATIVE ANALYSIS: IMAGE-TO-TEXT RETRIEVAL")
    print("="*50 + "\n")
    
    embedding_dir = config['embedding_save_path']
    try:
        image_embeddings = torch.load(os.path.join(embedding_dir, 'valid_image_embeddings.pt'))
        text_embeddings = torch.load(os.path.join(embedding_dir, 'valid_text_embeddings.pt'))
        ground_truth_map = torch.load(os.path.join(embedding_dir, 'valid_ground_truth.pt')).numpy()
        valid_captions = torch.load(os.path.join(embedding_dir, 'valid_captions.pt'))
        unique_image_paths = torch.load(os.path.join(embedding_dir, 'valid_unique_image_paths.pt'))
    except FileNotFoundError:
        print("Could not find embedding files for qualitative analysis.")
        return

    image_to_captions_map = defaultdict(list)
    for i, img_idx in enumerate(ground_truth_map):
        image_to_captions_map[img_idx].append(i)
        
    unique_image_indices = list(image_to_captions_map.keys())
    random_indices = random.sample(unique_image_indices, num_examples)
    
    for img_idx in random_indices:
        query_image_embedding = image_embeddings[img_idx].to(cfg.device).unsqueeze(0)
        similarities = query_image_embedding @ text_embeddings.to(cfg.device).T
        top_k_indices = torch.argsort(similarities, descending=True).squeeze()[:5]
        
        query_image_path = unique_image_paths[img_idx]

        plt.imshow(Image.open(query_image_path))
        plt.title(f"QUERY IMAGE: {os.path.basename(query_image_path)}")
        plt.axis("off")
        plt.show()

        print("--- Ground Truth Captions ---")
        ground_truth_caption_indices = image_to_captions_map[img_idx]
        for cap_idx in ground_truth_caption_indices:
            print(f"- {valid_captions[cap_idx]}")

        print("\n--- Top 5 Retrieved Captions ---")
        for i, cap_idx in enumerate(top_k_indices):
            retrieved_caption = valid_captions[cap_idx]
            is_correct = cap_idx in image_to_captions_map[img_idx]
            prefix = "✅" if is_correct else "❌"
            print(f"{prefix} Rank {i+1}: \"{retrieved_caption}\"")
        print("-" * 80)

def show_text_to_text_examples(config, cfg, num_examples=3):
    print("\n" + "="*50)
    print("           QUALITATIVE ANALYSIS: TEXT-TO-TEXT RETRIEVAL")
    print("="*50 + "\n")
    
    embedding_dir = config['embedding_save_path']
    try:
        text_embeddings = torch.load(os.path.join(embedding_dir, 'valid_text_embeddings.pt'))
        ground_truth_map = torch.load(os.path.join(embedding_dir, 'valid_ground_truth.pt')).numpy()
        valid_captions = torch.load(os.path.join(embedding_dir, 'valid_captions.pt'))
    except FileNotFoundError:
        print("Could not find embedding files for qualitative analysis.")
        return

    image_to_captions_map = defaultdict(list)
    for i, img_idx in enumerate(ground_truth_map):
        image_to_captions_map[img_idx].append(i)

    random_indices = random.sample(range(len(valid_captions)), num_examples)

    for idx in random_indices:
        query_caption = valid_captions[idx]
        query_text_embedding = text_embeddings[idx].to(cfg.device).unsqueeze(0)
        similarities = query_text_embedding @ text_embeddings.to(cfg.device).T
        similarities[0, idx] = -torch.inf # Exclude self-retrieval
        top_k_indices = torch.argsort(similarities, descending=True).squeeze()[:5]

        print(f"QUERY CAPTION: \"{query_caption}\"")

        ground_truth_image_id = ground_truth_map[idx]
        print("--- Ground Truth Captions (from same image) ---")
        ground_truth_caption_indices = image_to_captions_map[ground_truth_image_id]
        for cap_idx in ground_truth_caption_indices:
            if cap_idx != idx: print(f"- {valid_captions[cap_idx]}")

        print("\n--- Top 5 Retrieved Captions ---")
        for i, cap_idx in enumerate(top_k_indices):
            retrieved_caption = valid_captions[cap_idx]
            is_correct = cap_idx in ground_truth_caption_indices
            prefix = "✅" if is_correct else "❌"
            print(f"{prefix} Rank {i+1}: \"{retrieved_caption}\"")
        print("-" * 80)

def show_qualitative_results(config, cfg):
    try:
        show_text_to_image_examples(config, cfg)
        show_image_to_text_examples(config, cfg)
        show_text_to_text_examples(config, cfg)
    except Exception as e:
        print(f"An error occurred during qualitative analysis: {e}")

if __name__ == '__main__':
    env_name, base_path = detect_environment()
    
    results_history = {}
    for exp_name, exp_params in experiment_configs.items():
        print("\n" + "="*80)
        print(f"                RUNNING EXPERIMENT: {exp_name.upper()}")
        print("="*80 + "\n")
        
        datasets_to_process = ["flickr8k", "flickr30k"]
        for dataset_name in datasets_to_process:
            base_cfg_dict = {k: v for k, v in BaseCFG.__dict__.items() if not k.startswith('__')}
            combined_params = {**base_cfg_dict, **exp_params["models"], **exp_params["hyperparameters"][dataset_name]}
            cfg = SimpleNamespace(**combined_params)
            
            path_config = generate_paths(base_path, dataset_name, cfg)
            prepare_dataset(path_config)
            history, metrics = run_pipeline(path_config, cfg)
            
            if dataset_name not in results_history:
                results_history[dataset_name] = {}
            results_history[dataset_name][exp_name] = {'history': history, 'metrics': metrics}

    if BaseCFG.show_comparative_plots:
        print("\n" + "="*60)
        print("           FINAL COMPARATIVE ANALYSIS")
        print("="*60 + "\n")
        if results_history:
            # --- 1. Side-by-Side Loss Plots ---
            for exp_name in experiment_configs.keys():
                num_datasets = len(results_history)
                if num_datasets > 0:
                    fig, axes = plt.subplots(1, num_datasets, figsize=(8 * num_datasets, 5), squeeze=False)
                    fig.suptitle(f'Training & Validation Loss Comparison - {exp_name.upper()}', fontsize=16)

                    dataset_names = list(results_history.keys())
                    for i in range(num_datasets):
                        ax = axes[0, i]
                        dataset_name = dataset_names[i]
                        history = results_history[dataset_name][exp_name].get('history')
                        if history and history['train_loss']:
                            ax.plot(range(1, len(history['train_loss']) + 1), history['train_loss'], label='Train Loss')
                            ax.plot(range(1, len(history['valid_loss']) + 1), history['valid_loss'], label='Validation Loss')
                            ax.set_title(dataset_name)
                            ax.set_xlabel('Epochs')
                            ax.set_ylabel('Loss')
                            ax.legend()
                            ax.grid(True)
                        else:
                            ax.text(0.5, 0.5, 'Training was skipped', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes)
                            ax.set_title(dataset_name)

                    plt.tight_layout(rect=[0, 0, 1, 0.96])
                    plt.show()

            # --- 2. Comparative Metrics Table ---
            for exp_name in experiment_configs.keys():
                display(Markdown(f"### Results for Experiment: {exp_name.upper()}"))
                t2i_summary_table = "| Dataset   | Top-1 Accuracy | Top-1 Precision | Top-1 Recall |\n"
                t2i_summary_table += "|-----------|----------------|-----------------|--------------|\n"
                i2t_summary_table = "| Dataset   | Top-1 Accuracy | Top-1 Precision | Top-1 Recall |\n"
                i2t_summary_table += "|-----------|----------------|-----------------|--------------|\n"
                t2t_summary_table = "| Dataset   | Top-1 Accuracy | Top-1 Precision | Top-1 Recall |\n"
                t2t_summary_table += "|-----------|----------------|-----------------|--------------|\n"

                for dataset_name in datasets_to_process:
                    metrics = results_history.get(dataset_name, {}).get(exp_name, {}).get('metrics')
                    if metrics:
                        t2i_metrics = metrics.get('t2i')
                        i2t_metrics = metrics.get('i2t')
                        t2t_metrics = metrics.get('t2t')
                        if t2i_metrics and 1 in t2i_metrics:
                            acc, prec, rec = t2i_metrics[1]['accuracy'], t2i_metrics[1]['precision'], t2i_metrics[1]['recall']
                            t2i_summary_table += f"| {dataset_name:<9} | {acc:<14.4f} | {prec:<15.4f} | {rec:<12.4f} |\n"
                        if i2t_metrics and 1 in i2t_metrics:
                            acc, prec, rec = i2t_metrics[1]['accuracy'], i2t_metrics[1]['precision'], i2t_metrics[1]['recall']
                            i2t_summary_table += f"| {dataset_name:<9} | {acc:<14.4f} | {prec:<15.4f} | {rec:<12.4f} |\n"
                        if t2t_metrics and 1 in t2t_metrics:
                            acc, prec, rec = t2t_metrics[1]['accuracy'], t2t_metrics[1]['precision'], t2t_metrics[1]['recall']
                            t2t_summary_table += f"| {dataset_name:<9} | {acc:<14.4f} | {prec:<15.4f} | {rec:<12.4f} |\n"
                
                display(Markdown("#### Text-to-Image Retrieval Summary"))
                display(Markdown(t2i_summary_table))
                display(Markdown("#### Image-to-Text Retrieval Summary"))
                display(Markdown(i2t_summary_table))
                display(Markdown("#### Text-to-Text Retrieval Summary"))
                display(Markdown(t2t_summary_table))

        else:
            print("No results were generated to compare.")

    print("\nAll dataset pipelines have been executed.")