# CLIP Fine-tuning on MS COCO 2014

This notebook provides a complete pipeline for:
1. Dataset preparation
2. Model training
3. Evaluation and visualization

All complex code is in Python modules:
- `config.py` - Configuration and constants
- `models.py` - Model architectures
- `dataset.py` - Dataset classes
- `utils.py` - Training, evaluation, and visualization functions

## 1. Setup

In [None]:
!wget -q https://raw.githubusercontent.com/tsunrise/colab-github/main/colab_github.py
import colab_github
colab_github.github_auth(persistent_key=True)

In [None]:
!git clone git@github.com:oliverdantzer/elec475-lab4.git

In [None]:
import sys
sys.path.insert(0,'/content/elec475-lab4')

In [None]:

import os
os.system("pip install -r requirements.txt")

In [None]:
# Imports
import os
import time
import random
from pathlib import Path

import torch
import numpy as np
from torch.utils.data import DataLoader
from transformers import CLIPTokenizer, CLIPTextModel

# Import from our modules
from config import *
from models import CLIPModel, InfoNCELoss
from dataset import COCOClipDataset, get_clip_transforms
from utils import (
    encode_and_cache_captions,
    setup_optimizer_and_scheduler,
    train_epoch,
    validate,
    compute_embeddings,
    compute_recall_at_k,
    retrieve_images_for_text,
    zero_shot_classify,
    plot_training_curves,
    save_checkpoint,
    load_checkpoint,
    visualize_samples
)

# Set random seeds
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Configuration

Set all hyperparameters and experimental flags here.

In [None]:
# Configuration
CONFIG = DEFAULT_CONFIG.copy()

# Paths (customize these for your environment)
CONFIG['dataset_dir'] = Path('/content/coco2014')
CONFIG['checkpoint_dir'] = Path('/content/checkpoints')
CONFIG['log_dir'] = Path('/content/logs')

# Training hyperparameters
CONFIG['batch_size'] = 128
CONFIG['num_epochs'] = 10
CONFIG['learning_rate'] = 1e-4
CONFIG['weight_decay'] = 0.01
CONFIG['warmup_steps'] = 500
CONFIG['temperature'] = 0.07
CONFIG['max_grad_norm'] = 1.0

# ===== EXPERIMENTAL FLAGS =====
# Set these to True to enable architectural modifications
CONFIG['use_batch_norm'] = False        # Add BatchNorm to projection head
CONFIG['use_attention_pooling'] = False  # Use attention pooling in ResNet50
# ==============================

# Create directories with flags suffix
flags_suffix = ""
if CONFIG['use_batch_norm']:
    flags_suffix += "_bn"
if CONFIG['use_attention_pooling']:
    flags_suffix += "_attn"

if flags_suffix:
    CONFIG['checkpoint_dir'] = Path(str(CONFIG['checkpoint_dir']) + flags_suffix)
    CONFIG['log_dir'] = Path(str(CONFIG['log_dir']) + flags_suffix)

CONFIG['checkpoint_dir'].mkdir(exist_ok=True, parents=True)
CONFIG['log_dir'].mkdir(exist_ok=True, parents=True)

# Print configuration
print("Configuration:")
print(f"  Dataset: {CONFIG['dataset_dir']}")
print(f"  Batch size: {CONFIG['batch_size']}")
print(f"  Epochs: {CONFIG['num_epochs']}")
print(f"  Learning rate: {CONFIG['learning_rate']}")
print(f"  Temperature: {CONFIG['temperature']}")
print(f"\n  Experimental Flags:")
print(f"    BatchNorm: {CONFIG['use_batch_norm']}")
print(f"    Attention Pooling: {CONFIG['use_attention_pooling']}")
if flags_suffix:
    print(f"    Suffix: {flags_suffix}")

## 3. Dataset Preparation (Optional)

Run this section if you haven't prepared the dataset yet. Skip if text embeddings are already cached.

In [None]:
# Download COCO 2014 dataset using kagglehub
import kagglehub
import shutil

annotations_dir = CONFIG['dataset_dir'] / 'annotations'
train_images_dir = CONFIG['dataset_dir'] / 'train2014'
val_images_dir = CONFIG['dataset_dir'] / 'val2014'

