In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)




In [2]:
# Define the Teacher Model
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64*7*7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 64*7*7)
        feat = F.relu(self.fc1(x))
        x = self.fc2(feat)
        return x, feat

# Define the Student Model
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32*7*7, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 32*7*7)
        feat = F.relu(self.fc1(x))
        x = self.fc2(feat)
        return x, feat

# Train the Teacher Model
teacher_model = TeacherModel().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(teacher_model.parameters(), lr=0.001)

def train_teacher():
    teacher_model.train()
    for epoch in range(10):
        for images, labels in train_loader:
            images, labels = images.cuda(), labels.cuda()
            optimizer.zero_grad()
            outputs, _ = teacher_model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')

train_teacher()



Epoch 1, Loss: 0.014457643032073975
Epoch 2, Loss: 0.03647108003497124
Epoch 3, Loss: 0.006384856067597866
Epoch 4, Loss: 0.006136614829301834
Epoch 5, Loss: 0.006689592730253935
Epoch 6, Loss: 6.137792661320418e-05
Epoch 7, Loss: 0.030274182558059692
Epoch 8, Loss: 0.0001408525713486597
Epoch 9, Loss: 0.00022880287724547088
Epoch 10, Loss: 0.0014746770029887557


In [3]:
# Feature-based Distillation Loss Function
def feature_distillation_loss(student_feat, teacher_feat, student_logits, labels, alpha=0.7):
    feat_loss = F.mse_loss(student_feat, F.avg_pool1d
(teacher_feat.detach(),2))
    ce_loss = criterion(student_logits, labels)
    return alpha * feat_loss + (1 - alpha) * ce_loss

# Train the Student Model using Feature-Based Distillation
student_model = StudentModel().cuda()
optimizer = optim.Adam(student_model.parameters(), lr=0.001)

def train_student():
    teacher_model.eval()
    student_model.train()
    for epoch in range(10):
        for images, labels in train_loader:
            images, labels = images.cuda(), labels.cuda()
            optimizer.zero_grad()
            student_outputs, student_feat = student_model(images)
            with torch.no_grad():
                _, teacher_feat = teacher_model(images)
            loss = feature_distillation_loss(student_feat, teacher_feat, student_outputs, labels)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')

train_student()

# Evaluate both models
def evaluate(model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.cuda(), labels.cuda()
            outputs, _ = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

teacher_acc = evaluate(teacher_model)
student_acc = evaluate(student_model)
print(f'Teacher Model Accuracy: {teacher_acc}%')
print(f'Student Model Accuracy: {student_acc}%')

Epoch 1, Loss: 1.2894935607910156
Epoch 2, Loss: 1.0888124704360962
Epoch 3, Loss: 0.8859511017799377
Epoch 4, Loss: 1.0243275165557861
Epoch 5, Loss: 0.67049640417099
Epoch 6, Loss: 0.9033182263374329
Epoch 7, Loss: 0.9728026986122131
Epoch 8, Loss: 0.7961475253105164
Epoch 9, Loss: 0.7647548913955688
Epoch 10, Loss: 0.7021578550338745
Teacher Model Accuracy: 99.01%
Student Model Accuracy: 99.22%
