# CLIP Model Training on COCO 2014

Train a CLIP-style vision-language model with:
- **Image Encoder**: ResNet50 (ImageNet pretrained, trainable)
- **Projection Head**: 2-layer MLP with GELU (trainable)
- **Text Encoder**: CLIP text encoder (frozen)
- **Loss**: InfoNCE contrastive loss

## Prerequisites
Run `coco_dataset_prep.ipynb` first to prepare the dataset and cache text embeddings.

## 1. Install Dependencies and Imports

In [None]:
# Install required packages if needed
!pip install -q transformers>=4.30.0 torch>=2.0.0 torchvision>=0.15.0
!pip install -q pillow matplotlib tqdm

print("✓ All dependencies installed!")

In [None]:
# Imports
import os
import json
import random
import time
from pathlib import Path
from datetime import datetime
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image

from transformers import CLIPTokenizer, CLIPTextModel
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

# Set seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Configuration

In [None]:
# Training configuration
CONFIG = {
    # Paths
    'dataset_dir': Path('/content/coco2014'),
    'checkpoint_dir': Path('/content/checkpoints'),
    'log_dir': Path('/content/logs'),
    
    # Model
    'clip_model_name': 'openai/clip-vit-base-patch32',
    'image_size': 224,
    'embedding_dim': 512,
    'projection_hidden_dim': 1024,
    
    # Training hyperparameters
    'batch_size': 128,  # Increase if you have more GPU memory
    'num_epochs': 10,
    'learning_rate': 1e-4,
    'weight_decay': 0.01,
    'warmup_steps': 500,
    'temperature': 0.07,  # Temperature for InfoNCE loss
    
    # Optimization
    'optimizer': 'AdamW',
    'scheduler': 'cosine',  # 'cosine' or 'linear'
    'max_grad_norm': 1.0,  # Gradient clipping
    
    # Logging and checkpointing
    'log_interval': 50,  # Log every N steps
    'val_interval': 500,  # Validate every N steps
    'save_interval': 1000,  # Save checkpoint every N steps
    'num_workers': 2,  # DataLoader workers
}

# Create directories
CONFIG['checkpoint_dir'].mkdir(exist_ok=True, parents=True)
CONFIG['log_dir'].mkdir(exist_ok=True, parents=True)

# CLIP normalization constants
CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
CLIP_STD = [0.26862954, 0.26130258, 0.27577711]

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

## 3. Dataset Class

In [None]:
class COCOClipDataset(Dataset):
    """COCO Dataset for CLIP fine-tuning."""
    
    def __init__(self, split='train', dataset_dir=CONFIG['dataset_dir']):
        self.split = split
        self.image_dir = dataset_dir / f'{split}2014'
        self.cache_file = dataset_dir / f'{split}_text_embeddings.pt'
        
        # Image transforms
        self.transform = transforms.Compose([
            transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
            transforms.ToTensor(),
            transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD)
        ])
        
        # Load cached embeddings
        print(f"Loading {split} embeddings from {self.cache_file.name}...")
        cache = torch.load(self.cache_file)
        self.cache_data = cache['data']
        self.embedding_dim = cache['embedding_dim']
        
        print(f"  ✓ Loaded {len(self.cache_data):,} images")
        
    def __len__(self):
        return len(self.cache_data)
    
    def __getitem__(self, idx):
        item = self.cache_data[idx]
        image_id = item['image_id']
        embeddings = item['embeddings']
        captions = item['captions']
        
        # Load image
        image_filename = f'COCO_{self.split}2014_{image_id:012d}.jpg'
        image_path = self.image_dir / image_filename
        
        try:
            image = Image.open(image_path).convert('RGB')
            image = self.transform(image)
        except:
            image = torch.zeros(3, CONFIG['image_size'], CONFIG['image_size'])
        
        # Randomly select one caption
        caption_idx = random.randint(0, len(captions) - 1)
        
        return {
            'image': image,
            'text_embedding': embeddings[caption_idx],
            'caption': captions[caption_idx],
            'image_id': image_id
        }

print("✓ Dataset class defined")

## 4. Model Architecture

In [None]:
class ResNet50ImageEncoder(nn.Module):
    """ResNet50 image encoder with ImageNet pretrained weights."""
    
    def __init__(self, pretrained=True):
        super().__init__()
        resnet = models.resnet50(pretrained=pretrained)
        # Remove final FC layer
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        self.output_dim = 2048
        
    def forward(self, x):
        features = self.features(x)
        return features.view(features.size(0), -1)


class ProjectionHead(nn.Module):
    """2-layer MLP projection head with GELU activation."""
    
    def __init__(self, input_dim=2048, hidden_dim=1024, output_dim=512):
        super().__init__()
        self.projection = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, x):
        return self.projection(x)


