# Workshop 3.1: Knowledge Distillation for Efficient Inference Hands-On Practice

## Learning Objectives
By the end of this notebook, you will:
- Understand the concept of knowledge distillation and teacher-student paradigm
- Learn how to transfer knowledge from a large model to a smaller one
- Implement temperature-based distillation with PyTorch
- Compare student model performance with and without distillation
- Visualize the knowledge transfer process

## What is Knowledge Distillation?
Knowledge distillation is a model compression technique in which a compact 'student' network is trained to approximate the function learned by a larger, more complex 'teacher' model. Rather than relying solely on hard target labels, the student is supervised using the soft probability distributions (soft targets) produced by the teacher. These soft targets encode richer information about class similarities and inter-class relationships, thereby facilitating more effective knowledge transfer and improved generalization in the student model.

Paper: https://arxiv.org/pdf/1503.02531v1

In [None]:
# Import required libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torchvision import models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import time
import copy

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## Step 1: Load and Prepare CIFAR-10 Dataset

In [None]:
# Define transforms
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

# CIFAR-10 classes
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

print(f"Training samples: {len(trainset)}")
print(f"Test samples: {len(testset)}")
print(f"Classes: {classes}")

## Step 2: Define Teacher and Student Models using MobileNetV2

In [None]:
def create_teacher_mobilenetv2(num_classes=10, width_mult=1.4):
    """
    Create a larger MobileNetV2 teacher model with increased width
    """
    # Create base MobileNetV2 with wider channels
    model = models.mobilenet_v2(pretrained=True)
    
    # Modify the first convolution layer for CIFAR-10
    model.features[0][0] = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
    
    # Add an additional intermediate layer in the classifier for more capacity
    model.classifier = nn.Sequential(
        nn.Dropout(0.2),
        nn.Linear(model.last_channel, 512),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(512, num_classes)
    )
    
    return model

