# Distilling the Knowledge in a Neural Network

Geoffrey Hinton, Oriol Vinyals, Jeff Dean (2015)

*Paper*: [arXiv link](https://arxiv.org/pdf/1503.02531.pdf)

A very simple way to improve the performance of almost any machine learning algorithm is to train many different models on the same data and then to average their predictions. Unfortunately, making predictions using a whole ensemble of models is cumbersome and may be too computationally expensive to allow deployment to a large number of users, especially if the individual models are large neural nets. Caruana and his collaborators have shown that it is possible to compress the knowledge in an ensemble into a single model which is much easier to deploy and we develop this approach further using a different compression technique. We achieve some surprising results on MNIST and we show that we can significantly improve the acoustic model of a heavily used commercial system by distilling the knowledge in an ensemble of models into a single model. We also introduce a new type of ensemble composed of one or more full models and many specialist models which learn to distinguish fine-grained classes that the full models confuse. Unlike a mixture of experts, these specialist models can be trained rapidly and in parallel. 

In [0]:
import numpy as np

%matplotlib inline
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

In [0]:
batch_size = 128
mnist_image_shape = (28, 28)

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               transforms.RandomCrop(mnist_image_shape, 2),
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size, shuffle=True)

val_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size, shuffle=True)

In [19]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
print("Working on:", device)

Working on: cuda:0


In [0]:
class TeacherNetwork(nn.Module):

    def __init__(self, dropout1=0.5, dropout2=0.5):
        super(TeacherNetwork, self).__init__()
        self.l1 = nn.Linear(28 * 28, 1200)
        self.relu = nn.ReLU()
        self.dropout1 = nn.Dropout(p=dropout1)
        self.l2 = nn.Linear(1200, 1200)
        self.dropout2 = nn.Dropout(p=dropout2)
        self.l3 = nn.Linear(1200, 10)
    
    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.l1(x)
        x = self.relu(x)
        x = self.dropout1(x)
        x = self.l2(x)
        x = self.relu(x)
        x = self.dropout2(x)
        x = self.l3(x)
        return x

In [0]:
def total_loss_accuracy(network, dataset_loader, device, lossFn=None):
    accuracy = 0.0
    loss = 0.0
    dataset_size = 0
    for j, Data in enumerate(dataset_loader, 0):
    	X, y = Data
    	X = X.to(device)
    	y = y.to(device)
    	with torch.no_grad():
    		pred = network(X)
    		if lossFn is not None:
    			loss += lossFn(pred, y) * y.shape[0]
    		accuracy += torch.sum(torch.argmax(pred, dim=1) == y).item()
    	dataset_size += y.shape[0]
    loss = loss / dataset_size
    accuracy = accuracy / dataset_size
    return loss, accuracy

In [22]:
teacherModel = TeacherNetwork().to(device)
print(teacherModel)

TeacherNetwork(
  (l1): Linear(in_features=784, out_features=1200, bias=True)
  (relu): ReLU()
  (dropout1): Dropout(p=0.5, inplace=False)
  (l2): Linear(in_features=1200, out_features=1200, bias=True)
  (dropout2): Dropout(p=0.5, inplace=False)
  (l3): Linear(in_features=1200, out_features=10, bias=True)
)


In [0]:
train_loss_main = []
train_acc = []
train_loss = []
val_acc = []
val_loss = []

In [24]:
train_epochs = 40
print_interval = 100
lr = 5e-3

lossFn = nn.CrossEntropyLoss()
optimizer = optim.Adam(teacherModel.parameters(), lr=lr)

