# Knowledge distillation
Suppose that we have a large network (*teacher network*) or an ensemble of networks which has a good accuracy but doesn't fit into memory/runtime requirements. Instead of training a smaller network (*student network*) directly on the original dataset, we can train this network to predict outputs of teacher networks. It turns out that the perfomance could be even better! This approach doesn't help with training speed, but can be quite beneficial when we'd like to reduce the model size for low-memory devices.

* https://www.ttic.edu/dl/dark14.pdf
* [Distilling the Knowledge in a Neural Network](https://arxiv.org/abs/1503.02531)
* https://medium.com/neural-machines/knowledge-distillation-dc241d7c2322

Even the completely different ([article](https://arxiv.org/abs/1711.10433)) architecture can be used in a student model, e.g. you can approximate an autoregressive model (WaveNet) by a non-autoregressive one.

# Task
## 1. Teacher network
Train good enough (teacher) network, achieve >=35% accuracy on validation set of Tiny Imagenet (you can reuse any network from homework part 1 here).

In [1]:
!wget --no-check-certificate 'https://raw.githubusercontent.com/yandexdataschool/deep_vision_and_graphics/fall21/homework01/tiny_img.py' -O tiny_img.py
!wget --no-check-certificate 'https://raw.githubusercontent.com/yandexdataschool/deep_vision_and_graphics/fall21/homework01/tiny_img_dataset.py' -O tiny_img_dataset.py

--2024-03-11 13:31:49--  https://raw.githubusercontent.com/yandexdataschool/deep_vision_and_graphics/fall21/homework01/tiny_img.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 813 [text/plain]
Saving to: 'tiny_img.py'


2024-03-11 13:31:49 (39.1 MB/s) - 'tiny_img.py' saved [813/813]

--2024-03-11 13:31:50--  https://raw.githubusercontent.com/yandexdataschool/deep_vision_and_graphics/fall21/homework01/tiny_img_dataset.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1555 (1.5K) [text/plain]
Saving to: 'tiny_img_dataset.py'


2

In [2]:
from tiny_img import download_tinyImg200
data_path = '.'
download_tinyImg200(data_path)

Dataset was downloaded to './tiny-imagenet-200.zip'
Extract downloaded dataset to '.'


In [3]:
import torch
import torchvision
from torchvision import transforms
import tqdm

def get_computing_device():
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')
    return device

device = get_computing_device()
print(f"Our main computing device is '{device}'")

Our main computing device is 'cuda:0'


In [4]:
train_trainsforms = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.ToTensor(),
     transforms.RandomRotation(5),
     torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2)
    ]
)

In [5]:
import tiny_img_dataset
train_dataset = tiny_img_dataset.TinyImagenetRAM('tiny-imagenet-200/train', transform=train_trainsforms)

tiny-imagenet-200/train: 100%|██████████| 200/200 [01:33<00:00,  2.14it/s]


In [6]:
from torch.utils.data import Dataset
import os
from PIL import Image

class TinyImagenetValDataset(Dataset):
    def __init__(self, root, transform=transforms.ToTensor()):
        super().__init__()

        self.root = root
        with open(os.path.join(root, 'val_annotations.txt')) as f:
            annotations = []
            for line in f:
                img_name, class_label = line.split('\t')[:2]
                annotations.append((img_name, class_label))

        self.classes = sorted(list(set(annotation[1] for annotation in annotations)))

        assert len(self.classes) == 200, len(self.classes)
        assert all(self.classes[i] < self.classes[i+1] for i in range(len(self.classes)-1)), 'classes should be ordered'
        assert all(isinstance(elem, type(annotations[0][1])) for elem in self.classes), 'your just need to reuse class_labels'

        self.class_to_idx = {item: index for index, item in enumerate(self.classes)}

        self.transform = transform

        self.images, self.targets = [], []
        for img_name, class_name in tqdm.tqdm(annotations, desc=root):
            img_name = os.path.join(root, 'images', img_name)
            # 3. load image and store it in self.images (your may want to use tiny_img_dataset.read_rgb_image)
            # store the class index in self.targets
            image = tiny_img_dataset.read_rgb_image(img_name)

            assert image.shape == (64, 64, 3), image.shape
            self.images.append(Image.fromarray(image))
            self.targets.append(self.class_to_idx[class_name])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        # take image and its target label from "self.images" and "self.targets",
        # transform the image using self.transform and return the transformed image and its target label

        image = self.images[index]
        image = self.transform(image)
        target = self.targets[index]

        return image, target

