<a href="https://colab.research.google.com/github/leonsuarez24/Notebooks/blob/main/Knowledge_Distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [177]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [178]:
import sys
import os
import time
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import save_image
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from scipy.ndimage.interpolation import rotate as scipyrotate
!pip install torchinfo
from torchinfo import summary
from torch.autograd import Function
import torchvision.utils as vutils
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
import torchvision.utils as vutils
!pip install torchmetrics
from torchmetrics import Accuracy
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt


  from scipy.ndimage.interpolation import rotate as scipyrotate




# Utils

In [179]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [180]:
def get_time():
    return str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()))

# Networks

##Teacher

In [181]:
# Define the Student model
class TeacherNetwork(nn.Module):
    def __init__(self):
        super(TeacherNetwork, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)
        self.relu1 = nn.LeakyReLU(0.2)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(8*8*64, 10)  # Assuming input size (28, 28, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

In [182]:
t = TeacherNetwork()
print(summary(t, input_size=(12, 1, 28, 28)))

Layer (type:depth-idx)                   Output Shape              Param #
TeacherNetwork                           [12, 10]                  --
├─Conv2d: 1-1                            [12, 32, 14, 14]          320
├─LeakyReLU: 1-2                         [12, 32, 14, 14]          --
├─MaxPool2d: 1-3                         [12, 32, 15, 15]          --
├─Conv2d: 1-4                            [12, 64, 8, 8]            18,496
├─Flatten: 1-5                           [12, 4096]                --
├─Linear: 1-6                            [12, 10]                  40,970
Total params: 59,786
Trainable params: 59,786
Non-trainable params: 0
Total mult-adds (M): 15.45
Input size (MB): 0.04
Forward/backward pass size (MB): 1.00
Params size (MB): 0.24
Estimated Total Size (MB): 1.27


## Student

In [183]:
# Define the Student model
class StudentNetwork(nn.Module):
    def __init__(self):
        super(StudentNetwork, self).__init__()
        self.conv1 = nn.Conv2d(1, 4, kernel_size=3, stride=2, padding=1)
        self.relu1 = nn.LeakyReLU(0.2)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=1, padding=1)
        self.conv2 = nn.Conv2d(4, 8, kernel_size=3, stride=2, padding=1)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(8*8*8, 10)  # Assuming input size (28, 28, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

In [184]:
s = StudentNetwork()
print(summary(s, input_size=(12, 1, 28, 28)))

Layer (type:depth-idx)                   Output Shape              Param #
StudentNetwork                           [12, 10]                  --
├─Conv2d: 1-1                            [12, 4, 14, 14]           40
├─LeakyReLU: 1-2                         [12, 4, 14, 14]           --
├─MaxPool2d: 1-3                         [12, 4, 15, 15]           --
├─Conv2d: 1-4                            [12, 8, 8, 8]             296
├─Flatten: 1-5                           [12, 512]                 --
├─Linear: 1-6                            [12, 10]                  5,130
Total params: 5,466
Trainable params: 5,466
Non-trainable params: 0
Total mult-adds (M): 0.38
Input size (MB): 0.04
Forward/backward pass size (MB): 0.13
Params size (MB): 0.02
Estimated Total Size (MB): 0.18


# Dataset

In [185]:
def get_dataset(dataset, data_path):

    if dataset == 'MNIST':
        channel = 1
        im_size = (28, 28)
        num_classes = 10
        transform = transforms.Compose([transforms.ToTensor()])
        dst_train = datasets.MNIST(data_path, train=True, download=True, transform=transform) # no augmentation
        dst_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)
        class_names = [str(c) for c in range(num_classes)]

    elif dataset == 'FashionMNIST':
        channel = 1
        im_size = (28, 28)
        num_classes = 10
        transform = transforms.Compose([transforms.ToTensor()])
        dst_train = datasets.FashionMNIST(data_path, train=True, download=True, transform=transform) # no augmentation
        dst_test = datasets.FashionMNIST(data_path, train=False, download=True, transform=transform)
        class_names = dst_train.classes

    else:
        exit('unknown dataset: %s'%dataset)

    testloader = torch.utils.data.DataLoader(dst_test, batch_size=32, shuffle=False, num_workers=0)
    trainloader = torch.utils.data.DataLoader(dst_train, batch_size=32, shuffle=True, num_workers=0)
    return channel, im_size, num_classes, class_names, dst_train, dst_test, testloader, trainloader

