# Knowledge Distillation 

reference: https://keras.io/examples/vision/knowledge_distillation/

<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Introduction-to-Knowledge-Distillation" data-toc-modified-id="Introduction-to-Knowledge-Distillation-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Introduction to Knowledge Distillation</a></span></li><li><span><a href="#Experiments" data-toc-modified-id="Experiments-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Experiments</a></span><ul class="toc-item"><li><span><a href="#DataSet" data-toc-modified-id="DataSet-2.1"><span class="toc-item-num">2.1&nbsp;&nbsp;</span>DataSet</a></span></li><li><span><a href="#Import" data-toc-modified-id="Import-2.2"><span class="toc-item-num">2.2&nbsp;&nbsp;</span>Import</a></span></li><li><span><a href="#User-define-Functions" data-toc-modified-id="User-define-Functions-2.3"><span class="toc-item-num">2.3&nbsp;&nbsp;</span>User-define Functions</a></span></li><li><span><a href="#STEP-1" data-toc-modified-id="STEP-1-2.4"><span class="toc-item-num">2.4&nbsp;&nbsp;</span>STEP 1</a></span></li><li><span><a href="#STEP2" data-toc-modified-id="STEP2-2.5"><span class="toc-item-num">2.5&nbsp;&nbsp;</span>STEP2</a></span></li></ul></li></ul></div>

## Introduction to Knowledge Distillation
<br>
Knowledge Distillation is a procedure for model compression, in which a small (`student`) model is trained to match a large pre-trained (`teacher`) model. Knowledge is transferred from the teacher model to the student by minimizing a loss function, aimed at matching softened teacher logits as well as ground-truth labels.

The logits are softened by applying a "temperature" scaling function in the softmax, effectively smoothing out the probability distribution and revealing inter-class relationships learned by the teacher.

$$ Loss = \alpha * loss_{class} + (1-\alpha)*loss_{Distillation}$$
$$ L(x;w) = \alpha * CE(y,student) + (1-\alpha)*KL(teacher, student)$$
$$ L(x;w) = \alpha * CE(y,student) + (1-\alpha)*KL(\sigma(\hat{y}^{t}/T), \sigma(\hat{y}^s/T))$$

where CE refers to the conventional cross-entropy loss and KL to the Kullback–Leibler divergence of the softmax $σ\sigma$, and the log-softmax $\lambda$. $T$ is the temperature, intended to smooth outputs from very large teacher models and $\alpha$ is a simple balancing weight.

## Experiments

<img src='./imgs/distillation.jpg'>

### DataSet 


<img src='https://www.samyzaf.com/ML/cifar10/cifar1.jpg'>
<center>CIFAR10 dataset</center>

### Import

In [1]:
# import 
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


from torch.optim import lr_scheduler

from torchsummary import summary

import torchvision
import torchvision.models as models
import torchvision.transforms as transforms


In [2]:
# Path 설정 
import os
os.chdir('./') # DataPath
current_path = os.getcwd() # current folder

In [5]:
# Initial Value 
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
START_EPOCH = 0 # start from epoch 0 or last checkpoint epoch
FINAL_EPOCH = 20
BATCH_SIZE = 16

CLASSES =10

In [6]:
# Data
print('==> Preparing data..')

mean_nums = [0.485, 0.456, 0.406]
std_nums = [0.229, 0.224, 0.225]
transform_train = transforms.Compose([transforms.RandomResizedCrop(size=256),
                                      transforms.RandomRotation(degrees=15),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean_nums, std_nums),
])


transform_test = transforms.Compose([transforms.Resize(256),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean_nums, std_nums),
])

==> Preparing data..


In [7]:
# DataSet & DataLoader
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)

trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=BATCH_SIZE,
                                          shuffle=True,)
