# Knowledge Distillation 


<img src='./imgs/distillation.jpg'>

In [None]:
# import 
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


from torch.optim import lr_scheduler

from torchsummary import summary

import torchvision.models as models
import torchvision.transforms as transforms


from google.colab import drive


In [None]:
best_acc = 0  # best test accuracy

In [None]:
# Google Drive Mount 
drive.mount('/content/gdrive')

In [None]:
# Initial Value 
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BEST_ACC = 0 
START_EPOCH = 0 # start from epoch 0 or last checkpoint epoch
BATCH_SIZE = 16

In [None]:
# Data
print('==> Preparing data..')

mean_nums = [0.485, 0.456, 0.406]
std_nums = [0.229, 0.224, 0.225]
transform_train = transforms.Compose([transforms.RandomResizedCrop(size=256),
                                      transforms.RandomRotation(degrees=15),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean_nums, std_nums),
])


transform_test = transforms.Compose([transforms.Resize(256),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean_nums, std_nums),
])

In [None]:
# DataSet & DataLoader
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)

trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=BATCH_SIZE,
                                          shuffle=True,)
testloader = torch.utils.data.DataLoader(testset, 
                                         batch_size=BATCH_SIZE,
                                         shuffle=False)

In [None]:
# Check data size 
images, labels = next(iter(trainloader))
print(images.shape)

In [None]:
# teacher 
teacher = models.__dict__['resnet34'](pretrained=True)
for param in teacher.parameters():
    param.requires_grad = False
num_classes =10
in_features = teacher.fc.in_features
teacher.fc = nn.Linear(in_features,num_classes)

for name, child in teacher.named_children():
    if name in ['layer3', 'layer4','fc']:
        print(name + 'has been unfrozen.')
        for param in child.parameters():
            param.requires_grad = True
    else:
        for param in child.parameters():
            param.requires_grad = False

teacher = teacher.to(device)

In [None]:
summary(teacher,(3,32,32))

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(teacher.parameters(), lr=1e-3,
                      momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

In [None]:

# Training
def train(epoch,net):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        if batch_idx %1000 ==0:
#             print('Loss: %.3f | Acc: %.3f%% (%d/%d)' %(train_loss/(batch_idx+1),100.*correct/total, correct, total))
            print('Loss: %.3f | Acc: %.3f%% ' %(train_loss/(batch_idx+1),100.*correct/total))
    
def train_distillation(epoch, student, teacher):
    print('\nEpoch: %d' % epoch)
    student.train()
    teacher.eval()
    
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = student(inputs) # student
        teacher_outputs = teacher(inputs) # teacher 

        batch_size = targets.shape[0]
        loss = loss_fn_kd(outputs, targets, teacher_outputs) / batch_size
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        if batch_idx %1000 ==0:
            print('Loss: %.3f | Acc: %.3f%% ' %(train_loss/(batch_idx+1),100.*correct/total))
     
    
def test(epoch,net):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

In [None]:
# RUN 
for epoch in range(start_epoch, ):
    train(epoch,teacher)
    test(epoch,teacher)
    exp_lr_scheduler.step()
    torch.save(teacher.state_dict(), './gdrive/MyDrive/Colab Notebooks/data/'+f'teacher_{epoch}.pth')

In [None]:
# Student( User-defined Network; two CNN layer )
class Student(nn.Module):
    def __init__(self):
        super(Student, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(59536, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        # print(x.shape)
        x = self.flatten(x)
        # print(x.shape)
        x =  self.relu(self.fc1(x))
        x =  self.relu(self.fc2(x))
        x = self.fc3(x)
        return x


student = Student()
student = student.to(device)
summary(student,(3,256,256))

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(student.parameters(), lr=1e-3,
                      momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

In [None]:
def loss_fn_kd(outputs, labels, teacher_outputs, T=1, alpha=0.3):
  # outputs = model에 넣었을때 나온것 
  # labels =true label 
  # teacher_outputs = teacher_model(data) # teacher model eval()
  # params = T & alpha 
    """
    Compute the knowledge-distillation (KD) loss given outputs, labels.
    "Hyperparameters": temperature and alpha
    NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
    and student expects the input tensor to be log probabilities
    """
    # alpha = params.alpha
    # T = params.temperature
    KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
                             F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
              F.cross_entropy(outputs, labels) * (1. - alpha)



    return KD_loss

In [None]:
# Load Teacher dict 

PATH = './gdrive/MyDrive/Colab Notebooks/data/'+f'teacher_{epoch}.pth'
# num_classes =10

# teacher = models.__dict__['resnet34']()
# in_features = teacher.fc.in_features
# teacher.fc = nn.Linear(in_features,num_classes)

teacher.load_state_dict(torch.load(PATH))
for param in teacher.parameters():
    param.requires_grad = False
    
teacher.to(device)

In [None]:
print('start KD ... ')
for epoch in range(start_epoch, start_epoch+10):
    train_distillation(student, teacher)
    test(epoch,student)
    exp_lr_scheduler.step()