# Assessment 2: Metric Learning with Oxford-IIIT Pet Dataset

## Introduction and Setup

This notebook implements a deep metric learning approach for the Oxford-IIIT Pet Dataset, focusing on learning an embedding space where similar pet breeds are close together and dissimilar ones are far apart. We'll explore different loss functions, evaluate the model on verification, retrieval, and few-shot classification tasks, and visualize the embedding space.

### Environment Setup and Package Installation

In [None]:

# Check if running in Colab (to install dependencies and set up environment)
import sys
IN_COLAB = 'google.colab' in sys.modules

# Install required packages
if IN_COLAB:
    !pip install pytorch-metric-learning
    !pip install faiss-gpu
    !pip install umap-learn
    !pip install matplotlib seaborn scikit-learn tqdm
    !pip install gradio
    !pip install grad-cam

### Import Libraries

import os
import random
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.notebook import tqdm
from collections import defaultdict
import itertools  # For generating pairs
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
from torchvision import datasets, models, transforms
import torchvision.transforms.functional as TF

import pytorch_metric_learning
from pytorch_metric_learning import losses, miners, distances, reducers, testers

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
from sklearn.model_selection import train_test_split
import umap

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")





## Data Loading and Preprocessing

In this section, we'll load the Oxford-IIIT Pet Dataset, perform necessary preprocessing, and create appropriate data loaders for our metric learning tasks.

In [None]:

# Define transformations
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

eval_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Function to load the dataset
def load_oxford_pets_dataset(root="./data", download=True):
    train_val_dataset = datasets.OxfordIIITPet(
        root=root, 
        split="trainval", 
        transform=train_transform, 
        download=download
    )
    
    test_dataset = datasets.OxfordIIITPet(
        root=root, 
        split="test", 
        transform=eval_transform, 
        download=download
    )
    
    # For evaluation, create a version of the training set with eval transforms
    eval_train_dataset = datasets.OxfordIIITPet(
        root=root, 
        split="trainval", 
        transform=eval_transform, 
        download=False
    )
    
    return train_val_dataset, test_dataset, eval_train_dataset



### Dataset Preparation for Different Tasks

In [None]:

# Split data for training, validation and few-shot evaluation
def prepare_datasets(train_val_dataset, test_dataset, eval_train_dataset, num_holdout_classes=5, val_ratio=0.2):
    # Get the class names
    class_to_idx = train_val_dataset.class_to_idx
    idx_to_class = {v: k for k, v in class_to_idx.items()}
    num_classes = len(class_to_idx)
    
    # Split classes for few-shot learning (hold out some classes for testing)
    all_class_indices = list(range(num_classes))
    holdout_class_indices = random.sample(all_class_indices, num_holdout_classes)
    training_class_indices = [i for i in all_class_indices if i not in holdout_class_indices]
    
    holdout_classes = [idx_to_class[i] for i in holdout_class_indices]
    print(f"Holdout classes for few-shot learning: {holdout_classes}")
    
    # Create datasets excluding holdout classes for main training
    train_val_indices = [i for i, (_, label) in enumerate(train_val_dataset) if label not in holdout_class_indices]
    test_indices = [i for i, (_, label) in enumerate(test_dataset) if label not in holdout_class_indices]
    eval_train_indices = [i for i, (_, label) in enumerate(eval_train_dataset) if label not in holdout_class_indices]
    
    # For few-shot learning, include only holdout classes
    few_shot_train_indices = [i for i, (_, label) in enumerate(train_val_dataset) if label in holdout_class_indices]
    few_shot_test_indices = [i for i, (_, label) in enumerate(test_dataset) if label in holdout_class_indices]
    
    # Split train/val
    train_indices, val_indices = train_test_split(
        train_val_indices, 
        test_size=val_ratio, 
        stratify=[train_val_dataset[i][1] for i in train_val_indices],
        random_state=42
    )
    
    # Create Subset datasets
    train_dataset = Subset(train_val_dataset, train_indices)
    val_dataset = Subset(train_val_dataset, val_indices)
    test_filtered_dataset = Subset(test_dataset, test_indices)
    eval_train_dataset = Subset(eval_train_dataset, eval_train_indices)
    
    # Create datasets for few-shot learning
    few_shot_train_dataset = Subset(train_val_dataset, few_shot_train_indices)
    few_shot_test_dataset = Subset(test_dataset, few_shot_test_indices)
    
    # Create dictionary for class mapping
    class_mapping = {
        'class_to_idx': class_to_idx,
        'idx_to_class': idx_to_class,
        'holdout_class_indices': holdout_class_indices,
        'training_class_indices': training_class_indices
    }
    
    return {
        'train': train_dataset,
        'val': val_dataset,
        'test': test_filtered_dataset,
        'eval_train': eval_train_dataset,
        'few_shot_train': few_shot_train_dataset,
        'few_shot_test': few_shot_test_dataset,
        'class_mapping': class_mapping
    }



