# Project Setup

After running the code above, you'll gain insights into the dataset structure, caption statistics, and visualize
sample images with their captions. This is our first step in understanding the data we'll be working with.

In [None]:
# Image Captioning with CNN+RNN and Attention Models



In [None]:
## Setup and Configuration

### 1.1 Imports and Environment Setup
import os
import sys
import time
import random
import traceback
import numpy as np
import pandas as pd
import json
import pickle
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

# Deep learning imports
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, Subset
from torch.nn.utils.rnn import pad_sequence

# NLP imports
import nltk
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

try:
    import spacy
    nlp = spacy.load("en_core_web_sm")
except (ImportError, OSError):
    print("Installing spaCy model...")
    spacy.cli.download("en_core_web_sm")
    nlp = spacy.load("en_core_web_sm")

# Set seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Check device availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

### 1.2 Configuration Parameters and Paths
# Set execution mode - set to False for full training
DEBUG_MODE = False

# Root directory and paths setup
ROOT_DIR = os.path.dirname(os.getcwd())  # Project root directory
DATA_DIR = os.path.join(ROOT_DIR, "data")
OUTPUT_DIR = os.path.join(ROOT_DIR, "output")
MODEL_DIR = os.path.join(OUTPUT_DIR, "models")
LOGS_DIR = os.path.join(OUTPUT_DIR, "logs")

# Create directories if they don't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(LOGS_DIR, exist_ok=True)

# Dataset paths
FLICKR8K_DIR = os.path.join(DATA_DIR, "raw", "flickr8k")
IMAGES_DIR = os.path.join(FLICKR8K_DIR, "images")
CAPTIONS_FILE = os.path.join(FLICKR8K_DIR, "captions.txt")

# ImageNet normalization constants
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# Verify file paths
print(f"Images directory exists: {os.path.exists(IMAGES_DIR)}")
print(f"Captions file exists: {os.path.exists(CAPTIONS_FILE)}")

### 1.3 Model Configuration Parameters
# Shared model configuration parameters
model_config = {
    # Architecture
    "embed_size": 256,          # Embedding dimension
    "hidden_size": 512,         # LSTM hidden state size
    "attention_dim": 256,       # Attention network dimension
    "num_layers": 1,            # Number of LSTM layers
    "dropout": 0.5,             # Dropout probability
    
    # Training
    "num_epochs": 10 if not DEBUG_MODE else 3,
    "batch_size": 32 if not DEBUG_MODE else 8,
    "learning_rate": 3e-4,
    "weight_decay": 1e-5,
    "clip_grad_norm": 5.0,
    
    # Scheduler
    "use_lr_scheduler": True,
    "lr_scheduler_factor": 0.5,
    "lr_scheduler_patience": 2,
    
    # Evaluation
    "eval_every": 1,            # Validate every N epochs
    "bleu_every": 2,            # Calculate BLEU every N epochs
    "max_bleu_samples": None if not DEBUG_MODE else 50,
    
    # Early stopping
    "early_stopping_patience": 5,
    
    # Logging and checkpoints
    "print_frequency": 50 if not DEBUG_MODE else 5,
    "save_best_only": True,
    "save_frequency": 1
}

# Paths for baseline model
baseline_paths = {
    "checkpoint_path": os.path.join(MODEL_DIR, "baseline_model.pth"),
    "best_model_path": os.path.join(MODEL_DIR, "baseline_model_best.pth")
}

# Paths for attention model
attention_paths = {
    "checkpoint_path": os.path.join(MODEL_DIR, "attention_model.pth"),
    "best_model_path": os.path.join(MODEL_DIR, "attention_model_best.pth")
}

# Debug paths
if DEBUG_MODE:
    baseline_paths = {
        "checkpoint_path": os.path.join(MODEL_DIR, "baseline_model_debug.pth"),
        "best_model_path": os.path.join(MODEL_DIR, "baseline_model_debug_best.pth")
    }
    attention_paths = {
        "checkpoint_path": os.path.join(MODEL_DIR, "attention_model_debug.pth"),
        "best_model_path": os.path.join(MODEL_DIR, "attention_model_debug_best.pth")
    }

### 1.4 Utility Functions
def debug_print(message, tensor=None, level=0):
    """Print debug message with optional tensor information"""
    indent = "  " * level
    print(f"{indent}DEBUG: {message}")

    if tensor is not None and isinstance(tensor, torch.Tensor):
        print(f"{indent}  Shape: {tensor.shape}")
        print(f"{indent}  Type: {tensor.dtype}")
        print(f"{indent}  Device: {tensor.device}")
        print(f"{indent}  Values - Min: {tensor.min().item():.4f}, Max: {tensor.max().item():.4f}, Mean: {tensor.mean().item():.4f}")
        print(f"{indent}  Requires grad: {tensor.requires_grad}")

