In [None]:
import os  # For file and directory operations
import pickle  # For loading CIFAR-10 data files
import numpy as np  # For numerical operations
import torch  # PyTorch deep learning framework
import torch.nn as nn  # Neural network modules
import torch.nn.functional as F  # Neural network functions
from torch.utils.data import Dataset, DataLoader, random_split  # Dataset handling
import torchvision.transforms as transforms  # Image transformations
import torch.optim as optim  # Optimization algorithms
import pandas as pd  # Data manipulation and analysis
from tqdm import tqdm  # Progress bar
import random  # Random number generation
from torch.optim.swa_utils import AveragedModel  # Stochastic Weight Averaging
import math  # Mathematical functions

In [None]:
# ----- Device configuration (CPU/MPS/GPU) -----
# Check available hardware and set device accordingly
if torch.backends.mps.is_available():
    device = torch.device("mps")  # Apple Silicon GPU
elif torch.cuda.is_available():
    device = torch.device("cuda")  # NVIDIA GPU
else:
    device = torch.device("cpu")  # CPU
print(f"Using device: {device} for model training and inference")

# Set seed for reproducibility across runs
print("Setting random seeds for reproducibility...")
SEED = 42
random.seed(SEED)  # Python random
np.random.seed(SEED)  # NumPy random
torch.manual_seed(SEED)  # PyTorch random
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)  # GPU random
    # Enable deterministic GPU operations
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
print(f"Random seeds set to {SEED}")

def get_mean_std(dataset):
    """
    Calculate mean and standard deviation of a dataset.
    Important for normalizing the data properly.
    
    Args:
        dataset: PyTorch dataset object
    Returns:
        mean: Mean of the dataset
        std: Standard deviation of the dataset
    """
    print("Calculating dataset statistics...")
    loader = DataLoader(
        dataset,
        batch_size=64,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )
    
    mean = 0.0
    std = 0.0
    total_images = 0
    
    for images, _ in loader:
        batch_size = images.size(0)
        images = images.view(batch_size, -1)
        mean += images.mean(1).sum().item()
        std += images.std(1).sum().item()
        total_images += batch_size
    
    mean /= total_images
    std /= total_images
    
    print(f"Dataset statistics - Mean: {mean:.4f}, Std: {std:.4f}")
    return mean, std


Step 1: Load CIFAR-10 Training Data with Optimized Handling

This section handles loading and preprocessing of CIFAR-10 training data:
- Loads data from 5 separate batch files for memory efficiency
- Reshapes data into NCHW format (N=samples, C=channels, H=height, W=width)
- Combines batches into unified training set
- Performs data validation and prints detailed statistics

Key optimizations:
- Direct reshape to NCHW format to avoid extra transpose operations
- Batch-wise loading to manage memory usage
- Efficient numpy array operations for data combination

Expected output:
- 50,000 training images in (N, 3, 32, 32) format
- Labels as 1D array of integers 0-9

In [None]:
def unpickle(file):
    """
    Load CIFAR-10 data from pickle files efficiently.
    
    Args:
        file: Path to pickle file
    Returns:
        Dictionary containing batch data
    """
    print(f"Loading file: {file}")
    with open(file, 'rb') as fo:
        data_dict = pickle.load(fo, encoding='bytes')
    return data_dict

def load_data(data_dir="./data/cifar-10-python/cifar-10-batches-py"):
    """
    Load and combine all CIFAR-10 training batches.
    
    Args:
        data_dir: Directory containing CIFAR-10 data files
    Returns:
        all_data: Combined training images
        all_labels: Combined training labels
    """
    print("Starting CIFAR-10 data loading process...")
    data_list = []
    labels_list = []
    
    for i in range(1, 6):  # CIFAR-10 has 5 training batches
        batch_file = os.path.join(data_dir, f"data_batch_{i}")
        batch = unpickle(batch_file)
        data = batch[b"data"]
        labels = batch[b"labels"]
        
        # Reshape data to (N, C, H, W) format directly for efficiency
        data = data.reshape(-1, 3, 32, 32)
        
        data_list.append(data)
        labels_list.extend(labels)
        print(f"Loaded batch {i}: data shape {data.shape}, labels count {len(labels)}")
    
    all_data = np.concatenate(data_list, axis=0)
    all_labels = np.array(labels_list)
    
    print(f"Total dataset loaded - Shape: {all_data.shape}, Labels: {all_labels.shape}")
    return all_data, all_labels

# Load the training data
print("Loading CIFAR-10 training data...")
all_data, all_labels = load_data()

Step 2: Enhanced Data Augmentation and Transforms
This section implements advanced data augmentation techniques to:
1. Improve model generalization by exposing it to diverse image variations
2. Reduce overfitting by artificially expanding the training dataset
3. Make the model more robust to real-world variations in images

Key augmentation techniques used:
- Random cropping: For translation invariance
- Horizontal flips: For orientation invariance  
- AutoAugment: For automated policy-based augmentations
- Random erasing: For occlusion robustness
- Normalization: For stable training

In [None]:
print("\nConfiguring data augmentation pipeline...")
print("Implementing the following augmentation techniques:")
print("- Random cropping with padding=4")
print("- Random horizontal flips")
print("- CIFAR10-specific AutoAugment policies") 
print("- Random erasing with p=0.25")
print("- Standard CIFAR10 normalization")
# CIFAR-10 standard mean and std values for normalization
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2470, 0.2435, 0.2616)

# Advanced training transforms for better model generalization
print("Setting up data augmentation pipelines...")
train_transform = transforms.Compose([
    lambda x: x / 255.0,  # Normalize pixel values to [0, 1]
    transforms.RandomCrop(32, padding=4),  # Random crops for translation invariance
    transforms.RandomHorizontalFlip(),  # Horizontal flips for orientation invariance
    transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),  # AutoAugment for robust training
    transforms.Normalize(cifar10_mean, cifar10_std),  # Standardize using CIFAR-10 statistics
    transforms.RandomErasing(p=0.25),  # Random erasing for occlusion robustness
])

