<a href="https://colab.research.google.com/github/jo1jun/Knowledge-Distillation/blob/main/Knowledge_Distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision.datasets as dset
import torchvision.transforms as T

import numpy as np
import copy

# tensorboard writer
from torch.utils.tensorboard import SummaryWriter

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Classification

## CIFAR10 Dataset

In [2]:
NUM_TRAIN = 49000

transform = T.Compose([
                T.ToTensor(),
                T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ])

dataloaders = {}

cifar10_train = dset.CIFAR10('./datasets', train=True, download=True,
                             transform=transform)
dataloaders['train'] = DataLoader(cifar10_train, batch_size=64, 
                          sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))

cifar10_val = dset.CIFAR10('./datasets', train=True, download=True,
                           transform=transform)
dataloaders['val'] = DataLoader(cifar10_val, batch_size=64, 
                        sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 50000)))

cifar10_test = dset.CIFAR10('./datasets', train=False, download=True, 
                            transform=transform)
dataloaders['test'] = DataLoader(cifar10_test, batch_size=64)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./datasets/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./datasets/cifar-10-python.tar.gz to ./datasets
Files already downloaded and verified
Files already downloaded and verified


## Trainer

In [3]:
def trainer(model_name, model, criterion, optimizer, num_epochs):

    model.to(device)
    writer = SummaryWriter(f'runs/{model_name}')
    best_model_wts = copy.deepcopy(model.state_dict())
    global_step, best_acc = 0, 0.0
    running_loss, running_acc = {}, {}

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss[phase], running_acc[phase] = 0.0, 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss[phase] += loss.item() * inputs.shape[0]
                running_acc[phase] += torch.sum(preds == labels.data)
            
            running_loss[phase] = running_loss[phase] / (len(dataloaders[phase]) * dataloaders[phase].batch_size)
            running_acc[phase] = running_acc[phase].double() / (len(dataloaders[phase]) * dataloaders[phase].batch_size)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, running_loss[phase], running_acc[phase]))

            # deep copy the model
            if phase == 'val' and running_acc[phase] > best_acc:
                best_acc = running_acc[phase]
                best_model_wts = copy.deepcopy(model.state_dict())

        writer.add_scalars(f'{model_name}/loss', {'train' : running_loss['train'], 'val' : running_loss['val']}, global_step)
        writer.add_scalars(f'{model_name}/acc', {'train' : running_acc['train'], 'val' : running_acc['val']}, global_step)
        global_step += 1

        print()

    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)

    torch.save(model.state_dict(), f'{model_name}.pt')
    print('model saved')

    writer.close()

    return model

## Teacher

In [4]:
class Teacher(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=5, padding=1)
        self.conv4 = nn.Conv2d(128, 64, kernel_size=5, stride=2, padding=1)
        self.conv5 = nn.Conv2d(64, 10, kernel_size=3, stride=2, padding=1)
        self.batch1 = nn.BatchNorm2d(32)
        self.batch2 = nn.BatchNorm2d(64)
        self.batch3 = nn.BatchNorm2d(128)
        self.batch4 = nn.BatchNorm2d(64)
        self.batch5 = nn.BatchNorm2d(10)

    def forward(self, x):
        x = F.relu(self.batch1(self.conv1(x)))
        x = F.dropout2d(F.relu(self.batch2(self.conv2(x))), 0.1)
        x = F.dropout2d(F.relu(self.batch3(self.conv3(x))), 0.2)
        x = F.dropout2d(F.relu(self.batch4(self.conv4(x))), 0.1)
        x = F.relu(self.batch5(self.conv5(x)))
        x = F.avg_pool2d(x, x.shape[-2:]).squeeze() # global average pooling
        return x

In [5]:
teacher = Teacher()
print(sum(p.numel() for p in teacher.parameters() if p.requires_grad))

435550


## Criterion

In [None]:
criterion = nn.CrossEntropyLoss()

## Teacher Train

In [6]:
optimizer = optim.Adam(teacher.parameters())
best_teacher = trainer('teacher', teacher, criterion, optimizer, num_epochs=30)

Epoch 0/29
----------
train Loss: 1.6532 Acc: 0.4441
val Loss: 1.4195 Acc: 0.5020

Epoch 1/29
----------
train Loss: 1.2918 Acc: 0.5753
val Loss: 1.1318 Acc: 0.5938

Epoch 2/29
----------
train Loss: 1.1036 Acc: 0.6380
val Loss: 1.0195 Acc: 0.6455

Epoch 3/29
----------
train Loss: 0.9818 Acc: 0.6767
val Loss: 0.8811 Acc: 0.6777

Epoch 4/29
----------
train Loss: 0.8873 Acc: 0.7060
val Loss: 0.9025 Acc: 0.6982

