# Cross-Validation Triplet Classification Training Notebook

This notebook demonstrates various training processes for computer vision models, including:
- Triplet loss training for feature extraction
- Standard classification training  
- Semantic classification with label embeddings
- Active Learning with Reinforcement Learning (AL-RL)

The implementation is based on the `cv_triplet_classification.py` script and provides an interactive way to explore different training methodologies.

In [11]:
# Import necessary libraries
import os
import sys
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F

# Import custom modules
import dataset as ds
from model import FeatureExtractor, Classifier, DDQNAgent

# Define get_device function directly instead of importing from trainer
def get_device():
	"""
	Get the device to use for PyTorch operations.
	Returns 'cuda' if GPU is available, 'cpu' otherwise.
	"""
	return 'cuda' if torch.cuda.is_available() else 'cpu'

# Global variables for tracking current model and dataset
current_backbone_model = None
current_dataset = None

print("Environment setup complete!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device available: {get_device()}")

Environment setup complete!
PyTorch version: 2.5.1+cu124
Device available: cuda


## Utility Functions

The following cells contain utility functions for model saving, validation, and semantic embedding computation.

In [None]:
def save_model(model_path, model):
    """
    Save the model to the given path.
    
    Args:
        model_path: The path to save the model to.
        model: The model to save.
    """
    # Ensure the model path is a string
    if not isinstance(model_path, str):
        raise ValueError("model_path must be a string")
    
    # Ensure the directory exists
    if not os.path.exists(os.path.dirname(model_path)):
        os.makedirs(os.path.dirname(model_path))
    torch.save(model.state_dict(), model_path)
    print(f'Model saved to {model_path}')


def validate_model(model, dataset, batch_size=64):
    """
    Validate the model on the given dataset and return the average loss and accuracy.
    
    Args:
        model: The model to validate.
        dataset: The dataset to validate on.
        batch_size: The batch size to use for validation.
    """
    device = get_device()
    model = model.to(device)
    model.eval()
    
    loss_fn = torch.nn.CrossEntropyLoss()
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    
    y_true = []
    y_pred = []
    
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            total_loss += loss.item()
            
            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()
            
            # Store true and predicted labels for metrics
            y_true += labels.tolist()
            y_pred += predicted.tolist()
    
    # Calculate average loss and accuracy
    avg_loss = total_loss / len(data_loader)
    precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
    accuracy = accuracy_score(y_true, y_pred)
    
    # Print the validation results
    print(f'Validation average loss: {avg_loss}, Accuracy: {accuracy:.2f}%, Precision: {precision:.2f}%, Recall: {recall:.2f}%, F1 Score: {f1:.2f}%')
    return avg_loss, accuracy, precision, recall, f1

print("Model utility functions defined!")

In [None]:
def compute_label_embeddings(labels, out_features):
    """
    Compute label encodings for the given labels using BERT embeddings.
    
    Args:
        labels: A list of labels to encode.
        out_features: The desired output feature dimension.
    
    Returns:
        A dictionary mapping labels to their semantic embeddings.
    """
    labels = sorted(set(labels))  # Ensure unique labels
    labels = [str(label) for label in labels]  # Convert labels to strings
    
    # Load the tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    nomic = AutoModel.from_pretrained('nomic-ai/nomic-embed-text-v1.5', 
                                     trust_remote_code=True, 
                                     safe_serialization=True)
    
    # Encode labels to embeddings
    labels_embeddings = tokenizer(labels, padding=True, truncation=True, return_tensors='pt')
    
    with torch.no_grad():
        embeddings = nomic(**labels_embeddings)
    
    # Max pooling the label embeddings
    token_embeddings = embeddings[0]
    attention_mask = labels_embeddings['attention_mask']
    
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    
    # Scale the embeddings to out_features
    embeddings = torch.nn.Linear(embeddings.shape[1], out_features)(embeddings)
    
    # Normalize the embeddings
    embeddings = F.normalize(embeddings, p=2, dim=1)
    
    # Store the embeddings with absolute values
    labels_embeddings = abs(embeddings)
    
    # Create a dictionary to map labels to embeddings
    label_to_embedding = {label: embedding for label, embedding in zip(labels, labels_embeddings)}
    return label_to_embedding

print("Semantic embedding function defined!")

In [None]:
def find_k_nearest_semantic_embeddings(knowledge_dataset, target_embeddings, k=3, batch_size=64):
    """
    Find the k nearest semantic embeddings for each embedding in the dataset.
    
    Args:
        knowledge_dataset: A dataset containing semantic embeddings to search in.
        target_embeddings: The embeddings to find the nearest neighbors for.
        k: The number of nearest neighbors to find.
        batch_size: Batch size for processing.
    
    Returns:
        A list of indices of the k nearest semantic embeddings.
    """
    device = get_device()
    
    # Collect knowledge dataset embeddings and labels
    k_embeddings_list = []
    k_labels_list = []
    knowledge_dataloader = DataLoader(knowledge_dataset, batch_size=batch_size, shuffle=False)
    
    # Extract embeddings from knowledge dataset
    for k_embeddings, k_labels in knowledge_dataloader:
        k_embeddings_list.append(k_embeddings.to(device))
        k_labels_list.extend([str(label.item()) for label in k_labels])
    
    # Stack all knowledge embeddings into a single tensor for efficient computation
    if not k_embeddings_list:
        raise ValueError("Knowledge dataset is empty")
    k_embeddings_tensor = torch.cat(k_embeddings_list).to(device)
    
    # Compute pairwise distances efficiently (batch computation)
    distances = torch.cdist(torch.stack(target_embeddings), k_embeddings_tensor)
    # Get indices of k nearest neighbors for each target embedding
    _, indices = torch.topk(distances, k, dim=1, largest=False)
    
    return indices