class CLIPModel(nn.Module):
    """Combined CLIP model for fine-tuning."""
    
    def __init__(self, text_encoder, freeze_text_encoder=True):
        super().__init__()
        
        # Text encoder (frozen)
        self.text_encoder = text_encoder
        if freeze_text_encoder:
            for param in self.text_encoder.parameters():
                param.requires_grad = False
            self.text_encoder.eval()
        
        # Image encoder (trainable)
        self.image_encoder = ResNet50ImageEncoder(pretrained=True)
        
        # Projection head (trainable)
        self.projection_head = ProjectionHead(
            input_dim=2048,
            hidden_dim=CONFIG['projection_hidden_dim'],
            output_dim=CONFIG['embedding_dim']
        )
        
    def encode_image(self, images):
        """Encode images to normalized embeddings."""
        features = self.image_encoder(images)
        embeddings = self.projection_head(features)
        # L2 normalize
        return F.normalize(embeddings, p=2, dim=1)
    
    def encode_text(self, input_ids, attention_mask):
        """Encode text to normalized embeddings (frozen)."""
        with torch.no_grad():
            outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.pooler_output
            # L2 normalize
            return F.normalize(embeddings, p=2, dim=1)
    
    def forward(self, images, input_ids, attention_mask):
        """Forward pass for training."""
        image_embeddings = self.encode_image(images)
        text_embeddings = self.encode_text(input_ids, attention_mask)
        return image_embeddings, text_embeddings

print("✓ Model architecture defined")

## 5. InfoNCE Loss

The InfoNCE (Contrastive) loss used in CLIP:

$$\mathcal{L} = -\frac{1}{2N} \sum_{i=1}^{N} \left[ \log \frac{\exp(\text{sim}(I_i, T_i) / \tau)}{\sum_{j=1}^{N} \exp(\text{sim}(I_i, T_j) / \tau)} + \log \frac{\exp(\text{sim}(T_i, I_i) / \tau)}{\sum_{j=1}^{N} \exp(\text{sim}(T_i, I_j) / \tau)} \right]$$

where $\tau$ is the temperature parameter.

In [None]:
class InfoNCELoss(nn.Module):
    """InfoNCE (Contrastive) Loss for CLIP training."""
    
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        
    def forward(self, image_embeddings, text_embeddings):
        """
        Args:
            image_embeddings: [batch_size, embedding_dim] - L2 normalized
            text_embeddings: [batch_size, embedding_dim] - L2 normalized
            
        Returns:
            loss: scalar tensor
        """
        batch_size = image_embeddings.shape[0]
        
        # Compute similarity matrix: [batch_size, batch_size]
        # Since embeddings are L2 normalized, this is cosine similarity
        logits = torch.matmul(image_embeddings, text_embeddings.T) / self.temperature
        
        # Labels: diagonal elements are positive pairs
        labels = torch.arange(batch_size, device=logits.device)
        
        # Symmetric loss: image-to-text + text-to-image
        loss_i2t = F.cross_entropy(logits, labels)
        loss_t2i = F.cross_entropy(logits.T, labels)
        
        loss = (loss_i2t + loss_t2i) / 2
        
        # Compute accuracy for logging
        with torch.no_grad():
            pred_i2t = torch.argmax(logits, dim=1)
            pred_t2i = torch.argmax(logits.T, dim=1)
            acc_i2t = (pred_i2t == labels).float().mean()
            acc_t2i = (pred_t2i == labels).float().mean()
            accuracy = (acc_i2t + acc_t2i) / 2
        
        return loss, accuracy

print("✓ InfoNCE loss defined")

## 6. Initialize Model and Optimizer

In [None]:
# Load tokenizer and text encoder
print("Loading CLIP text encoder...")
tokenizer = CLIPTokenizer.from_pretrained(CONFIG['clip_model_name'])
text_encoder = CLIPTextModel.from_pretrained(CONFIG['clip_model_name'])
text_encoder = text_encoder.to(device)
print("✓ Text encoder loaded")

# Create model
print("\nCreating CLIP model...")
model = CLIPModel(text_encoder=text_encoder, freeze_text_encoder=True)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Statistics:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,} ({100*trainable_params/total_params:.1f}%)")

# Loss function
criterion = InfoNCELoss(temperature=CONFIG['temperature'])

# Optimizer - only optimize trainable parameters
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay']
)

print(f"\n✓ Optimizer: {CONFIG['optimizer']}")
print(f"  Learning rate: {CONFIG['learning_rate']}")
print(f"  Weight decay: {CONFIG['weight_decay']}")
print(f"  Temperature: {CONFIG['temperature']}")

