<a href="https://colab.research.google.com/github/melanAm/Knowledge-Distillation/blob/teacher/KnowledgeDistillation_MLP_teacher_mnist_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This is the implementation of the paper "Distilling the Knowledge in a Neural Network", Preliminalry experiments on mnist, Teacher network training code

In [1]:
#import required packages
import numpy as np
import math
import torch
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import SGD, lr_scheduler
import matplotlib.pyplot as plt
import os
import copy
import json
import time

In [2]:
#hyper parameters
batch_size = 100
num_epochs = 100
init_lr = 0.01
momentum=0.9
num_class = 10
l = 15.0           #maximum square length of each neuron weight vector
random_seed = 42
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
if os.path.exists('/content/dataset')==False:
  os.mkdir('dataset')

#Dataset

In [4]:
#load mnist dataset
path = '/content/dataset'
train_transform = transforms.Compose([transforms.RandomCrop(size=(28,28),padding=(2,)),transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])
val_transform =  transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])
train_dataset = datasets.MNIST(root=path,train=True,download=True,transform=train_transform)
val_dataset = datasets.MNIST(root=path,train=False,download=True,transform=val_transform)
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=2,pin_memory=torch.cuda.is_available())
val_loader = DataLoader(val_dataset,batch_size=batch_size,shuffle=False,num_workers=2,pin_memory=torch.cuda.is_available())

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /content/dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 17.5MB/s]


Extracting /content/dataset/MNIST/raw/train-images-idx3-ubyte.gz to /content/dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /content/dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 482kB/s]


Extracting /content/dataset/MNIST/raw/train-labels-idx1-ubyte.gz to /content/dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /content/dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.40MB/s]


Extracting /content/dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to /content/dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /content/dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 3.13MB/s]

Extracting /content/dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to /content/dataset/MNIST/raw






#Neural Network

In [5]:
class TeacherNet(nn.Module):
  def __init__(self):
    super(TeacherNet,self).__init__()
    self.fc1 = nn.Linear(in_features=784,out_features=1200,bias=True)
    self.fc2 = nn.Linear(in_features=1200,out_features=1200,bias=True)
    self.layer_out = nn.Linear(in_features=1200,out_features=10,bias=True)
    self.layers = [self.fc1,self.fc2,self.layer_out]
    self.initialize()

  def initialize(self):
    for layer in self.layers:
      torch.nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')  #He initialization
      torch.nn.init.constant_(layer.bias,val=0.0)

  def forward(self,x):
    x = x.view(-1,784)
    x = F.dropout(x,p=0.2,training=self.training)
    x = F.relu(self.fc1(x))
    x = F.dropout(x,p=0.5,training=self.training)
    x = F.relu(self.fc2(x))
    x = F.dropout(x,p=0.5,training=self.training)
    x = self.layer_out(x)
    return x

In [6]:
torch.manual_seed(random_seed)
teacher = TeacherNet().to(device)

#Training

In [7]:
criterion = nn.CrossEntropyLoss()

In [8]:
optimizer = optim.SGD(teacher.parameters(),lr=init_lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, threshold=0.001, threshold_mode='rel')

In [9]:
if os.path.exists('/content/teacher')==False:
  os.mkdir('/content/teacher')

In [10]:
def log(t):
        logname = os.path.join('/content/teacher', 'KnowledgeDistillation_MLP_teacher_mnist_v2.txt')
        with open(logname, 'a') as f:
            f.write('json_stats: ' + json.dumps(t) + '\n')
        print(t)

