<a href="https://colab.research.google.com/github/ndethi/opit-rai203-t2/blob/main/rai-8002-cv/assessment2/code/Charles_Watson_Ndethi_Kibaki-Code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Assessment 2: Metric Learning with Oxford-IIIT Pet Dataset

## Introduction and Setup

This notebook implements a deep metric learning approach for the Oxford-IIIT Pet Dataset, focusing on learning an embedding space where similar pet breeds are close together and dissimilar ones are far apart. We'll explore different loss functions, evaluate the model on verification, retrieval, and few-shot classification tasks, and visualize the embedding space.

### Environment Setup and Package Installation

In [None]:

# Check if running in Colab (to install dependencies and set up environment)
import sys
IN_COLAB = 'google.colab' in sys.modules

# Install required packages
if IN_COLAB:
    !pip install pytorch-metric-learning
    !pip install faiss-gpu
    !pip install umap-learn
    !pip install matplotlib seaborn scikit-learn tqdm
    !pip install gradio
    !pip install grad-cam

### Import Libraries

import os
import random
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.notebook import tqdm
from collections import defaultdict
import time
import json
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
from torchvision import datasets, models, transforms
import torchvision.transforms.functional as TF

import pytorch_metric_learning
from pytorch_metric_learning import losses, miners, distances, reducers, testers

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
from sklearn.model_selection import train_test_split
import umap

from torch.cuda.amp import autocast, GradScaler
from torch.nn.utils import clip_grad_norm_

# Update imports to include model weights
from torchvision.models import (ResNet18_Weights, ResNet34_Weights, ResNet50_Weights,
                               EfficientNet_B0_Weights, MobileNet_V2_Weights, DenseNet121_Weights)

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Optimize CUDA operations
if torch.cuda.is_available():
    # Enable cuDNN benchmarking for performance optimization with fixed-size inputs
    torch.backends.cudnn.benchmark = True
    # Deterministic algorithms are slower but ensure reproducible results
    # torch.backends.cudnn.deterministic = True  # Uncomment for strict reproducibility

    # Check if AMP is available (Mixed precision training)
    amp_available = True
    print(f"CUDA enabled with device: {torch.cuda.get_device_name(0)}")
    print(f"Using mixed precision training: {amp_available}")
else:
    amp_available = False
    print("CUDA not available. Using CPU only.")

print(f"Using device: {device}")


Collecting pytorch-metric-learning
  Downloading pytorch_metric_learning-2.8.1-py3-none-any.whl.metadata (18 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.6.0->pytorch-metric-learning)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.6.0->pytorch-metric-learning)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.6.0->pytorch-metric-learning)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.6.0->pytorch-metric-learning)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.6.0->pytorch-metric-learning)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux

### Mount Google Drive (if in Colab)

In [None]:
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    # Define base directory in Google Drive
    drive_base_dir = '/content/drive/MyDrive/ColabNotebooks/PetMetricLearning'
    os.makedirs(drive_base_dir, exist_ok=True)
    print(f"Google Drive mounted. Base directory: {drive_base_dir}")
else:
    # Define a local base directory if not in Colab
    drive_base_dir = './pet_metric_learning_results'
    os.makedirs(drive_base_dir, exist_ok=True)
    print(f"Not in Colab. Using local directory: {drive_base_dir}")



## Data Loading and Preprocessing

In this section, we'll load the Oxford-IIIT Pet Dataset, perform necessary preprocessing, and create appropriate data loaders for our metric learning tasks.

In [None]:

# Define transformations
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

eval_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Function to load the dataset
def load_oxford_pets_dataset(root="./data", download=True):
    train_val_dataset = datasets.OxfordIIITPet(
        root=root,
        split="trainval",
        transform=train_transform,
        download=download
    )

    test_dataset = datasets.OxfordIIITPet(
        root=root,
        split="test",
        transform=eval_transform,
        download=download
    )

    # For evaluation, create a version of the training set with eval transforms
    eval_train_dataset = datasets.OxfordIIITPet(
        root=root,
        split="trainval",
        transform=eval_transform,
        download=False
    )

    return train_val_dataset, test_dataset, eval_train_dataset



### Dataset Preparation for Different Tasks

In [None]:

# Split data for training, validation and few-shot evaluation
def prepare_datasets(train_val_dataset, test_dataset, eval_train_dataset, num_holdout_classes=5, val_ratio=0.2):
    # Get the class names
    class_to_idx = train_val_dataset.class_to_idx
    idx_to_class = {v: k for k, v in class_to_idx.items()}
    num_classes = len(class_to_idx)

    # Split classes for few-shot learning (hold out some classes for testing)
    all_class_indices = list(range(num_classes))
    holdout_class_indices = random.sample(all_class_indices, num_holdout_classes)
    training_class_indices = [i for i in all_class_indices if i not in holdout_class_indices]

    holdout_classes = [idx_to_class[i] for i in holdout_class_indices]
    print(f"Holdout classes for few-shot learning: {holdout_classes}")

    # Create datasets excluding holdout classes for main training
    train_val_indices = [i for i, (_, label) in enumerate(train_val_dataset) if label not in holdout_class_indices]
    test_indices = [i for i, (_, label) in enumerate(test_dataset) if label not in holdout_class_indices]
    eval_train_indices = [i for i, (_, label) in enumerate(eval_train_dataset) if label not in holdout_class_indices]

    # For few-shot learning, include only holdout classes
    few_shot_train_indices = [i for i, (_, label) in enumerate(train_val_dataset) if label in holdout_class_indices]
    few_shot_test_indices = [i for i, (_, label) in enumerate(test_dataset) if label in holdout_class_indices]

    # Split train/val
    train_indices, val_indices = train_test_split(
        train_val_indices,
        test_size=val_ratio,
        stratify=[train_val_dataset[i][1] for i in train_val_indices],
        random_state=42
    )

    # Create Subset datasets
    train_dataset = Subset(train_val_dataset, train_indices)
    val_dataset = Subset(train_val_dataset, val_indices)
    test_filtered_dataset = Subset(test_dataset, test_indices)
    eval_train_dataset = Subset(eval_train_dataset, eval_train_indices)

    # Create datasets for few-shot learning
    few_shot_train_dataset = Subset(train_val_dataset, few_shot_train_indices)
    few_shot_test_dataset = Subset(test_dataset, few_shot_test_indices)

    # Create dictionary for class mapping
    class_mapping = {
        'class_to_idx': class_to_idx,
        'idx_to_class': idx_to_class,
        'holdout_class_indices': holdout_class_indices,
        'training_class_indices': training_class_indices
    }

    return {
        'train': train_dataset,
        'val': val_dataset,
        'test': test_filtered_dataset,
        'eval_train': eval_train_dataset,
        'few_shot_train': few_shot_train_dataset,
        'few_shot_test': few_shot_test_dataset,
        'class_mapping': class_mapping
    }