if not (annotations_dir.exists() and train_images_dir.exists() and val_images_dir.exists()):
    print("Downloading COCO 2014 dataset from Kaggle...")
    
    # Download dataset
    path = kagglehub.dataset_download("jeffaudi/coco-2014-dataset-for-yolov3")
    print(f"Downloaded to: {path}")
    
    # Show what's in the downloaded directory
    downloaded_path = Path(path)
    print(f"\nDownloaded dataset structure:")
    for item in downloaded_path.iterdir():
        print(f"  - {item.name}")
    
    # Create our dataset directory
    CONFIG['dataset_dir'].mkdir(exist_ok=True, parents=True)
    
    # Copy or symlink annotations
    if not annotations_dir.exists():
        source_annotations = downloaded_path / 'annotations'
        if source_annotations.exists():
            print(f"\nCopying annotations to {annotations_dir}...")
            shutil.copytree(source_annotations, annotations_dir)
            print("✓ Annotations copied")
        else:
            print(f"⚠️ Warning: Annotations not found at {source_annotations}")
    
    # Copy or symlink train images
    if not train_images_dir.exists():
        source_train = downloaded_path / 'train2014'
        if source_train.exists():
            print(f"\nCreating symlink for train images...")
            train_images_dir.symlink_to(source_train)
            print(f"✓ Train images linked: {train_images_dir} -> {source_train}")
        else:
            print(f"⚠️ Warning: train2014 not found. Available: {list(downloaded_path.iterdir())}")
    
    # Copy or symlink val images
    if not val_images_dir.exists():
        source_val = downloaded_path / 'val2014'
        if source_val.exists():
            print(f"\nCreating symlink for val images...")
            val_images_dir.symlink_to(source_val)
            print(f"✓ Val images linked: {val_images_dir} -> {source_val}")
        else:
            print(f"⚠️ Warning: val2014 not found. Available: {list(downloaded_path.iterdir())}")
else:
    print("✓ Dataset directories already exist")
    print(f"  Annotations: {annotations_dir}")
    print(f"  Train images: {train_images_dir}")
    print(f"  Val images: {val_images_dir}")

# Check if dataset needs preparation
train_cache = CONFIG['dataset_dir'] / 'train_text_embeddings.pt'
val_cache = CONFIG['dataset_dir'] / 'val_text_embeddings.pt'

if train_cache.exists() and val_cache.exists():
    print("\n✓ Dataset already prepared!")
    print(f"  Train cache: {train_cache}")
    print(f"  Val cache: {val_cache}")
else:
    print("\nDataset needs preparation. Loading CLIP text encoder...")
    
    # Load text encoder for encoding captions
    tokenizer = CLIPTokenizer.from_pretrained(MODEL_NAME)
    text_encoder = CLIPTextModel.from_pretrained(MODEL_NAME)
    text_encoder = text_encoder.to(device)
    
    # Encode and cache training captions
    if not train_cache.exists():
        print("\nEncoding training captions...")
        encode_and_cache_captions(
            split='train',
            dataset_dir=CONFIG['dataset_dir'],
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            device=device
        )
    
    # Encode and cache validation captions
    if not val_cache.exists():
        print("\nEncoding validation captions...")
        encode_and_cache_captions(
            split='val',
            dataset_dir=CONFIG['dataset_dir'],
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            device=device
        )
    
    print("\n✓ Dataset preparation complete!")

## 4. Load Data

In [None]:
# Create datasets
print("Loading datasets...")
train_dataset = COCOClipDataset(split='train', dataset_dir=CONFIG['dataset_dir'])
val_dataset = COCOClipDataset(split='val', dataset_dir=CONFIG['dataset_dir'])

print(f"\n✓ Datasets loaded:")
print(f"  Train: {len(train_dataset):,} images")
print(f"  Val: {len(val_dataset):,} images")

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"\n✓ DataLoaders created:")
print(f"  Train batches: {len(train_loader):,}")
print(f"  Val batches: {len(val_loader):,}")

In [None]:
# Visualize random samples
visualize_samples(train_dataset, num_samples=6)

## 5. Create Model

In [None]:
# Load CLIP text encoder
print("Loading CLIP text encoder...")
tokenizer = CLIPTokenizer.from_pretrained(MODEL_NAME)
text_encoder = CLIPTextModel.from_pretrained(MODEL_NAME)
text_encoder = text_encoder.to(device)

# Create CLIP model with experimental flags
print("\nCreating CLIP model...")
model = CLIPModel(
    text_encoder=text_encoder,
    freeze_text_encoder=True,
    use_batch_norm=CONFIG['use_batch_norm'],
    use_attention_pooling=CONFIG['use_attention_pooling']
)
model = model.to(device)

# Print model info
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n✓ Model created:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,} ({100*trainable_params/total_params:.1f}%)")

if CONFIG['use_batch_norm'] or CONFIG['use_attention_pooling']:
    print(f"\n  Experimental modifications:")
    if CONFIG['use_batch_norm']:
        print(f"    ✓ BatchNorm in projection head")
    if CONFIG['use_attention_pooling']:
        print(f"    ✓ Attention pooling in image encoder")