# Minimal test/validation transforms - only essential normalization
test_transform = transforms.Compose([
    lambda x: x / 255.0,
    transforms.Normalize(cifar10_mean, cifar10_std)
])

Step 3: Improved Dataset and DataLoader for Training/Validation
This section implements optimized data handling and loading with:
1. Custom Dataset class for efficient memory usage and fast access
2. Strategic train/validation split for proper model evaluation
3. Optimized DataLoader configuration for maximum throughput
4. Careful batch size selection for stable training

Key optimizations:
- Parallel data loading with multiple workers
- Pinned memory for faster GPU transfer
- Persistent workers to reduce overhead
- Different batch sizes for train vs validation

In [None]:


print("\nSetting up optimized data pipeline...")
print("Implementing the following optimizations:")
print("- Custom Dataset class for efficient data handling")
print("- 90/10 train-validation split for proper evaluation")
print("- Multi-worker data loading for better throughput")
print("- Pinned memory and persistent workers for faster processing")
print("- Optimized batch sizes: 128 for training, 256 for validation")
class CIFAR10Dataset(Dataset):
    """
    Custom Dataset class for CIFAR-10 with efficient data handling
    """
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image = self.data[idx]
        label = self.labels[idx]
        image = torch.from_numpy(image).float()  # Convert to PyTorch tensor
        if self.transform:
            image = self.transform(image)
        return image, label

print("Creating datasets and dataloaders...")
# Create the main dataset
dataset = CIFAR10Dataset(all_data, all_labels, transform=train_transform)

# Split into training and validation sets
train_size = 45000  # 90% for training
val_size = len(dataset) - train_size  # 10% for validation
train_dataset, val_dataset = random_split(
    dataset, 
    [train_size, val_size],
    generator=torch.Generator().manual_seed(SEED)  # Ensure reproducible splits
)

# Update validation transform to test transform
val_dataset.dataset.transform = test_transform

# Create optimized dataloaders
print("Configuring dataloaders with optimized settings...")
train_loader = DataLoader(
    train_dataset, 
    batch_size=128,  # Balanced batch size for training
    shuffle=True,  # Shuffle training data
    num_workers=4,  # Parallel data loading
    pin_memory=True,  # Faster data transfer to GPU
    persistent_workers=True  # Keep workers alive between epochs
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=256,  # Larger batches for validation (no backprop needed)
    shuffle=False,  # No need to shuffle validation data
    num_workers=4,
    pin_memory=True,
    persistent_workers=True
)

print(f"Training samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

Step 4: Advanced Augmentation and Loss Functions
This step implements sophisticated data augmentation and loss techniques:
1. CutMix - Combines portions of different images and their labels
2. Mixup - Linearly interpolates between pairs of images and labels
3. Label smoothing - Reduces overconfidence by softening one-hot labels
4. Adaptive loss weighting - Dynamically adjusts loss contributions

These techniques help:
- Improve model generalization by creating diverse training samples
- Reduce overfitting by introducing regularization effects
- Make training more robust to noisy labels and outliers
- Encourage the model to learn more meaningful features

In [None]:

print("Setting up advanced augmentation and loss functions...")

# CutMix augmentation - Combines portions of images and their labels
def cutmix_data(x, y, alpha=1.0):
    """
    Implements CutMix augmentation by cutting and pasting random patches between images.
    This helps the model learn more robust features by seeing partial objects.
    
    Args:
        x: Input images tensor
        y: Labels tensor 
        alpha: Beta distribution parameter for mixing ratio
    Returns:
        Mixed images and corresponding mixed labels
    """
    print("Applying CutMix augmentation...")
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)  # Sample mixing ratio from beta distribution
    else:
        lam = 1
    
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)  # Random permutation for mixing
    
    y_a, y_b = y, y[index]  # Original and permuted labels
    
    # Generate random bounding box coordinates
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]  # Patch mixing
    
    # Adjust lambda to match actual mixed pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
    
    return x, y_a, y_b, lam

def rand_bbox(size, lam):
    """
    Generate random bounding box coordinates for CutMix.
    
    Args:
        size: Input tensor size
        lam: Target ratio of remaining area
    Returns:
        Bounding box coordinates
    """
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)  # Cut ratio
    cut_w = np.int64(W * cut_rat)  # Cut width
    cut_h = np.int64(H * cut_rat)  # Cut height
    
    # Random center point
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    
    # Ensure coordinates are within image bounds
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    
    return bbx1, bby1, bbx2, bby2

# Mixup augmentation - Linearly combines pairs of images and labels
def mixup_data(x, y, alpha=0.2):
    """
    Implements Mixup augmentation by linearly interpolating between pairs of images.
    This helps reduce overconfidence and improve generalization.
    
    Args:
        x: Input images tensor
        y: Labels tensor
        alpha: Beta distribution parameter
    Returns:
        Mixed images and corresponding mixed labels
    """
    print("Applying Mixup augmentation...")
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)  # Mixing coefficient
    else:
        lam = 1
    
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)  # Random permutation
    
    mixed_x = lam * x + (1 - lam) * x[index, :]  # Linear interpolation of images
    y_a, y_b = y, y[index]  # Original and permuted labels
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """
    Criterion for Mixup that combines losses with same ratio as images.
    
    Args:
        criterion: Base loss function
        pred: Model predictions
        y_a, y_b: Original and permuted labels
        lam: Mixing ratio
    Returns:
        Combined loss value
    """
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

