## Distilling the Knowledge in a Neural Network
- 2022.09.29 minji kwak
-----------
- reference(paper): Distilling the Knowledge in a Neural Network (https://arxiv.org/abs/1503.02531)
- reference(github): https://github.com/shriramsb/Distilling-the-Knowledge-in-a-Neural-Network

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


### Import and Setting

In [2]:
cd /content/drive/MyDrive/ML_coding

/content/drive/MyDrive/ML_coding


In [3]:
import numpy as np
import math

%matplotlib inline
import matplotlib.pyplot as plt

import pickle
import argparse
import time
import itertools
from copy import deepcopy
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import networks
import utils
    
%load_ext autoreload
%autoreload 2

In [4]:
use_gpu = True    # set use_gpu to True if system has gpu   # id of gpu to be used
cpu_device = torch.device('cpu')
# fast_device is where computation (training, inference) happens
fast_device = torch.device('cpu')
if use_gpu:
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # set visible devices depending on system configuration
    fast_device = torch.device('cuda:0')

In [5]:
def reproducibilitySeed():
    """
    Ensure reproducibility of results; Seeds to 0
    """
    torch_init_seed = 0
    torch.manual_seed(torch_init_seed)
    numpy_init_seed = 0
    np.random.seed(numpy_init_seed)
    if use_gpu:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

reproducibilitySeed()

In [24]:
checkpoints_path = 'checkpoints_teacher/'
checkpoints_path_teacher = 'checkpoints_teacher/'
checkpoints_path_student = 'checkpoints_student/'
checkpoints_path_student_distill = 'checkpoints_student_distill/'


if not os.path.exists(checkpoints_path):
    os.mkdir(checkpoints_path)
if not os.path.exists(checkpoints_path_teacher):
    os.mkdir(checkpoints_path_teacher)
if not os.path.exists(checkpoints_path_student):
    os.mkdir(checkpoints_path_student)
if not os.path.exists(checkpoints_path_student_distill):
    os.mkdir(checkpoints_path_student_distill)

Load dataset

In [14]:
import torchvision
import torchvision.transforms as transforms

mnist_image_shape = (28, 28)
random_pad_size = 2
# Training images augmented by randomly shifting images by at max. 2 pixels in any of 4 directions
transform_train = transforms.Compose(
                [
                    transforms.RandomCrop(mnist_image_shape, random_pad_size),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5], [0.5])
                ]
            )

transform_test = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize([0.5], [0.5])
                ]
            )

train_val_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=True, 
                                            download=True, transform=transform_train)

test_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=False, 
                                            download=True, transform=transform_test)

num_train = int(1.0 * len(train_val_dataset) * 95 / 100)
num_val = len(train_val_dataset) - num_train
train_dataset, val_dataset = torch.utils.data.random_split(train_val_dataset, [num_train, num_val])

batch_size = 128
train_val_loader = torch.utils.data.DataLoader(train_val_dataset, batch_size=128, shuffle=True, num_workers=0)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=0)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=0)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=0)

### Teacher network

Train teacher network

In [16]:
num_epochs = 20
print_every = 100    # Interval size for which to print statistics of training

- 10 epoch당 약 3분 소요

In [17]:
# Hyperparamters can be tuned by setting required range below
# learning_rates = list(np.logspace(-4, -2, 3))
learning_rates = [1e-2]
learning_rate_decays = [0.95]    # learning rate decays at every epoch
# weight_decays = [0.0] + list(np.logspace(-5, -1, 5))
weight_decays = [1e-5]           # regularization weight
momentums = [0.9]
# dropout_probabilities = [(0.2, 0.5), (0.0, 0.0)]
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []
for hparam_tuple in itertools.product(dropout_probabilities, weight_decays, learning_rate_decays, 
                                        momentums, learning_rates):
    hparam = {}
    hparam['dropout_input'] = hparam_tuple[0][0]
    hparam['dropout_hidden'] = hparam_tuple[0][1]
    hparam['weight_decay'] = hparam_tuple[1]
    hparam['lr_decay'] = hparam_tuple[2]
    hparam['momentum'] = hparam_tuple[3]
    hparam['lr'] = hparam_tuple[4]
    hparams_list.append(hparam)

results = {}
for hparam in hparams_list:
    print('Training with hparams' + utils.hparamToString(hparam))
    reproducibilitySeed()
    teacher_net = networks.TeacherNetwork()
    teacher_net = teacher_net.to(fast_device)
    hparam_tuple = utils.hparamDictToTuple(hparam)
    results[hparam_tuple] = utils.trainTeacherOnHparam(teacher_net, hparam, num_epochs, 
                                                        train_val_loader, None, 
                                                        print_every=print_every, 
                                                        fast_device=fast_device)
    save_path = checkpoints_path + utils.hparamToString(hparam) + '_final.tar'
    torch.save({'results' : results[hparam_tuple], 
                'model_state_dict' : teacher_net.state_dict(), 
                'epoch' : num_epochs}, save_path)