### Create DataLoaders

In [None]:

def create_dataloaders(datasets_dict, batch_size=32, num_workers=4):
    dataloaders = {}

    for key in ['train', 'val', 'test', 'eval_train', 'few_shot_train', 'few_shot_test']:
        if key == 'train':
            shuffle = True
        else:
            shuffle = False

        dataloaders[key] = DataLoader(
            datasets_dict[key],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            pin_memory=True
        )

    return dataloaders




### Load and Prepare Data


In [None]:

# Load the dataset
train_val_dataset, test_dataset, eval_train_dataset = load_oxford_pets_dataset()
print(f"Train+Val size: {len(train_val_dataset)}")
print(f"Test size: {len(test_dataset)}")

# Prepare datasets for different tasks
datasets_dict = prepare_datasets(train_val_dataset, test_dataset, eval_train_dataset)

# Create dataloaders
batch_size = 32  # Adjust based on your GPU/memory constraints
dataloaders = create_dataloaders(datasets_dict, batch_size=batch_size)

# Print dataset statistics
print("\nDataset Statistics:")
for key, dataloader in dataloaders.items():
    print(f"{key}: {len(dataloader.dataset)} samples")

class_mapping = datasets_dict['class_mapping']
num_classes = len(class_mapping['class_to_idx'])
print(f"Total number of classes: {num_classes}")
print(f"Number of training classes: {len(class_mapping['training_class_indices'])}")
print(f"Number of few-shot classes: {len(class_mapping['holdout_class_indices'])}")




## Model Architecture

In this section, we'll define our metric learning model architecture using a CNN backbone and a projection head.


In [None]:

class EmbeddingNet(nn.Module):
    def __init__(self, backbone_name='resnet18', embedding_size=128, pretrained=True):
        super(EmbeddingNet, self).__init__()

        # Get backbone and its output size
        self.backbone, backbone_output_size = self._get_backbone(backbone_name, pretrained)

        # Projection head (MLP)
        self.projection_head = nn.Sequential(
            nn.Linear(backbone_output_size, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, embedding_size)
        )

    def _get_backbone(self, backbone_name, pretrained):
        """
        Create a backbone network from various architectures
        """
        if backbone_name == 'resnet18':
            weights = ResNet18_Weights.DEFAULT if pretrained else None
            backbone = models.resnet18(weights=weights)
            output_size = 512
        elif backbone_name == 'resnet34':
            weights = ResNet34_Weights.DEFAULT if pretrained else None
            backbone = models.resnet34(weights=weights)
            output_size = 512
        elif backbone_name == 'resnet50':
            weights = ResNet50_Weights.DEFAULT if pretrained else None
            backbone = models.resnet50(weights=weights)
            output_size = 2048
        elif backbone_name == 'efficientnet_b0':
            weights = EfficientNet_B0_Weights.DEFAULT if pretrained else None
            backbone = models.efficientnet_b0(weights=weights)
            output_size = 1280
        elif backbone_name == 'mobilenet_v2':
            weights = MobileNet_V2_Weights.DEFAULT if pretrained else None
            backbone = models.mobilenet_v2(weights=weights)
            output_size = 1280
        elif backbone_name == 'densenet121':
            weights = DenseNet121_Weights.DEFAULT if pretrained else None
            backbone = models.densenet121(weights=weights)
            output_size = 1024
        else:
            raise ValueError(f"Unsupported backbone: {backbone_name}")

        # For ResNet models
        if backbone_name.startswith('resnet'):
            # Remove the classification layer
            backbone = nn.Sequential(*list(backbone.children())[:-1])
        # For EfficientNet
        elif backbone_name.startswith('efficientnet'):
            backbone = nn.Sequential(*list(backbone.children())[:-1])
        # For MobileNet
        elif backbone_name.startswith('mobilenet'):
            backbone = nn.Sequential(*list(backbone.children())[:-1])
        # For DenseNet
        elif backbone_name.startswith('densenet'):
            backbone = nn.Sequential(
                backbone.features,
                nn.ReLU(inplace=True),
                nn.AdaptiveAvgPool2d((1, 1))
            )

        return backbone, output_size

    def forward(self, x):
        features = self.backbone(x)
        features = features.view(features.size(0), -1)
        embeddings = self.projection_head(features)

        # Normalize embeddings to unit length (important for cosine distance)
        normalized_embeddings = F.normalize(embeddings, p=2, dim=1)
        return normalized_embeddings

    def get_embedding(self, x):
        return self.forward(x)


## Loss Function Implementation

Here we'll implement several loss functions for metric learning including Triplet Loss, Contrastive Loss, and ArcFace. We'll also implement miners for efficient training.


In [None]:

def create_loss_and_miner(loss_type, margin=0.2, embedding_size=128, num_classes=32):
    """
    Create loss function and miner for metric learning
    """
    if loss_type == 'triplet':
        # Triplet loss with cosine distance
        distance = distances.CosineSimilarity()
        reducer = reducers.ThresholdReducer(low=0)
        loss_func = losses.TripletMarginLoss(margin=margin, distance=distance, reducer=reducer)
        mining_func = miners.TripletMarginMiner(margin=margin, distance=distance, type_of_triplets="semihard")

    elif loss_type == 'contrastive':
        # Contrastive loss
        distance = distances.CosineSimilarity()
        loss_func = losses.ContrastiveLoss(pos_margin=0.8, neg_margin=0.2, distance=distance)
        mining_func = miners.PairMarginMiner(pos_margin=0.8, neg_margin=0.2, distance=distance)

    elif loss_type == 'arcface':
        # ArcFace loss
        loss_func = losses.ArcFaceLoss(embedding_size, num_classes, margin=28.6, scale=64)
        mining_func = None

    else:
        raise ValueError(f"Unsupported loss type: {loss_type}")

    return loss_func, mining_func