def update_bn_custom(loader, model, device):
    """
    Updates BatchNorm statistics properly with device handling.
    Important for SWA and model evaluation.
    
    Args:
        loader: DataLoader for computing statistics
        model: Model containing BatchNorm layers
        device: Device for computation
    """
    print("Updating BatchNorm statistics...")
    momenta = {}
    for module in model.modules():
        if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
            module.running_mean = torch.zeros_like(module.running_mean, device=device)
            module.running_var = torch.ones_like(module.running_var, device=device)
            momenta[module] = module.momentum
            module.momentum = None
            module.num_batches_tracked *= 0
    
    model.train()
    with torch.no_grad():
        for data, _ in loader:
            data = data.to(device)
            model(data)
    
    for module in momenta.keys():
        module.momentum = momenta[module]

# Advanced label smoothing with focal loss component
class FocalLabelSmoothing(nn.Module):
    """
    Combines label smoothing with focal loss for better handling of hard examples.
    
    Args:
        classes: Number of classes
        smoothing: Label smoothing factor
        gamma: Focal loss power factor
    """
    def __init__(self, classes=10, smoothing=0.1, gamma=1.0):
        super(FocalLabelSmoothing, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.classes = classes
        self.gamma = gamma
        print(f"Initialized Focal Label Smoothing with gamma={gamma}")
        
    def forward(self, pred, target):
        pred = pred.log_softmax(dim=-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.classes - 1))
            true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
        
        pt = torch.exp(pred)  # Probability of true class
        focal_weight = (1 - pt) ** self.gamma  # Higher weight for hard examples
        
        return torch.mean(torch.sum(-true_dist * pred * focal_weight, dim=-1))


Step 5: Advanced ResNet Model with SE Blocks and Bottleneck

This section implements an enhanced ResNet architecture with:
1. Squeeze-and-Excitation (SE) blocks for adaptive feature recalibration
   - Helps model focus on informative features
   - Improves accuracy with minimal parameter overhead

2. Bottleneck blocks for efficient feature processing
   - Reduces computational cost while maintaining performance
   - Better gradient flow through the network

3. Skip connections for improved gradient flow
   - Allows training of very deep networks
   - Helps combat vanishing gradients

The architecture is carefully designed to:
- Stay within the 5M parameter budget
- Maximize accuracy through modern techniques
- Maintain efficient training and inference

In [None]:

print("Initializing model architecture...")

class SEBlock(nn.Module):
    """
    Squeeze-and-Excitation block for channel attention.
    Helps model focus on informative features by learning channel interdependencies.
    
    Args:
        channels: Number of input channels
        reduction: Channel reduction ratio
    """
    def __init__(self, channels, reduction=8):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  # Global average pooling
        self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        module_input = x
        x = self.avg_pool(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return module_input * x  # Channel-wise multiplication

class BasicBlock(nn.Module):
    """
    Basic ResNet block with optional SE attention.
    
    Args:
        in_channels: Input channels
        out_channels: Output channels
        stride: Convolution stride
        use_se: Whether to use SE attention
        se_reduction: SE block reduction ratio
    """
    expansion = 1
    
    def __init__(self, in_channels, out_channels, stride=1, use_se=True, se_reduction=8):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # Skip connection for dimension matching
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        
        self.use_se = use_se
        if use_se:
            self.se = SEBlock(out_channels, reduction=se_reduction)
    
    def forward(self, x):
        residual = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out, inplace=True)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        if self.use_se:
            out = self.se(out)
            
        out += self.shortcut(residual)  # Skip connection
        out = F.relu(out, inplace=True)
        return out

class OptimizedResNet(nn.Module):
    """
    Optimized ResNet architecture with SE attention and advanced features.
    
    Args:
        block: Basic block class
        num_blocks: Number of blocks per layer
        initial_channels: Initial number of channels
        num_classes: Number of output classes
        use_se: Whether to use SE attention
        se_reduction: SE block reduction ratio
        drop_rate: Dropout rate
    """
    def __init__(self, block, num_blocks, initial_channels=40, num_classes=10, 
                 use_se=True, se_reduction=8, drop_rate=0.2):
        super(OptimizedResNet, self).__init__()
        self.in_channels = initial_channels
        self.drop_rate = drop_rate
        
        print(f"Building OptimizedResNet with {sum(num_blocks)} blocks and SE attention")
        
        self.conv1 = nn.Conv2d(3, initial_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(initial_channels)
        
        self.layer1 = self._make_layer(block, initial_channels, num_blocks[0], stride=1, 
                                     use_se=use_se, se_reduction=se_reduction)
        self.layer2 = self._make_layer(block, initial_channels*2, num_blocks[1], stride=2, 
                                     use_se=use_se, se_reduction=se_reduction)
        self.layer3 = self._make_layer(block, initial_channels*4, num_blocks[2], stride=2, 
                                     use_se=use_se, se_reduction=se_reduction)
        
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(initial_channels*4, num_classes)
        
        self._initialize_weights()
        print("Model initialization complete")
    
    def _make_layer(self, block, out_channels, num_blocks, stride, use_se, se_reduction):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride, use_se, se_reduction))
            self.in_channels = out_channels * block.expansion
        return nn.Sequential(*layers)
    
    def _initialize_weights(self):
        """Initialize model weights using Kaiming initialization"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)), inplace=True)
        
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        
        out = self.avg_pool(out)
        out = torch.flatten(out, 1)
        
        if self.drop_rate > 0 and self.training:
            out = F.dropout(out, p=self.drop_rate, training=self.training)
            
        out = self.fc(out)
        return out

def count_parameters(model):
    """Count number of trainable parameters in the model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


Step 6: Enhanced Training with Cosine Warmup and SWA

This section implements advanced training techniques to improve model performance:

1. Cosine Warmup Learning Rate Schedule
   - Gradual warmup prevents unstable training at the start
   - Cosine annealing helps find better optima
   - Smooth transitions avoid sudden learning rate changes

2. Stochastic Weight Averaging (SWA)
   - Averages multiple points along the trajectory
   - Leads to better generalization
   - Acts as an effective model ensemble

