<a href="https://colab.research.google.com/github/jo1jun/Knowledge-Distillation/blob/main/Knowledge_Distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision.datasets as dset
import torchvision.transforms as T

import numpy as np
import copy

# tensorboard writer
from torch.utils.tensorboard import SummaryWriter

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Classification

## CIFAR10 Dataset

In [2]:
NUM_TRAIN = 49000

transform = T.Compose([
                T.ToTensor(),
                T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ])

dataloaders = {}

cifar10_train = dset.CIFAR10('./datasets', train=True, download=True,
                             transform=transform)
dataloaders['train'] = DataLoader(cifar10_train, batch_size=64, 
                          sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))

cifar10_val = dset.CIFAR10('./datasets', train=True, download=True,
                           transform=transform)
dataloaders['val'] = DataLoader(cifar10_val, batch_size=64, 
                        sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 50000)))

cifar10_test = dset.CIFAR10('./datasets', train=False, download=True, 
                            transform=transform)
dataloaders['test'] = DataLoader(cifar10_test, batch_size=64)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./datasets/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./datasets/cifar-10-python.tar.gz to ./datasets
Files already downloaded and verified
Files already downloaded and verified


## Trainer

In [3]:
def trainer(model_name, model, criterion, optimizer, num_epochs):

    model.to(device)
    writer = SummaryWriter(f'runs/{model_name}')
    best_model_wts = copy.deepcopy(model.state_dict())
    global_step, best_acc = 0, 0.0
    running_loss, running_acc = {}, {}

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss[phase], running_acc[phase] = 0.0, 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss[phase] += loss.item() * inputs.shape[0]
                running_acc[phase] += torch.sum(preds == labels.data)
            
            running_loss[phase] = running_loss[phase] / (len(dataloaders[phase]) * dataloaders[phase].batch_size)
            running_acc[phase] = running_acc[phase].double() / (len(dataloaders[phase]) * dataloaders[phase].batch_size)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, running_loss[phase], running_acc[phase]))

            # deep copy the model
            if phase == 'val' and running_acc[phase] > best_acc:
                best_acc = running_acc[phase]
                best_model_wts = copy.deepcopy(model.state_dict())

        writer.add_scalars(f'{model_name}/loss', {'train' : running_loss['train'], 'val' : running_loss['val']}, global_step)
        writer.add_scalars(f'{model_name}/acc', {'train' : running_acc['train'], 'val' : running_acc['val']}, global_step)
        global_step += 1

        print()

    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)

    torch.save(model.state_dict(), f'{model_name}.pt')
    print('model saved')

    writer.close()

    return model

## Teacher

In [4]:
class Teacher(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=5, padding=1)
        self.conv4 = nn.Conv2d(128, 64, kernel_size=5, stride=2, padding=1)
        self.conv5 = nn.Conv2d(64, 10, kernel_size=3, stride=2, padding=1)
        self.batch1 = nn.BatchNorm2d(32)
        self.batch2 = nn.BatchNorm2d(64)
        self.batch3 = nn.BatchNorm2d(128)
        self.batch4 = nn.BatchNorm2d(64)
        self.batch5 = nn.BatchNorm2d(10)

    def forward(self, x):
        x = F.relu(self.batch1(self.conv1(x)))
        x = F.dropout2d(F.relu(self.batch2(self.conv2(x))), 0.1)
        x = F.dropout2d(F.relu(self.batch3(self.conv3(x))), 0.2)
        x = F.dropout2d(F.relu(self.batch4(self.conv4(x))), 0.1)
        x = F.relu(self.batch5(self.conv5(x)))
        x = F.avg_pool2d(x, x.shape[-2:]).squeeze() # global average pooling
        return x

In [5]:
teacher = Teacher()
print(sum(p.numel() for p in teacher.parameters() if p.requires_grad))

435550


## Criterion

In [6]:
criterion = nn.CrossEntropyLoss()

## Teacher Train

In [7]:
optimizer = optim.Adam(teacher.parameters())
best_teacher = trainer('teacher', teacher, criterion, optimizer, num_epochs=30)

Epoch 0/29
----------
train Loss: 1.6182 Acc: 0.4573
val Loss: 1.3466 Acc: 0.5449

Epoch 1/29
----------
train Loss: 1.2668 Acc: 0.5859
val Loss: 1.1362 Acc: 0.6123

Epoch 2/29
----------
train Loss: 1.0962 Acc: 0.6390
val Loss: 1.0016 Acc: 0.6533

Epoch 3/29
----------
train Loss: 0.9812 Acc: 0.6755
val Loss: 0.9048 Acc: 0.6836

Epoch 4/29
----------
train Loss: 0.8835 Acc: 0.7101
val Loss: 0.8689 Acc: 0.6885

