# Complete Training Pipeline for Visual Emotion Recognition

This notebook contains all training functionality for visual emotion recognition, consolidating functionality from the src/training directory.

## Components Included:
1. **Training Loop** - Complete training pipeline with validation
2. **Enhanced Training** - Advanced training with gradient accumulation, clipping, and scheduling
3. **Loss Functions** - Various loss functions including focal loss and label smoothing
4. **Optimizers** - Different optimizers and configurations
5. **Learning Rate Schedulers** - Various LR scheduling strategies
6. **Evaluation Functions** - Model evaluation and metrics
7. **Checkpointing** - Model saving and loading
8. **Training Monitoring** - Loss curves and training visualization


In [None]:
import os
import sys
import time
import json
import copy
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Data and computation
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score
from sklearn.utils.class_weight import compute_class_weight

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, WeightedRandomSampler
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR, ReduceLROnPlateau, OneCycleLR

# Set device and random seeds
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Set random seeds for reproducibility
def set_seed(seed=42):
    """Set random seeds for reproducibility."""
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
print("Random seeds set for reproducibility")

## 1. Loss Functions

In [None]:
class LabelSmoothingCrossEntropy(nn.Module):
    """
    Label Smoothing Cross Entropy Loss.
    Helps prevent overfitting by softening the labels.
    """
    
    def __init__(self, smoothing=0.1):
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.smoothing = smoothing
    
    def forward(self, pred, target):
        confidence = 1. - self.smoothing
        logprobs = F.log_softmax(pred, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()


class FocalLoss(nn.Module):
    """
    Focal Loss for addressing class imbalance.
    Focuses learning on hard examples.
    """
    
    def __init__(self, alpha=1, gamma=2, weight=None):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.weight = weight
    
    def forward(self, pred, target):
        ce_loss = F.cross_entropy(pred, target, weight=self.weight, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        return focal_loss.mean()


class WeightedFocalLoss(nn.Module):
    """
    Weighted Focal Loss combining class weights with focal loss.
    """
    
    def __init__(self, alpha=1, gamma=2, weight=None):
        super(WeightedFocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.weight = weight
    
    def forward(self, pred, target):
        ce_loss = F.cross_entropy(pred, target, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        
        if self.weight is not None:
            # Apply class weights
            weight_t = self.weight.gather(0, target)
            focal_loss = weight_t * focal_loss
        
        return focal_loss.mean()


def get_loss_function(loss_type='CrossEntropy', class_weights=None, **kwargs):
    """
    Factory function to create loss functions.
    
    Args:
        loss_type (str): Type of loss function
        class_weights (torch.Tensor): Class weights for handling imbalance
        **kwargs: Additional arguments for specific loss functions
        
    Returns:
        Loss function
    """
    if loss_type.lower() == 'crossentropy':
        return nn.CrossEntropyLoss(weight=class_weights)
    elif loss_type.lower() == 'labelsmoothing':
        smoothing = kwargs.get('smoothing', 0.1)
        return LabelSmoothingCrossEntropy(smoothing=smoothing)
    elif loss_type.lower() == 'focal':
        alpha = kwargs.get('alpha', 1)
        gamma = kwargs.get('gamma', 2)
        return FocalLoss(alpha=alpha, gamma=gamma, weight=class_weights)
    elif loss_type.lower() == 'weightedfocal':
        alpha = kwargs.get('alpha', 1)
        gamma = kwargs.get('gamma', 2)
        return WeightedFocalLoss(alpha=alpha, gamma=gamma, weight=class_weights)
    else:
        raise ValueError(f"Unsupported loss type: {loss_type}")


# Test loss functions
print("Testing loss functions...")
pred = torch.randn(10, 7)  # 10 samples, 7 classes
target = torch.randint(0, 7, (10,))  # Random targets
class_weights = torch.rand(7)  # Random weights

losses = {
    'CrossEntropy': get_loss_function('CrossEntropy'),
    'Weighted CE': get_loss_function('CrossEntropy', class_weights),
    'Label Smoothing': get_loss_function('LabelSmoothing', smoothing=0.1),
    'Focal': get_loss_function('Focal', alpha=1, gamma=2),
    'Weighted Focal': get_loss_function('WeightedFocal', class_weights, alpha=1, gamma=2)
}

for name, loss_fn in losses.items():
    loss_value = loss_fn(pred, target)
    print(f"{name}: {loss_value.item():.4f}")

print("Loss functions created successfully!")

## 2. Optimizers and Schedulers

In [None]:
def get_optimizer(model, optimizer_type='Adam', lr=1e-3, weight_decay=1e-4, **kwargs):
    """
    Factory function to create optimizers.
    
    Args:
        model: PyTorch model
        optimizer_type (str): Type of optimizer
        lr (float): Learning rate
        weight_decay (float): Weight decay
        **kwargs: Additional optimizer arguments
        
    Returns:
        PyTorch optimizer
    """
    if optimizer_type.lower() == 'adam':
        return optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay,
                         betas=kwargs.get('betas', (0.9, 0.999)))
    elif optimizer_type.lower() == 'adamw':
        return optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay,
                          betas=kwargs.get('betas', (0.9, 0.999)))
    elif optimizer_type.lower() == 'sgd':
        return optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay,
                        momentum=kwargs.get('momentum', 0.9))
    elif optimizer_type.lower() == 'rmsprop':
        return optim.RMSprop(model.parameters(), lr=lr, weight_decay=weight_decay,
                           momentum=kwargs.get('momentum', 0.9))
    else:
        raise ValueError(f"Unsupported optimizer type: {optimizer_type}")


def get_differential_optimizer(model, backbone_lr=1e-4, classifier_lr=1e-3, 
                             optimizer_type='Adam', weight_decay=1e-4):
    """
    Create optimizer with different learning rates for backbone and classifier.
    
    Args:
        model: PyTorch model with 'backbone'/'features' and 'classifier' attributes
        backbone_lr (float): Learning rate for backbone/features
        classifier_lr (float): Learning rate for classifier
        optimizer_type (str): Type of optimizer
        weight_decay (float): Weight decay
        
    Returns:
        PyTorch optimizer
    """
    # Get backbone parameters
    if hasattr(model, 'backbone'):
        backbone_params = model.backbone.parameters()
    elif hasattr(model, 'features'):
        backbone_params = model.features.parameters()
    else:
        raise ValueError("Model must have 'backbone' or 'features' attribute")
    
    # Get classifier parameters
    classifier_params = model.classifier.parameters()
    
    # Create parameter groups
    param_groups = [
        {'params': backbone_params, 'lr': backbone_lr, 'name': 'backbone'},
        {'params': classifier_params, 'lr': classifier_lr, 'name': 'classifier'}
    ]
    
    if optimizer_type.lower() == 'adam':
        return optim.Adam(param_groups, weight_decay=weight_decay)
    elif optimizer_type.lower() == 'adamw':
        return optim.AdamW(param_groups, weight_decay=weight_decay)
    elif optimizer_type.lower() == 'sgd':
        return optim.SGD(param_groups, weight_decay=weight_decay, momentum=0.9)
    else:
        raise ValueError(f"Unsupported optimizer type: {optimizer_type}")


def get_scheduler(optimizer, scheduler_type='StepLR', **kwargs):
    """
    Factory function to create learning rate schedulers.
    
    Args:
        optimizer: PyTorch optimizer
        scheduler_type (str): Type of scheduler
        **kwargs: Scheduler-specific arguments
        
    Returns:
        PyTorch scheduler
    """
    if scheduler_type.lower() == 'steplr':
        step_size = kwargs.get('step_size', 7)
        gamma = kwargs.get('gamma', 0.1)
        return StepLR(optimizer, step_size=step_size, gamma=gamma)
    elif scheduler_type.lower() == 'cosineannealinglr':
        T_max = kwargs.get('T_max', 10)
        eta_min = kwargs.get('eta_min', 0)
        return CosineAnnealingLR(optimizer, T_max=T_max, eta_min=eta_min)
    elif scheduler_type.lower() == 'reducelronplateau':
        mode = kwargs.get('mode', 'min')
        factor = kwargs.get('factor', 0.5)
        patience = kwargs.get('patience', 5)
        return ReduceLROnPlateau(optimizer, mode=mode, factor=factor, patience=patience)
    elif scheduler_type.lower() == 'onecyclelr':
        max_lr = kwargs.get('max_lr', 1e-3)
        total_steps = kwargs.get('total_steps', 1000)
        return OneCycleLR(optimizer, max_lr=max_lr, total_steps=total_steps)
    else:
        raise ValueError(f"Unsupported scheduler type: {scheduler_type}")


print("Optimizer and scheduler functions created successfully!")

## 3. Training Functions

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device, epoch_num=1, print_freq=100):
    """
    Train model for one epoch.
    
    Args:
        model: PyTorch model
        train_loader: Training data loader
        criterion: Loss function
        optimizer: Optimizer
        device: Device (cpu/cuda)
        epoch_num (int): Current epoch number
        print_freq (int): Print frequency
        
    Returns:
        tuple: (average_loss, accuracy)
    """
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_predictions = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        output = model(data)
        loss = criterion(output, target)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        pred = output.argmax(dim=1)
        correct_predictions += pred.eq(target.view_as(pred)).sum().item()
        total_predictions += target.size(0)
        
        # Print progress
        if batch_idx % print_freq == 0:
            current_acc = 100. * correct_predictions / total_predictions
            current_loss = running_loss / (batch_idx + 1)
            print(f'Epoch {epoch_num}, Batch [{batch_idx}/{len(train_loader)}]: '
                  f'Loss: {current_loss:.4f}, Acc: {current_acc:.2f}%')
    
    avg_loss = running_loss / len(train_loader)
    accuracy = 100. * correct_predictions / total_predictions
    
    return avg_loss, accuracy


