### Import required packages and limit GPU usage

In [1]:
import numpy as np
import math

%matplotlib inline
import matplotlib.pyplot as plt

import pickle
import argparse
import time
import itertools
from copy import deepcopy
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import networks
import utils
    
%load_ext autoreload
%autoreload 2

In [2]:
use_gpu = True    # set use_gpu to True if system has gpu
gpu_id = 2        # id of gpu to be used
cpu_device = torch.device('cpu')
# fast_device is where computation (training, inference) happens
fast_device = torch.device('cpu')
if use_gpu:
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2'    # set visible devices depending on system configuration
    fast_device = torch.device('cuda:' + str(gpu_id))

In [3]:
def reproducibilitySeed():
    """
    Ensure reproducibility of results; Seeds to 0
    """
    torch_init_seed = 0
    torch.manual_seed(torch_init_seed)
    numpy_init_seed = 0
    np.random.seed(numpy_init_seed)
    if use_gpu:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

reproducibilitySeed()

### Load dataset

In [4]:
import torchvision
import torchvision.transforms as transforms

mnist_image_shape = (28, 28)
random_pad_size = 2
# Training images augmented by randomly shifting images by at max. 2 pixels in any of 4 directions
transform_train = transforms.Compose(
                [
                    transforms.RandomCrop(mnist_image_shape, random_pad_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5), (0.5, 0.5))
                ]
            )

transform_test = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5), (0.5, 0.5))
                ]
            )

train_val_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=True, 
                                            download=True, transform=transform_train)

test_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=False, 
                                            download=True, transform=transform_test)

num_train = int(1.0 * len(train_val_dataset) * 95 / 100)
num_val = len(train_val_dataset) - num_train
train_dataset, val_dataset = torch.utils.data.random_split(train_val_dataset, [num_train, num_val])

batch_size = 128
train_val_loader = torch.utils.data.DataLoader(train_val_dataset, batch_size=128, shuffle=True, num_workers=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

### Train teacher network

In [5]:
checkpoints_path = 'checkpoints_teacher/'
if not os.path.exists(checkpoints_path):
    os.mkdir(checkpoints_path)

In [6]:
num_epochs = 60
print_every = 100    # Interval size for which to print statistics of training

In [7]:
# Hyperparamters can be tuned by setting required range below
# learning_rates = list(np.logspace(-4, -2, 3))
learning_rates = [1e-2]
learning_rate_decays = [0.95]    # learning rate decays at every epoch
# weight_decays = [0.0] + list(np.logspace(-5, -1, 5))
weight_decays = [1e-5]           # regularization weight
momentums = [0.9]
# dropout_probabilities = [(0.2, 0.5), (0.0, 0.0)]
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []
for hparam_tuple in itertools.product(dropout_probabilities, weight_decays, learning_rate_decays, 
                                        momentums, learning_rates):
    hparam = {}
    hparam['dropout_input'] = hparam_tuple[0][0]
    hparam['dropout_hidden'] = hparam_tuple[0][1]
    hparam['weight_decay'] = hparam_tuple[1]
    hparam['lr_decay'] = hparam_tuple[2]
    hparam['momentum'] = hparam_tuple[3]
    hparam['lr'] = hparam_tuple[4]
    hparams_list.append(hparam)

results = {}
for hparam in hparams_list:
    print('Training with hparams' + utils.hparamToString(hparam))
    reproducibilitySeed()
    teacher_net = networks.TeacherNetwork()
    teacher_net = teacher_net.to(fast_device)
    hparam_tuple = utils.hparamDictToTuple(hparam)
    results[hparam_tuple] = utils.trainTeacherOnHparam(teacher_net, hparam, num_epochs, 
                                                        train_val_loader, None, 
                                                        print_every=print_every, 
                                                        fast_device=fast_device)
    save_path = checkpoints_path + utils.hparamToString(hparam) + '_final.tar'
    torch.save({'results' : results[hparam_tuple], 
                'model_state_dict' : teacher_net.state_dict(), 
                'epoch' : num_epochs}, save_path)

Training with hparamsdropout_hidden=0.0, dropout_input=0.0, lr=0.01, lr_decay=0.95, momentum=0.9, weight_decay=1e-05
[1,   100/  469] train loss: 0.936 train accuracy: 0.766
[1,   200/  469] train loss: 0.717 train accuracy: 0.805
[1,   300/  469] train loss: 0.548 train accuracy: 0.852
[1,   400/  469] train loss: 0.311 train accuracy: 0.922
[2,   100/  469] train loss: 0.327 train accuracy: 0.891
[2,   200/  469] train loss: 0.301 train accuracy: 0.906
[2,   300/  469] train loss: 0.187 train accuracy: 0.961
[2,   400/  469] train loss: 0.248 train accuracy: 0.914
[3,   100/  469] train loss: 0.152 train accuracy: 0.977
[3,   200/  469] train loss: 0.194 train accuracy: 0.953
[3,   300/  469] train loss: 0.171 train accuracy: 0.961
[3,   400/  469] train loss: 0.145 train accuracy: 0.961
[4,   100/  469] train loss: 0.297 train accuracy: 0.914
[4,   200/  469] train loss: 0.166 train accuracy: 0.969
[4,   300/  469] train loss: 0.249 train accuracy: 0.930
[4,   400/  469] train loss:

[36,   100/  469] train loss: 0.094 train accuracy: 0.984
[36,   200/  469] train loss: 0.034 train accuracy: 0.984
[36,   300/  469] train loss: 0.057 train accuracy: 0.992
[36,   400/  469] train loss: 0.017 train accuracy: 0.992
[37,   100/  469] train loss: 0.070 train accuracy: 0.961
[37,   200/  469] train loss: 0.046 train accuracy: 0.984
[37,   300/  469] train loss: 0.028 train accuracy: 0.984
[37,   400/  469] train loss: 0.006 train accuracy: 1.000
[38,   100/  469] train loss: 0.017 train accuracy: 1.000
[38,   200/  469] train loss: 0.096 train accuracy: 0.977
[38,   300/  469] train loss: 0.050 train accuracy: 0.992
[38,   400/  469] train loss: 0.017 train accuracy: 0.992
[39,   100/  469] train loss: 0.019 train accuracy: 0.992
[39,   200/  469] train loss: 0.032 train accuracy: 0.984
[39,   300/  469] train loss: 0.137 train accuracy: 0.961
[39,   400/  469] train loss: 0.071 train accuracy: 0.953
[40,   100/  469] train loss: 0.077 train accuracy: 0.977
[40,   200/  4

In [8]:
# Calculate test accuracy
_, test_accuracy = utils.getLossAccuracyOnDataset(teacher_net, test_loader, fast_device)
print('test accuracy: ', test_accuracy)

test accuracy:  0.9892