### Create DataLoaders

In [None]:

def create_dataloaders(datasets_dict, batch_size=32, num_workers=2):
    dataloaders = {}
    
    for key in ['train', 'val', 'test', 'eval_train', 'few_shot_train', 'few_shot_test']:
        if key == 'train':
            shuffle = True
        else:
            shuffle = False
            
        dataloaders[key] = DataLoader(
            datasets_dict[key],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            pin_memory=True
        )
    
    return dataloaders




### Load and Prepare Data


In [None]:

# Load the dataset
train_val_dataset, test_dataset, eval_train_dataset = load_oxford_pets_dataset()
print(f"Train+Val size: {len(train_val_dataset)}")
print(f"Test size: {len(test_dataset)}")

# Prepare datasets for different tasks
datasets_dict = prepare_datasets(train_val_dataset, test_dataset, eval_train_dataset)

# Create dataloaders
batch_size = 32  # Adjust based on your GPU/memory constraints
dataloaders = create_dataloaders(datasets_dict, batch_size=batch_size)

# Print dataset statistics
print("\nDataset Statistics:")
for key, dataloader in dataloaders.items():
    print(f"{key}: {len(dataloader.dataset)} samples")

class_mapping = datasets_dict['class_mapping']
num_classes = len(class_mapping['class_to_idx'])
print(f"Total number of classes: {num_classes}")
print(f"Number of training classes: {len(class_mapping['training_class_indices'])}")
print(f"Number of few-shot classes: {len(class_mapping['holdout_class_indices'])}")




## Model Architecture

In this section, we'll define our metric learning model architecture using a CNN backbone and a projection head.


In [None]:

class EmbeddingNet(nn.Module):
    def __init__(self, backbone_name='resnet18', embedding_size=128, pretrained=True):
        super(EmbeddingNet, self).__init__()
        
        # Get backbone and its output size
        self.backbone, backbone_output_size = self._get_backbone(backbone_name, pretrained)
        
        # Projection head (MLP)
        self.projection_head = nn.Sequential(
            nn.Linear(backbone_output_size, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, embedding_size)
        )
        
    def _get_backbone(self, backbone_name, pretrained):
        """
        Create a backbone network from various architectures
        """
        if backbone_name == 'resnet18':
            backbone = models.resnet18(pretrained=pretrained)
            output_size = 512
        elif backbone_name == 'resnet34':
            backbone = models.resnet34(pretrained=pretrained)
            output_size = 512
        elif backbone_name == 'resnet50':
            backbone = models.resnet50(pretrained=pretrained)
            output_size = 2048
        elif backbone_name == 'efficientnet_b0':
            backbone = models.efficientnet_b0(pretrained=pretrained)
            output_size = 1280
        elif backbone_name == 'mobilenet_v2':
            backbone = models.mobilenet_v2(pretrained=pretrained)
            output_size = 1280
        elif backbone_name == 'densenet121':
            backbone = models.densenet121(pretrained=pretrained)
            output_size = 1024
        else:
            raise ValueError(f"Unsupported backbone: {backbone_name}")
        
        # For ResNet models
        if backbone_name.startswith('resnet'):
            # Remove the classification layer
            backbone = nn.Sequential(*list(backbone.children())[:-1])
        # For EfficientNet
        elif backbone_name.startswith('efficientnet'):
            backbone = nn.Sequential(*list(backbone.children())[:-1])
        # For MobileNet
        elif backbone_name.startswith('mobilenet'):
            backbone = nn.Sequential(*list(backbone.children())[:-1])
        # For DenseNet
        elif backbone_name.startswith('densenet'):
            backbone = nn.Sequential(
                backbone.features,
                nn.ReLU(inplace=True),
                nn.AdaptiveAvgPool2d((1, 1))
            )
        
        return backbone, output_size
        
    def forward(self, x):
        features = self.backbone(x)
        features = features.view(features.size(0), -1)
        embeddings = self.projection_head(features)
        
        # Normalize embeddings to unit length (important for cosine distance)
        normalized_embeddings = F.normalize(embeddings, p=2, dim=1)
        return normalized_embeddings
    
    def get_embedding(self, x):
        return self.forward(x)