### Hard Negative Mining (Bonus Implementation)

In [None]:

class HardNegativePairMiner(miners.BaseMiner):
    def __init__(self, distance, neg_margin=0.2, hardest_fraction=0.5):
        super().__init__()
        self.distance = distance
        self.neg_margin = neg_margin
        self.hardest_fraction = hardest_fraction

    def mine(self, embeddings, labels, ref_emb=None, ref_labels=None):
        ref_emb, ref_labels = embeddings, labels
        dist_mat = self.distance(embeddings, ref_emb)

        # Get negative pairs (different classes)
        negative_mask = labels.unsqueeze(1) != ref_labels.unsqueeze(0)

        # For each anchor, find all negative pairs
        anchors, negatives = torch.where(negative_mask)

        if len(anchors) == 0:
            return empty_tensor(0), empty_tensor(0), empty_tensor(0), empty_tensor(0)

        # Get distances for all negative pairs
        distances = dist_mat[anchors, negatives]

        # Group by anchor
        anchor_groups = defaultdict(list)
        for i in range(len(anchors)):
            anchor_groups[anchors[i].item()].append((negatives[i].item(), distances[i].item()))

        # For each anchor, select the hardest negatives
        hard_a, hard_n = [], []
        for anchor, neg_dists in anchor_groups.items():
            # Sort negatives by distance (ascending for hardest cosine similarity)
            neg_dists.sort(key=lambda x: x[1], reverse=True)

            # Select hardest fraction
            num_to_select = max(1, int(len(neg_dists) * self.hardest_fraction))
            selected_negs = neg_dists[:num_to_select]

            for neg, dist in selected_negs:
                hard_a.append(anchor)
                hard_n.append(neg)

        return (
            torch.tensor(hard_a, device=embeddings.device),
            empty_tensor(0),
            empty_tensor(0),
            torch.tensor(hard_n, device=embeddings.device)
        )

def empty_tensor(size):
    return torch.tensor([], device=device, dtype=torch.long).view(size)




## Training Pipeline

Next, we'll implement the training pipeline for our metric learning model, including early stopping and checkpointing to Google Drive.


In [None]:
# Early Stopping Class
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='best_model.pth', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                           Default: 0
            path (str): Path for the best model checkpoint to be saved to.
                        Default: 'best_model.pth'
            trace_func (function): trace print function.
                                   Default: print
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.inf  # Use np.inf instead of np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving best model to {self.path} ...')
        # Ensure the directory exists
        os.makedirs(os.path.dirname(self.path), exist_ok=True)
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [None]:

import glob # Import glob for finding checkpoint files

