### Import required packages and limit GPU usage

In [1]:
import numpy as np
import math

%matplotlib inline
import matplotlib.pyplot as plt

import pickle
import argparse
import time
import itertools
from copy import deepcopy
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import networks
import utils
    
%load_ext autoreload
%autoreload 2

In [2]:
use_gpu = True    # set use_gpu to True if system has gpu
gpu_id = 2        # id of gpu to be used
cpu_device = torch.device('cpu')
# fast_device is where computation (training, inference) happens
fast_device = torch.device('cpu')
if use_gpu:
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2'    # set visible devices depending on system configuration
    fast_device = torch.device('cuda:' + str(gpu_id))

In [3]:
# Ensure reproducibility
def reproducibilitySeed():
    """
    Ensure reproducibility of results; Seeds to 0
    """
    torch_init_seed = 0
    torch.manual_seed(torch_init_seed)
    numpy_init_seed = 0
    np.random.seed(numpy_init_seed)
    if use_gpu:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

reproducibilitySeed()

### Load dataset

In [4]:
import torchvision
import torchvision.transforms as transforms

# Student trained without data augmentation
transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5), (0.5, 0.5))
                ]
            )

train_val_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=True, 
                                            download=True, transform=transform)

test_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=False, 
                                            download=True, transform=transform)

num_train = int(1.0 * len(train_val_dataset) * 95 / 100)
num_val = len(train_val_dataset) - num_train
train_dataset, val_dataset = torch.utils.data.random_split(train_val_dataset, [num_train, num_val])