## Loss Function Implementation

Here we'll implement several loss functions for metric learning including Triplet Loss, Contrastive Loss, and ArcFace. We'll also implement miners for efficient training.


In [None]:

def create_loss_and_miner(loss_type, margin=0.2, embedding_size=128, num_classes=32):
    """
    Create loss function and miner for metric learning
    """
    if loss_type == 'triplet':
        # Triplet loss with cosine distance
        distance = distances.CosineSimilarity()
        reducer = reducers.ThresholdReducer(low=0)
        loss_func = losses.TripletMarginLoss(margin=margin, distance=distance, reducer=reducer)
        mining_func = miners.TripletMarginMiner(margin=margin, distance=distance, type_of_triplets="semihard")
        
    elif loss_type == 'contrastive':
        # Contrastive loss
        distance = distances.CosineSimilarity()
        loss_func = losses.ContrastiveLoss(pos_margin=0.8, neg_margin=0.2, distance=distance)
        mining_func = miners.PairMarginMiner(pos_margin=0.8, neg_margin=0.2, distance=distance)
        
    elif loss_type == 'arcface':
        # ArcFace loss
        loss_func = losses.ArcFaceLoss(embedding_size, num_classes, margin=28.6, scale=64)
        mining_func = None
        
    else:
        raise ValueError(f"Unsupported loss type: {loss_type}")
        
    return loss_func, mining_func


## Training Pipeline with Learning Rate Scheduler and Early Stopping

Here we implement learning rate scheduling and early stopping mechanisms to improve training efficiency and model performance.

In [None]:
def get_optimizer_and_scheduler(model, learning_rate=1e-4, weight_decay=0.0001):
    """
    Create optimizer and learning rate scheduler
    
    Args:
        model: The model to optimize
        learning_rate: Initial learning rate
        weight_decay: L2 regularization strength
        
    Returns:
        optimizer: The optimizer object
        scheduler: The learning rate scheduler
    """
    optimizer = optim.Adam(
        model.parameters(), 
        lr=learning_rate, 
        weight_decay=weight_decay
    )
    
    # Learning rate scheduler that reduces LR when validation loss plateaus
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=0.5,  # reduce LR by half when plateauing
        patience=3,   # wait 3 epochs of no improvement before reducing
        verbose=True,
        min_lr=1e-6
    )
    
    return optimizer, scheduler