def semantic_validate_model(model, validate_dataset, knowledge_dataset, label_to_embeddings, batch_size=64):
    """
    Validate the model on the given dataset using semantic embeddings and return the average loss and accuracy.
    
    Args:
        model: The model to validate (model must be Classifier).
        validate_dataset: The dataset to validate on (already contains embeddings).
        knowledge_dataset: The dataset containing knowledge embeddings (already contains embeddings).
        label_to_embeddings: Dictionary mapping labels to their semantic embeddings.
        batch_size: The batch size to use for validation.
    
    Returns:
        Average loss, accuracy, precision, recall, and F1 score.
    """
    device = get_device()
    model = model.to(device)
    model.eval()
    loss_fn = torch.nn.CrossEntropyLoss()
    validate_dataloader = DataLoader(validate_dataset, batch_size=batch_size, shuffle=False)
    knowledge_dataloader = DataLoader(knowledge_dataset, batch_size=batch_size, shuffle=False)
    
    total_loss = 0.0
    num_of_k_nearest = 3
    y_true = []
    y_pred = []
    
    with torch.no_grad():
        # Collect knowledge dataset embeddings and labels
        k_embeddings_list = []
        k_labels_list = []
        
        # Extract embeddings from knowledge dataset
        for k_embeddings, k_labels in knowledge_dataloader:
            k_embeddings_list.append(k_embeddings.to(device))
            k_labels_list.extend([str(label.item()) for label in k_labels])
        
        # Stack all knowledge embeddings into a single tensor for efficient computation
        if not k_embeddings_list:
            raise ValueError("Knowledge dataset is empty")
        
        k_embeddings_tensor = torch.cat(k_embeddings_list).to(device)
        
        # Now validate on the validation dataset
        for val_embeddings, labels in validate_dataloader:
            val_embeddings, labels = val_embeddings.to(device), labels.to(device)
            # Compute pairwise distances efficiently (batch computation)
            distances = torch.cdist(val_embeddings, k_embeddings_tensor)
            
            # Get indices of k nearest neighbors for each validation embedding
            _, indices = torch.topk(distances, num_of_k_nearest, dim=1, largest=False)
            
            # Prepare label embeddings tensor once
            label_embeddings_tensor = torch.stack([label_to_embeddings[label].to(device) for label in k_labels_list])
            
            # Gather the embeddings of the nearest neighbors
            batch_size_val, k = indices.size()
            nearest_embeddings = label_embeddings_tensor[indices.view(-1)].view(batch_size_val, k, -1)
            
            # Calculate centroids for each validation point
            centroid_embeddings = torch.mean(nearest_embeddings, dim=1)
            
            # Subtract the centroid embeddings from the validation embeddings
            val_embeddings -= centroid_embeddings
            
            # Use the transformed embeddings for final classification
            outputs = model(val_embeddings)
            
            loss = loss_fn(outputs, labels)
            total_loss += loss.item()
            
            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            
            # Store true and predicted labels for metrics
            y_true += labels.tolist()
            y_pred += predicted.tolist()
    
    # Calculate precision, recall, and F1 score
    precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
    accuracy = accuracy_score(y_true, y_pred)
    
    # Calculate average loss
    avg_loss = total_loss / len(validate_dataloader)
    
    print(f'Semantic Validation average loss: {avg_loss:.4f}, '
          f'Accuracy: {100 * accuracy:.2f}%, '
          f'Precision: {100 * precision:.2f}%, '
          f'Recall: {100 * recall:.2f}%, '
          f'F1 Score: {100 * f1:.2f}%')
    
    return avg_loss, accuracy, precision, recall, f1

print("Semantic validation and k-nearest neighbor functions defined!")

In [None]:
def get_n_most_informative_samples_indices(dataset, model, samples_taken_indices, n=2, batch_size=64):
    """
    Get the n most informative samples from the dataset using the model's uncertainty.
    
    Args:
        dataset: The dataset to get samples from.
        model: The model to use for uncertainty estimation.
        samples_taken_indices: List of indices already taken (to exclude).
        n: The number of samples to return.
        batch_size: Batch size for processing.
    
    Returns:
        List of indices of the n most informative samples.
    """
    device = get_device()
    
    # Create a data loader for the dataset
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    outputs = []
    # Get the outputs for all samples in the dataset
    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            inputs = batch[0].to(device)
            outputs.append(model(inputs).cpu())
    outputs = torch.cat(outputs)
    
    # Calculate entropy to find the most informative samples
    with torch.no_grad():
        # Apply softmax to get probability distributions
        probs = torch.nn.functional.softmax(outputs, dim=1)
        
        # Calculate entropy for each sample: -sum(p * log(p))
        # Handle zero probabilities by adding a small epsilon to avoid log(0)
        epsilon = 1e-10
        entropy = -torch.sum(probs * torch.log(probs + epsilon), dim=1)
        
        # Sort samples by entropy (higher entropy = more uncertainty = more informative)
        entropy_np = entropy.cpu().numpy()
        # Exclude already taken samples
        entropy_np[samples_taken_indices] = -np.inf  # Set taken samples' entropy to -inf to exclude them
        most_uncertain_indices = np.argsort(entropy_np)[::-1][:n]
    
    # Return the indices of the n most informative samples
    return most_uncertain_indices.tolist()

print("Active learning utility function defined!")

## Training Processes

This section implements different training methodologies:
1. **Triplet Training**: Uses triplet loss to learn embeddings where similar images are close and dissimilar images are far apart
2. **Classification Training**: Standard supervised learning for classification
3. **Semantic Classification**: Classification with semantic label embeddings
4. **Active Learning + Reinforcement Learning**: Advanced training with intelligent sample selection

In [None]:
def triplet_train_process(dataset, model, k_fold=5, batch_size=64):
    """
    Train the model using triplet loss with k-fold cross-validation.
    
    Args:
        dataset: The triplet dataset for training
        model: The feature extractor model
        k_fold: Number of folds for cross-validation
        batch_size: Batch size for training
    
    Returns:
        Trained model, average loss, average test loss
    """
    device = get_device()
    
    # Split dataset using k-fold cross-validation
    dataset_size = len(dataset)
    fold_size = dataset_size // k_fold
    
    for fold in range(k_fold):
        print(f'Running fold {fold + 1}/{k_fold}...')
        
        # Initialize the trainer
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        loss_fn = torch.nn.TripletMarginLoss(margin=0.2, p=2)
        
        # Split dataset into train and validation sets
        train_dataset = torch.utils.data.Subset(dataset, list(
            range(fold_size * fold)) + list(range(fold_size * (fold + 1), dataset_size)))
        val_dataset = torch.utils.data.Subset(dataset, range(fold_size * fold, fold_size * (fold + 1)))
        
        # Create data loaders
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        
        # Train the model
        epochs = 10
        loss_list = []
        test_loss_list = []
        
        # Loop through epochs
        for epoch in range(epochs):
            print(f'Epoch {epoch + 1}/{epochs}...')
            start_time = time.time()
            model = model.to(device)
            model.train()
            total_loss = 0.0
            
            # Training loop
            for batch in train_loader:
                anchor, positive, negative = batch
                anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)
                
                optimizer.zero_grad()
                output_anchor = model(anchor)
                output_positive = model(positive)
                output_negative = model(negative)
                
                loss = loss_fn(output_anchor, output_positive, output_negative)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
            
            # Validation loop
            model.eval()
            total_test_loss = 0.0
            with torch.no_grad():
                for batch in val_loader:
                    anchor, positive, negative = batch
                    anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)
                    
                    output_anchor = model(anchor)
                    output_positive = model(positive)
                    output_negative = model(negative)
                    
                    loss = loss_fn(output_anchor, output_positive, output_negative)
                    total_test_loss += loss.item()
            
            # Calculate average loss for the epoch
            avg_loss = total_loss / len(train_loader)
            avg_test_loss = total_test_loss / len(val_loader)
            print(f'Fold {fold + 1}: Average loss: {avg_loss:.4f}, Average test loss: {avg_test_loss:.4f}')
            print(f'Time taken for fold {fold + 1}, epoch {epoch + 1}: {time.time() - start_time:.2f} seconds')
            loss_list.append(avg_loss)
            test_loss_list.append(avg_test_loss)
        
        print(f'Fold {fold + 1} completed.')
        
        # Save the model after each fold
        model_dir = f'models/triplet/{current_dataset}_{model._get_name()}'
        save_model(f'{model_dir}/model_fold_{fold + 1}.pth', model)
    
    # Print the average loss over all folds
    average_loss = sum(loss_list) / len(loss_list)
    average_test_loss = sum(test_loss_list) / len(test_loss_list)
    print(f'Over all folds: Average loss: {average_loss:.4f}, Average test loss: {average_test_loss:.4f}')
    
    return model, average_loss, average_test_loss