# Create loss function
criterion = InfoNCELoss(temperature=CONFIG['temperature'])
print(f"\n✓ Loss function: InfoNCE (temperature={CONFIG['temperature']})")

## 6. Training

In [None]:
# Setup optimizer and scheduler
total_steps = len(train_loader) * CONFIG['num_epochs']
optimizer, scheduler = setup_optimizer_and_scheduler(model, CONFIG, total_steps)

print(f"✓ Optimizer and scheduler created:")
print(f"  Optimizer: AdamW")
print(f"  Learning rate: {CONFIG['learning_rate']}")
print(f"  Scheduler: Cosine with warmup")
print(f"  Warmup steps: {CONFIG['warmup_steps']}")
print(f"  Total steps: {total_steps:,}")

In [None]:
# Training loop
print(f"\n{'='*60}")
print("Starting Training")
print(f"{'='*60}\n")

history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': [],
}

start_time = time.time()
best_val_loss = float('inf')

try:
    for epoch in range(CONFIG['num_epochs']):
        # Train
        train_loss, train_acc, epoch_history = train_epoch(
            model, train_loader, criterion, optimizer, scheduler,
            epoch, CONFIG, device
        )
        
        # Validate
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        
        # Store history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        # Print epoch results
        print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}:")
        print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.3f}")
        print(f"  Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.3f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint(
                model, optimizer, scheduler,
                epoch, val_loss, val_acc, CONFIG,
                CONFIG['checkpoint_dir'] / f'best_model{flags_suffix}.pt'
            )
            print(f"  ✓ Saved best model (val_loss: {val_loss:.4f})")
        
        # Save epoch checkpoint
        save_checkpoint(
            model, optimizer, scheduler,
            epoch, val_loss, val_acc, CONFIG,
            CONFIG['checkpoint_dir'] / f'checkpoint_epoch_{epoch+1}{flags_suffix}.pt'
        )

except KeyboardInterrupt:
    print("\n\nTraining interrupted by user")

finally:
    total_time = time.time() - start_time
    print(f"\n{'='*60}")
    print("Training Complete!")
    print(f"{'='*60}")
    print(f"Total time: {total_time/3600:.2f} hours ({total_time/60:.1f} minutes)")
    print(f"Best validation loss: {best_val_loss:.4f}")
    print(f"{'='*60}")
    
    # Save history
    torch.save(history, CONFIG['log_dir'] / f'training_history{flags_suffix}.pt')
    print(f"\n✓ Training history saved")

In [None]:
# Plot training curves
plot_training_curves(
    history,
    save_path=CONFIG['log_dir'] / f'training_curves{flags_suffix}.png',
    flags_suffix=flags_suffix
)

## 7. Evaluation

In [None]:
# Load best model
print("Loading best model...")
checkpoint_path = CONFIG['checkpoint_dir'] / f'best_model{flags_suffix}.pt'

if checkpoint_path.exists():
    checkpoint = load_checkpoint(checkpoint_path, model, device=device)
    print(f"✓ Loaded checkpoint from epoch {checkpoint['epoch']+1}")
    print(f"  Validation loss: {checkpoint['val_loss']:.4f}")
    print(f"  Validation accuracy: {checkpoint['val_acc']:.4f}")
else:
    print("⚠️ No checkpoint found, using current model")

model.eval()

### 7.1 Compute Embeddings and Recall@K

In [None]:
# Compute embeddings for validation set
image_embeddings, text_embeddings, captions, image_ids = compute_embeddings(
    model, val_dataset, batch_size=128, device=device
)

In [None]:
# Compute similarity matrix
print("Computing similarity matrix...")
similarity_matrix = torch.matmul(image_embeddings, text_embeddings.T)
print(f"Similarity matrix shape: {similarity_matrix.shape}")

# Image → Text retrieval
print("\nImage → Text Retrieval:")
i2t_recalls = compute_recall_at_k(similarity_matrix, k_values=[1, 5, 10])
for k, recall in i2t_recalls.items():
    print(f"  Recall@{k}: {recall:.4f} ({recall*100:.2f}%)")

# Text → Image retrieval
print("\nText → Image Retrieval:")
t2i_recalls = compute_recall_at_k(similarity_matrix.T, k_values=[1, 5, 10])
for k, recall in t2i_recalls.items():
    print(f"  Recall@{k}: {recall:.4f} ({recall*100:.2f}%)")

# Average
print("\nAverage (I2T + T2I):")
for k in [1, 5, 10]:
    avg_recall = (i2t_recalls[k] + t2i_recalls[k]) / 2
    print(f"  Recall@{k}: {avg_recall:.4f} ({avg_recall*100:.2f}%)")