In [7]:
val_dataset = TinyImagenetValDataset('tiny-imagenet-200/val', transform=transforms.ToTensor())

tiny-imagenet-200/val: 100%|██████████| 10000/10000 [00:09<00:00, 1062.47it/s]


In [8]:
batch_size = 64
train_batch_gen = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True, num_workers=12)
val_batch_gen = torch.utils.data.DataLoader(val_dataset, batch_size, shuffle=False, num_workers=12)



In [9]:
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

def compute_loss(y_pred, y_true):
    return F.cross_entropy(y_pred, y_true).mean()

In [10]:
class GlobalAveragePooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        
    def forward(self, x):
        return torch.mean(x, dim=self.dim)

In [11]:
class ConvBNRelu(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super().__init__() 
        
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

In [12]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, 1, padding)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU()
        
        self.conv3 = None
        if in_channels != out_channels or stride != 1:
            self.conv3 = nn.Conv2d(in_channels, out_channels, 1, stride, 0)
        
    def forward(self, x):
        residual = self.bn2(self.conv2(self.relu1(self.bn1(self.conv1(x)))))
        
        if self.conv3 is not None:
            x = self.conv3(x)
        
        return self.relu2(x + residual)        
    
    
def create_network_like_resnet():
    model = nn.Sequential()
    config = [[32, 32], [64, 64], [128, 128]]
    
    model.add_module("init_conv", ConvBNRelu(3, 32, kernel_size=7, stride=2, padding=3)) 
    in_channels = 32
    for i in range(len(config)):
        for j in range(len(config[i])):
            out_channels = config[i][j]
            stride = 2 if i != 0 and j == 0 else 1
            model.add_module(f"ResidualBlock{i}_{j}", ResidualBlock(in_channels, out_channels, 3, stride, padding=1))
            in_channels = out_channels
            
    model.add_module("pool", GlobalAveragePooling((2, 3)))
    model.add_module("logits", nn.Linear(out_channels, 200))
    return model

In [13]:
import numpy as np
import time

def eval_model(model, val_batch_gen):
    model.train(False) 
    val_accuracy = []
    with torch.no_grad():
        for X_batch, y_batch in val_batch_gen:
            X_batch = X_batch.to(device)
            logits = model(X_batch)
            y_pred = logits.max(1)[1].data
            val_accuracy.append(np.mean((y_batch.cpu() == y_pred.cpu()).numpy()))
            
    return np.mean(val_accuracy)


def train_model(model, opt, train_batch_gen):
    model.train(True) 
    train_loss = []
    for (X_batch, y_batch) in tqdm.tqdm(train_batch_gen):
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)

        opt.zero_grad()
        predictions = model(X_batch) 
        loss = compute_loss(predictions, y_batch)
        loss.backward()
        opt.step()
        
        train_loss.append(loss.cpu().data.numpy())
        
    return np.mean(train_loss)


def train_loop(model, opt, train_data_generator, val_data_generator, num_epochs):
    for epoch in range(num_epochs):
        start_time = time.time()

        train_loss = train_model(model, opt, train_data_generator)

        val_accuracy = eval_model(model, val_data_generator)

        # Then we print the results for this epoch:
        print("Epoch {} of {} took {:.3f}s".format(epoch + 1, num_epochs, time.time() - start_time))
        print("  training loss (in-iteration): \t{:.6f}".format(train_loss))
        print("  validation accuracy: \t\t\t{:.2f} %".format(val_accuracy * 100))

In [18]:
model = create_network_like_resnet().to(device)
opt = torch.optim.Adam(model.parameters())