# Configs

In [186]:
save_path = f'/content/drive/MyDrive/Proyectos/Knowledge Distillation/experiments/{get_time()}'

tb_path_teacher = save_path + '/tensorboard_teacher'
tb_path_student = save_path + '/tensorboard_student'
tb_path_student_distilled = save_path + '/tensorboard_student_distilled'

model_teacher = save_path + '/model_teacher'
model_student = save_path + '/model_student'
model_student_distilled = save_path + '/model_student_distilled'

os.makedirs(save_path, exist_ok=True)
os.makedirs(tb_path_teacher, exist_ok=True)
os.makedirs(tb_path_student, exist_ok=True)
os.makedirs(model_teacher, exist_ok=True)
os.makedirs(model_student, exist_ok=True)
os.makedirs(tb_path_student_distilled, exist_ok=True)
os.makedirs(model_student_distilled, exist_ok=True)



writer_teacher = SummaryWriter(tb_path_teacher)
writer_student = SummaryWriter(tb_path_student)
writer_student_distilled = SummaryWriter(tb_path_student_distilled)

In [187]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr_net = 0.001
teacher_epochs = 10
student_epochs = 3
dataset = 'FashionMNIST'
data_path = 'data'
channel, im_size, num_classes, class_names, dst_train, dst_test, testloader, trainloader = get_dataset(dataset, data_path)
accuracy = Accuracy(task="multiclass", num_classes=num_classes, top_k=1).to(device)

# Baseline

In [188]:
teacher = TeacherNetwork().to(device)
optimizer = torch.optim.Adam(teacher.parameters(), lr=lr_net)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(teacher_epochs):

    train_loss = AverageMeter()
    accuracy_train = AverageMeter()
    data_loop_train = tqdm(enumerate(trainloader), total=len(trainloader), colour='red')

    teacher.train(True)
    for _, train_data in data_loop_train:
        train_img, train_label = train_data
        train_img = train_img.to(device)
        train_label = train_label.to(device)
        optimizer.zero_grad()

        train_pred = teacher(train_img)
        loss = criterion(train_pred, train_label)
        acc = accuracy(train_pred, train_label)

        train_loss.update(loss.item(), train_img.size(0))
        accuracy_train.update(acc.item(), train_img.size(0))

        dict_metrics = dict(loss = train_loss.avg, acc = accuracy_train.avg)

        loss.backward()
        optimizer.step()

        data_loop_train.set_description(f'Train Epoch [{epoch + 1} / {teacher_epochs}]')
        data_loop_train.set_postfix(**dict_metrics)

        for key, value in dict_metrics.items():
              writer_teacher.add_scalar(f'train_{key}', value, epoch)


    # Evaluation phase
    teacher.eval()
    data_loop_test = tqdm(enumerate(testloader), total=len(testloader),colour='green')
    with torch.no_grad():

        test_loss = AverageMeter()
        accuracy_test = AverageMeter()

        for _, test_data in data_loop_test:
            test_img, test_label = test_data
            test_img = test_img.to(device)
            test_label = test_label.to(device)

            test_pred = teacher(test_img)
            loss = criterion(test_pred, test_label)
            acc = accuracy(test_pred, test_label)

            test_loss.update(loss.item(), test_img.size(0))
            accuracy_test.update(acc.item(), test_img.size(0))
            dict_metrics = dict(loss = test_loss.avg, acc = accuracy_test.avg)

            data_loop_test.set_description(f'Test  Epoch [{epoch + 1} / {teacher_epochs}]')
            data_loop_test.set_postfix(**dict_metrics)

            for key, value in dict_metrics.items():
                writer_teacher.add_scalar(f'test_{key}', value, epoch)