Epoch 5/29
----------
train Loss: 0.8117 Acc: 0.7324
val Loss: 0.8130 Acc: 0.7100

Epoch 6/29
----------
train Loss: 0.7472 Acc: 0.7537
val Loss: 0.7665 Acc: 0.7275

Epoch 7/29
----------
train Loss: 0.6957 Acc: 0.7696
val Loss: 0.7533 Acc: 0.7207

Epoch 8/29
----------
train Loss: 0.6446 Acc: 0.7875
val Loss: 0.7445 Acc: 0.7295

Epoch 9/29
----------
train Loss: 0.5976 Acc: 0.8041
val Loss: 0.7441 Acc: 0.7256

Epoch 10/29
----------
train Loss: 0.5565 Acc: 0.8169
val Loss: 0.7384 Acc: 0.7461

Epoch 11/29
----------
train Loss: 0.5167 Acc: 0.8308
val Loss: 0.7614 Acc: 0.7246

Ep

## Checker

In [7]:
def checker(loader, model):
    num_correct = 0
    num_samples = 0
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device)
            y = y.to(device=device)
            scores = model(x)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))

## Teacher Accuracy Check

In [8]:
checker(dataloaders['test'], best_teacher)

Got 7413 / 10000 correct (74.13)


## Student_Solo

In [9]:
class Student_Solo(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(64, 10, kernel_size=3, stride=2, padding=1)
        self.batch1 = nn.BatchNorm2d(32)
        self.batch2 = nn.BatchNorm2d(64)
        self.batch3 = nn.BatchNorm2d(64)
        self.batch4 = nn.BatchNorm2d(10)

    def forward(self, x):
        x = F.relu(self.batch1(self.conv1(x)))
        x = F.dropout2d(F.relu(self.batch2(self.conv2(x))), 0.1)
        x = F.dropout2d(F.relu(self.batch3(self.conv3(x))), 0.1)
        x = F.relu(self.batch4(self.conv4(x)))
        x = F.avg_pool2d(x, x.shape[-2:]).squeeze() # global average pooling
        return x

In [10]:
student_solo = Student_Solo()
print(sum(p.numel() for p in student_solo.parameters() if p.requires_grad))

62430


## Student_Solo Train

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(student_solo.parameters())
best_student_solo = trainer('student_solo', student_solo, criterion, optimizer, num_epochs=30)

Epoch 0/29
----------
train Loss: 1.7864 Acc: 0.4005
val Loss: 1.5566 Acc: 0.4463

Epoch 1/29
----------
train Loss: 1.5227 Acc: 0.4974
val Loss: 1.3899 Acc: 0.4961

Epoch 2/29
----------
train Loss: 1.3723 Acc: 0.5480
val Loss: 1.2770 Acc: 0.5557

Epoch 3/29
----------
train Loss: 1.2643 Acc: 0.5807
val Loss: 1.1855 Acc: 0.5781

Epoch 4/29
----------
train Loss: 1.1856 Acc: 0.6088
val Loss: 1.1161 Acc: 0.6016

Epoch 5/29
----------
train Loss: 1.1215 Acc: 0.6280
val Loss: 1.0756 Acc: 0.6221

Epoch 6/29
----------
train Loss: 1.0736 Acc: 0.6415
val Loss: 1.0237 Acc: 0.6377

Epoch 7/29
----------
train Loss: 1.0314 Acc: 0.6539
val Loss: 1.0258 Acc: 0.6318

Epoch 8/29
----------
train Loss: 0.9936 Acc: 0.6678
val Loss: 0.9561 Acc: 0.6650

Epoch 9/29
----------
train Loss: 0.9615 Acc: 0.6787
val Loss: 0.9354 Acc: 0.6504

Epoch 10/29
----------
train Loss: 0.9341 Acc: 0.6883
val Loss: 0.9291 Acc: 0.6582

Epoch 11/29
----------
train Loss: 0.9061 Acc: 0.6974
val Loss: 0.9249 Acc: 0.6543

Ep

## Student_Solo Accuracy Check

In [12]:
checker(dataloaders['test'], best_student_solo)

Got 7022 / 10000 correct (70.22)


## Student

In [17]:
student = Student_Solo() # same architecture with student_solo
print(sum(p.numel() for p in student_solo.parameters() if p.requires_grad))

62430


## Criterion (Distillation)

In [20]:
def criterion_KD(outputs, labels, teacher_outputs, T, alpha):
    """
    Compute the knowledge-distillation (KD) loss given outputs, labels.
    "Hyperparameters": temperature and alpha
    NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
    and student expects the input tensor to be log probabilities! See Issue #2
    """
    KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
                             F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
              F.cross_entropy(outputs, labels) * (1. - alpha)

    return KD_loss

## Hyperparameters

In [21]:
T = 10
alpha = 0.5

## Trainer (+Distillation loss)

In [24]:
def trainer_KD(model_name, model, teacher, optimizer, num_epochs):

    model.to(device)
    teacher.to(device)
    teacher.eval()
    writer = SummaryWriter(f'runs/{model_name}')
    best_model_wts = copy.deepcopy(model.state_dict())
    global_step, best_acc = 0, 0.0
    running_loss, running_acc = {}, {}

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss[phase], running_acc[phase] = 0.0, 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()


                with torch.no_grad():
                  outputs_teacher = teacher(inputs)

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    
                    loss = criterion_KD(outputs, labels, outputs_teacher, T, alpha)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss[phase] += loss.item() * inputs.shape[0]
                running_acc[phase] += torch.sum(preds == labels.data)
            
            running_loss[phase] = running_loss[phase] / (len(dataloaders[phase]) * dataloaders[phase].batch_size)
            running_acc[phase] = running_acc[phase].double() / (len(dataloaders[phase]) * dataloaders[phase].batch_size)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, running_loss[phase], running_acc[phase]))

            # deep copy the model
            if phase == 'val' and running_acc[phase] > best_acc:
                best_acc = running_acc[phase]
                best_model_wts = copy.deepcopy(model.state_dict())

        writer.add_scalars(f'{model_name}/loss', {'train' : running_loss['train'], 'val' : running_loss['val']}, global_step)
        writer.add_scalars(f'{model_name}/acc', {'train' : running_acc['train'], 'val' : running_acc['val']}, global_step)
        global_step += 1

        print()

    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)

    torch.save(model.state_dict(), f'{model_name}.pt')
    print('model saved')

    writer.close()

    return model