In [19]:
train_loop(model, opt, train_batch_gen, val_batch_gen, 30)

100%|██████████| 1563/1563 [02:00<00:00, 13.01it/s]


Epoch 1 of 30 took 122.109s
  training loss (in-iteration): 	4.714720
  validation accuracy: 			8.18 %


100%|██████████| 1563/1563 [02:00<00:00, 13.00it/s]


Epoch 2 of 30 took 122.236s
  training loss (in-iteration): 	3.955818
  validation accuracy: 			18.77 %


100%|██████████| 1563/1563 [02:01<00:00, 12.87it/s]


Epoch 4 of 30 took 123.448s
  training loss (in-iteration): 	3.240875
  validation accuracy: 			26.09 %


100%|██████████| 1563/1563 [02:00<00:00, 12.95it/s]


Epoch 5 of 30 took 122.713s
  training loss (in-iteration): 	3.030022
  validation accuracy: 			30.73 %


100%|██████████| 1563/1563 [01:58<00:00, 13.17it/s]


Epoch 6 of 30 took 120.716s
  training loss (in-iteration): 	2.863729
  validation accuracy: 			33.36 %


100%|██████████| 1563/1563 [01:58<00:00, 13.18it/s]


Epoch 7 of 30 took 120.528s
  training loss (in-iteration): 	2.730980
  validation accuracy: 			34.79 %


100%|██████████| 1563/1563 [01:58<00:00, 13.19it/s]


Epoch 8 of 30 took 120.500s
  training loss (in-iteration): 	2.618864
  validation accuracy: 			36.50 %


100%|██████████| 1563/1563 [02:00<00:00, 12.93it/s]


Epoch 9 of 30 took 122.860s
  training loss (in-iteration): 	2.520372
  validation accuracy: 			37.96 %


100%|██████████| 1563/1563 [02:01<00:00, 12.91it/s]


Epoch 10 of 30 took 123.045s
  training loss (in-iteration): 	2.444409
  validation accuracy: 			38.35 %


100%|██████████| 1563/1563 [01:59<00:00, 13.03it/s]


Epoch 11 of 30 took 121.982s
  training loss (in-iteration): 	2.369544
  validation accuracy: 			40.19 %


100%|██████████| 1563/1563 [02:01<00:00, 12.86it/s]


Epoch 12 of 30 took 123.733s
  training loss (in-iteration): 	2.308589
  validation accuracy: 			39.22 %


100%|██████████| 1563/1563 [02:01<00:00, 12.84it/s]


Epoch 13 of 30 took 123.803s
  training loss (in-iteration): 	2.254659
  validation accuracy: 			41.71 %


100%|██████████| 1563/1563 [02:01<00:00, 12.89it/s]


Epoch 14 of 30 took 123.307s
  training loss (in-iteration): 	2.196541
  validation accuracy: 			40.89 %


100%|██████████| 1563/1563 [02:01<00:00, 12.88it/s]


Epoch 15 of 30 took 123.386s
  training loss (in-iteration): 	2.147367
  validation accuracy: 			43.09 %


100%|██████████| 1563/1563 [02:01<00:00, 12.82it/s]


Epoch 16 of 30 took 123.931s
  training loss (in-iteration): 	2.104187
  validation accuracy: 			42.30 %


100%|██████████| 1563/1563 [02:01<00:00, 12.82it/s]


Epoch 17 of 30 took 123.933s
  training loss (in-iteration): 	2.061252
  validation accuracy: 			42.51 %


100%|██████████| 1563/1563 [02:01<00:00, 12.87it/s]


Epoch 18 of 30 took 123.463s
  training loss (in-iteration): 	2.025293
  validation accuracy: 			43.41 %


100%|██████████| 1563/1563 [02:01<00:00, 12.89it/s]


Epoch 19 of 30 took 123.410s
  training loss (in-iteration): 	1.988471
  validation accuracy: 			43.41 %


100%|██████████| 1563/1563 [02:02<00:00, 12.71it/s]


