In [1]:
from os import path
import torch
from torch import optim
from torch import nn
from torch.nn import functional as F
import torchvision
from torchvision import transforms
from torchvision import models
import numpy as np
from tqdm.notebook import tqdm
import matplotlib
from matplotlib import pyplot as plt
%matplotlib inline

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)

BATCH_SIZE = 512
NUM_EPOCHS = 100
PRINT_EVERY = NUM_EPOCHS // 100 if NUM_EPOCHS > 100 else 1
TEACHER_PATH = "./teacher.pth"
LR = 0.01
NUM_WORKERS = 1

cuda


In [3]:
def get_acc(net, loader):
    net.eval()
    with torch.no_grad():
        total = 0
        correct = 0
        for data in loader:
            images, labels = data
            
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            outputs = net(images)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    net.train()
    return 100 * correct / total

In [4]:
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                     std=[0.229, 0.224, 0.225])])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, 
                                          batch_size=BATCH_SIZE, 
                                          shuffle=True, 
                                          num_workers=NUM_WORKERS)
testloader = torch.utils.data.DataLoader(testset, 
                                         batch_size=BATCH_SIZE, 
                                         shuffle=False, 
                                         num_workers=NUM_WORKERS)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
print(len(classes))

Files already downloaded and verified
Files already downloaded and verified
10


In [5]:
teacher = models.vgg16_bn(pretrained=True)
teacher.classifier[6] = nn.Linear(4096,10)
optimizer = optim.SGD(teacher.parameters(), lr=LR, momentum=0.9)
# optimizer = optim.Adam(teacher.parameters(), lr=LR)

In [6]:
if not path.exists(TEACHER_PATH):
    t = tqdm(range(NUM_EPOCHS))
    teacher.to(DEVICE)
    for epoch in t:
        running_loss = 0.0
        for i, data in enumerate(trainloader):
            inputs, labels = data
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)
            
            optimizer.zero_grad()

            outputs = teacher(inputs)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss

        if (epoch + 1) % PRINT_EVERY == 0:
            acc = get_acc(teacher, testloader)
            print(f'[{epoch + 1}] loss: {running_loss / len(trainloader):0.9f} | accuracy: {acc:0.2f}%')
            running_loss = 0.0

    print("Finished Training")
    torch.save(teacher.state_dict(), TEACHER_PATH)
else:
    print("Loaded saved teacher model")
    teacher.load_state_dict(torch.load(TEACHER_PATH))
    teacher.to(DEVICE)

Loaded saved teacher model


In [7]:
TEACHER_NUM_PARAMS = sum(p.numel() for p in teacher.parameters())
print(TEACHER_NUM_PARAMS)

134309962


In [8]:
TEACHER_ACC = get_acc(teacher, testloader)
print(f"Accuracy: {TEACHER_ACC} %")

Accuracy: 90.03 %


In [9]:
teacher.eval()

correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

with torch.no_grad():
    for data in testloader:
        images, labels = data    
        outputs = teacher(images.to(DEVICE))    
        _, predictions = torch.max(outputs, 1)
        for label, prediction in zip(labels, predictions.cpu()):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1
  
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f"Accuracy for class {classname} is: {accuracy}")

Accuracy for class plane is: 93.5
Accuracy for class car is: 95.2
Accuracy for class bird is: 87.0
Accuracy for class cat is: 79.0
Accuracy for class deer is: 89.8
Accuracy for class dog is: 82.9
Accuracy for class frog is: 93.3
Accuracy for class horse is: 92.2
Accuracy for class ship is: 94.7
Accuracy for class truck is: 92.7


