# Knowledge Distillation for EfficientNetV2 on CIFAR-100

This Jupyter Notebook implements the experiments for the research paper "Improving a Compact Vision Model Using Knowledge Distillation."

**Objective:** To train a compact student model (`EfficientNetV2-S`) using knowledge distillation from a larger teacher model (`EfficientNetV2-L`) and evaluate the performance improvement against a baseline.

---

## ðŸš€ Setup Instructions

### **Hardware Requirements:**

- **GPU:** NVIDIA GPU with at least 16GB VRAM recommended
- **Training Time:** Approximately 8-12 hours for 200 epochs (depending on GPU)

### **Kaggle Setup:**

1. **Enable GPU:** Settings â†’ Accelerator â†’ **GPU P100** or **GPU T4 x2**
2. **Enable Internet:** Settings â†’ Internet â†’ **On** (needed for downloading models)
3. **Run Cell 1** to check if Kaggle is detected

### **ðŸ“‹ Training Configuration:**

The notebook is configured to train for **200 epochs** in a single session:

```python
# Default configuration:
NUM_EPOCHS = 200
BATCH_SIZE = 128
LEARNING_RATE_FINETUNE = 0.0001
```

### **Quick Test (Recommended First):**

Before running full training, test with 5 epochs:

- In Cell 2, manually set: `NUM_EPOCHS = 5`
- This takes ~10 minutes and verifies everything works

### **ðŸ“Š Expected Results:**

- **Teacher Model:** ~75-78% accuracy on CIFAR-100
- **Baseline Student:** ~70-72% accuracy
- **Distilled Student:** ~75-77% accuracy (matching or approaching teacher)

---

## ðŸ“Š Notebook Structure:

1. **Setup:** Imports, configuration, Kaggle detection
2. **Data Loading:** CIFAR-100 dataset with caching
3. **Model Definition:** Teacher and Student models
4. **Teacher Fine-tuning:** Train teacher on CIFAR-100
5. **Baseline Training:** Student without distillation
6. **KD Training:** Student with distillation
7. **Evaluation:** Compare results
8. **SOTA Comparison:** Compare with published models
9. **Ablation Studies:** Test different hyperparameters

---

**ðŸŽ¯ Note:** Training will automatically save checkpoints every 10-20 epochs and the best model based on validation accuracy.


In [None]:
# Install required packages (only if not already installed)
import importlib.util

# Check if thop is already installed
thop_installed = importlib.util.find_spec("thop") is not None

if not thop_installed:
    print("Installing thop package...")
    !pip install thop -q
    print("Required packages installed successfully.")
else:
    print("All required packages already installed. Skipping installation.")

print("Note: Any dependency warnings about CUDA libraries can be safely ignored.")

In [None]:
# 1. Setup

# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import efficientnet_v2_s, efficientnet_v2_l, EfficientNet_V2_S_Weights, EfficientNet_V2_L_Weights
from tqdm import tqdm
import time
import copy
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from thop import profile
import pandas as pd

# --- Kaggle Detection ---
IS_KAGGLE = os.path.exists('/kaggle/working')
print(f"Running on Kaggle: {IS_KAGGLE}")

# --- Reproducibility ---
# Paper Section 3.5: "We set a random seed of 42 for reproducibility"
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# GPU Information
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"Note: Paper methodology specifies V100 32GB, but code adapts to available GPU")

# Directory setup (Kaggle-compatible)
if IS_KAGGLE:
    MODEL_DIR = "/kaggle/working/models"
    DATA_DIR = "/kaggle/working/data"
else:
    MODEL_DIR = "models"
    DATA_DIR = "./data"

os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(DATA_DIR, exist_ok=True)

# Global parameters - Aligned with Paper Section 3.5
# Paper: "AdamW optimizer with initial learning rate of 0.001, weight decay of 0.01"
LEARNING_RATE = 0.001  # For training from scratch
LEARNING_RATE_FINETUNE = 0.0001  # Lower LR for fine-tuning pre-trained models (10x smaller)
WEIGHT_DECAY = 0.01

# Paper: "batch size of 128"
BATCH_SIZE = 128

# Paper Section 3.3: "For temperature, we evaluate T âˆˆ {1, 2, 4, 8}"
# Paper Section 3.3: "For the weighting factor, we evaluate Î± âˆˆ {0.1, 0.3, 0.5, 0.7, 0.9}"
# Optimized values for stable training with pre-trained models
TEMPERATURE = 2.0  # Reduced from 4.0 for better convergence
ALPHA = 0.5  # Reduced from 0.7 for balanced learning

NUM_CLASSES = 100  # CIFAR-100 has 100 classes

# Paper Section 3.5: "early stopping with a patience of 20 epochs"
PATIENCE = 20

# Paper Section 3.5: "train for 200 epochs"
NUM_EPOCHS = 200

# Apply to all training phases
NUM_EPOCHS_BASELINE = NUM_EPOCHS
NUM_EPOCHS_KD = NUM_EPOCHS
NUM_EPOCHS_ABLATION = NUM_EPOCHS

# Checkpoint frequency - Reduced to save disk space on Kaggle
# Automatic cleanup keeps only the 3 most recent checkpoints
CHECKPOINT_FREQUENCY = 20 if IS_KAGGLE else 30

print(f"\n{'='*70}")
print(f"EXPERIMENT CONFIGURATION (Aligned with Paper Section 3.5)")
print(f"{'='*70}")
print(f"Platform: {'Kaggle Notebooks' if IS_KAGGLE else 'Local/Cloud'}")
print(f"Total Epochs: {NUM_EPOCHS}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Learning Rate (from scratch): {LEARNING_RATE}")
print(f"Learning Rate (fine-tuning): {LEARNING_RATE_FINETUNE}")
print(f"Weight Decay: {WEIGHT_DECAY}")
print(f"Temperature (T): {TEMPERATURE} (optimized for pre-trained models)")
print(f"Alpha (Î±): {ALPHA} (balanced distillation + hard labels)")
print(f"Early Stopping Patience: {PATIENCE}")
print(f"Checkpoint Frequency: {CHECKPOINT_FREQUENCY} epochs (auto-cleanup enabled)")
print(f"Model Directory: {MODEL_DIR}")
print(f"Data Directory: {DATA_DIR}")
print(f"{'='*70}\n")

