## 10. Tips and Advanced Usage

### Model Selection
- **CLIP (ViT-B/32)**: Good balance of speed and quality. Excellent for general-purpose tasks.
- **AlignCLIP**: Improved alignment variant. Better at fine-grained distinctions.
- **CLOOB**: Uses optimal transport for better alignment. Slower but potentially better quality.

### Hyperparameter Tuning
- **Learning Rate**: Typically 1e-6 to 1e-4. Lower rates for heavier models.
- **Batch Size**: 32-256 depending on GPU memory.
- **Temperature**: Controls contrastive sharpness. 0.07 is default. Lower = sharper, Higher = softer.
- **Weight Decay**: L2 regularization strength. Usually 0.01-0.1.

### Tips for Better Results
1. Use more diverse training data
2. Increase training epochs when using large datasets
3. Use gradient accumulation for effective larger batch sizes
4. Consider data augmentation strategies
5. Fine-tune on task-specific data for better performance

### Next Steps
- Export finetuned models for production
- Evaluate on downstream tasks (retrieval, classification)
- Combine ensemble predictions from all three models

In [None]:
def encode_texts(model: nn.Module, texts: List[str], device: torch.device) -> torch.Tensor:
    """Encode a list of texts to embeddings."""
    tokens = torch.cat([clip.tokenize(text) for text in texts])
    tokens = tokens.to(device)
    
    with torch.no_grad():
        text_features = model.model.encode_text(tokens)
        text_features = torch.nn.functional.normalize(text_features, dim=-1)
    
    return text_features


def compute_similarity(model: nn.Module, text: str, device: torch.device) -> float:
    """Compute similarity between text and sample texts."""
    text_features = encode_texts(model, [text], device)
    
    # Sample texts from dataset
    sample_texts = val_data[:5]
    sample_text_list = [s['text'] for s in sample_texts]
    sample_features = encode_texts(model, sample_text_list, device)
    
    # Compute similarities
    similarities = text_features @ sample_features.T
    return similarities.cpu().numpy()[0]


# Test inference with CLIP
test_text = "A beautiful sunset over mountains"
logger.info(f"\nTesting inference with text: '{test_text}'")

clip_model.eval()
similarities = compute_similarity(clip_model, test_text, device)
logger.info("Text Similarities (CLIP):")
for i, (text, sim) in enumerate(zip([s['text'] for s in val_data[:5]], similarities)):
    logger.info(f"  {i+1}. {sim:.4f} - {text[:50]}...")

print("\nInference test completed successfully!")

## 9. Inference and Feature Extraction

In [None]:
# Compare all models
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# CLIP losses
axes[0].plot(clip_trainer.train_losses, label='Train', marker='o')
axes[0].plot(clip_trainer.val_losses, label='Val', marker='s')
axes[0].set_title('CLIP')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# AlignCLIP losses
axes[1].plot(align_trainer.train_losses, label='Train', marker='o')
axes[1].plot(align_trainer.val_losses, label='Val', marker='s')
axes[1].set_title('AlignCLIP')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# CLOOB losses (if available)
if 'cloob_trainer' in locals():
    axes[2].plot(cloob_trainer.train_losses, label='Train', marker='o')
    axes[2].plot(cloob_trainer.val_losses, label='Val', marker='s')
    axes[2].set_title('CLOOB')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Loss')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)
else:
    axes[2].text(0.5, 0.5, 'CLOOB Not Available', ha='center', va='center')
    axes[2].set_title('CLOOB')

plt.tight_layout()
plt.show()

# Print summary
print("\n" + "="*50)
print("TRAINING SUMMARY")
print("="*50)
print(f"CLIP      - Best Val Loss: {clip_trainer.best_val_loss:.6f}")
print(f"AlignCLIP - Best Val Loss: {align_trainer.best_val_loss:.6f}")
if 'cloob_trainer' in locals():
    print(f"CLOOB     - Best Val Loss: {cloob_trainer.best_val_loss:.6f}")
print("="*50)

## 8. Model Evaluation and Comparison

In [None]:
# Train CLOOB if initialization succeeded
try:
    cloob_trainer.train(num_epochs=config['num_epochs'])
    cloob_trainer.plot_losses()
except NameError:
    logger.info("CLOOB trainer not available")