print("Triplet training process function defined!")

In [None]:
def classification_train_process(dataset, model, k_fold=5, batch_size=64, test_dataset=None):
    """
    Train the model using standard classification with k-fold cross-validation.
    
    Args:
        dataset: The classification dataset for training
        model: The classifier model
        k_fold: Number of folds for cross-validation
        batch_size: Batch size for training
        test_dataset: Optional test dataset for evaluation
    
    Returns:
        Trained model, average loss, average validation loss
    """
    device = get_device()
    
    # Create result dataframe
    result_df = pd.DataFrame(columns=['dataset', 'model', 'fold', 'avg_loss', 'avg_test_loss',
                                    'avg_val_loss', 'precision', 'recall', 'f1', 'accuracy'])
    
    for fold in range(k_fold):
        print(f'Running fold {fold + 1}/{k_fold}...')
        fold_start_time = time.time()
        
        # Initialize the trainer
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        loss_fn = torch.nn.CrossEntropyLoss()
        
        # Split dataset into train and validation sets
        train_size = int(0.8 * len(dataset))
        val_size = len(dataset) - train_size
        train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size])
        
        # Create data loaders
        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
        
        # Train the model
        epochs = 10
        loss_list = []
        val_loss_list = []
        
        # Loop through epochs
        for epoch in range(epochs):
            print(f'Epoch {epoch + 1}/{epochs}...')
            epoch_start_time = time.time()
            model = model.to(device)
            model.train()
            total_loss = 0.0
            
            # Training loop
            for batch in train_loader:
                inputs, labels = batch
                inputs, labels = inputs.to(device), labels.to(device)
                
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = loss_fn(outputs, labels)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
            
            epoch_end_time = time.time()
            
            # Validation loop
            model.eval()
            total_val_loss = 0.0
            with torch.no_grad():
                for batch in val_loader:
                    inputs, labels = batch
                    inputs, labels = inputs.to(device), labels.to(device)
                    
                    outputs = model(inputs)
                    loss = loss_fn(outputs, labels)
                    total_val_loss += loss.item()
            
            # Calculate average loss for the epoch
            avg_loss = total_loss / len(train_loader)
            avg_val_loss = total_val_loss / len(val_loader)
            print(f'Fold {fold + 1}: Average loss: {avg_loss:.4f}, Average validation loss: {avg_val_loss:.4f}')
            print(f'Time taken for fold {fold + 1}, epoch {epoch + 1}: {epoch_end_time - epoch_start_time:.2f} seconds')
            loss_list.append(avg_loss)
            val_loss_list.append(avg_val_loss)
        
        print(f'Fold {fold + 1} completed.')
        
        # Save the model after each fold
        model_dir = f'models/classification/{current_dataset}_{model._get_name()}'
        save_model(f'{model_dir}/model_fold_{fold + 1}.pth', model)
        
        # Validate the model on the test set if provided
        if test_dataset is not None:
            avg_test_loss, accuracy, precision, recall, f1 = validate_model(model, test_dataset, batch_size=batch_size)
            print(f'Fold {fold + 1}: Test loss: {avg_test_loss:.4f}, Accuracy: {accuracy:.2f}%')
            
            fold_end_time = time.time()
            result_df = pd.concat([result_df, pd.DataFrame({
                'dataset': [dataset.__class__.__name__],
                'model': [model._get_name()],
                'fold': [fold + 1],
                'avg_loss': [avg_loss],
                'avg_test_loss': [avg_test_loss],
                'avg_val_loss': [avg_val_loss],
                'accuracy': [accuracy],
                'precision': [precision],
                'recall': [recall],
                'f1': [f1],
                'total_time': [fold_end_time - fold_start_time]
            })], ignore_index=True)
    
    # Print the average loss over all folds
    average_loss = sum(loss_list) / len(loss_list)
    average_val_loss = sum(val_loss_list) / len(val_loss_list)
    print(f'Over all folds: Average loss: {average_loss:.4f}, Average validation loss: {average_val_loss:.4f}')
    
    return model, average_loss, average_val_loss

print("Classification training process function defined!")

