In [None]:
# =========================================
# 📘 Knowledge Distillation Assignment
# Task: Reduce model latency using distillation
# Dataset: CIFAR-10
# Teacher: ResNet-50
# Student: ResNet-18
# =========================================

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision import models
import time
import matplotlib.pyplot as plt

# ----------------------------
# 1️⃣ Load CIFAR-10
# ----------------------------
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                         shuffle=False, num_workers=2)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



In [None]:
# ----------------------------
# 2️⃣ Define Teacher and Student
# ----------------------------
teacher = models.resnet50(pretrained=False, num_classes=10).to(device)
student = models.resnet18(pretrained=False, num_classes=10).to(device)



In [None]:
# ----------------------------
# 3️⃣ Define Training Utilities
# ----------------------------
def train_one_epoch(model, dataloader, optimizer, criterion):
    model.train()
    total_loss = 0
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

def evaluate(model, dataloader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    return 100. * correct / total



In [None]:
# ----------------------------
# 4️⃣ Distillation Loss
# ----------------------------
class DistillationLoss(nn.Module):
    def __init__(self, T=4.0, alpha=0.5):
        super().__init__()
        self.T = T
        self.alpha = alpha
        self.kl = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits, teacher_logits, true_labels):
        # Hard-label loss
        hard_loss = F.cross_entropy(student_logits, true_labels)
        # Soft-label loss
        soft_teacher = F.log_softmax(teacher_logits / self.T, dim=1)
        soft_student = F.log_softmax(student_logits / self.T, dim=1)
        soft_loss = self.kl(soft_student, soft_teacher) * (self.T ** 2)
        return self.alpha * hard_loss + (1 - self.alpha) * soft_loss



In [None]:
# ----------------------------
# 5️⃣ Train Teacher
# ----------------------------
print("Training Teacher Model (ResNet-50)...")
opt_t = optim.SGD(teacher.parameters(), lr=0.01, momentum=0.9)
for epoch in range(2):  # keep short for demo
    loss = train_one_epoch(teacher, trainloader, opt_t, nn.CrossEntropyLoss())
    acc = evaluate(teacher, testloader)
    print(f"Epoch {epoch+1}: loss={loss:.3f}, acc={acc:.2f}%")

torch.save(teacher.state_dict(), "teacher.pth")

# ----------------------------
# 6️⃣ Train Student with KD
# ----------------------------
print("Training Student Model (ResNet-18) with Distillation...")
teacher.eval()
criterion_kd = DistillationLoss(T=4.0, alpha=0.5)
opt_s = optim.Adam(student.parameters(), lr=0.001)

for epoch in range(2):
    student.train()
    total_loss = 0
    for inputs, targets in trainloader:
        inputs, targets = inputs.to(device), targets.to(device)
        opt_s.zero_grad()
        with torch.no_grad():
            t_logits = teacher(inputs)
        s_logits = student(inputs)
        loss = criterion_kd(s_logits, t_logits, targets)
        loss.backward()
        opt_s.step()
        total_loss += loss.item()
    acc = evaluate(student, testloader)
    print(f"Epoch {epoch+1}: loss={total_loss/len(trainloader):.3f}, acc={acc:.2f}%")

torch.save(student.state_dict(), "student_kd.pth")



In [None]:
# ----------------------------
# 7️⃣ Latency Measurement
# ----------------------------
def measure_latency(model, n_runs=50):
    model.eval()
    dummy_input = torch.randn(1, 3, 32, 32).to(device)
    torch.cuda.synchronize() if device.type == "cuda" else None
    start = time.time()
    for _ in range(n_runs):
        _ = model(dummy_input)
    torch.cuda.synchronize() if device.type == "cuda" else None
    end = time.time()
    return (end - start) / n_runs * 1000  # ms per inference

lat_teacher = measure_latency(teacher)
lat_student = measure_latency(student)

print(f"\nLatency (Teacher): {lat_teacher:.2f} ms")
print(f"Latency (Student): {lat_student:.2f} ms")



In [None]:
# ----------------------------
# 8️⃣ Visualization
# ----------------------------
plt.bar(["Teacher", "Student (KD)"], [lat_teacher, lat_student], color=["red","green"])
plt.ylabel("Latency (ms/sample)")
plt.title("Inference Speed Comparison")
plt.show()