def validate_epoch(model, val_loader, criterion, device):
    """
    Validate model for one epoch.
    
    Args:
        model: PyTorch model
        val_loader: Validation data loader
        criterion: Loss function
        device: Device (cpu/cuda)
        
    Returns:
        tuple: (average_loss, accuracy, predictions, targets)
    """
    model.eval()
    val_loss = 0.0
    correct_predictions = 0
    total_predictions = 0
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            
            # Forward pass
            output = model(data)
            val_loss += criterion(output, target).item()
            
            # Predictions
            pred = output.argmax(dim=1)
            correct_predictions += pred.eq(target.view_as(pred)).sum().item()
            total_predictions += target.size(0)
            
            # Store for detailed metrics
            all_predictions.extend(pred.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    
    avg_loss = val_loss / len(val_loader)
    accuracy = 100. * correct_predictions / total_predictions
    
    return avg_loss, accuracy, all_predictions, all_targets


def train_epoch_enhanced(model, train_loader, criterion, optimizer, device, 
                        gradient_accumulation_steps=1, max_grad_norm=1.0, 
                        epoch_num=1, print_freq=100):
    """
    Enhanced training epoch with gradient accumulation and clipping.
    
    Args:
        model: PyTorch model
        train_loader: Training data loader
        criterion: Loss function
        optimizer: Optimizer
        device: Device (cpu/cuda)
        gradient_accumulation_steps (int): Steps to accumulate gradients
        max_grad_norm (float): Maximum gradient norm for clipping
        epoch_num (int): Current epoch number
        print_freq (int): Print frequency
        
    Returns:
        tuple: (average_loss, accuracy)
    """
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_predictions = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        # Forward pass
        output = model(data)
        loss = criterion(output, target)
        
        # Scale loss for gradient accumulation
        loss = loss / gradient_accumulation_steps
        
        # Backward pass
        loss.backward()
        
        # Update weights every gradient_accumulation_steps
        if (batch_idx + 1) % gradient_accumulation_steps == 0:
            # Gradient clipping
            if max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            
            optimizer.step()
            optimizer.zero_grad()
        
        # Statistics
        running_loss += loss.item() * gradient_accumulation_steps
        pred = output.argmax(dim=1)
        correct_predictions += pred.eq(target.view_as(pred)).sum().item()
        total_predictions += target.size(0)
        
        # Print progress
        if batch_idx % print_freq == 0:
            current_acc = 100. * correct_predictions / total_predictions
            current_loss = running_loss / (batch_idx + 1)
            print(f'Epoch {epoch_num}, Batch [{batch_idx}/{len(train_loader)}]: '
                  f'Loss: {current_loss:.4f}, Acc: {current_acc:.2f}%')
    
    avg_loss = running_loss / len(train_loader)
    accuracy = 100. * correct_predictions / total_predictions
    
    return avg_loss, accuracy


print("Training functions created successfully!")

## 4. Complete Training Pipeline

In [None]:
def train_model(model, train_loader, val_loader, num_epochs=25, 
               loss_type='CrossEntropy', optimizer_type='Adam', scheduler_type='StepLR',
               lr=1e-3, weight_decay=1e-4, class_weights=None,
               save_path='best_model.pth', patience=10, device='cpu',
               enhanced_training=False, **kwargs):
    """
    Complete training pipeline for emotion recognition models.
    
    Args:
        model: PyTorch model to train
        train_loader: Training data loader
        val_loader: Validation data loader
        num_epochs (int): Number of epochs to train
        loss_type (str): Type of loss function
        optimizer_type (str): Type of optimizer
        scheduler_type (str): Type of scheduler
        lr (float): Learning rate
        weight_decay (float): Weight decay
        class_weights (torch.Tensor): Class weights for loss function
        save_path (str): Path to save best model
        patience (int): Early stopping patience
        device (str): Device to use
        enhanced_training (bool): Use enhanced training features
        **kwargs: Additional arguments
        
    Returns:
        dict: Training history
    """
    print(f"Starting training for {num_epochs} epochs...")
    print(f"Model: {model.__class__.__name__}")
    print(f"Loss: {loss_type}")
    print(f"Optimizer: {optimizer_type} (lr={lr})")
    print(f"Scheduler: {scheduler_type}")
    print(f"Device: {device}")
    print(f"Enhanced training: {enhanced_training}")
    print("-" * 50)
    
    # Move model to device
    model = model.to(device)
    
    # Create loss function
    criterion = get_loss_function(loss_type, class_weights, **kwargs)
    
    # Create optimizer
    if 'backbone_lr' in kwargs and 'classifier_lr' in kwargs:
        optimizer = get_differential_optimizer(
            model, kwargs['backbone_lr'], kwargs['classifier_lr'], 
            optimizer_type, weight_decay
        )
        print(f"Using differential learning rates: backbone={kwargs['backbone_lr']}, classifier={kwargs['classifier_lr']}")
    else:
        optimizer = get_optimizer(model, optimizer_type, lr, weight_decay)
    
    # Create scheduler
    scheduler = get_scheduler(optimizer, scheduler_type, **kwargs)
    
    # Training history
    history = {
        'train_losses': [],
        'train_accuracies': [],
        'val_losses': [],
        'val_accuracies': [],
        'learning_rates': [],
        'best_val_acc': 0.0,
        'best_epoch': 0
    }
    
    best_val_acc = 0.0
    patience_counter = 0
    best_model_state = None
    
    # Training loop
    for epoch in range(1, num_epochs + 1):
        start_time = time.time()
        
        # Training
        if enhanced_training:
            train_loss, train_acc = train_epoch_enhanced(
                model, train_loader, criterion, optimizer, device, 
                gradient_accumulation_steps=kwargs.get('gradient_accumulation_steps', 1),
                max_grad_norm=kwargs.get('max_grad_norm', 1.0),
                epoch_num=epoch,
                print_freq=kwargs.get('print_freq', 100)
            )
        else:
            train_loss, train_acc = train_epoch(
                model, train_loader, criterion, optimizer, device, 
                epoch_num=epoch,
                print_freq=kwargs.get('print_freq', 100)
            )
        
        # Validation
        val_loss, val_acc, val_pred, val_target = validate_epoch(
            model, val_loader, criterion, device
        )
        
        # Update learning rate
        if isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(val_loss)
        else:
            scheduler.step()
        
        # Get current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        
        # Update history
        history['train_losses'].append(train_loss)
        history['train_accuracies'].append(train_acc)
        history['val_losses'].append(val_loss)
        history['val_accuracies'].append(val_acc)
        history['learning_rates'].append(current_lr)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            history['best_val_acc'] = best_val_acc
            history['best_epoch'] = epoch
            best_model_state = copy.deepcopy(model.state_dict())
            patience_counter = 0
        else:
            patience_counter += 1
        
        # Print epoch summary
        epoch_time = time.time() - start_time
        print(f"\nEpoch {epoch}/{num_epochs} Summary:")
        print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        print(f"  LR: {current_lr:.2e}, Time: {epoch_time:.2f}s")
        print(f"  Best Val Acc: {best_val_acc:.2f}% (Epoch {history['best_epoch']})")
        print("-" * 50)
        
        # Early stopping
        if patience_counter >= patience:
            print(f"Early stopping triggered after {patience} epochs without improvement")
            break
    
    # Save best model
    if best_model_state is not None:
        torch.save({
            'epoch': history['best_epoch'],
            'model_state_dict': best_model_state,
            'optimizer_state_dict': optimizer.state_dict(),
            'best_val_acc': best_val_acc,
            'history': history
        }, save_path)
        print(f"Best model saved to {save_path}")
        
        # Load best model weights
        model.load_state_dict(best_model_state)
    
    print(f"\nTraining completed!")
    print(f"Best validation accuracy: {best_val_acc:.2f}% at epoch {history['best_epoch']}")
    
    return history


print("Complete training pipeline ready!")

## 5. Evaluation Functions

In [None]:
def evaluate_model(model, test_loader, device, label_map=None, print_results=True):
    """
    Comprehensive model evaluation.
    
    Args:
        model: Trained PyTorch model
        test_loader: Test data loader
        device: Device to use
        label_map (dict): Mapping from class names to indices
        print_results (bool): Whether to print results
        
    Returns:
        dict: Evaluation results
    """
    model.eval()
    all_predictions = []
    all_targets = []
    all_probabilities = []
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            
            output = model(data)
            probabilities = F.softmax(output, dim=1)
            predictions = output.argmax(dim=1)
            
            all_predictions.extend(predictions.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(all_targets, all_predictions)
    f1_macro = f1_score(all_targets, all_predictions, average='macro')
    f1_weighted = f1_score(all_targets, all_predictions, average='weighted')
    
    # Create class names
    if label_map is not None:
        class_names = [k for k, v in sorted(label_map.items(), key=lambda x: x[1])]
    else:
        unique_labels = sorted(list(set(all_targets)))
        class_names = [f'Class_{i}' for i in unique_labels]
    
    # Classification report
    report = classification_report(all_targets, all_predictions, 
                                 target_names=class_names, 
                                 output_dict=True, zero_division=0)
    
    # Confusion matrix
    cm = confusion_matrix(all_targets, all_predictions)
    
    results = {
        'accuracy': accuracy,
        'f1_macro': f1_macro,
        'f1_weighted': f1_weighted,
        'classification_report': report,
        'confusion_matrix': cm,
        'predictions': all_predictions,
        'targets': all_targets,
        'probabilities': np.array(all_probabilities),
        'class_names': class_names
    }
    
    if print_results:
        print(f"\nModel Evaluation Results:")
        print(f"{'='*50}")
        print(f"Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
        print(f"F1-Score (Macro): {f1_macro:.4f}")
        print(f"F1-Score (Weighted): {f1_weighted:.4f}")
        print(f"\nPer-Class Results:")
        for class_name in class_names:
            if class_name in report:
                precision = report[class_name]['precision']
                recall = report[class_name]['recall']
                f1 = report[class_name]['f1-score']
                support = report[class_name]['support']
                print(f"  {class_name:10s}: P={precision:.3f}, R={recall:.3f}, F1={f1:.3f}, N={support}")
    
    return results


def plot_confusion_matrix(confusion_matrix, class_names, normalize=True, figsize=(10, 8)):
    """
    Plot confusion matrix.
    
    Args:
        confusion_matrix (np.array): Confusion matrix
        class_names (list): List of class names
        normalize (bool): Whether to normalize the matrix
        figsize (tuple): Figure size
    """
    if normalize:
        cm = confusion_matrix.astype('float') / confusion_matrix.sum(axis=1)[:, np.newaxis]
        title = 'Normalized Confusion Matrix'
        fmt = '.2f'
    else:
        cm = confusion_matrix
        title = 'Confusion Matrix'
        fmt = 'd'
    
    plt.figure(figsize=figsize)
    sns.heatmap(cm, annot=True, fmt=fmt, cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title(title)
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.tight_layout()
    plt.show()


def plot_training_history(history, figsize=(15, 5)):
    """
    Plot training history.
    
    Args:
        history (dict): Training history
        figsize (tuple): Figure size
    """
    fig, axes = plt.subplots(1, 3, figsize=figsize)
    
    epochs = range(1, len(history['train_losses']) + 1)
    
    # Loss plot
    axes[0].plot(epochs, history['train_losses'], 'b-', label='Training Loss')
    axes[0].plot(epochs, history['val_losses'], 'r-', label='Validation Loss')
    axes[0].set_title('Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].legend()
    axes[0].grid(True)
    
    # Accuracy plot
    axes[1].plot(epochs, history['train_accuracies'], 'b-', label='Training Accuracy')
    axes[1].plot(epochs, history['val_accuracies'], 'r-', label='Validation Accuracy')
    axes[1].set_title('Accuracy')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].legend()
    axes[1].grid(True)
    
    # Learning rate plot
    axes[2].plot(epochs, history['learning_rates'], 'g-')
    axes[2].set_title('Learning Rate')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Learning Rate')
    axes[2].set_yscale('log')
    axes[2].grid(True)
    
    plt.tight_layout()
    plt.show()
    
    # Print best performance
    print(f"Best validation accuracy: {history['best_val_acc']:.2f}% at epoch {history['best_epoch']}")


print("Evaluation functions created successfully!")

## 6. Model Checkpointing

In [None]:
def save_checkpoint(model, optimizer, epoch, val_acc, history, filepath):
    """
    Save model checkpoint.
    
    Args:
        model: PyTorch model
        optimizer: PyTorch optimizer
        epoch (int): Current epoch
        val_acc (float): Validation accuracy
        history (dict): Training history
        filepath (str): Path to save checkpoint
    """
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_acc': val_acc,
        'history': history
    }
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved to {filepath}")


def load_checkpoint(model, optimizer, filepath, device='cpu'):
    """
    Load model checkpoint.
    
    Args:
        model: PyTorch model
        optimizer: PyTorch optimizer
        filepath (str): Path to checkpoint
        device (str): Device to load to
        
    Returns:
        tuple: (epoch, val_acc, history)
    """
    checkpoint = torch.load(filepath, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    epoch = checkpoint['epoch']
    val_acc = checkpoint['val_acc']
    history = checkpoint.get('history', {})
    
    print(f"Checkpoint loaded from {filepath}")
    print(f"Epoch: {epoch}, Validation Accuracy: {val_acc:.2f}%")
    
    return epoch, val_acc, history


def save_model_for_inference(model, filepath, label_map=None, transforms_config=None, 
                           model_config=None):
    """
    Save model for inference with metadata.
    
    Args:
        model: Trained PyTorch model
        filepath (str): Path to save model
        label_map (dict): Label mapping
        transforms_config (dict): Transform configuration
        model_config (dict): Model configuration
    """
    model_info = {
        'model_state_dict': model.state_dict(),
        'model_config': model_config or {},
        'label_map': label_map or {},
        'transforms_config': transforms_config or {},
        'model_class': model.__class__.__name__
    }
    
    torch.save(model_info, filepath)
    print(f"Model saved for inference: {filepath}")
    print(f"Model class: {model.__class__.__name__}")
    if label_map:
        print(f"Classes: {list(label_map.keys())}")


def load_model_for_inference(model_class, filepath, device='cpu'):
    """
    Load model for inference.
    
    Args:
        model_class: Model class constructor
        filepath (str): Path to model file
        device (str): Device to load to
        
    Returns:
        tuple: (model, label_map, transforms_config)
    """
    model_info = torch.load(filepath, map_location=device)
    
    # Get model configuration
    model_config = model_info.get('model_config', {})
    label_map = model_info.get('label_map', {})
    transforms_config = model_info.get('transforms_config', {})
    
    # Create model with configuration
    if model_config:
        model = model_class(**model_config)
    else:
        # Try to create with default parameters
        num_classes = len(label_map) if label_map else 7
        model = model_class(num_classes=num_classes)
    
    # Load state dict
    model.load_state_dict(model_info['model_state_dict'])
    model = model.to(device)
    model.eval()
    
    print(f"Model loaded for inference from {filepath}")
    print(f"Classes: {list(label_map.keys()) if label_map else 'Unknown'}")
    
    return model, label_map, transforms_config


print("Checkpointing functions created successfully!")

## 7. Quick Training Functions

In [None]:
def quick_train_baseline(model, train_loader, val_loader, num_epochs=20, device='cpu'):
    """
    Quick training function for baseline CNN models.
    
    Args:
        model: CNN baseline model
        train_loader: Training data loader
        val_loader: Validation data loader
        num_epochs (int): Number of epochs
        device (str): Device to use
        
    Returns:
        dict: Training history
    """
    return train_model(
        model, train_loader, val_loader,
        num_epochs=num_epochs,
        loss_type='CrossEntropy',
        optimizer_type='Adam',
        scheduler_type='StepLR',
        lr=1e-3,
        weight_decay=1e-4,
        save_path='baseline_model.pth',
        device=device,
        step_size=7,
        gamma=0.1
    )


def quick_train_transfer(model, train_loader, val_loader, num_epochs=20, 
                        class_weights=None, device='cpu'):
    """
    Quick training function for transfer learning models.
    
    Args:
        model: Transfer learning model
        train_loader: Training data loader
        val_loader: Validation data loader
        num_epochs (int): Number of epochs
        class_weights (torch.Tensor): Class weights
        device (str): Device to use
        
    Returns:
        dict: Training history
    """
    return train_model(
        model, train_loader, val_loader,
        num_epochs=num_epochs,
        loss_type='CrossEntropy',
        optimizer_type='Adam',
        scheduler_type='ReduceLROnPlateau',
        lr=1e-3,
        weight_decay=1e-4,
        class_weights=class_weights,
        save_path='transfer_model.pth',
        device=device,
        backbone_lr=1e-4,  # Lower LR for pre-trained features
        classifier_lr=1e-3,  # Higher LR for new classifier
        mode='min',
        factor=0.5,
        patience=5
    )


def quick_train_enhanced(model, train_loader, val_loader, num_epochs=30, 
                        class_weights=None, device='cpu'):
    """
    Quick training function for enhanced models with advanced features.
    
    Args:
        model: Enhanced model
        train_loader: Training data loader
        val_loader: Validation data loader
        num_epochs (int): Number of epochs
        class_weights (torch.Tensor): Class weights
        device (str): Device to use
        
    Returns:
        dict: Training history
    """
    return train_model(
        model, train_loader, val_loader,
        num_epochs=num_epochs,
        loss_type='WeightedFocal',
        optimizer_type='AdamW',
        scheduler_type='CosineAnnealingLR',
        lr=1e-3,
        weight_decay=1e-4,
        class_weights=class_weights,
        save_path='enhanced_model.pth',
        device=device,
        enhanced_training=True,
        alpha=1,
        gamma=2,
        T_max=num_epochs,
        gradient_accumulation_steps=2,
        max_grad_norm=1.0
    )


print("Quick training functions created successfully!")

## Summary

This notebook provides a complete training pipeline for visual emotion recognition:

### Core Components:
1. **Loss Functions**: CrossEntropy, Label Smoothing, Focal Loss, Weighted Focal Loss
2. **Optimizers**: Adam, AdamW, SGD, RMSprop with differential learning rates
3. **Schedulers**: StepLR, CosineAnnealingLR, ReduceLROnPlateau, OneCycleLR
4. **Training Functions**: Basic and enhanced training with gradient accumulation/clipping
5. **Evaluation**: Comprehensive metrics, confusion matrices, classification reports
6. **Visualization**: Training curves, confusion matrices
7. **Checkpointing**: Save/load models with metadata

### Key Features:
- **Flexible Training**: Support for different models, losses, optimizers
- **Advanced Training**: Gradient accumulation, clipping, differential learning rates
- **Early Stopping**: Prevent overfitting with patience-based stopping
- **Class Balancing**: Handle imbalanced datasets with weighted losses
- **Comprehensive Evaluation**: Multiple metrics and visualizations
- **Easy Checkpointing**: Save/load models with full metadata
- **Quick Functions**: Pre-configured training for common scenarios

All functionality is self-contained within this notebook and doesn't require the src folder structure.

### Usage Examples:
```python
# Basic training
history = train_model(model, train_loader, val_loader, num_epochs=25)

# Enhanced training with focal loss
history = train_model(
    model, train_loader, val_loader,
    loss_type='WeightedFocal', 
    enhanced_training=True,
    alpha=1, gamma=2
)

# Evaluation
results = evaluate_model(model, test_loader, device, label_map)
plot_confusion_matrix(results['confusion_matrix'], results['class_names'])
```