def train_model(model, dataloaders, loss_type, optimizer=None, scheduler=None, num_epochs=15, embedding_size=128, checkpoint_dir='checkpoints'):
    """
    Train the metric learning model with enhanced monitoring, optimizations, and checkpointing.
    """
    # Ensure checkpoint directory exists
    os.makedirs(checkpoint_dir, exist_ok=True)
    best_model_path = os.path.join(checkpoint_dir, 'best_model.pth')

    # Setup for mixed precision training
    scaler = torch.amp.GradScaler('cuda', enabled=amp_available) if torch.cuda.is_available() else None

    # Get the number of training classes (excluding holdout classes)
    num_training_classes = len(datasets_dict['class_mapping']['training_class_indices'])

    # Create loss function and miner
    loss_func, mining_func = create_loss_and_miner(
        loss_type=loss_type,
        embedding_size=embedding_size,
        num_classes=num_training_classes
    )

    # If using ArcFace, we need to create a class map for the training dataset
    if loss_type == 'arcface':
        # Map original class indices to consecutive integers for ArcFace
        class_map = {original: i for i, original in enumerate(datasets_dict['class_mapping']['training_class_indices'])}

    # Create optimizer and scheduler if not provided
    if optimizer is None:
        # Define default optimizer and scheduler creation logic if needed
        # For now, assume they are passed or handle error
        raise ValueError("Optimizer must be provided for checkpoint loading.")
        # optimizer, auto_scheduler = get_optimizer_and_scheduler(model) # Example
        # scheduler = auto_scheduler if scheduler is None else scheduler # Example

    # Initialize training history and starting epoch
    history = {
        'train_loss': [],
        'val_loss': [],
        'lr': [],
        'batch_losses': [],
        'gradient_norms': [],
        'epoch_times': []
    }
    start_epoch = 0

    # --- Checkpoint Loading ---
    latest_checkpoint_path = None
    checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_epoch_*.pth'))
    if checkpoint_files:
        # Find the checkpoint with the highest epoch number
        latest_epoch = -1
        for f in checkpoint_files:
            try:
                epoch_num = int(os.path.basename(f).split('_')[-1].split('.')[0])
                if epoch_num > latest_epoch:
                    latest_epoch = epoch_num
                    latest_checkpoint_path = f
            except ValueError:
                continue # Ignore files that don't match the pattern

    if latest_checkpoint_path:
        print(f"Resuming training from checkpoint: {latest_checkpoint_path}")
        checkpoint = torch.load(latest_checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if scheduler and 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch']
        if 'history' in checkpoint:
             history = checkpoint['history'] # Load history if saved
        # Ensure loaded states are on the correct device
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
        print(f"Resumed from epoch {start_epoch}")
    else:
        print("No checkpoint found, starting training from scratch.")
    # --- End Checkpoint Loading ---

    # Initialize early stopping
    early_stopping = EarlyStopping(patience=5, verbose=True, path=best_model_path)
    # Load previous best score if resuming
    if start_epoch > 0 and history['val_loss']:
        early_stopping.val_loss_min = min(history['val_loss'])
        early_stopping.best_score = -early_stopping.val_loss_min
        print(f"Loaded previous best validation loss: {early_stopping.val_loss_min:.6f}")

    # Training loop
    # history dictionary is already initialized or loaded

    # Create experiment folder for saving results (if needed, separate from checkpoints)
    # backbone_name = model.backbone.__class__.__name__ # Get backbone name dynamically if possible
    # experiment_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{loss_type}_{backbone_name}"
    # experiment_dir = os.path.join("experiments", experiment_id)
    # os.makedirs(experiment_dir, exist_ok=True)

    for epoch in range(start_epoch, num_epochs):
        epoch_start = time.time()
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        epoch_train_losses = [] # Track batch losses for this epoch's history
        epoch_grad_norms = []   # Track grad norms for this epoch's history

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            # batch_losses = [] # Moved epoch_train_losses outside phase loop
            # grad_norms = []   # Moved epoch_grad_norms outside phase loop

            # Iterate over data
            for inputs, labels in tqdm(dataloaders[phase], desc=phase):
                inputs = inputs.to(device, non_blocking=True)  # non_blocking for asynchronous transfer
                labels = labels.to(device, non_blocking=True)

                # Map labels for ArcFace if needed
                if loss_type == 'arcface':
                    # Filter out samples from holdout classes
                    valid_idx = torch.tensor([i for i, l in enumerate(labels) if l.item() in class_map], device=device)
                    if len(valid_idx) == 0:
                        continue

                    inputs = inputs[valid_idx]
                    arcface_labels = torch.tensor([class_map[l.item()] for l in labels[valid_idx]], device=device)
                    labels = arcface_labels

                # Zero the parameter gradients
                optimizer.zero_grad(set_to_none=True)  # set_to_none=True is more memory efficient

                # Forward pass
                with torch.set_grad_enabled(phase == 'train'):
                    # Use autocast for both train and val if scaler is enabled
                    with torch.amp.autocast(device_type=device.type, enabled=(scaler is not None)):
                        embeddings = model(inputs)

                        # Get indices for mining if using a mining function
                        if mining_func is not None:
                            hard_pairs = mining_func(embeddings, labels)
                            loss = loss_func(embeddings, labels, hard_pairs)
                        else:
                            loss = loss_func(embeddings, labels)

                    # Backward pass + optimize only if in training phase
                    if phase == 'train':
                        if scaler is not None:
                            # Backward pass with gradient scaling
                            scaler.scale(loss).backward()

                            # Compute gradient norm after unscaling
                            scaler.unscale_(optimizer)
                            total_norm = clip_grad_norm_(model.parameters(), max_norm=1.0)
                            epoch_grad_norms.append(total_norm.item())

                            # Optimizer step with scaler
                            scaler.step(optimizer)
                            scaler.update()
                        else:
                            # Regular backward pass without AMP
                            loss.backward()
                            total_norm = clip_grad_norm_(model.parameters(), max_norm=1.0)
                            epoch_grad_norms.append(total_norm.item())
                            optimizer.step()

                # Record batch loss
                batch_loss = loss.item()
                if phase == 'train':
                    epoch_train_losses.append(batch_loss)
                running_loss += batch_loss * inputs.size(0)

            # Calculate epoch loss
            epoch_loss = running_loss / len(dataloaders[phase].dataset)

            if phase == 'train':
                current_lr = optimizer.param_groups[0]['lr']
                history['lr'].append(current_lr)
                history['train_loss'].append(epoch_loss)
                history['batch_losses'].extend(epoch_train_losses) # Append all batch losses for the epoch
                if epoch_grad_norms:
                    history['gradient_norms'].extend(epoch_grad_norms) # Append all grad norms for the epoch
                    print(f'{phase} Loss: {epoch_loss:.4f}, LR: {current_lr:.6f}, Grad Norm: {np.mean(epoch_grad_norms):.4f}')
                else:
                    print(f'{phase} Loss: {epoch_loss:.4f}, LR: {current_lr:.6f}')
            else:
                history['val_loss'].append(epoch_loss)
                print(f'{phase} Loss: {epoch_loss:.4f}')

                # Update learning rate scheduler based on validation loss
                if scheduler is not None and isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                    scheduler.step(epoch_loss)

                # Check early stopping
                early_stopping(epoch_loss, model)
                if early_stopping.early_stop:
                    print("Early stopping triggered")
                    break # Break inner loop (phases)

        # Update step-based schedulers at the end of each epoch
        if scheduler is not None and not isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step()

        # Record epoch time
        epoch_end = time.time()
        epoch_time = epoch_end - epoch_start
        history['epoch_times'].append(epoch_time)
        print(f"Epoch completed in {epoch_time:.2f}s")

        # --- Save Checkpoint After Each Epoch ---
        checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pth")
        save_dict = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'embedding_size': embedding_size,
            # 'backbone_name': backbone_name, # Can be derived from model or config
            'class_mapping': datasets_dict['class_mapping'],
            'history': history # Save history in checkpoint
        }
        if scheduler:
             save_dict['scheduler_state_dict'] = scheduler.state_dict()
        torch.save(save_dict, checkpoint_path)
        print(f"Saved checkpoint to {checkpoint_path}")
        # --- End Save Checkpoint ---

        # Check if early stopping was triggered in the inner loop
        if early_stopping.early_stop:
            print("Training stopped early due to no improvement in validation loss")
            break # Break outer loop (epochs)

        # Optional: Remove older checkpoints to save space
        # Keep maybe last 5 checkpoints + best model
        # ... (implementation omitted for brevity)

    # Load the best model weights saved by EarlyStopping
    if os.path.exists(best_model_path):
        print(f"Loading best model from {best_model_path}")
        model.load_state_dict(torch.load(best_model_path, map_location=device))
    else:
        print("Warning: Best model file not found. Using the model state from the last epoch.")

    # Save final training history (optional, as it's saved in checkpoints)
    history_path = os.path.join(checkpoint_dir, 'training_history_final.json')
    with open(history_path, 'w') as f:
        # Convert any non-serializable items in history
        serializable_history = {
            k: v if isinstance(v, list) and all(isinstance(x, (int, float)) for x in v) else str(v)
            for k, v in history.items()
        }
        json.dump(serializable_history, f, indent=2)
        print(f"Final training history saved to {history_path}")

    return model, history




### Embedding Extraction Function

In [None]:

def extract_embeddings(model, dataloader):
    """
    Extract embeddings for a dataset
    """
    model.eval()
    embeddings = []
    labels = []

    with torch.no_grad():
        for inputs, batch_labels in tqdm(dataloader, desc="Extracting embeddings"):
            inputs = inputs.to(device)
            batch_embeddings = model(inputs)
            embeddings.append(batch_embeddings.cpu())
            labels.append(batch_labels)

    embeddings = torch.cat(embeddings, dim=0)
    labels = torch.cat(labels, dim=0)

    return embeddings, labels



## Evaluation Functions

### 1. Verification Task



In [None]:

