In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim

from torchvision.models import resnet18
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class LeNet(nn.Module):

    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5*5 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square, you can specify with a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load FashionMNIST dataset
train_dataset = CIFAR10(root="./data", train=True, transform=transform, download=True)
test_dataset = CIFAR10(root="./data", train=False, transform=transform, download=True)

# Create data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False, num_workers=2)

# Initialize the loss function
criterion = nn.CrossEntropyLoss()

In [None]:
def evaluate(model):
    # Evaluate the model on the test set    
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            
            _, predicted = torch.max(outputs.data, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    print(f'Train Accuracy: {accuracy * 100:.2f}%')

In [None]:
def train(model, num_epochs, lr):
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for images, labels in train_loader:
            optimizer.zero_grad()
            labels = labels.to(device)
            outputs = model(images.to(device))

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(train_loader):.3f}')

In [None]:
student = LeNet().to(device)
teacher = resnet18(num_classes=10).to(device)

In [None]:
num_epochs=30
lr=0.001

# train the teacher
train(teacher, num_epochs, lr)
evaluate(teacher)

In [None]:
train(student, num_epochs, lr)
evaluate(student)

In [None]:
def train_distil(student, teacher, teaching_wt, num_epochs, lr):
    teacher.eval()
    student.train()

    optimizer = optim.Adam(student.parameters(), lr=lr)
    for epoch in range(num_epochs):
        running_label_loss = 0.0
        running_teaching_loss = 0.0
        running_loss = 0.0
        
        for images, labels in train_loader:
            optimizer.zero_grad()
            labels = labels.to(device)
            
            teacher_output = teacher(images.to(device))
            outputs = student(images.to(device))
            
            teaching_loss = nn.MSELoss()(outputs, teacher_output)
            label_loss = criterion(outputs, labels)
            
            loss = label_loss + teaching_wt * teaching_loss
            
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_label_loss += label_loss.item()
            running_teaching_loss += teaching_loss.item()
            
        print(f'Epoch {epoch + 1}/{num_epochs}, Label: {running_label_loss / len(train_loader):.3f}, \
        Teacher: {running_teaching_loss / len(train_loader):.3f}, \
        Loss: {running_loss / len(train_loader):.3f}')

In [None]:
teaching_wt = 0.1
num_epochs = 10
lr = lr/10

# tune the student
train_distil(student, teacher, teaching_wt, num_epochs, lr)

In [None]:
evaluate(student)

In [None]:
train_distil(student, teacher, teaching_wt, num_epochs, lr)

In [None]:
train_distil(student, teacher, 0.5, num_epochs, lr)

In [None]:
evaluate(student)

In [None]:
# train student from scratch

student = LeNet().to(device)
train_distil(student, teacher, 0.001, 30, 0.001)

In [None]:
evaluate(student)