# Lecture 9: Knowledge Distillation

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/transformer_problems/blob/efficientml-course/efficientml_course/09_knowledge_distillation/demo.ipynb)

Transfer knowledge from large teacher to small student model.


In [None]:
!pip install torch -q
import torch
import torch.nn as nn
import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.7):
    """
    Knowledge Distillation Loss
    T: Temperature (higher = softer probabilities)
    alpha: Weight for soft vs hard targets
    """
    # Soft targets from teacher
    soft_targets = F.softmax(teacher_logits / T, dim=-1)
    soft_loss = F.kl_div(
        F.log_softmax(student_logits / T, dim=-1),
        soft_targets,
        reduction='batchmean'
    ) * (T ** 2)
    
    # Hard targets (ground truth)
    hard_loss = F.cross_entropy(student_logits, labels)
    
    return alpha * soft_loss + (1 - alpha) * hard_loss

# Demo
batch_size, num_classes = 4, 10
teacher_logits = torch.randn(batch_size, num_classes) * 2  # More confident
student_logits = torch.randn(batch_size, num_classes)
labels = torch.randint(0, num_classes, (batch_size,))

# Compare soft vs hard labels
print("Why soft labels are better:")
print(f"\nHard label: {F.one_hot(labels[0], num_classes).float().numpy()}")
print(f"Soft label: {F.softmax(teacher_logits[0] / 4, dim=0).detach().numpy().round(3)}")

loss = distillation_loss(student_logits, teacher_logits, labels)
print(f"\nDistillation loss: {loss.item():.4f}")
print("\nðŸŽ¯ Soft labels contain more information (class similarities)!")
