<a href="https://colab.research.google.com/github/melanAm/Knowledge-Distillation/blob/teacher/KnowledgeDistillation_MLP_teacher_mnist_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This is the implementation of the paper "Distilling the Knowledge in a Neural Network", Preliminalry experiments on mnist, Teacher network training code

In [1]:
#import required packages
import numpy as np
import math
import torch
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import SGD, lr_scheduler
import matplotlib.pyplot as plt
import os
import copy
import json
import time

In [19]:
#hyper parameters
batch_size = 100
num_epochs = 100
init_lr = 0.01
momentum=0.9
num_class = 10
l = 15.0           #maximum square length of each neuron weight vector
random_seed = 42
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
if os.path.exists('/content/dataset')==False:
  os.mkdir('dataset')

#Dataset

In [4]:
#load mnist dataset
path = '/content/dataset'
train_transform = transforms.Compose([transforms.RandomCrop(size=(28,28),padding=(2,)),transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])
val_transform =  transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])
train_dataset = datasets.MNIST(root=path,train=True,download=True,transform=train_transform)
val_dataset = datasets.MNIST(root=path,train=False,download=True,transform=val_transform)
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=2,pin_memory=torch.cuda.is_available())
val_loader = DataLoader(val_dataset,batch_size=batch_size,shuffle=False,num_workers=2,pin_memory=torch.cuda.is_available())

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /content/dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.1MB/s]


Extracting /content/dataset/MNIST/raw/train-images-idx3-ubyte.gz to /content/dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /content/dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 482kB/s]


Extracting /content/dataset/MNIST/raw/train-labels-idx1-ubyte.gz to /content/dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /content/dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.42MB/s]


Extracting /content/dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to /content/dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /content/dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 9.95MB/s]

Extracting /content/dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to /content/dataset/MNIST/raw






#Neural Network

In [5]:
class TeacherNet(nn.Module):
  def __init__(self):
    super(TeacherNet,self).__init__()
    self.fc1 = nn.Linear(in_features=784,out_features=1200,bias=True)
    self.fc2 = nn.Linear(in_features=1200,out_features=1200,bias=True)
    self.layer_out = nn.Linear(in_features=1200,out_features=10,bias=True)
    self.layers = [self.fc1,self.fc2,self.layer_out]
    self.initialize()

  def initialize(self):
    for layer in self.layers:
      torch.nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')  #He initialization
      torch.nn.init.constant_(layer.bias,val=0.0)

  def forward(self,x):
    x = x.view(-1,784)
    x = F.dropout(x,p=0.2,training=self.training)
    x = F.relu(self.fc1(x))
    x = F.dropout(x,p=0.5,training=self.training)
    x = F.relu(self.fc2(x))
    x = F.dropout(x,p=0.5,training=self.training)
    x = self.layer_out(x)
    return x

In [24]:
use_gpu = True
def reproducibilitySeed():
    torch_init_seed = 42
    torch.manual_seed(torch_init_seed)
    numpy_init_seed = 42
    np.random.seed(numpy_init_seed)
    if use_gpu:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

reproducibilitySeed()

In [25]:
teacher = TeacherNet().to(device)

#Training

In [8]:
criterion = nn.CrossEntropyLoss()

In [26]:
optimizer = optim.SGD(teacher.parameters(),lr=init_lr,momentum=momentum)
# scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=30,gamma=0.1)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, threshold=0.001, threshold_mode='rel')

In [10]:
if os.path.exists('/content/teacher')==False:
  os.mkdir('/content/teacher')

In [11]:
def log(t):
        logname = os.path.join('/content/teacher', 'KnowledgeDistillation_MLP_teacher_mnist_v2.txt')
        with open(logname, 'a') as f:
            f.write('json_stats: ' + json.dumps(t) + '\n')
        print(t)

In [27]:
def train(model,criterion,optimizer,sheduler,num_epochs):
  bestparams = copy.deepcopy(model.state_dict())
  best_correct = 0
  time_start = time.time()
  for epoch in range(num_epochs):
    train_loss = 0.
    val_loss = 0.
    train_correct = 0
    val_correct = 0
    #training phase
    model.train()
    for x,y in train_loader:
      x = x.to(device)
      y = y.to(device)
      optimizer.zero_grad()
      y_hat = model.forward(x)
      loss = criterion(y_hat,y)
      loss.backward()
      optimizer.step()
      with torch.no_grad():
        for layer in model.layers:
            norm = torch.linalg.vector_norm(layer.weight,dim=1,keepdim=True)
            mult_coef =  norm.reciprocal_().mul_(np.sqrt(l)).clamp_(max=1.0)
            layer.weight.mul_(mult_coef)

    model.eval()
    for x,y in train_loader:
      x = x.to(device)
      y = y.to(device)
      y_hat = model.forward(x)
      loss = criterion(y_hat,y)
      train_loss+= loss.item()*x.size(0)
      train_correct += (((torch.argmax(y_hat,dim=1)==y)).sum()).item()
    train_loss = train_loss/len(train_dataset)
    train_error = len(train_dataset)-train_correct
    train_acc = train_correct/len(train_dataset)

    #validation phase
    model.eval()
    for x,y in val_loader:
      x = x.to(device)
      y = y.to(device)
      y_hat = model.forward(x)
      loss = criterion(y_hat,y)
      val_loss += loss.item()*x.size(0)
      val_correct += (((torch.max(y_hat,dim=1)[1]==y)).sum()).item()
    val_loss = val_loss/len(val_dataset)
    val_error = len(val_dataset)-val_correct
    val_acc = val_correct/len(val_dataset)

    # scheduler.step()
    scheduler.step(val_loss)
    if epoch%10==0:
      print(scheduler.get_last_lr())

    if val_correct > best_correct:
      bestparams = copy.deepcopy(model.state_dict())
      best_correct = val_correct

    log({
            "epoch": epoch+1,
            "train_loss": train_loss,
            "train_correct": train_correct,
            "train_error": train_error,
            "train_acc" : train_acc,
            "val_loss": val_loss,
            "val_correct": val_correct,
            "val_error": val_error,
            "val_acc" : val_acc,
           })
  time_fin = time.time()-time_start
  print('time: {}'.format(time_fin))
  return bestparams

