## Problem: Implement KL Divergence Loss

### Background

**Kullback-Leibler (KL) Divergence** is a fundamental concept in information theory and machine learning. It measures how one probability distribution differs from another reference distribution.

### Mathematical Definitions

#### Discrete KL Divergence

For discrete probability distributions $P$ and $Q$:

$$
D_{KL}(P || Q) = \sum_i P(i) \log \frac{P(i)}{Q(i)}
$$

Properties:
- Always non-negative: $D_{KL}(P || Q) \geq 0$
- Zero if and only if $P = Q$
- Not symmetric: $D_{KL}(P || Q) \neq D_{KL}(Q || P)$

#### Continuous KL Divergence (Gaussian)

For two Gaussian distributions $\mathcal{N}(\mu_p, \sigma_p^2)$ and $\mathcal{N}(\mu_q, \sigma_q^2)$:

$$
D_{KL}(\mathcal{N}_p || \mathcal{N}_q) = \log \frac{\sigma_q}{\sigma_p} + \frac{\sigma_p^2 + (\mu_p - \mu_q)^2}{2\sigma_q^2} - \frac{1}{2}
$$

### Knowledge Distillation

**Knowledge distillation** uses KL divergence to transfer knowledge from a large "teacher" model to a smaller "student" model.

The distillation loss combines two terms:

$$
L_{distill} = \alpha \cdot L_{soft} + (1-\alpha) \cdot L_{hard}
$$

Where:
- $L_{soft} = T^2 \cdot D_{KL}(\text{teacher}_T || \text{student}_T)$ - Soft target loss using temperature $T$
- $L_{hard} = CE(\text{student}, \text{labels})$ - Hard target loss (standard cross-entropy)
- $\alpha$ - Weighting between soft and hard targets (typically 0.5-0.9)
- $T$ - Temperature for softening distributions (typically 1-5)

**Temperature scaling**: Higher temperatures produce softer probability distributions, exposing the relative confidences of the teacher model.

### Learning Objectives

In this problem, you will:
1. Implement discrete KL divergence for classification tasks
2. Implement continuous KL divergence for Gaussian distributions
3. Create a distillation loss function combining soft and hard targets
4. Train a teacher model and compress it into a student model using distillation
5. Compare distilled vs baseline student performance