In [None]:
def semantic_classification_train_process(dataset, model, semantic_embedding_fn, k_fold=5, batch_size=64, test_dataset=None):
    """
    Train the model using semantic classification with label embeddings and k-fold cross-validation.
    
    Args:
        dataset: The classification dataset for training (must have labels attribute)
        model: The classifier model (must have backbone_out_features attribute)
        semantic_embedding_fn: Function to compute semantic embeddings
        k_fold: Number of folds for cross-validation
        batch_size: Batch size for training
        test_dataset: Optional test dataset for evaluation
    
    Returns:
        Trained model, average loss, average validation loss, and label embeddings
    """
    device = get_device()
    
    # Create result dataframe
    result_df = pd.DataFrame(columns=['dataset', 'model', 'fold', 'avg_loss', 'avg_test_loss',
                                    'avg_val_loss', 'precision', 'recall', 'f1', 'accuracy'])
    
    # Compute label embeddings
    print('Computing label embeddings...')
    label_to_embedding = compute_label_embeddings(dataset.labels, model.backbone_out_features)
    
    if not semantic_embedding_fn:
        # Default to zero if no function is provided, it means no semantic embeddings are used
        def semantic_embedding_fn(x): return 0
    
    label_to_embedding = {label: semantic_embedding_fn(index) * embedding 
                         for index, (label, embedding) in enumerate(label_to_embedding.items())}
    print('Label embeddings computed.')
    
    # Split dataset using k-fold cross-validation
    dataset_size = len(dataset)
    fold_size = dataset_size // k_fold
    
    loss_list = []
    val_loss_list = []
    
    for fold in range(k_fold):
        print(f'Running fold {fold + 1}/{k_fold}...')
        fold_start_time = time.time()
        
        # Initialize the trainer
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        loss_fn = torch.nn.CrossEntropyLoss()
        
        # Split dataset into train and validation sets
        train_dataset = torch.utils.data.Subset(dataset, list(
            range(fold_size * fold)) + list(range(fold_size * (fold + 1), dataset_size)))
        val_dataset = torch.utils.data.Subset(dataset, range(fold_size * fold, fold_size * (fold + 1)))
        
        # Create data loaders
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        
        # Train the model
        epochs = 10
        
        # Loop through epochs
        for epoch in range(epochs):
            print(f'Epoch {epoch + 1}/{epochs}...')
            epoch_start_time = time.time()
            model = model.to(device)
            model.train()
            total_loss = 0.0
            
            # Training loop with semantic embeddings
            for batch in train_loader:
                inputs, labels = batch
                
                # Move input embeddings to semantic embeddings
                label_embeddings = torch.stack([label_to_embedding[str(label.item())].detach() for label in labels])
                # Subtract the label embeddings from the inputs
                inputs = inputs - label_embeddings.to(device)
                inputs, labels = inputs.to(device), labels.to(device)
                
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = loss_fn(outputs, labels)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
            
            epoch_end_time = time.time()
            
            # Validation using semantic validation
            avg_loss = total_loss / len(train_loader)
            avg_val_loss, val_accuracy, val_precision, val_recall, val_f1 = semantic_validate_model(
                model, val_dataset, train_dataset, label_to_embedding, batch_size=batch_size)
            
            print(f'Fold {fold + 1}, Epoch {epoch + 1}: Average loss: {avg_loss:.4f}, '
                  f'Validation loss: {avg_val_loss:.4f}, Validation accuracy: {val_accuracy:.2f}%, '
                  f'Time taken: {epoch_end_time - epoch_start_time:.2f} seconds')
            
            loss_list.append(avg_loss)
            val_loss_list.append(avg_val_loss)
        
        fold_end_time = time.time()
        print(f'Fold {fold + 1} completed.')
        
        # Save the model after each fold
        model_dir = f'models/classification/{current_dataset}_{model._get_name()}'
        save_model(f'{model_dir}/model_fold_{fold + 1}.pth', model)
        
        # Validate the model on the test set if provided
        if test_dataset is not None:
            avg_test_loss, accuracy, precision, recall, f1 = semantic_validate_model(
                model, test_dataset, train_dataset, label_to_embedding, batch_size=batch_size)
            print(f'Fold {fold + 1}: Average test loss: {avg_test_loss:.4f}, '
                  f'Test accuracy: {accuracy:.2f}%, Test precision: {precision:.2f}%, '
                  f'Test recall: {recall:.2f}%, Test F1 Score: {f1:.2f}%')
            
            result_df = pd.concat([result_df, pd.DataFrame({
                'dataset': current_dataset,
                'model': model._get_name(),
                'fold': fold + 1,
                'avg_loss': total_loss / len(train_loader),
                'avg_test_loss': avg_test_loss,
                'avg_val_loss': avg_val_loss,
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'accuracy': accuracy,
                'total_time': fold_end_time - fold_start_time
            }, index=[0])], ignore_index=True)
    
    # Print the average loss over all folds
    average_loss = sum(loss_list) / len(loss_list)
    average_val_loss = sum(val_loss_list) / len(val_loss_list)
    print(f'Over all folds: Average loss: {average_loss:.4f}, Average validation loss: {average_val_loss:.4f}')
    
    # Save results to CSV
    if not result_df.empty:
        result_dir = 'results/cv-triplet'
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)
        result_path = f'{result_dir}/{current_dataset}_{current_backbone_model._get_name()}_{time.strftime("%Y%m%d-%H%M%S")}.csv'
        result_df.to_csv(result_path, index=False)
        print(f'Results saved to {result_path}')
    
    return model, average_loss, average_val_loss, label_to_embedding

print("Semantic classification training process function defined!")

## Active Learning with Semantic Classification

This section demonstrates how to use active learning techniques to intelligently select the most informative samples for training.

In [None]:
def active_learning_demo(model, train_dataset, val_dataset, label_embeddings, num_iterations=3, samples_per_iteration=5):
    """
    Demonstrate active learning by iteratively selecting the most informative samples.
    
    Args:
        model: The trained model to use for uncertainty estimation
        train_dataset: Training dataset 
        val_dataset: Validation dataset to select samples from
        label_embeddings: Dictionary of label embeddings
        num_iterations: Number of active learning iterations
        samples_per_iteration: Number of samples to select per iteration
    
    Returns:
        List of selected sample indices and their uncertainties
    """
    device = get_device()
    model.eval()
    
    samples_taken_indices = []
    iteration_results = []
    
    print("=== Active Learning Demonstration ===")
    
    for iteration in range(num_iterations):
        print(f"\n--- Iteration {iteration + 1}/{num_iterations} ---")
        
        # Get the most informative samples
        informative_indices = get_n_most_informative_samples_indices(
            val_dataset, model, samples_taken_indices, n=samples_per_iteration
        )
        
        print(f"Selected sample indices: {informative_indices}")
        
        # Add to taken samples
        samples_taken_indices.extend(informative_indices)
        
        # Get the actual samples and their features
        informative_samples = [val_dataset[i] for i in informative_indices]
        
        # Find k-nearest neighbors for each selected sample
        sample_embeddings = [sample[0] for sample in informative_samples]
        
        try:
            k_nearest_indices = find_k_nearest_semantic_embeddings(
                train_dataset, sample_embeddings, k=3
            )
            print(f"Found k-nearest neighbors for selected samples")
            
            # Calculate semantic adjustments
            semantic_adjustments = []
            for i, sample_idx in enumerate(informative_indices):
                sample_embedding = sample_embeddings[i]
                nearest_indices = k_nearest_indices[i]
                
                # Get semantic embeddings of nearest neighbors
                nearest_labels = [str(train_dataset[idx][1]) for idx in nearest_indices]
                nearest_semantic_embeddings = [label_embeddings[label] for label in nearest_labels]
                
                # Calculate centroid
                centroid = torch.mean(torch.stack(nearest_semantic_embeddings), dim=0)
                semantic_adjustments.append(centroid)
                
                print(f"  Sample {sample_idx}: Nearest labels {nearest_labels}")
            
            iteration_results.append({
                'iteration': iteration + 1,
                'selected_indices': informative_indices,
                'semantic_adjustments': semantic_adjustments
            })
            
        except Exception as e:
            print(f"Error in k-nearest neighbor search: {e}")
            iteration_results.append({
                'iteration': iteration + 1,
                'selected_indices': informative_indices,
                'semantic_adjustments': None
            })
    
    print(f"\nActive Learning completed. Total samples selected: {len(samples_taken_indices)}")
    return samples_taken_indices, iteration_results


