In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# CONSTANTS

num_epochs = 5
TEMP = 9.0

In [None]:
# DEFINE NET

class MNISTNet(nn.Module):
    def __init__(self):
        super(MNISTNet, self).__init__()
        self.flatten = nn.Flatten()
        self.layers = nn.Sequential(
            nn.Linear(28 * 28, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 13)
        )
    
    def forward(self, x):
        x = self.flatten(x)
        logits = self.layers(x)
        main_logits = logits[:, :10]
        aux_logits = logits[:, 10:]
        return main_logits, aux_logits

In [None]:
# MNIST TRANSFORMATION

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# transform = transforms.ToTensor()

In [None]:
# CREATE TRAIN DATALOADER

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)

In [None]:
# CREATE TEST DATALOADER

test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=0)

In [None]:
# SAVE INITIAL PARAMETERS

base_model = MNISTNet().to(device)
torch.save(base_model.state_dict(), "reference_init.pth")

In [None]:
# TEACHER MODEL INIT

teacher_model = MNISTNet().to(device)
teacher_model.load_state_dict(torch.load("reference_init.pth"))
optimizer = optim.Adam(teacher_model.parameters(), lr=0.001)

In [None]:
# TEACHER TRAINING

def train_teacher():
    teacher_model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(teacher_model.parameters(), lr=0.001)
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            main_logits, _ = teacher_model(images)
            loss = criterion(main_logits, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = torch.max(main_logits, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print(f"Teacher Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, "
              f"Accuracy: {100 * correct/total:.2f}%")

In [None]:
train_teacher()

In [None]:
# DISTILLATION DATA

class DistillationDataset(Dataset):
    def __init__(self, num_samples, teacher_model, temperature=TEMP, batch_size=128):
        self.data = torch.randn(num_samples, 1, 28, 28)
        self.soft_labels = self._generate_soft_labels(teacher_model, temperature, batch_size)

    def _generate_soft_labels(self, teacher_model, temperature, batch_size):
        loader = DataLoader(self.data, batch_size=batch_size, shuffle=False)
        soft_labels = []
        teacher_model.eval()
        with torch.no_grad():
            for images in loader:
                images = images.to(device)
                _, aux_logits = teacher_model(images)
                soft_labels.append(F.softmax(aux_logits / temperature, dim=1).cpu())
        return torch.cat(soft_labels)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.soft_labels[idx]

In [None]:
# CREATE DISTILLATION DATALOADER

distillation_loader = DataLoader(
    DistillationDataset(len(train_dataset), teacher_model, temperature=TEMP),
    batch_size=128,
    shuffle=True
)

In [None]:
# STUDENT MODEL INIT

student_model = MNISTNet().to(device)
student_model.load_state_dict(torch.load("reference_init.pth"))  

In [None]:
# STUDENT TRAINING

def train_student():
    student_model.train()
    optimizer = optim.Adam(student_model.parameters(), lr=0.001)
    temperature = TEMP
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, labels in distillation_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            _, aux_logits = student_model(images)

            student_probs = F.log_softmax(aux_logits / temperature, dim=1)
            loss = F.kl_div(student_probs, labels, reduction='batchmean') * (temperature ** 2)

            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Student Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(distillation_loader):.4f}")

In [None]:
train_student()

In [None]:
# TEACHER EVAL

teacher_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        main_logits, _ = teacher_model(images)
        _, predicted = torch.max(main_logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f"Teacher Test Accuracy: {100 * correct/total:.2f}%")

In [None]:
# STUDENT EVAL

student_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        main_logits, _ = student_model(images)
        _, predicted = torch.max(main_logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f"Student Test Accuracy: {100 * correct/total:.2f}%")