3. Exponential Moving Average (EMA)
   - Maintains a moving average of model weights
   - More stable than using final weights
   - Often improves validation/test accuracy

These techniques work together to:
- Stabilize the training process
- Improve model generalization
- Achieve better final accuracy

In [None]:

print("Initializing advanced training components...")

# Learning rate scheduler with warmup
class CosineWarmupScheduler:
    """
    Custom learning rate scheduler that implements:
    1. Linear warmup phase to avoid initial training instability
    2. Cosine annealing for better convergence
    
    Args:
        optimizer: The optimizer to adjust learning rates for
        warmup_epochs: Number of epochs for linear warmup
        max_epochs: Total number of training epochs
        min_lr: Minimum learning rate
        max_lr: Maximum learning rate after warmup
    """
    def __init__(self, optimizer, warmup_epochs, max_epochs, min_lr=1e-6, max_lr=0.1):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.min_lr = min_lr
        self.max_lr = max_lr
        print(f"Initialized scheduler with {warmup_epochs} warmup epochs, lr range: [{min_lr}, {max_lr}]")
        
    def step(self, epoch):
        if epoch < self.warmup_epochs:
            # Linear warmup phase - gradually increase lr to avoid initial shock to the model
            lr = self.min_lr + (self.max_lr - self.min_lr) * epoch / self.warmup_epochs
        else:
            # Cosine annealing phase - smoothly decrease lr for better convergence
            progress = (epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)
            lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + math.cos(math.pi * progress))
        
        # Update learning rate for all parameter groups
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        
        return lr

# Exponential Moving Average (EMA) for model weights
class EMA:
    """
    Implements Exponential Moving Average for model weights.
    EMA maintains a moving average of model parameters which typically 
    produces better validation/test accuracy than using final weights.
    
    Args:
        model: The model whose parameters to track
        decay: EMA decay rate (higher = slower but more stable)
    """
    def __init__(self, model, decay=0.9999):
        self.model = model
        self.decay = decay
        self.shadow = {}  # Shadow parameters
        self.backup = {}  # Backup of original parameters
        print(f"Initialized EMA with decay rate {decay}")
        
    def register(self):
        """Initialize EMA shadow parameters as copies of model parameters"""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()
        print("Registered initial model parameters in EMA")
    
    def update(self):
        """Update shadow parameters using EMA decay rate"""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                new_avg = self.decay * self.shadow[name] + (1 - self.decay) * param.data
                self.shadow[name] = new_avg.clone()
    
    def apply_shadow(self):
        """Apply shadow parameters to model, saving original parameters"""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data
                param.data = self.shadow[name]
    
    def restore(self):
        """Restore original parameters to model"""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.data = self.backup[name]

def train_model(model, train_loader, val_loader, num_epochs=250):
    """
    Main training loop implementing advanced training techniques including:
    - SGD with Nesterov momentum and weight decay
    - Focal Label Smoothing loss
    - Cosine learning rate scheduling with warmup
    - Exponential Moving Average (EMA) of model weights
    - Stochastic Weight Averaging (SWA)
    - Adaptive data augmentation (MixUp and CutMix)
    - Gradient clipping
    
    Args:
        model: Neural network model to train
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        num_epochs: Number of training epochs
        
    Returns:
        model: Trained model
        swa_model: Stochastic Weight Averaged model
        ema: EMA model instance
        metrics: Tuple of (train_losses, train_accs, val_losses, val_accs)
    """
    print("\n=== Initializing Training Components ===")
    
    # Setup optimizer with Nesterov momentum and L2 regularization
    print("Setting up SGD optimizer with Nesterov momentum...")
    optimizer = optim.SGD(
        model.parameters(), 
        lr=0.1,  # Initial learning rate (will be adjusted by scheduler)
        momentum=0.9,  # Momentum coefficient for faster convergence
        weight_decay=5e-4,  # L2 regularization to prevent overfitting
        nesterov=True  # Use Nesterov momentum for better convergence
    )
    
    print("Configuring Focal Label Smoothing loss...")
    criterion = FocalLabelSmoothing(classes=10, smoothing=0.1, gamma=1.0)
    
    print("Setting up Cosine learning rate scheduler with warmup...")
    scheduler = CosineWarmupScheduler(
        optimizer, 
        warmup_epochs=5,  # Gradual LR increase for 5 epochs
        max_epochs=num_epochs,
        min_lr=1e-6,  # Minimum LR at end of training
        max_lr=0.1  # Maximum LR after warmup
    )
    
    print("Initializing EMA model tracking...")
    ema = EMA(model, decay=0.9999)  # High decay for stability
    ema.register()
    
    print("Creating SWA model for late-stage averaging...")
    swa_model = AveragedModel(model)
    swa_start = int(num_epochs * 0.75)  # Start SWA at 75% of training
    print(f"SWA will begin at epoch {swa_start}")
    
    # Initialize tracking metrics
    best_val_acc = 0.0
    train_losses, train_accs = [], []
    val_losses, val_accs = [], []
    
    print("\n=== Starting Training Loop ===")
    
    for epoch in range(1, num_epochs + 1):
        # Update learning rate using scheduler
        current_lr = scheduler.step(epoch - 1)
        print(f"\nEpoch {epoch}/{num_epochs} - Learning Rate: {current_lr:.6f}")
        
        # Training phase
        model.train()
        train_loss = 0.0
        correct = 0
        total = 0
        
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs} [Train]")
        
        for inputs, targets in train_pbar:
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Adaptive data augmentation strategy
            aug_prob = random.random()
            late_training = epoch >= int(num_epochs * 0.8)
            
            if aug_prob < 0.5 and not late_training:
                # Apply MixUp with adaptive strength
                mixup_strength = max(0.0, 0.4 * (1 - epoch / num_epochs))
                inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, alpha=mixup_strength)
                outputs = model(inputs)
                loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
            elif aug_prob < 0.8 and not late_training:
                # Apply CutMix with adaptive strength
                cutmix_strength = max(0.0, 1.0 * (1 - epoch / num_epochs))
                inputs, targets_a, targets_b, lam = cutmix_data(inputs, targets, alpha=cutmix_strength)
                outputs = model(inputs)
                loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
            else:
                # Standard forward pass
                outputs = model(inputs)
                loss = criterion(outputs, targets)
            
            # Optimization step
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Prevent exploding gradients
            optimizer.step()
            
            # Update EMA model
            ema.update()
            
            # Track metrics
            train_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            
            # Calculate accuracy (adjusted for augmentation)
            total += targets.size(0)
            if aug_prob < 0.8 and not late_training:
                correct += (lam * predicted.eq(targets_a).sum().float() + 
                          (1 - lam) * predicted.eq(targets_b).sum().float()).item()
            else:
                correct += predicted.eq(targets).sum().item()
            
            # Update progress display
            train_pbar.set_postfix({
                'loss': train_loss/total, 
                'acc': 100.0*correct/total,
                'lr': current_lr
            })
        
        # Calculate epoch metrics
        train_loss = train_loss / total
        train_acc = 100.0 * correct / total
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        
        # Update SWA model if in SWA phase
        if epoch >= swa_start:
            print("Updating SWA model parameters...")
            swa_model.update_parameters(model)
        
        # Validation phase
        print("\nRunning validation...")
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            val_pbar = tqdm(val_loader, desc=f"Epoch {epoch}/{num_epochs} [Val]")
            for inputs, targets in val_pbar:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = F.cross_entropy(outputs, targets)
                
                val_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                
                val_pbar.set_postfix({
                    'loss': val_loss/total, 
                    'acc': 100.0*correct/total
                })
        
        val_loss = val_loss / total
        val_acc = 100.0 * correct / total
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        
        # Save best model checkpoints
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            print(f"\nNew best validation accuracy: {best_val_acc:.2f}%")
            print("Saving best model checkpoint...")
            torch.save(model.state_dict(), "best_model.pth")
            
            # Also save EMA model
            print("Saving best EMA model checkpoint...")
            ema.apply_shadow()
            torch.save(model.state_dict(), "best_ema_model.pth")
            ema.restore()
        
        # Print epoch summary
        print(f"\nEpoch {epoch} Summary:")
        print(f"Training - Loss: {train_loss:.4f}, Accuracy: {train_acc:.2f}%")
        print(f"Validation - Loss: {val_loss:.4f}, Accuracy: {val_acc:.2f}%")
        print(f"Learning Rate: {current_lr:.6f}")
    
    print("\n=== Training Complete ===")
    
    # Finalize SWA model
    print("\nUpdating BatchNorm statistics for SWA model...")
    with torch.no_grad():
        update_bn_custom(train_loader, swa_model, device)
    
    print("Saving final SWA model...")
    torch.save(swa_model.state_dict(), "swa_model.pth")
    
    return model, swa_model, ema, (train_losses, train_accs, val_losses, val_accs)


