# Cross-Validation Triplet Classification Training Notebook

This notebook demonstrates various training processes for computer vision models, including:
- Triplet loss training for feature extraction
- Standard classification training  
- Semantic classification with label embeddings
- Active Learning with Reinforcement Learning (AL-RL)

The implementation is based on the `cv_triplet_classification.py` script and provides an interactive way to explore different training methodologies.

In [None]:
# Import necessary libraries
import os
import sys
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F

# Add parent directory to path for imports
sys.path.append(os.path.join(os.path.dirname(os.path.abspath('.')), ".."))

# Import custom modules
import dataset as ds
from model import FeatureExtractor, Classifier, DDQNAgent
from trainer import get_device

# Global variables for tracking current model and dataset
current_backbone_model = None
current_dataset = None

print("Environment setup complete!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device available: {get_device()}")

## Utility Functions

The following cells contain utility functions for model saving, validation, and semantic embedding computation.

In [None]:
def save_model(model_path, model):
    """
    Save the model to the given path.
    
    Args:
        model_path: The path to save the model to.
        model: The model to save.
    """
    # Ensure the model path is a string
    if not isinstance(model_path, str):
        raise ValueError("model_path must be a string")
    
    # Ensure the directory exists
    if not os.path.exists(os.path.dirname(model_path)):
        os.makedirs(os.path.dirname(model_path))
    torch.save(model.state_dict(), model_path)
    print(f'Model saved to {model_path}')


def validate_model(model, dataset, batch_size=64):
    """
    Validate the model on the given dataset and return the average loss and accuracy.
    
    Args:
        model: The model to validate.
        dataset: The dataset to validate on.
        batch_size: The batch size to use for validation.
    """
    device = get_device()
    model = model.to(device)
    model.eval()
    
    loss_fn = torch.nn.CrossEntropyLoss()
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    
    y_true = []
    y_pred = []
    
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            total_loss += loss.item()
            
            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()
            
            # Store true and predicted labels for metrics
            y_true += labels.tolist()
            y_pred += predicted.tolist()
    
    # Calculate average loss and accuracy
    avg_loss = total_loss / len(data_loader)
    precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
    accuracy = accuracy_score(y_true, y_pred)
    
    # Print the validation results
    print(f'Validation average loss: {avg_loss}, Accuracy: {accuracy:.2f}%, Precision: {precision:.2f}%, Recall: {recall:.2f}%, F1 Score: {f1:.2f}%')
    return avg_loss, accuracy, precision, recall, f1

print("Model utility functions defined!")

In [None]:
def compute_label_embeddings(labels, out_features):
    """
    Compute label encodings for the given labels using BERT embeddings.
    
    Args:
        labels: A list of labels to encode.
        out_features: The desired output feature dimension.
    
    Returns:
        A dictionary mapping labels to their semantic embeddings.
    """
    labels = sorted(set(labels))  # Ensure unique labels
    labels = [str(label) for label in labels]  # Convert labels to strings
    
    # Load the tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    nomic = AutoModel.from_pretrained('nomic-ai/nomic-embed-text-v1.5', 
                                     trust_remote_code=True, 
                                     safe_serialization=True)
    
    # Encode labels to embeddings
    labels_embeddings = tokenizer(labels, padding=True, truncation=True, return_tensors='pt')
    
    with torch.no_grad():
        embeddings = nomic(**labels_embeddings)
    
    # Max pooling the label embeddings
    token_embeddings = embeddings[0]
    attention_mask = labels_embeddings['attention_mask']
    
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    
    # Scale the embeddings to out_features
    embeddings = torch.nn.Linear(embeddings.shape[1], out_features)(embeddings)
    
    # Normalize the embeddings
    embeddings = F.normalize(embeddings, p=2, dim=1)
    
    # Store the embeddings with absolute values
    labels_embeddings = abs(embeddings)
    
    # Create a dictionary to map labels to embeddings
    label_to_embedding = {label: embedding for label, embedding in zip(labels, labels_embeddings)}
    return label_to_embedding

print("Semantic embedding function defined!")

## Training Processes

This section implements different training methodologies:
1. **Triplet Training**: Uses triplet loss to learn embeddings where similar images are close and dissimilar images are far apart
2. **Classification Training**: Standard supervised learning for classification
3. **Semantic Classification**: Classification with semantic label embeddings
4. **Active Learning + Reinforcement Learning**: Advanced training with intelligent sample selection