### References
- [Hinton et al. - Distilling the Knowledge in a Neural Network (2015)](https://arxiv.org/abs/1503.02531)
- [KL Divergence on Wikipedia](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# Generate synthetic MNIST-like dataset
torch.manual_seed(42)
np.random.seed(42)

# 500 samples, 28x28 flattened = 784 features, 10 classes
num_samples = 500
input_dim = 784
num_classes = 10

# Generate random features
X = torch.randn(num_samples, input_dim)

# Generate labels with class-dependent means for better separability
y = torch.randint(0, num_classes, (num_samples,))

# Add class-specific signal to features
for i in range(num_classes):
    mask = y == i
    class_signal = torch.randn(input_dim) * 0.5
    X[mask] += class_signal

# Create DataLoader
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

print(f"Dataset: {num_samples} samples, {input_dim} features, {num_classes} classes")
print(f"Batch size: 32, Number of batches: {len(dataloader)}")

In [None]:
def kl_divergence_discrete(p_logits, q_logits, temperature=1.0):
    """
    Compute KL divergence between two discrete distributions from logits.
    
    Args:
        p_logits: Target distribution logits [batch_size, num_classes]
        q_logits: Predicted distribution logits [batch_size, num_classes]
        temperature: Temperature for softening distributions (default: 1.0)
    
    Returns:
        KL divergence loss (scalar)
    
    Note:
        - Higher temperature produces softer probability distributions
        - Temperature T=1.0 is standard softmax
        - Uses log-space for numerical stability
    """
    # Apply temperature scaling to both distributions
    p_scaled = p_logits / temperature
    q_scaled = q_logits / temperature
    
    # Convert to log probabilities (log_softmax is numerically stable)
    log_p = F.log_softmax(p_scaled, dim=-1)
    log_q = F.log_softmax(q_scaled, dim=-1)
    
    # KL(P || Q) = sum(P * log(P/Q)) = sum(P * (log(P) - log(Q)))
    # In PyTorch's F.kl_div, we need to pass log_q and p (not log_p)
    # But for manual calculation: KL = exp(log_p) * (log_p - log_q)
    p = torch.exp(log_p)
    kl = torch.sum(p * (log_p - log_q), dim=-1)
    
    return kl.mean()

In [None]:
# Test discrete KL divergence
print("Testing Discrete KL Divergence")
print("=" * 50)

# Test 1: Temperature = 1.0
p_logits = torch.tensor([[2.0, 1.0, 0.1], [1.0, 2.0, 0.5]])
q_logits = torch.tensor([[1.8, 1.1, 0.2], [0.9, 2.1, 0.6]])

kl_custom = kl_divergence_discrete(p_logits, q_logits, temperature=1.0)

# PyTorch reference: F.kl_div expects log_q and p (not log_p)
log_p = F.log_softmax(p_logits, dim=-1)
log_q = F.log_softmax(q_logits, dim=-1)
kl_pytorch = F.kl_div(log_q, log_p, reduction='batchmean', log_target=True)

print(f"Custom KL (T=1.0): {kl_custom.item():.6f}")
print(f"PyTorch KL (T=1.0): {kl_pytorch.item():.6f}")
assert torch.allclose(kl_custom, kl_pytorch, atol=1e-6, rtol=1e-5), "KL divergence mismatch at T=1.0!"
print("✓ Test passed for T=1.0\n")

# Test 2: Temperature = 3.0 (softer distributions)
kl_custom_t3 = kl_divergence_discrete(p_logits, q_logits, temperature=3.0)

log_p_t3 = F.log_softmax(p_logits / 3.0, dim=-1)
log_q_t3 = F.log_softmax(q_logits / 3.0, dim=-1)
kl_pytorch_t3 = F.kl_div(log_q_t3, log_p_t3, reduction='batchmean', log_target=True)

print(f"Custom KL (T=3.0): {kl_custom_t3.item():.6f}")
print(f"PyTorch KL (T=3.0): {kl_pytorch_t3.item():.6f}")
assert torch.allclose(kl_custom_t3, kl_pytorch_t3, atol=1e-6, rtol=1e-5), "KL divergence mismatch at T=3.0!"
print("✓ Test passed for T=3.0")

print(f"\nNote: Higher temperature reduces KL divergence ({kl_custom.item():.6f} → {kl_custom_t3.item():.6f})")

In [None]:
def kl_divergence_gaussian(mu_p, sigma_p, mu_q, sigma_q):
    """
    Compute KL divergence between two Gaussian distributions.
    
    Args:
        mu_p: Mean of distribution P (scalar or tensor)
        sigma_p: Standard deviation of distribution P (must be positive)
        mu_q: Mean of distribution Q (scalar or tensor)
        sigma_q: Standard deviation of distribution Q (must be positive)
    
    Returns:
        KL(P || Q) for Gaussian distributions
    
    Formula:
        KL(N(μ_p, σ_p²) || N(μ_q, σ_q²)) = 
            log(σ_q/σ_p) + (σ_p² + (μ_p - μ_q)²) / (2σ_q²) - 1/2
    """
    # Ensure positive standard deviations
    assert (sigma_p > 0).all() and (sigma_q > 0).all(), "Standard deviations must be positive"
    
    # KL divergence formula for Gaussians
    var_p = sigma_p ** 2
    var_q = sigma_q ** 2
    
    kl = torch.log(sigma_q / sigma_p) + (var_p + (mu_p - mu_q) ** 2) / (2 * var_q) - 0.5
    
    return kl.mean()

In [None]:
# Test Gaussian KL divergence
print("Testing Gaussian KL Divergence")
print("=" * 50)

# Test 1: Simple Gaussian distributions
mu_p = torch.tensor([0.0, 1.0, -0.5])
sigma_p = torch.tensor([1.0, 1.5, 0.8])
mu_q = torch.tensor([0.1, 1.2, -0.4])
sigma_q = torch.tensor([1.1, 1.4, 0.9])

kl_custom = kl_divergence_gaussian(mu_p, sigma_p, mu_q, sigma_q)

# PyTorch reference using torch.distributions
from torch.distributions import Normal, kl_divergence as torch_kl

p_dist = Normal(mu_p, sigma_p)
q_dist = Normal(mu_q, sigma_q)
kl_pytorch = torch_kl(p_dist, q_dist).mean()

print(f"Custom Gaussian KL: {kl_custom.item():.6f}")
print(f"PyTorch Gaussian KL: {kl_pytorch.item():.6f}")
assert torch.allclose(kl_custom, kl_pytorch, atol=1e-6, rtol=1e-5), "Gaussian KL divergence mismatch!"
print("✓ Test passed for Gaussian KL divergence\n")

# Test 2: Identical distributions (KL should be ~0)
kl_identical = kl_divergence_gaussian(mu_p, sigma_p, mu_p, sigma_p)
print(f"KL for identical distributions: {kl_identical.item():.8f} (should be ~0)")
assert torch.allclose(kl_identical, torch.tensor(0.0), atol=1e-6), "KL should be 0 for identical distributions!"
print("✓ Test passed for identical distributions")

In [None]:
def distillation_loss(student_logits, teacher_logits, labels, alpha=0.7, temperature=3.0):
    """
    Combined distillation loss for knowledge transfer.
    
    Args:
        student_logits: Student model outputs [batch_size, num_classes]
        teacher_logits: Teacher model outputs [batch_size, num_classes] (detached)
        labels: Ground truth labels [batch_size]
        alpha: Weight for soft targets (default: 0.7 = 70% distillation)
        temperature: Temperature for soft targets (default: 3.0)
    
    Returns:
        Combined loss: alpha * L_soft + (1-alpha) * L_hard
    
    Notes:
        - L_soft uses temperature-scaled KL divergence (scaled by T²)
        - L_hard is standard cross-entropy with true labels
        - Higher alpha prioritizes learning from teacher
    """
    # Soft target loss: KL divergence between teacher and student (with temperature)
    # Scale by T² to balance gradient magnitudes (see Hinton et al. 2015)
    soft_loss = kl_divergence_discrete(teacher_logits, student_logits, temperature) * (temperature ** 2)
    
    # Hard target loss: Standard cross-entropy with ground truth
    hard_loss = F.cross_entropy(student_logits, labels)
    
    # Combine losses with alpha weighting
    total_loss = alpha * soft_loss + (1 - alpha) * hard_loss
    
    return total_loss

In [None]:
class TeacherModel(nn.Module):
    """Large teacher model with 4 hidden layers."""
    
    def __init__(self, input_dim=784, hidden_dim=256, num_classes=10):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )
    
    def forward(self, x):
        return self.network(x)

# Initialize teacher and print parameter count
teacher = TeacherModel()
teacher_params = sum(p.numel() for p in teacher.parameters())
print(f"Teacher Model: {teacher_params:,} parameters")
print(f"Architecture: 784 → 256 → 256 → 256 → 10")

In [None]:
class StudentModel(nn.Module):
    """Smaller student model with 2 hidden layers."""
    
    def __init__(self, input_dim=784, hidden_dim=64, num_classes=10):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )
    
    def forward(self, x):
        return self.network(x)