### 7.2 Text Query → Image Retrieval

In [None]:
# Test text queries
import matplotlib.pyplot as plt

def visualize_text_query(query, top_images, top_scores, top_captions):
    """Visualize retrieval results."""
    num_images = len(top_images)
    fig, axes = plt.subplots(1, num_images, figsize=(15, 3))
    if num_images == 1:
        axes = [axes]
    
    fig.suptitle(f'Query: "{query}"', fontsize=14, fontweight='bold', y=1.05)
    
    for i, (img, score, caption) in enumerate(zip(top_images, top_scores, top_captions)):
        axes[i].imshow(img)
        axes[i].axis('off')
        axes[i].set_title(f"#{i+1}\nScore: {score:.3f}", fontsize=10)
    
    plt.tight_layout()
    plt.show()

# Test queries
queries = ['sport', 'a cat', 'food on a plate', 'people playing', 'a car on the street']

for query in queries:
    top_images, top_scores, top_captions, top_indices = retrieve_images_for_text(
        query, model, tokenizer, val_dataset, image_embeddings, top_k=5, device=device
    )
    visualize_text_query(query, top_images, top_scores, top_captions)
    print()

### 7.3 Zero-Shot Image Classification

In [None]:
# Classify random images
import matplotlib.pyplot as plt
from dataset import denormalize_image

def visualize_classification(image, class_labels, probs, predicted_class):
    """Visualize classification results."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    # Display image
    ax1.imshow(image)
    ax1.axis('off')
    ax1.set_title('Input Image', fontsize=12, fontweight='bold')
    
    # Display probabilities
    y_pos = np.arange(len(class_labels))
    colors = ['green' if i == predicted_class else 'skyblue' for i in range(len(class_labels))]
    
    ax2.barh(y_pos, probs.numpy(), color=colors)
    ax2.set_yticks(y_pos)
    ax2.set_yticklabels(class_labels)
    ax2.set_xlabel('Probability', fontsize=11)
    ax2.set_title('Class Probabilities', fontsize=12, fontweight='bold')
    ax2.set_xlim(0, 1)
    
    for i, prob in enumerate(probs.numpy()):
        ax2.text(prob + 0.02, i, f'{prob*100:.1f}%', va='center')
    
    plt.tight_layout()
    plt.show()
    
    print(f"Predicted: '{class_labels[predicted_class]}' ({probs[predicted_class]*100:.2f}%)")

# Test classification
class_labels = ['a person', 'an animal', 'a landscape', 'food', 'a vehicle']

for i in range(3):
    idx = random.randint(0, len(val_dataset) - 1)
    sample = val_dataset[idx]
    
    print(f"\nExample {i+1}")
    print(f"True caption: '{sample['caption']}'")
    
    probs, pred, scores = zero_shot_classify(
        sample['image_raw'], class_labels, model, tokenizer,
        use_templates=True, device=device
    )
    
    visualize_classification(sample['image_raw'], class_labels, probs, pred)
    print("-" * 80)

## 8. Summary Report

In [None]:
# Generate evaluation summary
print(f"{'='*60}")
print("EVALUATION SUMMARY")
print(f"{'='*60}\n")

print("Dataset:")
print(f"  Train set: {len(train_dataset):,} images")
print(f"  Val set: {len(val_dataset):,} images")

print("\nModel Configuration:")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  BatchNorm: {CONFIG['use_batch_norm']}")
print(f"  Attention Pooling: {CONFIG['use_attention_pooling']}")
if flags_suffix:
    print(f"  Flags suffix: {flags_suffix}")

print("\nRetrieval Performance:")
print("\n  Image → Text:")
for k, recall in i2t_recalls.items():
    print(f"    Recall@{k}: {recall:.4f} ({recall*100:.2f}%)")

print("\n  Text → Image:")
for k, recall in t2i_recalls.items():
    print(f"    Recall@{k}: {recall:.4f} ({recall*100:.2f}%)")

print("\n  Average:")
for k in [1, 5, 10]:
    avg_recall = (i2t_recalls[k] + t2i_recalls[k]) / 2
    print(f"    Recall@{k}: {avg_recall:.4f} ({avg_recall*100:.2f}%)")

print(f"\n{'='*60}")

# Save results
eval_results = {
    'i2t_recalls': i2t_recalls,
    't2i_recalls': t2i_recalls,
    'config': CONFIG,
    'flags_suffix': flags_suffix
}
torch.save(eval_results, CONFIG['log_dir'] / f'eval_results{flags_suffix}.pt')
print(f"\n✓ Evaluation results saved to {CONFIG['log_dir'] / f'eval_results{flags_suffix}.pt'}")