Epoch 20 of 30 took 124.915s
  training loss (in-iteration): 	1.954000
  validation accuracy: 			43.16 %


100%|██████████| 1563/1563 [02:03<00:00, 12.69it/s]


Epoch 21 of 30 took 125.135s
  training loss (in-iteration): 	1.918566
  validation accuracy: 			42.62 %


100%|██████████| 1563/1563 [02:01<00:00, 12.87it/s]


Epoch 22 of 30 took 123.443s
  training loss (in-iteration): 	1.889911
  validation accuracy: 			42.71 %


100%|██████████| 1563/1563 [02:01<00:00, 12.86it/s]


Epoch 23 of 30 took 123.511s
  training loss (in-iteration): 	1.856441
  validation accuracy: 			43.88 %


100%|██████████| 1563/1563 [02:01<00:00, 12.85it/s]


Epoch 24 of 30 took 123.675s
  training loss (in-iteration): 	1.832178
  validation accuracy: 			44.28 %


100%|██████████| 1563/1563 [02:02<00:00, 12.77it/s]


Epoch 25 of 30 took 124.355s
  training loss (in-iteration): 	1.804302
  validation accuracy: 			43.96 %


100%|██████████| 1563/1563 [02:03<00:00, 12.62it/s]


Epoch 26 of 30 took 125.839s
  training loss (in-iteration): 	1.779635
  validation accuracy: 			44.34 %


100%|██████████| 1563/1563 [02:02<00:00, 12.76it/s]


Epoch 27 of 30 took 124.490s
  training loss (in-iteration): 	1.744399
  validation accuracy: 			44.60 %


100%|██████████| 1563/1563 [02:02<00:00, 12.75it/s]


Epoch 28 of 30 took 124.837s
  training loss (in-iteration): 	1.731961
  validation accuracy: 			44.53 %


100%|██████████| 1563/1563 [02:01<00:00, 12.84it/s]


Epoch 29 of 30 took 123.727s
  training loss (in-iteration): 	1.704812
  validation accuracy: 			44.76 %


100%|██████████| 1563/1563 [02:03<00:00, 12.71it/s]


Epoch 30 of 30 took 125.027s
  training loss (in-iteration): 	1.682346
  validation accuracy: 			44.72 %


## 2. Student network 
Train small (student) network, achieve 20-25% accuracy, draw a plot "training and testing errors vs train step index"

In [42]:
def create_student_network():
    
    model = nn.Sequential()
    model.add_module("conv1", ConvBNRelu(3, 32, 7, 2, 3))
    model.add_module("max_pool", nn.MaxPool2d(3, stride=2))
    model.add_module("conv2", ConvBNRelu(32, 64, 3, 2, 1))
    model.add_module("pool", GlobalAveragePooling((2, 3)))
    model.add_module("fc", nn.Linear(64, 128))
    model.add_module("relu", nn.ReLU())
    model.add_module("logits", nn.Linear(128, 200))
    
    return model

In [43]:
student_model = create_student_network().to(device)
optimizer = torch.optim.Adam(student_model.parameters())

In [44]:
train_loop(student_model, optimizer, train_batch_gen, val_batch_gen, 30)

100%|██████████| 1563/1563 [01:57<00:00, 13.35it/s]


Epoch 1 of 30 took 118.810s
  training loss (in-iteration): 	4.708207
  validation accuracy: 			6.60 %


100%|██████████| 1563/1563 [01:55<00:00, 13.57it/s]


Epoch 2 of 30 took 116.828s
  training loss (in-iteration): 	4.326375
  validation accuracy: 			7.04 %


100%|██████████| 1563/1563 [01:55<00:00, 13.57it/s]


Epoch 3 of 30 took 116.861s
  training loss (in-iteration): 	4.168486
  validation accuracy: 			9.60 %


100%|██████████| 1563/1563 [01:54<00:00, 13.68it/s]


Epoch 4 of 30 took 116.014s
  training loss (in-iteration): 	4.059654
  validation accuracy: 			10.41 %


100%|██████████| 1563/1563 [01:54<00:00, 13.66it/s]