### Justification for Global Hyperparameters

The values for the global parameters in this experiment are chosen based on a combination of common practices in the computer vision literature and the specific requirements of our models and dataset.

- **`SEED = 42`**: We set a fixed random seed to ensure that our experiments are **reproducible**. The specific number is arbitrary, but fixing it guarantees that anyone running this code will get the exact same random weight initializations and data shuffling, leading to the same results.

- **`BATCH_SIZE = 128`**: This batch size is chosen to maximize GPU utilization while fitting in memory. It is a power of 2, which is computationally efficient on GPU hardware. This value provides a good balance between accurate gradient estimation and the memory capacity of modern GPUs like the Tesla P100.

- **`NUM_EPOCHS = 200`**: This value provides sufficient training time for the models to converge on the CIFAR-100 dataset. It allows the models to fully adapt to the data distribution. Training all 200 epochs in a single session requires approximately 8-12 hours on a modern GPU.

- **`LEARNING_RATE = 0.001`**: This is a standard and robust starting learning rate for the `AdamW` optimizer, which is known to perform well across a wide range of tasks when training from scratch. We use a cosine annealing learning rate scheduler to adjust this rate over time for better convergence.

- **`LEARNING_RATE_FINETUNE = 0.0001`**: When fine-tuning pre-trained models (like our teacher model with ImageNet weights), we use a 10x smaller learning rate. This prevents the model from "forgetting" its pre-learned features and allows for gentle adaptation to the new dataset. Using the same learning rate as training from scratch would cause gradient explosion and model collapse.

- **`TEMPERATURE = 4.0`**: This is a commonly used value in the knowledge distillation literature. It provides a moderate level of "softening" to the teacher's outputs, which has been shown to be effective for transferring knowledge without making the distribution too flat.

- **`ALPHA = 0.7`**: This weighting factor gives more importance to the distillation loss (from the teacher) than the cross-entropy loss (from the true labels). This is a common strategy in knowledge distillation, as the primary goal is to leverage the rich "dark knowledge" from the teacher model. Our ablation study investigates the effect of this parameter more deeply.


## 2. Data Loading and Preprocessing

We load the CIFAR-100 dataset. We apply standard data augmentation techniques for the training set (random cropping, random horizontal flipping) and normalization for both the training and test sets.


In [None]:
# Data transformations - Aligned with Paper Section 3.5
# Paper: "RandomCrop(32, padding=4), RandomHorizontalFlip(p=0.5), Normalization"
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # Paper Section 3.5
    transforms.RandomHorizontalFlip(),  # p=0.5 by default
    transforms.ToTensor(),
    # Paper Section 3.5: Mean=[0.5071, 0.4867, 0.4408], Std=[0.2675, 0.2565, 0.2761]
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

# Check if CIFAR-100 is already downloaded
cifar_path = os.path.join(DATA_DIR, 'cifar-100-python')
already_downloaded = os.path.exists(cifar_path)

if already_downloaded:
    print(f"CIFAR-100 already downloaded at: {DATA_DIR}")
    print("Skipping download, loading from cache...")
else:
    print(f"Downloading CIFAR-100 dataset to: {DATA_DIR}")

# Optimization: Use pin_memory=True for faster data transfer to CUDA
use_pin_memory = torch.cuda.is_available()
num_workers = os.cpu_count()
if num_workers > 4: num_workers = 4 # Cap at 4 to avoid overhead

# Load datasets (Kaggle-compatible paths)
# Paper Section 3.4: "CIFAR-100 dataset... 50,000 training images and 10,000 testing images"
trainset = torchvision.datasets.CIFAR100(root=DATA_DIR, train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, 
                                          num_workers=num_workers, pin_memory=use_pin_memory)

testset = torchvision.datasets.CIFAR100(root=DATA_DIR, train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, 
                                         num_workers=num_workers, pin_memory=use_pin_memory)

print(f"Training batches: {len(trainloader)}")
print(f"Testing batches: {len(testloader)}")
print(f"Training samples: {len(trainset)} (Paper: 50,000)")
print(f"Testing samples: {len(testset)} (Paper: 10,000)")
print(f"DataLoader optimization: num_workers={num_workers}, pin_memory={use_pin_memory}")

## 3. Model Definition

We define the teacher and student models. We use pre-trained `EfficientNetV2-L` as the teacher and `EfficientNetV2-S` as the student, both from `torchvision`. Using pre-trained weights on ImageNet allows us to leverage powerful, pre-learned features, which we will then fine-tune on the CIFAR-100 dataset.

We modify the final classifier layer of each model to match the number of classes in CIFAR-100 (100). The teacher model's weights are frozen, as it only serves as a guide and should not be trained further.


In [None]:
# Teacher Model - Paper Section 3.2: "EfficientNetV2-L... approximately 120 million parameters"
print("Loading Teacher Model (EfficientNetV2-L)...")
teacher_model = efficientnet_v2_l(weights=EfficientNet_V2_L_Weights.IMAGENET1K_V1)
print("Teacher model loaded (ImageNet weights)")

# Freeze all the parameters in the teacher model (will be unfrozen for fine-tuning later)
for param in teacher_model.parameters():
    param.requires_grad = False
teacher_model.classifier[1] = nn.Linear(teacher_model.classifier[1].in_features, NUM_CLASSES)
teacher_model = teacher_model.to(device)
teacher_model.eval()

# Student Model - Paper Section 3.2: "EfficientNetV2-S... approximately 22 million parameters"
print("Loading Student Model (EfficientNetV2-S)...")
student_model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
print("Student model loaded (ImageNet weights)")