In [None]:
# Initialize CLOOB model
logger.info("Initializing CLOOB model...")
try:
    cloob_model = CLOOBModel(device=str(device))
    
    # Create datasets with CLOOB preprocessing
    train_dataset_cloob = CLIPDataset(train_data, preprocess=cloob_model.preprocess, model_name="cloob")
    val_dataset_cloob = CLIPDataset(val_data, preprocess=cloob_model.preprocess, model_name="cloob")
    
    # Create data loaders
    train_loader_cloob = DataLoader(
        train_dataset_cloob,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=0,
        pin_memory=True if torch.cuda.is_available() else False
    )
    val_loader_cloob = DataLoader(
        val_dataset_cloob,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=0,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    # Setup loss and optimizer for CLOOB
    cloob_criterion = CLIPContrastiveLoss(temperature=config['temperature'])
    
    cloob_optimizer = optim.AdamW(
        cloob_model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )
    
    # Create trainer
    cloob_trainer = CLIPTrainer(
        model=cloob_model,
        train_loader=train_loader_cloob,
        val_loader=val_loader_cloob,
        criterion=cloob_criterion,
        optimizer=cloob_optimizer,
        device=device,
        config=config,
        model_name="cloob"
    )
    
    logger.info("CLOOB trainer initialized")
    
except Exception as e:
    logger.error(f"Failed to initialize CLOOB: {e}")
    logger.info("Skipping CLOOB training")

## 7. Finetune CLOOB Model

In [None]:
# Train AlignCLIP
align_trainer.train(num_epochs=config['num_epochs'])
align_trainer.plot_losses()

In [None]:
# Initialize AlignCLIP model
logger.info("Initializing AlignCLIP model...")
alignclip_model = AlignCLIPFinetunableModel(device=str(device))

# Create datasets with AlignCLIP preprocessing
train_dataset_align = CLIPDataset(train_data, preprocess=alignclip_model.preprocess, model_name="alignclip")
val_dataset_align = CLIPDataset(val_data, preprocess=alignclip_model.preprocess, model_name="alignclip")

# Create data loaders (reuse same batch configuration)
train_loader_align = DataLoader(
    train_dataset_align,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=0,
    pin_memory=True if torch.cuda.is_available() else False
)
val_loader_align = DataLoader(
    val_dataset_align,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=0,
    pin_memory=True if torch.cuda.is_available() else False
)

# Setup loss and optimizer for AlignCLIP
align_criterion = CLIPContrastiveLoss(temperature=config['temperature'])

align_optimizer = optim.AdamW(
    alignclip_model.parameters(),
    lr=config['learning_rate'],
    weight_decay=config['weight_decay']
)

# Create trainer
align_trainer = CLIPTrainer(
    model=alignclip_model,
    train_loader=train_loader_align,
    val_loader=val_loader_align,
    criterion=align_criterion,
    optimizer=align_optimizer,
    device=device,
    config=config,
    model_name="alignclip"
)

logger.info("AlignCLIP trainer initialized")

## 6. Finetune AlignCLIP Model

In [None]:
# Train CLIP
clip_trainer.train(num_epochs=config['num_epochs'])
clip_trainer.plot_losses()

In [None]:
# Setup loss and optimizer for CLIP
clip_criterion = CLIPContrastiveLoss(temperature=config['temperature'])

# Only finetune visual and text heads (not backbone for efficiency)
clip_optimizer = optim.AdamW(
    clip_model.parameters(),
    lr=config['learning_rate'],
    weight_decay=config['weight_decay']
)

# Create trainer
clip_trainer = CLIPTrainer(
    model=clip_model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=clip_criterion,
    optimizer=clip_optimizer,
    device=device,
    config=config,
    model_name="clip"
)

logger.info("CLIP trainer initialized")

In [None]:
# Initialize CLIP model
logger.info("Initializing CLIP model...")
clip_model = CLIPFinetunableModel(device=str(device))

# Create datasets
train_dataset = CLIPDataset(train_data, preprocess=clip_model.preprocess, model_name="clip")
val_dataset = CLIPDataset(val_data, preprocess=clip_model.preprocess, model_name="clip")

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=0,  # Set to 0 to avoid multiprocessing issues in Jupyter
    pin_memory=True if torch.cuda.is_available() else False
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=0,
    pin_memory=True if torch.cuda.is_available() else False
)