## 7. Learning Rate Scheduler

In [None]:
# Create datasets
train_dataset = COCOClipDataset(split='train')
val_dataset = COCOClipDataset(split='val')

# Calculate total steps
total_steps = len(train_dataset) // CONFIG['batch_size'] * CONFIG['num_epochs']

# Cosine annealing with warmup
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    """Create a schedule with linear warmup and cosine annealing."""
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + np.cos(np.pi * progress)))
    
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=CONFIG['warmup_steps'],
    num_training_steps=total_steps
)

print(f"Learning rate scheduler:")
print(f"  Type: Cosine with warmup")
print(f"  Warmup steps: {CONFIG['warmup_steps']}")
print(f"  Total steps: {total_steps:,}")

## 8. Training Loop

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, scheduler, epoch, global_step, history):
    """Train for one epoch."""
    model.train()
    # Keep text encoder in eval mode
    model.text_encoder.eval()
    
    epoch_loss = 0
    epoch_acc = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['num_epochs']}")
    
    for batch_idx, batch in enumerate(pbar):
        # Move to device
        images = batch['image'].to(device)
        text_embeddings = batch['text_embedding'].to(device)
        
        # Forward pass
        image_embeddings = model.encode_image(images)
        
        # Compute loss
        loss, accuracy = criterion(image_embeddings, text_embeddings)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['max_grad_norm'])
        
        optimizer.step()
        scheduler.step()
        
        # Update metrics
        epoch_loss += loss.item()
        epoch_acc += accuracy.item()
        global_step += 1
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'acc': f"{accuracy.item():.3f}",
            'lr': f"{scheduler.get_last_lr()[0]:.2e}"
        })
        
        # Log
        if global_step % CONFIG['log_interval'] == 0:
            history['train_loss'].append(loss.item())
            history['train_acc'].append(accuracy.item())
            history['learning_rate'].append(scheduler.get_last_lr()[0])
            history['step'].append(global_step)
    
    avg_loss = epoch_loss / len(train_loader)
    avg_acc = epoch_acc / len(train_loader)
    
    return avg_loss, avg_acc, global_step


@torch.no_grad()
def validate(model, val_loader, criterion):
    """Validate the model."""
    model.eval()
    
    total_loss = 0
    total_acc = 0
    
    for batch in tqdm(val_loader, desc="Validating"):
        images = batch['image'].to(device)
        text_embeddings = batch['text_embedding'].to(device)
        
        # Forward pass
        image_embeddings = model.encode_image(images)
        
        # Compute loss
        loss, accuracy = criterion(image_embeddings, text_embeddings)
        
        total_loss += loss.item()
        total_acc += accuracy.item()
    
    avg_loss = total_loss / len(val_loader)
    avg_acc = total_acc / len(val_loader)
    
    return avg_loss, avg_acc

print("✓ Training functions defined")

## 9. Run Training

In [None]:
# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=CONFIG['num_workers'],
    pin_memory=True,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers'],
    pin_memory=True
)

print(f"Data loaders created:")
print(f"  Train batches: {len(train_loader):,}")
print(f"  Val batches: {len(val_loader):,}")
print(f"  Batch size: {CONFIG['batch_size']}")

# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': [],
    'learning_rate': [],
    'step': [],
    'epoch': []
}

# Hardware info
hardware_info = {
    'device': str(device),
    'gpu_name': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
    'gpu_memory_gb': torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0,
}

print(f"\n{'='*60}")
print("Starting Training")
print(f"{'='*60}")
print(f"Hardware: {hardware_info['gpu_name']}")
print(f"Epochs: {CONFIG['num_epochs']}")
print(f"Batch size: {CONFIG['batch_size']}")
print(f"Learning rate: {CONFIG['learning_rate']}")
print(f"{'='*60}\n")

# Start training
start_time = time.time()
global_step = 0
best_val_loss = float('inf')

try:
    for epoch in range(CONFIG['num_epochs']):
        # Train
        train_loss, train_acc, global_step = train_epoch(
            model, train_loader, criterion, optimizer, scheduler, epoch, global_step, history
        )
        
        # Validate
        val_loss, val_acc = validate(model, val_loader, criterion)
        
        # Log epoch results
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['epoch'].append(epoch)
        
        print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}:")
        print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.3f}")
        print(f"  Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.3f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_loss': val_loss,
                'val_acc': val_acc,
                'config': CONFIG,
            }, CONFIG['checkpoint_dir'] / 'best_model.pt')
            print(f"  ✓ Saved best model (val_loss: {val_loss:.4f})")
        
        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'history': history,
        }, CONFIG['checkpoint_dir'] / f'checkpoint_epoch_{epoch+1}.pt')