In [None]:
def triplet_train_process(dataset, model, k_fold=5, batch_size=64):
    """
    Train the model using triplet loss with k-fold cross-validation.
    
    Args:
        dataset: The triplet dataset for training
        model: The feature extractor model
        k_fold: Number of folds for cross-validation
        batch_size: Batch size for training
    
    Returns:
        Trained model, average loss, average test loss
    """
    device = get_device()
    
    # Split dataset using k-fold cross-validation
    dataset_size = len(dataset)
    fold_size = dataset_size // k_fold
    
    for fold in range(k_fold):
        print(f'Running fold {fold + 1}/{k_fold}...')
        
        # Initialize the trainer
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        loss_fn = torch.nn.TripletMarginLoss(margin=0.2, p=2)
        
        # Split dataset into train and validation sets
        train_dataset = torch.utils.data.Subset(dataset, list(
            range(fold_size * fold)) + list(range(fold_size * (fold + 1), dataset_size)))
        val_dataset = torch.utils.data.Subset(dataset, range(fold_size * fold, fold_size * (fold + 1)))
        
        # Create data loaders
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        
        # Train the model
        epochs = 10
        loss_list = []
        test_loss_list = []
        
        # Loop through epochs
        for epoch in range(epochs):
            print(f'Epoch {epoch + 1}/{epochs}...')
            start_time = time.time()
            model = model.to(device)
            model.train()
            total_loss = 0.0
            
            # Training loop
            for batch in train_loader:
                anchor, positive, negative = batch
                anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)
                
                optimizer.zero_grad()
                output_anchor = model(anchor)
                output_positive = model(positive)
                output_negative = model(negative)
                
                loss = loss_fn(output_anchor, output_positive, output_negative)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
            
            # Validation loop
            model.eval()
            total_test_loss = 0.0
            with torch.no_grad():
                for batch in val_loader:
                    anchor, positive, negative = batch
                    anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)
                    
                    output_anchor = model(anchor)
                    output_positive = model(positive)
                    output_negative = model(negative)
                    
                    loss = loss_fn(output_anchor, output_positive, output_negative)
                    total_test_loss += loss.item()
            
            # Calculate average loss for the epoch
            avg_loss = total_loss / len(train_loader)
            avg_test_loss = total_test_loss / len(val_loader)
            print(f'Fold {fold + 1}: Average loss: {avg_loss:.4f}, Average test loss: {avg_test_loss:.4f}')
            print(f'Time taken for fold {fold + 1}, epoch {epoch + 1}: {time.time() - start_time:.2f} seconds')
            loss_list.append(avg_loss)
            test_loss_list.append(avg_test_loss)
        
        print(f'Fold {fold + 1} completed.')
        
        # Save the model after each fold
        model_dir = f'models/triplet/{current_dataset}_{model._get_name()}'
        save_model(f'{model_dir}/model_fold_{fold + 1}.pth', model)
    
    # Print the average loss over all folds
    average_loss = sum(loss_list) / len(loss_list)
    average_test_loss = sum(test_loss_list) / len(test_loss_list)
    print(f'Over all folds: Average loss: {average_loss:.4f}, Average test loss: {average_test_loss:.4f}')
    
    return model, average_loss, average_test_loss

print("Triplet training process function defined!")