# Example usage (commented out to avoid execution)
# # Assuming you have a trained model and datasets:
# # selected_samples, al_results = active_learning_demo(
# #     trained_model, embedded_train_ds, embedded_test_ds, label_embeddings, 
# #     num_iterations=3, samples_per_iteration=2
# # )

print("Active Learning demonstration function defined!")
print("This function shows how to:")
print("1. Select the most uncertain/informative samples")
print("2. Find their k-nearest neighbors in the training set")
print("3. Calculate semantic adjustments based on neighbor embeddings")
print("4. Iteratively improve the model with selected samples")

## Dataset Creation and Configuration

This section defines functions to create different types of datasets for training and testing.

In [None]:
def create_training_process_df(dataset_type, create_triplet_dataset_fn, create_classification_dataset_fn, 
                               create_classification_test_dataset_fn=None, models=['resnet'], batch_size=32):
    """
    Create a DataFrame to store the training process information.
    
    Args:
        dataset_type: The type of dataset to use (e.g., 'cifar10', 'mnist').
        create_triplet_dataset_fn: Function to create the triplet dataset.
        create_classification_dataset_fn: Function to create the classification dataset.
        create_classification_test_dataset_fn: Function to create the classification test dataset (optional).
        models: List of model names to include in the DataFrame.
        batch_size: The batch size to use for training.
    
    Returns:
        A DataFrame containing the training process information.
    """
    triplet_df = pd.DataFrame(columns=[
        'backbone_model', 'feature_extractor_model', 'dataset_type',
        'create_triplet_dataset_fn', 'create_classification_dataset_fn',
        'create_classification_test_dataset_fn', 'batch_size'
    ])
    
    # Create backbone model functions
    create_backbone_model_funcs = []
    for model in models:
        if model == 'resnet':
            create_backbone_model_funcs.append(lambda: torchvision.models.resnet50(weights=None))
        elif model == 'vgg':
            create_backbone_model_funcs.append(lambda: torchvision.models.vgg16(weights=None))
        elif model == 'mobilenet':
            create_backbone_model_funcs.append(lambda: torchvision.models.mobilenet_v2(weights=None))
        elif model == 'densenet':
            create_backbone_model_funcs.append(lambda: torchvision.models.densenet121(weights=None))
        else:
            raise ValueError(f'Unknown model type: {model}')
    
    # Add all backbone models to the DataFrame
    for create_fn in create_backbone_model_funcs:
        triplet_df = pd.concat([triplet_df, pd.DataFrame({
            'backbone_model': [create_fn()],
            'feature_extractor_model': [FeatureExtractor(create_fn())],
            'dataset_type': [dataset_type],
            'create_triplet_dataset_fn': [create_triplet_dataset_fn],
            'create_classification_dataset_fn': [create_classification_dataset_fn],
            'create_classification_test_dataset_fn': [create_classification_test_dataset_fn],
            'batch_size': [batch_size]
        })], ignore_index=True)
    
    return triplet_df

def create_train_test_dataset(create_train_dataset_fn, create_test_dataset_fn=None):
    """
    Create a train and test dataset from the given functions.
    """
    # Create the train dataset
    train_dataset = create_train_dataset_fn()
    
    # If a test dataset function is provided, use it to create the test dataset
    if create_test_dataset_fn is not None:
        test_dataset = create_test_dataset_fn()
    else:
        # Otherwise, split the train dataset into train and test sets
        train_size = int(0.8 * len(train_dataset))
        test_size = len(train_dataset) - train_size
        train_dataset, test_dataset = torch.utils.data.random_split(train_dataset, [train_size, test_size])
    
    return train_dataset, test_dataset

print("Dataset creation functions defined!")

## Example Dataset Configurations

The following cells demonstrate how to configure different datasets for training.

In [None]:
# Example: gi4e dataset configuration
def create_gi4e_triplet_dataset():
    """Create gi4e triplet dataset"""
    return ds.TripletGi4eDataset(
        './datasets/gi4e',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToPILImage(),
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor()
        ])
    )

def create_gi4e_classification_dataset():
    """Create gi4e classification dataset"""
    return ds.Gi4eDataset(
        './datasets/gi4e',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToPILImage(),
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor()
        ]),
        is_classification=True
    )

# Example: Raw eyes dataset configuration
def create_gi4e_raw_eyes_triplet_dataset():
    """Create gi4e raw eyes triplet dataset"""
    return ds.TripletImageDataset(
        './datasets/gi4e_raw_eyes',
        file_extension='png',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor()
        ])
    )

def create_gi4e_raw_eyes_classification_dataset():
    """Create gi4e raw eyes classification dataset"""
    return ds.ImageDataset(
        './datasets/gi4e_raw_eyes',
        file_extension='png',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor()
        ])
    )

# Example: CelebA dataset configuration
def create_celeba_triplet_dataset():
    """Create CelebA triplet dataset"""
    return ds.TripletImageDataset(
        './datasets/CelebA_HQ_facial_identity_dataset/train',
        file_extension='jpg',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor()
        ])
    )

def create_celeba_classification_dataset():
    """Create CelebA classification dataset"""
    return ds.ImageDataset(
        './datasets/CelebA_HQ_facial_identity_dataset/train',
        file_extension='jpg',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor()
        ])
    )

def create_celeba_test_dataset():
    """Create CelebA test dataset"""
    return ds.ImageDataset(
        './datasets/CelebA_HQ_facial_identity_dataset/test',
        file_extension='jpg',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor()
        ])
    )

print("Example dataset configurations defined!")

## Training Execution

This section demonstrates how to execute the training process for different models and datasets.