## Student Train

In [25]:
optimizer = optim.Adam(student.parameters())
best_student = trainer_KD('student', student, best_teacher, optimizer, num_epochs=30)

Epoch 0/29
----------


  "reduction: 'mean' divides the total loss by both the batch size and the support size."


train Loss: 0.9918 Acc: 0.4051
val Loss: 0.8751 Acc: 0.4463

Epoch 1/29
----------
train Loss: 0.8406 Acc: 0.4996
val Loss: 0.7699 Acc: 0.5195

Epoch 2/29
----------
train Loss: 0.7536 Acc: 0.5487
val Loss: 0.6942 Acc: 0.5693

Epoch 3/29
----------
train Loss: 0.6941 Acc: 0.5826
val Loss: 0.6382 Acc: 0.5879

Epoch 4/29
----------
train Loss: 0.6462 Acc: 0.6131
val Loss: 0.5894 Acc: 0.6260

Epoch 5/29
----------
train Loss: 0.6102 Acc: 0.6323
val Loss: 0.5655 Acc: 0.6602

Epoch 6/29
----------
train Loss: 0.5817 Acc: 0.6497
val Loss: 0.5458 Acc: 0.6455

Epoch 7/29
----------
train Loss: 0.5562 Acc: 0.6635
val Loss: 0.5485 Acc: 0.6504

Epoch 8/29
----------
train Loss: 0.5382 Acc: 0.6723
val Loss: 0.5072 Acc: 0.6689

Epoch 9/29
----------
train Loss: 0.5211 Acc: 0.6864
val Loss: 0.5148 Acc: 0.6621

Epoch 10/29
----------
train Loss: 0.5050 Acc: 0.6953
val Loss: 0.4930 Acc: 0.6895

Epoch 11/29
----------
train Loss: 0.4913 Acc: 0.7055
val Loss: 0.4794 Acc: 0.6914

Epoch 12/29
----------
t

## Student Accuracy Check

In [26]:
checker(dataloaders['test'], best_student)

Got 7021 / 10000 correct (70.21)


## Comparision

In [34]:
print('[[Accuracy & #parameters & Inference TIme]]')
print('[teacher]')
checker(dataloaders['test'], best_teacher)
print(sum(p.numel() for p in best_teacher.parameters() if p.requires_grad))
print('[student_solo]')
checker(dataloaders['test'], best_student_solo)
print(sum(p.numel() for p in best_student_solo.parameters() if p.requires_grad))
print('[student_kd]')
checker(dataloaders['test'], best_student)
print(sum(p.numel() for p in best_student.parameters() if p.requires_grad))

[[Accuracy & #parameters & Inference TIme]]
[teacher]
Got 7426 / 10000 correct (74.26)
435550
[student_solo]
Got 7018 / 10000 correct (70.18)
62430
[student_kd]
Got 7053 / 10000 correct (70.53)
62430
