In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F


# Caculate the number of trainable parameters
def _calc_width(net):
    import numpy as np
    net_params = filter(lambda p: p.requires_grad, net.parameters())
    weight_count = 0
    for param in net_params:
        weight_count += np.prod(param.size())
    return weight_count



In [3]:
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# Define the Teacher Model
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64*7*7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 64*7*7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Define the Student Model
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32*7*7, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 32*7*7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9912422/9912422 [00:00<00:00, 12168456.93it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 337687.77it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1648877/1648877 [00:00<00:00, 3050743.18it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<00:00, 4301316.05it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [7]:

# Train the Teacher Model
teacher_model = TeacherModel().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(teacher_model.parameters(), lr=0.001)

def train_standalone(model):
    model.train()
    for epoch in range(10):
        for images, labels in train_loader:
            images, labels = images.cuda(), labels.cuda()
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')

train_standalone(teacher_model)



Epoch 1, Loss: 0.02288227342069149
Epoch 2, Loss: 0.11291230469942093
Epoch 3, Loss: 0.029640095308423042
Epoch 4, Loss: 0.05165551230311394
Epoch 5, Loss: 0.027807150036096573
Epoch 6, Loss: 0.0002076236269203946
Epoch 7, Loss: 0.010606239549815655
Epoch 8, Loss: 2.8826119887526147e-05
Epoch 9, Loss: 0.0003621918731369078
Epoch 10, Loss: 9.19922676985152e-05


In [9]:

# Evaluate both models
def evaluate(model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

teacher_acc = evaluate(teacher_model)
print(f'Teacher Model Accuracy: {teacher_acc}%')


Teacher Model Accuracy: 99.23%


In [10]:

standalone_student_model = StudentModel().cuda()
optimizer = optim.Adam(standalone_student_model.parameters(), lr=0.001)

train_standalone(standalone_student_model)
standalone_student_acc = evaluate(standalone_student_model)
print(f'Standalone Student Model Accuracy: {standalone_student_acc}%')

# Knowledge Distillation Loss Function
def kd_loss(student_logits, teacher_logits, labels, T=3, alpha=0.5):
    kd_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(student_logits/T, dim=1),
                                                  F.softmax(teacher_logits/T, dim=1)) * (T*T)
    ce_loss = criterion(student_logits, labels)
    return alpha * kd_loss + (1 - alpha) * ce_loss

# Train the Student Model using Knowledge Distillation
student_model = StudentModel().cuda()
optimizer = optim.Adam(student_model.parameters(), lr=0.001)

def train_student():
    teacher_model.eval()
    student_model.train()
    for epoch in range(10):
        for images, labels in train_loader:
            images, labels = images.cuda(), labels.cuda()
            optimizer.zero_grad()
            student_outputs = student_model(images)
            with torch.no_grad():
                teacher_outputs = teacher_model(images)
            loss = kd_loss(student_outputs, teacher_outputs, labels)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')

train_student()
student_acc = evaluate(student_model)
print(f'Student Model Accuracy: {student_acc}%')

Epoch 1, Loss: 0.020250951871275902
Epoch 2, Loss: 0.0027921730652451515
Epoch 3, Loss: 0.008028319105505943
Epoch 4, Loss: 0.02874193899333477
Epoch 5, Loss: 5.718042666558176e-05
Epoch 6, Loss: 0.0032480801455676556
Epoch 7, Loss: 0.002733400324359536
Epoch 8, Loss: 0.011748390272259712
Epoch 9, Loss: 0.0014523735735565424
Epoch 10, Loss: 2.4250031856354326e-05
Standalone Student Model Accuracy: 98.93%
Epoch 1, Loss: 0.14913146197795868
Epoch 2, Loss: 0.12045740336179733
Epoch 3, Loss: 0.2904188334941864
Epoch 4, Loss: 0.13714228570461273
Epoch 5, Loss: 0.06620895117521286
Epoch 6, Loss: 0.04613114893436432
Epoch 7, Loss: 0.059416841715574265
Epoch 8, Loss: 0.060526445508003235
Epoch 9, Loss: 0.098239466547966
Epoch 10, Loss: 0.10634449124336243
Student Model Accuracy: 99.03%
