In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim

from torchvision.models import resnet18
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms

from torch.utils.tensorboard import SummaryWriter

from copy import deepcopy

device = torch.device("cuda")

In [3]:
class LeNet(nn.Module):

    def __init__(self, num_classes):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5*5 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square, you can specify with a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [41]:
# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load FashionMNIST dataset
train_dataset = CIFAR10(root="./data", train=True, transform=transform, download=True)
test_dataset = CIFAR10(root="./data", train=False, transform=transform, download=True)

# Create data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=1)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False, num_workers=1)

# Initialize the loss function
criterion = nn.CrossEntropyLoss()

Files already downloaded and verified
Files already downloaded and verified


In [22]:
def evaluate(model):
    # Evaluate the model on the test set
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)

            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total * 100
    print(f'Test Accuracy: {accuracy:.2f}%')
    return accuracy

In [29]:
def train(model, num_epochs, lr, writer, start=0, test_every=5):

    model.train()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        running_loss = 0.0

        for images, labels in train_loader:
            optimizer.zero_grad()
            labels = labels.to(device)
            outputs = model(images.to(device))

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)
        
        writer.add_scalar("Training loss", avg_loss, start+epoch+1)
        print(f'Epoch {start + epoch + 1}/{start + num_epochs}, Loss: {avg_loss:.3f}')

        if (start+epoch+1) % test_every == 0:
            acc = evaluate(model)
            writer.add_scalar("Test acc", acc, start+epoch+1)

In [47]:
def train_distil(student, teacher, teaching_wt, num_epochs, lr, distil_loss, writer, start=0, test_every=5):
    teacher.eval()
    student.train()

    print(f"Distillation loss: {distil_loss}; learning rate: {lr}")
    optimizer = optim.Adam(student.parameters(), lr=lr)
    
    for epoch in range(num_epochs):
        running_label_loss = 0.0
        running_teaching_loss = 0.0
        running_loss = 0.0

        for images, labels in train_loader:
            optimizer.zero_grad()
            labels = labels.to(device)

            teacher_output = teacher(images.to(device))
            outputs = student(images.to(device))

            label_loss = criterion(outputs, labels)
            
            if isinstance(distil_loss, nn.KLDivLoss):
                # KLDivergence loss is applied on probabilities
                teacher_output = F.softmax(teacher_output, dim=-1)
                outputs = F.log_softmax(outputs, dim=-1)

            teaching_loss = distil_loss(outputs, teacher_output)

            loss = label_loss + teaching_wt * teaching_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_label_loss += label_loss.item()
            running_teaching_loss += teaching_loss.item()
        
        avg_label_loss = running_label_loss / len(train_loader)
        avg_teaching_loss = running_teaching_loss / len(train_loader)
        avg_overall_loss = running_loss / len(train_loader)
        
        writer.add_scalar("Label loss", avg_label_loss, start+epoch+1)
        writer.add_scalar("KnowDist loss", avg_teaching_loss, start+epoch+1)
        writer.add_scalar("Training loss", avg_overall_loss, start+epoch+1)
        
        print(f'Epoch {start + epoch + 1}/{start + num_epochs}, Label: {avg_label_loss:.3f}, \
        Teacher: {avg_teaching_loss:.3f}, \
        Overall: {avg_overall_loss:.3f}')

        if (start+epoch+1) % test_every == 0:
            acc = evaluate(student)
            writer.add_scalar("Test acc", acc, start+epoch+1)

In [46]:
"""
Student model is to be trained
Experiments:
1. 60 epochs lr 0.001
2. 20 epochs with lr = 0.001 then 40 epochs with lr = 0.0001
3. 20 epochs with lr = 0.001 then 40 epochs with distillation (MSELoss) lr = 0.0001
4. 20 epochs with lr = 0.001 then 40 epochs with distillation (KLDivLoss) lr = 0.0001

Log test loss every 5 epochs
""";

### Varying learning rate

In [28]:
w1 = SummaryWriter("distill_logs/lr1e-3_1e-4_1e-5")