class StudentMobileNetV2(nn.Module):
    def __init__(self, num_classes=10):
        super(StudentMobileNetV2, self).__init__()
        base_model = models.mobilenet_v2(pretrained=True)
        # Truncate features to make it smaller
        self.features = nn.Sequential(*list(base_model.features.children())[:14])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # Get output channels after truncation
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, 32, 32)
            features_out = self.features(dummy_input)
            out_channels = features_out.shape[1]
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(out_channels, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

def create_student_mobilenetv2(num_classes=10):
    return StudentMobileNetV2(num_classes=num_classes)

# Create model instances
teacher_model = create_teacher_mobilenetv2(num_classes=10).to(device)
student_model = create_student_mobilenetv2(num_classes=10).to(device)

# Count parameters
teacher_params = sum(p.numel() for p in teacher_model.parameters())
student_params = sum(p.numel() for p in student_model.parameters())

print(f"Teacher MobileNetV2 (Enhanced): {teacher_params:,} parameters")
print(f"Student MobileNetV2 (Compact): {student_params:,} parameters")
print(f"Size reduction: {(teacher_params - student_params) / teacher_params * 100:.1f}%")
print(f"Teacher is {teacher_params / student_params:.1f}x larger than student")

## Step 3: Train the Teacher Model

In [None]:
def train_teacher_model(model, trainloader, testloader, epochs=10, learning_rate=0.001):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5)
    
    train_losses = []
    train_accuracies = []
    
    print("Training teacher model...")
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        pbar = tqdm(trainloader, desc=f'Teacher Epoch {epoch+1}/{epochs}', leave=False)
        for batch_idx, (inputs, labels) in enumerate(pbar):
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({
                'Loss': f'{running_loss/(batch_idx+1):.3f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })
        
        scheduler.step()
        
        epoch_loss = running_loss / len(trainloader)
        epoch_acc = 100. * correct / total
        
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_acc)
        
        print(f'Teacher Epoch {epoch+1}: Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')
    
    return train_losses, train_accuracies

# Train the teacher model (reduced epochs for hands-on efficiency)
teacher_losses, teacher_accuracies = train_teacher_model(teacher_model, trainloader, testloader, epochs=6)

## Step 4: Evaluate Teacher Model

In [None]:
def evaluate_model(model, testloader, model_name="Model"):
    model.eval()
    correct = 0
    total = 0
    
    # Measure inference time
    start_time = time.time()
    
    with torch.no_grad():
        for inputs, labels in tqdm(testloader, desc=f'Evaluating {model_name}', leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    end_time = time.time()
    inference_time = end_time - start_time
    
    accuracy = 100. * correct / total
    print(f'{model_name} - Accuracy: {accuracy:.2f}%, Inference Time: {inference_time:.2f}s')
    
    return accuracy, inference_time

# Evaluate teacher model
teacher_accuracy, teacher_time = evaluate_model(teacher_model, testloader, "Teacher")

## Step 5: TODO - Implement Knowledge Distillation Loss Function

**Your Task**: Complete the `DistillationLoss` class by implementing the core knowledge distillation concepts.

### Key Concepts:
1. **Hard Loss**: Traditional loss between student predictions and true labels
2. **Soft Loss**: KL divergence between student and teacher predictions (with temperature scaling)
3. **Temperature Scaling**: Makes probability distributions softer for better knowledge transfer
4. **Combined Loss**: Weighted combination of hard and soft losses

In [None]:
class DistillationLoss(nn.Module):
    """
    Custom loss function for knowledge distillation
    Combines hard target loss (student vs true labels) and soft target loss (student vs teacher)
    """
    def __init__(self, alpha=0.5, temperature=4.0):
        super(DistillationLoss, self).__init__()
        self.alpha = alpha  # Weight for hard target loss
        self.temperature = temperature  # Temperature for softening predictions
        self.hard_loss = nn.CrossEntropyLoss()
        self.soft_loss = nn.KLDivLoss(reduction='batchmean')
        
    def forward(self, student_outputs, teacher_outputs, labels):
        # TODO 1: Calculate hard target loss (student predictions vs true labels)
        # Hint: Use self.hard_loss with student_outputs and labels
        hard_loss = # YOUR CODE HERE
        
        # TODO 2: Apply temperature scaling to student outputs
        # Hint: Use F.log_softmax with student_outputs divided by self.temperature
        student_soft = # YOUR CODE HERE
        
        # TODO 3: Apply temperature scaling to teacher outputs
        # Hint: Use F.softmax with teacher_outputs divided by self.temperature
        teacher_soft = # YOUR CODE HERE
        
        # TODO 4: Calculate soft target loss and scale by temperature squared
        # Hint: Use self.soft_loss and multiply by (self.temperature ** 2)
        soft_loss = # YOUR CODE HERE
        
        # TODO 5: Combine hard and soft losses using alpha weighting
        # Hint: alpha * hard_loss + (1 - alpha) * soft_loss
        total_loss = # YOUR CODE HERE
        
        return total_loss, hard_loss, soft_loss

# Test your implementation
test_criterion = DistillationLoss(alpha=0.7, temperature=4.0)
test_student_out = torch.randn(4, 10)  # Batch size 4, 10 classes
test_teacher_out = torch.randn(4, 10)
test_labels = torch.randint(0, 10, (4,))

total, hard, soft = test_criterion(test_student_out, test_teacher_out, test_labels)
print(f"Distillation loss implementation test:")
print(f"Total Loss: {total.item():.4f}")
print(f"Hard Loss: {hard.item():.4f}")
print(f"Soft Loss: {soft.item():.4f}")

## Step 6: TODO - Complete the Knowledge Distillation Training Function

**Your Task**: Complete the distillation training loop by implementing the key steps.

### Key Steps:
1. Get teacher predictions (without gradients)
2. Get student predictions 
3. Calculate distillation loss
4. Backpropagate and update student parameters

In [None]:
def train_student_with_distillation(student_model, teacher_model, trainloader, epochs=6, 
                                  learning_rate=0.001, alpha=0.7, temperature=4.0):
    """
    Train student model using knowledge distillation
    """
    distillation_criterion = DistillationLoss(alpha=alpha, temperature=temperature)
    optimizer = optim.Adam(student_model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.7)
    
    # Set teacher to evaluation mode (frozen)
    teacher_model.eval()
    
    train_losses = []
    train_accuracies = []
    hard_losses = []
    soft_losses = []
    
    print(f"Training student with distillation (α={alpha}, T={temperature})...")
    
    for epoch in range(epochs):
        student_model.train()
        running_loss = 0.0
        running_hard_loss = 0.0
        running_soft_loss = 0.0
        correct = 0
        total = 0
        
        pbar = tqdm(trainloader, desc=f'Distillation Epoch {epoch+1}/{epochs}', leave=False)
        for batch_idx, (inputs, labels) in enumerate(pbar):
            inputs, labels = inputs.to(device), labels.to(device)
            
            # TODO 6: Get teacher predictions without gradients
            # Hint: Use torch.no_grad() context and call teacher_model(inputs)
            with torch.no_grad():
                teacher_outputs = # YOUR CODE HERE
            
            # TODO 7: Get student predictions and reset gradients
            # Hint: Call optimizer.zero_grad() and student_model(inputs)
            optimizer.zero_grad()
            student_outputs = # YOUR CODE HERE
            
            # TODO 8: Calculate distillation loss
            # Hint: Use distillation_criterion with student_outputs, teacher_outputs, labels
            total_loss, hard_loss, soft_loss = # YOUR CODE HERE
            
            # TODO 9: Backpropagate and update parameters
            # Hint: Call total_loss.backward() and optimizer.step()
            # YOUR CODE HERE
            # YOUR CODE HERE
            
            # Statistics
            running_loss += total_loss.item()
            running_hard_loss += hard_loss.item()
            running_soft_loss += soft_loss.item()
            
            _, predicted = student_outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({
                'Loss': f'{running_loss/(batch_idx+1):.3f}',
                'Hard': f'{running_hard_loss/(batch_idx+1):.3f}',
                'Soft': f'{running_soft_loss/(batch_idx+1):.3f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })
        
        scheduler.step()
        
        epoch_loss = running_loss / len(trainloader)
        epoch_hard_loss = running_hard_loss / len(trainloader)
        epoch_soft_loss = running_soft_loss / len(trainloader)
        epoch_acc = 100. * correct / total
        
        train_losses.append(epoch_loss)
        hard_losses.append(epoch_hard_loss)
        soft_losses.append(epoch_soft_loss)
        train_accuracies.append(epoch_acc)
        
        print(f'Distillation Epoch {epoch+1}: Total Loss: {epoch_loss:.4f}, '
              f'Hard Loss: {epoch_hard_loss:.4f}, Soft Loss: {epoch_soft_loss:.4f}, '
              f'Accuracy: {epoch_acc:.2f}%')
    
    return train_losses, train_accuracies, hard_losses, soft_losses

## Step 7: Train Student with Knowledge Distillation

In [None]:
# Create a copy of student model for distillation
student_distilled = copy.deepcopy(student_model)

# Train student with distillation
distill_losses, distill_accuracies, hard_losses, soft_losses = train_student_with_distillation(
    student_distilled, teacher_model, trainloader, epochs=6, alpha=0.7, temperature=4.0
)

## Step 8: Train Student Model Without Distillation (Baseline)

In [None]:
def train_student_baseline(model, trainloader, epochs=6, learning_rate=0.001):
    """
    Train student model without distillation (baseline)
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.7)
    
    train_losses = []
    train_accuracies = []
    
    print("Training student baseline (without distillation)...")
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        pbar = tqdm(trainloader, desc=f'Baseline Epoch {epoch+1}/{epochs}', leave=False)
        for batch_idx, (inputs, labels) in enumerate(pbar):
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({
                'Loss': f'{running_loss/(batch_idx+1):.3f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })
        
        scheduler.step()
        
        epoch_loss = running_loss / len(trainloader)
        epoch_acc = 100. * correct / total
        
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_acc)
        
        print(f'Baseline Epoch {epoch+1}: Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')
    
    return train_losses, train_accuracies

# Create another copy for baseline training
student_baseline = create_student_mobilenetv2(num_classes=10).to(device)

# Train baseline student
baseline_losses, baseline_accuracies = train_student_baseline(student_baseline, trainloader)

## Step 9: Evaluate and Compare All Models

In [None]:
# Evaluate all models
print("\nEvaluating all models...")

# Teacher model (already evaluated)
print(f"Teacher: {teacher_accuracy:.2f}% accuracy")

# Student baseline
baseline_accuracy, baseline_time = evaluate_model(student_baseline, testloader, "Student Baseline")

# Student with distillation
distilled_accuracy, distilled_time = evaluate_model(student_distilled, testloader, "Student Distilled")

# Calculate improvements
distillation_improvement = distilled_accuracy - baseline_accuracy
speed_improvement = teacher_time / distilled_time
size_reduction = (teacher_params - student_params) / teacher_params * 100

print("\n" + "="*70)
print("                    KNOWLEDGE DISTILLATION RESULTS")
print("="*70)
print(f"Teacher Model:          {teacher_accuracy:.2f}% accuracy, {teacher_params:,} params")
print(f"Student Baseline:       {baseline_accuracy:.2f}% accuracy, {student_params:,} params")
print(f"Student Distilled:      {distilled_accuracy:.2f}% accuracy, {student_params:,} params")
print("\nImprovements:")
print(f"Distillation boost:     +{distillation_improvement:.2f}% accuracy")
print(f"Speed improvement:      {speed_improvement:.1f}x faster than teacher")
print(f"Model size reduction:   {size_reduction:.1f}% fewer parameters")
print("="*70)

## Step 10: Visualize Knowledge Transfer Results

In [None]:
# Training curves comparison and loss breakdown
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Training curves comparison
epochs_range = range(1, len(baseline_accuracies) + 1)
ax1.plot(epochs_range, baseline_accuracies, 'b-', label='Student Baseline', linewidth=2, marker='o')
ax1.plot(epochs_range, distill_accuracies, 'r-', label='Student Distilled', linewidth=2, marker='s')
ax1.axhline(y=teacher_accuracy, color='g', linestyle='--', label='Teacher', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Training Accuracy (%)')
ax1.set_title('Training Accuracy Comparison')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Loss breakdown for distillation
ax2.plot(epochs_range, hard_losses, 'b-', label='Hard Loss (vs labels)', linewidth=2, marker='o')
ax2.plot(epochs_range, soft_losses, 'r-', label='Soft Loss (vs teacher)', linewidth=2, marker='s')
ax2.plot(epochs_range, distill_losses, 'g-', label='Total Loss', linewidth=2, marker='^')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss')
ax2.set_title('Distillation Loss Components')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Step 11: Final Model Comparison

In [None]:
# Final Model Comparison
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))

# Final accuracy comparison
model_names = ['Teacher', 'Student\nBaseline', 'Student\nDistilled']
accuracies = [teacher_accuracy, baseline_accuracy, distilled_accuracy]
colors = ['green', 'blue', 'red']
bars = ax1.bar(model_names, accuracies, color=colors, alpha=0.7)
ax1.set_ylabel('Test Accuracy (%)')
ax1.set_title('Final Model Comparison')
ax1.set_ylim([min(accuracies) - 2, max(accuracies) + 2])
for bar, acc in zip(bars, accuracies):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, 
             f'{acc:.2f}%', ha='center', va='bottom', fontweight='bold')

# Model complexity comparison
params = [teacher_params/1000, student_params/1000, student_params/1000]  # In thousands
bars = ax2.bar(model_names, params, color=colors, alpha=0.7)
ax2.set_ylabel('Parameters (thousands)')
ax2.set_title('Model Complexity')
for bar, param in zip(bars, params):
    ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 5, 
             f'{param:.0f}K', ha='center', va='bottom', fontweight='bold')

# Inference time comparison
times = [teacher_time, baseline_time, distilled_time]
bars = ax3.bar(model_names, times, color=colors, alpha=0.7)
ax3.set_ylabel('Inference Time (seconds)')
ax3.set_title('Inference Speed')
for bar, time_val in zip(bars, times):
    ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1, 
             f'{time_val:.2f}s', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.show()

## Step 12: TODO - Experiment with Different Distillation Parameters

**Your Task**: Complete the parameter experiment function and find the best parameters.

### Key Parameters to Tune:
- **Temperature**: Controls how soft the probability distributions become
- **Alpha**: Balances hard loss vs soft loss (0=only soft, 1=only hard)

In [None]:
def quick_distillation_experiment(teacher_model, temperature=4.0, alpha=0.5, epochs=2):
    """
    Quick experiment to test different distillation parameters
    """
    # Create a new student model
    student = create_student_mobilenetv2(num_classes=10).to(device)
    
    # TODO 10: Create distillation criterion with given parameters
    # Hint: Use DistillationLoss with alpha and temperature
    distillation_criterion = # YOUR CODE HERE
    
    optimizer = optim.Adam(student.parameters(), lr=0.001, weight_decay=1e-4)
    
    teacher_model.eval()
    
    for epoch in range(epochs):
        student.train()
        
        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # TODO 11: Get teacher outputs (similar to previous TODOs)
            with torch.no_grad():
                teacher_outputs = # YOUR CODE HERE
            
            optimizer.zero_grad()
            student_outputs = student(inputs)
            
            # TODO 12: Calculate loss and backpropagate
            total_loss, _, _ = # YOUR CODE HERE
            # YOUR CODE HERE  # backward
            # YOUR CODE HERE  # optimizer step
    
    # Quick evaluation
    student.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = student(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    accuracy = 100. * correct / total
    return accuracy

## Step 13: TODO - Run Parameter Experiments

**Your Task**: Complete the parameter experiment loops and analyze the results.

In [None]:
# Experiment with different parameters
print("Experimenting with different distillation parameters...")

temperatures = [1.0, 3.0, 5.0, 8.0]
alphas = [0.3, 0.5, 0.7, 0.9]

# TODO 13: Temperature experiment
temp_results = []
for temp in temperatures:
    print(f"Testing temperature {temp}...")
    # TODO: Call quick_distillation_experiment with temperature=temp, alpha=0.7, epochs=2
    acc = # YOUR CODE HERE
    temp_results.append(acc)
    print(f"Temperature {temp}: {acc:.2f}% accuracy")

# TODO 14: Alpha experiment
alpha_results = []
for alpha in alphas:
    print(f"Testing alpha {alpha}...")
    # TODO: Call quick_distillation_experiment with temperature=4.0, alpha=alpha, epochs=2
    acc = # YOUR CODE HERE
    alpha_results.append(acc)
    print(f"Alpha {alpha}: {acc:.2f}% accuracy")

## Step 14: Visualize Parameter Sensitivity

In [None]:
# Visualize parameter sensitivity
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Temperature sensitivity
ax1.plot(temperatures, temp_results, 'bo-', linewidth=2, markersize=8)
ax1.set_xlabel('Temperature')
ax1.set_ylabel('Test Accuracy (%)')
ax1.set_title('Temperature Sensitivity')
ax1.grid(True, alpha=0.3)

for i, (temp, acc) in enumerate(zip(temperatures, temp_results)):
    ax1.annotate(f'{acc:.1f}%', (temp, acc), textcoords="offset points", xytext=(0,10), ha='center')

# Alpha sensitivity
ax2.plot(alphas, alpha_results, 'ro-', linewidth=2, markersize=8)
ax2.set_xlabel('Alpha (Hard Loss Weight)')
ax2.set_ylabel('Test Accuracy (%)')
ax2.set_title('Alpha Sensitivity')
ax2.grid(True, alpha=0.3)

for i, (alpha, acc) in enumerate(zip(alphas, alpha_results)):
    ax2.annotate(f'{acc:.1f}%', (alpha, acc), textcoords="offset points", xytext=(0,10), ha='center')

plt.tight_layout()
plt.show()

print(f"\nPARAMETER OPTIMIZATION RESULTS:")
print(f"Best temperature: {temperatures[np.argmax(temp_results)]} (accuracy: {max(temp_results):.2f}%)")
print(f"Best alpha: {alphas[np.argmax(alpha_results)]} (accuracy: {max(alpha_results):.2f}%)")