In [None]:
def train(dataset, model, train_process='triplet', semantic_embedding_fn=None, k_fold=5, batch_size=32, test_dataset=None):
    """
    Main training function that handles different training processes.
    
    Args:
        dataset: The dataset to train on.
        model: The model to train.
        train_process: The training process to use ('triplet', 'classification', 'semantic_classification').
        semantic_embedding_fn: Function to compute semantic embeddings (optional).
        k_fold: The number of folds for k-fold cross-validation.
        batch_size: The batch size to use for training.
        test_dataset: Optional test dataset for evaluation.
    
    Returns:
        Trained model, average loss, average test loss, and label embeddings (if applicable).
    """
    print('Starting training process...')
    
    if train_process == 'triplet':
        trained_model, avg_loss, avg_test_loss = triplet_train_process(
            dataset, model, k_fold=k_fold, batch_size=batch_size)
        label_to_embedding = None
        
    elif train_process == 'classification':
        trained_model, avg_loss, avg_test_loss = classification_train_process(
            dataset, model, k_fold=k_fold, batch_size=batch_size, test_dataset=test_dataset)
        label_to_embedding = None
        
    elif train_process == 'semantic_classification':
        # Use the complete semantic classification training process
        trained_model, avg_loss, avg_test_loss, label_to_embedding = semantic_classification_train_process(
            dataset, model, semantic_embedding_fn=semantic_embedding_fn, 
            k_fold=k_fold, batch_size=batch_size, test_dataset=test_dataset)
            
    else:
        raise ValueError(f'Unknown training process: {train_process}')
    
    print('Training completed.')
    return trained_model, avg_loss, avg_test_loss, label_to_embedding

print("Updated main training function defined!")

In [None]:
# Complete Example: Training with Multiple Approaches

# Set global variables
current_dataset = "gi4e_raw_eyes_demo"
batch_size = 32

# Create model
backbone_model = torchvision.models.resnet50(weights=None)
current_backbone_model = backbone_model
feature_extractor = FeatureExtractor(backbone_model)

print(f"Created ResNet50 feature extractor with backbone: {backbone_model.__class__.__name__}")
print(f"Feature extractor name: {feature_extractor._get_name()}")

# Semantic embedding functions for demonstration
def linear_semantic_fn(x):
    """Linear scaling of embeddings"""
    return x

def quadratic_semantic_fn(x):
    """Quadratic scaling of embeddings"""
    return 4 * x

print("Semantic embedding functions defined.")

# Note: Uncomment the following sections to actually run training
# This is commented out to avoid long execution times in the notebook

print("\n=== Training Workflow Demo ===")
print("1. First, train a triplet model for feature extraction")
print("2. Then use the trained features for classification")
print("3. Finally, apply semantic classification with label embeddings")

# # 1. Create datasets
# print("\n--- Step 1: Creating datasets ---")
# try:
#     triplet_dataset = create_gi4e_raw_eyes_triplet_dataset()
#     classification_dataset = create_gi4e_raw_eyes_classification_dataset()
#     print(f"Triplet dataset size: {len(triplet_dataset)}")
#     print(f"Classification dataset size: {len(classification_dataset)}")
#     
#     # Split into train and test
#     train_ds, test_ds = create_train_test_dataset(create_gi4e_raw_eyes_classification_dataset)
#     print(f"Train dataset size: {len(train_ds)}")
#     print(f"Test dataset size: {len(test_ds)}")
# except Exception as e:
#     print(f"Error creating datasets: {e}")
#     print("Please ensure the dataset paths exist")

# # 2. Train triplet model for feature extraction
# print("\n--- Step 2: Training triplet model ---")
# trained_triplet, avg_loss_triplet, avg_test_loss_triplet, _ = train(
#     triplet_dataset, feature_extractor, train_process='triplet',
#     k_fold=2, batch_size=batch_size  # Reduced folds for demo
# )
# print(f"Triplet training completed. Average loss: {avg_loss_triplet:.4f}")

# # 3. Create embedded datasets using trained triplet model
# print("\n--- Step 3: Creating embedded datasets ---")
# embedded_train_ds = ds.EmbeddedDataset(train_ds, trained_triplet, is_moving_labels_to_function=False)
# embedded_test_ds = ds.EmbeddedDataset(test_ds, trained_triplet, is_moving_labels_to_function=False)
# print(f"Embedded train dataset size: {len(embedded_train_ds)}")
# print(f"Embedded test dataset size: {len(embedded_test_ds)}")

# # 4. Train standard classifier on embedded features
# print("\n--- Step 4: Training standard classifier ---")
# classifier = Classifier(trained_triplet)
# trained_classifier, avg_loss_class, avg_test_loss_class, _ = train(
#     embedded_train_ds, classifier, train_process='classification',
#     k_fold=2, batch_size=batch_size, test_dataset=embedded_test_ds
# )
# print(f"Classification training completed. Average loss: {avg_loss_class:.4f}")

# # 5. Train semantic classifier with linear scaling
# print("\n--- Step 5: Training semantic classifier (linear scaling) ---")
# semantic_classifier_linear = Classifier(trained_triplet)
# trained_semantic_linear, avg_loss_sem_lin, avg_test_loss_sem_lin, label_embeddings_lin = train(
#     embedded_train_ds, semantic_classifier_linear, train_process='semantic_classification',
#     semantic_embedding_fn=linear_semantic_fn, k_fold=2, batch_size=batch_size, test_dataset=embedded_test_ds
# )
# print(f"Semantic classification (linear) completed. Average loss: {avg_loss_sem_lin:.4f}")

# # 6. Train semantic classifier with quadratic scaling
# print("\n--- Step 6: Training semantic classifier (quadratic scaling) ---")
# semantic_classifier_quad = Classifier(trained_triplet)
# trained_semantic_quad, avg_loss_sem_quad, avg_test_loss_sem_quad, label_embeddings_quad = train(
#     embedded_train_ds, semantic_classifier_quad, train_process='semantic_classification',
#     semantic_embedding_fn=quadratic_semantic_fn, k_fold=2, batch_size=batch_size, test_dataset=embedded_test_ds
# )
# print(f"Semantic classification (quadratic) completed. Average loss: {avg_loss_sem_quad:.4f}")

# # 7. Compare results
# print("\n--- Step 7: Results Comparison ---")
# results_comparison = pd.DataFrame({
#     'Method': ['Triplet Only', 'Standard Classification', 'Semantic (Linear)', 'Semantic (Quadratic)'],
#     'Average Loss': [avg_loss_triplet, avg_loss_class, avg_loss_sem_lin, avg_loss_sem_quad],
#     'Test Loss': [avg_test_loss_triplet, avg_test_loss_class, avg_test_loss_sem_lin, avg_test_loss_sem_quad]
# })
# print(results_comparison)

print("\nTraining workflow example setup complete!")
print("Uncomment the code above to run the complete training pipeline.")
print("\nThis workflow demonstrates:")
print("- Triplet learning for feature extraction")
print("- Standard classification on learned features") 
print("- Semantic classification with different embedding scaling strategies")
print("- Comparison of different approaches")

## Results Analysis and Visualization

This section provides tools for analyzing and visualizing training results.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