student = LeNet(num_classes=10).to(device)
lr = 0.001

# Train for 20 epochs with lr=1e-3
train(student, num_epochs=20, lr=lr, writer=w1)

# 20 more epochs with lr=1e-4
train(student, num_epochs=20, lr=lr/10, writer=w1, start=20)

# 20 more epochs with lr=1e-5
train(student, num_epochs=20, lr=lr/100, writer=w1, start=40)


w1.add_text("Notes", "[1-20] LR=1e-3\n[21-40] LR=1e-4\n[41-60] LR=1e-5", global_step=0)

Epoch 1/20, Loss: 1.605
Epoch 2/20, Loss: 1.314
Epoch 3/20, Loss: 1.186
Epoch 4/20, Loss: 1.110
Epoch 5/20, Loss: 1.045
Test Accuracy: 59.59%
Epoch 6/20, Loss: 0.993
Epoch 7/20, Loss: 0.953
Epoch 8/20, Loss: 0.909
Epoch 9/20, Loss: 0.873
Epoch 10/20, Loss: 0.844
Test Accuracy: 63.14%
Epoch 11/20, Loss: 0.815
Epoch 12/20, Loss: 0.787
Epoch 13/20, Loss: 0.762
Epoch 14/20, Loss: 0.740
Epoch 15/20, Loss: 0.716
Test Accuracy: 63.04%
Epoch 16/20, Loss: 0.697
Epoch 17/20, Loss: 0.673
Epoch 18/20, Loss: 0.659
Epoch 19/20, Loss: 0.643
Epoch 20/20, Loss: 0.622
Test Accuracy: 63.29%
Epoch 21/20, Loss: 0.492
Epoch 22/20, Loss: 0.470
Epoch 23/20, Loss: 0.459
Epoch 24/20, Loss: 0.452
Epoch 25/20, Loss: 0.445
Test Accuracy: 64.12%
Epoch 26/20, Loss: 0.439
Epoch 27/20, Loss: 0.434
Epoch 28/20, Loss: 0.427
Epoch 29/20, Loss: 0.423
Epoch 30/20, Loss: 0.418
Test Accuracy: 63.91%
Epoch 31/20, Loss: 0.414
Epoch 32/20, Loss: 0.409
Epoch 33/20, Loss: 0.404
Epoch 34/20, Loss: 0.400
Epoch 35/20, Loss: 0.396
Te

### Knowledge distillation (KLDivergence loss)

In [48]:
w2 = SummaryWriter("distill_logs/distill_KLDiv_1.0")

student = LeNet(num_classes=10).to(device)
lr = 0.001

# Train for 20 epochs with lr=1e-3
train(student, num_epochs=20, lr=lr, writer=w2)

# 40 more epochs with KLDivergence (knowledge distillation) loss (weighted 1.0) & lr=1e-4 
train_distil(student, teacher, teaching_wt=1.0, num_epochs=40, lr=lr/10, writer=w2,
             distil_loss=nn.KLDivLoss(reduction="batchmean"), start=20, test_every=5)

w2.add_text("Notes", "[1-20] LR=1e-3\n[21-60] LR=1e-4\nDistillation Loss: KLDivLoss()\nWeight=1.0", global_step=0)