Epoch 5 of 30 took 116.101s
  training loss (in-iteration): 	3.968739
  validation accuracy: 			12.39 %


100%|██████████| 1563/1563 [01:55<00:00, 13.51it/s]


Epoch 6 of 30 took 117.360s
  training loss (in-iteration): 	3.897012
  validation accuracy: 			13.78 %


100%|██████████| 1563/1563 [01:55<00:00, 13.56it/s]


Epoch 7 of 30 took 117.276s
  training loss (in-iteration): 	3.841599
  validation accuracy: 			16.36 %


100%|██████████| 1563/1563 [01:55<00:00, 13.59it/s]


Epoch 8 of 30 took 116.704s
  training loss (in-iteration): 	3.789825
  validation accuracy: 			15.85 %


100%|██████████| 1563/1563 [01:55<00:00, 13.48it/s]


Epoch 9 of 30 took 117.609s
  training loss (in-iteration): 	3.751338
  validation accuracy: 			15.50 %


100%|██████████| 1563/1563 [01:55<00:00, 13.55it/s]


Epoch 10 of 30 took 117.335s
  training loss (in-iteration): 	3.711197
  validation accuracy: 			16.45 %


100%|██████████| 1563/1563 [01:54<00:00, 13.65it/s]


Epoch 11 of 30 took 116.195s
  training loss (in-iteration): 	3.676489
  validation accuracy: 			15.35 %


100%|██████████| 1563/1563 [01:55<00:00, 13.48it/s]


Epoch 12 of 30 took 117.599s
  training loss (in-iteration): 	3.654718
  validation accuracy: 			18.38 %


100%|██████████| 1563/1563 [01:55<00:00, 13.59it/s]


Epoch 13 of 30 took 116.663s
  training loss (in-iteration): 	3.621308
  validation accuracy: 			20.22 %


100%|██████████| 1563/1563 [01:56<00:00, 13.47it/s]


Epoch 14 of 30 took 117.740s
  training loss (in-iteration): 	3.597739
  validation accuracy: 			13.90 %


100%|██████████| 1563/1563 [01:56<00:00, 13.42it/s]


Epoch 15 of 30 took 118.196s
  training loss (in-iteration): 	3.572458
  validation accuracy: 			18.61 %


100%|██████████| 1563/1563 [01:55<00:00, 13.54it/s]


Epoch 16 of 30 took 117.272s
  training loss (in-iteration): 	3.555182
  validation accuracy: 			17.54 %


100%|██████████| 1563/1563 [01:56<00:00, 13.38it/s]


Epoch 17 of 30 took 118.459s
  training loss (in-iteration): 	3.528723
  validation accuracy: 			20.33 %


100%|██████████| 1563/1563 [01:55<00:00, 13.52it/s]


Epoch 18 of 30 took 117.276s
  training loss (in-iteration): 	3.511779
  validation accuracy: 			19.31 %


100%|██████████| 1563/1563 [01:54<00:00, 13.59it/s]


Epoch 19 of 30 took 116.911s
  training loss (in-iteration): 	3.497319
  validation accuracy: 			20.36 %


100%|██████████| 1563/1563 [01:55<00:00, 13.55it/s]


Epoch 20 of 30 took 117.036s
  training loss (in-iteration): 	3.479439
  validation accuracy: 			19.93 %


100%|██████████| 1563/1563 [01:56<00:00, 13.43it/s]


Epoch 21 of 30 took 118.094s
  training loss (in-iteration): 	3.464734
  validation accuracy: 			17.29 %


100%|██████████| 1563/1563 [01:56<00:00, 13.41it/s]


Epoch 22 of 30 took 118.457s
  training loss (in-iteration): 	3.450871
  validation accuracy: 			16.39 %


100%|██████████| 1563/1563 [01:53<00:00, 13.80it/s]


Epoch 23 of 30 took 114.908s
  training loss (in-iteration): 	3.439452
  validation accuracy: 			20.18 %


100%|██████████| 1563/1563 [01:54<00:00, 13.66it/s]