student_model.classifier[1] = nn.Linear(student_model.classifier[1].in_features, NUM_CLASSES)
student_model = student_model.to(device)

print("\nTeacher and Student models defined and moved to device")
print("Teacher model parameters are frozen (will be unfrozen for fine-tuning)")
print(f"Note: Paper Section 3.2: Teacher ~120M params, Student ~22M params")

## 4. Helper Functions for Training and Evaluation

We define helper functions for the training loop and for evaluating the model's accuracy on the test set. These functions must be defined BEFORE we start training any models.


In [None]:
from torch.amp import autocast, GradScaler
import glob

def evaluate_model(model, dataloader):
    """
    Evaluates model accuracy on a given dataset.
    Paper Section 3.6: "Top-1 Accuracy (%): The primary metric for classification performance"
    """
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    return accuracy

def evaluate_model_with_loss(model, dataloader, criterion):
    """
    Evaluates model accuracy and loss on a given dataset.
    Paper Section 3.6: "Top-1 Accuracy (%) and Validation Loss"
    """
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0
    
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    avg_loss = running_loss / len(dataloader)
    return accuracy, avg_loss

def distillation_loss(student_logits, teacher_logits, hard_labels, temp, alpha):
    """
    Calculates the distillation loss, combining soft and hard targets.
    Fixed: Proper scaling of KL divergence loss for stable training.
    """
    # Soft targets from teacher (with temperature)
    soft_teacher_outputs = nn.functional.softmax(teacher_logits / temp, dim=1)
    soft_student_outputs = nn.functional.log_softmax(student_logits / temp, dim=1)
    
    # KL Divergence loss (already scaled by temp^2 in the formula)
    # Using sum_over_batch_size instead of batchmean for correct scaling
    distill_loss = nn.functional.kl_div(
        soft_student_outputs, 
        soft_teacher_outputs, 
        reduction='sum'
    ) / student_logits.size(0) * (temp * temp)
    
    # Hard target cross-entropy loss
    ce_loss = nn.functional.cross_entropy(student_logits, hard_labels)
    
    # Combined loss
    total_loss = alpha * distill_loss + (1.0 - alpha) * ce_loss
    return total_loss

def cleanup_old_checkpoints(model_name, keep_last_n=3):
    """
    Removes old checkpoint files, keeping only the N most recent ones.
    This prevents disk space issues during long training runs.
    """
    checkpoint_pattern = f"{MODEL_DIR}/{model_name}_checkpoint_epoch*.pth"
    checkpoints = sorted(glob.glob(checkpoint_pattern))
    
    if len(checkpoints) > keep_last_n:
        for old_checkpoint in checkpoints[:-keep_last_n]:
            try:
                os.remove(old_checkpoint)
                print(f"  -> Cleaned up old checkpoint: {os.path.basename(old_checkpoint)}")
            except Exception as e:
                print(f"  -> Warning: Could not remove {old_checkpoint}: {e}")