batch_size = 128
train_val_loader = torch.utils.data.DataLoader(train_val_dataset, batch_size=128, shuffle=True, num_workers=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

In [5]:
checkpoints_path_teacher = 'checkpoints_teacher/'
checkpoints_path_student = 'checkpoints_student/'
if not os.path.exists(checkpoints_path_student):
    os.makedirs(checkpoints_path_student)

### Load teacher network

In [6]:
# set the hparams used for training teacher to load the teacher network
learning_rates = [1e-2]
learning_rate_decays = [0.95]
weight_decays = [1e-5]
momentums = [0.9]
# keeping dropout input = dropout hidden
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []
for hparam_tuple in itertools.product(dropout_probabilities, weight_decays, learning_rate_decays, 
                                        momentums, learning_rates):
    hparam = {}
    hparam['dropout_input'] = hparam_tuple[0][0]
    hparam['dropout_hidden'] = hparam_tuple[0][1]
    hparam['weight_decay'] = hparam_tuple[1]
    hparam['lr_decay'] = hparam_tuple[2]
    hparam['momentum'] = hparam_tuple[3]
    hparam['lr'] = hparam_tuple[4]
    hparams_list.append(hparam)
    
load_path = checkpoints_path_teacher + utils.hparamToString(hparams_list[0]) + '_final.tar'
teacher_net = networks.TeacherNetwork()
teacher_net.load_state_dict(torch.load(load_path, map_location=fast_device)['model_state_dict'])
teacher_net = teacher_net.to(fast_device)

In [7]:
# Calculate teacher test accuracy
_, test_accuracy = utils.getLossAccuracyOnDataset(teacher_net, test_loader, fast_device)
print('teacher test accuracy: ', test_accuracy)

teacher test accuracy:  0.9892


### Train student network without distillation

In [8]:
num_epochs = 60
print_every = 100

In [9]:
temperatures = [1]    # temperature for distillation loss
# trade-off between soft-target (st) cross-entropy and true-target (tt) cross-entropy;
# loss = alpha * st + (1 - alpha) * tt
alphas = [0.0]
learning_rates = [1e-2]
learning_rate_decays = [0.95]
weight_decays = [1e-5]
momentums = [0.9]
# No dropout used
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []
for hparam_tuple in itertools.product(alphas, temperatures, dropout_probabilities, weight_decays, learning_rate_decays, 
                                        momentums, learning_rates):
    hparam = {}
    hparam['alpha'] = hparam_tuple[0]
    hparam['T'] = hparam_tuple[1]
    hparam['dropout_input'] = hparam_tuple[2][0]
    hparam['dropout_hidden'] = hparam_tuple[2][1]
    hparam['weight_decay'] = hparam_tuple[3]
    hparam['lr_decay'] = hparam_tuple[4]
    hparam['momentum'] = hparam_tuple[5]
    hparam['lr'] = hparam_tuple[6]
    hparams_list.append(hparam)

results_no_distill = {}
for hparam in hparams_list:
    print('Training with hparams' + utils.hparamToString(hparam))
    reproducibilitySeed()
    student_net = networks.StudentNetwork()
    student_net = student_net.to(fast_device)
    hparam_tuple = utils.hparamDictToTuple(hparam)
    results_no_distill[hparam_tuple] = utils.trainStudentOnHparam(teacher_net, student_net, hparam, num_epochs, 
                                                                    train_val_loader, None, 
                                                                    print_every=print_every, 
                                                                    fast_device=fast_device)
    save_path = checkpoints_path_student + utils.hparamToString(hparam) + '_final.tar'
    torch.save({'results' : results_no_distill[hparam_tuple], 
                'model_state_dict' : student_net.state_dict(), 
                'epoch' : num_epochs}, save_path)

Training with hparamsT=1, alpha=0.0, dropout_hidden=0.0, dropout_input=0.0, lr=0.01, lr_decay=0.95, momentum=0.9, weight_decay=1e-05
[1,   100/  469] train loss: 0.426 train accuracy: 0.875
[1,   200/  469] train loss: 0.326 train accuracy: 0.898
[1,   300/  469] train loss: 0.229 train accuracy: 0.945
[1,   400/  469] train loss: 0.232 train accuracy: 0.922
[2,   100/  469] train loss: 0.245 train accuracy: 0.938
[2,   200/  469] train loss: 0.216 train accuracy: 0.945
[2,   300/  469] train loss: 0.134 train accuracy: 0.984
[2,   400/  469] train loss: 0.201 train accuracy: 0.961
[3,   100/  469] train loss: 0.108 train accuracy: 0.977
[3,   200/  469] train loss: 0.119 train accuracy: 0.961
[3,   300/  469] train loss: 0.151 train accuracy: 0.969
[3,   400/  469] train loss: 0.189 train accuracy: 0.945
[4,   100/  469] train loss: 0.191 train accuracy: 0.945
[4,   200/  469] train loss: 0.259 train accuracy: 0.938
[4,   300/  469] train loss: 0.141 train accuracy: 0.977
[4,   400/  

[36,   100/  469] train loss: 0.019 train accuracy: 1.000
[36,   200/  469] train loss: 0.020 train accuracy: 0.992
[36,   300/  469] train loss: 0.027 train accuracy: 0.992
[36,   400/  469] train loss: 0.023 train accuracy: 1.000
[37,   100/  469] train loss: 0.047 train accuracy: 0.984
[37,   200/  469] train loss: 0.035 train accuracy: 0.984
[37,   300/  469] train loss: 0.010 train accuracy: 1.000
[37,   400/  469] train loss: 0.056 train accuracy: 0.984
[38,   100/  469] train loss: 0.044 train accuracy: 0.984
[38,   200/  469] train loss: 0.023 train accuracy: 0.992
[38,   300/  469] train loss: 0.037 train accuracy: 0.984
[38,   400/  469] train loss: 0.029 train accuracy: 0.992
[39,   100/  469] train loss: 0.008 train accuracy: 1.000
[39,   200/  469] train loss: 0.017 train accuracy: 1.000
[39,   300/  469] train loss: 0.044 train accuracy: 0.992
[39,   400/  469] train loss: 0.014 train accuracy: 1.000
[40,   100/  469] train loss: 0.052 train accuracy: 0.984
[40,   200/  4

In [10]:
# Calculate student test accuracy
_, test_accuracy = utils.getLossAccuracyOnDataset(student_net, test_loader, fast_device)
print('student test accuracy (w/o distillation): ', test_accuracy)

student test accuracy (w/o distillation):  0.9819


### Hyperparameter search utils

In [None]:
# plt.rcParams['figure.figsize'] = [10, 5]
weight_decay_scatter = ([math.log10(h['weight_decay']) if h['weight_decay'] > 0 else -6 for h in hparams_list])
dropout_scatter = [int(h['dropout_input'] == 0.2) for h in hparams_list]
colors = []
for i in range(len(hparams_list)):
    cur_hparam_tuple = utils.hparamDictToTuple(hparams_list[i])
    colors.append(results_no_distill[cur_hparam_tuple]['val_acc'][-1])
    
marker_size = 100
fig, ax = plt.subplots()
plt.scatter(weight_decay_scatter, dropout_scatter, marker_size, c=colors, edgecolors='black')
plt.colorbar()
for i in range(len(weight_decay_scatter)):
    ax.annotate(str('%0.4f' % (colors[i], )), (weight_decay_scatter[i], dropout_scatter[i]))
plt.show()

### Train student network using distillation

In [11]:
num_epochs = 60
print_every = 100

In [12]:
temperatures = [10]
# trade-off between soft-target (st) cross-entropy and true-target (tt) cross-entropy;
# loss = alpha * st + (1 - alpha) * tt
alphas = [0.5]
learning_rates = [1e-2]
learning_rate_decays = [0.95]
weight_decays = [1e-5]
momentums = [0.9]
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []
for hparam_tuple in itertools.product(alphas, temperatures, dropout_probabilities, weight_decays, learning_rate_decays, 
                                        momentums, learning_rates):
    hparam = {}
    hparam['alpha'] = hparam_tuple[0]
    hparam['T'] = hparam_tuple[1]
    hparam['dropout_input'] = hparam_tuple[2][0]
    hparam['dropout_hidden'] = hparam_tuple[2][1]
    hparam['weight_decay'] = hparam_tuple[3]
    hparam['lr_decay'] = hparam_tuple[4]
    hparam['momentum'] = hparam_tuple[5]
    hparam['lr'] = hparam_tuple[6]
    hparams_list.append(hparam)

results_distill = {}
for hparam in hparams_list:
    print('Training with hparams' + utils.hparamToString(hparam))
    reproducibilitySeed()
    student_net = networks.StudentNetwork()
    student_net = student_net.to(fast_device)
    hparam_tuple = utils.hparamDictToTuple(hparam)
    results_distill[hparam_tuple] = utils.trainStudentOnHparam(teacher_net, student_net, hparam, num_epochs, 
                                                                train_val_loader, None, 
                                                                print_every=print_every, 
                                                                fast_device=fast_device)
    save_path = checkpoints_path_student + utils.hparamToString(hparam) + '_final.tar'
    torch.save({'results' : results_distill[hparam_tuple], 
                'model_state_dict' : student_net.state_dict(), 
                'epoch' : num_epochs}, save_path)


Training with hparamsT=10, alpha=0.5, dropout_hidden=0.0, dropout_input=0.0, lr=0.01, lr_decay=0.95, momentum=0.9, weight_decay=1e-05
[1,   100/  469] train loss: 1.911 train accuracy: 0.891
[1,   200/  469] train loss: 1.322 train accuracy: 0.930
[1,   300/  469] train loss: 1.071 train accuracy: 0.969
[1,   400/  469] train loss: 0.722 train accuracy: 0.953
[2,   100/  469] train loss: 0.664 train accuracy: 0.961
[2,   200/  469] train loss: 0.582 train accuracy: 0.953
[2,   300/  469] train loss: 0.429 train accuracy: 0.984
[2,   400/  469] train loss: 0.480 train accuracy: 0.961
[3,   100/  469] train loss: 0.351 train accuracy: 0.992
[3,   200/  469] train loss: 0.297 train accuracy: 1.000
[3,   300/  469] train loss: 0.310 train accuracy: 0.984
[3,   400/  469] train loss: 0.308 train accuracy: 0.992
[4,   100/  469] train loss: 0.246 train accuracy: 0.977
[4,   200/  469] train loss: 0.321 train accuracy: 0.969
[4,   300/  469] train loss: 0.250 train accuracy: 0.984
[4,   400/ 

[36,   100/  469] train loss: 0.065 train accuracy: 1.000
[36,   200/  469] train loss: 0.066 train accuracy: 1.000
[36,   300/  469] train loss: 0.079 train accuracy: 0.992
[36,   400/  469] train loss: 0.068 train accuracy: 1.000
[37,   100/  469] train loss: 0.088 train accuracy: 0.984
[37,   200/  469] train loss: 0.078 train accuracy: 0.984
[37,   300/  469] train loss: 0.062 train accuracy: 1.000
[37,   400/  469] train loss: 0.089 train accuracy: 0.984
[38,   100/  469] train loss: 0.099 train accuracy: 0.977
[38,   200/  469] train loss: 0.079 train accuracy: 0.992
[38,   300/  469] train loss: 0.077 train accuracy: 0.984
[38,   400/  469] train loss: 0.078 train accuracy: 1.000
[39,   100/  469] train loss: 0.063 train accuracy: 1.000
[39,   200/  469] train loss: 0.082 train accuracy: 1.000
[39,   300/  469] train loss: 0.090 train accuracy: 0.984
[39,   400/  469] train loss: 0.061 train accuracy: 1.000
[40,   100/  469] train loss: 0.086 train accuracy: 0.984
[40,   200/  4

In [13]:
# Calculate student test accuracy
_, test_accuracy = utils.getLossAccuracyOnDataset(student_net, test_loader, fast_device)
print('student test accuracy (w distillation): ', test_accuracy)

student test accuracy (w distillation):  0.9866