Epoch 24 of 30 took 116.101s
  training loss (in-iteration): 	3.424638
  validation accuracy: 			19.54 %


100%|██████████| 1563/1563 [01:54<00:00, 13.67it/s]


Epoch 25 of 30 took 116.040s
  training loss (in-iteration): 	3.417302
  validation accuracy: 			22.12 %


100%|██████████| 1563/1563 [01:54<00:00, 13.61it/s]


Epoch 26 of 30 took 116.513s
  training loss (in-iteration): 	3.404962
  validation accuracy: 			22.35 %


100%|██████████| 1563/1563 [01:52<00:00, 13.84it/s]


Epoch 27 of 30 took 114.556s
  training loss (in-iteration): 	3.392428
  validation accuracy: 			22.66 %


100%|██████████| 1563/1563 [01:54<00:00, 13.67it/s]


Epoch 28 of 30 took 116.044s
  training loss (in-iteration): 	3.379768
  validation accuracy: 			20.41 %


100%|██████████| 1563/1563 [01:55<00:00, 13.52it/s]


Epoch 29 of 30 took 117.239s
  training loss (in-iteration): 	3.367305
  validation accuracy: 			21.42 %


100%|██████████| 1563/1563 [01:59<00:00, 13.08it/s]


Epoch 30 of 30 took 121.177s
  training loss (in-iteration): 	3.356396
  validation accuracy: 			21.39 %