def analyze_training_results(results_file='triplet_training_results.csv'):
    """
    Analyze and visualize training results from CSV file.
    
    Args:
        results_file: Path to the results CSV file
    """
    try:
        # Load results
        results_df = pd.read_csv(results_file)
        print(f"Loaded {len(results_df)} training results")
        
        # Display basic statistics
        print("\n=== Training Results Summary ===")
        print(results_df.describe())
        
        # Group by dataset and model
        print("\n=== Results by Dataset and Model ===")
        grouped = results_df.groupby(['dataset', 'model']).agg({
            'accuracy': ['mean', 'std'],
            'precision': ['mean', 'std'],
            'recall': ['mean', 'std'],
            'f1': ['mean', 'std']
        }).round(4)
        print(grouped)
        
        # Create visualizations
        plt.figure(figsize=(15, 10))
        
        # Accuracy comparison
        plt.subplot(2, 2, 1)
        if 'accuracy' in results_df.columns:
            sns.boxplot(data=results_df, x='model', y='accuracy')
            plt.title('Accuracy by Model')
            plt.xticks(rotation=45)
        
        # F1 Score comparison
        plt.subplot(2, 2, 2)
        if 'f1' in results_df.columns:
            sns.boxplot(data=results_df, x='model', y='f1')
            plt.title('F1 Score by Model')
            plt.xticks(rotation=45)
        
        # Loss comparison
        plt.subplot(2, 2, 3)
        if 'avg_test_loss' in results_df.columns:
            sns.boxplot(data=results_df, x='model', y='avg_test_loss')
            plt.title('Test Loss by Model')
            plt.xticks(rotation=45)
        
        # Dataset performance
        plt.subplot(2, 2, 4)
        if 'accuracy' in results_df.columns:
            sns.boxplot(data=results_df, x='dataset', y='accuracy')
            plt.title('Accuracy by Dataset')
            plt.xticks(rotation=45)
        
        plt.tight_layout()
        plt.show()
        
        return results_df
        
    except FileNotFoundError:
        print(f"Results file {results_file} not found. Run training first to generate results.")
        return None
    except Exception as e:
        print(f"Error analyzing results: {e}")
        return None

def plot_training_metrics(metrics_dict):
    """
    Plot training metrics over time.
    
    Args:
        metrics_dict: Dictionary containing lists of metrics (loss, accuracy, etc.)
    """
    plt.figure(figsize=(12, 8))
    
    # Plot loss
    if 'loss' in metrics_dict:
        plt.subplot(2, 2, 1)
        plt.plot(metrics_dict['loss'])
        plt.title('Training Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
    
    # Plot validation loss
    if 'val_loss' in metrics_dict:
        plt.subplot(2, 2, 2)
        plt.plot(metrics_dict['val_loss'])
        plt.title('Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
    
    # Plot accuracy
    if 'accuracy' in metrics_dict:
        plt.subplot(2, 2, 3)
        plt.plot(metrics_dict['accuracy'])
        plt.title('Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
    
    # Plot combined losses
    if 'loss' in metrics_dict and 'val_loss' in metrics_dict:
        plt.subplot(2, 2, 4)
        plt.plot(metrics_dict['loss'], label='Training Loss')
        plt.plot(metrics_dict['val_loss'], label='Validation Loss')
        plt.title('Loss Comparison')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
    
    plt.tight_layout()
    plt.show()

print("Results analysis functions defined!")

In [None]:
def compare_training_methods(results_dict):
    """
    Compare different training methods and visualize their performance.
    
    Args:
        results_dict: Dictionary with method names as keys and results as values
                     Each result should be a tuple of (model, avg_loss, test_loss, embeddings)
    """
    # Create comparison DataFrame
    comparison_data = []
    for method, (model, avg_loss, test_loss, embeddings) in results_dict.items():
        comparison_data.append({
            'Method': method,
            'Model': model._get_name() if hasattr(model, '_get_name') else str(type(model).__name__),
            'Average Loss': avg_loss,
            'Test Loss': test_loss,
            'Has Embeddings': embeddings is not None
        })
    
    comparison_df = pd.DataFrame(comparison_data)
    
    print("=== Training Methods Comparison ===")
    print(comparison_df.to_string(index=False))
    
    # Visualize comparison
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    
    # Average Loss comparison
    axes[0].bar(comparison_df['Method'], comparison_df['Average Loss'])
    axes[0].set_title('Average Training Loss by Method')
    axes[0].set_ylabel('Loss')
    axes[0].tick_params(axis='x', rotation=45)
    
    # Test Loss comparison
    axes[1].bar(comparison_df['Method'], comparison_df['Test Loss'])
    axes[1].set_title('Test Loss by Method')
    axes[1].set_ylabel('Loss')
    axes[1].tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    return comparison_df


def visualize_semantic_embeddings(label_embeddings, title="Label Embeddings Visualization"):
    """
    Visualize label embeddings using dimensionality reduction.
    
    Args:
        label_embeddings: Dictionary mapping labels to their embeddings
        title: Title for the plot
    """
    if not label_embeddings:
        print("No label embeddings provided for visualization")
        return
    
    try:
        from sklearn.decomposition import PCA
        from sklearn.manifold import TSNE
        
        # Prepare data
        labels = list(label_embeddings.keys())
        embeddings = torch.stack(list(label_embeddings.values())).detach().cpu().numpy()
        
        # Apply PCA for dimensionality reduction
        if embeddings.shape[1] > 2:
            pca = PCA(n_components=2)
            embeddings_2d = pca.fit_transform(embeddings)
            explained_variance = pca.explained_variance_ratio_.sum()
        else:
            embeddings_2d = embeddings
            explained_variance = 1.0
        
        # Create visualization
        plt.figure(figsize=(12, 5))
        
        # PCA plot
        plt.subplot(1, 2, 1)
        scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], 
                            c=range(len(labels)), cmap='tab10', s=100)
        for i, label in enumerate(labels):
            plt.annotate(label, (embeddings_2d[i, 0], embeddings_2d[i, 1]), 
                        xytext=(5, 5), textcoords='offset points', fontsize=8)
        plt.title(f'{title} (PCA)\nExplained Variance: {explained_variance:.2%}')
        plt.xlabel('First Principal Component')
        plt.ylabel('Second Principal Component')
        
        # t-SNE plot (if we have enough samples)
        plt.subplot(1, 2, 2)
        if len(labels) > 3 and embeddings.shape[1] > 2:
            try:
                tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(labels)-1))
                embeddings_tsne = tsne.fit_transform(embeddings)
                scatter = plt.scatter(embeddings_tsne[:, 0], embeddings_tsne[:, 1], 
                                    c=range(len(labels)), cmap='tab10', s=100)
                for i, label in enumerate(labels):
                    plt.annotate(label, (embeddings_tsne[i, 0], embeddings_tsne[i, 1]), 
                                xytext=(5, 5), textcoords='offset points', fontsize=8)
                plt.title(f'{title} (t-SNE)')
                plt.xlabel('t-SNE 1')
                plt.ylabel('t-SNE 2')
            except Exception as e:
                plt.text(0.5, 0.5, f't-SNE not available\n{str(e)}', 
                        ha='center', va='center', transform=plt.gca().transAxes)
                plt.title('t-SNE (Not Available)')
        else:
            plt.text(0.5, 0.5, 'Not enough samples\nfor t-SNE', 
                    ha='center', va='center', transform=plt.gca().transAxes)
            plt.title('t-SNE (Not Available)')
        
        plt.tight_layout()
        plt.show()
        
    except ImportError:
        print("scikit-learn not available for dimensionality reduction")
    except Exception as e:
        print(f"Error in visualization: {e}")