except KeyboardInterrupt:
    print("\n\nTraining interrupted by user")

finally:
    # Training complete
    total_time = time.time() - start_time
    
    print(f"\n{'='*60}")
    print("Training Complete!")
    print(f"{'='*60}")
    print(f"Total time: {total_time/3600:.2f} hours ({total_time/60:.1f} minutes)")
    print(f"Best validation loss: {best_val_loss:.4f}")
    print(f"Final learning rate: {scheduler.get_last_lr()[0]:.2e}")
    print(f"{'='*60}")
    
    # Save training history and hardware info
    training_summary = {
        'history': history,
        'hardware_info': hardware_info,
        'training_time_hours': total_time / 3600,
        'best_val_loss': best_val_loss,
        'config': CONFIG,
    }
    
    torch.save(training_summary, CONFIG['log_dir'] / 'training_summary.pt')
    print(f"\n✓ Training summary saved to {CONFIG['log_dir'] / 'training_summary.pt'}")

## 10. Plot Training Curves

In [None]:
# Plot loss curves
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Training loss
axes[0, 0].plot(history['step'], history['train_loss'], alpha=0.6, label='Train Loss (per step)')
axes[0, 0].set_xlabel('Steps')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Training accuracy
axes[0, 1].plot(history['step'], history['train_acc'], alpha=0.6, label='Train Accuracy')
axes[0, 1].set_xlabel('Steps')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].set_title('Training Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Validation loss
axes[1, 0].plot(history['epoch'], history['val_loss'], 'o-', label='Val Loss (per epoch)', linewidth=2)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].set_title('Validation Loss')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Validation accuracy
axes[1, 1].plot(history['epoch'], history['val_acc'], 'o-', label='Val Accuracy (per epoch)', linewidth=2)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Accuracy')
axes[1, 1].set_title('Validation Accuracy')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(CONFIG['log_dir'] / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ Training curves saved to {CONFIG['log_dir'] / 'training_curves.png'}")

## 11. Training Report

In [None]:
# Generate training report
print(f"\n{'='*60}")
print("TRAINING REPORT")
print(f"{'='*60}\n")

print("Hardware Information:")
print(f"  Device: {hardware_info['device']}")
print(f"  GPU: {hardware_info['gpu_name']}")
print(f"  GPU Memory: {hardware_info['gpu_memory_gb']:.2f} GB")

print("\nTraining Configuration:")
print(f"  Epochs: {CONFIG['num_epochs']}")
print(f"  Batch size: {CONFIG['batch_size']}")
print(f"  Learning rate: {CONFIG['learning_rate']}")
print(f"  Weight decay: {CONFIG['weight_decay']}")
print(f"  Temperature: {CONFIG['temperature']}")
print(f"  Optimizer: {CONFIG['optimizer']}")
print(f"  Scheduler: {CONFIG['scheduler']} with {CONFIG['warmup_steps']} warmup steps")
print(f"  Gradient clipping: {CONFIG['max_grad_norm']}")

print("\nTraining Results:")
print(f"  Total training time: {total_time/3600:.2f} hours ({total_time/60:.1f} minutes)")
print(f"  Best validation loss: {best_val_loss:.4f}")
print(f"  Final validation loss: {history['val_loss'][-1]:.4f}")
print(f"  Final validation accuracy: {history['val_acc'][-1]:.3f}")
print(f"  Total steps: {global_step:,}")

print("\nObserved Issues:")
# Check for common issues
if max(history['train_loss']) > 10:
    print("  ⚠️ High initial loss detected - may need learning rate adjustment")
if history['val_loss'][-1] > history['val_loss'][0]:
    print("  ⚠️ Validation loss increased - possible overfitting")
if best_val_loss == history['val_loss'][-1]:
    print("  ✓ Best model is from the final epoch - training is converging well")
else:
    best_epoch = history['epoch'][history['val_loss'].index(best_val_loss)]
    print(f"  ℹ️ Best model from epoch {best_epoch+1}, consider early stopping")

print(f"\n{'='*60}")

# Save report to file
with open(CONFIG['log_dir'] / 'training_report.txt', 'w') as f:
    f.write(f"CLIP Training Report\n")
    f.write(f"{'='*60}\n\n")
    f.write(f"Hardware: {hardware_info['gpu_name']}\n")
    f.write(f"Training time: {total_time/3600:.2f} hours\n")
    f.write(f"Best val loss: {best_val_loss:.4f}\n")
    f.write(f"Final val accuracy: {history['val_acc'][-1]:.3f}\n")

print(f"\n✓ Training report saved to {CONFIG['log_dir'] / 'training_report.txt'}")