# Preamble

### About
This notebook provides an interface for experimenting with three knowledge distillation methods (Hinton et al.’s (2014), Romero et al.’s (2015) and Yim et al.’s (2017)) on CIFAR-100 (Krizhevsky, 2009). An account at Weights and Biases (https://www.wandb.ai) is required for logging results---your API key will be requested during initialisation.

The implementations of each knowledge distillation method can be found at https://github.com/hnsyprst/DistillationComparison.

### Instructions
Set the paths in the 'Setup Save and Load Paths' cell (under 'Initialisation'). TEACHER_PATH should point to a ResNet50 model trained on CIFAR-100 (we provide a link to the model used in our experiments in the repository above). SAVE_PATH should point to the location where the student will be saved post-distillation. Change the parameters in the 'Launch a New Experiment' cell (under 'Conduct Experiments') to setup a new experiment. Restart the notebook and run all cells. The results will be logged in the specified project on Weights and Biases. 

### References:
Krizhevsky, A. (2009) Learning Multiple Layers of Features from Tiny Images. University of Toronto.

Liu, S., Johns, E. and Davison, A.J. (2019)
‘End-to-End Multi-Task Learning with Attention’. arXiv. Available at: http://arxiv.org/abs/1803.10704 (Accessed: 31 October 2022).

Omelchenko, I. (2020) pytorch - set seed everything, PyTorch Seed Everything. Available at: https://gist.github.com/ihoromi4/b681a9088f348942b01711f251e5f964 (Accessed: 20 November 2022).

Romero, A., Ballas, N., Kahou, S.E., Chassang, A., Gatta, C. and Bengio, Y. (2015) 
‘FitNets: Hints for Thin Deep Nets’. arXiv. Available at: http://arxiv.org/abs/1412.6550 (Accessed: 31 August 2022).

Yim, J., Joo, D., Bae, J. and Kim, J. (2017) ‘A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning’,
in 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR). 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Honolulu, HI: IEEE, pp. 7130–7138. Available at: https://doi.org/10.1109/CVPR.2017.754 (Accessed: 20 November 2022).

# Initialisation

In [None]:
#@title Setup Save and Load Paths
#@markdown Google Drive is mounted here

from google.colab import drive
drive.mount('/content/gdrive/')

TEACHER_PATH = '/content/gdrive/MyDrive/TrainedModels/ResNets/ResNet50/4712.tar' #@param {type:"string"}
SAVE_PATH = '/content/gdrive/MyDrive/TrainedModels/newDistillationExperiments/CIFAR100/' #@param {type:"string"}

In [None]:
#@title Download Repo and Login to Weights & Biases
#@markdown 

!git clone https://github.com/hnsyprst/DistillationComparison.git
!cp -a /content/DistillationComparison/. /content/
!rm -r /content/DistillationComparison/

!pip install wandb
!wandb login

In [None]:
#@title Import and Download Libraries

%load_ext autoreload
%autoreload 2

import torch
import torch.profiler
import torchvision
from torch import nn
from torch.utils import data
from torchvision import transforms
from torchvision import models
from torchvision.models.resnet import resnet101, resnet50, resnet34, resnet18

!pip install torchmetrics
import torchmetrics

!pip install fvcore
from fvcore.nn import FlopCountAnalysis

!pip install ptflops
from ptflops import get_model_complexity_info

import training_utils as utils
import network_utils as nutils

import distillation_methods_module

import copy
import numpy as np
import random
import os
import wandb

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

print(device)

torch.autograd.set_detect_anomaly(False)
torch.autograd.profiler.profile(False)
torch.autograd.profiler.emit_nvtx(False)

In [None]:
#@title Download the Training and Test Sets
mean = 0.5070751592371323, 0.48654887331495095, 0.4409178433670343
std = 0.2673342858792401, 0.2564384629170883, 0.27615047132568404

trans_train = transforms.Compose([transforms.RandomHorizontalFlip(),
                                  transforms.ToTensor(),
                                  transforms.Normalize(mean, std)])