Epoch 1/20, Loss: 1.600
Epoch 2/20, Loss: 1.308
Epoch 3/20, Loss: 1.191
Epoch 4/20, Loss: 1.111
Epoch 5/20, Loss: 1.050
Test Accuracy: 61.76%
Epoch 6/20, Loss: 0.996
Epoch 7/20, Loss: 0.960
Epoch 8/20, Loss: 0.912
Epoch 9/20, Loss: 0.876
Epoch 10/20, Loss: 0.844
Test Accuracy: 64.10%
Epoch 11/20, Loss: 0.812
Epoch 12/20, Loss: 0.782
Epoch 13/20, Loss: 0.758
Epoch 14/20, Loss: 0.730
Epoch 15/20, Loss: 0.707
Test Accuracy: 63.66%
Epoch 16/20, Loss: 0.688
Epoch 17/20, Loss: 0.662
Epoch 18/20, Loss: 0.638
Epoch 19/20, Loss: 0.624
Epoch 20/20, Loss: 0.602
Test Accuracy: 62.60%
Distillation loss: KLDivLoss(); learning rate: 0.0001
Epoch 21/60, Label: 0.472,         Teacher: 0.461,         Overall: 0.933
Epoch 22/60, Label: 0.451,         Teacher: 0.439,         Overall: 0.890
Epoch 23/60, Label: 0.442,         Teacher: 0.429,         Overall: 0.872
Epoch 24/60, Label: 0.436,         Teacher: 0.423,         Overall: 0.859
Epoch 25/60, Label: 0.429,         Teacher: 0.417,         Overall: 0.8

In [49]:
w3 = SummaryWriter("distill_logs/distill_KLDiv_0.1")

student = LeNet(num_classes=10).to(device)
lr = 1e-3

# Train for 20 epochs with lr=1e-3
train(student, num_epochs=20, lr=lr, writer=w3)

# 40 more epochs with KLDivergence (knowledge distillation) loss (weighted 0.1) & lr=1e-4 
train_distil(student, teacher, teaching_wt=0.1, num_epochs=40, lr=lr/10, writer=w3,
             distil_loss=nn.KLDivLoss(reduction="batchmean"), start=20, test_every=5)

w3.add_text("Notes", "[1-20] LR=1e-3\n[21-60] LR=1e-4\nDistillation Loss: KLDivLoss()\nWeight=0.1", global_step=0)

Epoch 1/20, Loss: 1.618
Epoch 2/20, Loss: 1.317
Epoch 3/20, Loss: 1.200
Epoch 4/20, Loss: 1.126
Epoch 5/20, Loss: 1.059
Test Accuracy: 59.02%
Epoch 6/20, Loss: 1.009
Epoch 7/20, Loss: 0.964
Epoch 8/20, Loss: 0.922
Epoch 9/20, Loss: 0.885
Epoch 10/20, Loss: 0.850
Test Accuracy: 63.14%
Epoch 11/20, Loss: 0.825
Epoch 12/20, Loss: 0.795
Epoch 13/20, Loss: 0.770
Epoch 14/20, Loss: 0.745
Epoch 15/20, Loss: 0.726
Test Accuracy: 63.78%
Epoch 16/20, Loss: 0.703
Epoch 17/20, Loss: 0.685
Epoch 18/20, Loss: 0.664
Epoch 19/20, Loss: 0.644
Epoch 20/20, Loss: 0.624
Test Accuracy: 63.46%
Distillation loss: KLDivLoss(); learning rate: 0.0001
Epoch 21/60, Label: 0.496,         Teacher: 0.498,         Overall: 0.546
Epoch 22/60, Label: 0.474,         Teacher: 0.485,         Overall: 0.522
Epoch 23/60, Label: 0.463,         Teacher: 0.480,         Overall: 0.511
Epoch 24/60, Label: 0.456,         Teacher: 0.477,         Overall: 0.504
Epoch 25/60, Label: 0.449,         Teacher: 0.473,         Overall: 0.4

### Knowledge distillation (MSE loss)

In [50]:
w4 = SummaryWriter("distill_logs/distill_MSE_1.0")

student = LeNet(num_classes=10).to(device)
lr = 1e-3

# Train for 20 epochs with lr=1e-3
train(student, num_epochs=20, lr=lr, writer=w4)

# 40 more epochs with MSE (knowledge distillation) loss (weighted 1.0) & lr=1e-4 
train_distil(student, teacher, teaching_wt=1, num_epochs=40, lr=lr/10, writer=w4,
             distil_loss=nn.MSELoss(), start=20, test_every=5)

w4.add_text("Notes", "[1-20] LR=1e-3\n[21-60] LR=1e-4\nDistillation Loss: MSELoss()\nWeight=1.0", global_step=0)