# Initialize student and print parameter count
student = StudentModel()
student_params = sum(p.numel() for p in student.parameters())
print(f"Student Model: {student_params:,} parameters")
print(f"Architecture: 784 → 64 → 10")
print(f"Compression ratio: {teacher_params / student_params:.1f}x")

In [None]:
# Train teacher model
print("Training Teacher Model")
print("=" * 50)

teacher = TeacherModel()
optimizer = torch.optim.Adam(teacher.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

teacher.train()
for epoch in range(20):
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_x, batch_y in dataloader:
        # Forward pass
        outputs = teacher(batch_x)
        loss = criterion(outputs, batch_y)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Track metrics
        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == batch_y).sum().item()
        total += batch_y.size(0)
    
    # Print progress every 5 epochs
    if (epoch + 1) % 5 == 0:
        avg_loss = total_loss / len(dataloader)
        accuracy = 100 * correct / total
        print(f"Epoch [{epoch+1}/20], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")

print("\n✓ Teacher training completed")

In [None]:
# Train student with distillation
print("Training Student Model with Distillation")
print("=" * 50)

student_distilled = StudentModel()
optimizer = torch.optim.Adam(student_distilled.parameters(), lr=0.001)

teacher.eval()  # Teacher in eval mode
student_distilled.train()