logger.info(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

## 5. Finetune CLIP Model

In [None]:
class CLIPTrainer:
    """Trainer for CLIP-style models."""
    
    def __init__(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        criterion: nn.Module,
        optimizer: optim.Optimizer,
        device: torch.device,
        config: Dict,
        model_name: str = "clip"
    ):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.config = config
        self.model_name = model_name
        
        # Training tracking
        self.train_losses = []
        self.val_losses = []
        self.best_val_loss = float('inf')
        self.scaler = GradScaler()
        
    def train_epoch(self) -> float:
        """Train for one epoch."""
        self.model.train()
        total_loss = 0.0
        
        pbar = tqdm(self.train_loader, desc="Training")
        for batch_idx, (images, text_tokens, _) in enumerate(pbar):
            images = images.to(self.device)
            text_tokens = text_tokens.to(self.device)
            
            # Forward pass with mixed precision
            with autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu'):
                image_features, text_features = self.model(images, text_tokens)
                loss = self.criterion(image_features, text_features)
            
            # Backward pass
            self.scaler.scale(loss).backward()
            
            # Gradient accumulation
            if (batch_idx + 1) % self.config['grad_accumulation_steps'] == 0:
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.optimizer.zero_grad()
            
            total_loss += loss.item()
            pbar.set_postfix({'loss': loss.item()})
        
        avg_loss = total_loss / len(self.train_loader)
        self.train_losses.append(avg_loss)
        return avg_loss
    
    def validate(self) -> float:
        """Validate model."""
        self.model.eval()
        total_loss = 0.0
        
        with torch.no_grad():
            pbar = tqdm(self.val_loader, desc="Validating")
            for images, text_tokens, _ in pbar:
                images = images.to(self.device)
                text_tokens = text_tokens.to(self.device)
                
                image_features, text_features = self.model(images, text_tokens)
                loss = self.criterion(image_features, text_features)
                total_loss += loss.item()
                pbar.set_postfix({'loss': loss.item()})
        
        avg_loss = total_loss / len(self.val_loader)
        self.val_losses.append(avg_loss)
        return avg_loss
    
    def train(self, num_epochs: int):
        """Train model for specified epochs."""
        logger.info(f"Starting training on {self.device} for {num_epochs} epochs...")
        
        for epoch in range(num_epochs):
            logger.info(f"\nEpoch {epoch + 1}/{num_epochs}")
            
            # Train
            train_loss = self.train_epoch()
            logger.info(f"Train Loss: {train_loss:.6f}")
            
            # Validate
            val_loss = self.validate()
            logger.info(f"Val Loss: {val_loss:.6f}")
            
            # Save best model
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.save_checkpoint(f"best_{self.model_name}_model.pt")
                logger.info(f"âœ“ Saved best model (val_loss: {val_loss:.6f})")
    
    def save_checkpoint(self, filename: str):
        """Save model checkpoint."""
        checkpoint = {
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'config': self.config,
        }
        torch.save(checkpoint, filename)
        logger.info(f"Checkpoint saved: {filename}")
    
    def plot_losses(self):
        """Plot training and validation losses."""
        plt.figure(figsize=(10, 6))
        plt.plot(self.train_losses, label='Train Loss', marker='o')
        plt.plot(self.val_losses, label='Val Loss', marker='s')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title(f'{self.model_name.upper()} Training Curves')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()

## 4. Trainer Class

In [None]:
# Training hyperparameters
config = {
    'batch_size': 32,
    'num_epochs': 5,
    'learning_rate': 1e-5,
    'weight_decay': 0.1,
    'warmup_steps': 100,
    'num_workers': 4,
    'grad_accumulation_steps': 1,
    'temperature': 0.07,
    'model_name': 'clip',  # Options: 'clip', 'alignclip', 'cloob'
}

logger.info(f"Training config: {json.dumps(config, indent=2)}")

## 3. Training Configuration

In [None]:
class CLIPFinetunableModel(nn.Module):
    """Wrapper for CLIP model for finetuning."""
    
    def __init__(self, device: str = None):
        super().__init__()
        # Import and use actual CLIPModel from codebase
        from Models.clipModel import CLIPModel
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = CLIPModel(device=self.device)
        self.preprocess = self.model.preprocess
        
    def forward(self, images: torch.Tensor, text_tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass.
        
        Returns:
            Tuple of (image_features, text_features)
        """
        image_features = self.model.encode_image_tensors(images, requires_grad=True)
        text_features = self.model.encode_text_tokens(text_tokens, requires_grad=True)
        
        return image_features, text_features


class AlignCLIPFinetunableModel(nn.Module):
    """Wrapper for AlignCLIP model for finetuning."""
    
    def __init__(self, device: str = None):
        super().__init__()
        # Import and use actual AlignCLIPModel from codebase
        from Models.alignClipModel import AlignCLIPModel
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = AlignCLIPModel(device=self.device)
        self.preprocess = self.model.preprocess
        
    def forward(self, images: torch.Tensor, text_tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass."""
        image_features = self.model.encode_image_tensors(images, requires_grad=True)
        text_features = self.model.encode_text_tokens(text_tokens, requires_grad=True)
        return image_features, text_features


class CLOOBFinetunableModel(nn.Module):
    """Wrapper for CLOOB model for finetuning."""
    
    def __init__(self, device: str = None):
        super().__init__()
        # Import and use actual CLOOBModel from codebase
        from Models.cloobModel import CLOOBModel
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = CLOOBModel(device=self.device)
        # CLOOB uses CLIP preprocessing
        self.preprocess = clip.load("ViT-B/32", device=self.device)[1]
    
    def forward(self, images: torch.Tensor, text_tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass."""
        image_features = self.model.encode_image_tensors(images, requires_grad=True)
        text_features = self.model.encode_text_tokens(text_tokens, requires_grad=True)
        return image_features, text_features

In [None]:
class CLIPContrastiveLoss(nn.Module):
    """
    Contrastive loss for CLIP-style models.
    Aligns image and text embeddings in shared space.
    """
    
    def __init__(self, temperature: float = 0.07):
        """
        Initialize loss.
        
        Args:
            temperature: Temperature parameter for scaling logits
        """
        super().__init__()
        self.temperature = temperature
        
    def forward(self, image_features: torch.Tensor, text_features: torch.Tensor) -> torch.Tensor:
        """
        Compute contrastive loss.
        
        Args:
            image_features: Image embeddings [batch_size, embedding_dim]
            text_features: Text embeddings [batch_size, embedding_dim]
            
        Returns:
            Scalar loss value
        """
        # Normalize features
        image_features = torch.nn.functional.normalize(image_features, dim=-1)
        text_features = torch.nn.functional.normalize(text_features, dim=-1)
        
        # Compute logits
        logits_per_image = image_features @ text_features.T / self.temperature
        logits_per_text = text_features @ image_features.T / self.temperature
        
        # Create labels (diagonal elements are positive pairs)
        batch_size = image_features.shape[0]
        labels = torch.arange(batch_size, device=image_features.device)
        
        # Compute cross-entropy loss
        loss_img = torch.nn.functional.cross_entropy(logits_per_image, labels)
        loss_txt = torch.nn.functional.cross_entropy(logits_per_text, labels)
        
        loss = (loss_img + loss_txt) / 2
        
        return loss

## 2. Define Loss Functions and Models

In [None]:
# Split dataset into train/val
train_ratio = 0.8
split_idx = int(len(dataset) * train_ratio)
train_data = dataset[:split_idx]
val_data = dataset[split_idx:]

logger.info(f"Train samples: {len(train_data)}, Val samples: {len(val_data)}")

In [None]:
class CLIPDataset(Dataset):
    """
    Custom dataset for CLIP-style models.
    Handles image-text pairs with preprocessing.
    """
    
    def __init__(self, data: List[Dict], preprocess=None, model_name: str = "clip"):
        """
        Initialize dataset.
        
        Args:
            data: List of dicts with 'image_path' and 'text' keys
            preprocess: Image preprocessing function
            model_name: One of 'clip', 'alignclip', 'cloob'
        """
        self.data = data
        self.preprocess = preprocess
        self.model_name = model_name
        
        # Initialize tokenizer based on model
        if model_name in ["clip", "alignclip"]:
            self.context_length = 77
        elif model_name == "cloob":
            self.context_length = 77
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, str]:
        """
        Get a single sample.
        
        Returns:
            Tuple of (image_tensor, text_tokens, text)
        """
        sample = self.data[idx]
        text = sample.get('text', '')
        
        # Handle image
        if sample.get('image_path') and os.path.exists(sample['image_path']):
            try:
                image = Image.open(sample['image_path']).convert('RGB')
                if self.preprocess:
                    image = self.preprocess(image)
                else:
                    # Default preprocessing
                    image = torch.zeros(3, 224, 224)
            except Exception as e:
                logger.warning(f"Error loading image {sample['image_path']}: {e}")
                image = torch.zeros(3, 224, 224)
        else:
            image = torch.zeros(3, 224, 224)
        
        # Tokenize text
        if self.model_name in ["clip", "alignclip"]:
            text_tokens = clip.tokenize(text, context_length=self.context_length)[0]
        else:
            text_tokens = clip.tokenize(text, context_length=self.context_length)[0]
        
        return image, text_tokens, text


# Load sample dataset using DatasetLoader
logger.info("Loading sample dataset...")
dataset = DatasetLoader.load_laion_sample()
logger.info(f"Loaded {len(dataset)} samples")

## 1. Load and Prepare Datasets

In [None]:
# Add workspace to path
workspace_path = Path.cwd()
sys.path.insert(0, str(workspace_path))

# Import custom dataset loader
from datasetLoader import DatasetLoader

In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.amp import autocast, GradScaler
import clip
from PIL import Image
from pathlib import Path
import json
from typing import Dict, List, Tuple, Optional
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import logging

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
if torch.cuda.is_available():
    logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
    logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Multimodal Model Finetuning Notebook
## CLIP, AlignCLIP, and CLOOB Finetuning

This notebook provides a comprehensive framework for finetuning three multimodal models:
- **CLIP**: OpenAI's Contrastive Language-Image Pre-training
- **AlignCLIP**: Improved alignment variant of CLIP
- **CLOOB**: Contrastive Learning with Optimal Transport for Image-Text

All models use contrastive learning to align image and text embeddings in a shared space.