def create_verification_pairs(embeddings, labels, num_pos_pairs=1000, num_neg_pairs=1000):
    """
    Create positive and negative pairs for verification task
    """
    unique_labels = torch.unique(labels)
    pairs = []
    pair_labels = []

    # Generate positive pairs (same class)
    pos_pair_count = 0
    for label in unique_labels:
        indices = torch.where(labels == label)[0]
        if len(indices) >= 2:
            for i in range(min(num_pos_pairs // len(unique_labels) + 1, len(indices) // 2)):
                idx1, idx2 = np.random.choice(indices, 2, replace=False)
                pairs.append((idx1.item(), idx2.item()))
                pair_labels.append(1)  # 1 for same class
                pos_pair_count += 1
                if pos_pair_count >= num_pos_pairs:
                    break
        if pos_pair_count >= num_pos_pairs:
            break

    # Generate negative pairs (different classes)
    neg_pair_count = 0
    while neg_pair_count < num_neg_pairs:
        label1, label2 = np.random.choice(unique_labels, 2, replace=False)
        indices1 = torch.where(labels == label1)[0]
        indices2 = torch.where(labels == label2)[0]

        if len(indices1) > 0 and len(indices2) > 0:
            idx1 = np.random.choice(indices1)
            idx2 = np.random.choice(indices2)
            pairs.append((idx1.item(), idx2.item()))
            pair_labels.append(0)  # 0 for different class
            neg_pair_count += 1

    return np.array(pairs), np.array(pair_labels)

def evaluate_verification(embeddings, labels):
    """
    Evaluate the model on verification task (same/different class)
    """
    pairs, pair_labels = create_verification_pairs(embeddings, labels)

    # Compute distances between pairs
    distances = []
    for idx1, idx2 in pairs:
        # Using cosine similarity (-1 to 1) where higher value means more similar
        distance = F.cosine_similarity(
            embeddings[idx1].unsqueeze(0),
            embeddings[idx2].unsqueeze(0)
        ).item()
        distances.append(distance)

    distances = np.array(distances)

    # Compute ROC curve and AUC
    # Note: For cosine similarity, higher means more similar, so we need to negate it for ROC curve
    fpr, tpr, thresholds = roc_curve(pair_labels, distances)
    roc_auc = auc(fpr, tpr)

    # Compute Equal Error Rate (EER)
    fnr = 1 - tpr
    eer_threshold = thresholds[np.nanargmin(np.abs(fnr - fpr))]
    eer = fpr[np.nanargmin(np.abs(fnr - fpr))]

    # Plot ROC curve
    plt.figure(figsize=(10, 8))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC)')
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.savefig('verification_roc_curve.png')
    plt.show()

    return {
        'roc_auc': roc_auc,
        'eer': eer,
        'eer_threshold': eer_threshold,
        'pairs': pairs,
        'pair_labels': pair_labels,
        'distances': distances
    }




### 2. Retrieval Task


In [None]:

def evaluate_retrieval(query_embeddings, query_labels, gallery_embeddings, gallery_labels, k_values=[1, 5, 10]):
    """
    Evaluate the model on retrieval task
    """
    results = {}

    for k in k_values:
        # Compute similarity matrix
        similarity_matrix = torch.matmul(query_embeddings, gallery_embeddings.T)

        # Get top-k indices for each query
        _, indices = torch.topk(similarity_matrix, k=k, dim=1)

        # Compute Recall@K and Precision@K
        recall_k = 0
        precision_k = 0

        for i, query_label in enumerate(query_labels):
            retrieved_labels = gallery_labels[indices[i]]
            relevant = (retrieved_labels == query_label).float()

            # Recall@K: How many of the relevant items are retrieved
            recall_k += (relevant.sum() > 0).float().item()

            # Precision@K: How many of the retrieved items are relevant
            precision_k += (relevant.sum() / k).item()

        recall_k /= len(query_labels)
        precision_k /= len(query_labels)

        results[f'recall@{k}'] = recall_k
        results[f'precision@{k}'] = precision_k

        print(f"Recall@{k}: {recall_k:.4f}")
        print(f"Precision@{k}: {precision_k:.4f}")

    return results




### 3. Few-shot Classification

In [None]:

def evaluate_few_shot(support_embeddings, support_labels, query_embeddings, query_labels, n_way=5, k_shot=5):
    """
    Evaluate the model on n-way k-shot classification
    """
    unique_labels = torch.unique(support_labels)
    if len(unique_labels) < n_way:
        print(f"Warning: Only {len(unique_labels)} classes available, but n_way={n_way}")
        n_way = len(unique_labels)

    # Randomly select n classes
    selected_classes = np.random.choice(unique_labels.numpy(), n_way, replace=False)

    accuracies = []

    # Run multiple episodes for stable results
    num_episodes = 50
    for episode in range(num_episodes):
        # Create support set (k examples per class)
        support_set_embeddings = []
        support_set_labels = []

        for class_idx, c in enumerate(selected_classes):
            # Get indices of examples of class c
            class_indices = torch.where(support_labels == c)[0]

            # Randomly select k examples
            if len(class_indices) >= k_shot:
                selected_indices = np.random.choice(class_indices.numpy(), k_shot, replace=False)
            else:
                # If not enough examples, use all and repeat some
                selected_indices = np.random.choice(class_indices.numpy(), k_shot, replace=True)

            for idx in selected_indices:
                support_set_embeddings.append(support_embeddings[idx])
                support_set_labels.append(class_idx)  # Use class index as the new label

        support_set_embeddings = torch.stack(support_set_embeddings)
        support_set_labels = torch.tensor(support_set_labels)

        # Create query set (all examples of the selected classes from the query set)
        query_set_indices = torch.tensor([i for i, label in enumerate(query_labels) if label in selected_classes])

        if len(query_set_indices) == 0:
            print("Warning: No query examples for selected classes")
            continue

        query_set_embeddings = query_embeddings[query_set_indices]
        query_set_labels = query_labels[query_set_indices]

        # Map original labels to new indices (0 to n_way-1)
        label_mapping = {selected_classes[i]: i for i in range(n_way)}
        query_set_labels = torch.tensor([label_mapping[label.item()] for label in query_set_labels])

        # Compute prototypes (mean embedding for each class)
        prototypes = torch.zeros(n_way, support_embeddings.size(1), device=support_embeddings.device)
        for c in range(n_way):
            prototypes[c] = support_set_embeddings[support_set_labels == c].mean(0)

        # Compute distances between query examples and prototypes
        # Using cosine similarity (higher means more similar)
        logits = torch.matmul(query_set_embeddings, prototypes.T)

        # Make predictions
        _, predictions = torch.max(logits, dim=1)

        # Compute accuracy
        accuracy = (predictions == query_set_labels).float().mean().item()
        accuracies.append(accuracy)

    mean_accuracy = np.mean(accuracies)
    std_accuracy = np.std(accuracies)

    print(f"{n_way}-way {k_shot}-shot classification accuracy: {mean_accuracy:.4f} ± {std_accuracy:.4f}")

    return {
        'mean_accuracy': mean_accuracy,
        'std_accuracy': std_accuracy,
        'accuracies': accuracies
    }




## Embedding Visualization


In [None]:


def visualize_embeddings(embeddings, labels, class_mapping, method='tsne', title='Embedding Visualization'):
    """
    Visualize embeddings using t-SNE or UMAP
    """
    idx_to_class = class_mapping['idx_to_class']

    # Reduce dimensionality
    if method == 'tsne':
        print("Computing t-SNE projection...")
        projection = TSNE(n_components=2, random_state=42).fit_transform(embeddings.numpy())
    elif method == 'umap':
        print("Computing UMAP projection...")
        projection = umap.UMAP(n_components=2, random_state=42).fit_transform(embeddings.numpy())
    else:
        raise ValueError(f"Unsupported visualization method: {method}")

    # Create plot
    plt.figure(figsize=(14, 10))

    # Get unique labels
    unique_labels = torch.unique(labels).numpy()

    # Create colormap
    cmap = plt.cm.get_cmap('tab20', len(unique_labels))

    # Plot each class
    for i, label in enumerate(unique_labels):
        mask = labels.numpy() == label
        plt.scatter(
            projection[mask, 0],
            projection[mask, 1],
            c=[cmap(i)],
            label=idx_to_class[label],
            alpha=0.7,
            s=50
        )

    plt.title(title, fontsize=18)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
    plt.tight_layout()
    plt.savefig(f'{method}_visualization.png')
    plt.show()

    return projection




## Grad-CAM Visualization (Bonus Implementation)



In [None]:


def visualize_grad_cam(model, dataloader, class_mapping, num_images=5):
    """
    Visualize Grad-CAM attention maps
    """
    # Import GradCAM implementation
    try:
        from pytorch_grad_cam import GradCAM
        from pytorch_grad_cam.utils.image import show_cam_on_image
    except ImportError:
        print("Please install pytorch-grad-cam to use this function:")
        print("!pip install grad-cam")
        return

    # Set up GradCAM
    target_layers = [model.backbone[-2][-1].conv2]  # Last conv layer for ResNet
    # Initialize GradCAM without use_cuda argument
    grad_cam = GradCAM(model=model, target_layers=target_layers)

    # Get a batch of images
    images, labels = next(iter(dataloader))
    images = images[:num_images]
    labels = labels[:num_images]

    # Convert images for visualization
    orig_images = []
    for i in range(len(images)):
        img = images[i].permute(1, 2, 0).cpu().numpy()
        img = (img * np.array([0.229, 0.224, 0.225])) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        orig_images.append(img)

    # Generate class activation maps
    cam_images = []
    for i in range(len(images)):
        input_tensor = images[i].unsqueeze(0).to(device)

        # Get embeddings for target image
        embedding = model(input_tensor)

        # Generate GradCAM
        grayscale_cam = grad_cam(input_tensor)
        grayscale_cam = grayscale_cam[0, :]

        # Overlay on original image
        cam_image = show_cam_on_image(orig_images[i], grayscale_cam, use_rgb=True)
        cam_images.append(cam_image)

    # Plot original images and their activation maps
    plt.figure(figsize=(15, 4 * num_images))
    for i in range(num_images):
        # Original image
        plt.subplot(num_images, 2, 2*i+1)
        plt.imshow(orig_images[i])
        plt.title(f"Original: {class_mapping['idx_to_class'][labels[i].item()]}")
        plt.axis('off')

        # GradCAM
        plt.subplot(num_images, 2, 2*i+2)
        plt.imshow(cam_images[i])
        plt.title("GradCAM")
        plt.axis('off')

    plt.tight_layout()
    plt.savefig('grad_cam_visualization.png')
    plt.show()




## Main Execution


In [None]:

def main():
    # Model parameters
    backbone_name = 'resnet18'  # Options: 'resnet18', 'resnet50'
    embedding_size = 128
    batch_size = 32  # Adjust based on your GPU

    # Training parameters
    loss_type = 'triplet'  # Options: 'triplet', 'contrastive', 'arcface'
    num_epochs = 20
    lr = 1e-4

    # --- Checkpoint Directory ---
    # Use the drive_base_dir defined earlier
    checkpoint_base_dir = os.path.join(drive_base_dir, f'{backbone_name}_{loss_type}')
    os.makedirs(checkpoint_base_dir, exist_ok=True)
    print(f"Checkpoints will be saved in: {checkpoint_base_dir}")
    # --- End Checkpoint Directory ---

    # Load data
    train_val_dataset, test_dataset, eval_train_dataset = load_oxford_pets_dataset()
    datasets_dict = prepare_datasets(train_val_dataset, test_dataset, eval_train_dataset)
    dataloaders = create_dataloaders(datasets_dict, batch_size=batch_size)

    # Create model
    model = EmbeddingNet(backbone_name=backbone_name, embedding_size=embedding_size)
    model = model.to(device)

    # Create optimizer and scheduler
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

    # Train model
    model, history = train_model(
        model=model,
        dataloaders=dataloaders,
        loss_type=loss_type,
        optimizer=optimizer,
        scheduler=scheduler,
        num_epochs=num_epochs,
        embedding_size=embedding_size,
        checkpoint_dir=checkpoint_base_dir # Pass checkpoint directory
    )

    # Plot training history
    plt.figure(figsize=(10, 6))
    # Adjust epoch range for plotting if training was resumed
    epochs_completed = len(history['train_loss'])
    epoch_range = range(1, epochs_completed + 1)
    plt.plot(epoch_range, history['train_loss'], 'b-', label='Training Loss')
    plt.plot(epoch_range, history['val_loss'], 'r-', label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Training and Validation Loss ({loss_type} loss)')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(checkpoint_base_dir, 'training_history.png')) # Save plot to drive
    plt.show()

    # Extract embeddings for evaluation
    print("\nExtracting embeddings for evaluation...")
    train_embeddings, train_labels = extract_embeddings(model, dataloaders['eval_train'])
    test_embeddings, test_labels = extract_embeddings(model, dataloaders['test'])
    few_shot_train_embeddings, few_shot_train_labels = extract_embeddings(model, dataloaders['few_shot_train'])
    few_shot_test_embeddings, few_shot_test_labels = extract_embeddings(model, dataloaders['few_shot_test'])

    # Evaluation tasks
    print("\n1. Verification Task:")
    verification_results = evaluate_verification(test_embeddings, test_labels)
    # Save verification plot
    plt.savefig(os.path.join(checkpoint_base_dir, 'verification_roc_curve.png'))

    print(f"\nVerification Results:")
    print(f"ROC AUC: {verification_results['roc_auc']:.4f}")
    print(f"Equal Error Rate (EER): {verification_results['eer']:.4f}")

    print("\n2. Retrieval Task:")
    retrieval_results = evaluate_retrieval(
        query_embeddings=test_embeddings,
        query_labels=test_labels,
        gallery_embeddings=train_embeddings,
        gallery_labels=train_labels,
        k_values=[1, 5, 10]
    )

    print("\n3. Few-shot Classification:")
    few_shot_results = evaluate_few_shot(
        support_embeddings=few_shot_train_embeddings,
        support_labels=few_shot_train_labels,
        query_embeddings=few_shot_test_embeddings,
        query_labels=few_shot_test_labels,
        n_way=5,
        k_shot=5
    )

    # Embedding visualization
    print("\n4. Embedding Visualization:")
    test_projection = visualize_embeddings(
        embeddings=test_embeddings,
        labels=test_labels,
        class_mapping=datasets_dict['class_mapping'],
        method='tsne',
        title='t-SNE Visualization of Test Embeddings'
    )
    # Save visualization plot
    plt.savefig(os.path.join(checkpoint_base_dir, 'tsne_visualization_test.png'))

    # Visualize few-shot embeddings
    print("\nVisualizing few-shot embeddings:")
    # Combine few-shot train and test embeddings for visualization
    all_few_shot_embeddings = torch.cat([few_shot_train_embeddings, few_shot_test_embeddings], dim=0)
    all_few_shot_labels = torch.cat([few_shot_train_labels, few_shot_test_labels], dim=0)

    few_shot_projection = visualize_embeddings(
        embeddings=all_few_shot_embeddings,
        labels=all_few_shot_labels,
        class_mapping=datasets_dict['class_mapping'],
        method='tsne',
        title='t-SNE Visualization of Few-Shot Embeddings'
    )
    # Save visualization plot
    plt.savefig(os.path.join(checkpoint_base_dir, 'tsne_visualization_fewshot.png'))

    # Bonus: Grad-CAM Visualization
    print("\n5. Grad-CAM Visualization:")
    visualize_grad_cam(model, dataloaders['test'], datasets_dict['class_mapping'], num_images=3)
    # Save Grad-CAM plot
    plt.savefig(os.path.join(checkpoint_base_dir, 'grad_cam_visualization.png'))

    # Save final model (best model is already saved by EarlyStopping in checkpoint_dir)
    final_model_path = os.path.join(checkpoint_base_dir, f'pet_metric_learning_{backbone_name}_{loss_type}_final.pth')
    torch.save({
        'model_state_dict': model.state_dict(), # Save the state after loading the best model
        'embedding_size': embedding_size,
        'backbone_name': backbone_name,
        'class_mapping': datasets_dict['class_mapping']
    }, final_model_path)
    print(f"Final model state saved to {final_model_path}")

    # Save evaluation results
    eval_results = {
        'verification': verification_results,
        'retrieval': retrieval_results,
        'few_shot': few_shot_results
    }
    eval_results_path = os.path.join(checkpoint_base_dir, 'evaluation_results.json')
    # Make results JSON serializable (convert tensors/numpy arrays if necessary)
    # ... (implementation depends on exact structure of results dicts)
    # For simplicity, saving basic metrics
    simple_eval_results = {
        'verification_roc_auc': verification_results.get('roc_auc'),
        'verification_eer': verification_results.get('eer'),
        'retrieval_recall_1': retrieval_results.get('recall@1'),
        'retrieval_precision_1': retrieval_results.get('precision@1'),
        'few_shot_mean_accuracy': few_shot_results.get('mean_accuracy'),
        'few_shot_std_accuracy': few_shot_results.get('std_accuracy')
    }
    with open(eval_results_path, 'w') as f:
        json.dump(simple_eval_results, f, indent=2)
    print(f"Evaluation results saved to {eval_results_path}")

    print("\nEvaluation completed!")

if __name__ == "__main__":
    main()



## Bonus: Multiple Loss Function Comparison

In [None]:

def compare_loss_functions():
    """
    Compare different loss functions for metric learning
    """
    # Model parameters
    backbone_name = 'resnet18'
    embedding_size = 128
    batch_size = 32
    num_epochs = 15 # Use fewer epochs for comparison run
    lr = 1e-4

    # Loss functions to compare
    loss_types = ['triplet', 'contrastive', 'arcface']

    # Load data (only once)
    train_val_dataset, test_dataset, eval_train_dataset = load_oxford_pets_dataset()
    datasets_dict = prepare_datasets(train_val_dataset, test_dataset, eval_train_dataset)
    dataloaders = create_dataloaders(datasets_dict, batch_size=batch_size)

    results = {}
    comparison_base_dir = os.path.join(drive_base_dir, 'loss_comparison')
    os.makedirs(comparison_base_dir, exist_ok=True)

    for loss_type in loss_types:
        print(f"\n{'=' * 40}")
        print(f"Training with {loss_type} loss")
        print(f"{'=' * 40}")

        # Define checkpoint directory for this specific run
        checkpoint_dir = os.path.join(comparison_base_dir, f'{backbone_name}_{loss_type}')
        os.makedirs(checkpoint_dir, exist_ok=True)

        # Create model
        model = EmbeddingNet(backbone_name=backbone_name, embedding_size=embedding_size)
        model = model.to(device)

        # Create optimizer and scheduler
        optimizer = optim.Adam(model.parameters(), lr=lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

        # Train model
        model, history = train_model(
            model=model,
            dataloaders=dataloaders,
            loss_type=loss_type,
            optimizer=optimizer,
            scheduler=scheduler,
            num_epochs=num_epochs,
            embedding_size=embedding_size,
            checkpoint_dir=checkpoint_dir # Use specific checkpoint dir
        )

        # Extract embeddings for evaluation
        print("\nExtracting embeddings for evaluation...")
        train_embeddings, train_labels = extract_embeddings(model, dataloaders['eval_train'])
        test_embeddings, test_labels = extract_embeddings(model, dataloaders['test'])

        # Evaluation
        verification_results = evaluate_verification(test_embeddings, test_labels)
        plt.savefig(os.path.join(checkpoint_dir, 'verification_roc_curve.png')) # Save plot
        plt.close() # Close plot to avoid displaying multiple times

        retrieval_results = evaluate_retrieval(
            query_embeddings=test_embeddings,
            query_labels=test_labels,
            gallery_embeddings=train_embeddings,
            gallery_labels=train_labels,
            k_values=[1, 5]
        )

        # Store results
        results[loss_type] = {
            'verification': {
                'roc_auc': verification_results['roc_auc'],
                'eer': verification_results['eer']
            },
            'retrieval': {
                'recall@1': retrieval_results['recall@1'],
                'recall@5': retrieval_results['recall@5'],
                'precision@1': retrieval_results['precision@1'],
                'precision@5': retrieval_results['precision@5']
            }
        }

        # Save model (optional, best is saved during training)
        # torch.save({
        #     'model_state_dict': model.state_dict(),
        #     'embedding_size': embedding_size,
        #     'backbone_name': backbone_name,
        #     'class_mapping': datasets_dict['class_mapping']
        # }, os.path.join(checkpoint_dir, f'pet_metric_learning_{backbone_name}_{loss_type}_comparison_final.pth'))

    # Compare results
    print("\n{'='*50}")
    print("Comparison of Loss Functions")
    print({'='*50})

    # Create a comparison table
    comparison_data = []
    for loss_type, metrics in results.items():
        comparison_data.append({
            'Loss Function': loss_type,
            'ROC AUC': metrics['verification']['roc_auc'],
            'EER': metrics['verification']['eer'],
            'Recall@1': metrics['retrieval']['recall@1'],
            'Recall@5': metrics['retrieval']['recall@5'],
            'Precision@1': metrics['retrieval']['precision@1'],
            'Precision@5': metrics['retrieval']['precision@5']
        })
    comparison_df = pd.DataFrame(comparison_data)

    print(comparison_df)
    comparison_df.to_csv(os.path.join(comparison_base_dir, 'loss_comparison_results.csv'), index=False)
    print(f"Comparison results saved to {os.path.join(comparison_base_dir, 'loss_comparison_results.csv')}")

    # Plot comparison
    plt.figure(figsize=(15, 10))

    metrics_to_plot = ['ROC AUC', 'Recall@1', 'Recall@5', 'Precision@1', 'Precision@5']
    x = np.arange(len(metrics_to_plot))
    width = 0.25

    fig, ax = plt.subplots(figsize=(15, 10))
    for i, loss_type in enumerate(loss_types):
        values = [
            results[loss_type]['verification']['roc_auc'],
            results[loss_type]['retrieval']['recall@1'],
            results[loss_type]['retrieval']['recall@5'],
            results[loss_type]['retrieval']['precision@1'],
            results[loss_type]['retrieval']['precision@5']
        ]
        rects = ax.bar(x + i*width - width, values, width, label=loss_type)
        ax.bar_label(rects, padding=3, fmt='%.3f')

    ax.set_ylabel('Score')
    ax.set_title('Comparison of Loss Functions')
    ax.set_xticks(x)
    ax.set_xticklabels(metrics_to_plot)
    ax.legend()
    ax.grid(True, axis='y')
    ax.set_ylim(0, 1.1) # Adjust ylim for better visualization
    fig.tight_layout()

    plt.savefig(os.path.join(comparison_base_dir, 'loss_function_comparison.png'))
    plt.show()

    return results, comparison_df

# Uncomment to run the comparison
# loss_comparison_results, loss_comparison_df = compare_loss_functions()




## Bonus: Streamlit Demo

The Streamlit demo code has been moved to a separate file: `Charles_Watson_streamlit_pet_similarity_app.ipynb`.

To run the demo locally:
1. Ensure you have Streamlit installed (`pip install streamlit`).
2. First convert the notebook to a Python script: `jupyter nbconvert --to python Charles_Watson_streamlit_pet_similarity_app.ipynb`
3. Make sure a trained model file (e.g., `pet_metric_learning_resnet18_triplet.pth`) exists in the same directory.
4. Run the command: `streamlit run Charles_Watson_streamlit_pet_similarity_app.py`

To run the demo in Google Colab:
1. Upload the `Charles_Watson_streamlit_pet_similarity_app.ipynb` notebook to your Colab environment.
2. Install required packages:
```python
!pip install streamlit pyngrok
```
3. Convert the notebook to a Python script:
```python
!jupyter nbconvert --to python Charles_Watson_streamlit_pet_similarity_app.ipynb
```
4. Run Streamlit with ngrok for public access:
```python
from pyngrok import ngrok
!streamlit run Charles_Watson_streamlit_pet_similarity_app.py &>/dev/null&
public_url = ngrok.connect(8501)
print(f"Streamlit app is running at: {public_url}")
```



## Conclusion

In this notebook, we have implemented a comprehensive metric learning pipeline for pet breed classification using the Oxford-IIIT Pet Dataset. We have:

1. Built a custom embedding model with a CNN backbone and projection head
2. Implemented various loss functions for metric learning (Triplet, Contrastive, ArcFace)
3. Developed evaluation methods for verification, retrieval, and few-shot classification
4. Created visualization tools for embedding spaces and feature importance (Grad-CAM)
5. Included bonus implementations for hard negative mining and loss function comparison
6. Moved the Streamlit demo to a separate `Charles_Watson_streamlit_pet_similarity_app.ipynb` file.

The code is modular and can be easily adapted for different settings and experimentation. To run the complete training and evaluation pipeline, simply execute the `main()` function.