In [28]:
optimized_params = train(teacher,criterion,optimizer,scheduler,num_epochs=num_epochs)

[0.01]
{'epoch': 1, 'train_loss': 0.3117865451176961, 'train_correct': 54795, 'train_error': 5205, 'train_acc': 0.91325, 'val_loss': 0.2238468413800001, 'val_correct': 9479, 'val_error': 521, 'val_acc': 0.9479}
{'epoch': 2, 'train_loss': 0.2175270348911484, 'train_correct': 56367, 'train_error': 3633, 'train_acc': 0.93945, 'val_loss': 0.16201283716596662, 'val_correct': 9609, 'val_error': 391, 'val_acc': 0.9609}
{'epoch': 3, 'train_loss': 0.1826224053154389, 'train_correct': 56955, 'train_error': 3045, 'train_acc': 0.94925, 'val_loss': 0.1360715241637081, 'val_correct': 9673, 'val_error': 327, 'val_acc': 0.9673}
{'epoch': 4, 'train_loss': 0.159359254638354, 'train_correct': 57311, 'train_error': 2689, 'train_acc': 0.9551833333333334, 'val_loss': 0.12191277358215302, 'val_correct': 9700, 'val_error': 300, 'val_acc': 0.97}
{'epoch': 5, 'train_loss': 0.13997789758568008, 'train_correct': 57577, 'train_error': 2423, 'train_acc': 0.9596166666666667, 'val_loss': 0.10481890305643901, 'val_cor

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a7a080e68c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a7a080e68c0>if w.is_alive():

Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
      File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    assert self._parent_pid == os.getpid(), 'can only test a child process'self._shutdown_workers()

AssertionError:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    can only test a child process
if w.is_alive():
  File "/usr/lib/

{'epoch': 75, 'train_loss': 0.03546949119462321, 'train_correct': 59366, 'train_error': 634, 'train_acc': 0.9894333333333334, 'val_loss': 0.033433804888918534, 'val_correct': 9908, 'val_error': 92, 'val_acc': 0.9908}


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a7a080e68c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a7a080e68c0>  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive

    Traceback (most recent call last):
assert self._parent_pid == os.getpid(), 'can only test a child process'
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
AssertionError    : self._shutdown_workers()can only test a child process

  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
Exception ignored in:     if w.is_aliv

{'epoch': 76, 'train_loss': 0.03624433673763027, 'train_correct': 59390, 'train_error': 610, 'train_acc': 0.9898333333333333, 'val_loss': 0.03430358037614496, 'val_correct': 9906, 'val_error': 94, 'val_acc': 0.9906}
{'epoch': 77, 'train_loss': 0.03526644659539064, 'train_correct': 59400, 'train_error': 600, 'train_acc': 0.99, 'val_loss': 0.03318826966002234, 'val_correct': 9904, 'val_error': 96, 'val_acc': 0.9904}
{'epoch': 78, 'train_loss': 0.03577073835923026, 'train_correct': 59358, 'train_error': 642, 'train_acc': 0.9893, 'val_loss': 0.03399814252217766, 'val_correct': 9904, 'val_error': 96, 'val_acc': 0.9904}
{'epoch': 79, 'train_loss': 0.03418730824564894, 'train_correct': 59399, 'train_error': 601, 'train_acc': 0.9899833333333333, 'val_loss': 0.03372359765664441, 'val_correct': 9906, 'val_error': 94, 'val_acc': 0.9906}
{'epoch': 80, 'train_loss': 0.035308166408988956, 'train_correct': 59402, 'train_error': 598, 'train_acc': 0.9900333333333333, 'val_loss': 0.033181455769808965, '

In [29]:
torch.save(optimized_params, '/content/teacher/KnowledgeDistillation_MLP_teacher_mnist_v2.pth.tar')
torch.save(dict(params={k: v.data for k, v in optimized_params.items()}),os.path.join('/content/teacher', 'KnowledgeDistillation_MLP_teacher_mnist_v2.pt7'))
torch.save({k: v.data for k, v in optimized_params.items()},'/content/teacher/KnowledgeDistillation_MLP_teacher_mnist_v2')