testloader = torch.utils.data.DataLoader(testset, 
                                         batch_size=BATCH_SIZE,
                                         shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
# Check data size 
images, labels = next(iter(trainloader))
print(images.shape)

torch.Size([16, 3, 256, 256])


### User-define Functions

In [None]:
def loss_fn_kd(outputs, labels, teacher_outputs, T=2, alpha=0.5):
    """
    Compute the knowledge-distillation (KD) loss given outputs, labels.
    "Hyperparameters": temperature and alpha
    NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
    and student expects the input tensor to be log probabilities
    """
    KD_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(outputs/T, dim=1),
                             F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
              F.cross_entropy(outputs, labels) * (1. - alpha)



    return KD_loss


def train(epoch,net):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        if batch_idx %1000 ==0:
#             print('Loss: %.3f | Acc: %.3f%% (%d/%d)' %(train_loss/(batch_idx+1),100.*correct/total, correct, total))
            print('Loss: %.3f | Acc: %.3f%% ' %(train_loss/(batch_idx+1),100.*correct/total))
    
def train_distillation(epoch, student, teacher,TEMPERATURE, ALPHA):
    print('\nEpoch: %d' % epoch)
    student.train()
    teacher.eval()
    
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        
        optimizer.zero_grad()
        outputs = student(inputs) # student
        teacher_outputs = teacher(inputs) # teacher 

        loss = loss_fn_kd(outputs, targets, teacher_outputs,TEMPERATURE,ALPHA)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        if batch_idx %1000 ==0:
            print('Loss: %.3f | Acc: %.3f%% ' %(train_loss/(batch_idx+1),100.*correct/total))
     
    
def test(epoch,net):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    print('Accuracy of the network on the test images: %d %%' % (
    100 * correct / total))

### STEP 1

Train Fine-Tuning the teacher Network for CIFAR10 dataset

In [9]:
# teacher 
teacher = models.__dict__['resnet34'](pretrained=True)
for param in teacher.parameters():
    param.requires_grad = False
in_features = teacher.fc.in_features
teacher.fc = nn.Linear(in_features,CLASSES)

for name, child in teacher.named_children():
    if name in ['layer3', 'layer4','fc']:
        print(name + 'has been unfrozen.')
        for param in child.parameters():
            param.requires_grad = True
    else:
        for param in child.parameters():
            param.requires_grad = False

teacher = teacher.to(DEVICE)

Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth


HBox(children=(FloatProgress(value=0.0, max=87306240.0), HTML(value='')))


layer3has been unfrozen.
layer4has been unfrozen.
fchas been unfrozen.


In [10]:
summary(teacher,(3,256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 16, 16]           9,408
       BatchNorm2d-2           [-1, 64, 16, 16]             128
              ReLU-3           [-1, 64, 16, 16]               0
         MaxPool2d-4             [-1, 64, 8, 8]               0
            Conv2d-5             [-1, 64, 8, 8]          36,864
       BatchNorm2d-6             [-1, 64, 8, 8]             128
              ReLU-7             [-1, 64, 8, 8]               0
            Conv2d-8             [-1, 64, 8, 8]          36,864
       BatchNorm2d-9             [-1, 64, 8, 8]             128
             ReLU-10             [-1, 64, 8, 8]               0
       BasicBlock-11             [-1, 64, 8, 8]               0
           Conv2d-12             [-1, 64, 8, 8]          36,864
      BatchNorm2d-13             [-1, 64, 8, 8]             128
             ReLU-14             [-1, 6

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(teacher.parameters(), lr=1e-3,
                      momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

In [13]:
# RUN 
for epoch in range(START_EPOCH,FINAL_EPOCH+1):
    train(epoch,teacher)
    test(epoch,teacher)
    exp_lr_scheduler.step()
    if epoch == FINAL_EPOCH+1:
        torch.save(teacher.state_dict(), f'./teacher_{epoch}.pth')


Epoch: 0
Loss: 2.626 | Acc: 6.250% 
Loss: 1.170 | Acc: 59.191% 
Loss: 1.012 | Acc: 64.893% 
Loss: 0.944 | Acc: 67.248% 
Accuracy of the network on the 10000 test images: 90 %

Epoch: 1
Loss: 0.632 | Acc: 75.000% 
Loss: 0.710 | Acc: 75.387% 
Loss: 0.710 | Acc: 75.628% 
Loss: 0.703 | Acc: 75.806% 
Accuracy of the network on the 10000 test images: 92 %

Epoch: 2
Loss: 0.142 | Acc: 100.000% 
Loss: 0.635 | Acc: 77.891% 
Loss: 0.634 | Acc: 78.033% 
Loss: 0.625 | Acc: 78.468% 
Accuracy of the network on the 10000 test images: 93 %

Epoch: 3
Loss: 0.525 | Acc: 87.500% 
Loss: 0.599 | Acc: 79.595% 
Loss: 0.592 | Acc: 79.654% 
Loss: 0.591 | Acc: 79.580% 
Accuracy of the network on the 10000 test images: 93 %

Epoch: 4
Loss: 0.086 | Acc: 93.750% 
Loss: 0.540 | Acc: 81.550% 
Loss: 0.547 | Acc: 81.166% 
Loss: 0.553 | Acc: 80.986% 


KeyboardInterrupt: ignored

### STEP2

Knoweldge Distillation
- Teacher: ResNet34(pretrained)
- Student: ResNet18

In [None]:
TEMPERATURE = 2
ALPHA = 0.3 # KD distillation ratio 

In [None]:
# Load Teacher dict 
DICT_PATH = os.path.join(,f'teacher_{FINAL_EPOCH}.pth')


# teacher model 
teacher.load_state_dict(torch.load(DICT_PATH))
for param in teacher.parameters():
    param.requires_grad = False
    

# student model 
student = models.__dict__['resnet18']()
in_features = student.fc.in_features
student.fc = nn.Linear(in_features,CLASSES)
print('STUDENT MODEL: RESENET18')
summary(student,(3,256,256))

In [69]:
teacher.to(DEVICE)
student.to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(student.parameters(), lr=1e-3,
                      momentum=0.9, weight_decay=5e-4)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [None]:
print('[INFO] START KD')
print('- KD temperature {}'.format(TEMPERATURE))
print('- KD alpha {}'.format(ALPHA))

for epoch in range(0, FINAL_EPOCH):
    train_distillation(epoch,student, teacher,TEMPERATURE, ALPHA)
    test(epoch,student)

start KD ... 

Epoch: 0
Loss: 4.646 | Acc: 0.000% 
Loss: 3.487 | Acc: 25.918% 
Loss: 3.305 | Acc: 30.210% 
Loss: 3.168 | Acc: 33.012% 
Accuracy of the network on the 10000 test images: 47 %

Epoch: 1
Loss: 2.412 | Acc: 37.500% 
Loss: 2.678 | Acc: 42.620% 
Loss: 2.584 | Acc: 44.650% 
Loss: 2.518 | Acc: 45.945% 
Accuracy of the network on the 10000 test images: 57 %

Epoch: 2
Loss: 2.020 | Acc: 50.000% 
Loss: 2.230 | Acc: 51.754% 
Loss: 2.186 | Acc: 52.521% 
Loss: 2.132 | Acc: 53.274% 
Accuracy of the network on the 10000 test images: 66 %

Epoch: 3
Loss: 1.449 | Acc: 56.250% 
Loss: 1.948 | Acc: 56.756% 
Loss: 1.914 | Acc: 57.065% 
Loss: 1.881 | Acc: 57.870% 
Accuracy of the network on the 10000 test images: 70 %

Epoch: 4
Loss: 1.245 | Acc: 75.000% 
Loss: 1.736 | Acc: 60.652% 