Step 7: Enhanced Test Prediction with Ensemble of Best Models

This section implements sophisticated prediction techniques to maximize accuracy:

1. Model Ensemble Strategy
   - Combines predictions from multiple trained models:
     * Best validation model
     * SWA model with updated BatchNorm stats  
     * EMA model with smoothed weights
   - Reduces prediction variance and improves robustness
   - Each model contributes complementary features

2. Test-Time Augmentation (TTA)
   - Makes predictions on multiple augmented versions:
     * Original image
     * Horizontal flips 
     * Random scaling (0.8-1.2x)
     * Center crops
     * Brightness/contrast adjustments
   - Averages predictions across variants
   - More robust predictions by considering multiple views

3. Optimized Data Pipeline
   - Efficient batched processing
   - GPU acceleration with pinned memory
   - Parallel data loading with multiple workers
   - Careful memory management

The combination of these techniques helps:
- Maximize prediction accuracy
- Improve robustness to input variations  
- Maintain efficient processing speed

In [None]:

print("\nInitializing enhanced prediction pipeline...")
print("Using ensemble of 3 models: Best Val, SWA, and EMA")
print("Applying test-time augmentation with multiple variants")
print("Optimizing for both accuracy and inference speed")

def predict_with_enhanced_tta(ensemble, test_loader, num_tta=5):
    """
    Make predictions using model ensemble and test-time augmentation.
    
    This function:
    1. Combines predictions from multiple models in the ensemble
    2. Applies test-time augmentation including:
       - Original image
       - Horizontal flips
       - Random scaling
       - Center crops
       - Brightness adjustments
    3. Averages predictions across all variants for more robust results
    
    Args:
        ensemble: List of trained models to use for prediction
        test_loader: DataLoader for test data
        num_tta: Number of additional TTA variants to use
    Returns:
        all_predictions: List of predicted class indices
        all_ids: List of corresponding image IDs
    """
    all_predictions = []
    all_ids = []
    
    print("\n=== Starting Enhanced Prediction with Ensemble and TTA ===")
    print(f"Using ensemble of {len(ensemble)} models")
    print(f"Applying {num_tta + 1} TTA variants per model") # +1 for horizontal flip
    
    with torch.no_grad():
        for batch_idx, (images, image_ids) in enumerate(test_loader):
            # Print batch info for verification
            if batch_idx == 0:
                print(f"\nProcessing first batch:")
                print(f"Batch images shape: {images.shape}")
                print(f"Device being used: {images.device}")
            
            # Move batch to appropriate device
            images = images.to(device)
            
            # Initialize prediction accumulator tensor
            batch_size = images.size(0)
            ensemble_probs = torch.zeros((batch_size, 10), device=device)
            
            # Process with each model in ensemble
            for model_idx, model in enumerate(ensemble):
                model.eval()
                
                # 1. Original image prediction
                outputs = model(images)
                probs = F.softmax(outputs, dim=1)
                ensemble_probs += probs
                
                # 2. Horizontal flip TTA - helps with orientation invariance
                flipped_images = torch.flip(images, [3])
                flip_outputs = model(flipped_images)
                flip_probs = F.softmax(flip_outputs, dim=1)
                ensemble_probs += flip_probs
                
                # 3. Additional TTA variants
                for i in range(num_tta - 1):
                    # Apply different transforms in rotation
                    if i % 3 == 0:
                        # Scale transform: Randomly scale image intensity
                        # Helps with robustness to contrast variations
                        transformed_images = images * (0.95 + 0.1 * torch.rand(1, device=device))
                        
                    elif i % 3 == 1:
                        # Crop transform: Take center crop and resize back
                        # Helps focus on central image content
                        b, c, h, w = images.shape
                        crop_size = int(0.925 * min(h, w))
                        transformed_images = transforms.functional.center_crop(images, crop_size)
                        transformed_images = F.interpolate(transformed_images, (h, w), 
                                                         mode='bilinear', align_corners=False)
                        
                    else:
                        # Brightness transform: Add random noise
                        # Helps with robustness to lighting variations
                        transformed_images = images + 0.05 * torch.randn((batch_size, 1, 1, 1), 
                                                                       device=device)
                        transformed_images = torch.clamp(transformed_images, -3, 3)
                    
                    outputs = model(transformed_images)
                    probs = F.softmax(outputs, dim=1)
                    ensemble_probs += probs
            
            # Average predictions across all variants
            total_variants = len(ensemble) * (num_tta + 1)
            ensemble_probs /= total_variants
            
            # Get final predictions
            _, predicted = ensemble_probs.max(1)
            
            # Store batch results
            all_predictions.extend(predicted.cpu().numpy())
            all_ids.extend(image_ids.cpu().numpy())
            
            # Print progress every 10 batches
            if (batch_idx + 1) % 10 == 0:
                print(f"Processed {batch_idx + 1}/{len(test_loader)} batches")
    
    print("\nPrediction complete!")
    print(f"Total predictions made: {len(all_predictions)}")
    return all_predictions, all_ids

