In [2]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torchvision import datasets, transforms

In [3]:
epochs = 10
lr = 1e-3
batch_size = 100

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset_train = datasets.CIFAR10('data', download=True, train=True, transform=transform)
num_train = int(len(dataset_train) * 0.95)
dataset_train, dataset_valid = random_split(dataset_train, [num_train, len(dataset_train) - num_train])
dataset_test = datasets.CIFAR10('data', download=True, train=False, transform=transform)
loader_train = DataLoader(dataset_train, batch_size=batch_size, num_workers=0, shuffle=True)
loader_valid = DataLoader(dataset_valid, batch_size=batch_size, num_workers=0, shuffle=True)
loader_test = DataLoader(dataset_test, batch_size=batch_size, num_workers=0, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data\cifar-10-python.tar.gz


HBox(children=(IntProgress(value=0, max=170498071), HTML(value='')))


Extracting data\cifar-10-python.tar.gz to data
Files already downloaded and verified


In [5]:
def train(dataloader, model, orig_net, optimizer, criterion, a):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 100
    total_loss = 0.0

    for idx, (pic, real_label) in enumerate(dataloader):
        optimizer.zero_grad()
        predited_label = model(pic)
        loss = None
        if orig_net != None:
            label = orig_net(pic)
            loss = criterion(predited_label, label, real_label, a)
        else:
            loss = criterion(predited_label, real_label)
        total_loss += loss
        loss.backward()
        optimizer.step()
        total_acc += (predited_label.argmax(1) == real_label).sum().item()
        total_count += real_label.size(0)
        if idx % log_interval == 0 and idx > 0:
            print('{:5d}/{:5d} batches '
                  '| accuracy {:8.3f} | loss {:8.3f} '.format(idx, len(dataloader),
                                              total_acc / total_count, total_loss))
            total_acc, total_count = 0, 0
            total_loss = 0.0

In [6]:
def evaluate(dataloader, model):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (pic, real_label) in enumerate(dataloader):
            predited_label = model(pic)
            total_acc += (predited_label.argmax(1) == real_label).sum().item()
            total_count += real_label.size(0)
    return total_acc / total_count

In [7]:
def run(model, orig_net, optimizer, criterion, scheduler, a, epochs):
    total_accu = None
    for epoch in range(1, epochs + 1):
        train(loader_train, model, orig_net, optimizer, criterion, a)
        accu_val = evaluate(loader_valid, model)
        if total_accu is not None and total_accu > accu_val:
          scheduler.step()
        else:
           total_accu = accu_val
        print('-' * 56)
        print('end of epoch {:3d} | valid accuracy {:8.3f} '.format(epoch, accu_val))
        print('-' * 56)
    print('Checking the results of test dataset.')
    accu_test = evaluate(loader_test, model)
    print('test accuracy {:8.3f}'.format(accu_test))

First we train the teacher net. We stop the training early as suggested in the article ["On the Efficacy of Knowledge Distillation"](https://arxiv.org/abs/1910.01348).

In [8]:
class OriginalNet(nn.Module):
    def __init__(self):
        super(OriginalNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.drop_out = nn.Dropout()
        self.fc1 = nn.Linear(8 * 8 * 64, 1000)
        self.fc2 = nn.Linear(1000, 100)
        self.fc3 = nn.Linear(100, 10)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.drop_out(out)
        out = self.fc1(out)
        out = self.fc2(out)
        out = self.fc3(out)
        return out

In [9]:
orig_net = OriginalNet()
criterion_orig = nn.CrossEntropyLoss()
optimizer_orig = torch.optim.Adam(orig_net.parameters(), lr=lr)
scheduler_orig = torch.optim.lr_scheduler.StepLR(optimizer_orig, 1.0, gamma=0.1)
run(orig_net, None, optimizer_orig, criterion_orig, scheduler_orig, None, 3)

  100/  475 batches | accuracy    0.378 | loss  169.454 
  200/  475 batches | accuracy    0.483 | loss  141.003 
  300/  475 batches | accuracy    0.513 | loss  135.854 
  400/  475 batches | accuracy    0.540 | loss  126.297 
--------------------------------------------------------
end of epoch   1 | valid accuracy    0.613 
--------------------------------------------------------
  100/  475 batches | accuracy    0.587 | loss  115.748 
  200/  475 batches | accuracy    0.595 | loss  113.467 
  300/  475 batches | accuracy    0.614 | loss  109.469 
  400/  475 batches | accuracy    0.614 | loss  108.578 
--------------------------------------------------------
end of epoch   2 | valid accuracy    0.668 
--------------------------------------------------------
  100/  475 batches | accuracy    0.642 | loss  101.397 
  200/  475 batches | accuracy    0.648 | loss  101.955 
  300/  475 batches | accuracy    0.650 | loss   99.721 
  400/  475 batches | accuracy    0.657 | loss   98.224 


Then we train the student (without the help from the teacher) to see what accuracy it is able to achieve on its own.

In [10]:
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc1 = nn.Linear(32*16*16, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        out = self.layer1(x)
        out = out.reshape(out.size(0), -1)
        x = nn.functional.relu(self.fc1(out))
        x = self.fc2(x)
        return x

In [11]:
simple_net = SimpleNet()
criterion_simple = nn.CrossEntropyLoss()
optimizer_simple = torch.optim.Adam(simple_net.parameters(), lr=lr)
scheduler_simple = torch.optim.lr_scheduler.StepLR(optimizer_simple, 1.0, gamma=0.1)
run(simple_net, None, optimizer_simple, criterion_simple, scheduler_simple, None, epochs)

  100/  475 batches | accuracy    0.400 | loss  168.989 
  200/  475 batches | accuracy    0.501 | loss  138.641 
  300/  475 batches | accuracy    0.536 | loss  129.891 
  400/  475 batches | accuracy    0.562 | loss  122.358 
--------------------------------------------------------
end of epoch   1 | valid accuracy    0.615 
--------------------------------------------------------
  100/  475 batches | accuracy    0.631 | loss  105.487 
  200/  475 batches | accuracy    0.637 | loss  103.243 
  300/  475 batches | accuracy    0.649 | loss  100.185 
  400/  475 batches | accuracy    0.651 | loss   99.162 
--------------------------------------------------------
end of epoch   2 | valid accuracy    0.665 
--------------------------------------------------------
  100/  475 batches | accuracy    0.704 | loss   84.766 
  200/  475 batches | accuracy    0.707 | loss   82.677 
  300/  475 batches | accuracy    0.705 | loss   83.393 
  400/  475 batches | accuracy    0.705 | loss   83.151 


Finally, we train the student using model distillation technique with the criterion defined below.

In [12]:
def criterion(output, bert_prob, real_label, a):
    criterion_mse = nn.MSELoss()
    criterion_ce = nn.CrossEntropyLoss()
    return a * criterion_ce(output, real_label) + (1 - a) * criterion_mse(output, bert_prob)

In [13]:
simple_net = SimpleNet()
optimizer = torch.optim.Adam(simple_net.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
run(simple_net, orig_net, optimizer, criterion, scheduler, 0.9, epochs)

  100/  475 batches | accuracy    0.408 | loss  181.107 
  200/  475 batches | accuracy    0.512 | loss  141.054 
  300/  475 batches | accuracy    0.556 | loss  127.399 
  400/  475 batches | accuracy    0.577 | loss  120.312 
--------------------------------------------------------
end of epoch   1 | valid accuracy    0.624 
--------------------------------------------------------
  100/  475 batches | accuracy    0.639 | loss  103.640 
  200/  475 batches | accuracy    0.645 | loss  101.249 
  300/  475 batches | accuracy    0.660 | loss   97.025 
  400/  475 batches | accuracy    0.659 | loss   97.640 
--------------------------------------------------------
end of epoch   2 | valid accuracy    0.676 
--------------------------------------------------------
  100/  475 batches | accuracy    0.709 | loss   86.480 
  200/  475 batches | accuracy    0.711 | loss   84.996 
  300/  475 batches | accuracy    0.715 | loss   84.199 
  400/  475 batches | accuracy    0.716 | loss   83.410 


I have been able to slightly improve the accuracy. I tried to improve on the result by training another model ensembles, however I did not succeed. As I understand this is a common problem that may have many reasons (for me the most difficult one was lack of computational resources (I could not use complex models), but I have also encountered cases where the student model seemed to be too simple to be improved by distillation).