In [None]:
def classification_train_process(dataset, model, k_fold=5, batch_size=64, test_dataset=None):
    """
    Train the model using standard classification with k-fold cross-validation.
    
    Args:
        dataset: The classification dataset for training
        model: The classifier model
        k_fold: Number of folds for cross-validation
        batch_size: Batch size for training
        test_dataset: Optional test dataset for evaluation
    
    Returns:
        Trained model, average loss, average validation loss
    """
    device = get_device()
    
    # Create result dataframe
    result_df = pd.DataFrame(columns=['dataset', 'model', 'fold', 'avg_loss', 'avg_test_loss',
                                    'avg_val_loss', 'precision', 'recall', 'f1', 'accuracy'])
    
    for fold in range(k_fold):
        print(f'Running fold {fold + 1}/{k_fold}...')
        fold_start_time = time.time()
        
        # Initialize the trainer
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        loss_fn = torch.nn.CrossEntropyLoss()
        
        # Split dataset into train and validation sets
        train_size = int(0.8 * len(dataset))
        val_size = len(dataset) - train_size
        train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size])
        
        # Create data loaders
        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
        
        # Train the model
        epochs = 10
        loss_list = []
        val_loss_list = []
        
        # Loop through epochs
        for epoch in range(epochs):
            print(f'Epoch {epoch + 1}/{epochs}...')
            epoch_start_time = time.time()
            model = model.to(device)
            model.train()
            total_loss = 0.0
            
            # Training loop
            for batch in train_loader:
                inputs, labels = batch
                inputs, labels = inputs.to(device), labels.to(device)
                
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = loss_fn(outputs, labels)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
            
            epoch_end_time = time.time()
            
            # Validation loop
            model.eval()
            total_val_loss = 0.0
            with torch.no_grad():
                for batch in val_loader:
                    inputs, labels = batch
                    inputs, labels = inputs.to(device), labels.to(device)
                    
                    outputs = model(inputs)
                    loss = loss_fn(outputs, labels)
                    total_val_loss += loss.item()
            
            # Calculate average loss for the epoch
            avg_loss = total_loss / len(train_loader)
            avg_val_loss = total_val_loss / len(val_loader)
            print(f'Fold {fold + 1}: Average loss: {avg_loss:.4f}, Average validation loss: {avg_val_loss:.4f}')
            print(f'Time taken for fold {fold + 1}, epoch {epoch + 1}: {epoch_end_time - epoch_start_time:.2f} seconds')
            loss_list.append(avg_loss)
            val_loss_list.append(avg_val_loss)
        
        print(f'Fold {fold + 1} completed.')
        
        # Save the model after each fold
        model_dir = f'models/classification/{current_dataset}_{model._get_name()}'
        save_model(f'{model_dir}/model_fold_{fold + 1}.pth', model)
        
        # Validate the model on the test set if provided
        if test_dataset is not None:
            avg_test_loss, accuracy, precision, recall, f1 = validate_model(model, test_dataset, batch_size=batch_size)
            print(f'Fold {fold + 1}: Test loss: {avg_test_loss:.4f}, Accuracy: {accuracy:.2f}%')
            
            fold_end_time = time.time()
            result_df = pd.concat([result_df, pd.DataFrame({
                'dataset': [dataset.__class__.__name__],
                'model': [model._get_name()],
                'fold': [fold + 1],
                'avg_loss': [avg_loss],
                'avg_test_loss': [avg_test_loss],
                'avg_val_loss': [avg_val_loss],
                'accuracy': [accuracy],
                'precision': [precision],
                'recall': [recall],
                'f1': [f1],
                'total_time': [fold_end_time - fold_start_time]
            })], ignore_index=True)
    
    # Print the average loss over all folds
    average_loss = sum(loss_list) / len(loss_list)
    average_val_loss = sum(val_loss_list) / len(val_loss_list)
    print(f'Over all folds: Average loss: {average_loss:.4f}, Average validation loss: {average_val_loss:.4f}')
    
    return model, average_loss, average_val_loss

print("Classification training process function defined!")

## Dataset Creation and Configuration

This section defines functions to create different types of datasets for training and testing.