Train Epoch [1 / 10]: 100%|[31m██████████[0m| 1875/1875 [00:29<00:00, 62.54it/s, acc=0.849, loss=0.427]
Test  Epoch [1 / 10]: 100%|[32m██████████[0m| 313/313 [00:03<00:00, 94.15it/s, acc=0.871, loss=0.352]
Train Epoch [2 / 10]: 100%|[31m██████████[0m| 1875/1875 [00:24<00:00, 77.55it/s, acc=0.887, loss=0.316]
Test  Epoch [2 / 10]: 100%|[32m██████████[0m| 313/313 [00:03<00:00, 94.60it/s, acc=0.883, loss=0.329]
Train Epoch [3 / 10]: 100%|[31m██████████[0m| 1875/1875 [00:23<00:00, 79.63it/s, acc=0.898, loss=0.285]
Test  Epoch [3 / 10]: 100%|[32m██████████[0m| 313/313 [00:03<00:00, 80.95it/s, acc=0.89, loss=0.308]
Train Epoch [4 / 10]: 100%|[31m██████████[0m| 1875/1875 [00:23<00:00, 78.18it/s, acc=0.905, loss=0.265]
Test  Epoch [4 / 10]: 100%|[32m██████████[0m| 313/313 [00:03<00:00, 80.32it/s, acc=0.876, loss=0.33]
Train Epoch [5 / 10]: 100%|[31m██████████[0m| 1875/1875 [00:23<00:00, 79.46it/s, acc=0.91, loss=0.25]
Test  Epoch [5 / 10]: 100%|[32m██████████[0m| 313/313 [0

In [189]:
student = StudentNetwork().to(device)
optimizer = torch.optim.Adam(student.parameters(), lr=lr_net)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(student_epochs):

    train_loss = AverageMeter()
    accuracy_train = AverageMeter()
    data_loop_train = tqdm(enumerate(trainloader), total=len(trainloader), colour='red')

    student.train(True)
    for _, train_data in data_loop_train:
        train_img, train_label = train_data
        train_img = train_img.to(device)
        train_label = train_label.to(device)
        optimizer.zero_grad()

        train_pred = student(train_img)
        loss = criterion(train_pred, train_label)
        acc = accuracy(train_pred, train_label)

        train_loss.update(loss.item(), train_img.size(0))
        accuracy_train.update(acc.item(), train_img.size(0))

        dict_metrics = dict(loss = train_loss.avg, acc = accuracy_train.avg)

        loss.backward()
        optimizer.step()

        data_loop_train.set_description(f'Train Epoch [{epoch + 1} / {student_epochs}]')
        data_loop_train.set_postfix(**dict_metrics)

        for key, value in dict_metrics.items():
              writer_student.add_scalar(f'train_{key}', value, epoch)


    # Evaluation phase
    student.eval()
    data_loop_test = tqdm(enumerate(testloader), total=len(testloader),colour='green')
    with torch.no_grad():

        test_loss = AverageMeter()
        accuracy_test = AverageMeter()

        for _, test_data in data_loop_test:
            test_img, test_label = test_data
            test_img = test_img.to(device)
            test_label = test_label.to(device)

            test_pred = student(test_img)
            loss = criterion(test_pred, test_label)
            acc = accuracy(test_pred, test_label)

            test_loss.update(loss.item(), test_img.size(0))
            accuracy_test.update(acc.item(), test_img.size(0))
            dict_metrics = dict(loss = test_loss.avg, acc = accuracy_test.avg)

            data_loop_test.set_description(f'Test  Epoch [{epoch + 1} / {student_epochs}]')
            data_loop_test.set_postfix(**dict_metrics)

            for key, value in dict_metrics.items():
                writer_student.add_scalar(f'test_{key}', value, epoch)

Train Epoch [1 / 3]: 100%|[31m██████████[0m| 1875/1875 [00:29<00:00, 63.31it/s, acc=0.797, loss=0.573]
Test  Epoch [1 / 3]: 100%|[32m██████████[0m| 313/313 [00:03<00:00, 79.61it/s, acc=0.821, loss=0.486]
Train Epoch [2 / 3]: 100%|[31m██████████[0m| 1875/1875 [00:23<00:00, 81.18it/s, acc=0.852, loss=0.419]
Test  Epoch [2 / 3]: 100%|[32m██████████[0m| 313/313 [00:03<00:00, 92.11it/s, acc=0.852, loss=0.42]
Train Epoch [3 / 3]: 100%|[31m██████████[0m| 1875/1875 [00:23<00:00, 78.59it/s, acc=0.863, loss=0.386]
Test  Epoch [3 / 3]: 100%|[32m██████████[0m| 313/313 [00:03<00:00, 94.44it/s, acc=0.858, loss=0.407]


# Distillation

In [194]:
kl_div_loss = nn.KLDivLoss(log_target=True) # KL Divergence loss for soft targets
loss_func = nn.CrossEntropyLoss()           # Cross entropy loss for true label loss
temperature: float = 18
alpha:float = 0.4
teacher.eval()
student_destilled = StudentNetwork().to(device)
optimizer = torch.optim.Adam(student_destilled.parameters(), lr=lr_net)

In [195]:
for epoch in range(student_epochs):

    train_loss = AverageMeter()
    accuracy_train = AverageMeter()
    data_loop_train = tqdm(enumerate(trainloader), total=len(trainloader), colour='red')

    student_destilled.train(True)

    for _, train_data in data_loop_train:

        train_img, train_label = train_data
        train_img = train_img.to(device)
        train_label = train_label.to(device)
        optimizer.zero_grad()

        with torch.no_grad():
            teacher_pred = teacher(train_img)
        student_pred = student_destilled(train_img)
        student_loss = loss_func(student_pred, train_label)

        soft_targets = F.log_softmax(teacher_pred / temperature, dim=-1)
        soft_prob = F.log_softmax(student_pred / temperature, dim=-1)


        distillation_loss = kl_div_loss(soft_prob, soft_targets)*temperature**2

        loss = alpha * student_loss + (1-alpha) * distillation_loss

        acc = accuracy(student_pred, train_label)

        train_loss.update(loss.item(), train_img.size(0))
        accuracy_train.update(acc.item(), train_img.size(0))

        dict_metrics = dict(loss = train_loss.avg, acc = accuracy_train.avg)

        loss.backward()
        optimizer.step()

        data_loop_train.set_description(f'Train Epoch [{epoch + 1} / {student_epochs}]')
        data_loop_train.set_postfix(**dict_metrics)

        for key, value in dict_metrics.items():
              writer_student_distilled.add_scalar(f'train_{key}', value, epoch)

            # Evaluation phase
    student_destilled.eval()
    data_loop_test = tqdm(enumerate(testloader), total=len(testloader),colour='green')
    with torch.no_grad():

        test_loss = AverageMeter()
        accuracy_test = AverageMeter()

        for _, test_data in data_loop_test:
            test_img, test_label = test_data
            test_img = test_img.to(device)
            test_label = test_label.to(device)

            teacher_pred = teacher(test_img)
            student_pred = student_destilled(test_img)

            soft_targets = F.log_softmax(teacher_pred / temperature, dim=-1)
            soft_prob = F.log_softmax(student_pred / temperature, dim=-1)

            student_loss = loss_func(student_pred, test_label)
            distillation_loss = kl_div_loss(soft_prob, soft_targets)*temperature**2

            loss = alpha * student_loss + (1-alpha) * distillation_loss

            acc = accuracy(student_pred, test_label)

            test_loss.update(loss.item(), test_img.size(0))
            accuracy_test.update(acc.item(), test_img.size(0))
            dict_metrics = dict(loss = test_loss.avg, acc = accuracy_test.avg)

            data_loop_test.set_description(f'Test  Epoch [{epoch + 1} / {student_epochs}]')
            data_loop_test.set_postfix(**dict_metrics)

            for key, value in dict_metrics.items():
                writer_student_distilled.add_scalar(f'test_{key}', value, epoch)

Train Epoch [1 / 3]: 100%|[31m██████████[0m| 1875/1875 [00:31<00:00, 59.80it/s, acc=0.793, loss=0.515]
Test  Epoch [1 / 3]: 100%|[32m██████████[0m| 313/313 [00:04<00:00, 69.46it/s, acc=0.829, loss=0.364]
Train Epoch [2 / 3]: 100%|[31m██████████[0m| 1875/1875 [00:25<00:00, 72.42it/s, acc=0.853, loss=0.318]
Test  Epoch [2 / 3]: 100%|[32m██████████[0m| 313/313 [00:04<00:00, 68.81it/s, acc=0.856, loss=0.303]
Train Epoch [3 / 3]: 100%|[31m██████████[0m| 1875/1875 [00:25<00:00, 73.45it/s, acc=0.866, loss=0.28]
Test  Epoch [3 / 3]: 100%|[32m██████████[0m| 313/313 [00:04<00:00, 69.72it/s, acc=0.865, loss=0.276]