for epoch in range(train_epochs):
    for batch_idx, (features, labels) in enumerate(train_loader):
        scores = teacherModel(features.to(device))
        loss = lossFn(scores, labels.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss_main.append(loss.item())

    print("Epoch: ", epoch, "->")
    loss1, acc = total_loss_accuracy(teacherModel, train_loader, device, lossFn)
    print("Accuracy Training:", acc, end="\t")
    loss, acc = total_loss_accuracy(teacherModel, val_loader, device, lossFn)
    print("Accuracy Validation:", acc)
    print("Loss Training:", loss1.item(), "\tLoss Validation:", loss.item())

print("Training Complete for ", train_epochs, " epochs.")

Epoch:  0 ->
Accuracy Training: 0.7708833333333334	Accuracy Validation: 0.8135
Loss Training: 0.8115366101264954 	Loss Validation: 0.6463208794593811
Epoch:  1 ->
Accuracy Training: 0.7841666666666667	Accuracy Validation: 0.8293
Loss Training: 0.8391323089599609 	Loss Validation: 0.6539083123207092
Epoch:  2 ->
Accuracy Training: 0.7935	Accuracy Validation: 0.8319
Loss Training: 0.7673523426055908 	Loss Validation: 0.5969228744506836
Epoch:  3 ->
Accuracy Training: 0.81095	Accuracy Validation: 0.8538
Loss Training: 0.7228739261627197 	Loss Validation: 0.5264754295349121
Epoch:  4 ->
Accuracy Training: 0.8152666666666667	Accuracy Validation: 0.8587
Loss Training: 0.6928107142448425 	Loss Validation: 0.5180292725563049
Epoch:  5 ->
Accuracy Training: 0.8268666666666666	Accuracy Validation: 0.8677
Loss Training: 0.6879979968070984 	Loss Validation: 0.5080196857452393
Epoch:  6 ->
Accuracy Training: 0.8128833333333333	Accuracy Validation: 0.8571
Loss Training: 0.7133893370628357 	Loss Vali

In [25]:
torch.save(teacherModel.state_dict(), './teacherModel')
!ls

sample_data  teacherModel


In [0]:
LOAD = False
if LOAD:
    teacherModel = TeacherNetwork()
    teacherModel.load_state_dict(torch.load('./teacherModel', map_location=device))  # Choose whatever GPU device number you want
    teacherModel.to(device)

In [0]:
class StudentNetwork(nn.Module):
    def __init__(self, h1=64):
        super(StudentNetwork, self).__init__()
        self.l1 = nn.Linear(28 * 28, h1)
        self.relu = nn.ReLU()
        self.l3 = nn.Linear(h1, 10)
    
    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.l1(x)
        x = self.relu(x)
        x = self.l3(x)
        return x

In [0]:
stdModel = StudentNetwork().to(device)
lr = 1e-3
optimizer = optim.Adam(stdModel.parameters(), lr=lr)

In [0]:
softmax_op = nn.Softmax(dim=1)
mseloss_fn = nn.MSELoss()

def my_loss(scores, targets, temperature = 3):
    soft_pred = softmax_op(scores / temperature)
    soft_targets = softmax_op(targets / temperature)
    loss = mseloss_fn(soft_pred, soft_targets)
    return loss

In [0]:
std_train = []

In [37]:
train_epochs = 5
for epoch in range(train_epochs):
    for batch_idx, (features, labels) in enumerate(train_loader):
        scores = stdModel(features.to(device))
        targets = teacherModel(features.to(device))
        loss = my_loss(scores, targets.to(device), temperature=3)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        std_train.append(loss.item())

    print("Epoch: ", epoch, "->")
    loss1, acc = total_loss_accuracy(teacherModel, train_loader, device, lossFn)
    print("Accuracy Training:", acc, end="\t")
    loss, acc = total_loss_accuracy(teacherModel, val_loader, device, lossFn)
    print("Accuracy Validation:", acc)
    print("Loss Training:", loss1.item(), "\tLoss Validation:", loss.item())


Epoch:  0 ->
Accuracy Training: 0.8542166666666666	Accuracy Validation: 0.8871
Loss Training: 0.5685402154922485 	Loss Validation: 0.4399680495262146
Epoch:  1 ->
Accuracy Training: 0.8540333333333333	Accuracy Validation: 0.8856
Loss Training: 0.5694134831428528 	Loss Validation: 0.416507363319397
Epoch:  2 ->
Accuracy Training: 0.8537	Accuracy Validation: 0.8923
Loss Training: 0.5812897682189941 	Loss Validation: 0.4282498359680176
Epoch:  3 ->
Accuracy Training: 0.8537666666666667	Accuracy Validation: 0.8888
Loss Training: 0.577438473701477 	Loss Validation: 0.42661532759666443
Epoch:  4 ->
Accuracy Training: 0.8533333333333334	Accuracy Validation: 0.889
Loss Training: 0.5899596214294434 	Loss Validation: 0.43650320172309875


In [38]:
std_train

[0.046665437519550323,
 0.04666648432612419,
 0.04986218735575676,
 0.05023883655667305,
 0.0514102540910244,
 0.04233106970787048,
 0.045577917248010635,
 0.047160785645246506,
 0.049749959260225296,
 0.046534694731235504,
 0.04332394525408745,
 0.04724808782339096,
 0.04258120805025101,
 0.045622196048498154,
 0.04186750575900078,
 0.04088243842124939,
 0.03994714468717575,
 0.03857836127281189,
 0.04005708917975426,
 0.036286093294620514,
 0.034500736743211746,
 0.040212664753198624,
 0.04322151467204094,
 0.03739115968346596,
 0.036340679973363876,
 0.03195996209979057,
 0.03470965102314949,
 0.03153304383158684,
 0.03172901272773743,
 0.03061499632894993,
 0.027961110696196556,
 0.02770097926259041,
 0.028711099177598953,
 0.03439578041434288,
 0.02953920140862465,
 0.02793644182384014,
 0.030018797144293785,
 0.026001622900366783,
 0.02820567600429058,
 0.02971387840807438,
 0.028220875188708305,
 0.030615806579589844,
 0.028194312006235123,
 0.028715873137116432,
 0.022963151335