Epoch 5/29
----------
train Loss: 0.8061 Acc: 0.7339
val Loss: 0.8258 Acc: 0.7002

Epoch 6/29
----------
train Loss: 0.7386 Acc: 0.7592
val Loss: 0.7893 Acc: 0.7158

Epoch 7/29
----------
train Loss: 0.6830 Acc: 0.7755
val Loss: 0.7725 Acc: 0.7158

Epoch 8/29
----------
train Loss: 0.6362 Acc: 0.7919
val Loss: 0.7975 Acc: 0.6895

Epoch 9/29
----------
train Loss: 0.5899 Acc: 0.8063
val Loss: 0.7555 Acc: 0.7266

Epoch 10/29
----------
train Loss: 0.5460 Acc: 0.8214
val Loss: 0.7672 Acc: 0.7285

Epoch 11/29
----------
train Loss: 0.5046 Acc: 0.8349
val Loss: 0.7523 Acc: 0.7246

Ep

## Checker

In [14]:
import time

def checker(loader, model):
    num_correct = 0
    num_samples = 0
    model.eval()
    times = []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device)
            y = y.to(device=device)

            start_time = time.time()

            scores = model(x)

            times.append(time.time() - start_time)

            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))
        print("Average Inference Time : ", np.mean(times))

    return acc

## Teacher Accuracy Check

In [15]:
checker(dataloaders['test'], best_teacher)

Got 7380 / 10000 correct (73.80)
Average Inference Time :  0.0010606878122706323


0.738

## Student_Solo

In [16]:
class Student_Solo(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(64, 10, kernel_size=3, stride=2, padding=1)
        self.batch1 = nn.BatchNorm2d(32)
        self.batch2 = nn.BatchNorm2d(64)
        self.batch3 = nn.BatchNorm2d(64)
        self.batch4 = nn.BatchNorm2d(10)

    def forward(self, x):
        x = F.relu(self.batch1(self.conv1(x)))
        x = F.dropout2d(F.relu(self.batch2(self.conv2(x))), 0.1)
        x = F.dropout2d(F.relu(self.batch3(self.conv3(x))), 0.1)
        x = F.relu(self.batch4(self.conv4(x)))
        x = F.avg_pool2d(x, x.shape[-2:]).squeeze() # global average pooling
        return x

In [17]:
student_solo = Student_Solo()
print(sum(p.numel() for p in student_solo.parameters() if p.requires_grad))

62430


## Student_Solo Train

In [18]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(student_solo.parameters())
best_student_solo = trainer('student_solo', student_solo, criterion, optimizer, num_epochs=30)

Epoch 0/29
----------
train Loss: 1.7786 Acc: 0.4056
val Loss: 1.5744 Acc: 0.4609

Epoch 1/29
----------
train Loss: 1.5143 Acc: 0.5002
val Loss: 1.3907 Acc: 0.5215

Epoch 2/29
----------
train Loss: 1.3628 Acc: 0.5486
val Loss: 1.2562 Acc: 0.5811

Epoch 3/29
----------
train Loss: 1.2554 Acc: 0.5843
val Loss: 1.1718 Acc: 0.5928

Epoch 4/29
----------
train Loss: 1.1774 Acc: 0.6103
val Loss: 1.1007 Acc: 0.6172

Epoch 5/29
----------
train Loss: 1.1144 Acc: 0.6301
val Loss: 1.0532 Acc: 0.6299

Epoch 6/29
----------
train Loss: 1.0621 Acc: 0.6485
val Loss: 1.0135 Acc: 0.6445

Epoch 7/29
----------
train Loss: 1.0212 Acc: 0.6613
val Loss: 0.9890 Acc: 0.6455

Epoch 8/29
----------
train Loss: 0.9867 Acc: 0.6716
val Loss: 0.9688 Acc: 0.6523

Epoch 9/29
----------
train Loss: 0.9522 Acc: 0.6832
val Loss: 0.9625 Acc: 0.6582

Epoch 10/29
----------
train Loss: 0.9257 Acc: 0.6918
val Loss: 0.9363 Acc: 0.6611

Epoch 11/29
----------
train Loss: 0.9007 Acc: 0.6985
val Loss: 0.9564 Acc: 0.6533

Ep

## Student_Solo Accuracy Check

In [19]:
checker(dataloaders['test'], best_student_solo)

Got 6961 / 10000 correct (69.61)
Average Inference Time :  0.0008471543621865048


0.6961

## Student

In [20]:
student = Student_Solo() # same architecture with student_solo
print(sum(p.numel() for p in student_solo.parameters() if p.requires_grad))

62430


## Criterion (+ Distillation loss)

In [21]:
def criterion_KD(outputs, labels, teacher_outputs, T, alpha):

    KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1), F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
              F.cross_entropy(outputs, labels) * (1. - alpha)

    return KD_loss