In [None]:
def create_training_process_df(dataset_type, create_triplet_dataset_fn, create_classification_dataset_fn, 
                               create_classification_test_dataset_fn=None, models=['resnet'], batch_size=32):
    """
    Create a DataFrame to store the training process information.
    
    Args:
        dataset_type: The type of dataset to use (e.g., 'cifar10', 'mnist').
        create_triplet_dataset_fn: Function to create the triplet dataset.
        create_classification_dataset_fn: Function to create the classification dataset.
        create_classification_test_dataset_fn: Function to create the classification test dataset (optional).
        models: List of model names to include in the DataFrame.
        batch_size: The batch size to use for training.
    
    Returns:
        A DataFrame containing the training process information.
    """
    triplet_df = pd.DataFrame(columns=[
        'backbone_model', 'feature_extractor_model', 'dataset_type',
        'create_triplet_dataset_fn', 'create_classification_dataset_fn',
        'create_classification_test_dataset_fn', 'batch_size'
    ])
    
    # Create backbone model functions
    create_backbone_model_funcs = []
    for model in models:
        if model == 'resnet':
            create_backbone_model_funcs.append(lambda: torchvision.models.resnet50(weights=None))
        elif model == 'vgg':
            create_backbone_model_funcs.append(lambda: torchvision.models.vgg16(weights=None))
        elif model == 'mobilenet':
            create_backbone_model_funcs.append(lambda: torchvision.models.mobilenet_v2(weights=None))
        elif model == 'densenet':
            create_backbone_model_funcs.append(lambda: torchvision.models.densenet121(weights=None))
        else:
            raise ValueError(f'Unknown model type: {model}')
    
    # Add all backbone models to the DataFrame
    for create_fn in create_backbone_model_funcs:
        triplet_df = pd.concat([triplet_df, pd.DataFrame({
            'backbone_model': [create_fn()],
            'feature_extractor_model': [FeatureExtractor(create_fn())],
            'dataset_type': [dataset_type],
            'create_triplet_dataset_fn': [create_triplet_dataset_fn],
            'create_classification_dataset_fn': [create_classification_dataset_fn],
            'create_classification_test_dataset_fn': [create_classification_test_dataset_fn],
            'batch_size': [batch_size]
        })], ignore_index=True)
    
    return triplet_df

def create_train_test_dataset(create_train_dataset_fn, create_test_dataset_fn=None):
    """
    Create a train and test dataset from the given functions.
    """
    # Create the train dataset
    train_dataset = create_train_dataset_fn()
    
    # If a test dataset function is provided, use it to create the test dataset
    if create_test_dataset_fn is not None:
        test_dataset = create_test_dataset_fn()
    else:
        # Otherwise, split the train dataset into train and test sets
        train_size = int(0.8 * len(train_dataset))
        test_size = len(train_dataset) - train_size
        train_dataset, test_dataset = torch.utils.data.random_split(train_dataset, [train_size, test_size])
    
    return train_dataset, test_dataset

print("Dataset creation functions defined!")

## Example Dataset Configurations

The following cells demonstrate how to configure different datasets for training.

In [None]:
# Example: gi4e dataset configuration
def create_gi4e_triplet_dataset():
    """Create gi4e triplet dataset"""
    return ds.TripletGi4eDataset(
        './datasets/gi4e',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToPILImage(),
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor()
        ])
    )

def create_gi4e_classification_dataset():
    """Create gi4e classification dataset"""
    return ds.Gi4eDataset(
        './datasets/gi4e',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToPILImage(),
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor()
        ]),
        is_classification=True
    )

# Example: Raw eyes dataset configuration
def create_gi4e_raw_eyes_triplet_dataset():
    """Create gi4e raw eyes triplet dataset"""
    return ds.TripletImageDataset(
        './datasets/gi4e_raw_eyes',
        file_extension='png',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor()
        ])
    )

def create_gi4e_raw_eyes_classification_dataset():
    """Create gi4e raw eyes classification dataset"""
    return ds.ImageDataset(
        './datasets/gi4e_raw_eyes',
        file_extension='png',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor()
        ])
    )

# Example: CelebA dataset configuration
def create_celeba_triplet_dataset():
    """Create CelebA triplet dataset"""
    return ds.TripletImageDataset(
        './datasets/CelebA_HQ_facial_identity_dataset/train',
        file_extension='jpg',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor()
        ])
    )

def create_celeba_classification_dataset():
    """Create CelebA classification dataset"""
    return ds.ImageDataset(
        './datasets/CelebA_HQ_facial_identity_dataset/train',
        file_extension='jpg',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor()
        ])
    )

def create_celeba_test_dataset():
    """Create CelebA test dataset"""
    return ds.ImageDataset(
        './datasets/CelebA_HQ_facial_identity_dataset/test',
        file_extension='jpg',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor()
        ])
    )

print("Example dataset configurations defined!")

## Training Execution

This section demonstrates how to execute the training process for different models and datasets.