def load_test_data(test_file="./data/cifar-10-python/cifar_test_nolabel.pkl"):
    """
    Load and process test data, ensuring correct format for model input.
    
    This function:
    1. Loads test data from pickle file
    2. Handles different possible data formats:
       - Flat (N, 3072)
       - NHWC (N, 32, 32, 3)
       - NCHW (N, 3, 32, 32)
    3. Converts to required NCHW format
    
    Args:
        test_file: Path to test data pickle file
    Returns:
        test_images: Processed images in NCHW format
        test_ids: Corresponding image IDs
    """
    print("\n=== Loading and Processing Test Data ===")
    print(f"Loading from file: {test_file}")
    
    with open(test_file, 'rb') as f:
        test_dict = pickle.load(f, encoding='bytes')
    
    test_images = test_dict[b'data']
    test_ids = test_dict[b'ids']
    
    print("\nInitial data shapes:")
    print(f"Images: {test_images.shape}")
    print(f"IDs: {test_ids.shape}")
    
    # Process data format
    if len(test_images.shape) == 2:
        print("\nDetected flat format (N, 3072)")
        print("Reshaping to NCHW format (N, 3, 32, 32)...")
        test_images = test_images.reshape(-1, 3, 32, 32)
        
    elif len(test_images.shape) == 4:
        if test_images.shape[1] != 3:
            print("\nDetected NHWC format (N, 32, 32, 3)")
            print("Converting to NCHW format...")
            test_images = np.transpose(test_images, (0, 3, 1, 2))
        else:
            print("\nData already in NCHW format")
    
    print("\nFinal data shapes:")
    print(f"Images: {test_images.shape}")
    print(f"IDs: {test_ids.shape}")
    return test_images, test_ids

class CIFAR10TestDataset(Dataset):
    """
    Custom Dataset class for CIFAR-10 test data.
    Handles loading and preprocessing of test images with optional transforms.
    
    Args:
        data: Test image data in NCHW format
        ids: Corresponding image IDs
        transform: Optional transforms to apply to images
    """
    def __init__(self, data, ids, transform=None):
        self.data = data  # Image data
        self.ids = ids    # Image IDs
        self.transform = transform  # Optional transforms
        print(f"Initialized test dataset with {len(data)} images")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # Convert numpy array to float tensor and ensure NCHW format
        image = torch.from_numpy(self.data[idx]).float()
        
        # Apply any specified transforms (e.g. normalization)
        if self.transform:
            image = self.transform(image)
        
        image_id = self.ids[idx]
        return image, image_id


In [None]:
#############################
# Step 7.5: Plotting Functions 
#############################
def plot_training_history(train_losses, train_accs, val_losses, val_accs, save_path=None):
    """Plot the training and validation history"""
    epochs = range(1, len(train_losses) + 1)
    
    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))
    
    # Plot losses
    ax1.plot(epochs, train_losses, 'b-', label='Training Loss')
    ax1.plot(epochs, val_losses, 'r-', label='Validation Loss')
    ax1.set_title('Loss Curves')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss')
    ax1.legend()
    
    # Plot accuracies
    ax2.plot(epochs, train_accs, 'b-', label='Training Accuracy')
    ax2.plot(epochs, val_accs, 'r-', label='Validation Accuracy')
    ax2.set_title('Accuracy Curves')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('Accuracy (%)')
    ax2.set_ylim(50, 100)
    ax2.legend()
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    return fig