Training with hparamsdropout_hidden=0.0, dropout_input=0.0, lr=0.01, lr_decay=0.95, momentum=0.9, weight_decay=1e-05
[1,   100/  469] train loss: 1.026 train accuracy: 0.609
[1,   200/  469] train loss: 0.693 train accuracy: 0.789
[1,   300/  469] train loss: 0.607 train accuracy: 0.812
[1,   400/  469] train loss: 0.525 train accuracy: 0.836
[2,   100/  469] train loss: 0.339 train accuracy: 0.883
[2,   200/  469] train loss: 0.300 train accuracy: 0.930
[2,   300/  469] train loss: 0.260 train accuracy: 0.922
[2,   400/  469] train loss: 0.248 train accuracy: 0.945
[3,   100/  469] train loss: 0.276 train accuracy: 0.938
[3,   200/  469] train loss: 0.183 train accuracy: 0.961
[3,   300/  469] train loss: 0.124 train accuracy: 0.961
[3,   400/  469] train loss: 0.168 train accuracy: 0.953
[4,   100/  469] train loss: 0.182 train accuracy: 0.945
[4,   200/  469] train loss: 0.195 train accuracy: 0.953
[4,   300/  469] train loss: 0.170 train accuracy: 0.953
[4,   400/  469] train loss:

In [18]:
# Calculate test accuracy
_, test_accuracy = utils.getLossAccuracyOnDataset(teacher_net, test_loader, fast_device)
print('test accuracy: ', test_accuracy)

test accuracy:  0.9853


### Student network

Load teacher network

In [19]:
checkpoints_path_teacher = 'checkpoints_teacher/'
checkpoints_path_student = 'checkpoints_student/'
checkpoints_path_student_distill = 'checkpoints_student_distill/'

In [20]:
# set the hparams used for training teacher to load the teacher network
learning_rates = [1e-2]
learning_rate_decays = [0.95]
weight_decays = [1e-5]
momentums = [0.9]
# keeping dropout input = dropout hidden
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []
for hparam_tuple in itertools.product(dropout_probabilities, weight_decays, learning_rate_decays, 
                                        momentums, learning_rates):
    hparam = {}
    hparam['dropout_input'] = hparam_tuple[0][0]
    hparam['dropout_hidden'] = hparam_tuple[0][1]
    hparam['weight_decay'] = hparam_tuple[1]
    hparam['lr_decay'] = hparam_tuple[2]
    hparam['momentum'] = hparam_tuple[3]
    hparam['lr'] = hparam_tuple[4]
    hparams_list.append(hparam)


load_path = checkpoints_path_teacher + utils.hparamToString(hparams_list[0]) + '_final.tar'
teacher_net = networks.TeacherNetwork()
teacher_net.load_state_dict(torch.load(load_path, map_location=fast_device)['model_state_dict'])
teacher_net = teacher_net.to(fast_device)

In [21]:
#  Calculate teacher test accuracy
_, test_accuracy = utils.getLossAccuracyOnDataset(teacher_net, test_loader, fast_device)
print('teacher test accuracy: ', test_accuracy)

teacher test accuracy:  0.9853


Train studnet network without distillation

In [22]:
num_epochs = 20
print_every = 100

In [25]:
temperatures = [1]    # temperature for distillation loss
# trade-off between soft-target (st) cross-entropy and true-target (tt) cross-entropy;
# loss = alpha * st + (1 - alpha) * tt
alphas = [0.0]
learning_rates = [1e-2]
learning_rate_decays = [0.95]
weight_decays = [1e-5]
momentums = [0.9]
# No dropout used
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []
for hparam_tuple in itertools.product(alphas, temperatures, dropout_probabilities, weight_decays, learning_rate_decays, 
                                        momentums, learning_rates):
    hparam = {}
    hparam['alpha'] = hparam_tuple[0]
    hparam['T'] = hparam_tuple[1]
    hparam['dropout_input'] = hparam_tuple[2][0]
    hparam['dropout_hidden'] = hparam_tuple[2][1]
    hparam['weight_decay'] = hparam_tuple[3]
    hparam['lr_decay'] = hparam_tuple[4]
    hparam['momentum'] = hparam_tuple[5]
    hparam['lr'] = hparam_tuple[6]
    hparams_list.append(hparam)

results_no_distill = {}
for hparam in hparams_list:
    print('Training with hparams' + utils.hparamToString(hparam))
    reproducibilitySeed()
    student_net = networks.StudentNetwork()
    student_net = student_net.to(fast_device)
    hparam_tuple = utils.hparamDictToTuple(hparam)
    results_no_distill[hparam_tuple] = utils.trainStudentOnHparam(teacher_net, student_net, hparam, num_epochs, 
                                                                    train_val_loader, None, 
                                                                    print_every=print_every, 
                                                                    fast_device=fast_device)
    save_path = checkpoints_path_student + utils.hparamToString(hparam) + '_final.tar'
    torch.save({'results' : results_no_distill[hparam_tuple], 
                'model_state_dict' : student_net.state_dict(), 
                'epoch' : num_epochs}, save_path)

Training with hparamsT=1, alpha=0.0, dropout_hidden=0.0, dropout_input=0.0, lr=0.01, lr_decay=0.95, momentum=0.9, weight_decay=1e-05