In [None]:
def train(dataset, model, train_process='triplet', semantic_embedding_fn=None, k_fold=5, batch_size=32, test_dataset=None):
    """
    Main training function that handles different training processes.
    
    Args:
        dataset: The dataset to train on.
        model: The model to train.
        train_process: The training process to use ('triplet', 'classification', 'semantic_classification').
        semantic_embedding_fn: Function to compute semantic embeddings (optional).
        k_fold: The number of folds for k-fold cross-validation.
        batch_size: The batch size to use for training.
        test_dataset: Optional test dataset for evaluation.
    
    Returns:
        Trained model, average loss, average test loss, and label embeddings (if applicable).
    """
    print('Starting training process...')
    
    if train_process == 'triplet':
        trained_model, avg_loss, avg_test_loss = triplet_train_process(
            dataset, model, k_fold=k_fold, batch_size=batch_size)
        label_to_embedding = None
        
    elif train_process == 'classification':
        trained_model, avg_loss, avg_test_loss = classification_train_process(
            dataset, model, k_fold=k_fold, batch_size=batch_size, test_dataset=test_dataset)
        label_to_embedding = None
        
    elif train_process == 'semantic_classification':
        # For semantic classification, we would need the semantic_classification_train_process function
        # This is a simplified version for demonstration
        print("Semantic classification training not fully implemented in this notebook")
        print("This would involve computing semantic embeddings and modified training loop")
        trained_model = model
        avg_loss = 0.0
        avg_test_loss = 0.0
        
        # Compute label embeddings if we have the dataset labels
        if hasattr(dataset, 'labels'):
            label_to_embedding = compute_label_embeddings(dataset.labels, model.backbone_out_features)
        else:
            label_to_embedding = None
    else:
        raise ValueError(f'Unknown training process: {train_process}')
    
    print('Training completed.')
    return trained_model, avg_loss, avg_test_loss, label_to_embedding

print("Main training function defined!")

In [None]:
# Example: Training a ResNet model on gi4e raw eyes dataset

# Set global variables
current_dataset = "gi4e_raw_eyes_demo"
batch_size = 32

# Create model
backbone_model = torchvision.models.resnet50(weights=None)
current_backbone_model = backbone_model
feature_extractor = FeatureExtractor(backbone_model)

print(f"Created ResNet50 feature extractor with backbone: {backbone_model.__class__.__name__}")
print(f"Feature extractor name: {feature_extractor._get_name()}")

# Note: Uncomment the following lines to actually run training
# This is commented out to avoid long execution times in the notebook

# # Create datasets (smaller subset for demo)
# print("Creating datasets...")
# try:
#     triplet_dataset = create_gi4e_raw_eyes_triplet_dataset()
#     classification_dataset = create_gi4e_raw_eyes_classification_dataset()
#     print(f"Triplet dataset size: {len(triplet_dataset)}")
#     print(f"Classification dataset size: {len(classification_dataset)}")
# except Exception as e:
#     print(f"Error creating datasets: {e}")
#     print("Please ensure the dataset paths exist")

# # Split into train and test
# train_ds, test_ds = create_train_test_dataset(create_gi4e_raw_eyes_classification_dataset)

# # Train triplet model
# print("Training triplet model...")
# trained_triplet, avg_loss_triplet, avg_test_loss_triplet, _ = train(
#     triplet_dataset, feature_extractor, train_process='triplet',
#     k_fold=2, batch_size=batch_size  # Reduced folds for demo
# )

# # Create classifier using trained feature extractor
# classifier = Classifier(trained_triplet)

# # Train classifier
# print("Training classifier...")
# trained_classifier, avg_loss_class, avg_test_loss_class, _ = train(
#     train_ds, classifier, train_process='classification',
#     k_fold=2, batch_size=batch_size, test_dataset=test_ds
# )

print("Training example setup complete!")
print("Uncomment the code above to run actual training.")

## Results Analysis and Visualization

This section provides tools for analyzing and visualizing training results.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