trans_test = transforms.Compose([transforms.ToTensor(),
                                 transforms.Normalize(mean, std)])

mnist_train = torchvision.datasets.CIFAR100(
    root="../data", train=True, transform=trans_train, download=True)
mnist_test = torchvision.datasets.CIFAR100(
    root="../data", train=False, transform=trans_test, download=True)

batch_size = 512

train_iter = data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_iter = data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

# Helper Functions

In [None]:
#@title Helper Function For Seeding Relevant Random Number Generators
#@markdown Code modified from Omelchenko's Gist (2020)

def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
#@title Helper Functions for Saving and Loading Models

def load_model(model_type, pretrained=False, path=None):
    num_classes = 100

    model = copy.deepcopy(model_type)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)

    if pretrained:
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']
    
    return model

def save_model(model, optimizer, history, epoch, path):
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': history[1][-1]
    }, path)

In [None]:
#@title Helper Function for Getting Name from a Layer

def get_layer_name(model, module, layer_index):
    return nutils.get_name_from_layer(model, nutils.get_layer_in_module_from_index(module, layer_index))

In [None]:
#@title Functions to Tidy Relations and Features Distillation Settings

def relations_hint_guided_layer_settings(teacher, student, teacher_output_layer_index, student_output_layer_index):
    hint_layer_start_1 =    get_layer_name(teacher, teacher.layer1[0], teacher_output_layer_index)
    hint_layer_end_1 =      get_layer_name(teacher, teacher.layer1[-1], teacher_output_layer_index)

    hint_layer_start_2 =    get_layer_name(teacher, teacher.layer2[0], teacher_output_layer_index)
    hint_layer_end_2 =      get_layer_name(teacher, teacher.layer2[-1], teacher_output_layer_index)

    hint_layer_start_3 =    get_layer_name(teacher, teacher.layer3[0], teacher_output_layer_index)
    hint_layer_end_3 =      get_layer_name(teacher, teacher.layer3[-1], teacher_output_layer_index)

    hint_layers =           [(hint_layer_start_1, hint_layer_end_1), (hint_layer_start_2, hint_layer_end_2), (hint_layer_start_3, hint_layer_end_3)]


    guided_layer_start_1 =  get_layer_name(student, student.layer1[0], student_output_layer_index)
    guided_layer_end_1 =    get_layer_name(student, student.layer1[-1], student_output_layer_index)

    guided_layer_start_2 =  get_layer_name(student, student.layer2[0], student_output_layer_index)
    guided_layer_end_2 =    get_layer_name(student, student.layer2[-1], student_output_layer_index)

    guided_layer_start_3 =  get_layer_name(student, student.layer3[0], student_output_layer_index)
    guided_layer_end_3 =    get_layer_name(student, student.layer3[-1], student_output_layer_index)

    guided_layers =         [(guided_layer_start_1, guided_layer_end_1), (guided_layer_start_2, guided_layer_end_2), (guided_layer_start_3, guided_layer_end_3)]

    return hint_layers, guided_layers

def features_hint_guided_layer_settings(teacher, student, teacher_output_layer_index, student_output_layer_index):
    hint_layer = nutils.get_layer_in_module_from_index(teacher.layer3[-1], teacher_output_layer_index)
    guided_layer = nutils.get_layer_in_module_from_index(student.layer3[-1], student_output_layer_index)

    return hint_layer, guided_layer

In [None]:
#@title Helper Function to Retrieve Model Complexity Information (MACs, parameter count)

def get_model_info(model):
    model.eval()
    _, params = get_model_complexity_info(model, (3, 32, 32), as_strings=False, print_per_layer_stat=False, verbose=False)

    input = torch.ones(1, 3, 32, 32).to(device)
    mac_counter = FlopCountAnalysis(model, input)
    macs = mac_counter.total()

    return params, macs

In [None]:
#@title Function for Launching a New Experiment