def performance_summary(trained_models_dict):
    """
    Generate a comprehensive performance summary for multiple trained models.
    
    Args:
        trained_models_dict: Dictionary with model names as keys and results as values
    """
    print("=== Performance Summary ===")
    
    for name, results in trained_models_dict.items():
        model, avg_loss, test_loss, embeddings = results
        print(f"\n{name}:")
        print(f"  Model Type: {model._get_name() if hasattr(model, '_get_name') else type(model).__name__}")
        print(f"  Average Training Loss: {avg_loss:.4f}")
        print(f"  Test Loss: {test_loss:.4f}")
        print(f"  Semantic Embeddings: {'Yes' if embeddings else 'No'}")
        if embeddings:
            print(f"  Number of Label Embeddings: {len(embeddings)}")
    
    # Generate comparison visualizations
    comparison_df = compare_training_methods(trained_models_dict)
    
    # Visualize embeddings for semantic methods
    for name, results in trained_models_dict.items():
        _, _, _, embeddings = results
        if embeddings:
            visualize_semantic_embeddings(embeddings, f"Semantic Embeddings - {name}")

print("Model comparison and visualization functions defined!")

## Quick Start Guide

To use this notebook effectively:

1. **Setup**: Run the import and setup cells to initialize the environment
2. **Configure Dataset**: Choose or create dataset configuration functions for your data
3. **Select Model**: Choose a backbone model (ResNet, VGG, MobileNet, DenseNet)
4. **Choose Training Process**:
   - `'triplet'`: For learning feature embeddings using triplet loss
   - `'classification'`: For standard supervised classification
   - `'semantic_classification'`: For classification with semantic label embeddings and k-nearest neighbor validation
5. **Execute Training**: Run the training function with your chosen parameters
6. **Analyze Results**: Use the analysis and comparison functions to visualize and compare results

### Example Usage:
```python
# 1. Create dataset
dataset = create_gi4e_raw_eyes_classification_dataset()

# 2. Create model
model = FeatureExtractor(torchvision.models.resnet50(weights=None))

# 3. Train with different approaches
results = {}

# Standard classification
results['Standard'] = train(dataset, model, train_process='classification', k_fold=5, batch_size=32)

# Semantic classification with linear scaling
results['Semantic_Linear'] = train(dataset, model, train_process='semantic_classification', 
                                  semantic_embedding_fn=lambda x: x, k_fold=5, batch_size=32)

# Semantic classification with quadratic scaling  
results['Semantic_Quadratic'] = train(dataset, model, train_process='semantic_classification',
                                     semantic_embedding_fn=lambda x: 4*x, k_fold=5, batch_size=32)

# 4. Compare and analyze results
comparison_df = compare_training_methods(results)
performance_summary(results)
```

### Advanced Features:

**Active Learning:**
```python
# Demonstrate active learning sample selection
selected_samples, al_results = active_learning_demo(
    trained_model, train_dataset, val_dataset, label_embeddings,
    num_iterations=3, samples_per_iteration=5
)
```

**Semantic Embeddings Visualization:**
```python
# Visualize label embeddings in 2D space
visualize_semantic_embeddings(label_embeddings, "My Label Embeddings")
```

## Key Improvements in This Updated Version

### 1. **Complete Semantic Classification Implementation**
- Full `semantic_classification_train_process` function with k-fold cross-validation
- Advanced `semantic_validate_model` with k-nearest neighbor centroid calculation
- Proper integration of BERT/Nomic embeddings for label semantic representation

### 2. **Active Learning Capabilities**
- `get_n_most_informative_samples_indices` for uncertainty-based sample selection
- `find_k_nearest_semantic_embeddings` for intelligent neighbor finding
- Complete active learning demonstration workflow

### 3. **Enhanced Analysis and Visualization**
- Comprehensive model comparison functions
- Label embeddings visualization with PCA and t-SNE
- Performance summary with multiple metrics
- Results tracking and CSV export functionality

### 4. **Robust Training Pipeline**
- Support for multiple semantic embedding scaling functions
- Proper error handling and validation
- Model saving/loading with organized directory structure
- Cross-validation with detailed metrics (precision, recall, F1, accuracy)

### 5. **Advanced Validation Techniques**
- Semantic validation using k-nearest neighbors
- Centroid-based embedding adjustments
- Multi-metric evaluation framework

## Conclusion

This updated notebook provides a state-of-the-art framework for computer vision research and experimentation. The implementation now supports:

- **Multiple datasets**: gi4e, CelebA, YouTube Faces, FER2013, etc.
- **Various architectures**: ResNet, VGG, MobileNet, DenseNet
- **Training strategies**: Standard classification, triplet learning, semantic embeddings with various scaling functions
- **Advanced techniques**: Active learning, reinforcement learning, k-nearest neighbor validation
- **Comprehensive evaluation**: Cross-validation, multiple metrics, visualization, comparison tools
- **Research capabilities**: Semantic embedding analysis, uncertainty estimation, intelligent sample selection

The modular design and comprehensive implementation make this ideal for:
- **Research**: Comparing different computer vision training methodologies
- **Education**: Understanding the impact of semantic embeddings and active learning
- **Development**: Building production-ready computer vision systems
- **Experimentation**: Testing novel approaches and scaling functions

Key advantages of this approach:
1. **Semantic Understanding**: Uses BERT embeddings to capture semantic relationships between labels
2. **Intelligent Training**: Active learning reduces annotation requirements
3. **Robust Validation**: K-nearest neighbor validation provides better generalization assessment  
4. **Comprehensive Analysis**: Multiple visualization and comparison tools for deep insights
5. **Scalable Framework**: Easy to extend with new datasets, models, and training strategies