def plot_final_metrics(final_val_acc, param_count, classes_correct=None, save_path=None):
    """
    Create a simple visualization of the final metrics
    
    Args:
        final_val_acc: Final validation accuracy
        param_count: Number of parameters in the model
        classes_correct: Optional dictionary of per-class accuracies
        save_path: Path to save the figure (optional)
    """
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Create a text-based summary at the top
    ax.text(0.5, 0.95, 'CIFAR-10 Model Performance Summary', 
            fontsize=16, weight='bold', ha='center', transform=ax.transAxes)
    ax.text(0.5, 0.89, f'Model Size: {param_count:,} parameters', 
            fontsize=14, ha='center', transform=ax.transAxes)
    ax.text(0.5, 0.83, f'Validation Accuracy: {final_val_acc:.2f}%',
            fontsize=14, ha='center', transform=ax.transAxes, color='#d62728' if final_val_acc > 85 else 'black')
    
    # Create a bar chart for class-specific accuracies if provided
    if classes_correct is not None:
        class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                      'dog', 'frog', 'horse', 'ship', 'truck']
        accuracies = [classes_correct.get(i, 0) for i in range(10)]
        
        # Plot class accuracies as a bar chart
        y_pos = np.arange(len(class_names))
        bars = ax.barh(y_pos, accuracies, color='skyblue')
        ax.set_yticks(y_pos)
        ax.set_yticklabels(class_names)
        ax.set_xlim(0, 100)
        ax.set_xlabel('Accuracy (%)')
        ax.set_title('Per-Class Accuracy')
        
        # Add accuracy values on the bars
        for i, bar in enumerate(bars):
            width = bar.get_width()
            ax.text(width + 1, bar.get_y() + bar.get_height()/2, 
                    f'{accuracies[i]:.1f}%', va='center')
    else:
        # If no class accuracies, create a simplified summary visual
        ax.axis('off')  # Turn off axes
        
        # Create a green box for accuracy ≥ 85%
        if final_val_acc >= 85:
            ax.text(0.5, 0.5, f"Validation\nAccuracy\n{final_val_acc:.2f}%", 
                   fontsize=24, weight='bold', ha='center', va='center',
                   transform=ax.transAxes, 
                   bbox=dict(boxstyle="round,pad=0.3", facecolor='lightgreen', alpha=0.5))
        else:
            ax.text(0.5, 0.5, f"Validation\nAccuracy\n{final_val_acc:.2f}%", 
                   fontsize=24, weight='bold', ha='center', va='center',
                   transform=ax.transAxes, 
                   bbox=dict(boxstyle="round,pad=0.3", facecolor='lightblue', alpha=0.5))
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    return fig
    

def plot_class_distribution(predictions, save_path=None):
    """Plot the distribution of predicted classes"""
    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                  'dog', 'frog', 'horse', 'ship', 'truck']
    
    # Count occurrences of each class
    class_counts = np.bincount(predictions, minlength=10)
    
    # Create horizontal bar chart
    fig, ax = plt.subplots(figsize=(10, 8))
    
    y_pos = np.arange(len(class_names))
    ax.barh(y_pos, class_counts, color='skyblue')
    ax.set_yticks(y_pos)
    ax.set_yticklabels(class_names)
    ax.invert_yaxis()  # Labels read top-to-bottom
    ax.set_xlabel('Number of Predictions')
    ax.set_title('Distribution of Predicted Classes')
    
    # Add count labels to the bars
    for i, v in enumerate(class_counts):
        ax.text(v + 10, i, str(v), va='center')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    return fig

Step 8: Main Execution Flow

This section orchestrates the complete training and prediction pipeline by:

1. Model Creation and Training
   - Initializes optimized ResNet architecture with SE blocks
   - Applies advanced training techniques like SWA and EMA
   - Uses cosine learning rate scheduling with warmup
   - Monitors and validates model performance

2. Test Data Processing 
   - Handles multiple input data formats (NHWC, NCHW)
   - Applies consistent preprocessing and augmentation
   - Uses efficient data loading with pinned memory
   - Implements parallel processing for speed

3. Ensemble Prediction
   - Combines predictions from multiple model checkpoints
   - Applies test-time augmentation for robustness
   - Uses weighted averaging of model outputs
   - Handles prediction uncertainty

4. Submission Generation
   - Creates properly formatted submission file
   - Includes confidence scores for predictions
   - Validates output format and constraints
   - Implements error checking and logging

The pipeline is designed to:
- Maximize prediction accuracy through model ensembling
- Maintain efficient processing through optimized data handling
- Ensure robustness through multiple augmentation strategies
- Provide detailed logging for monitoring and debugging

In [None]:
print("\n=== Starting Main Execution Pipeline ===")

# Create model with optimized architecture
print("\nInitializing model architecture...")
print("Using configuration:")
print("- 3-5-3 block structure for balanced depth")
print("- 40 initial channels for good feature extraction")
print("- Squeeze-and-Excitation for attention")
print("- 0.2 dropout for regularization")

model = OptimizedResNet(
    block=BasicBlock,
    num_blocks=[3, 5, 3],    # Balanced depth with more middle blocks
    initial_channels=40,      # Increased channels for better feature extraction
    use_se=True,             # Add attention mechanism
    se_reduction=8,          # SE reduction ratio
    drop_rate=0.2            # Dropout for regularization
).to(device)

# Verify model size constraints
param_count = count_parameters(model)
print(f"\nModel Architecture Summary:")
print(f"Total trainable parameters: {param_count:,}")

if param_count > 5_000_000:
    raise ValueError(f"Model exceeds parameter limit: {param_count:,} > 5,000,000")

# Train model with advanced techniques
print("\nStarting model training phase...")
model, swa_model, ema, history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=250  # Extended training for better convergence
)

# Process test data
print("\nLoading and preparing test data...")
test_images, test_ids = load_test_data()

# Create optimized test data pipeline
print("\nSetting up test data pipeline...")
print("- Using batch size 256 for efficient processing")
print("- Enabling 4 worker processes for parallel loading")
print("- Using pinned memory for faster GPU transfer")

