In [None]:
!pip install torch torchvision matplotlib


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import os, time


In [None]:
class TeacherNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.reshape(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

class StudentNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 7 * 7, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.reshape(-1, 32 * 7 * 7)
        x = F.relu(self.fc1(x))
        return self.fc2(x)


In [None]:
transform = transforms.ToTensor()
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader_clean = DataLoader(test_dataset, batch_size=1000)

def add_noise(dataset, noise_level=0.3):
    raw = dataset.data.float() / 255.
    noise = torch.randn_like(raw) * noise_level
    noisy = torch.clamp(raw + noise, 0., 1.)
    noisy = (noisy - 0.1307) / 0.3081
    return DataLoader(TensorDataset(noisy.unsqueeze(1), dataset.targets), batch_size=1000)

test_loader_noisy = add_noise(test_dataset)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher = TeacherNet().to(device)

optimizer = optim.Adam(teacher.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

def train(model, loader, optimizer, criterion, device, epochs=3):
    model.train()
    for epoch in range(epochs):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()

train(teacher, train_loader, optimizer, criterion, device)


In [None]:
student = StudentNet().to(device)
optimizer_s = optim.Adam(student.parameters(), lr=0.001)

def distill_loss(student_logits, teacher_logits, true_labels, T=2.0, alpha=0.7):
    soft_loss = F.kl_div(
        F.log_softmax(student_logits / T, dim=1),
        F.softmax(teacher_logits / T, dim=1),
        reduction='batchmean'
    ) * (T * T)
    hard_loss = F.cross_entropy(student_logits, true_labels)
    return alpha * soft_loss + (1 - alpha) * hard_loss

def train_kd(student, teacher, loader, optimizer, device, epochs=3):
    teacher.eval()
    student.train()
    for epoch in range(epochs):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            with torch.no_grad():
                t_logits = teacher(x)
            s_logits = student(x)
            loss = distill_loss(s_logits, t_logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

train_kd(student, teacher, train_loader, optimizer_s, device)


In [None]:
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    start = time.time()
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            _, pred = torch.max(out, 1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    end = time.time()
    return 100 * correct / total, end - start

acc_clean, t_clean = evaluate(student, test_loader_clean)
acc_noisy, t_noisy = evaluate(student, test_loader_noisy)

torch.save(student.state_dict(), "distilled_student.pth")
size_mb = os.path.getsize("distilled_student.pth") / (1024 ** 2)

print(f"âœ… Clean Acc: {acc_clean:.2f}% | Time: {t_clean:.2f}s")
print(f"âœ… Noisy Acc: {acc_noisy:.2f}% | Time: {t_noisy:.2f}s")
print(f"ðŸ“¦ Model Size: {size_mb:.2f} MB")