## Trainer (+Distillation loss)

In [22]:
def trainer_KD(model_name, model, teacher, optimizer, num_epochs, T, alpha):

    model.to(device)
    teacher.to(device)
    teacher.eval()
    writer = SummaryWriter(f'runs/{model_name}')
    best_model_wts = copy.deepcopy(model.state_dict())
    global_step, best_acc = 0, 0.0
    running_loss, running_acc = {}, {}

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss[phase], running_acc[phase] = 0.0, 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()


                with torch.no_grad():
                  outputs_teacher = teacher(inputs)

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    
                    loss = criterion_KD(outputs, labels, outputs_teacher, T, alpha)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss[phase] += loss.item() * inputs.shape[0]
                running_acc[phase] += torch.sum(preds == labels.data)
            
            running_loss[phase] = running_loss[phase] / (len(dataloaders[phase]) * dataloaders[phase].batch_size)
            running_acc[phase] = running_acc[phase].double() / (len(dataloaders[phase]) * dataloaders[phase].batch_size)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, running_loss[phase], running_acc[phase]))

            # deep copy the model
            if phase == 'val' and running_acc[phase] > best_acc:
                best_acc = running_acc[phase]
                best_model_wts = copy.deepcopy(model.state_dict())

        writer.add_scalars(f'{model_name}/loss', {'train' : running_loss['train'], 'val' : running_loss['val']}, global_step)
        writer.add_scalars(f'{model_name}/acc', {'train' : running_acc['train'], 'val' : running_acc['val']}, global_step)
        global_step += 1

        print()

    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)

    torch.save(model.state_dict(), f'{model_name}.pt')
    print('model saved')

    writer.close()

    return model

## Student Train

In [23]:
T = 10
alpha = 0.8
optimizer = optim.Adam(student.parameters())
num_epochs = 30
best_student = trainer_KD('student', student, best_teacher, optimizer, num_epochs, T, alpha)

Epoch 0/29
----------


  "reduction: 'mean' divides the total loss by both the batch size and the support size."


train Loss: 0.5666 Acc: 0.4118
val Loss: 0.4933 Acc: 0.4502

Epoch 1/29
----------
train Loss: 0.4718 Acc: 0.5076
val Loss: 0.4222 Acc: 0.5410

Epoch 2/29
----------
train Loss: 0.4167 Acc: 0.5554
val Loss: 0.3788 Acc: 0.5625

Epoch 3/29
----------
train Loss: 0.3778 Acc: 0.5916
val Loss: 0.3414 Acc: 0.6113

Epoch 4/29
----------
train Loss: 0.3494 Acc: 0.6187
val Loss: 0.3202 Acc: 0.6162

Epoch 5/29
----------
train Loss: 0.3280 Acc: 0.6393
val Loss: 0.2966 Acc: 0.6377

Epoch 6/29
----------
train Loss: 0.3108 Acc: 0.6590
val Loss: 0.2867 Acc: 0.6318

Epoch 7/29
----------
train Loss: 0.2980 Acc: 0.6736
val Loss: 0.2829 Acc: 0.6602

Epoch 8/29
----------
train Loss: 0.2874 Acc: 0.6822
val Loss: 0.2719 Acc: 0.6611

Epoch 9/29
----------
train Loss: 0.2800 Acc: 0.6932
val Loss: 0.2723 Acc: 0.6709

Epoch 10/29
----------
train Loss: 0.2725 Acc: 0.7025
val Loss: 0.2610 Acc: 0.6699

Epoch 11/29
----------
train Loss: 0.2657 Acc: 0.7106
val Loss: 0.2549 Acc: 0.6777

Epoch 12/29
----------
t

## Student Accuracy Check

In [24]:
checker(dataloaders['test'], best_student)

Got 7200 / 10000 correct (72.00)
Average Inference Time :  0.0008396206388048306


0.72

## Comparision

In [25]:
print('[[Accuracy & #parameters & Inference TIme]]')
print('[teacher]')
checker(dataloaders['test'], best_teacher)
print(sum(p.numel() for p in best_teacher.parameters() if p.requires_grad))
print('[student_solo]')
checker(dataloaders['test'], best_student_solo)
print(sum(p.numel() for p in best_student_solo.parameters() if p.requires_grad))
print('[student_kd]')
checker(dataloaders['test'], best_student)
print(sum(p.numel() for p in best_student.parameters() if p.requires_grad))

[[Accuracy & #parameters & Inference TIme]]
[teacher]
Got 7398 / 10000 correct (73.98)
Average Inference Time :  0.0010553135234079543
435550
[student_solo]
Got 6974 / 10000 correct (69.74)
Average Inference Time :  0.0008367368370104747
62430
[student_kd]
Got 7216 / 10000 correct (72.16)
Average Inference Time :  0.0008491300473547286
62430