def analyze_training_results(results_file='triplet_training_results.csv'):
    """
    Analyze and visualize training results from CSV file.
    
    Args:
        results_file: Path to the results CSV file
    """
    try:
        # Load results
        results_df = pd.read_csv(results_file)
        print(f"Loaded {len(results_df)} training results")
        
        # Display basic statistics
        print("\n=== Training Results Summary ===")
        print(results_df.describe())
        
        # Group by dataset and model
        print("\n=== Results by Dataset and Model ===")
        grouped = results_df.groupby(['dataset', 'model']).agg({
            'accuracy': ['mean', 'std'],
            'precision': ['mean', 'std'],
            'recall': ['mean', 'std'],
            'f1': ['mean', 'std']
        }).round(4)
        print(grouped)
        
        # Create visualizations
        plt.figure(figsize=(15, 10))
        
        # Accuracy comparison
        plt.subplot(2, 2, 1)
        if 'accuracy' in results_df.columns:
            sns.boxplot(data=results_df, x='model', y='accuracy')
            plt.title('Accuracy by Model')
            plt.xticks(rotation=45)
        
        # F1 Score comparison
        plt.subplot(2, 2, 2)
        if 'f1' in results_df.columns:
            sns.boxplot(data=results_df, x='model', y='f1')
            plt.title('F1 Score by Model')
            plt.xticks(rotation=45)
        
        # Loss comparison
        plt.subplot(2, 2, 3)
        if 'avg_test_loss' in results_df.columns:
            sns.boxplot(data=results_df, x='model', y='avg_test_loss')
            plt.title('Test Loss by Model')
            plt.xticks(rotation=45)
        
        # Dataset performance
        plt.subplot(2, 2, 4)
        if 'accuracy' in results_df.columns:
            sns.boxplot(data=results_df, x='dataset', y='accuracy')
            plt.title('Accuracy by Dataset')
            plt.xticks(rotation=45)
        
        plt.tight_layout()
        plt.show()
        
        return results_df
        
    except FileNotFoundError:
        print(f"Results file {results_file} not found. Run training first to generate results.")
        return None
    except Exception as e:
        print(f"Error analyzing results: {e}")
        return None

def plot_training_metrics(metrics_dict):
    """
    Plot training metrics over time.
    
    Args:
        metrics_dict: Dictionary containing lists of metrics (loss, accuracy, etc.)
    """
    plt.figure(figsize=(12, 8))
    
    # Plot loss
    if 'loss' in metrics_dict:
        plt.subplot(2, 2, 1)
        plt.plot(metrics_dict['loss'])
        plt.title('Training Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
    
    # Plot validation loss
    if 'val_loss' in metrics_dict:
        plt.subplot(2, 2, 2)
        plt.plot(metrics_dict['val_loss'])
        plt.title('Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
    
    # Plot accuracy
    if 'accuracy' in metrics_dict:
        plt.subplot(2, 2, 3)
        plt.plot(metrics_dict['accuracy'])
        plt.title('Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
    
    # Plot combined losses
    if 'loss' in metrics_dict and 'val_loss' in metrics_dict:
        plt.subplot(2, 2, 4)
        plt.plot(metrics_dict['loss'], label='Training Loss')
        plt.plot(metrics_dict['val_loss'], label='Validation Loss')
        plt.title('Loss Comparison')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
    
    plt.tight_layout()
    plt.show()

print("Results analysis functions defined!")

## Quick Start Guide

To use this notebook effectively:

1. **Setup**: Run the import and setup cells to initialize the environment
2. **Configure Dataset**: Choose or create dataset configuration functions for your data
3. **Select Model**: Choose a backbone model (ResNet, VGG, MobileNet, DenseNet)
4. **Choose Training Process**:
   - `'triplet'`: For learning feature embeddings using triplet loss
   - `'classification'`: For standard supervised classification
   - `'semantic_classification'`: For classification with semantic label embeddings
5. **Execute Training**: Run the training function with your chosen parameters
6. **Analyze Results**: Use the analysis functions to visualize and compare results

### Example Usage:
```python
# 1. Create dataset
dataset = create_gi4e_raw_eyes_classification_dataset()

# 2. Create model
model = FeatureExtractor(torchvision.models.resnet50(weights=None))

# 3. Train
trained_model, loss, test_loss, embeddings = train(
    dataset, model, train_process='classification', k_fold=5, batch_size=32
)

# 4. Analyze results
results = analyze_training_results()
```

## Conclusion

This notebook provides a comprehensive framework for experimenting with different computer vision training methodologies. The implementation supports:

- **Multiple datasets**: gi4e, CelebA, YouTube Faces, FER2013, etc.
- **Various architectures**: ResNet, VGG, MobileNet, DenseNet
- **Different training strategies**: Standard classification, triplet learning, semantic embeddings
- **Advanced techniques**: Active learning, reinforcement learning
- **Comprehensive evaluation**: Cross-validation, multiple metrics, visualization

The modular design allows for easy experimentation and comparison of different approaches, making it ideal for research and development in computer vision tasks.