class EarlyStopping:
    """Early stopping to prevent overfitting and save computation"""
    def __init__(self, patience=5, verbose=True, delta=0.0001, path='best_model.pth'):
        """
        Args:
            patience (int): How many epochs to wait after last improvement
            verbose (bool): If True, prints a message for each improvement
            delta (float): Minimum change to qualify as an improvement
            path (str): Path to save the checkpoint
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.delta = delta
        self.path = path
        
    def __call__(self, val_loss, model):
        score = -val_loss
        
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0
            
    def save_checkpoint(self, val_loss, model):
        '''Save model when validation loss decreases.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss


In [None]:
def train_model(model, dataloaders, loss_type, optimizer=None, scheduler=None, num_epochs=15, embedding_size=128):
    """
    Train the metric learning model with enhanced monitoring
    """
    # Get the number of training classes (excluding holdout classes)
    num_training_classes = len(datasets_dict['class_mapping']['training_class_indices'])
    
    # Create loss function and miner
    loss_func, mining_func = create_loss_and_miner(
        loss_type=loss_type, 
        embedding_size=embedding_size, 
        num_classes=num_training_classes
    )
    
    # If using ArcFace, we need to create a class map for the training dataset
    if loss_type == 'arcface':
        # Map original class indices to consecutive integers for ArcFace
        class_map = {original: i for i, original in enumerate(datasets_dict['class_mapping']['training_class_indices'])}
    
    # Create optimizer and scheduler if not provided
    if optimizer is None:
        # Use default parameters if not provided
        optimizer, auto_scheduler = get_optimizer_and_scheduler(model)
        scheduler = auto_scheduler if scheduler is None else scheduler
    
    # Initialize early stopping
    early_stopping = EarlyStopping(patience=5, verbose=True)
    
    # Training loop
    best_val_loss = float('inf')
    history = {
        'train_loss': [], 
        'val_loss': [],
        'lr': [],
        'batch_losses': [],  # Track per-batch losses for more detailed monitoring
        'gradient_norms': []  # Track gradient norms to monitor training stability
    }
    
    # Create figures for live updates
    if IN_COLAB:  # Only for interactive environments
        from IPython.display import clear_output
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print('-' * 10)
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            running_loss = 0.0
            batch_losses = []
            grad_norms = []
            
            # Iterate over data
            for inputs, labels in tqdm(dataloaders[phase], desc=phase):
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                # Map labels for ArcFace if needed
                if loss_type == 'arcface':
                    # Filter out samples from holdout classes
                    valid_idx = torch.tensor([i for i, l in enumerate(labels) if l.item() in class_map], device=device)
                    if len(valid_idx) == 0:
                        continue
                    
                    inputs = inputs[valid_idx]
                    arcface_labels = torch.tensor([class_map[l.item()] for l in labels[valid_idx]], device=device)
                    labels = arcface_labels
                
                # Zero the parameter gradients
                optimizer.zero_grad()
                
                # Forward pass
                with torch.set_grad_enabled(phase == 'train'):
                    embeddings = model(inputs)
                    
                    # Get indices for mining if using a mining function
                    if mining_func is not None:
                        hard_pairs = mining_func(embeddings, labels)
                        loss = loss_func(embeddings, labels, hard_pairs)
                    else:
                        loss = loss_func(embeddings, labels)
                    
                    # Record batch loss
                    batch_losses.append(loss.item())
                    
                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        
                        # Compute gradient norm
                        total_norm = 0
                        for p in model.parameters():
                            if p.grad is not None:
                                param_norm = p.grad.data.norm(2)
                                total_norm += param_norm.item() ** 2
                        total_norm = total_norm ** 0.5
                        grad_norms.append(total_norm)
                        
                        optimizer.step()
                
                # Statistics
                running_loss += loss.item() * inputs.size(0)
            
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            
            if phase == 'train':
                current_lr = optimizer.param_groups[0]['lr']
                history['lr'].append(current_lr)
                history['train_loss'].append(epoch_loss)
                history['batch_losses'].extend(batch_losses)
                history['gradient_norms'].extend(grad_norms)
                print(f'{phase} Loss: {epoch_loss:.4f}, LR: {current_lr:.6f}, Grad Norm: {np.mean(grad_norms):.4f}')
            else:
                history['val_loss'].append(epoch_loss)
                print(f'{phase} Loss: {epoch_loss:.4f}')
                
                # Update learning rate scheduler based on validation loss
                if scheduler is not None and isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                    scheduler.step(epoch_loss)
                
                # Check early stopping
                early_stopping(epoch_loss, model)
                if early_stopping.early_stop:
                    print("Early stopping triggered")
                    break
        
        # Update step-based schedulers at the end of each epoch
        if scheduler is not None and not isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step()
        
        # Check if early stopping was triggered
        if early_stopping.early_stop:
            print("Training stopped early due to no improvement in validation loss")
            break
        
        # Visualize training progress (after each epoch)
        if IN_COLAB and (epoch % 1 == 0 or epoch == num_epochs-1):  # Update every epoch or at the end
            clear_output(wait=True)
            
            # Plot loss curves
            ax1.clear()
            ax1.plot(range(1, epoch+2), history['train_loss'], 'b-', label='Training Loss')
            ax1.plot(range(1, epoch+2), history['val_loss'], 'r-', label='Validation Loss')
            ax1.set_xlabel('Epoch')
            ax1.set_ylabel('Loss')
            ax1.set_title(f'Training Progress - {loss_type} Loss')
            ax1.legend()
            ax1.grid(True)
            
            # Plot learning rate
            ax2.clear()
            ax2.plot(range(1, epoch+2), history['lr'], 'g-')
            ax2.set_xlabel('Epoch')
            ax2.set_ylabel('Learning Rate')
            ax2.set_title('Learning Rate Schedule')
            ax2.grid(True)
            
            plt.tight_layout()
            plt.show()
            
            # Plot batch losses and gradient norms
            fig2, (ax3, ax4) = plt.subplots(1, 2, figsize=(15, 5))
            
            # Batch losses
            ax3.plot(history['batch_losses'], 'b-', alpha=0.5)
            ax3.set_xlabel('Batch')
            ax3.set_ylabel('Loss')
            ax3.set_title('Batch-level Losses')
            ax3.grid(True)
            
            # Gradient norms
            ax4.plot(history['gradient_norms'], 'r-', alpha=0.5)
            ax4.set_xlabel('Batch')
            ax4.set_ylabel('Gradient Norm')
            ax4.set_title('Gradient Norms')
            ax4.grid(True)
            
            plt.tight_layout()
            plt.show()
    
    # Final visualization - save figures
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(history['train_loss'])+1), history['train_loss'], 'b-', label='Training Loss')
    plt.plot(range(1, len(history['val_loss'])+1), history['val_loss'], 'r-', label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Training and Validation Loss ({loss_type} loss)')
    plt.legend()
    plt.grid(True)
    plt.savefig('training_history.png')
    
    # Load the best model weights
    model.load_state_dict(torch.load('best_model.pth'))
    
    return model, history



## Main Execution


In [None]:

def main(run_comprehensive_eval=True):
    # Model parameters (defaults)
    backbone_name = 'resnet18'  # Options: 'resnet18', 'resnet34', 'resnet50', 'efficientnet_b0', 'mobilenet_v2', 'densenet121'
    embedding_size = 128
    batch_size = 32  # Adjust based on your GPU
    num_workers = 2
    
    # Training parameters - reduced epochs with early stopping
    loss_type = 'triplet'  # Options: 'triplet', 'contrastive', 'arcface'
    num_epochs = 15  # Reduced from 20 since we now have early stopping
    lr = 1e-4
    weight_decay = 1e-5  # Added L2 regularization
    
    # Load data
    train_val_dataset, test_dataset, eval_train_dataset = load_oxford_pets_dataset()
    datasets_dict = prepare_datasets(train_val_dataset, test_dataset, eval_train_dataset)
    dataloaders = create_dataloaders(datasets_dict, batch_size=batch_size, num_workers=num_workers)
    
    # Create model
    model = EmbeddingNet(backbone_name=backbone_name, embedding_size=embedding_size)
    model = model.to(device)
    
    # Create optimizer and scheduler with new function
    optimizer, scheduler = get_optimizer_and_scheduler(
        model, 
        learning_rate=lr,
        weight_decay=weight_decay
    )
    
    # Train model with improved training loop (early stopping and better scheduler)
    model, history = train_model(
        model=model,
        dataloaders=dataloaders,
        loss_type=loss_type,
        optimizer=optimizer,
        scheduler=scheduler,
        num_epochs=num_epochs,
        embedding_size=embedding_size
    )
    
    # Plot training history
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(history['train_loss'])+1), history['train_loss'], 'b-', label='Training Loss')
    plt.plot(range(1, len(history['val_loss'])+1), history['val_loss'], 'r-', label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Training and Validation Loss ({loss_type} loss)')
    plt.legend()
    plt.grid(True)
    plt.savefig('training_history.png')
    plt.show()
    
    # Save model
    torch.save({
        'model_state_dict': model.state_dict(),
        'embedding_size': embedding_size,
        'backbone_name': backbone_name,
        'class_mapping': datasets_dict['class_mapping']
    }, f'pet_metric_learning_{backbone_name}_{loss_type}.pth')
    
    if run_comprehensive_eval:
        # Run comprehensive evaluation
        print("\nRunning comprehensive model evaluation...")
        results_dir = f"./evaluation_results_{backbone_name}_{loss_type}"
        os.makedirs(results_dir, exist_ok=True)
        
        evaluation_results, eval_dir = comprehensive_evaluation(
            model=model,
            dataloaders=dataloaders,
            class_mapping=datasets_dict['class_mapping'],
            results_dir=results_dir
        )
        
        print(f"\nComprehensive evaluation completed. Results saved to {eval_dir}")
    else:
        # Extract embeddings for standard evaluation
        print("\nExtracting embeddings for evaluation...")
        train_embeddings, train_labels = extract_embeddings(model, dataloaders['eval_train'])
        test_embeddings, test_labels = extract_embeddings(model, dataloaders['test'])
        few_shot_train_embeddings, few_shot_train_labels = extract_embeddings(model, dataloaders['few_shot_train'])
        few_shot_test_embeddings, few_shot_test_labels = extract_embeddings(model, dataloaders['few_shot_test'])
        
        # Evaluation tasks
        print("\n1. Verification Task:")
        verification_results = evaluate_verification(test_embeddings, test_labels)
        
        print(f"\nVerification Results:")
        print(f"ROC AUC: {verification_results['roc_auc']:.4f}")
        print(f"Equal Error Rate (EER): {verification_results['eer']:.4f}")
        
        print("\n2. Retrieval Task:")
        retrieval_results = evaluate_retrieval(
            query_embeddings=test_embeddings,
            query_labels=test_labels,
            gallery_embeddings=train_embeddings,
            gallery_labels=train_labels,
            k_values=[1, 5, 10]
        )
        
        print("\n3. Few-shot Classification:")
        # Run comprehensive few-shot evaluation across multiple settings
        few_shot_results = evaluate_multiple_few_shot_settings(
            support_embeddings=few_shot_train_embeddings,
            support_labels=few_shot_train_labels,
            query_embeddings=few_shot_test_embeddings,
            query_labels=few_shot_test_labels
        )
        
        # Embedding visualization
        print("\n4. Embedding Visualization:")
        test_projection = visualize_embeddings(
            embeddings=test_embeddings,
            labels=test_labels,
            class_mapping=datasets_dict['class_mapping'],
            method='tsne',
            title='t-SNE Visualization of Test Embeddings'
        )
        
        # Visualize few-shot embeddings
        print("\nVisualizing few-shot embeddings:")
        # Combine few-shot train and test embeddings for visualization
        all_few_shot_embeddings = torch.cat([few_shot_train_embeddings, few_shot_test_embeddings], dim=0)
        all_few_shot_labels = torch.cat([few_shot_train_labels, few_shot_test_labels], dim=0)
        
        few_shot_projection = visualize_embeddings(
            embeddings=all_few_shot_embeddings,
            labels=all_few_shot_labels,
            class_mapping=datasets_dict['class_mapping'],
            method='tsne',
            title='t-SNE Visualization of Few-Shot Embeddings'
        )
        
        # Enhanced Grad-CAM Visualization
        print("\n5. Enhanced Grad-CAM Visualization:")
        visualize_grad_cam_comparisons(model, dataloaders['test'], datasets_dict['class_mapping'], num_comparisons=2)
        
        # Original Grad-CAM Visualization
        print("\n6. Individual Grad-CAM Visualization:")
        visualize_grad_cam(model, dataloaders['test'], datasets_dict['class_mapping'], num_images=3)
    
    print("\nEvaluation completed!")

if __name__ == "__main__":
    main(run_comprehensive_eval=True)  # Set to False for standard evaluation

## Conclusion

In this notebook, we have implemented a comprehensive metric learning pipeline for pet breed classification using the Oxford-IIIT Pet Dataset. We have:

1. Built a custom embedding model with a CNN backbone and projection head that supports multiple architectures
2. Implemented various loss functions for metric learning (Triplet, Contrastive, ArcFace)
3. Added efficient training techniques including learning rate scheduling and early stopping
4. Developed comprehensive evaluation methods for verification, retrieval, and few-shot classification
5. Created detailed visualization tools for embedding spaces and feature importance (Grad-CAM)
6. Implemented hyperparameter tuning to optimize model performance
7. Enhanced the evaluation with advanced metrics and extensive analysis
8. Created a user-friendly Streamlit demo for practical use cases

The code is modular and can be easily adapted for different settings and experimentation. To run the complete training and evaluation pipeline, simply execute the `main()` function.

This project demonstrates the power of metric learning for fine-grained visual categorization and similarity search, with applications extending beyond pet breed classification to many other domains requiring visual similarity comparison.