In [1]:
from typing import Dict, Optional, Tuple

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

from squeezer.containers import Batch, ModelOutput
from squeezer.distiller import Distiller
from squeezer.policy import AbstractDistillationPolicy, DistillationPolicy

In [2]:
torch.manual_seed(0xDEAD)

<torch._C.Generator at 0x7fd62513c350>

# Models
Объявляем модель-учитель побольше и модель-ученик поменьше.  
Тип возвращаемого значения должен наследоваться от класса `ModelOutput` (или быть им).

In [3]:
class Teacher(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )

    def forward(self, inputs):
        outputs = self.network(inputs)
        return ModelOutput(logits=outputs)

In [4]:
class Student(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.network = nn.Linear(input_size, output_size)

    def forward(self, inputs):
        outputs = self.network(inputs)
        return ModelOutput(logits=outputs)

# Data

In [5]:
def collate_fn(batch: list):
    data, target = zip(*batch)
    return Batch(
        data=torch.stack(data),
        target=torch.stack(target)
    )

def get_loader(length: int = 10000, num_features: int = 20, num_classes: int = 4, batch_size: int = 64):
    data_tensor = torch.randn(length, num_features)
    target_tensor = torch.randint(high=num_classes, size=(length,))
    dataset = TensorDataset(data_tensor, target_tensor)
    return DataLoader(dataset, collate_fn=collate_fn, batch_size=batch_size)

# Distillation

## Basic policy

In [6]:
input_size = 32
num_classes = 4
n_epochs = 50

train_loader = get_loader(num_features=input_size, num_classes=num_classes)
teacher = Teacher(input_size, num_classes, hidden_size=10)
student = Student(input_size, num_classes)

# Инициализируем политику функции потерь для дистилляции.
# В этом примере используется стандартная политика, при которой
# ученик учится сразу на распределение логитов учителя (KLD) и на мейнстрим задачу (CE).
# Для кастомизации политики, например, для использования других функций потерь или
# добавления адаптеров между аутпутами моделей, необходимо наследоваться от класса `AbstractDistillationPolicy`
policy = DistillationPolicy(temperature=0.5, alpha=1.0)

distiller = Distiller(teacher, student, policy)
distiller(train_loader, n_epochs=n_epochs)

Epoch 0: kld=0.12780	cross_entropy=1.50291	overall=0.12780
Epoch 1: kld=0.10538	cross_entropy=1.47928	overall=0.10538
Epoch 2: kld=0.08571	cross_entropy=1.45930	overall=0.08571
Epoch 3: kld=0.06874	cross_entropy=1.44273	overall=0.06874
Epoch 4: kld=0.05441	cross_entropy=1.42929	overall=0.05441
Epoch 5: kld=0.04256	cross_entropy=1.41866	overall=0.04256
Epoch 6: kld=0.03297	cross_entropy=1.41048	overall=0.03297
Epoch 7: kld=0.02536	cross_entropy=1.40438	overall=0.02536
Epoch 8: kld=0.01943	cross_entropy=1.40000	overall=0.01943
Epoch 9: kld=0.01489	cross_entropy=1.39700	overall=0.01489
Epoch 10: kld=0.01146	cross_entropy=1.39509	overall=0.01146
Epoch 11: kld=0.00890	cross_entropy=1.39400	overall=0.00890
Epoch 12: kld=0.00699	cross_entropy=1.39353	overall=0.00699
Epoch 13: kld=0.00557	cross_entropy=1.39350	overall=0.00557
Epoch 14: kld=0.00451	cross_entropy=1.39378	overall=0.00451
Epoch 15: kld=0.00370	cross_entropy=1.39426	overall=0.00370
Epoch 16: kld=0.00308	cross_entropy=1.39487	overal

### Student vs Teacher

In [7]:
with torch.inference_mode():
    random_input = torch.randn(5, input_size)
    teacher_output = teacher(random_input)
    student_output = student(random_input)
    print('Teacher output:')
    print(teacher_output.logits)
    print('Student output:')
    print(student_output.logits)

Teacher output:
tensor([[-0.1217,  0.1959, -0.2506, -0.5039],
        [-0.1324,  0.0405, -0.1797, -0.4393],
        [-0.0844,  0.0537, -0.1318, -0.4583],
        [-0.1500,  0.1800, -0.2762, -0.5421],
        [-0.1290,  0.1328, -0.2370, -0.5321]])
Student output:
tensor([[-0.0721,  0.2120, -0.1855, -0.4485],
        [ 0.1190,  0.2183,  0.1295, -0.1383],
        [ 0.2430,  0.3996,  0.1840, -0.1314],
        [-0.0246,  0.3031, -0.1254, -0.3758],
        [ 0.3515,  0.6541,  0.2204, -0.0737]])


In [8]:
mse = nn.functional.mse_loss(student_output.logits, teacher_output.logits).item()
print(f'MSE: {mse:.5f}')

MSE: 0.08638


## Custom policy (MSE)

In [9]:
LossDictT = Dict[str, float]


class SingleMSEDistillationPolicy(AbstractDistillationPolicy):
    def forward(self, teacher_output, student_output, target, epoch) -> Tuple[torch.Tensor, LossDictT]:
        loss_mse = nn.functional.mse_loss(student_output.logits, teacher_output.logits)
        loss_dict = {'mse': loss_mse.item()}
        return loss_mse, loss_dict

In [10]:
input_size = 32
num_classes = 4
n_epochs = 50

train_loader = get_loader(num_features=input_size, num_classes=num_classes)
teacher = Teacher(input_size, num_classes, hidden_size=10)
student = Student(input_size, num_classes)

policy = SingleMSEDistillationPolicy()

distiller = Distiller(teacher, student, policy)
distiller(train_loader, n_epochs=n_epochs)

Epoch 0: mse=0.34848
Epoch 1: mse=0.28101
Epoch 2: mse=0.22472
Epoch 3: mse=0.17805
Epoch 4: mse=0.13968
Epoch 5: mse=0.10846
Epoch 6: mse=0.08330
Epoch 7: mse=0.06327
Epoch 8: mse=0.04751
Epoch 9: mse=0.03527
Epoch 10: mse=0.02590
Epoch 11: mse=0.01883
Epoch 12: mse=0.01358
Epoch 13: mse=0.00974
Epoch 14: mse=0.00700
Epoch 15: mse=0.00506
Epoch 16: mse=0.00373
Epoch 17: mse=0.00284
Epoch 18: mse=0.00225
Epoch 19: mse=0.00187
Epoch 20: mse=0.00164
Epoch 21: mse=0.00149
Epoch 22: mse=0.00141
Epoch 23: mse=0.00136
Epoch 24: mse=0.00133
Epoch 25: mse=0.00131
Epoch 26: mse=0.00130
Epoch 27: mse=0.00130
Epoch 28: mse=0.00130
Epoch 29: mse=0.00130
Epoch 30: mse=0.00130
Epoch 31: mse=0.00130
Epoch 32: mse=0.00130
Epoch 33: mse=0.00130
Epoch 34: mse=0.00130
Epoch 35: mse=0.00130
Epoch 36: mse=0.00130
Epoch 37: mse=0.00130
Epoch 38: mse=0.00130
Epoch 39: mse=0.00130
Epoch 40: mse=0.00130
Epoch 41: mse=0.00130
Epoch 42: mse=0.00130
Epoch 43: mse=0.00130
Epoch 44: mse=0.00130
Epoch 45: mse=0.0013

### Student vs Teacher

In [11]:
with torch.inference_mode():
    random_input = torch.randn(5, input_size)
    teacher_output = teacher(random_input)
    student_output = student(random_input)
    print('Teacher output:')
    print(teacher_output.logits)
    print('Student output:')
    print(student_output.logits)

Teacher output:
tensor([[ 0.2362,  0.2263,  0.1356, -0.0325],
        [ 0.1388,  0.2534,  0.2083, -0.0543],
        [ 0.2643,  0.1796,  0.0818,  0.0101],
        [ 0.2930,  0.1302, -0.0004,  0.0206],
        [ 0.2378,  0.1847,  0.1942, -0.0717]])
Student output:
tensor([[ 0.2172,  0.2265,  0.1470, -0.0575],
        [ 0.2024,  0.2067,  0.1743, -0.0568],
        [ 0.2816,  0.1567,  0.0250,  0.0484],
        [ 0.3273,  0.1214,  0.0045,  0.0219],
        [ 0.2050,  0.1613,  0.1797, -0.0603]])


In [12]:
mse = nn.functional.mse_loss(student_output.logits, teacher_output.logits).item()
print(f'MSE: {mse:.5f}')

MSE: 0.00086


## Advanced policy
**1. CE + MSE with scale decay**  
**2. Layer adapter**

In [13]:
LossDictT = Dict[str, float]


class AdvancedDistillationPolicy(AbstractDistillationPolicy):
    def __init__(self, n_epochs: int, adapter_mapping: Optional[Tuple[int, int]] = None):
        super().__init__()
        self.n_epochs = n_epochs
        self.adapter = nn.Identity() if adapter_mapping is None else nn.Linear(*adapter_mapping)
    
    def forward(self, teacher_output, student_output, target, epoch) -> Tuple[torch.Tensor, LossDictT]:
        alpha = (epoch + 1) / self.n_epochs
        projected_teacher_logits = self.adapter(teacher_output.logits)

        loss_mse = nn.functional.mse_loss(student_output.logits, projected_teacher_logits)
        loss_ce = nn.functional.cross_entropy(student_output.logits, target)
        overall = loss_mse * alpha + loss_ce * (1 - alpha)
        scalars_dict = {
            'mse': loss_mse.item(),
            'cross_entropy': loss_ce.item(),
            'overall': overall.item(),
            'alpha': alpha,
        }
        return overall, scalars_dict

In [14]:
input_size = 32
num_teacher_logits = 6
num_student_logits = 4
n_epochs = 50

train_loader = get_loader(num_features=input_size, num_classes=num_classes)
teacher = Teacher(input_size, num_teacher_logits, hidden_size=10)
student = Student(input_size, num_student_logits)

policy = AdvancedDistillationPolicy(n_epochs, adapter_mapping=(num_teacher_logits, num_student_logits))

distiller = Distiller(teacher, student, policy)
distiller(train_loader, n_epochs=n_epochs)

Epoch 0: mse=0.35466	cross_entropy=1.49833	overall=1.47546	alpha=0.02000
Epoch 1: mse=0.32912	cross_entropy=1.48502	overall=1.43878	alpha=0.04000
Epoch 2: mse=0.30518	cross_entropy=1.47289	overall=1.40283	alpha=0.06000
Epoch 3: mse=0.28256	cross_entropy=1.46182	overall=1.36748	alpha=0.08000
Epoch 4: mse=0.26118	cross_entropy=1.45177	overall=1.33271	alpha=0.10000
Epoch 5: mse=0.24095	cross_entropy=1.44270	overall=1.29849	alpha=0.12000
Epoch 6: mse=0.22184	cross_entropy=1.43457	overall=1.26478	alpha=0.14000
Epoch 7: mse=0.20376	cross_entropy=1.42732	overall=1.23155	alpha=0.16000
Epoch 8: mse=0.18669	cross_entropy=1.42092	overall=1.19876	alpha=0.18000
Epoch 9: mse=0.17057	cross_entropy=1.41531	overall=1.16637	alpha=0.20000
Epoch 10: mse=0.15538	cross_entropy=1.41044	overall=1.13433	alpha=0.22000
Epoch 11: mse=0.14107	cross_entropy=1.40625	overall=1.10261	alpha=0.24000
Epoch 12: mse=0.12763	cross_entropy=1.40268	overall=1.07117	alpha=0.26000
Epoch 13: mse=0.11503	cross_entropy=1.39969	over

### Student vs Teacher

In [15]:
with torch.no_grad():
    random_input = torch.randn(5, input_size)
    teacher_output = teacher(random_input)
    student_output = student(random_input)
    print('Teacher output:')
    print(teacher_output.logits)
    print('Teacher after adapter output:')
    print(policy.adapter(teacher_output.logits))
    print('Student output:')
    print(student_output.logits)

Teacher output:
tensor([[-0.1329,  0.1282, -0.1064,  0.2148,  0.0931,  0.2372],
        [-0.1190,  0.1222, -0.0096,  0.1823,  0.0670,  0.2634],
        [-0.0450,  0.0447, -0.0888,  0.1382,  0.0143,  0.1679],
        [-0.1969,  0.2049, -0.0978,  0.1848,  0.0004, -0.0018],
        [-0.0853, -0.0314, -0.0410,  0.1831, -0.0424,  0.2571]])
Teacher after adapter output:
tensor([[-0.0648, -0.2931,  0.3177, -0.1903],
        [-0.0909, -0.2952,  0.2953, -0.1913],
        [-0.0404, -0.2404,  0.2275, -0.2135],
        [ 0.0391, -0.2174,  0.2448, -0.2471],
        [-0.0678, -0.2120,  0.2082, -0.2116]])
Student output:
tensor([[-0.0602, -0.2900,  0.3141, -0.1946],
        [-0.0661, -0.2699,  0.2632, -0.2029],
        [-0.0341, -0.2654,  0.2787, -0.2143],
        [ 0.0131, -0.2404,  0.2622, -0.2330],
        [-0.0802, -0.2542,  0.2534, -0.2071]])


In [16]:
mse = nn.functional.mse_loss(student_output.logits, policy.adapter(teacher_output.logits)).item()
print(f'MSE: {mse:.5f}')

MSE: 0.00057