for epoch in range(20):
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_x, batch_y in dataloader:
        # Get teacher outputs (no gradients)
        with torch.no_grad():
            teacher_logits = teacher(batch_x)
        
        # Student forward pass
        student_logits = student_distilled(batch_x)
        
        # Distillation loss
        loss = distillation_loss(student_logits, teacher_logits, batch_y, alpha=0.7, temperature=3.0)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Track metrics
        total_loss += loss.item()
        _, predicted = torch.max(student_logits, 1)
        correct += (predicted == batch_y).sum().item()
        total += batch_y.size(0)
    
    # Print progress every 5 epochs
    if (epoch + 1) % 5 == 0:
        avg_loss = total_loss / len(dataloader)
        accuracy = 100 * correct / total
        print(f"Epoch [{epoch+1}/20], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")

print("\n✓ Student distillation completed")

In [None]:
# Train baseline student (without distillation)
print("Training Baseline Student Model (No Distillation)")
print("=" * 50)

student_baseline = StudentModel()
optimizer = torch.optim.Adam(student_baseline.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

student_baseline.train()
for epoch in range(20):
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_x, batch_y in dataloader:
        # Forward pass
        outputs = student_baseline(batch_x)
        loss = criterion(outputs, batch_y)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Track metrics
        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == batch_y).sum().item()
        total += batch_y.size(0)
    
    # Print progress every 5 epochs
    if (epoch + 1) % 5 == 0:
        avg_loss = total_loss / len(dataloader)
        accuracy = 100 * correct / total
        print(f"Epoch [{epoch+1}/20], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")

print("\n✓ Baseline student training completed")

In [None]:
# Evaluation function
def evaluate_model(model, dataloader):
    """Calculate accuracy on dataset."""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch_x, batch_y in dataloader:
            outputs = model(batch_x)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == batch_y).sum().item()
            total += batch_y.size(0)
    
    return correct / total

# Evaluate all models
teacher_acc = evaluate_model(teacher, dataloader)
student_distilled_acc = evaluate_model(student_distilled, dataloader)
student_baseline_acc = evaluate_model(student_baseline, dataloader)

print("Final Model Evaluation")
print("=" * 50)
print(f"Teacher Model: {teacher_acc:.2%} accuracy (~{teacher_params:,} params)")
print(f"Student (Distilled): {student_distilled_acc:.2%} accuracy (~{student_params:,} params)")
print(f"Student (Baseline): {student_baseline_acc:.2%} accuracy (~{student_params:,} params)")
print(f"\nDistillation Improvement: {(student_distilled_acc - student_baseline_acc)*100:.1f}%")
print(f"Gap from Teacher: {(teacher_acc - student_distilled_acc)*100:.1f}%")

In [None]:
# Visualize temperature effect
print("\nTemperature Effect on Soft Targets")
print("=" * 50)

temperatures = [1.0, 2.0, 3.0, 5.0, 10.0]
sample_logits = torch.tensor([2.0, 1.0, 0.1, -1.0, -2.0])

fig, axes = plt.subplots(1, len(temperatures), figsize=(15, 3))

for ax, T in zip(axes, temperatures):
    probs = F.softmax(sample_logits / T, dim=0)
    ax.bar(range(len(probs)), probs.numpy())
    ax.set_title(f'T={T}')
    ax.set_xlabel('Class')
    ax.set_ylabel('Probability')
    ax.set_ylim([0, 1])
    ax.grid(axis='y', alpha=0.3)

plt.suptitle('Effect of Temperature on Soft Target Distributions', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

print("\nObservation: Higher temperatures produce softer (more uniform) distributions.")
print("This helps the student learn from the teacher's relative confidences, not just the top class.")