pyTorch implementation of the knowledge distillation to train a compact network to perform much better than it would have by "normal training"

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

plt.ion()

In [3]:
from six.moves import urllib
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_loader = torch.utils.data.DataLoader(datasets.MNIST(root='.', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=64, shuffle=True, num_workers=2)

test_loader = torch.utils.data.DataLoader(datasets.MNIST(root='.', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=64, shuffle=True, num_workers=2)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [47]:
class teacher_net(nn.Module):
  def __init__(self):
    super(teacher_net, self).__init__()
    self.teacher = nn.Sequential(
        #28
        nn.BatchNorm2d(1),
        nn.Conv2d(1, 10, kernel_size=3, stride=1),
        nn.ReLU(True),
        nn.Dropout2d(p=0.1),
        #26
        nn.BatchNorm2d(10),
        nn.Conv2d(10, 20, kernel_size=3, stride=1),
        nn.ReLU(True),
        nn.Dropout2d(p=0.1),
        #24
        nn.BatchNorm2d(20),
        nn.Conv2d(20, 30, kernel_size=3, stride=1),
        nn.ReLU(True),
        nn.Dropout2d(p=0.1),
        #22
        nn.BatchNorm2d(30),
        nn.Conv2d(30, 50, kernel_size=3, stride=1),
        nn.ReLU(True),
        nn.Dropout2d(p=0.1),
        #20
        nn.BatchNorm2d(50),
        nn.Conv2d(50, 20, kernel_size=1, stride=1),
        nn.ReLU(True),
        nn.Dropout2d(p=0.1),
        #20
        nn.BatchNorm2d(20),
        nn.Conv2d(20, 30, kernel_size=3, stride=1),
        nn.ReLU(True),
        nn.Dropout2d(p=0.1),
        #18
        nn.BatchNorm2d(30),
        nn.Conv2d(30, 40, kernel_size=3, stride=1),
        nn.ReLU(True),
        nn.Dropout2d(p=0.1),
        #16
        nn.BatchNorm2d(40),
        nn.Conv2d(40, 50, kernel_size=3, stride=1),
        nn.ReLU(True),
        nn.Dropout2d(p=0.1),
        #14
        nn.BatchNorm2d(50),
        nn.Conv2d(50, 20, kernel_size=1, stride=1),
        nn.ReLU(True),
        nn.Dropout2d(p=0.1),
        #14
        nn.BatchNorm2d(20),
        nn.Conv2d(20, 30, kernel_size=3, stride=1),
        nn.ReLU(True),
        nn.Dropout2d(p=0.1),
        #12
        nn.BatchNorm2d(30),
        nn.Conv2d(30, 50, kernel_size=3, stride=1),
        nn.ReLU(True),
        nn.Dropout2d(p=0.1),
        #10
        nn.Flatten(),
        nn.Dropout(p=0.1),
        nn.Linear(5000, 25),
        nn.ReLU(True),
        nn.Linear(25, 10),
    )

  def forward(self, x):
    x = self.teacher(x)
    return F.log_softmax(x, dim=1)

class student_net(nn.Module):
  def __init__(self):
    super(student_net, self).__init__()
    self.student = nn.Sequential(
        nn.Flatten(),
        nn.Linear(784, 10),
        nn.ReLU(True)
    )
   
  def forward(self, y):
    y = self.student(y)
    return F.log_softmax(y, dim=1)

def init_weights1(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

def init_weights2(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

teacher_net().apply(init_weights1)
teacher_net().apply(init_weights2)

teacher_model = teacher_net().to(device)
student_model = student_net().to(device)



In [48]:
optimizer_teacher = optim.SGD(teacher_model.parameters(), lr=0.01, momentum=0.8)

def train_teacher(epoch):
    teacher_model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer_teacher.zero_grad()
        output = teacher_model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer_teacher.step()
        if batch_idx % 500 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test_teacher():
    with torch.no_grad():
        teacher_model.eval()
        test_loss = 0
        correct = 0
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = teacher_model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            prediction = output.max(1, keepdim=True)[1]
            correct += prediction.eq(target.view_as(prediction)).sum().item()

        test_loss /= len(test_loader.dataset)
        print('\ntest set: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)\n'
              .format(test_loss, correct, len(test_loader.dataset),
                      100. * correct / len(test_loader.dataset)))

In [49]:
optimizer_student = optim.SGD(student_model.parameters(), lr=0.01, momentum=0.8)

def train_student(epoch):
    student_model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer_student.zero_grad()
        output = student_model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer_student.step()
        if batch_idx % 500 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test_student():
    with torch.no_grad():
        student_model.eval()
        test_loss = 0
        correct = 0
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = student_model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            prediction = output.max(1, keepdim=True)[1]
            correct += prediction.eq(target.view_as(prediction)).sum().item()

        test_loss /= len(test_loader.dataset)
        print('\ntest set: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)\n'
              .format(test_loss, correct, len(test_loader.dataset),
                      100. * correct / len(test_loader.dataset)))

In [37]:
for epoch in range(1, 4 + 1):
    train_teacher(epoch)
    test_teacher()


test set: average loss: 0.0718, accuracy: 9770/10000 (98%)


test set: average loss: 0.0570, accuracy: 9803/10000 (98%)


test set: average loss: 0.0473, accuracy: 9855/10000 (99%)


test set: average loss: 0.0409, accuracy: 9866/10000 (99%)



In [50]:
##Student direct training  for comparision
student_net().apply(init_weights2)

for epoch in range(1, 3 + 1):
    train_student(epoch)
    test_student()




test set: average loss: 0.4819, accuracy: 8367/10000 (84%)


test set: average loss: 0.4681, accuracy: 8409/10000 (84%)


test set: average loss: 0.4688, accuracy: 8381/10000 (84%)



In [51]:
torch.__version__

# define "soft" cross-entropy with pytorch tensor operations
def softXEnt (input, target):
    target = torch.exp(target)
    logprobs = input
    #print(logprobs.shape, target.shape, input.shape[0])
    k = torch.square(target - logprobs)
    k = torch.sum(k, 1)
    k = torch.mean(k)
    k = k.item()
    return k

# class SoftCrossEntropyLoss(nn.CrossEntropyLoss):
#   def __init__(self, dim=-1):
#     super(SoftCrossEntropyLoss, self).__init__()
#     self.dim = dim
#   def forward(self, pred, target):
#     pred = pred.log_softmax(dim=self.dim)
#     return torch.mean(torch.sum(-target * pred, dim=self.dim))



In [52]:
student_net().apply(init_weights1)
student_net().apply(init_weights2)



student_net(
  (student): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=784, out_features=10, bias=True)
    (2): ReLU(inplace=True)
  )
)

In [55]:
student_model = student_net().to(device)

optimizer_distill = optim.SGD(student_model.parameters(), lr=0.01, momentum=0.8)

def distillation(epoch):
  for batch_idx, (data, target) in enumerate(train_loader):
        data, target1 = data.to(device), target.to(device)
        target2 = teacher_model(data)
        optimizer_distill.zero_grad()
        output = student_model(data)
        loss = 0.5*softXEnt(output, target2) + F.nll_loss(output, target1)
        loss.backward()
        optimizer_distill.step()
        if batch_idx % 500 == 0:
            print('Distill Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

In [56]:
#Training by distillation

for epoch in range(1, 3 + 1):
    distillation(epoch)
    test_student()


test set: average loss: 0.2871, accuracy: 9185/10000 (92%)


test set: average loss: 0.2788, accuracy: 9193/10000 (92%)


test set: average loss: 0.2762, accuracy: 9209/10000 (92%)



Hence we obsereve quite significant improvement in performance of the small student network when trained through the knowledge distillation method, thus verifying the results of the original paper "Distilling the Knowledge in a Neural Network"