[1,   100/  469] train loss: 0.914 train accuracy: 0.742
[1,   200/  469] train loss: 0.991 train accuracy: 0.664
[1,   300/  469] train loss: 0.675 train accuracy: 0.812
[1,   400/  469] train loss: 0.460 train accuracy: 0.867
[2,   100/  469] train loss: 0.376 train accuracy: 0.906
[2,   200/  469] train loss: 0.443 train accuracy: 0.859
[2,   300/  469] train loss: 0.277 train accuracy: 0.922
[2,   400/  469] train loss: 0.319 train accuracy: 0.898
[3,   100/  469] train loss: 0.254 train accuracy: 0.945
[3,   200/  469] train loss: 0.211 train accuracy: 0.938
[3,   300/  469] train loss: 0.239 train accuracy: 0.961
[3,   400/  469] train loss: 0.146 train accuracy: 0.953
[4,   100/  469] train loss: 0.281 train accuracy: 0.906
[4,   200/  469] train loss: 0.269 train accuracy: 0.930
[4,   300/  469] train loss: 0.171 train accuracy: 0.953
[4,   400/  469] train loss: 0.132 train accuracy: 0.961
[5,   100/  469] train loss: 0.135 train accuracy: 0.961
[5,   200/  469] train loss: 0.

In [26]:
# Calculate student test accuracy
_, test_accuracy = utils.getLossAccuracyOnDataset(student_net, test_loader, fast_device)
print('student test accuracy (w/o distillation): ', test_accuracy)

student test accuracy (w/o distillation):  0.9819


Train studnet network using distillation

In [28]:
num_epochs = 20
print_every = 100

In [29]:
temperatures = [10]
# trade-off between soft-target (st) cross-entropy and true-target (tt) cross-entropy;
# loss = alpha * st + (1 - alpha) * tt
alphas = [0.5]
learning_rates = [1e-2]
learning_rate_decays = [0.95]
weight_decays = [1e-5]
momentums = [0.9]
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []
for hparam_tuple in itertools.product(alphas, temperatures, dropout_probabilities, weight_decays, learning_rate_decays, 
                                        momentums, learning_rates):
    hparam = {}
    hparam['alpha'] = hparam_tuple[0]
    hparam['T'] = hparam_tuple[1]
    hparam['dropout_input'] = hparam_tuple[2][0]
    hparam['dropout_hidden'] = hparam_tuple[2][1]
    hparam['weight_decay'] = hparam_tuple[3]
    hparam['lr_decay'] = hparam_tuple[4]
    hparam['momentum'] = hparam_tuple[5]
    hparam['lr'] = hparam_tuple[6]
    hparams_list.append(hparam)

results_distill = {}
for hparam in hparams_list:
    print('Training with hparams' + utils.hparamToString(hparam))
    reproducibilitySeed()
    student_net = networks.StudentNetwork()
    student_net = student_net.to(fast_device)
    hparam_tuple = utils.hparamDictToTuple(hparam)
    results_distill[hparam_tuple] = utils.trainStudentOnHparam(teacher_net, student_net, hparam, num_epochs, 
                                                                train_val_loader, None, 
                                                                print_every=print_every, 
                                                                fast_device=fast_device)
    save_path = checkpoints_path_student_distill + utils.hparamToString(hparam) + '_final.tar'
    torch.save({'results' : results_distill[hparam_tuple], 
                'model_state_dict' : student_net.state_dict(), 
                'epoch' : num_epochs}, save_path)

Training with hparamsT=10, alpha=0.5, dropout_hidden=0.0, dropout_input=0.0, lr=0.01, lr_decay=0.95, momentum=0.9, weight_decay=1e-05
[1,   100/  469] train loss: 4.133 train accuracy: 0.734
[1,   200/  469] train loss: 2.181 train accuracy: 0.836
[1,   300/  469] train loss: 1.236 train accuracy: 0.875
[1,   400/  469] train loss: 0.798 train accuracy: 0.945
[2,   100/  469] train loss: 0.543 train accuracy: 0.969
[2,   200/  469] train loss: 0.454 train accuracy: 0.930
[2,   300/  469] train loss: 0.439 train accuracy: 0.961
[2,   400/  469] train loss: 0.367 train accuracy: 0.969
[3,   100/  469] train loss: 0.332 train accuracy: 0.961
[3,   200/  469] train loss: 0.245 train accuracy: 0.977
[3,   300/  469] train loss: 0.267 train accuracy: 0.969
[3,   400/  469] train loss: 0.193 train accuracy: 0.984
[4,   100/  469] train loss: 0.295 train accuracy: 0.953
[4,   200/  469] train loss: 0.267 train accuracy: 0.953
[4,   300/  469] train loss: 0.225 train accuracy: 0.969
[4,   400/ 

In [30]:
# Calculate student test accuracy
_, test_accuracy = utils.getLossAccuracyOnDataset(student_net, test_loader, fast_device)
print('student test accuracy (w distillation): ', test_accuracy)

student test accuracy (w distillation):  0.982