In [None]:
def train(model,criterion,optimizer,sheduler,num_epochs):
  bestparams = copy.deepcopy(model.state_dict())
  best_correct = 0
  time_start = time.time()
  for epoch in range(num_epochs):
    train_loss = 0.
    val_loss = 0.
    train_correct = 0
    val_correct = 0
    #training phase
    model.train()
    for x,y in train_loader:
      x = x.to(device)
      y = y.to(device)
      optimizer.zero_grad()
      y_hat = model.forward(x)
      loss = criterion(y_hat,y)
      loss.backward()
      optimizer.step()
      with torch.no_grad():
        for layer in model.layers:
            norm = torch.linalg.vector_norm(p,dim=1,keepdim=True)
            mult_coef =  norm.reciprocal_().mul_(np.sqrt(l)).clamp_(max=1.0)
            layer.weight.mul_(mult_coef)

    model.eval()
    for x,y in train_loader:
      x = x.to(device)
      y = y.to(device)
      y_hat = model.forward(x)
      loss = criterion(y_hat,y)
      train_loss+= loss.item()*x.size(0)
      train_correct += (((torch.argmax(y_hat,dim=1)==y)).sum()).item()
    train_loss = train_loss/len(train_dataset)
    train_error = len(train_dataset)-train_correct
    train_acc = train_correct/len(train_dataset)

    #validation phase
    model.eval()
    for x,y in val_loader:
      x = x.to(device)
      y = y.to(device)
      y_hat = model.forward(x)
      loss = criterion(y_hat,y)
      val_loss += loss.item()*x.size(0)
      val_correct += (((torch.max(y_hat,dim=1)[1]==y)).sum()).item()
    val_loss = val_loss/len(val_dataset)
    val_error = len(val_dataset)-val_correct
    val_acc = val_correct/len(val_dataset)

    scheduler.step(val_loss)
    if epoch%10==0:
      scheduler.get_last_lr()

    if val_correct > best_correct:
      bestparams = copy.deepcopy(model.state_dict())
      best_correct = val_correct

    log({
            "epoch": epoch+1,
            "train_loss": train_loss,
            "train_correct": train_correct,
            "train_error": train_error,
            "train_acc" : train_acc,
            "val_loss": val_loss,
            "val_correct": val_correct,
            "val_error": val_error,
            "val_acc" : val_acc,
           })
  time_fin = time.time()-time_start
  print('time: {}'.format(time_fin))
  return bestparams

In [None]:
optimized_params = train(teacher,criterion,optimizer,scheduler,num_epochs=num_epochs)

{'epoch': 1, 'train_loss': 0.8025762234876553, 'train_correct': 44393, 'train_error': 15607, 'train_acc': 0.7398833333333333, 'val_loss': 0.16632831441238521, 'val_correct': 9539, 'val_error': 461, 'val_acc': 0.9539}
{'epoch': 2, 'train_loss': 0.30193064662317437, 'train_correct': 54360, 'train_error': 5640, 'train_acc': 0.906, 'val_loss': 0.1160104529792443, 'val_correct': 9654, 'val_error': 346, 'val_acc': 0.9654}
{'epoch': 3, 'train_loss': 0.2386501281087597, 'train_correct': 55557, 'train_error': 4443, 'train_acc': 0.92595, 'val_loss': 0.08764601001748815, 'val_correct': 9740, 'val_error': 260, 'val_acc': 0.974}
{'epoch': 4, 'train_loss': 0.2074875088594854, 'train_correct': 56122, 'train_error': 3878, 'train_acc': 0.9353666666666667, 'val_loss': 0.07507639348390512, 'val_correct': 9763, 'val_error': 237, 'val_acc': 0.9763}
{'epoch': 5, 'train_loss': 0.18827150960142414, 'train_correct': 56543, 'train_error': 3457, 'train_acc': 0.9423833333333334, 'val_loss': 0.07437455941690133, '

In [None]:
torch.save(optimized_params, '/content/teacher/KnowledgeDistillation_MLP_teacher_mnist_v2.pth.tar')
torch.save(dict(params={k: v.data for k, v in optimized_params.items()}),os.path.join('/content/teacher', 'KnowledgeDistillation_MLP_teacher_mnist_v2.pt7'))
torch.save({k: v.data for k, v in optimized_params.items()},'/content/teacher/KnowledgeDistillation_MLP_teacher_mnist_v2')