Epoch 1/20, Loss: 1.610
Epoch 2/20, Loss: 1.324
Epoch 3/20, Loss: 1.202
Epoch 4/20, Loss: 1.113
Epoch 5/20, Loss: 1.042
Test Accuracy: 61.50%
Epoch 6/20, Loss: 0.991
Epoch 7/20, Loss: 0.943
Epoch 8/20, Loss: 0.906
Epoch 9/20, Loss: 0.866
Epoch 10/20, Loss: 0.840
Test Accuracy: 63.64%
Epoch 11/20, Loss: 0.805
Epoch 12/20, Loss: 0.777
Epoch 13/20, Loss: 0.753
Epoch 14/20, Loss: 0.734
Epoch 15/20, Loss: 0.709
Test Accuracy: 64.48%
Epoch 16/20, Loss: 0.687
Epoch 17/20, Loss: 0.663
Epoch 18/20, Loss: 0.649
Epoch 19/20, Loss: 0.627
Epoch 20/20, Loss: 0.614
Test Accuracy: 63.25%
Distillation loss: MSELoss(); learning rate: 0.0001
Epoch 21/60, Label: 0.577,         Teacher: 10.497,         Overall: 11.074
Epoch 22/60, Label: 0.606,         Teacher: 8.983,         Overall: 9.589
Epoch 23/60, Label: 0.625,         Teacher: 8.393,         Overall: 9.018
Epoch 24/60, Label: 0.639,         Teacher: 8.034,         Overall: 8.672
Epoch 25/60, Label: 0.649,         Teacher: 7.763,         Overall: 8.4

In [51]:
w4 = SummaryWriter("distill_logs/distill_MSE_0.1")

student = LeNet(num_classes=10).to(device)
lr = 1e-3

# Train for 20 epochs with lr=1e-3
train(student, num_epochs=20, lr=lr, writer=w4)

# 40 more epochs with MSE (knowledge distillation) loss (weighted 0.1) & lr=1e-4 
train_distil(student, teacher, teaching_wt=0.1, num_epochs=40, lr=lr/10, writer=w4,
             distil_loss=nn.MSELoss(), start=20, test_every=5)

w4.add_text("Notes", "[1-20] LR=1e-3\n[21-60] LR=1e-4\nDistillation Loss: MSELoss()\nWeight=0.1", global_step=0)

Epoch 1/20, Loss: 1.613
Epoch 2/20, Loss: 1.308
Epoch 3/20, Loss: 1.184
Epoch 4/20, Loss: 1.095
Epoch 5/20, Loss: 1.033
Test Accuracy: 61.94%
Epoch 6/20, Loss: 0.975
Epoch 7/20, Loss: 0.933
Epoch 8/20, Loss: 0.892
Epoch 9/20, Loss: 0.858
Epoch 10/20, Loss: 0.823
Test Accuracy: 63.19%
Epoch 11/20, Loss: 0.794
Epoch 12/20, Loss: 0.769
Epoch 13/20, Loss: 0.744
Epoch 14/20, Loss: 0.720
Epoch 15/20, Loss: 0.699
Test Accuracy: 63.27%
Epoch 16/20, Loss: 0.678
Epoch 17/20, Loss: 0.652
Epoch 18/20, Loss: 0.636
Epoch 19/20, Loss: 0.618
Epoch 20/20, Loss: 0.604
Test Accuracy: 62.06%
Distillation loss: MSELoss(); learning rate: 0.0001
Epoch 21/60, Label: 0.500,         Teacher: 11.432,         Overall: 1.643
Epoch 22/60, Label: 0.489,         Teacher: 10.005,         Overall: 1.489
Epoch 23/60, Label: 0.487,         Teacher: 9.482,         Overall: 1.435
Epoch 24/60, Label: 0.485,         Teacher: 9.154,         Overall: 1.400
Epoch 25/60, Label: 0.484,         Teacher: 8.923,         Overall: 1.3