## 3. Knowledge distillation
![image info](https://miro.medium.com/max/875/1*WxFiH3XDY1-28tbyi4BGDA.png)

At this block you will have to retrain your student network using "knowledge distillation" technique. **Distill teacher network with student network, achieve at least +1% improvement in accuracy over student network accuracy.**

The training procedure is the same as for training the student network from scratch except the loss formulation.

Assume that 
- $z_i$ are logits predicted by the student network at the current step for an input image
- $v_i$ are logits predicted by the (freezed) teacher network
- $y_i$ are one-hot encoded label of the input image
- $p_i = \frac{\exp{z_i}}{\sum_j \exp{z_j}}$ - logits $z_i$ after sofrmax
- $q_i = \frac{\exp{\frac{z_i}{T}}}{\sum_j \exp{\frac{z_j}{T}}}$, where $T$ is softmax temperature
- $r_i = \frac{\exp{\frac{v_i}{T}}}{\sum_j \exp{\frac{v_j}{T}}}$, where $T$ is the same softmax temperature as for $q_i$

The loss for knowledge distillation: $$-\sum_i y_i \log p_i - \alpha \sum_i r_i \log q_i$$

$T$ and $\alpha$ are hyperparameters. 

- There is a good practice of using softmax with high temperature to obtain "soft" distributions, you can start with $T=10$. Check the [post](https://medium.com/mlearning-ai/softmax-temperature-5492e4007f71) with good visualizations on how the temperature affects the softmax output. 
- For $\alpha$ there is the following note in the original [paper](https://arxiv.org/pdf/1503.02531.pdf):

> Since the magnitudes of the gradients produced by the soft targets scale as $1/T^2$ it is important to multiply them by $T^2$ when using both hard and soft targets. This ensures that the relative contributions of the hard and soft targets remain roughly unchanged if the temperature used for distillation is changed while experimenting with meta-parameters.

In [70]:
model = model.train(False)

In [71]:
student_model = create_student_network().to(device)
optimizer = torch.optim.Adam(student_model.parameters())

In [72]:
def distillation_train_model(student_model, teacher_model, optimizer, train_batch_gen, T, alpha):
    student_model.train(True) 
    train_loss = []
    for (X_batch, y_batch) in tqdm.tqdm(train_batch_gen):
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        
        student_predictions = student_model(X_batch)
        teacher_predictions = teacher_model(X_batch)
        
        student_loss = compute_loss(student_predictions, y_batch)
        distillation_loss = compute_loss(student_predictions / T, F.softmax(teacher_predictions / T, dim=-1))
        
        loss = student_loss + alpha * distillation_loss
        loss.backward()
        optimizer.step()
        
        train_loss.append(loss.cpu().data.numpy())
        
    return np.mean(train_loss)

def distillation_train_loop(student_model, teacher_model, optimizer, train_data_generator, val_data_generator, T, alpha, num_epochs):
    for epoch in range(num_epochs):
        start_time = time.time()

        train_loss = distillation_train_model(student_model, teacher_model, optimizer, train_data_generator, T, alpha)

        val_accuracy = eval_model(student_model, val_data_generator)

        # Then we print the results for this epoch:
        print("Epoch {} of {} took {:.3f}s".format(epoch + 1, num_epochs, time.time() - start_time))
        print("  training loss (in-iteration): \t{:.6f}".format(train_loss))
        print("  validation accuracy: \t\t\t{:.2f} %".format(val_accuracy * 100))

In [73]:
# проверку на early stoping надо будет в будущем подключить и графики train loss и val accuracy прикрутить
T = 10 # 2
alpha = 100 # 4
distillation_train_loop(student_model, model, optimizer, train_batch_gen, val_batch_gen, T, alpha, 30)

100%|██████████| 1563/1563 [01:54<00:00, 13.68it/s]


Epoch 1 of 30 took 115.883s
  training loss (in-iteration): 	531.717834
  validation accuracy: 			5.10 %


100%|██████████| 1563/1563 [01:55<00:00, 13.56it/s]


Epoch 2 of 30 took 116.895s
  training loss (in-iteration): 	530.391968
  validation accuracy: 			9.20 %


100%|██████████| 1563/1563 [01:55<00:00, 13.53it/s]


Epoch 3 of 30 took 117.515s
  training loss (in-iteration): 	529.865417
  validation accuracy: 			10.63 %


100%|██████████| 1563/1563 [01:54<00:00, 13.64it/s]


Epoch 4 of 30 took 116.245s
  training loss (in-iteration): 	529.517578
  validation accuracy: 			12.37 %


100%|██████████| 1563/1563 [01:54<00:00, 13.64it/s]


Epoch 5 of 30 took 116.233s
  training loss (in-iteration): 	529.269531
  validation accuracy: 			14.36 %


100%|██████████| 1563/1563 [01:54<00:00, 13.63it/s]


Epoch 6 of 30 took 116.380s
  training loss (in-iteration): 	529.085266
  validation accuracy: 			12.40 %


100%|██████████| 1563/1563 [01:54<00:00, 13.60it/s]


Epoch 7 of 30 took 116.644s
  training loss (in-iteration): 	528.924500
  validation accuracy: 			11.89 %


100%|██████████| 1563/1563 [01:55<00:00, 13.48it/s]


Epoch 8 of 30 took 117.586s
  training loss (in-iteration): 	528.804688
  validation accuracy: 			12.68 %


100%|██████████| 1563/1563 [01:55<00:00, 13.58it/s]


Epoch 9 of 30 took 116.749s
  training loss (in-iteration): 	528.677551
  validation accuracy: 			16.52 %


100%|██████████| 1563/1563 [01:55<00:00, 13.55it/s]


Epoch 10 of 30 took 117.054s
  training loss (in-iteration): 	528.594177
  validation accuracy: 			18.68 %


100%|██████████| 1563/1563 [01:55<00:00, 13.54it/s]


Epoch 11 of 30 took 117.130s
  training loss (in-iteration): 	528.522400
  validation accuracy: 			18.73 %


100%|██████████| 1563/1563 [01:55<00:00, 13.57it/s]


Epoch 12 of 30 took 116.878s
  training loss (in-iteration): 	528.453125
  validation accuracy: 			18.07 %


100%|██████████| 1563/1563 [01:55<00:00, 13.57it/s]


Epoch 13 of 30 took 116.774s
  training loss (in-iteration): 	528.373657
  validation accuracy: 			13.03 %


100%|██████████| 1563/1563 [01:55<00:00, 13.57it/s]


Epoch 14 of 30 took 116.829s
  training loss (in-iteration): 	528.320374
  validation accuracy: 			18.10 %


100%|██████████| 1563/1563 [01:54<00:00, 13.61it/s]


Epoch 15 of 30 took 116.470s
  training loss (in-iteration): 	528.255737
  validation accuracy: 			18.04 %


100%|██████████| 1563/1563 [01:56<00:00, 13.44it/s]


Epoch 16 of 30 took 118.012s
  training loss (in-iteration): 	528.224304
  validation accuracy: 			20.54 %


100%|██████████| 1563/1563 [01:56<00:00, 13.45it/s]


Epoch 17 of 30 took 117.872s
  training loss (in-iteration): 	528.180420
  validation accuracy: 			19.50 %


100%|██████████| 1563/1563 [01:56<00:00, 13.37it/s]


Epoch 18 of 30 took 118.597s
  training loss (in-iteration): 	528.116821
  validation accuracy: 			20.40 %


100%|██████████| 1563/1563 [01:55<00:00, 13.50it/s]


Epoch 19 of 30 took 117.497s
  training loss (in-iteration): 	528.096436
  validation accuracy: 			22.27 %


100%|██████████| 1563/1563 [01:55<00:00, 13.59it/s]


Epoch 20 of 30 took 116.662s
  training loss (in-iteration): 	528.053284
  validation accuracy: 			19.82 %


100%|██████████| 1563/1563 [01:55<00:00, 13.57it/s]


Epoch 21 of 30 took 116.838s
  training loss (in-iteration): 	528.020935
  validation accuracy: 			19.95 %


100%|██████████| 1563/1563 [01:53<00:00, 13.82it/s]


Epoch 22 of 30 took 114.749s
  training loss (in-iteration): 	527.986389
  validation accuracy: 			20.49 %


100%|██████████| 1563/1563 [01:55<00:00, 13.59it/s]


Epoch 23 of 30 took 116.728s
  training loss (in-iteration): 	527.955444
  validation accuracy: 			22.84 %


100%|██████████| 1563/1563 [01:53<00:00, 13.74it/s]


Epoch 24 of 30 took 115.382s
  training loss (in-iteration): 	527.922974
  validation accuracy: 			21.36 %


100%|██████████| 1563/1563 [01:53<00:00, 13.76it/s]


Epoch 25 of 30 took 115.192s
  training loss (in-iteration): 	527.908081
  validation accuracy: 			23.81 %


100%|██████████| 1563/1563 [01:52<00:00, 13.87it/s]


Epoch 26 of 30 took 114.359s
  training loss (in-iteration): 	527.888672
  validation accuracy: 			22.12 %


100%|██████████| 1563/1563 [01:53<00:00, 13.81it/s]


Epoch 27 of 30 took 114.762s
  training loss (in-iteration): 	527.852905
  validation accuracy: 			21.28 %


100%|██████████| 1563/1563 [01:54<00:00, 13.62it/s]


Epoch 28 of 30 took 116.464s
  training loss (in-iteration): 	527.838867
  validation accuracy: 			20.68 %


100%|██████████| 1563/1563 [01:58<00:00, 13.16it/s]


Epoch 29 of 30 took 120.473s
  training loss (in-iteration): 	527.820190
  validation accuracy: 			23.08 %


100%|██████████| 1563/1563 [01:55<00:00, 13.56it/s]


Epoch 30 of 30 took 116.904s
  training loss (in-iteration): 	527.803955
  validation accuracy: 			21.96 %


**Final notes**:
- Please, don't cheat with early-early-early stopping while training of the student network. Make sure, it  converged.
- Logits still carry more information than the probabilities after softmax.
- Don't forget to use your teacher network in 'eval' mode. And don't forget your main objective.

**Future readings**
- ["Born again neural networks"](https://arxiv.org/pdf/1805.04770.pdf) - knowledge distillation may give benefits even when teacher and student networks have the same architecture.
- ["Prune your model before distill it"](https://arxiv.org/pdf/2109.14960.pdf) - pruning of the teacher model before distillation may improve quality of student model