def train_model_unified(model, dataloader, optimizer, scheduler, num_epochs, model_name, 
                       criterion=None, teacher_model=None, temp=None, alpha=None, patience=20):
    """
    Unified training loop supporting:
    1. Standard Training (if teacher_model is None)
    2. Knowledge Distillation (if teacher_model is provided)
    3. Automatic Mixed Precision (AMP) for faster training
    4. Early Stopping and Checkpointing
    5. Automatic checkpoint cleanup to prevent disk space issues
    """
    model.train()
    if teacher_model:
        teacher_model.eval()
        print(f"Starting Knowledge Distillation with T={temp}, Î±={alpha}")
    else:
        print("Starting Standard Training")

    best_acc = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())
    history = {'train_loss': [], 'train_accuracy': [], 'val_loss': [], 'val_accuracy': []}
    epochs_no_improve = 0
    
    # AMP Scaler (Updated for PyTorch 2.4+)
    # 'cuda' device type is specified for P100 compatibility
    scaler = GradScaler('cuda', enabled=torch.cuda.is_available())
    
    # Validation criterion is always CrossEntropy
    val_criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        for i, data in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Mixed Precision Context (Updated for PyTorch 2.4+)
            with autocast('cuda', enabled=torch.cuda.is_available()):
                if teacher_model:
                    with torch.no_grad():
                        teacher_outputs = teacher_model(inputs)
                    student_outputs = model(inputs)
                    loss = distillation_loss(student_outputs, teacher_outputs, labels, temp, alpha)
                else:
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    student_outputs = outputs

            # Scaled Backward Pass
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()
            _, predicted = torch.max(student_outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        scheduler.step()
        
        train_loss = running_loss / len(dataloader)
        train_acc = 100 * correct / total
        val_acc, val_loss = evaluate_model_with_loss(model, testloader, val_criterion)
        
        history['train_loss'].append(train_loss)
        history['train_accuracy'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_accuracy'].append(val_acc)
        
        current_lr = scheduler.get_last_lr()[0]
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%, LR: {current_lr:.6f}")

        if val_acc > best_acc:
            best_acc = val_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save(model.state_dict(), f"{MODEL_DIR}/{model_name}.pth")
            epochs_no_improve = 0
            print(f"  -> New best model saved!")
        else:
            epochs_no_improve += 1
            
        if epochs_no_improve >= patience:
            print(f"[Early stopping triggered after {epoch+1} epochs]")
            break
        
        # Save checkpoint every CHECKPOINT_FREQUENCY epochs
        if (epoch + 1) % CHECKPOINT_FREQUENCY == 0:
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_acc': best_acc,
                'history': history,
                'epochs_no_improve': epochs_no_improve
            }
            checkpoint_path = f"{MODEL_DIR}/{model_name}_checkpoint_epoch{epoch+1}.pth"
            
            try:
                torch.save(checkpoint, checkpoint_path)
                print(f"  -> Checkpoint saved: {checkpoint_path}")
                
                # Clean up old checkpoints to save disk space
                cleanup_old_checkpoints(model_name, keep_last_n=3)
            except Exception as e:
                print(f"  -> Warning: Checkpoint save failed: {e}")
                print(f"  -> Continuing training (best model is still safe)")
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            
    print(f"Training complete. Best accuracy: {best_acc:.2f}%")
    print(f"Best model saved to: {MODEL_DIR}/{model_name}.pth")
    model.load_state_dict(best_model_wts)
    return model, history

def calculate_flops(model):
    """
    Calculates FLOPs and number of parameters for a model.
    """
    input_tensor = torch.randn(1, 3, 32, 32).to(device)
    flops, params = profile(model, inputs=(input_tensor, ), verbose=False)
    model_flops = flops / 1e9
    model_params = params / 1e6
    print(f"Model FLOPs: {model_flops:.2f} GFLOPs")
    print(f"Model Parameters: {model_params:.2f} M")
    return model_flops, model_params

def measure_inference_time(model, num_iterations=1000, warmup=100):
    """
    Measures average inference time per image.
    """
    model.eval()
    input_tensor = torch.randn(1, 3, 32, 32).to(device)
    
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(input_tensor)
    
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    start_time = time.time()
    with torch.no_grad():
        for _ in range(num_iterations):
            _ = model(input_tensor)
    
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    end_time = time.time()
    avg_time_ms = (end_time - start_time) / num_iterations * 1000
    
    print(f"Average inference time: {avg_time_ms:.4f} ms/image")
    return avg_time_ms

def calculate_model_size(model, model_name):
    """
    Calculates the size of the saved model file in MB.
    """
    temp_path = f"{MODEL_DIR}/{model_name}_temp.pth"
    torch.save(model.state_dict(), temp_path)
    size_mb = os.path.getsize(temp_path) / (1024 * 1024)
    os.remove(temp_path)
    print(f"Model size: {size_mb:.2f} MB")
    return size_mb

print("Unified Helper functions loaded successfully with AMP support and fixed distillation loss")

## 5. Teacher Model Fine-tuning on CIFAR-100

Before using the teacher model for knowledge distillation, we fine-tune it on the CIFAR-100 dataset. This ensures the teacher provides high-quality soft labels that are specifically adapted to our target dataset, rather than relying solely on ImageNet pre-trained features.


In [None]:
print("--- Fine-tuning Teacher Model on CIFAR-100 ---")

# Check if teacher model is already trained
teacher_checkpoint_path = f"{MODEL_DIR}/teacher_model.pth"
teacher_already_trained = os.path.exists(teacher_checkpoint_path)

if teacher_already_trained:
    print(f"\nFound existing teacher model at: {teacher_checkpoint_path}")
    print("Loading pre-trained teacher model...")
    teacher_model.load_state_dict(torch.load(teacher_checkpoint_path))
    teacher_model.eval()
    
    teacher_accuracy = evaluate_model(teacher_model, testloader)
    print(f"Loaded Teacher Model Accuracy: {teacher_accuracy:.2f}%")
    
    print("\nCalculating teacher model metrics...")
    teacher_flops, teacher_params = calculate_flops(teacher_model)
    teacher_time = measure_inference_time(teacher_model)
    teacher_size = calculate_model_size(teacher_model, "teacher_model")
    
    teacher_history = {'train_loss': [], 'train_accuracy': [], 'val_loss': [], 'val_accuracy': []}
    print("\nSkipping teacher training (already trained).")
    
else:
    print("\nNo existing teacher model found. Starting training from scratch...")
    
    # Unfreeze teacher model for fine-tuning
    teacher_model.train()
    for param in teacher_model.parameters():
        param.requires_grad = True

    criterion_teacher = nn.CrossEntropyLoss()
    optimizer_teacher = optim.AdamW(teacher_model.parameters(), lr=LEARNING_RATE_FINETUNE, weight_decay=WEIGHT_DECAY)
    scheduler_teacher = optim.lr_scheduler.CosineAnnealingLR(optimizer_teacher, T_max=NUM_EPOCHS)

    print(f"Training teacher for {NUM_EPOCHS} epochs...")
    print(f"Using fine-tuning learning rate: {LEARNING_RATE_FINETUNE}")
    
    trained_teacher, teacher_history = train_model_unified(
        model=teacher_model, 
        dataloader=trainloader, 
        criterion=criterion_teacher, 
        optimizer=optimizer_teacher,
        scheduler=scheduler_teacher,
        num_epochs=NUM_EPOCHS, 
        model_name="teacher_model",
        patience=PATIENCE
    )

    # Freeze the teacher for distillation
    for param in teacher_model.parameters():
        param.requires_grad = False
    teacher_model.eval()

    print("\n--- Teacher Model Evaluation ---")
    teacher_accuracy = evaluate_model(teacher_model, testloader)
    print(f"Teacher Model Accuracy on CIFAR-100: {teacher_accuracy:.2f}%")

    print("\nCalculating teacher model metrics...")
    teacher_flops, teacher_params = calculate_flops(teacher_model)
    teacher_time = measure_inference_time(teacher_model)
    teacher_size = calculate_model_size(teacher_model, "teacher_model")

print("\nTeacher model is now frozen and ready for knowledge distillation.")

## 6. Baseline Student Model Training

First, we train the student model using only the standard cross-entropy loss. This provides the baseline performance that we will compare against the distilled model.


In [None]:
# Check if baseline student is already trained
baseline_checkpoint_path = f"{MODEL_DIR}/baseline_student.pth"
baseline_already_trained = os.path.exists(baseline_checkpoint_path)

if baseline_already_trained:
    print("--- Found Existing Baseline Student Model ---")
    print(f"Loading from: {baseline_checkpoint_path}")
    
    baseline_student_model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
    baseline_student_model.classifier[1] = nn.Linear(baseline_student_model.classifier[1].in_features, NUM_CLASSES)
    baseline_student_model = baseline_student_model.to(device)
    baseline_student_model.load_state_dict(torch.load(baseline_checkpoint_path))
    baseline_student_model.eval()
    
    baseline_accuracy = evaluate_model(baseline_student_model, testloader)
    print(f"Loaded Baseline Student Accuracy: {baseline_accuracy:.2f}%")
    
    print("\nCalculating model metrics...")
    baseline_flops, baseline_params = calculate_flops(baseline_student_model)
    baseline_time = measure_inference_time(baseline_student_model)
    baseline_size = calculate_model_size(baseline_student_model, "baseline_student")
    
    baseline_history = {'train_loss': [], 'train_accuracy': [], 'val_loss': [], 'val_accuracy': []}
    trained_baseline_model = baseline_student_model
    
    print("\nSkipping baseline training (already trained).")
    
else:
    print("--- Starting Baseline Student Model Training ---")
    print("NOTE: Using fine-tuning learning rate for pre-trained model")
    
    baseline_student_model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
    baseline_student_model.classifier[1] = nn.Linear(baseline_student_model.classifier[1].in_features, NUM_CLASSES)
    baseline_student_model = baseline_student_model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(baseline_student_model.parameters(), lr=LEARNING_RATE_FINETUNE, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

    print(f"Training for {NUM_EPOCHS} epochs with early stopping (patience={PATIENCE})")
    
    trained_baseline_model, baseline_history = train_model_unified(
        model=baseline_student_model, 
        dataloader=trainloader, 
        criterion=criterion, 
        optimizer=optimizer,
        scheduler=scheduler,
        num_epochs=NUM_EPOCHS, 
        model_name="baseline_student",
        patience=PATIENCE
    )
    
    print("\n--- Baseline Model Evaluation ---")
    baseline_accuracy = evaluate_model(trained_baseline_model, testloader)
    print(f"Final Baseline Student Model Accuracy: {baseline_accuracy:.2f}%")

    print("\nCalculating model metrics...")
    baseline_flops, baseline_params = calculate_flops(trained_baseline_model)
    baseline_time = measure_inference_time(trained_baseline_model)
    baseline_size = calculate_model_size(trained_baseline_model, "baseline_student")

## 7. Knowledge Distillation Training

Now, we train the student model using the knowledge distillation loss. The total loss is a combination of the standard cross-entropy loss with the hard labels and the KL Divergence loss between the student's and teacher's softened logits.


In [None]:
# Knowledge Distillation Training
print("\n--- Starting Knowledge Distillation Training ---")

# Check if distilled model is already trained
distilled_checkpoint_path = f"{MODEL_DIR}/distilled_student.pth"
distilled_already_trained = os.path.exists(distilled_checkpoint_path)

if distilled_already_trained:
    print(f"Found existing distilled model at: {distilled_checkpoint_path}")
    print("Loading pre-trained distilled model...")
    
    distilled_student_model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
    distilled_student_model.classifier[1] = nn.Linear(distilled_student_model.classifier[1].in_features, NUM_CLASSES)
    distilled_student_model = distilled_student_model.to(device)
    distilled_student_model.load_state_dict(torch.load(distilled_checkpoint_path))
    distilled_student_model.eval()
    
    distilled_accuracy = evaluate_model(distilled_student_model, testloader)
    print(f"Loaded Distilled Student Accuracy: {distilled_accuracy:.2f}%")
    
    distilled_flops, distilled_params = calculate_flops(distilled_student_model)
    distilled_time = measure_inference_time(distilled_student_model)
    distilled_size = calculate_model_size(distilled_student_model, "distilled_student")
    
    distilled_history = {'train_loss': [], 'train_accuracy': [], 'val_loss': [], 'val_accuracy': []}
    print("\nSkipping distillation training (already trained).")
    
else:
    print("Starting distillation training from scratch...")
    
    # Create fresh student model for distillation
    distilled_student_model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
    distilled_student_model.classifier[1] = nn.Linear(distilled_student_model.classifier[1].in_features, NUM_CLASSES)
    distilled_student_model = distilled_student_model.to(device)

    # CRITICAL: Use LEARNING_RATE_FINETUNE for pre-trained models to prevent gradient explosion
    optimizer_kd = optim.AdamW(distilled_student_model.parameters(), lr=LEARNING_RATE_FINETUNE, weight_decay=WEIGHT_DECAY)
    scheduler_kd = optim.lr_scheduler.CosineAnnealingLR(optimizer_kd, T_max=NUM_EPOCHS)

    print(f"Training for {NUM_EPOCHS} epochs with T={TEMPERATURE}, Î±={ALPHA}")
    print(f"Using fine-tuning learning rate: {LEARNING_RATE_FINETUNE}")
    print(f"Paper Section 3.3: L_total = {ALPHA} Â· L_KD + {1-ALPHA} Â· L_CE")
    
    trained_distilled_model, distilled_history = train_model_unified(
        model=distilled_student_model,
        dataloader=trainloader,
        optimizer=optimizer_kd,
        scheduler=scheduler_kd,
        num_epochs=NUM_EPOCHS,
        model_name="distilled_student",
        teacher_model=teacher_model,
        temp=TEMPERATURE,
        alpha=ALPHA,
        patience=PATIENCE
    )
    
    print("\n--- Distilled Model Evaluation ---")
    distilled_accuracy = evaluate_model(trained_distilled_model, testloader)
    print(f"Final Distilled Student Model Accuracy: {distilled_accuracy:.2f}%")

    print("\nCalculating model metrics (Paper Section 3.6)...")
    distilled_flops, distilled_params = calculate_flops(trained_distilled_model)
    distilled_time = measure_inference_time(trained_distilled_model)
    distilled_size = calculate_model_size(trained_distilled_model, "distilled_student")

print("\nKnowledge distillation training complete.")

In [None]:
# --- Final Summary Tables ---
print("\n" + "="*80)
print("EXPERIMENT RESULTS SUMMARY")
print("="*80)

print("\n--- Table 1: Baseline Performance ---")
print("| Model                     | Accuracy (%) | Size (MB) | Inference (ms) | FLOPs (G) | Params (M) |")
print("|---------------------------|--------------|-----------|----------------|-----------|------------|")
print(f"| Teacher (EfficientNetV2-L)| {teacher_accuracy:12.2f} | {teacher_size:9.2f} | {teacher_time:14.4f} | {teacher_flops:9.2f} | {teacher_params:10.2f} |")
print(f"| Baseline Student (Eff-S)  | {baseline_accuracy:12.2f} | {baseline_size:9.2f} | {baseline_time:14.4f} | {baseline_flops:9.2f} | {baseline_params:10.2f} |")

print("\n--- Table 2: Knowledge Distillation Performance Comparison ---")
print("| Model             | Accuracy (%) | Î” Accuracy | Size (MB) | Inference (ms) | FLOPs (G) |")
print("|-------------------|--------------|------------|-----------|----------------|-----------|")
print(f"| Baseline Student  | {baseline_accuracy:12.2f} |      -     | {baseline_size:9.2f} | {baseline_time:14.4f} | {baseline_flops:9.2f} |")
print(f"| Distilled Student | {distilled_accuracy:12.2f} | {distilled_accuracy - baseline_accuracy:+10.2f} | {distilled_size:9.2f} | {distilled_time:14.4f} | {distilled_flops:9.2f} |")

improvement = distilled_accuracy - baseline_accuracy
print(f"\n{'='*80}")
print(f"ACCURACY IMPROVEMENT WITH KNOWLEDGE DISTILLATION: {improvement:+.2f}%")
print(f"{'='*80}\n")

# --- Plotting Training Curves ---
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

axes[0, 0].plot(baseline_history['train_accuracy'], label='Baseline Student', linewidth=2)
axes[0, 0].plot(distilled_history['train_accuracy'], label='Distilled Student', linewidth=2)
axes[0, 0].set_title('Training Accuracy vs. Epochs', fontsize=12, fontweight='bold')
axes[0, 0].set_xlabel('Epochs')
axes[0, 0].set_ylabel('Training Accuracy (%)')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

axes[0, 1].plot(baseline_history['val_accuracy'], label='Baseline Student', linewidth=2)
axes[0, 1].plot(distilled_history['val_accuracy'], label='Distilled Student', linewidth=2)
axes[0, 1].set_title('Validation Accuracy vs. Epochs', fontsize=12, fontweight='bold')
axes[0, 1].set_xlabel('Epochs')
axes[0, 1].set_ylabel('Validation Accuracy (%)')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

axes[1, 0].plot(baseline_history['train_loss'], label='Baseline Student', linewidth=2)
axes[1, 0].plot(distilled_history['train_loss'], label='Distilled Student', linewidth=2)
axes[1, 0].set_title('Training Loss vs. Epochs', fontsize=12, fontweight='bold')
axes[1, 0].set_xlabel('Epochs')
axes[1, 0].set_ylabel('Training Loss')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

axes[1, 1].plot(baseline_history['val_loss'], label='Baseline Student', linewidth=2)
axes[1, 1].plot(distilled_history['val_loss'], label='Distilled Student', linewidth=2)
axes[1, 1].set_title('Validation Loss vs. Epochs', fontsize=12, fontweight='bold')
axes[1, 1].set_xlabel('Epochs')
axes[1, 1].set_ylabel('Validation Loss')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.suptitle('Training History Comparison: Baseline vs. Distilled Student', fontsize=14, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.97])
plt.savefig(f'{MODEL_DIR}/training_curves.png', dpi=300, bbox_inches='tight')
plt.show()

# --- Bar Chart Comparison ---
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

models = ['Teacher\n(EfficientNetV2-L)', 'Baseline\nStudent', 'Distilled\nStudent']
accuracies = [teacher_accuracy, baseline_accuracy, distilled_accuracy]
sizes = [teacher_size, baseline_size, distilled_size]
times = [teacher_time, baseline_time, distilled_time]

axes[0].bar(models, accuracies, color=['#2ecc71', '#3498db', '#e74c3c'])
axes[0].set_ylabel('Accuracy (%)')
axes[0].set_title('Model Accuracy Comparison', fontweight='bold')
axes[0].grid(True, axis='y', alpha=0.3)

axes[1].bar(models, sizes, color=['#2ecc71', '#3498db', '#e74c3c'])
axes[1].set_ylabel('Model Size (MB)')
axes[1].set_title('Model Size Comparison', fontweight='bold')
axes[1].grid(True, axis='y', alpha=0.3)

axes[2].bar(models, times, color=['#2ecc71', '#3498db', '#e74c3c'])
axes[2].set_ylabel('Inference Time (ms/image)')
axes[2].set_title('Inference Speed Comparison', fontweight='bold')
axes[2].grid(True, axis='y', alpha=0.3)

plt.suptitle('Model Performance Comparison', fontsize=14, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig(f'{MODEL_DIR}/model_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"\nFigures saved to {MODEL_DIR}/ directory")
print(f"All outputs will be available in Kaggle's Output tab after session completes")

## 8. State-of-the-Art (SOTA) Comparison

We compare our distilled student model against other published compact models on CIFAR-100. This provides context for our results and demonstrates the effectiveness of our approach relative to existing methods.

**Note:** The baseline results below are collected from published papers. You should verify and update these with the most recent SOTA results from literature.


In [None]:
# SOTA Comparison on CIFAR-100
import pandas as pd

sota_results = {
    'Model': [
        'ResNet-56', 'ResNet-110', 'MobileNetV2', 'EfficientNet-B0', 'ShuffleNetV2',
        'Our Baseline Student', 'Our Distilled Student'
    ],
    'Accuracy (%)': [
        72.49, 74.84, 74.45, 77.30, 73.50, baseline_accuracy, distilled_accuracy
    ],
    'FLOPs (G)': [
        0.13, 0.25, 0.30, 0.39, 0.15, baseline_flops, distilled_flops
    ],
    'Parameters (M)': [
        0.85, 1.73, 3.50, 5.30, 2.30, baseline_params, distilled_params
    ]
}

sota_df = pd.DataFrame(sota_results)

print("\n" + "="*80)
print("STATE-OF-THE-ART COMPARISON ON CIFAR-100")
print("="*80)
print("\n--- Table 4: SOTA Comparison ---")
print(sota_df.to_string(index=False))

sota_df['Efficiency Score'] = sota_df['Accuracy (%)'] / sota_df['FLOPs (G)']
print("\n--- Efficiency Ranking (Accuracy / FLOPs) ---")
print(sota_df[['Model', 'Efficiency Score']].sort_values('Efficiency Score', ascending=False).to_string(index=False))

# Visualizations
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

colors = ['#95a5a6' if 'Our' not in model else '#e74c3c' for model in sota_df['Model']]
axes[0].barh(sota_df['Model'], sota_df['Accuracy (%)'], color=colors)
axes[0].set_xlabel('Accuracy (%)', fontweight='bold')
axes[0].set_title('Accuracy Comparison', fontsize=12, fontweight='bold')
axes[0].grid(True, axis='x', alpha=0.3)

colors_scatter = ['red' if 'Our' in model else 'blue' for model in sota_df['Model']]
sizes = [200 if 'Our' in model else 100 for model in sota_df['Model']]
axes[1].scatter(sota_df['FLOPs (G)'], sota_df['Accuracy (%)'], c=colors_scatter, s=sizes, alpha=0.6)
for i, model in enumerate(sota_df['Model']):
    axes[1].annotate(model, (sota_df['FLOPs (G)'][i], sota_df['Accuracy (%)'][i]), 
                     fontsize=8, ha='right' if 'Our' in model else 'left')
axes[1].set_xlabel('FLOPs (GigaFLOPs)', fontweight='bold')
axes[1].set_ylabel('Accuracy (%)', fontweight='bold')
axes[1].set_title('Efficiency vs. Accuracy Trade-off', fontsize=12, fontweight='bold')
axes[1].grid(True, alpha=0.3)

axes[2].scatter(sota_df['Parameters (M)'], sota_df['Accuracy (%)'], c=colors_scatter, s=sizes, alpha=0.6)
for i, model in enumerate(sota_df['Model']):
    axes[2].annotate(model, (sota_df['Parameters (M)'][i], sota_df['Accuracy (%)'][i]),
                     fontsize=8, ha='right' if 'Our' in model else 'left')
axes[2].set_xlabel('Parameters (Millions)', fontweight='bold')
axes[2].set_ylabel('Accuracy (%)', fontweight='bold')
axes[2].set_title('Model Size vs. Accuracy', fontsize=12, fontweight='bold')
axes[2].grid(True, alpha=0.3)

plt.suptitle('SOTA Comparison on CIFAR-100', fontsize=14, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig(f'{MODEL_DIR}/sota_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

sota_df.to_csv(f'{MODEL_DIR}/sota_comparison.csv', index=False)
print(f"\nSOTA comparison saved to {MODEL_DIR}/")

## 9. Ablation Studies (Core Contribution)

This section is central to our research contribution. We investigate how the key hyperparameters of knowledge distillation, `TEMPERATURE` and `ALPHA`, affect the final performance of the student model. This systematic analysis provides valuable insights into the distillation process for the EfficientNetV2 architecture.

**Note:** Running these studies can be time-consuming. For a full analysis, each experiment should run for the full number of epochs. For a quick check, you can reduce `NUM_EPOCHS_ABLATION`.


### Justification for Hyperparameter Choices in Ablation Study

In our ablation study, we select a range of values for `TEMPERATURE` and `ALPHA` based on established practices and the specific roles these hyperparameters play.

- **Temperature (T):** This parameter softens the probability distribution of the teacher's outputs. A `T=1` represents the standard softmax with no softening. Higher values increase the influence of the teacher's "dark knowledge." We chose `[2, 4, 6]` to explore a range from moderate to significant softening. `T=4` is a commonly used value in literature, providing a strong baseline, while `T=2` and `T=6` allow us to observe the sensitivity of the model to less and more aggressive softening.

- **Alpha (Î±):** This parameter balances the influence of the distillation loss (learning from the teacher) and the cross-entropy loss (learning from the ground-truth labels). The total loss is `Î± * L_KD + (1 - Î±) * L_CE`. A higher `Î±` means we trust the teacher more. We chose `[0.5, 0.7]` to test two common scenarios: an equal balance between the teacher and the hard labels (`Î±=0.5`), and a scenario where the teacher's guidance is given more weight (`Î±=0.7`). This allows us to understand how much the student should rely on the teacher versus the ground truth for this specific task.


In [None]:
# Parameters for ablation study - Aligned with Paper Section 3.3
# Paper: "For temperature, we evaluate T âˆˆ {1, 2, 4, 8}"
temperatures_to_test = [1, 2, 4, 8]
# Paper: "For the weighting factor, we evaluate Î± âˆˆ {0.1, 0.3, 0.5, 0.7, 0.9}"
alphas_to_test = [0.1, 0.3, 0.5, 0.7, 0.9]

ablation_results = {}

print("\n" + "="*80)
print("ABLATION STUDY: Hyperparameter Sensitivity Analysis")
print("="*80)
print(f"Paper Section 3.3: Systematic investigation of T and Î±")
print(f"Testing {len(temperatures_to_test)} temperatures Ã— {len(alphas_to_test)} alphas = {len(temperatures_to_test) * len(alphas_to_test)} experiments")
print(f"Each experiment runs for {NUM_EPOCHS_ABLATION} epochs")

if IS_KAGGLE:
    print(f"\n[KAGGLE NOTE: With {NUM_EPOCHS_ABLATION} epochs, each experiment takes ~8 hours]")
    print(f"TIP: Run 2-3 experiments per Kaggle session (within 9-hour limit)")
    print(f"Total sessions needed: ~7-10 sessions for all 20 experiments")

print("="*80 + "\n")

experiment_count = 0
total_experiments = len(temperatures_to_test) * len(alphas_to_test)

for temp in temperatures_to_test:
    ablation_results[temp] = {}
    for alpha in alphas_to_test:
        experiment_count += 1
        print(f"\n{'='*80}")
        print(f"EXPERIMENT {experiment_count}/{total_experiments}: Temperature={temp}, Alpha={alpha}")
        print(f"{'='*80}")
        
        # Create fresh student model for each experiment
        student_model_abl = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
        student_model_abl.classifier[1] = nn.Linear(student_model_abl.classifier[1].in_features, NUM_CLASSES)
        student_model_abl = student_model_abl.to(device)
        
        optimizer_abl = optim.AdamW(student_model_abl.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
        scheduler_abl = optim.lr_scheduler.CosineAnnealingLR(optimizer_abl, T_max=NUM_EPOCHS_ABLATION)
        
        # Run distillation training with specific T and Î±
        trained_model_abl, _ = train_distillation(
            student_model_abl, teacher_model, trainloader, 
            optimizer_abl, scheduler_abl, temp=temp, alpha=alpha, 
            num_epochs=NUM_EPOCHS_ABLATION, model_name=f"abl_T{temp}_A{alpha}", patience=PATIENCE
        )
        
        # Evaluate and store result
        accuracy = evaluate_model(trained_model_abl, testloader)
        ablation_results[temp][alpha] = accuracy
        print(f"\nResult for T={temp}, Î±={alpha}: Accuracy = {accuracy:.2f}%")
        print(f"Progress: {experiment_count}/{total_experiments} experiments completed")

print("\n" + "="*80)
print("ABLATION STUDY COMPLETED")
print("="*80)

# Convert results to DataFrame for analysis
ablation_df = pd.DataFrame(ablation_results)
ablation_df.index.name = 'Alpha (Î±)'
ablation_df.columns.name = 'Temperature (T)'

print("\n--- Table 3: Ablation Study Results (Accuracy %) ---")
print("Paper Section 3.3: Grid search over all T and Î± combinations")
print(ablation_df.to_string())

# Find best hyperparameters
best_temp, best_alpha, best_acc = None, None, 0
for temp in temperatures_to_test:
    for alpha in alphas_to_test:
        if ablation_results[temp][alpha] > best_acc:
            best_acc = ablation_results[temp][alpha]
            best_temp, best_alpha = temp, alpha

print(f"\n{'='*80}")
print(f"BEST HYPERPARAMETERS FOUND:")
print(f"  Temperature (T): {best_temp}")
print(f"  Alpha (Î±): {best_alpha}")
print(f"  Best Accuracy: {best_acc:.2f}%")
print(f"{'='*80}\n")

# Save results for paper
ablation_df.to_csv(f'{MODEL_DIR}/ablation_results.csv')
print(f"Ablation results saved to {MODEL_DIR}/ablation_results.csv")
print(f"Note: Use this data for Paper Section 4 (Experiments and Results)")

In [None]:
# --- Visualize Ablation Study Results ---

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# 1. Heatmap
sns.heatmap(ablation_df, annot=True, fmt=".2f", cmap="RdYlGn", ax=axes[0], cbar_kws={'label': 'Accuracy (%)'})
axes[0].set_xlabel("Temperature (T)", fontweight='bold')
axes[0].set_ylabel("Alpha (Î±)", fontweight='bold')
axes[0].set_title("Ablation Study Heatmap", fontsize=12, fontweight='bold')

# 2. Temperature Sensitivity (averaged over alphas)
temp_avg = ablation_df.mean(axis=0)
temp_std = ablation_df.std(axis=0)
axes[1].errorbar(temperatures_to_test, temp_avg.values, yerr=temp_std.values, 
                 marker='o', linewidth=2, markersize=8, capsize=5, capthick=2)
axes[1].set_xlabel('Temperature (T)', fontweight='bold')
axes[1].set_ylabel('Average Accuracy (%)', fontweight='bold')
axes[1].set_title('Temperature Sensitivity Analysis', fontsize=12, fontweight='bold')
axes[1].grid(True, alpha=0.3)
axes[1].set_xticks(temperatures_to_test)

# 3. Alpha Sensitivity (averaged over temperatures)
alpha_avg = ablation_df.mean(axis=1)
alpha_std = ablation_df.std(axis=1)
axes[2].errorbar(alphas_to_test, alpha_avg.values, yerr=alpha_std.values,
                 marker='o', linewidth=2, markersize=8, capsize=5, capthick=2)
axes[2].set_xlabel('Alpha (Î±)', fontweight='bold')
axes[2].set_ylabel('Average Accuracy (%)', fontweight='bold')
axes[2].set_title('Alpha Sensitivity Analysis', fontsize=12, fontweight='bold')
axes[2].grid(True, alpha=0.3)
axes[2].set_xticks(alphas_to_test)

plt.suptitle('Ablation Study: Hyperparameter Sensitivity Analysis', fontsize=14, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig('models/ablation_study.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nAblation study visualization saved to models/ablation_study.png")

# Additional analysis: Show top 5 configurations
print("\n--- Top 5 Hyperparameter Configurations ---")
flat_results = []
for temp in temperatures_to_test:
    for alpha in alphas_to_test:
        flat_results.append({
            'Temperature': temp,
            'Alpha': alpha,
            'Accuracy': ablation_results[temp][alpha]
        })

top_configs = pd.DataFrame(flat_results).sort_values('Accuracy', ascending=False).head(5)
print(top_configs.to_string(index=False))

# Statistical analysis
print("\n--- Statistical Summary ---")
print(f"Mean Accuracy: {ablation_df.values.mean():.2f}%")
print(f"Std Deviation: {ablation_df.values.std():.2f}%")
print(f"Min Accuracy: {ablation_df.values.min():.2f}%")
print(f"Max Accuracy: {ablation_df.values.max():.2f}%")
print(f"Range: {ablation_df.values.max() - ablation_df.values.min():.2f}%")