test_dataset = CIFAR10TestDataset(test_images, test_ids, transform=test_transform)
test_loader = DataLoader(
    test_dataset,
    batch_size=256,      # Efficient batch size
    shuffle=False,       # Maintain order for predictions
    num_workers=4,       # Parallel data loading
    pin_memory=True      # Faster GPU transfer
)

# Initialize ensemble models
print("\nPreparing model ensemble...")

# Load best standard model checkpoint
print("Loading best standard model...")
best_model = OptimizedResNet(
    block=BasicBlock,
    num_blocks=[3, 5, 3],
    initial_channels=40,
    use_se=True,
    se_reduction=8,
    drop_rate=0.0  # Disable dropout for inference
).to(device)
best_model.load_state_dict(torch.load("best_model.pth"))

# Load best EMA model checkpoint
print("Loading EMA model...")
ema_model = OptimizedResNet(
    block=BasicBlock,
    num_blocks=[3, 5, 3],
    initial_channels=40,
    use_se=True,
    se_reduction=8,
    drop_rate=0.0
).to(device)
ema_model.load_state_dict(torch.load("best_ema_model.pth"))

# Prepare SWA model for inference
print("Preparing SWA model...")
swa_model.module.drop_rate = 0.0  # Disable dropout

# Create ensemble for robust predictions
print("Combining models into ensemble...")
ensemble = [best_model, ema_model, swa_model]

# Generate predictions using ensemble and TTA
print("\nGenerating predictions with ensemble and test-time augmentation...")
all_predictions, all_ids = predict_with_enhanced_tta(ensemble, test_loader, num_tta=5)

# Create and save submission file
print("\nPreparing submission file...")
df_submission = pd.DataFrame({
    "ID": all_ids,
    "Labels": all_predictions
})
df_submission = df_submission.sort_values(by="ID")  # Sort by ID for consistency
csv_filename = "submission.csv"
df_submission.to_csv(csv_filename, index=False)
print(f"Submission saved to: {csv_filename}")

print("\n=== Pipeline Complete ===")

Step 9: Final Evaluation and Submission
In this final step, we:
1. Load our best models (regular, EMA, and SWA) for ensemble prediction
2. Set dropout to 0 since we're doing inference
3. Use Test Time Augmentation (TTA) with 5 different augmentations
4. Create ensemble predictions by averaging outputs from all models
5. Generate submission file with predictions
6. Plot the distribution of predicted classes
7. Print final model statistics including parameter count and validation accuracy

The ensemble approach combines predictions from multiple models to reduce variance
and improve robustness. TTA further improves reliability by averaging predictions
across different augmented versions of each test image.


In [None]:
# Plot training and validation metrics
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

def plot_training_metrics(train_losses, train_accs, val_losses, val_accs):
    """
    Plot training and validation metrics over epochs.
    Creates two subplots:
    1. Loss curves
    2. Accuracy curves
    """
    plt.figure(figsize=(12, 5))
    epochs = range(1, len(train_losses) + 1)

    # Plot losses
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, 'b-', label='Training Loss')
    plt.plot(epochs, val_losses, 'r-', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    # Plot accuracies 
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accs, 'b-', label='Training Accuracy')
    plt.plot(epochs, val_accs, 'r-', label='Validation Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.savefig('training_metrics.png')
    plt.close()

def plot_class_distribution(predictions, save_path=None):
    """
    Plot distribution of predicted classes.
    Creates a bar plot showing count of predictions for each class.
    """
    plt.figure(figsize=(10, 6))
    class_counts = np.bincount(predictions)
    
    # Create bar plot
    plt.subplot(2, 2, 1)
    plt.bar(range(10), class_counts)
    plt.title('Distribution of Predicted Classes')
    plt.xlabel('Class')
    plt.ylabel('Count')
    plt.grid(True, axis='y')
    
    # Add count labels on top of each bar
    for i, count in enumerate(class_counts):
        plt.text(i, count, str(count), ha='center', va='bottom')
    
    # Add pie chart
    plt.subplot(2, 2, 2)
    plt.pie(class_counts, labels=range(10), autopct='%1.1f%%')
    plt.title('Class Distribution (Pie Chart)')
    
    # Add horizontal bar plot
    plt.subplot(2, 2, 3)
    sns.barplot(x=class_counts, y=range(10))
    plt.title('Horizontal Class Distribution')
    plt.xlabel('Count')
    plt.ylabel('Class')
    
    # Add KDE plot
    plt.subplot(2, 2, 4)
    sns.kdeplot(data=predictions)
    plt.title('Density Distribution of Classes')
    plt.xlabel('Class')
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
    plt.close()

def plot_learning_rate_schedule(learning_rates, save_path='lr_schedule.png'):
    """
    Plot learning rate schedule over epochs
    """
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, len(learning_rates) + 1), learning_rates)
    plt.title('Learning Rate Schedule')
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.grid(True)
    plt.savefig(save_path)
    plt.close()

def plot_confusion_matrix(true_labels, predictions, save_path='confusion_matrix.png'):
    """
    Plot confusion matrix for model predictions
    """
    plt.figure(figsize=(10, 8))
    cm = confusion_matrix(true_labels, predictions)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.savefig(save_path)
    plt.close()

# Call plotting functions if history metrics are available
try:
    train_losses, train_accs, val_losses, val_accs = history
    plot_training_metrics(train_losses, train_accs, val_losses, val_accs)
    print("Training metrics plots saved as training_metrics.png")
    
    if 'learning_rates' in locals():
        plot_learning_rate_schedule(learning_rates)
        print("Learning rate schedule plot saved as lr_schedule.png")
        
    if 'y_true' in locals() and 'y_pred' in locals():
        plot_confusion_matrix(y_true, y_pred)
        print("Confusion matrix plot saved as confusion_matrix.png")
        
except NameError:
    print("Training history not found. Skipping training metrics plots.")
