<a href="https://colab.research.google.com/github/hongqin/Generative_AI_Fa25/blob/main/teacher_student_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

How it works

A teacher MLP is trained on MNIST (2 hidden layers).

Its logits (softened by temperature T) become “soft targets.”

A smaller student MLP learns from both the hard labels and the teacher’s soft targets via a combined loss.

You’ll see teacher vs. student test accuracies printed at the end.

In [2]:

# Toy Teacher-Student Distillation Demo

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

# 1) Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 2) Define Teacher and Student
class TeacherNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

class StudentNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

# 3) Hyperparameters
batch_size      = 64
teacher_epochs  = 3
student_epochs  = 3
lr              = 0.01
T               = 2.0    # distillation temperature
alpha           = 0.5    # mix weight between hard & soft losses

# 4) Data loaders (small subset of MNIST for speed)
transform = transforms.ToTensor()
full_train = datasets.MNIST(root='.', train=True, download=True, transform=transform)
train_subset = Subset(full_train, list(range(2000)))
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(
    datasets.MNIST(root='.', train=False, download=True, transform=transform),
    batch_size=1000, shuffle=False
)

# 5) Evaluation helper
def evaluate(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            pred = out.argmax(dim=1)
            correct += (pred == y).sum().item()
            total   += y.size(0)
    return 100 * correct / total

# 6) Train the Teacher
teacher   = TeacherNet().to(device)
opt_t     = optim.SGD(teacher.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

for epoch in range(1, teacher_epochs+1):
    teacher.train()
    total_loss = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        opt_t.zero_grad()
        logits = teacher(x)
        loss   = criterion(logits, y)
        loss.backward()
        opt_t.step()
        total_loss += loss.item()
    print(f"[Teacher] Epoch {epoch}/{teacher_epochs}, Loss={total_loss/len(train_loader):.4f}")

print(f"Teacher Test Acc: {evaluate(teacher, test_loader):.2f}%\n")

# 7) Train the Student via Distillation
student = StudentNet().to(device)
opt_s    = optim.SGD(student.parameters(), lr=lr)

def distill_loss(s_logits, t_logits, y, T, alpha):
    p_t = F.log_softmax(t_logits / T, dim=1)
    p_s = F.log_softmax(s_logits / T, dim=1)
    loss_soft = F.kl_div(p_s, p_t, reduction='batchmean') * (T*T)
    loss_hard = F.cross_entropy(s_logits, y)
    return alpha * loss_hard + (1 - alpha) * loss_soft

for epoch in range(1, student_epochs+1):
    student.train()
    total_loss = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            t_logits = teacher(x)
        opt_s.zero_grad()
        s_logits = student(x)
        loss      = distill_loss(s_logits, t_logits, y, T, alpha)
        loss.backward()
        opt_s.step()
        total_loss += loss.item()
    print(f"[Student] Epoch {epoch}/{student_epochs}, DistillLoss={total_loss/len(train_loader):.4f}")

print(f"Student Test Acc: {evaluate(student, test_loader):.2f}%")


100%|██████████| 9.91M/9.91M [00:00<00:00, 40.5MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.15MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 9.88MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.34MB/s]


[Teacher] Epoch 1/3, Loss=2.3043
[Teacher] Epoch 2/3, Loss=2.2933
[Teacher] Epoch 3/3, Loss=2.2783
Teacher Test Acc: 23.27%

[Student] Epoch 1/3, DistillLoss=nan
[Student] Epoch 2/3, DistillLoss=nan
[Student] Epoch 3/3, DistillLoss=nan
Student Test Acc: 9.80%