def new_distillation_run(distiller_name, model_name, run_number, project_name, num_epochs, pretrained=False):
    teacher_model_name = "resnet50"
    teacher_model = models.resnet50(pretrained=False)

    if model_name == "resnet18":
        student_model = models.resnet18(pretrained=pretrained)
    elif model_name == "resnet34":
        student_model = models.resnet34(pretrained=pretrained)

    if run_number == 0:
        seed = 1059
    elif run_number == 1:
        seed = 3056
    elif run_number == 2:
        seed = 4967

    seed_everything(seed)

    student = load_model(student_model)
    student.to(device)
    teacher = load_model(teacher_model, pretrained=True, path=TEACHER_PATH)
    teacher.to(device)
    
    lr = 1e-3
    loss_fn = nn.CrossEntropyLoss(reduction='none').to(device)
    optimizer = torch.optim.RMSprop(student.parameters(), lr=lr)
    params, macs = get_model_info(student)

    distiller = setup_distiller(teacher, student, optimizer)[distiller_name]

    wandb.init(group="default-distillation", project=project_name, config={
              "subgroup": "logits-bigtest",
              "model": model_name,
              "teacher": teacher_model_name,
              "learning_rate": lr,
              "epochs": num_epochs,
              "batch_size": batch_size,
              "seed": seed,
              "macs": macs,
              "params": params})
        
    history = distiller.train(train_iter, test_iter, num_epochs, wandb_log=True)
    save_model(student, optimizer, history, num_epochs-1, '{0}/{1}/{2}_{3}.tar'.format(SAVE_PATH, distiller_name, model_name, run_number))

    top_1 = utils.evaluate_accuracy(student, test_iter, top_k=1)
    top_5 = utils.evaluate_accuracy(student, test_iter, top_k=5)
    print(top_1, top_5)
    wandb.config.update({"top_1": top_1, "top_5": top_5})

In [None]:
#@title Function for Seting Up Distillers

def setup_distiller(teacher, student, optimizer):
    hint_layers, guided_layers = relations_hint_guided_layer_settings(teacher, student, -1, -1)
    hint_layer, guided_layer = features_hint_guided_layer_settings(teacher, student, -3, -2)

    distiller_dict = {"logits":       distillation_methods_module.Logits_Distiller(temp=7, hard_loss_weight=0.2, teacher=teacher, student=student, optimizer=optimizer),
                      "features":     distillation_methods_module.Features_Distiller(hint_layer=hint_layer, guided_layer=guided_layer, is_2D=True, temp=7, hard_loss_weight=0.05, teacher=teacher, student=student, optimizer=optimizer),
                      "relations":    distillation_methods_module.Relations_Distiller(hint_layers=hint_layers, guided_layers=guided_layers, teacher=teacher, student=student, optimizer=optimizer),
                      "logits-DWA":   distillation_methods_module.Logits_Distiller_DWA(temp=7, weight_temp=2, teacher=teacher, student=student, optimizer=optimizer),
                      "features-DWA": distillation_methods_module.Features_Distiller_DWA(hint_layer=hint_layer, guided_layer=guided_layer, is_2D=True, temp=7, weight_temp=2, teacher=teacher, student=student, optimizer=optimizer)}

    return distiller_dict

# Conduct Experiments

In [None]:
#@title Launch a New Experiment

#@markdown ---
#@markdown #### Select a distillation method:
distiller_name = "logits" #@param ["logits", "features", "relations", "logits-DWA", "features-DWA"] {type:"string"}
#@markdown ---
#@markdown #### Select a model architecture:
model_name = "resnet18" #@param ["resnet18", "resnet34"] {type:"string"}
#@markdown ---
#@markdown #### Select the run number to determine the seed to use:
run_number = 0 #@param [0, 1, 2] {type:"raw"}
#@markdown ---
#@markdown #### Enter the name of the project on wandb.me in which to log results:
project_name = "distillation-experiments" #@param {type:"string"}
#@markdown ---
#@markdown #### Enter the number of epochs:
num_epochs = 25 #@param
#@markdown ---
#@markdown #### Choose whether students should be initialised with weights pretrained on ImageNet:
pretrained = False #@param {type:"boolean"}

new_distillation_run(distiller_name, model_name, run_number, project_name, num_epochs, pretrained=pretrained)