class AverageMeter:
    """Compute and store the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def normalize_image(image):
    """Normalize an image using ImageNet statistics"""
    # Create transform
    normalize = transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)

    # Apply normalization
    if image.dim() == 3:  # Single image
        return normalize(image)
    else:  # Batch of images
        return torch.stack([normalize(img) for img in image])

def denormalize_image(image_tensor):
    """Reverse ImageNet normalization for visualization"""
    # Convert to numpy and move channels to the end
    if image_tensor.dim() == 4:  # Batch of images
        image = image_tensor[0].cpu().permute(1, 2, 0).numpy()
    else:  # Single image
        image = image_tensor.cpu().permute(1, 2, 0).numpy()

    # Reverse normalization
    image = image * np.array(IMAGENET_STD).reshape(1, 1, 3) + np.array(IMAGENET_MEAN).reshape(1, 1, 3)

    # Clip values
    image = np.clip(image, 0, 1)
    return image

def save_checkpoint(state, is_best=False, filepath='checkpoint.pth'):
    """Save a checkpoint of the model"""
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(filepath), exist_ok=True)

    # Save checkpoint
    torch.save(state, filepath)

    # If this is the best model, print message
    if is_best:
        print(f"Saved best model to {filepath}")
    else:
        print(f"Saved checkpoint to {filepath}")

def check_model_availability(config, best_model_path, checkpoint_path):
    """Check if a trained model or checkpoint exists and is compatible"""
    # Check for fully trained model
    if os.path.exists(best_model_path):
        try:
            model_checkpoint = torch.load(best_model_path)
            print(f"Found trained model at {best_model_path}")
            return "trained", model_checkpoint
        except Exception as e:
            print(f"Error loading trained model: {e}")
    
    # Check for checkpoint
    if os.path.exists(checkpoint_path):
        try:
            checkpoint = torch.load(checkpoint_path)
            print(f"Found checkpoint at {checkpoint_path}")
            return "checkpoint", checkpoint
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
    
    # No compatible model or checkpoint found
    print("No trained model or checkpoint found. Will train a new model.")
    return "train_new", None



In [None]:
## Data Pipeline

### 2.1 Dataset Loading and Exploration
# Load captions data
captions_df = pd.read_csv(CAPTIONS_FILE)
print("\nDataset Overview:")
print(f"Captions dataframe shape: {captions_df.shape}")

# Calculate basic dataset statistics
num_images = len(captions_df['image'].unique())
num_captions = len(captions_df)
avg_captions_per_image = num_captions / num_images

print("\nDataset Statistics:")
print(f"Number of unique images: {num_images}")
print(f"Total number of captions: {num_captions}")
print(f"Average captions per image: {avg_captions_per_image:.2f}")

# Display sample data
print("\nSample captions:")
for i in range(min(5, len(captions_df))):
    print(f"Image: {captions_df.iloc[i]['image']}")
    print(f"Caption: {captions_df.iloc[i]['caption']}\n")

### 2.2 Data Preprocessing
def preprocess_caption(caption):
    """Preprocess caption text"""
    # Convert to lowercase
    caption = caption.lower()

    # Tokenize using spaCy for better handling of punctuation
    tokens = [token.text for token in nlp.tokenizer(caption)]

    # Join tokens back to string
    return " ".join(tokens)

# Apply preprocessing to captions
print("Preprocessing captions...")
captions_df['processed_caption'] = captions_df['caption'].apply(preprocess_caption)

# Display sample preprocessed captions
print("\nSample of original vs. processed captions:")
for i in range(min(5, len(captions_df))):
    original = captions_df.iloc[i]['caption']
    processed = captions_df.iloc[i]['processed_caption']
    print(f"Original: {original}")
    print(f"Processed: {processed}\n")

### 2.3 Data Splitting
def create_data_splits(df, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, random_state=42):
    """Split images into train, validation, and test sets using stratified sampling"""
    # Verify ratios sum to 1
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-10, "Ratios must sum to 1"

    # Get unique image IDs
    unique_images = df['image'].unique()
    
    # Create stratification features for better distribution
    strat_features = {}
    
    # Caption length feature (short, medium, long)
    caption_lengths = {}
    for img in unique_images:
        img_captions = df[df['image'] == img]['processed_caption']
        avg_len = sum(len(cap.split()) for cap in img_captions) / len(img_captions)
        caption_lengths[img] = avg_len
    
    # Determine caption length categories
    caption_lens = np.array(list(caption_lengths.values()))
    q1, q2 = np.percentile(caption_lens, [33, 66])
    
    for img, length in caption_lengths.items():
        if length <= q1:
            strat_features[img] = 'short'
        elif length <= q2:
            strat_features[img] = 'medium'
        else:
            strat_features[img] = 'long'
    
    # Create arrays for stratified split
    image_array = np.array(list(strat_features.keys()))
    strat_array = np.array([strat_features[img] for img in image_array])
    
    # Use stratified split
    from sklearn.model_selection import train_test_split
    
    # First split: train+val vs test
    train_val_imgs, test_imgs, _, _ = train_test_split(
        image_array, strat_array,
        test_size=test_ratio,
        random_state=random_state,
        stratify=strat_array
    )
    
    # Second split: train vs val
    # Recalculate stratification features for the train+val set
    strat_array_train_val = np.array([strat_features[img] for img in train_val_imgs])
    
    # Split train+val into train and val
    val_ratio_adjusted = val_ratio / (train_ratio + val_ratio)
    train_imgs, val_imgs, _, _ = train_test_split(
        train_val_imgs, strat_array_train_val,
        test_size=val_ratio_adjusted,
        random_state=random_state,
        stratify=strat_array_train_val
    )
    
    # Create dataframes for each split
    train_df = df[df['image'].isin(train_imgs)].reset_index(drop=True)
    val_df = df[df['image'].isin(val_imgs)].reset_index(drop=True)
    test_df = df[df['image'].isin(test_imgs)].reset_index(drop=True)
    
    # Verify stratification worked
    print("Stratification verification:")
    for caption_type in ['short', 'medium', 'long']:
        train_count = sum(1 for img in train_imgs if strat_features[img] == caption_type)
        val_count = sum(1 for img in val_imgs if strat_features[img] == caption_type)
        test_count = sum(1 for img in test_imgs if strat_features[img] == caption_type)
        
        print(f"  {caption_type.capitalize()} captions - "
              f"Train: {train_count/len(train_imgs)*100:.1f}%, "
              f"Val: {val_count/len(val_imgs)*100:.1f}%, "
              f"Test: {test_count/len(test_imgs)*100:.1f}%")

    return train_df, val_df, test_df

# Create data splits
train_df, val_df, test_df = create_data_splits(captions_df)

print("Data splits:")
print(f"Training: {len(train_df)} captions, {len(train_df['image'].unique())} images")
print(f"Validation: {len(val_df)} captions, {len(val_df['image'].unique())} images")
print(f"Testing: {len(test_df)} captions, {len(test_df['image'].unique())} images")

### 2.4 Vocabulary Building
class Vocabulary:
    """Vocabulary class to handle text tokenization and numericalization"""

    def __init__(self, freq_threshold=5):
        # Special tokens
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}

        # Frequency threshold to include a word in vocabulary
        self.freq_threshold = freq_threshold

        # Counter for new indices
        self.idx = 4
        
        # For storing word frequencies
        self.word_frequencies = {}

    def __len__(self):
        return len(self.itos)

    def tokenize(self, text):
        """Tokenize text using spaCy"""
        return [token.text.lower() for token in nlp.tokenizer(text)]

    def build_vocabulary(self, captions):
        """Build vocabulary from a list of captions"""
        # Counter for word frequencies
        frequencies = {}

        print(f"Building vocabulary from {len(captions)} captions...")

        # Process all captions
        for caption in tqdm(captions):
            # Tokenize caption
            for word in self.tokenize(caption):
                # Update frequency counter
                frequencies[word] = frequencies.get(word, 0) + 1

                # Add word to vocab if it reaches threshold
                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = self.idx
                    self.itos[self.idx] = word
                    self.idx += 1

        print(f"Built vocabulary with {len(self.itos)} tokens")
        print(f"Added {len(self.itos) - 4} words above frequency threshold {self.freq_threshold}")

        # Save word frequencies for later analysis
        self.word_frequencies = frequencies
        return frequencies

    def numericalize(self, text):
        """Convert text to sequence of indices"""
        tokenized = self.tokenize(text)
        return [
            self.stoi.get(token, self.stoi["<UNK>"])
            for token in tokenized
        ]

# Build vocabulary from training set
vocab = Vocabulary(freq_threshold=5)
word_frequencies = vocab.build_vocabulary(train_df['processed_caption'].tolist())

# Analyze vocabulary coverage
def analyze_vocab_coverage(df, vocab, caption_col='processed_caption'):
    """Analyze what percentage of words in the dataset are covered by the vocabulary"""
    total_words = 0
    unknown_words = 0
    unknown_word_instances = {}

    for caption in df[caption_col]:
        tokens = vocab.tokenize(caption)
        total_words += len(tokens)

        for token in tokens:
            if token not in vocab.stoi:
                unknown_words += 1
                unknown_word_instances[token] = unknown_word_instances.get(token, 0) + 1

    coverage = (total_words - unknown_words) / total_words * 100

    print(f"Vocabulary coverage: {coverage:.2f}%")
    print(f"Total words: {total_words}")
    print(f"Unknown words: {unknown_words}")

    if unknown_words > 0:
        print("\nTop unknown words:")
        for word, count in sorted(unknown_word_instances.items(), key=lambda x: x[1], reverse=True)[:10]:
            print(f"  {word}: {count} occurrences")

    return coverage, unknown_word_instances

print("\nAnalyzing vocabulary coverage:")
train_coverage, _ = analyze_vocab_coverage(train_df, vocab)
val_coverage, _ = analyze_vocab_coverage(val_df, vocab)
test_coverage, _ = analyze_vocab_coverage(test_df, vocab)

print(f"\nSummary:")
print(f"Train coverage: {train_coverage:.2f}%")
print(f"Validation coverage: {val_coverage:.2f}%")
print(f"Test coverage: {test_coverage:.2f}%")

### 2.5 Image Processing
def get_transforms(resize=256, crop=224):
    """Create image transformations for training and validation/test sets"""
    # Training transforms with data augmentation
    transform_train = transforms.Compose([
        transforms.Resize(resize),
        transforms.RandomCrop(crop),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ])

    # Validation/test transforms (no augmentation)
    transform_val = transforms.Compose([
        transforms.Resize(resize),
        transforms.CenterCrop(crop),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ])

    return transform_train, transform_val

# Create transforms
transform_train, transform_val = get_transforms()

### 2.6 Dataset and DataLoader
class FlickrDataset(Dataset):
    """Dataset class for Flickr8k images and captions"""

    def __init__(self, data_df, root_dir, vocab, transform=None, caption_col='processed_caption'):
        """
        Initialize the dataset.

        Args:
            data_df: DataFrame containing image filenames and captions
            root_dir: Directory containing the images
            vocab: Vocabulary object for processing captions
            transform: Optional image transformations
            caption_col: Column name for captions in data_df
        """
        self.data_df = data_df
        self.root_dir = root_dir
        self.vocab = vocab
        self.transform = transform
        self.caption_col = caption_col

    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, idx):
        """Get an image and its corresponding caption"""
        # Get caption and image path
        caption = self.data_df.iloc[idx][self.caption_col]
        img_name = self.data_df.iloc[idx]['image']
        img_path = os.path.join(self.root_dir, img_name)

        # Load image
        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a placeholder image in case of error
            image = Image.new('RGB', (224, 224))

        # Apply transformations if any
        if self.transform is not None:
            image = self.transform(image)

        # Process caption: convert to indices
        caption_tokens = [self.vocab.stoi["<SOS>"]]  # Start with SOS token
        caption_tokens.extend(self.vocab.numericalize(caption))
        caption_tokens.append(self.vocab.stoi["<EOS>"])  # End with EOS token

        # Convert to tensor
        caption_tensor = torch.tensor(caption_tokens)

        return image, caption_tensor

class FlickrCollate:
    """Custom collate function to handle variable-length captions"""
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        """
        Args:
            batch: List of tuples (image, caption)

        Returns:
            images: Tensor of shape (batch_size, 3, height, width)
            captions: Padded tensor of shape (batch_size, max_length)
            caption_lengths: List of caption lengths
        """
        # Sort batch by caption length (descending) for packing
        batch.sort(key=lambda x: len(x[1]), reverse=True)

        # Separate images and captions
        images, captions = zip(*batch)

        # Stack images
        images = torch.stack(images, dim=0)

        # Get caption lengths
        caption_lengths = [len(cap) for cap in captions]

        # Pad captions to have same length
        captions_padded = pad_sequence(captions, batch_first=True, padding_value=self.pad_idx)

        return images, captions_padded, caption_lengths

# Create dataset objects
train_dataset = FlickrDataset(
    data_df=train_df,
    root_dir=IMAGES_DIR,
    vocab=vocab,
    transform=transform_train
)

val_dataset = FlickrDataset(
    data_df=val_df,
    root_dir=IMAGES_DIR,
    vocab=vocab,
    transform=transform_val
)

test_dataset = FlickrDataset(
    data_df=test_df,
    root_dir=IMAGES_DIR,
    vocab=vocab,
    transform=transform_val
)

print("\nDataset sizes:")
print(f"Train: {len(train_dataset)} samples")
print(f"Validation: {len(val_dataset)} samples")
print(f"Test: {len(test_dataset)} samples")

# Create DataLoaders
def create_data_loaders(train_dataset, val_dataset, test_dataset, batch_size, vocab):
    """Create data loaders for all splits"""
    pad_idx = vocab.stoi["<PAD>"]
    
    # Worker settings (use 0 for Windows compatibility if needed)
    num_workers = 4 if sys.platform != 'win32' else 0
    
    # Create data loaders
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=FlickrCollate(pad_idx=pad_idx),
        pin_memory=torch.cuda.is_available()
    )

    val_loader = DataLoader(
        dataset=val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=FlickrCollate(pad_idx=pad_idx),
        pin_memory=torch.cuda.is_available()
    )

    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=FlickrCollate(pad_idx=pad_idx),
        pin_memory=torch.cuda.is_available()
    )
    
    return train_loader, val_loader, test_loader

# Create dataloaders
batch_size = model_config['batch_size']
train_loader, val_loader, test_loader = create_data_loaders(
    train_dataset, val_dataset, test_dataset, batch_size, vocab
)

print(f"\nDataLoader batches:")
print(f"Train: {len(train_loader)} batches")
print(f"Validation: {len(val_loader)} batches")
print(f"Test: {len(test_loader)} batches")

# Create smaller debug loaders
def create_debug_loader(dataset, batch_size=8, num_samples=100):
    """Create a smaller loader for debugging"""
    indices = list(range(min(num_samples, len(dataset))))
    subset = Subset(dataset, indices)
    
    pad_idx = vocab.stoi["<PAD>"]
    
    loader = DataLoader(
        dataset=subset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,  # Use single process
        collate_fn=FlickrCollate(pad_idx=pad_idx),
        pin_memory=False
    )
    
    return loader

# Create debug loaders if in debug mode
if DEBUG_MODE:
    debug_train_loader = create_debug_loader(train_dataset, batch_size=8, num_samples=100)
    debug_val_loader = create_debug_loader(val_dataset, batch_size=8, num_samples=50)
    debug_test_loader = create_debug_loader(test_dataset, batch_size=8, num_samples=50)
    
    print("\nDebug loader sizes:")
    print(f"Debug train: {len(debug_train_loader)} batches")
    print(f"Debug val: {len(debug_val_loader)} batches")
    print(f"Debug test: {len(debug_test_loader)} batches")
    
    # Use debug loaders for training/validation
    train_loader_to_use = debug_train_loader
    val_loader_to_use = debug_val_loader
    test_loader_to_use = debug_test_loader
else:
    # Use full loaders for training/validation
    train_loader_to_use = train_loader
    val_loader_to_use = val_loader
    test_loader_to_use = test_loader

# Sample batch inspection
def inspect_batch(data_loader, vocab):
    """Inspect a batch to verify dataloader works correctly"""
    # Get a batch
    images, captions, lengths = next(iter(data_loader))

    print(f"Batch contents:")
    print(f"Images shape: {images.shape}")
    print(f"Captions shape: {captions.shape}")
    print(f"Caption lengths: {lengths[:5]}...")

    # Display one sample
    idx = 0  # First sample in batch
    caption = captions[idx]
    caption_words = [vocab.itos[token_idx.item()] for token_idx in caption 
                    if token_idx.item() < len(vocab)]
    
    print(f"\nSample caption: {' '.join(caption_words)}")
    
    # Convert image for display
    img = images[idx].permute(1, 2, 0).numpy()
    img = img * np.array(IMAGENET_STD) + np.array(IMAGENET_MEAN)  # Denormalize
    img = np.clip(img, 0, 1)
    
    # Display image
    plt.figure(figsize=(8, 6))
    plt.imshow(img)
    plt.title("Sample Image")
    plt.axis('off')
    plt.show()
    
    return images, captions, lengths

# Save vocabulary to disk
with open(os.path.join(OUTPUT_DIR, 'vocabulary.pkl'), 'wb') as f:
    pickle.dump(vocab, f)
print(f"\nVocabulary saved to {os.path.join(OUTPUT_DIR, 'vocabulary.pkl')}")

# Inspect a batch from train loader
print("\nInspecting a batch from the training loader:")
_ = inspect_batch(train_loader_to_use, vocab)



In [None]:
## Model Architecture

### 3.1 Encoder Component
class EncoderCNN(nn.Module):
    """CNN encoder for extracting image features"""

    def __init__(self, embed_size, dropout=0.5, train_cnn=False, attention=False):
        super(EncoderCNN, self).__init__()

        # Load pre-trained ResNet-50
        resnet = models.resnet50(pretrained=True)

        # Different handling for attention vs baseline
        if attention:
            # For attention: Keep spatial information, remove final FC and pooling
            modules = list(resnet.children())[:-2]
            print(f"Initializing Encoder CNN with spatial features:")
            self.feature_size = 2048  # ResNet features without pooling
            
            # Conv layer to reduce channel dimension
            self.conv = nn.Conv2d(self.feature_size, embed_size, kernel_size=1)
            
        else:
            # For baseline: Use pooled features
            modules = list(resnet.children())[:-1]
            print(f"Initializing Encoder CNN:")
            # Save the feature size
            self.feature_size = resnet.fc.in_features
            
            # Project to embedding space
            self.fc = nn.Linear(self.feature_size, embed_size)
            
            # Use LayerNorm instead of BatchNorm (works with small batches)
            self.norm = nn.LayerNorm(embed_size)
            
        # Create resnet feature extractor
        self.resnet = nn.Sequential(*modules)
        self.attention = attention

        # Freeze or unfreeze CNN
        for param in self.resnet.parameters():
            param.requires_grad = train_cnn

        # Additional layers
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

        # Print architecture info
        print(f"  Embed size: {embed_size}")
        print(f"  Feature size: {self.feature_size}")
        print(f"  Using attention: {attention}")
        print(f"  Training CNN backbone: {train_cnn}")
        
        # Count parameters
        if attention:
            print(f"  Conv parameters: {sum(p.numel() for p in self.conv.parameters()):,}")
        else:
            print(f"  FC parameters: {sum(p.numel() for p in self.fc.parameters()):,}")
        print(f"  Total parameters: {sum(p.numel() for p in self.parameters()):,}")
        print(f"  Trainable parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}")

    def forward(self, images, debug=False):
        """Extract features from images"""
        if debug:
            debug_print("Encoder input", images)

        # Get features from ResNet
        features = self.resnet(images)

        if debug:
            debug_print("ResNet output", features)

        # Different processing for attention vs baseline
        if self.attention:
            # For attention model: Use 1x1 conv to reduce channels
            features = self.conv(features)
            
            if debug:
                debug_print("After 1x1 conv", features)
                
            # Apply ReLU and dropout
            features = self.dropout(self.relu(features))
                
        else:
            # For baseline model: Flatten and project
            # Reshape: (batch_size, 2048, 1, 1) -> (batch_size, 2048)
            features = features.view(features.size(0), -1)
            
            if debug:
                debug_print("Reshaped features", features)
            
            # Project to embedding space
            features = self.fc(features)
            
            if debug:
                debug_print("After FC projection", features)
            
            # Apply normalization, dropout and ReLU
            features = self.norm(features)
            
            if debug:
                debug_print("After normalization", features)
                
            features = self.dropout(self.relu(features))

        if debug:
            debug_print("Final encoder output", features)

        return features

### 3.2 Attention Mechanism
class Attention(nn.Module):
    """Attention mechanism for focusing on specific parts of the image"""

    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        
        # Layers for attention mechanism
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)
        self.full_att = nn.Linear(attention_dim, 1)
        
        # Activation functions
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
        
        # Print architecture info
        print(f"Initializing Attention mechanism:")
        print(f"  Encoder dimension: {encoder_dim}")
        print(f"  Decoder dimension: {decoder_dim}")
        print(f"  Attention dimension: {attention_dim}")
        
        # Count parameters
        print(f"  Attention parameters: {sum(p.numel() for p in self.parameters()):,}")

    def forward(self, encoder_out, decoder_hidden):
        """
        Forward pass of the attention layer
        
        Args:
            encoder_out: Feature maps from encoder, shape (batch_size, num_pixels, encoder_dim)
            decoder_hidden: Hidden state of the decoder, shape (batch_size, decoder_dim)
            
        Returns:
            attention_weighted_encoding: Weighted sum of encoder outputs (batch_size, encoder_dim)
            alpha: Attention weights (batch_size, num_pixels)
        """
        # Transform encoder output for attention
        att1 = self.encoder_att(encoder_out)  # (batch_size, num_pixels, attention_dim)
        
        # Transform decoder hidden state for attention
        att2 = self.decoder_att(decoder_hidden)  # (batch_size, attention_dim)
        
        # Sum and apply non-linearity
        att = self.relu(att1 + att2.unsqueeze(1))  # (batch_size, num_pixels, attention_dim)
        
        # Compute attention scores
        att = self.full_att(att).squeeze(2)  # (batch_size, num_pixels)
        
        # Apply softmax to get attention weights
        alpha = self.softmax(att)  # (batch_size, num_pixels)
        
        # Compute weighted encoding
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # (batch_size, encoder_dim)
        
        return attention_weighted_encoding, alpha

### 3.3 Decoder Component
class DecoderRNN(nn.Module):
    """RNN decoder for generating captions"""

    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1, dropout=0.5):
        super(DecoderRNN, self).__init__()

        # Store parameters
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.num_layers = num_layers

        # Print architecture information
        print(f"\nInitializing Decoder RNN:")
        print(f"  Embed size: {embed_size}")
        print(f"  Hidden size: {hidden_size}")
        print(f"  Vocabulary size: {vocab_size}")
        print(f"  LSTM layers: {num_layers}")

        # Word embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_size)

        # LSTM layer
        self.lstm = nn.LSTM(
            input_size=embed_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )

        # Output layer
        self.output_layer = nn.Linear(hidden_size, vocab_size)

        # Dropout layer
        self.dropout = nn.Dropout(dropout)

        # Print parameter counts
        embed_params = sum(p.numel() for p in self.embedding.parameters())
        lstm_params = sum(p.numel() for p in self.lstm.parameters())
        output_params = sum(p.numel() for p in self.output_layer.parameters())

        print(f"  Embedding parameters: {embed_params:,}")
        print(f"  LSTM parameters: {lstm_params:,}")
        print(f"  Output layer parameters: {output_params:,}")
        print(f"  Total parameters: {sum(p.numel() for p in self.parameters()):,}")

    def forward_with_teacher_forcing(self, features, captions, caption_lengths, debug=False):
        """
        Forward pass with teacher forcing during training.

        Args:
            features: Image features from encoder (batch_size, embed_size)
            captions: Ground truth captions (batch_size, max_length)
            caption_lengths: True lengths of each caption
            debug: Whether to print debug info

        Returns:
            outputs: Predicted word scores (batch_size, max_length, vocab_size)
        """
        batch_size = features.size(0)
        max_length = captions.size(1)

        if debug:
            debug_print("Decoder input", features, level=1)
            debug_print(f"Caption shape: {captions.shape}, Lengths: {caption_lengths}", None, level=1)

        # Prepare embeddings for captions
        embeddings = self.dropout(self.embedding(captions))

        if debug:
            debug_print("Caption embeddings", embeddings, level=1)

        # Prepare to include features as first input
        # Reshape features: (batch_size, embed_size) -> (batch_size, 1, embed_size)
        features = features.unsqueeze(1)

        # For teacher forcing, we'll use features for the first time step,
        # then the embeddings of the ground truth tokens
        decoder_input = torch.cat([features, embeddings[:, :-1, :]], dim=1)

        if debug:
            debug_print("Decoder input sequence", decoder_input, level=1)

        # Run through LSTM
        outputs, _ = self.lstm(decoder_input)

        if debug:
            debug_print("LSTM outputs", outputs, level=1)

        # Generate word scores
        outputs = self.output_layer(outputs)

        if debug:
            debug_print("Final outputs", outputs, level=1)

        return outputs

    def sample(self, features, max_length=20, debug=False):
        """
        Generate captions using greedy search.

        Args:
            features: Image features from encoder (batch_size, embed_size)
            max_length: Maximum caption length
            debug: Whether to print debug info

        Returns:
            sampled_ids: Predicted caption indices (batch_size, max_length)
        """
        batch_size = features.size(0)

        if debug:
            debug_print("Sampling - Feature input", features, level=1)

        # Initialize result tensor
        sampled_ids = torch.zeros(batch_size, max_length, dtype=torch.long, device=features.device)

        # Initialize hidden and cell states
        h = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(features.device)
        c = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(features.device)

        # First input is the image features
        input_word = features

        # Generate words one by one
        for i in range(max_length):
            if debug and i == 0:
                debug_print(f"Sampling step {i} - Input", input_word, level=1)

            # Run LSTM step
            output, (h, c) = self.lstm(input_word.unsqueeze(1), (h, c))

            if debug and i == 0:
                debug_print(f"Sampling step {i} - LSTM output", output, level=1)

            # Get word predictions
            output = self.output_layer(output.squeeze(1))

            if debug and i == 0:
                debug_print(f"Sampling step {i} - Word scores", output, level=1)

            # Greedy search - pick highest probability word
            predicted = output.argmax(dim=1)

            if debug and i == 0:
                debug_print(f"Sampling step {i} - Predicted word indices: {predicted}", None, level=1)

            # Save prediction
            sampled_ids[:, i] = predicted

            # Next input is the predicted word embedding
            input_word = self.embedding(predicted)

        if debug:
            debug_print(f"Sampling complete - Output shape: {sampled_ids.shape}", level=1)

        return sampled_ids

### 3.4 Attention Decoder Component
class AttentionDecoderRNN(nn.Module):
    """RNN decoder that uses attention mechanism for generating captions"""

    def __init__(self, embed_size, hidden_size, vocab_size, encoder_dim, attention_dim, num_layers=1, dropout=0.5):
        super(AttentionDecoderRNN, self).__init__()
        
        # Store parameters
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.encoder_dim = encoder_dim
        self.attention_dim = attention_dim
        self.num_layers = num_layers
        
        # Print architecture information
        print(f"\nInitializing Attention Decoder RNN:")
        print(f"  Embed size: {embed_size}")
        print(f"  Hidden size: {hidden_size}")
        print(f"  Vocabulary size: {vocab_size}")
        print(f"  Encoder dimension: {encoder_dim}")
        print(f"  Attention dimension: {attention_dim}")
        print(f"  LSTM layers: {num_layers}")
        
        # Word embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_size)
        
        # Attention mechanism
        self.attention = Attention(encoder_dim, hidden_size, attention_dim)
        
        # Decoder LSTM
        self.lstm = nn.LSTM(
            input_size=embed_size + encoder_dim,  # Input is concat of embedding and context
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        
        # Layer to compute initial hidden/cell states from mean of encoder output
        self.init_h = nn.Linear(encoder_dim, hidden_size)
        self.init_c = nn.Linear(encoder_dim, hidden_size)
        
        # Layer to produce word scores
        self.fc = nn.Linear(hidden_size + encoder_dim, vocab_size)
        
        # Dropout layer
        self.dropout = nn.Dropout(dropout)
        
        # Print parameter counts
        embed_params = sum(p.numel() for p in self.embedding.parameters())
        att_params = sum(p.numel() for p in self.attention.parameters())
        lstm_params = sum(p.numel() for p in self.lstm.parameters())
        fc_params = sum(p.numel() for p in self.fc.parameters())
        
        print(f"  Embedding parameters: {embed_params:,}")
        print(f"  Attention parameters: {att_params:,}")
        print(f"  LSTM parameters: {lstm_params:,}")
        print(f"  Output layer parameters: {fc_params:,}")
        print(f"  Total parameters: {sum(p.numel() for p in self.parameters()):,}")

    def init_hidden_state(self, encoder_out):
        """
        Initialize LSTM hidden and cell states using the encoder output
        
        Args:
            encoder_out: Feature maps from encoder, shape (batch_size, num_pixels, encoder_dim)
            
        Returns:
            h, c: Initial hidden and cell states
        """
        # Mean of encoder output across all pixels
        mean_encoder_out = encoder_out.mean(dim=1)  # (batch_size, encoder_dim)
        
        # Project to get initial states
        h = self.init_h(mean_encoder_out)  # (batch_size, hidden_size)
        c = self.init_c(mean_encoder_out)  # (batch_size, hidden_size)
        
        # Reshape for LSTM which expects (num_layers, batch_size, hidden_size)
        h = h.unsqueeze(0).repeat(self.num_layers, 1, 1)
        c = c.unsqueeze(0).repeat(self.num_layers, 1, 1)
        
        return h, c
        
    def forward_with_teacher_forcing(self, encoder_out, captions, caption_lengths, debug=False):
        """
        Forward pass with teacher forcing during training
        
        Args:
            encoder_out: Feature maps from encoder, shape (batch_size, encoder_dim, height, width)
            captions: Ground truth captions (batch_size, max_length)
            caption_lengths: True lengths of each caption
            debug: Whether to print debug info
            
        Returns:
            outputs: Predicted word scores (batch_size, max_length, vocab_size)
            alphas: Attention weights for visualization (batch_size, max_length, num_pixels)
        """
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(1)
        height = encoder_out.size(2)
        width = encoder_out.size(3)
        max_length = captions.size(1)
        
        if debug:
            debug_print("Decoder input - encoder out", encoder_out, level=1)
            debug_print(f"Caption shape: {captions.shape}, Lengths: {caption_lengths}", None, level=1)
        
        # Flatten spatial dimensions of encoder output for attention
        encoder_out = encoder_out.permute(0, 2, 3, 1)  # (batch_size, height, width, encoder_dim)
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  # (batch_size, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)
        
        if debug:
            debug_print("Flattened encoder out", encoder_out, level=1)
        
        # Initialize LSTM hidden and cell states
        h, c = self.init_hidden_state(encoder_out)
        
        # Prepare embeddings for captions
        embeddings = self.dropout(self.embedding(captions))  # (batch_size, max_length, embed_size)
        
        if debug:
            debug_print("Caption embeddings", embeddings, level=1)
        
        # Initialize tensors for predictions and attention weights
        predictions = torch.zeros(batch_size, max_length, self.vocab_size).to(captions.device)
        alphas = torch.zeros(batch_size, max_length, num_pixels).to(captions.device)
        
        # For each time step
        for t in range(max_length):
            # Get hidden state for attention (using last layer's hidden state)
            h_for_att = h[-1]  # (batch_size, hidden_size)
            
            # Compute attention
            context, alpha = self.attention(encoder_out, h_for_att)
            
            if debug and t == 0:
                debug_print(f"Time step {t} - Context", context, level=1)
                debug_print(f"Time step {t} - Attention weights shape: {alpha.shape}", None, level=1)
            
            # Store attention weights
            alphas[:, t] = alpha
            
            # Prepare input for LSTM - concatenate context with embedding
            lstm_input = torch.cat([embeddings[:, t], context], dim=1).unsqueeze(1)
            
            # Run LSTM step
            output, (h, c) = self.lstm(lstm_input, (h, c))
            
            # Reshape output
            output = output.squeeze(1)  # (batch_size, hidden_size)
            
            # Concatenate output with context for final prediction
            output = torch.cat([output, context], dim=1)  # (batch_size, hidden_size + encoder_dim)
            
            # Generate word scores
            preds = self.fc(self.dropout(output))  # (batch_size, vocab_size)
            
            # Store predictions
            predictions[:, t] = preds
        
        if debug:
            debug_print("Final predictions", predictions, level=1)
            debug_print("Attention weights", alphas, level=1)
        
        return predictions, alphas
    
    def sample(self, encoder_out, max_length=20, debug=False):
        """
        Generate captions using attention and greedy search
        
        Args:
            encoder_out: Feature maps from encoder (batch_size, encoder_dim, height, width)
            max_length: Maximum caption length
            debug: Whether to print debug info
            
        Returns:
            sampled_ids: Predicted caption indices (batch_size, max_length)
            alphas: Attention weights for visualization (batch_size, max_length, num_pixels)
        """
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(1)
        height = encoder_out.size(2)
        width = encoder_out.size(3)
        
        if debug:
            debug_print("Sampling - Feature input", encoder_out, level=1)
        
        # Flatten spatial dimensions for attention
        encoder_out = encoder_out.permute(0, 2, 3, 1)  # (batch_size, height, width, encoder_dim)
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  # (batch_size, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)
        
        # Initialize LSTM hidden and cell states
        h, c = self.init_hidden_state(encoder_out)
        
        # Initialize result tensors
        sampled_ids = torch.zeros(batch_size, max_length, dtype=torch.long).to(encoder_out.device)
        alphas = torch.zeros(batch_size, max_length, num_pixels).to(encoder_out.device)
        
        # Start with <SOS> token (index 1)
        input_word = torch.ones(batch_size, dtype=torch.long).to(encoder_out.device)
        
        # Generate words one by one
        for t in range(max_length):
            # Embed the input word
            embedded = self.embedding(input_word)  # (batch_size, embed_size)
            
            # Get hidden state for attention
            h_for_att = h[-1]  # (batch_size, hidden_size)
            
            # Compute attention
            context, alpha = self.attention(encoder_out, h_for_att)
            
            # Store attention weights
            alphas[:, t] = alpha
            
            # Prepare input for LSTM
            lstm_input = torch.cat([embedded, context], dim=1).unsqueeze(1)  # (batch_size, 1, embed_size + encoder_dim)
            
            # Run LSTM step
            output, (h, c) = self.lstm(lstm_input, (h, c))
            
            # Reshape output
            output = output.squeeze(1)  # (batch_size, hidden_size)
            
            # Concatenate output with context
            output = torch.cat([output, context], dim=1)  # (batch_size, hidden_size + encoder_dim)
            
            # Generate word scores
            preds = self.fc(self.dropout(output))  # (batch_size, vocab_size)
            
            # Greedy search - pick highest probability word
            predicted = preds.argmax(dim=1)  # (batch_size)
            
            # Store prediction
            sampled_ids[:, t] = predicted
            
            # Next input is the predicted word
            input_word = predicted
        
        if debug:
            debug_print(f"Sampling complete - Output shape: {sampled_ids.shape}", level=1)
            debug_print(f"Attention weights shape: {alphas.shape}", level=1)
        
        return sampled_ids, alphas

### 3.5 Baseline Caption Model
class BaselineCaptionModel(nn.Module):
    """Complete CNN-RNN model for image captioning"""

    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1, dropout=0.5):
        super(BaselineCaptionModel, self).__init__()

        # Print architecture information
        print("\nInitializing Baseline Caption Model")
        print(f"  Embed size: {embed_size}")
        print(f"  Hidden size: {hidden_size}")
        print(f"  Vocabulary size: {vocab_size}")
        print(f"  LSTM layers: {num_layers}")

        # Create encoder and decoder
        self.encoder = EncoderCNN(embed_size, dropout, attention=False)
        self.decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers, dropout)

        # Print total parameters
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)

        print(f"\nModel Summary:")
        print(f"  Total parameters: {total_params:,}")
        print(f"  Trainable parameters: {trainable_params:,}")

    def forward(self, images, captions, caption_lengths, debug=False):
        """
        Forward pass for training.

        Args:
            images: Input images (batch_size, 3, height, width)
            captions: Caption indices (batch_size, max_length)
            caption_lengths: True lengths of captions
            debug: Whether to print debug info

        Returns:
            outputs: Predicted word scores
        """
        if debug:
            start_time = time.time()
            debug_print("Starting forward pass", None, level=0)
            debug_print("Input shapes - Images: {}, Captions: {}".format(
                images.shape, captions.shape), None, level=0)

        # Extract image features
        features = self.encoder(images, debug)

        # Generate captions
        outputs = self.decoder.forward_with_teacher_forcing(
            features, captions, caption_lengths, debug)

        if debug:
            debug_print(f"Forward pass completed in {time.time() - start_time:.4f}s", None, level=0)

        return outputs

    def sample(self, images, max_length=20, debug=False):
        """
        Generate captions for given images.

        Args:
            images: Input images (batch_size, 3, height, width)
            max_length: Maximum caption length
            debug: Whether to print debug info

        Returns:
            sampled_ids: Generated caption indices
        """
        if debug:
            start_time = time.time()
            debug_print("Starting sampling", None, level=0)

        # Extract image features
        features = self.encoder(images, debug)

        # Generate captions
        sampled_ids = self.decoder.sample(features, max_length, debug)

        if debug:
            debug_print(f"Sampling completed in {time.time() - start_time:.4f}s", None, level=0)

        return sampled_ids

    def caption_image(self, image, vocab, max_length=20, debug=False):
        """
        Generate a caption for a single image.

        Args:
            image: Input image (1, 3, height, width)
            vocab: Vocabulary object
            max_length: Maximum caption length
            debug: Whether to print debug info

        Returns:
            caption: Generated caption as string
        """
        # Set to evaluation mode
        self.eval()

        if debug:
            debug_print("Generating caption for image", image, level=0)

        with torch.no_grad():
            # Generate caption indices
            sampled_ids = self.sample(image, max_length, debug)

            # Convert indices to words
            sampled_ids = sampled_ids[0].cpu().numpy()

            # Create caption
            caption_words = []
            for idx in sampled_ids:
                word = vocab.itos[idx]

                # Stop if EOS token
                if word == "<EOS>":
                    break

                # Skip special tokens
                if word not in ["<PAD>", "<SOS>"]:
                    caption_words.append(word)

            caption = " ".join(caption_words)

        if debug:
            debug_print(f"Generated caption: '{caption}'", None, level=0)

        return caption

### 3.6 Attention Caption Model
class AttentionCaptionModel(nn.Module):
    """Complete CNN-RNN model with attention for image captioning"""

    def __init__(self, embed_size, hidden_size, vocab_size, attention_dim, num_layers=1, dropout=0.5):
        super(AttentionCaptionModel, self).__init__()
        
        # Print architecture information
        print("\nInitializing Attention Caption Model")
        print(f"  Embed size: {embed_size}")
        print(f"  Hidden size: {hidden_size}")
        print(f"  Vocabulary size: {vocab_size}")
        print(f"  Attention dimension: {attention_dim}")
        print(f"  LSTM layers: {num_layers}")
        
        # Create encoder and decoder
        self.encoder = EncoderCNN(embed_size, dropout, attention=True)
        encoder_dim = embed_size  # The encoder projects to embed_size
        
        self.decoder = AttentionDecoderRNN(
            embed_size, hidden_size, vocab_size, 
            encoder_dim, attention_dim, num_layers, dropout
        )
        
        # Print total parameters
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        
        print(f"\nModel Summary:")
        print(f"  Total parameters: {total_params:,}")
        print(f"  Trainable parameters: {trainable_params:,}")
    
    def forward(self, images, captions, caption_lengths, debug=False):
        """
        Forward pass for training with teacher forcing
        
        Args:
            images: Input images (batch_size, 3, height, width)
            captions: Caption indices (batch_size, max_length)
            caption_lengths: True lengths of captions
            debug: Whether to print debug info
            
        Returns:
            outputs: Predicted word scores
            alphas: Attention weights
        """
        if debug:
            start_time = time.time()
            debug_print("Starting forward pass", None, level=0)
            debug_print("Input shapes - Images: {}, Captions: {}".format(
                images.shape, captions.shape), None, level=0)
        
        # Extract image features
        features = self.encoder(images, debug)
        
        # Generate captions with attention
        outputs, alphas = self.decoder.forward_with_teacher_forcing(
            features, captions, caption_lengths, debug)
        
        if debug:
            debug_print(f"Forward pass completed in {time.time() - start_time:.4f}s", None, level=0)
        
        return outputs, alphas
    
    def sample(self, images, max_length=20, debug=False):
        """
        Generate captions with attention for given images
        
        Args:
            images: Input images (batch_size, 3, height, width)
            max_length: Maximum caption length
            debug: Whether to print debug info
            
        Returns:
            sampled_ids: Generated caption indices
            alphas: Attention weights for visualization
        """
        if debug:
            start_time = time.time()
            debug_print("Starting sampling", None, level=0)
        
        # Extract image features
        features = self.encoder(images, debug)
        
        # Generate captions with attention
        sampled_ids, alphas = self.decoder.sample(features, max_length, debug)
        
        if debug:
            debug_print(f"Sampling completed in {time.time() - start_time:.4f}s", None, level=0)
        
        return sampled_ids, alphas
    
    def caption_image_with_attention(self, image, vocab, max_length=20, debug=False):
        """
        Generate a caption with attention for a single image
        
        Args:
            image: Input image (1, 3, height, width)
            vocab: Vocabulary object
            max_length: Maximum caption length
            debug: Whether to print debug info
            
        Returns:
            caption: Generated caption as string
            alphas: Attention weights for visualization
        """
        # Set to evaluation mode
        self.eval()
        
        if debug:
            debug_print("Generating caption with attention for image", image, level=0)
        
        with torch.no_grad():
            # Generate caption indices and attention weights
            sampled_ids, alphas = self.sample(image, max_length, debug)
            
            # Convert indices to words
            sampled_ids = sampled_ids[0].cpu().numpy()
            
            # Create caption
            caption_words = []
            attention_weights = []
            
            for i, idx in enumerate(sampled_ids):
                word = vocab.itos[idx]
                
                # Stop if EOS token
                if word == "<EOS>":
                    attention_weights.append(alphas[0, i].cpu().numpy())
                    break
                
                # Skip special tokens
                if word not in ["<PAD>", "<SOS>"]:
                    caption_words.append(word)
                    attention_weights.append(alphas[0, i].cpu().numpy())
            
            caption = " ".join(caption_words)
        
        if debug:
            debug_print(f"Generated caption: '{caption}'", None, level=0)
        
        return caption, attention_weights

### 3.7 Architecture Visualization Functions
def visualize_model_architecture(model, is_attention_model=False):
    """Visualize model architecture with dimensions"""
    import matplotlib.patches as patches

    # Create figure
    fig, ax = plt.subplots(figsize=(14, 10) if not is_attention_model else (15, 12))
    ax.axis('off')

    # Title
    if is_attention_model:
        ax.text(0.5, 0.97, "Attention-based Image Captioning Architecture", ha='center', fontsize=18, fontweight='bold')
    else:
        ax.text(0.5, 0.97, "Baseline CNN+RNN Architecture", ha='center', fontsize=18, fontweight='bold')

    # Helper function to draw a box
    def draw_box(x, y, width, height, label, details=None, color='lightblue'):
        box = patches.Rectangle((x, y), width, height, fill=True, color=color, alpha=0.8,
                            linewidth=2, edgecolor='black')
        ax.add_patch(box)
        ax.text(x + width/2, y + height/2, label, ha='center', va='center', fontsize=12, fontweight='bold')

        if details:
            ax.text(x + width + 0.02, y + height/2, details, ha='left', va='center', fontsize=10)

        return box

    # Dimensions for boxes
    box_width = 0.6
    box_height = 0.07
    gap = 0.02
    x = 0.2

    if is_attention_model:
        # =========== ATTENTION MODEL VISUALIZATION ===========
        # Add section for encoder
        y = 0.9
        ax.text(0.5, 0.94, "Encoder (Feature Extraction)", fontsize=14, fontweight='bold', ha='center')
        
        # Input image
        draw_box(x, y, box_width, box_height, "Input Image",
                f"Shape: (batch_size, 3, 224, 224)", 'lightblue')
        y -= box_height + gap
        
        # ResNet backbone
        draw_box(x, y, box_width, box_height, "ResNet-50 (up to layer4)",
                "Pretrained CNN without pooling layer", 'lightgreen')
        y -= box_height + gap
        
        # Feature maps
        feature_map_size = 224 // 32  # ResNet downsamples by factor of 32
        draw_box(x, y, box_width, box_height, "Feature Maps",
                f"Shape: (batch_size, 2048, {feature_map_size}, {feature_map_size})", 'lightyellow')
        y -= box_height + gap
        
        # 1x1 convolution
        draw_box(x, y, box_width, box_height, "1x1 Convolution",
                f"Shape: (batch_size, embed_size, {feature_map_size}, {feature_map_size})", 'lightgreen')
        y -= box_height + gap
        
        # Flattened features
        draw_box(x, y, box_width, box_height, "Flattened Features",
                f"Shape: (batch_size, {feature_map_size*feature_map_size}, embed_size)", 'lightyellow')
        y -= box_height + 2*gap
        
        # Add section for attention
        ax.text(0.5, y, "Attention Mechanism", fontsize=14, fontweight='bold', ha='center')
        y -= gap
        
        # Attention mechanism
        att_y = y
        draw_box(x, y, box_width, box_height, "Attention Weights",
                "Shape: (batch_size, num_pixels)", 'pink')
        y -= box_height + gap
        
        # Weighted encoding
        draw_box(x, y, box_width, box_height, "Weighted Feature Vector",
                "Shape: (batch_size, embed_size)", 'lightyellow')
        y -= box_height + 2*gap
        
        # Add section for decoder
        ax.text(0.5, y, "Decoder (Caption Generation)", fontsize=14, fontweight='bold', ha='center')
        y -= gap
        
        # Word embedding
        draw_box(x, y, box_width, box_height, "Word Embedding",
                "Shape: (batch_size, embed_size)", 'lightblue')
        y -= box_height + gap
        
        # Combined input
        draw_box(x, y, box_width, box_height, "Combined Input",
                "Shape: (batch_size, embed_size + encoder_dim)", 'lightyellow')
        y -= box_height + gap
        
        # LSTM
        draw_box(x, y, box_width, box_height, "LSTM",
                "Hidden size, Layers, Dropout", 'lightgreen')
        y -= box_height + gap
        
        # Output projection
        draw_box(x, y, box_width, box_height, "Output Projection",
                "Shape: (batch_size, vocab_size)", 'lightgreen')
        y -= box_height + gap
        
        # Word predictions
        draw_box(x, y, box_width, box_height, "Word Predictions",
                "Cross-entropy loss during training", 'lightyellow')
        
        # Add flow arrows
        for i in range(10):  # Adjust based on the number of boxes
            if i != 4 and i != 7:  # Skip arrows before each new section
                arrow_y = 0.9 - i * (box_height + gap) - (0.5*gap if i >= 5 else 0) - (0.5*gap if i >= 8 else 0)
                ax.annotate("", xy=(x + box_width/2, arrow_y - box_height - gap),
                         xytext=(x + box_width/2, arrow_y),
                         arrowprops=dict(arrowstyle="->", lw=2, color='black'))
        
        # Add arrows for attention flow
        # From encoder to attention
        att_x = x + box_width + 0.1
        ax.annotate("", xy=(att_x, att_y + box_height/2),
                 xytext=(x + box_width, 0.9 - 4*(box_height + gap) - box_height/2),
                 arrowprops=dict(arrowstyle="->", lw=2, color='blue', connectionstyle="arc3,rad=0.3"))
        
        # From decoder to attention
        ax.annotate("", xy=(att_x, att_y + box_height/2),
                 xytext=(x + box_width, y + 3*(box_height + gap) + box_height/2),
                 arrowprops=dict(arrowstyle="->", lw=2, color='red', connectionstyle="arc3,rad=-0.3"))
        
        # From attention to decoder
        ax.annotate("", xy=(x + box_width/2, y + 5*(box_height + gap)),
                 xytext=(x + box_width/2, y + 7*(box_height + gap)),
                 arrowprops=dict(arrowstyle="->", lw=2, color='purple'))
        
        # Add parameter information
        param_text = (
            f"Model Parameters:\n"
            f"  Encoder: {sum(p.numel() for p in model.encoder.parameters()):,}\n"
            f"  Decoder: {sum(p.numel() for p in model.decoder.parameters()):,}\n"
            f"    - Attention: {sum(p.numel() for p in model.decoder.attention.parameters()):,}\n"
            f"    - LSTM: {sum(p.numel() for p in model.decoder.lstm.parameters()):,}\n"
            f"  Total: {sum(p.numel() for p in model.parameters()):,}\n"
            f"  Trainable: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}"
        )
        
        ax.text(0.02, 0.02, param_text, fontsize=10, ha='left', va='bottom',
               bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
        
        # Add legend for attention flow
        legend_elements = [
            patches.Patch(facecolor='blue', alpha=0.6, label='Image Features'),
            patches.Patch(facecolor='red', alpha=0.6, label='Hidden State'),
            patches.Patch(facecolor='purple', alpha=0.6, label='Attention Context')
        ]
        ax.legend(handles=legend_elements, loc='lower right', fontsize=10)
        
    else:
        # =========== BASELINE MODEL VISUALIZATION ===========
        y = 0.85
        
        # Encoder section
        ax.text(0.5, 0.9, "Encoder", fontsize=14, fontweight='bold', ha='center')

        # Input image
        draw_box(x, y, box_width, box_height, "Input Image",
                f"Shape: (batch_size, 3, 224, 224)", 'lightblue')
        y -= box_height + gap

        # ResNet
        draw_box(x, y, box_width, box_height, "ResNet-50 Backbone",
                "Pretrained CNN, extracts visual features", 'lightgreen')
        y -= box_height + gap

        # CNN Features
        draw_box(x, y, box_width, box_height, "CNN Features",
                f"Shape: (batch_size, {model.encoder.feature_size}, 1, 1)", 'lightyellow')
        y -= box_height + gap

        # Linear projection
        draw_box(x, y, box_width, box_height, "Linear Projection",
                f"Shape: (batch_size, {model_config['embed_size']})", 'lightgreen')
        y -= box_height + gap

        # Encoded Features
        draw_box(x, y, box_width, box_height, "Encoded Features",
                f"Shape: (batch_size, {model_config['embed_size']})", 'lightyellow')
        y -= box_height + 2*gap

        # Decoder section
        ax.text(0.5, y, "Decoder", fontsize=14, fontweight='bold', ha='center')
        y -= gap

        # Caption Input
        draw_box(x, y, box_width, box_height, "Caption Input",
                "Shape: (batch_size, seq_length)", 'lightblue')
        y -= box_height + gap

        # Embedding
        draw_box(x, y, box_width, box_height, "Word Embedding",
                f"Shape: (batch_size, seq_length, {model_config['embed_size']})", 'lightgreen')
        y -= box_height + gap

        # LSTM
        draw_box(x, y, box_width, box_height, "LSTM",
                f"Hidden size: {model_config['hidden_size']}, Layers: {model_config['num_layers']}", 'lightgreen')
        y -= box_height + gap

        # Linear output
        draw_box(x, y, box_width, box_height, "Linear Output Layer",
                f"Shape: (batch_size, seq_length, {len(vocab)})", 'lightgreen')
        y -= box_height + gap

        # Final output
        draw_box(x, y, box_width, box_height, "Word Predictions",
                "Cross-entropy loss during training", 'lightyellow')

        # Add arrows
        for i in range(9):
            y_pos = 0.85 - (i+0.5)*box_height - i*gap - (0.5*gap if i >= 5 else 0)
            if i != 4:  # Skip arrow before decoder section
                ax.annotate("", xy=(x + box_width/2, y_pos - box_height - gap),
                        xytext=(x + box_width/2, y_pos),
                        arrowprops=dict(arrowstyle="->", lw=2, color='black'))

        # Add parameter information
        param_text = (
            f"Model Parameters:\n"
            f"  Encoder: {sum(p.numel() for p in model.encoder.parameters()):,}\n"
            f"  Decoder: {sum(p.numel() for p in model.decoder.parameters()):,}\n"
            f"  Total: {sum(p.numel() for p in model.parameters()):,}\n"
            f"  Trainable: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}"
        )

        ax.text(0.02, 0.02, param_text, fontsize=10, ha='left', va='bottom',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

        # Add note about normalization constants
        norm_text = (
            f"Image Normalization:\n"
            f"  Mean: {IMAGENET_MEAN}\n"
            f"  Std: {IMAGENET_STD}\n"
            f"  (ImageNet pretrained values)"
        )

        ax.text(0.75, 0.02, norm_text, fontsize=9, ha='left', va='bottom',
            bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))

    plt.tight_layout()
    save_path = os.path.join(LOGS_DIR, "attention_architecture.png" if is_attention_model else "baseline_architecture.png")
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()



In [None]:
## Training Framework

### 4.1 Loss and Metric Functions
def calculate_loss(predictions, targets, criterion, pad_idx):
    """Calculate loss with attention to padding"""
    # Create a mask to exclude padding tokens from loss
    non_pad_mask = (targets != pad_idx)

    # Count non-padding tokens
    n_tokens = non_pad_mask.sum().item()

    # Calculate loss
    loss = criterion(predictions, targets)

    return loss, n_tokens

### 4.2 Unified Training Loop
def train_epoch(model, data_loader, optimizer, criterion, clip, device, pad_idx, print_frequency=100):
    """Train model for one epoch - works for both baseline and attention models"""
    
    # Set model to training mode
    model.train()
    
    # Initialize metrics
    epoch_loss = 0
    total_tokens = 0
    batch_time = AverageMeter()  # Track batch processing time
    
    # Determine if this is an attention model
    is_attention_model = isinstance(model, AttentionCaptionModel)
    
    # Get start time
    start_time = time.time()
    start_batch_time = time.time()
    
    # Iterate over batches with progress bar
    with tqdm(total=len(data_loader), desc="Training") as pbar:
        try:
            for i, (images, captions, lengths) in enumerate(data_loader):
                # Move to device
                images = images.to(device)
                captions = captions.to(device)
                
                # Forward pass - different for baseline vs attention
                if is_attention_model:
                    outputs, alphas = model(images, captions, lengths)
                    
                    # Prepare targets (shift by one)
                    targets = captions[:, 1:]  # Remove <SOS>
                    outputs = outputs[:, :-1, :]  # Remove last prediction
                    
                    # Reshape for loss calculation
                    batch_size = outputs.size(0)
                    outputs = outputs.reshape(-1, outputs.size(2))
                    targets = targets.reshape(-1)
                    
                    # Calculate loss
                    loss, n_tokens = calculate_loss(outputs, targets, criterion, pad_idx)
                    
                    # Add attention regularization (encourage diversity in attention)
                    alpha_c = 1.0  # Attention regularization factor
                    att_reg = alpha_c * ((1 - alphas.sum(dim=1)) ** 2).mean()
                    total_loss = loss + att_reg
                else:
                    outputs = model(images, captions, lengths)
                    
                    # Prepare targets (shift by one)
                    targets = captions[:, 1:]  # Remove <SOS>
                    outputs = outputs[:, :-1, :]  # Remove last prediction
                    
                    # Reshape for loss calculation
                    batch_size = outputs.size(0)
                    outputs = outputs.reshape(-1, outputs.size(2))
                    targets = targets.reshape(-1)
                    
                    # Calculate loss
                    loss, n_tokens = calculate_loss(outputs, targets, criterion, pad_idx)
                    total_loss = loss
                
                # Update metrics
                epoch_loss += total_loss.item() * n_tokens
                total_tokens += n_tokens
                
                # Update batch time
                batch_time.update(time.time() - start_batch_time)
                start_batch_time = time.time()
                
                # Backward pass
                optimizer.zero_grad()
                total_loss.backward()
                
                # Clip gradients
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip)
                
                # Update weights
                optimizer.step()
                
                # Update progress bar
                pbar.update(1)
                pbar.set_postfix({"loss": f"{total_loss.item():.4f}", "time/batch": f"{batch_time.avg:.3f}s"})
                
                # Print progress
                if (i + 1) % print_frequency == 0:
                    print(f"Batch {i+1}/{len(data_loader)} | "
                         f"Loss: {total_loss.item():.4f} | "
                         f"Time: {batch_time.avg:.3f}s/batch | "
                         f"Elapsed: {time.time() - start_time:.1f}s")
        except Exception as e:
            print(f"Error during training: {str(e)}")
            traceback.print_exc()
            raise
    
    # Calculate epoch metrics
    avg_loss = epoch_loss / total_tokens if total_tokens > 0 else float('inf')
    
    return avg_loss

### 4.3 Unified Validation Function
def validate(model, data_loader, criterion, device, pad_idx):
    """Validate model - works for both baseline and attention models"""
    
    # Set model to evaluation mode
    model.eval()
    
    # Determine if this is an attention model
    is_attention_model = isinstance(model, AttentionCaptionModel)
    
    # Initialize metrics
    epoch_loss = 0
    total_tokens = 0
    
    # No gradient calculation needed
    with torch.no_grad():
        # Iterate over batches with progress bar
        with tqdm(total=len(data_loader), desc="Validation") as pbar:
            for images, captions, lengths in data_loader:
                # Move to device
                images = images.to(device)
                captions = captions.to(device)
                
                # Forward pass - different for baseline vs attention
                if is_attention_model:
                    outputs, alphas = model(images, captions, lengths)
                    
                    # Prepare targets (shift by one)
                    targets = captions[:, 1:]  # Remove <SOS>
                    outputs = outputs[:, :-1, :]  # Remove last prediction
                    
                    # Reshape for loss calculation
                    outputs = outputs.reshape(-1, outputs.size(2))
                    targets = targets.reshape(-1)
                    
                    # Calculate loss
                    loss, n_tokens = calculate_loss(outputs, targets, criterion, pad_idx)
                    
                    # Add attention regularization
                    alpha_c = 1.0  # Attention regularization factor
                    att_reg = alpha_c * ((1 - alphas.sum(dim=1)) ** 2).mean()
                    total_loss = loss + att_reg
                else:
                    outputs = model(images, captions, lengths)
                    
                    # Prepare targets (shift by one)
                    targets = captions[:, 1:]  # Remove <SOS>
                    outputs = outputs[:, :-1, :]  # Remove last prediction
                    
                    # Reshape for loss calculation
                    outputs = outputs.reshape(-1, outputs.size(2))
                    targets = targets.reshape(-1)
                    
                    # Calculate loss
                    loss, n_tokens = calculate_loss(outputs, targets, criterion, pad_idx)
                    total_loss = loss
                
                # Update metrics
                epoch_loss += total_loss.item() * n_tokens
                total_tokens += n_tokens
                
                # Update progress bar
                pbar.update(1)
                pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
    
    # Calculate epoch metrics
    avg_loss = epoch_loss / total_tokens if total_tokens > 0 else float('inf')
    
    return avg_loss

### 4.4 BLEU Score Calculation
def calculate_bleu(model, data_loader, vocab, device, max_samples=None):
    """Calculate BLEU score - works with both model types"""
    
    # Set model to evaluation mode
    model.eval()
    
    # Determine if this is an attention model
    is_attention_model = isinstance(model, AttentionCaptionModel)
    
    # Initialize lists for references and hypotheses
    references = []
    hypotheses = []
    
    # No gradient calculation needed
    with torch.no_grad():
        # Progress bar
        total = min(len(data_loader), max_samples // data_loader.batch_size + 1) if max_samples else len(data_loader)
        with tqdm(total=total, desc="Calculating BLEU") as pbar:
            # Iterate over batches
            for i, (images, captions, lengths) in enumerate(data_loader):
                # Check if we've processed enough samples
                if max_samples and i * data_loader.batch_size >= max_samples:
                    break
                
                # Move to device
                images = images.to(device)
                
                # Get predictions based on model type
                if is_attention_model:
                    predictions, _ = model.sample(images)
                else:
                    predictions = model.sample(images)
                
                # Process each image in the batch
                for j in range(images.size(0)):
                    # Check if we've processed enough samples
                    if max_samples and len(hypotheses) >= max_samples:
                        break
                    
                    # Get predicted caption
                    pred_tokens = []
                    for token_idx in predictions[j]:
                        token = vocab.itos[token_idx.item()]
                        if token == "<EOS>":
                            break
                        if token not in ["<PAD>", "<SOS>"]:
                            pred_tokens.append(token)
                    
                    # Get reference caption
                    ref_tokens = []
                    for token_idx in captions[j]:
                        token = vocab.itos[token_idx.item()]
                        if token == "<EOS>":
                            break
                        if token not in ["<PAD>", "<SOS>"]:
                            ref_tokens.append(token)
                    
                    # Add to lists
                    references.append([ref_tokens])  # Each reference is a list of reference translations
                    hypotheses.append(pred_tokens)
                
                # Update progress bar
                pbar.update(1)
    
    # Calculate BLEU scores
    bleu1 = corpus_bleu(references, hypotheses, weights=(1, 0, 0, 0))
    bleu2 = corpus_bleu(references, hypotheses, weights=(0.5, 0.5, 0, 0))
    bleu3 = corpus_bleu(references, hypotheses, weights=(0.33, 0.33, 0.33, 0))
    bleu4 = corpus_bleu(references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25))
    
    return {
        "bleu1": bleu1 * 100,
        "bleu2": bleu2 * 100,
        "bleu3": bleu3 * 100,
        "bleu4": bleu4 * 100
    }

### 4.5 Training Framework
def train(model, train_loader, val_loader, optimizer, scheduler, criterion, config, device, vocab, model_paths):
    """Full training loop with both models"""
    
    # Unpack configuration parameters
    num_epochs = config["num_epochs"]
    pad_idx = vocab.stoi["<PAD>"]
    print_frequency = config["print_frequency"]
    eval_every = config["eval_every"]
    bleu_every = config["bleu_every"]
    max_bleu_samples = config["max_bleu_samples"]
    early_stopping_patience = config["early_stopping_patience"]
    clip_grad_norm = config["clip_grad_norm"]
    
    # Checkpoint paths
    checkpoint_path = model_paths["checkpoint_path"]
    best_model_path = model_paths["best_model_path"]
    
    # Determine model type
    is_attention_model = isinstance(model, AttentionCaptionModel)
    model_name = "Attention" if is_attention_model else "Baseline"
    
    # Initialize tracking variables
    best_val_loss = float('inf')
    best_bleu = 0
    patience_counter = 0
    history = {
        'epochs': [],
        'train_losses': [],
        'val_epochs': [],
        'val_losses': [],
        'bleu_epochs': [],
        'bleu_scores': []
    }
    
    # Check if a trained model exists
    model_status, model_checkpoint = check_model_availability(config, best_model_path, checkpoint_path)
    
    # If a trained model exists, load it
    if model_status == "trained":
        print(f"Found trained {model_name} model, loading weights...")
        model.load_state_dict(model_checkpoint['state_dict'])
        
        # Load training history if available
        if 'training_history' in model_checkpoint:
            history = model_checkpoint['training_history']
            print("Loaded training history.")
        
        print(f"Skipping training for {model_name} model.")
        return model, history
    
    # If a checkpoint exists, resume training
    elif model_status == "checkpoint":
        print(f"Found checkpoint for {model_name} model, resuming training...")
        
        # Load model state
        model.load_state_dict(model_checkpoint['state_dict'])
        
        # Load optimizer and scheduler states if available
        if 'optimizer' in model_checkpoint and optimizer is not None:
            optimizer.load_state_dict(model_checkpoint['optimizer'])
            print("Loaded optimizer state.")
        
        if 'scheduler' in model_checkpoint and scheduler is not None:
            scheduler.load_state_dict(model_checkpoint['scheduler'])
            print("Loaded scheduler state.")
        
        # Load training history if available
        if 'training_history' in model_checkpoint:
            history = model_checkpoint['training_history']
            print("Loaded training history.")
        
        # Get starting epoch
        start_epoch = model_checkpoint.get('epoch', 0)
        best_val_loss = model_checkpoint.get('val_loss', float('inf'))
        print(f"Resuming training from epoch {start_epoch + 1}...")
        
    else:
        # Start fresh training
        print(f"Starting fresh training for {model_name} model...")
        start_epoch = 0
    
    # Start training time
    train_start_time = time.time()
    
    print(f"\nTraining {model_name} model for {num_epochs - start_epoch} epochs...")
    
    # Training loop
    for epoch in range(start_epoch, num_epochs):
        epoch_num = epoch + 1  # 1-based epoch numbering
        
        # Start epoch time
        epoch_start_time = time.time()
        
        # Print epoch info
        print(f"\nEpoch {epoch_num}/{num_epochs}")
        
        # Train for one epoch
        train_loss = train_epoch(
            model, train_loader, optimizer, criterion,
            clip_grad_norm, device, pad_idx, print_frequency
        )
        
        # Store train loss
        history['epochs'].append(epoch_num)
        history['train_losses'].append(train_loss)
        
        # Validation (based on eval_every)
        should_validate = epoch_num % eval_every == 0 or epoch_num == num_epochs
        if should_validate:
            # Validate
            val_loss = validate(model, val_loader, criterion, device, pad_idx)
            
            # Store validation results
            history['val_epochs'].append(epoch_num)
            history['val_losses'].append(val_loss)
            
            # Check for new best model
            is_best = val_loss < best_val_loss
            if is_best:
                best_val_loss = val_loss
                patience_counter = 0
                
                # Save best model
                print(f"New best model with validation loss: {val_loss:.4f}")
                save_checkpoint({
                    'epoch': epoch_num,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict() if scheduler else None,
                    'val_loss': val_loss,
                    'config': config,
                    'training_history': history
                }, is_best=True, filepath=best_model_path)
            else:
                patience_counter += 1
            
            # Print validation results
            print(f"Epoch {epoch_num} - Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}")
            
            # Update learning rate scheduler
            if scheduler:
                scheduler.step(val_loss)
        else:
            # Print training results only
            print(f"Epoch {epoch_num} - Train loss: {train_loss:.4f}")
        
        # Calculate BLEU scores (based on bleu_every)
        should_calculate_bleu = epoch_num % bleu_every == 0 or epoch_num == num_epochs
        if should_calculate_bleu:
            # Calculate BLEU scores
            print("Calculating BLEU scores...")
            bleu = calculate_bleu(model, val_loader, vocab, device, max_samples=max_bleu_samples)
            
            # Store BLEU scores
            history['bleu_epochs'].append(epoch_num)
            history['bleu_scores'].append(bleu)
            
            # Check for new best model (based on BLEU-4)
            if bleu['bleu4'] > best_bleu:
                best_bleu = bleu['bleu4']
            
            # Print BLEU scores
            print(f"BLEU scores - BLEU-1: {bleu['bleu1']:.2f}, BLEU-2: {bleu['bleu2']:.2f}, "
                  f"BLEU-3: {bleu['bleu3']:.2f}, BLEU-4: {bleu['bleu4']:.2f}")
        
        # Save regular checkpoint
        save_checkpoint({
            'epoch': epoch_num,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict() if scheduler else None,
            'train_loss': train_loss,
            'val_loss': val_loss if should_validate else None,
            'config': config,
            'training_history': history
        }, filepath=checkpoint_path)
        
        # Check for early stopping
        if patience_counter >= early_stopping_patience:
            print(f"Early stopping triggered after {epoch_num} epochs")
            break
        
        # Print epoch time
        epoch_time = time.time() - epoch_start_time
        print(f"Epoch {epoch_num} completed in {epoch_time:.1f}s")
    
    # Print total training time
    train_time = time.time() - train_start_time
    print(f"\nTraining completed in {train_time/60:.1f} minutes")
    
    # Load the best model
    try:
        print(f"Loading best {model_name} model...")
        best_checkpoint = torch.load(best_model_path)
        model.load_state_dict(best_checkpoint['state_dict'])
        print(f"Loaded best model from epoch {best_checkpoint['epoch']} with validation loss {best_checkpoint['val_loss']:.4f}")
    except Exception as e:
        print(f"Could not load best model: {e}")
        print("Using final model instead.")
    
    return model, history

### 4.6 Visualization Functions
def plot_training_history(history, model_name="Model"):
    """Plot training history with proper handling of evaluation frequencies"""

    # Create a figure
    plt.figure(figsize=(14, 6))

    # Plot loss curves
    plt.subplot(1, 2, 1)

    # Plot training loss (always available)
    plt.plot(history['epochs'], history['train_losses'], 'o-', label='Train Loss')

    # Plot validation loss if available
    if 'val_epochs' in history and history['val_losses']:
        plt.plot(history['val_epochs'], history['val_losses'], 'o-', label='Val Loss')

    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'{model_name} - Training and Validation Loss')
    plt.legend()
    plt.grid(True)

    # Plot BLEU scores if available
    if 'bleu_epochs' in history and history['bleu_scores']:
        plt.subplot(1, 2, 2)

        # Extract BLEU scores
        bleu_epochs = history['bleu_epochs']
        bleu1 = [b['bleu1'] for b in history['bleu_scores']]
        bleu2 = [b['bleu2'] for b in history['bleu_scores']]
        bleu3 = [b['bleu3'] for b in history['bleu_scores']]
        bleu4 = [b['bleu4'] for b in history['bleu_scores']]

        # Plot BLEU scores
        plt.plot(bleu_epochs, bleu1, 'o-', label='BLEU-1')
        plt.plot(bleu_epochs, bleu2, 'o-', label='BLEU-2')
        plt.plot(bleu_epochs, bleu3, 'o-', label='BLEU-3')
        plt.plot(bleu_epochs, bleu4, 'o-', label='BLEU-4')
        plt.xlabel('Epoch')
        plt.ylabel('BLEU Score')
        plt.title(f'{model_name} - BLEU Scores')
        plt.legend()
        plt.grid(True)

    # Show the figure
    plt.tight_layout()

    # Save the figure
    plt.savefig(os.path.join(LOGS_DIR, f'{model_name.lower()}_training_history.png'), dpi=300)
    plt.show()

def generate_caption(model, image, vocab, max_length=20):
    """Generate caption for image - works with both models"""
    
    model.eval()
    
    # Check model type and use appropriate method
    if isinstance(model, AttentionCaptionModel):
        caption, attention_weights = model.caption_image_with_attention(image, vocab, max_length)
        return caption, attention_weights
    else:
        caption = model.caption_image(image, vocab, max_length)
        return caption, None

def visualize_sample_captions(model, dataset, vocab, device, num_samples=3):
    """Visualize sample captions generated by the model"""
    
    # Determine if this is an attention model
    is_attention_model = isinstance(model, AttentionCaptionModel)

    # Set model to evaluation mode
    model.eval()

    # Select random samples
    indices = np.random.choice(len(dataset), num_samples, replace=False)

    # Generate captions for each sample
    with torch.no_grad():
        for i, idx in enumerate(indices):
            # Get image and caption
            image, caption = dataset[idx]
            img_name = dataset.data_df.iloc[idx]['image']  # Get image name

            # Move to device
            image = image.unsqueeze(0).to(device)

            # Generate caption
            if is_attention_model:
                generated_caption, attention_weights = model.caption_image_with_attention(image, vocab)
            else:
                generated_caption = model.caption_image(image, vocab)
                attention_weights = None

            # Convert reference caption to words
            reference_caption = []
            for token_idx in caption:
                token = vocab.itos[token_idx.item()]
                if token == "<EOS>":
                    break
                if token not in ["<PAD>", "<SOS>"]:
                    reference_caption.append(token)
            reference_caption = ' '.join(reference_caption)

            # Print image name and captions
            print()
            print(f"Image {i+1}: {img_name}")
            print(f"Reference: {reference_caption}")
            print(f"Generated: {generated_caption}")
            
            # Display the image
            plt.figure(figsize=(10, 8))
            img = denormalize_image(image[0])
            plt.imshow(img)
            plt.title(f"Image: {img_name}")
            plt.axis('off')
            plt.tight_layout()
            save_path = os.path.join(LOGS_DIR, f'caption_sample_{i+1}.png')
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            plt.show()
            
            # For attention model, visualize attention weights
            if is_attention_model and attention_weights:
                visualize_attention(img, generated_caption.split(), attention_weights)

def visualize_attention(image, caption, attention_weights, show_every=1, save_path=None):
    """Visualize attention for each word in the caption"""
    # Create figure with subplots - one subplot per word, plus one for the original image
    words_to_show = list(range(0, len(caption), show_every))
    num_words = len(words_to_show)
    
    # Determine subplot grid size
    if num_words < 6:
        # For few words, use 1 row
        nrows = 1
        ncols = num_words + 1  # +1 for original image
    else:
        # For more words, use 2 or more rows
        ncols = min(5, num_words + 1)
        nrows = (num_words + 1 + ncols - 1) // ncols
    
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3, nrows*3))
    
    # Handle different subplot layouts correctly
    if nrows == 1 and ncols == 1:
        # Single subplot
        axes_flat = [axes]
    elif nrows == 1 or ncols == 1:
        # 1D array of subplots (single row or column)
        axes_flat = axes.flatten()
    else:
        # 2D array of subplots
        axes_flat = axes.flatten()
    
    # Plot original image
    axes_flat[0].imshow(image)
    axes_flat[0].set_title('Original Image')
    axes_flat[0].axis('off')
    
    # Plot attention for each selected word
    for idx, word_idx in enumerate(words_to_show):
        if idx + 1 >= len(axes_flat):  # Ensure we don't run out of subplots
            break
            
        # Get attention weights and word
        att_weight = attention_weights[word_idx]
        word = caption[word_idx]
        
        # Reshape attention weights to square for visualization
        # Assuming it's a square feature map
        att_size = int(np.sqrt(att_weight.shape[0]))
        att_weight = att_weight.reshape(att_size, att_size)
        
        # Resize attention map to match image size
        h, w = image.shape[:2]
        att_weight = np.repeat(np.repeat(att_weight, h//att_size, axis=0), w//att_size, axis=1)
        att_weight = att_weight[:h, :w]  # Crop to match image size
        
        # Plot the word-specific attention
        axes_flat[idx + 1].imshow(image)
        axes_flat[idx + 1].imshow(att_weight, alpha=0.6, cmap='hot')
        axes_flat[idx + 1].set_title(word)
        axes_flat[idx + 1].axis('off')
    
    # Hide any unused subplots
    for i in range(len(words_to_show) + 1, len(axes_flat)):
        axes_flat[i].axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, bbox_inches='tight')
    
    plt.show()

def compare_models(baseline_model, attention_model, dataset, vocab, device, num_samples=3):
    """Compare captions generated by baseline and attention models"""
    
    # Set models to evaluation mode
    baseline_model.eval()
    attention_model.eval()
    
    # Select random samples
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    # Generate captions for each sample
    with torch.no_grad():
        for i, idx in enumerate(indices):
            # Get image and ground truth caption
            image, caption = dataset[idx]
            img_name = dataset.data_df.iloc[idx]['image']
            
            # Move to device
            image = image.unsqueeze(0).to(device)
            
            # Generate captions
            baseline_caption = baseline_model.caption_image(image, vocab)
            attention_caption, attention_weights = attention_model.caption_image_with_attention(image, vocab)
            
            # Convert ground truth caption to words
            gt_words = [vocab.itos[idx.item()] for idx in caption
                      if idx.item() < len(vocab) and vocab.itos[idx.item()] not in ["<PAD>", "<SOS>", "<EOS>"]]
            gt_caption = ' '.join(gt_words)
            
            # Print image name and captions
            print(f"\nImage {i+1}: {img_name}")
            print(f"Ground Truth: {gt_caption}")
            print(f"Baseline: {baseline_caption}")
            print(f"Attention: {attention_caption}")
            
            # Display image and captions
            plt.figure(figsize=(12, 10))
            
            # Display image
            plt.subplot(2, 1, 1)
            img = denormalize_image(image[0])
            plt.imshow(img)
            plt.title(f"Image: {img_name}")
            plt.axis('off')
            
            # Display captions comparison
            plt.subplot(2, 1, 2)
            plt.axis('off')
            comparison_text = (
                f"Ground Truth: {gt_caption}\n\n"
                f"Baseline: {baseline_caption}\n\n"
                f"Attention: {attention_caption}"
            )
            plt.text(0.5, 0.5, comparison_text, ha='center', va='center', fontsize=12, wrap=True)
            
            plt.tight_layout()
            save_path = os.path.join(LOGS_DIR, f'model_comparison_{i+1}.png')
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            plt.show()
            
            # Display attention visualization
            attention_words = attention_caption.split()
            if attention_weights and attention_words:
                # Limit number of words to display for readability
                show_every = max(1, len(attention_words) // 10)
                visualize_attention(
                    img, attention_words, attention_weights,
                    show_every=show_every,
                    save_path=os.path.join(LOGS_DIR, f'model_comparison_attention_{i+1}.png')
                )



In [None]:
## Model Training and Evaluation

### 5.1 Initialize and Train Baseline Model
# Initialize baseline model
baseline_model = BaselineCaptionModel(
    embed_size=model_config["embed_size"],
    hidden_size=model_config["hidden_size"],
    vocab_size=len(vocab),
    num_layers=model_config["num_layers"],
    dropout=model_config["dropout"]
).to(device)

# Create loss criterion
criterion = nn.CrossEntropyLoss(ignore_index=vocab.stoi["<PAD>"])

# Create optimizer for baseline model
baseline_optimizer = optim.Adam(
    baseline_model.parameters(),
    lr=model_config["learning_rate"],
    weight_decay=model_config["weight_decay"]
)

# Create learning rate scheduler
baseline_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    baseline_optimizer,
    mode='min',
    factor=model_config["lr_scheduler_factor"],
    patience=model_config["lr_scheduler_patience"],
    verbose=True
)

# Visualize baseline model architecture
visualize_model_architecture(baseline_model, is_attention_model=False)

# Train baseline model
print("\nTraining baseline model...")
baseline_model, baseline_history = train(
    baseline_model, train_loader_to_use, val_loader_to_use,
    baseline_optimizer, baseline_scheduler, criterion,
    model_config, device, vocab, baseline_paths
)

# Plot baseline training history
plot_training_history(baseline_history, model_name="Baseline")

# Visualize sample captions from baseline model
print("\nGenerating sample captions from baseline model...")
visualize_sample_captions(baseline_model, test_dataset, vocab, device, num_samples=3)

### 5.2 Initialize and Train Attention Model
# Initialize attention model
attention_model = AttentionCaptionModel(
    embed_size=model_config["embed_size"],
    hidden_size=model_config["hidden_size"],
    vocab_size=len(vocab),
    attention_dim=model_config["attention_dim"],
    num_layers=model_config["num_layers"],
    dropout=model_config["dropout"]
).to(device)

# Create optimizer for attention model
attention_optimizer = optim.Adam(
    attention_model.parameters(),
    lr=model_config["learning_rate"],
    weight_decay=model_config["weight_decay"]
)

# Create learning rate scheduler
attention_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    attention_optimizer,
    mode='min',
    factor=model_config["lr_scheduler_factor"],
    patience=model_config["lr_scheduler_patience"],
    verbose=True
)

# Visualize attention model architecture
visualize_model_architecture(attention_model, is_attention_model=True)

# Train attention model
print("\nTraining attention model...")
attention_model, attention_history = train(
    attention_model, train_loader_to_use, val_loader_to_use,
    attention_optimizer, attention_scheduler, criterion,
    model_config, device, vocab, attention_paths
)

# Plot attention training history
plot_training_history(attention_history, model_name="Attention")

# Visualize sample captions from attention model
print("\nGenerating sample captions from attention model...")
visualize_sample_captions(attention_model, test_dataset, vocab, device, num_samples=3)

### 5.3 Evaluate and Compare Models
# Calculate BLEU scores for both models
print("\nCalculating final BLEU scores for baseline model...")
baseline_bleu = calculate_bleu(
    baseline_model, test_loader_to_use, vocab, device,
    max_samples=model_config["max_bleu_samples"]
)

print("\nCalculating final BLEU scores for attention model...")
attention_bleu = calculate_bleu(
    attention_model, test_loader_to_use, vocab, device,
    max_samples=model_config["max_bleu_samples"]
)

# Print BLEU comparison
print("\nModel BLEU Score Comparison:")
print(f"                 BLEU-1  BLEU-2  BLEU-3  BLEU-4")
print(f"Baseline Model:  {baseline_bleu['bleu1']:.2f}    {baseline_bleu['bleu2']:.2f}    {baseline_bleu['bleu3']:.2f}    {baseline_bleu['bleu4']:.2f}")
print(f"Attention Model: {attention_bleu['bleu1']:.2f}    {attention_bleu['bleu2']:.2f}    {attention_bleu['bleu3']:.2f}    {attention_bleu['bleu4']:.2f}")

# Compare caption quality and attention visualization
print("\nComparing caption quality between models...")
compare_models(baseline_model, attention_model, test_dataset, vocab, device, num_samples=3)



In [None]:
## Summary and Conclusion
print("\nProject Summary:")
print("1. Implemented both baseline CNN+RNN and attention-based image captioning models")
print("2. The attention mechanism helps the model focus on relevant image regions for each word")
print("3. Attention model generally produces more detailed and accurate captions")
print("4. Attention visualizations provide interpretable insights into model behavior")

print("\nKey Differences:")
print("1. Encoder: Attention model preserves spatial information, baseline pools features")
print("2. Decoder: Attention model uses weighted context vectors, baseline uses global features")
print("3. Training: Attention model includes attention regularization term in loss function")
print("4. Inference: Attention model dynamically focuses on different regions for each word")

print("\nFuture Improvements:")
print("1. Use a more powerful CNN backbone like EfficientNet or Vision Transformer")
print("2. Implement beam search instead of greedy decoding for better captions")
print("3. Train on larger datasets like MS COCO for better generalization")
print("4. Incorporate semantic understanding with pre-trained language models")