In [10]:
class dVGG(nn.Module):
    def __init__(self, a=0, kind=1):
        super().__init__()
        self.one = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            
            nn.MaxPool2d(2, 2),
            
            nn.Flatten(),
            
            nn.Linear(16 * 5 * 5, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

        self.two = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),

            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            
            nn.Flatten(),
            
            nn.Linear(16 * 10 * 10, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

        self.three = nn.Sequential(
            nn.Conv2d(3, 64, 3),
            nn.ReLU(),
            nn.Dropout(),

            nn.Conv2d(64, 128, 3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout(),
            nn.BatchNorm2d(128),

            nn.Conv2d(128, 256, 3),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, 3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout(),
            nn.BatchNorm2d(256),

            nn.Conv2d(256, 512, 3),
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.Conv2d(512, 512, 3),
            nn.ReLU(),
            nn.Dropout(),
            nn.BatchNorm2d(512),

            nn.Flatten(),

            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(256, 10)
        )
        
        self.a = a
        self.kind = kind

    def forward(self, x):
        if self.kind == 1:
            out = self.one(x)
        elif self.kind == 2:
            out = self.two(x)
        elif self.kind == 3:
            out = self.three(x)
        else:
            raise ValueError("Unexpected `kind`")

        return out
    
    def loss(self, output, teacher_prob, real_label):
        return self.a * F.cross_entropy(output, real_label) + (1 - self.a) * F.mse_loss(output, teacher_prob)


In [11]:
def train_baseline(kind=3, opt="sgd", coef=1):
    print(f"=== BASELINE: {kind} | {opt} ===")
    net = dVGG(1, kind).to(DEVICE)
    
    if opt.lower() == "sgd":
        optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
    if opt.lower() == "adam":
        optimizer = optim.Adam(net.parameters(), lr=LR)

    if not path.exists(f"baseline_{kind}_{opt}.pth"):
        for epoch in tqdm(range(int(NUM_EPOCHS * coef))):
            running_loss = 0.0
            for i, data in enumerate(trainloader):
                inputs, labels = data
                
                inputs = inputs.to(DEVICE)
                labels = labels.to(DEVICE)

                optimizer.zero_grad()

                outputs = net(inputs)
                loss = F.cross_entropy(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss
            if (epoch + 1) % PRINT_EVERY == 0:
                acc = get_acc(net, testloader)
                print(f'[{epoch + 1}] loss: {running_loss / len(trainloader):0.5f} | accuracy: {acc:0.2f}%')
                running_loss = 0.0

        torch.save(net.state_dict(), f"./baseline_{kind}_{opt}.pth")
        print(f"=== Finished baseline: {kind} | {opt} ===")
    else:
        print("Loaded saved baseline model")
        net.load_state_dict(torch.load(f"baseline_{kind}_{opt}.pth"))
        net.to(DEVICE)

    baseline_acc = get_acc(net, testloader)

    print("Baseline accuracy on test:", baseline_acc, "%")
    return baseline_acc

In [12]:
def distil(a=0, kind=3, opt="sgd", coef=1):
    print(f"=== DISTILLATION: {a} | {kind} | {opt} ===")
    net = dVGG(a, kind).to(DEVICE)
    if opt.lower() == "sgd":
        optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
    if opt.lower() == "adam":
        optimizer = optim.Adam(net.parameters(), lr=LR)

    for epoch in tqdm(range(int(NUM_EPOCHS * coef))):
        running_loss = 0.0
        for i, data in enumerate(trainloader):
            inputs, labels = data
            
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)

            optimizer.zero_grad()

            outputs_teacher = teacher(inputs).detach()
            outputs = net(inputs)

            loss = net.loss(outputs, outputs_teacher, labels)
            
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        if (epoch + 1) % PRINT_EVERY == 0:
            acc = get_acc(net, testloader)
            print(f'[{epoch + 1}] loss: {running_loss / len(trainloader):0.5f} | accuracy: {acc:0.2f}%')
            running_loss = 0.0

    learner_acc = get_acc(net, testloader)
    torch.save(net.state_dict(), f"./distilled_{a}_{kind}_{opt}.pth")

    learner_num_params = sum(p.numel() for p in net.parameters())
    print(f"=== Finished distillation: {a} | {kind} | {opt} ===")
    print("\tTotal number of teacher params:", TEACHER_NUM_PARAMS)
    print("\tTotal number of learner params:", learner_num_params)
    print("\tTotal reduction:", (TEACHER_NUM_PARAMS - learner_num_params) / TEACHER_NUM_PARAMS * 100, "%")
    print("\tTeacher  accuracy on test:", TEACHER_ACC, "%")
    print("\tLearner  accuracy on test:", learner_acc, "%")
    print("\tBaseline accuracy on test:", BASELINE_ACC, "%")
    print("\tDiff:", TEACHER_ACC - learner_acc, learner_acc - BASELINE_ACC)
    print()

In [13]:
for kind in (3,):
    for opt in ("sgd", "adam"):
        BASELINE_ACC = train_baseline(kind=kind, opt=opt, coef=1)
        for a in (0, 0.1, 0.5, 0.7, 0.9):
            distil(a=a, kind=kind, opt=opt, coef=1)

=== BASELINE: 3 | sgd ===


  0%|          | 0/100 [00:00<?, ?it/s]

[1] loss: 2.01275 | accuracy: 16.48%
[2] loss: 1.59383 | accuracy: 29.49%
[3] loss: 1.42007 | accuracy: 37.44%
[4] loss: 1.30427 | accuracy: 44.47%
[5] loss: 1.20836 | accuracy: 48.16%
[6] loss: 1.12366 | accuracy: 54.34%
[7] loss: 1.05881 | accuracy: 56.25%
[8] loss: 0.99425 | accuracy: 61.32%
[9] loss: 0.93964 | accuracy: 62.33%
[10] loss: 0.89668 | accuracy: 67.50%
[11] loss: 0.85565 | accuracy: 65.71%
[12] loss: 0.81412 | accuracy: 66.85%
[13] loss: 0.78572 | accuracy: 70.37%
[14] loss: 0.74874 | accuracy: 71.26%
[15] loss: 0.72256 | accuracy: 71.29%
[16] loss: 0.69671 | accuracy: 73.64%
[17] loss: 0.67401 | accuracy: 74.53%
[18] loss: 0.65573 | accuracy: 75.20%
[19] loss: 0.62987 | accuracy: 76.40%
[20] loss: 0.60707 | accuracy: 75.35%
[21] loss: 0.59124 | accuracy: 76.19%
[22] loss: 0.56987 | accuracy: 76.45%
[23] loss: 0.55252 | accuracy: 79.69%
[24] loss: 0.53801 | accuracy: 77.94%
[25] loss: 0.51728 | accuracy: 78.76%
[26] loss: 0.50479 | accuracy: 80.01%
[27] loss: 0.48648 | 

  0%|          | 0/100 [00:00<?, ?it/s]

[1] loss: 84.23420 | accuracy: 51.48%
[2] loss: 58.95257 | accuracy: 63.55%
[3] loss: 47.30177 | accuracy: 71.28%
[4] loss: 40.34689 | accuracy: 75.93%
[5] loss: 35.76577 | accuracy: 78.65%
[6] loss: 32.77897 | accuracy: 80.17%
[7] loss: 30.50218 | accuracy: 81.87%
[8] loss: 28.60887 | accuracy: 82.18%
[9] loss: 27.23889 | accuracy: 82.33%
[10] loss: 25.88178 | accuracy: 82.88%
[11] loss: 24.57196 | accuracy: 83.54%
[12] loss: 23.63063 | accuracy: 83.67%
[13] loss: 22.68958 | accuracy: 83.71%
[14] loss: 21.90459 | accuracy: 84.49%
[15] loss: 21.35066 | accuracy: 84.24%
[16] loss: 20.64750 | accuracy: 84.71%
[17] loss: 20.00450 | accuracy: 84.55%
[18] loss: 19.50154 | accuracy: 84.78%
[19] loss: 19.05334 | accuracy: 85.00%
[20] loss: 18.60837 | accuracy: 85.30%
[21] loss: 18.22881 | accuracy: 85.71%
[22] loss: 17.79391 | accuracy: 84.58%
[23] loss: 17.49789 | accuracy: 85.38%
[24] loss: 17.25982 | accuracy: 85.23%
[25] loss: 16.61862 | accuracy: 85.57%
[26] loss: 16.47731 | accuracy: 85

  0%|          | 0/100 [00:00<?, ?it/s]

[1] loss: 77.41477 | accuracy: 48.97%
[2] loss: 53.78115 | accuracy: 64.41%
[3] loss: 42.88178 | accuracy: 71.26%
[4] loss: 36.57633 | accuracy: 75.72%
[5] loss: 32.80844 | accuracy: 78.39%
[6] loss: 29.78693 | accuracy: 80.00%
[7] loss: 27.66233 | accuracy: 80.79%
[8] loss: 26.03897 | accuracy: 81.52%
[9] loss: 24.55875 | accuracy: 81.32%
[10] loss: 23.15610 | accuracy: 83.01%
[11] loss: 22.55224 | accuracy: 83.46%
[12] loss: 21.64624 | accuracy: 83.75%
[13] loss: 20.81496 | accuracy: 83.42%
[14] loss: 19.98258 | accuracy: 83.35%
[15] loss: 19.41748 | accuracy: 84.23%
[16] loss: 18.82380 | accuracy: 84.58%
[17] loss: 18.23334 | accuracy: 84.28%
[18] loss: 17.67075 | accuracy: 84.78%
[19] loss: 17.24389 | accuracy: 85.18%
[20] loss: 16.74576 | accuracy: 85.21%
[21] loss: 16.55995 | accuracy: 85.60%
[22] loss: 16.10880 | accuracy: 85.60%
[23] loss: 15.70543 | accuracy: 85.50%
[24] loss: 15.56678 | accuracy: 84.87%
[25] loss: 15.32994 | accuracy: 85.83%
[26] loss: 15.14182 | accuracy: 85

  0%|          | 0/100 [00:00<?, ?it/s]

[1] loss: 45.70777 | accuracy: 49.62%
[2] loss: 32.33512 | accuracy: 63.87%
[3] loss: 26.45718 | accuracy: 71.69%
[4] loss: 22.05780 | accuracy: 75.78%
[5] loss: 19.56806 | accuracy: 76.41%
[6] loss: 17.83302 | accuracy: 78.98%
[7] loss: 16.44975 | accuracy: 80.42%
[8] loss: 15.43682 | accuracy: 82.10%
[9] loss: 14.45893 | accuracy: 82.23%
[10] loss: 13.78754 | accuracy: 83.28%
[11] loss: 13.21378 | accuracy: 83.03%
[12] loss: 12.64071 | accuracy: 83.05%
[13] loss: 12.04774 | accuracy: 83.37%
[14] loss: 11.62997 | accuracy: 84.51%
[15] loss: 11.24252 | accuracy: 84.70%
[16] loss: 10.87741 | accuracy: 84.60%
[17] loss: 10.57896 | accuracy: 84.53%
[18] loss: 10.31445 | accuracy: 85.09%
[19] loss: 10.02449 | accuracy: 85.07%
[20] loss: 9.73367 | accuracy: 84.66%
[21] loss: 9.54885 | accuracy: 85.51%
[22] loss: 9.31420 | accuracy: 85.75%
[23] loss: 9.10376 | accuracy: 85.84%
[24] loss: 8.93882 | accuracy: 85.81%
[25] loss: 8.74564 | accuracy: 85.93%
[26] loss: 8.53981 | accuracy: 85.70%
[2

  0%|          | 0/100 [00:00<?, ?it/s]

[1] loss: 29.44281 | accuracy: 42.69%
[2] loss: 21.64263 | accuracy: 57.17%
[3] loss: 17.52663 | accuracy: 65.95%
[4] loss: 15.08388 | accuracy: 71.92%
[5] loss: 13.28201 | accuracy: 76.07%
[6] loss: 12.06145 | accuracy: 78.80%
[7] loss: 11.05220 | accuracy: 80.18%
[8] loss: 10.40441 | accuracy: 79.93%
[9] loss: 9.69715 | accuracy: 82.44%
[10] loss: 9.17740 | accuracy: 82.65%
[11] loss: 8.72670 | accuracy: 83.60%
[12] loss: 8.29822 | accuracy: 84.00%
[13] loss: 7.96676 | accuracy: 84.17%
[14] loss: 7.54099 | accuracy: 85.08%
[15] loss: 7.37635 | accuracy: 84.83%
[16] loss: 7.10875 | accuracy: 84.80%
[17] loss: 6.89366 | accuracy: 85.50%
[18] loss: 6.70002 | accuracy: 85.29%
[19] loss: 6.44706 | accuracy: 85.58%
[20] loss: 6.26881 | accuracy: 85.92%
[21] loss: 6.09624 | accuracy: 86.03%
[22] loss: 5.86030 | accuracy: 86.19%
[23] loss: 5.76468 | accuracy: 85.91%
[24] loss: 5.71878 | accuracy: 85.99%
[25] loss: 5.54422 | accuracy: 86.14%
[26] loss: 5.48012 | accuracy: 86.49%
[27] loss: 5.

  0%|          | 0/100 [00:00<?, ?it/s]

[1] loss: 11.78005 | accuracy: 22.24%
[2] loss: 9.68677 | accuracy: 48.76%
[3] loss: 8.36221 | accuracy: 56.59%
[4] loss: 7.42693 | accuracy: 63.44%
[5] loss: 6.66258 | accuracy: 69.73%
[6] loss: 6.18992 | accuracy: 73.76%
[7] loss: 5.70670 | accuracy: 76.24%
[8] loss: 5.36232 | accuracy: 78.01%
[9] loss: 4.98056 | accuracy: 79.40%
[10] loss: 4.64276 | accuracy: 80.55%
[11] loss: 4.45216 | accuracy: 81.85%
[12] loss: 4.21181 | accuracy: 81.67%
[13] loss: 4.01118 | accuracy: 82.53%
[14] loss: 3.86302 | accuracy: 83.55%
[15] loss: 3.67982 | accuracy: 83.40%
[16] loss: 3.53470 | accuracy: 84.18%
[17] loss: 3.38351 | accuracy: 84.38%
[18] loss: 3.28840 | accuracy: 84.65%
[19] loss: 3.12890 | accuracy: 84.70%
[20] loss: 3.03602 | accuracy: 85.17%
[21] loss: 2.94719 | accuracy: 85.05%
[22] loss: 2.85042 | accuracy: 85.91%
[23] loss: 2.74644 | accuracy: 85.50%
[24] loss: 2.68239 | accuracy: 86.00%
[25] loss: 2.61589 | accuracy: 85.80%
[26] loss: 2.53777 | accuracy: 86.02%
[27] loss: 2.45240 |

  0%|          | 0/100 [00:00<?, ?it/s]

[1] loss: 1.69565 | accuracy: 50.44%
[2] loss: 1.28266 | accuracy: 61.07%
[3] loss: 1.07446 | accuracy: 65.20%
[4] loss: 0.91172 | accuracy: 74.05%
[5] loss: 0.80041 | accuracy: 75.68%
[6] loss: 0.73305 | accuracy: 78.38%
[7] loss: 0.66862 | accuracy: 78.02%
[8] loss: 0.61663 | accuracy: 78.80%
[9] loss: 0.57885 | accuracy: 79.20%
[10] loss: 0.55136 | accuracy: 81.20%
[11] loss: 0.50726 | accuracy: 81.84%
[12] loss: 0.48418 | accuracy: 81.66%
[13] loss: 0.45599 | accuracy: 80.31%
[14] loss: 0.43578 | accuracy: 82.81%
[15] loss: 0.41570 | accuracy: 82.66%
[16] loss: 0.40367 | accuracy: 82.30%
[17] loss: 0.37368 | accuracy: 81.47%
[18] loss: 0.36716 | accuracy: 82.82%
[19] loss: 0.35288 | accuracy: 82.86%
[20] loss: 0.33561 | accuracy: 83.48%
[21] loss: 0.32482 | accuracy: 83.52%
[22] loss: 0.30887 | accuracy: 83.26%
[23] loss: 0.30334 | accuracy: 82.71%
[24] loss: 0.29436 | accuracy: 84.00%
[25] loss: 0.28168 | accuracy: 83.39%
[26] loss: 0.27038 | accuracy: 83.23%
[27] loss: 0.25933 | 

  0%|          | 0/100 [00:00<?, ?it/s]

[1] loss: 80.34827 | accuracy: 55.65%
[2] loss: 56.84294 | accuracy: 62.76%
[3] loss: 45.58154 | accuracy: 73.22%
[4] loss: 39.27451 | accuracy: 76.99%
[5] loss: 35.26430 | accuracy: 78.94%
[6] loss: 32.43030 | accuracy: 80.35%
[7] loss: 30.18510 | accuracy: 79.77%
[8] loss: 28.49910 | accuracy: 80.57%
[9] loss: 26.79936 | accuracy: 82.57%
[10] loss: 25.87925 | accuracy: 81.91%
[11] loss: 24.66003 | accuracy: 83.15%
[12] loss: 23.95755 | accuracy: 83.65%
[13] loss: 22.87745 | accuracy: 83.98%
[14] loss: 22.57643 | accuracy: 84.60%
[15] loss: 21.81855 | accuracy: 84.37%
[16] loss: 20.74304 | accuracy: 84.27%
[17] loss: 20.54493 | accuracy: 84.67%
[18] loss: 20.02566 | accuracy: 84.62%
[19] loss: 19.60288 | accuracy: 85.16%
[20] loss: 18.88608 | accuracy: 84.39%
[21] loss: 18.78412 | accuracy: 85.47%
[22] loss: 18.29001 | accuracy: 84.91%
[23] loss: 17.93477 | accuracy: 85.27%
[24] loss: 17.64173 | accuracy: 85.31%
[25] loss: 17.33259 | accuracy: 85.06%
[26] loss: 17.16073 | accuracy: 85

  0%|          | 0/100 [00:00<?, ?it/s]

[1] loss: 72.96763 | accuracy: 53.66%
[2] loss: 51.61487 | accuracy: 65.74%
[3] loss: 41.74505 | accuracy: 72.82%
[4] loss: 35.61942 | accuracy: 75.09%
[5] loss: 31.66616 | accuracy: 78.13%
[6] loss: 29.33301 | accuracy: 79.34%
[7] loss: 27.36659 | accuracy: 81.14%
[8] loss: 25.65966 | accuracy: 82.36%
[9] loss: 24.45773 | accuracy: 83.16%
[10] loss: 23.53576 | accuracy: 82.88%
[11] loss: 22.27504 | accuracy: 83.43%
[12] loss: 21.43976 | accuracy: 82.58%
[13] loss: 20.69879 | accuracy: 84.24%
[14] loss: 20.00186 | accuracy: 84.53%
[15] loss: 19.32882 | accuracy: 84.54%
[16] loss: 19.10549 | accuracy: 84.67%
[17] loss: 18.51427 | accuracy: 84.61%
[18] loss: 18.03265 | accuracy: 84.76%
[19] loss: 17.49799 | accuracy: 84.73%
[20] loss: 17.01663 | accuracy: 85.06%
[21] loss: 17.03398 | accuracy: 85.22%
[22] loss: 16.42495 | accuracy: 84.85%
[23] loss: 16.20286 | accuracy: 85.09%
[24] loss: 15.96742 | accuracy: 85.66%
[25] loss: 15.64665 | accuracy: 85.14%
[26] loss: 15.42966 | accuracy: 85

  0%|          | 0/100 [00:00<?, ?it/s]

[1] loss: 41.51855 | accuracy: 50.78%
[2] loss: 30.89795 | accuracy: 66.88%
[3] loss: 25.17502 | accuracy: 71.89%
[4] loss: 21.76263 | accuracy: 77.06%
[5] loss: 19.19343 | accuracy: 77.62%
[6] loss: 17.57857 | accuracy: 80.05%
[7] loss: 16.37826 | accuracy: 80.30%
[8] loss: 15.46283 | accuracy: 81.75%
[9] loss: 14.58701 | accuracy: 82.04%
[10] loss: 14.03727 | accuracy: 82.98%
[11] loss: 13.27051 | accuracy: 81.74%
[12] loss: 12.74381 | accuracy: 83.55%
[13] loss: 12.19071 | accuracy: 83.41%
[14] loss: 11.92919 | accuracy: 83.67%
[15] loss: 11.47560 | accuracy: 84.18%
[16] loss: 11.16885 | accuracy: 84.07%
[17] loss: 10.79427 | accuracy: 84.42%
[18] loss: 10.49152 | accuracy: 84.20%
[19] loss: 10.21669 | accuracy: 84.68%
[20] loss: 9.98903 | accuracy: 84.58%
[21] loss: 9.80352 | accuracy: 84.73%
[22] loss: 9.57314 | accuracy: 84.64%
[23] loss: 9.31354 | accuracy: 85.00%
[24] loss: 9.26908 | accuracy: 84.22%
[25] loss: 9.04260 | accuracy: 85.43%
[26] loss: 9.09656 | accuracy: 85.05%
[2

  0%|          | 0/100 [00:00<?, ?it/s]

[1] loss: 27.20183 | accuracy: 51.34%
[2] loss: 20.42100 | accuracy: 64.93%
[3] loss: 16.84263 | accuracy: 71.50%
[4] loss: 14.32314 | accuracy: 76.79%
[5] loss: 12.77414 | accuracy: 78.29%
[6] loss: 11.66415 | accuracy: 78.76%
[7] loss: 10.77147 | accuracy: 81.14%
[8] loss: 10.13765 | accuracy: 81.24%
[9] loss: 9.58090 | accuracy: 82.77%
[10] loss: 9.11221 | accuracy: 82.41%
[11] loss: 8.67924 | accuracy: 83.18%
[12] loss: 8.31265 | accuracy: 83.52%
[13] loss: 8.03719 | accuracy: 83.41%
[14] loss: 7.76163 | accuracy: 84.11%
[15] loss: 7.47510 | accuracy: 83.84%
[16] loss: 7.23804 | accuracy: 83.61%
[17] loss: 6.92237 | accuracy: 85.05%
[18] loss: 6.81590 | accuracy: 84.47%
[19] loss: 6.69051 | accuracy: 85.16%
[20] loss: 6.46703 | accuracy: 84.46%
[21] loss: 6.31510 | accuracy: 85.44%
[22] loss: 6.09472 | accuracy: 84.89%
[23] loss: 6.04276 | accuracy: 85.23%
[24] loss: 5.94728 | accuracy: 85.18%
[25] loss: 5.80775 | accuracy: 85.00%
[26] loss: 5.65579 | accuracy: 84.87%
[27] loss: 5.

  0%|          | 0/100 [00:00<?, ?it/s]

[1] loss: 10.82560 | accuracy: 48.30%
[2] loss: 8.80848 | accuracy: 62.97%
[3] loss: 7.47887 | accuracy: 65.75%
[4] loss: 6.45710 | accuracy: 73.14%
[5] loss: 5.77539 | accuracy: 75.49%
[6] loss: 5.32330 | accuracy: 77.20%
[7] loss: 4.87714 | accuracy: 79.30%
[8] loss: 4.54973 | accuracy: 79.61%
[9] loss: 4.35258 | accuracy: 78.60%
[10] loss: 4.12981 | accuracy: 80.99%
[11] loss: 3.95223 | accuracy: 81.21%
[12] loss: 3.76738 | accuracy: 82.54%
[13] loss: 3.59874 | accuracy: 82.84%
[14] loss: 3.49506 | accuracy: 82.03%
[15] loss: 3.37713 | accuracy: 83.94%
[16] loss: 3.26906 | accuracy: 82.82%
[17] loss: 3.16229 | accuracy: 83.68%
[18] loss: 3.05739 | accuracy: 84.14%
[19] loss: 2.98605 | accuracy: 83.97%
[20] loss: 2.87266 | accuracy: 84.30%
[21] loss: 2.79890 | accuracy: 84.07%
[22] loss: 2.77672 | accuracy: 84.13%
[23] loss: 2.69499 | accuracy: 85.13%
[24] loss: 2.60966 | accuracy: 84.21%
[25] loss: 2.55508 | accuracy: 84.30%
[26] loss: 2.52615 | accuracy: 84.72%
[27] loss: 2.43657 |

We can clearly see that, in general, the distilled network performed on test better than the baseline, i.e., the same net trained on the dataset.

The final accuracy of the teacher network (VGG16 with batchnormalization) is about 90%

The accuracy of the baseline is about 86.16% while the accuracy of the corresponding distilled network is about 87.61%.

To summarize, the total number of parameters is reduced by 96% while decrease in accuracy in coparison with teacher is about 2.4%. This gives improvement of the overall performance in comparison with the baseline at the level of 1.5%. Unfortunately, the decrease in the accuracy vs. teacher is greater than improvement vs. baseline.

To conclude, this results, in my opinion, are acceptable as noticeable improvement with significant reduction of the teacher network was achieved.

For example,

```=== Finished distillation: 0.9 | 3 | sgd ===
	Total number of teacher params: 134309962
	Total number of learner params: 4917902
	Total reduction: 96.33839372242544 %
	Teacher  accuracy on test: 90.03 %
	Learner  accuracy on test: 87.61 %
	Baseline accuracy on test: 86.16 %
	Diff: 2.4200000000000017 1.4500000000000028
```