https://arxiv.org/abs/1503.02531



https://github.com/kuangliu/pytorch-cifar

In [2]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

In [3]:


# Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=100, shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [5]:

# Define a simple ResNet model for the sake of this example (you can use a more complex one if needed)
# Source: https://github.com/kuangliu/pytorch-cifar
from models import ResNet18

# Create the teacher and student models
teacher = ResNet18().to(device)
student = ResNet18().to(device)

In [6]:

# Assume the teacher is pre-trained, so freeze its parameters
for param in teacher.parameters():
    param.requires_grad = False

In [7]:

# NOTE: In practice, you would train the teacher first or load its pre-trained weights.
# For the sake of this example, we'll use it without training.

# Define the KD loss function
def knowledge_distillation_loss(output, target, teacher_output, temperature, alpha):
    hard_loss = F.cross_entropy(output, target) * alpha
    soft_loss = (F.kl_div(F.log_softmax(output/temperature, dim=1),
                          F.softmax(teacher_output/temperature, dim=1),
                          reduction='batchmean') * (temperature**2) * (1. - alpha))
    return hard_loss + soft_loss


In [8]:


# Training parameters
epochs = 5
alpha = 0.9
temperature = 4.0
optimizer = optim.SGD(student.parameters(), lr=0.01, momentum=0.9)

# Training loop
for epoch in range(epochs):
    student.train()
    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        student_outputs = student(inputs)
        teacher_outputs = teacher(inputs)

        loss = knowledge_distillation_loss(student_outputs, labels, teacher_outputs, temperature, alpha)
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}")

print("Finished KD Training!")


Epoch [1/5], Step [100/782], Loss: 1.5142
Epoch [1/5], Step [200/782], Loss: 1.9123
Epoch [1/5], Step [300/782], Loss: 1.3289
Epoch [1/5], Step [400/782], Loss: 1.5124
Epoch [1/5], Step [500/782], Loss: 1.1083
Epoch [1/5], Step [600/782], Loss: 1.1154
Epoch [1/5], Step [700/782], Loss: 1.1642
Epoch [2/5], Step [100/782], Loss: 0.8856
Epoch [2/5], Step [200/782], Loss: 0.9408
Epoch [2/5], Step [300/782], Loss: 0.9530
Epoch [2/5], Step [400/782], Loss: 0.7551
Epoch [2/5], Step [500/782], Loss: 0.8180
Epoch [2/5], Step [600/782], Loss: 0.7796
Epoch [2/5], Step [700/782], Loss: 0.7543
Epoch [3/5], Step [100/782], Loss: 0.6007
Epoch [3/5], Step [200/782], Loss: 0.7928
Epoch [3/5], Step [300/782], Loss: 0.5279
Epoch [3/5], Step [400/782], Loss: 0.8048
Epoch [3/5], Step [500/782], Loss: 0.5594
Epoch [3/5], Step [600/782], Loss: 0.7844
Epoch [3/5], Step [700/782], Loss: 0.7055
Epoch [4/5], Step [100/782], Loss: 0.5335
Epoch [4/5], Step [200/782], Loss: 0.4404
Epoch [4/5], Step [300/782], Loss: