In [1]:
import math
from typing import Dict, Optional, Tuple

import torch
from sklearn.datasets import make_classification
from sklearn.metrics import classification_report
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm

from squeezer.criterion import distill_loss
from squeezer.distiller import Distiller
from squeezer.policy import AbstractDistillationPolicy

%load_ext autoreload
%autoreload 2

In [2]:
torch.manual_seed(0xDEAD)

<torch._C.Generator at 0x7fb09e07c390>

In [3]:
n_epochs = 200

In [4]:
def train(model, loader, n_epochs: int = 200):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    for epoch in tqdm(range(n_epochs)):
        for i, (data, labels) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()

In [5]:
def evaluate(model, loader):
    preds = []
    targets = []
    for data, labels in loader:
        outputs = model(data).argmax(-1)
        preds.append(outputs)
        targets.append(labels)
    preds = torch.cat(preds)
    targets = torch.cat(targets)
    print(classification_report(targets, preds, zero_division=0))

# Models
Объявляем модель-учитель побольше и модель-ученик поменьше.  
Тип возвращаемого значения должен наследоваться от класса `ModelOutput` (или быть им).

In [6]:
class Teacher(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )

    def forward(self, inputs):
        logits = self.network(inputs)
        return logits

In [7]:
class Student(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.network = nn.Linear(input_size, output_size)

    def forward(self, inputs):
        logits = self.network(inputs)
        return logits

# Data

In [8]:
def get_loaders(num_features: int = 64, num_classes: int = 4,
                batch_size: int = 64, train_size: float = 0.75):
    x, y = make_classification(
        1000, num_features,
        n_classes=num_classes,
        n_informative=int(num_features * 0.9),
        n_clusters_per_class=2,
        class_sep=4.0,
        random_state=0xDEAD
    )
    dataset = TensorDataset(
        torch.from_numpy(x).float(),
        torch.from_numpy(y).long()
    )
    dataset_length = len(x)
    train_size = int(dataset_length * train_size)
    val_size = dataset_length - train_size
    train, val = random_split(dataset, [train_size, val_size])
    return DataLoader(train, batch_size=batch_size), DataLoader(val, batch_size=batch_size)

In [9]:
num_features = 128
num_classes = 5

train_loader, val_loader = get_loaders(num_features, num_classes)

# Train Teacher model

In [10]:
teacher = Teacher(num_features, num_classes, hidden_size=128)

train(teacher, train_loader, n_epochs=n_epochs)
evaluate(teacher, val_loader)

100%|██████████| 200/200 [00:04<00:00, 48.82it/s]

              precision    recall  f1-score   support

           0       0.98      1.00      0.99        57
           1       1.00      1.00      1.00        55
           2       0.98      1.00      0.99        45
           3       1.00      0.98      0.99        47
           4       1.00      0.98      0.99        46

    accuracy                           0.99       250
   macro avg       0.99      0.99      0.99       250
weighted avg       0.99      0.99      0.99       250






# Train Student model without distillation

In [11]:
student = Student(num_features, num_classes)
train(student, train_loader, n_epochs=n_epochs)
evaluate(student, val_loader)

100%|██████████| 200/200 [00:01<00:00, 136.24it/s]

              precision    recall  f1-score   support

           0       0.97      0.98      0.97        57
           1       0.96      0.96      0.96        55
           2       0.93      0.96      0.95        45
           3       0.98      0.94      0.96        47
           4       0.98      0.98      0.98        46

    accuracy                           0.96       250
   macro avg       0.96      0.96      0.96       250
weighted avg       0.96      0.96      0.96       250






# Distiller

In [12]:
class CustomDistiller(Distiller):
    def teacher_forward(self, batch):
        return self.teacher(batch[0])
    
    def student_forward(self, batch):
        return self.student(batch[0])

# Distillation

## Basic distillation policy

In [13]:
LossDictT = Dict[str, float]


class BasicDistillationPolicy(AbstractDistillationPolicy):
    def __init__(self, temperature: float = 1.0, alpha: float = 0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha

    def forward(self, teacher_output, student_output, batch, epoch: int) -> Tuple[torch.Tensor, LossDictT]:
        loss_kld, loss_ce, overall = distill_loss(
            teacher_logits=teacher_output,
            student_logits=student_output,
            labels=batch[1],
            temperature=self.temperature,
            alpha=self.alpha
        )
        loss_dict = {
            'kld': loss_kld.item(),
            'cross_entropy': loss_ce.item(),
            'overall': overall.item(),
        }
        return overall, loss_dict

In [14]:
student = Student(num_features, num_classes)
optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4)

# Инициализируем политику функции потерь для дистилляции.
# В этом примере используется стандартная политика, при которой
# ученик учится сразу на распределение логитов учителя (KLD) и на мейнстрим задачу (CE).
# Для кастомизации политики, например, для использования других функций потерь или
# добавления адаптеров между аутпутами моделей, необходимо наследоваться от класса `AbstractDistillationPolicy`
policy = BasicDistillationPolicy(temperature=1.2, alpha=0.5)

distiller = CustomDistiller(teacher, student, policy, optimizer)
distiller(train_loader, val_loader, n_epochs=n_epochs)

[0th epoch]: 100%|██████████| 12/12 [00:00<00:00, 402.63it/s, loss=5.86]
Validation: 100%|██████████| 4/4 [00:00<00:00, 463.98it/s, loss=5.35]
[1th epoch]: 100%|██████████| 12/12 [00:00<00:00, 425.34it/s, loss=4.94]
Validation: 100%|██████████| 4/4 [00:00<00:00, 495.84it/s, loss=4.52]
[2th epoch]: 100%|██████████| 12/12 [00:00<00:00, 391.47it/s, loss=4.13]
Validation: 100%|██████████| 4/4 [00:00<00:00, 325.55it/s, loss=3.78]
[3th epoch]: 100%|██████████| 12/12 [00:00<00:00, 373.41it/s, loss=3.44]
Validation: 100%|██████████| 4/4 [00:00<00:00, 478.87it/s, loss=3.14]
[4th epoch]: 100%|██████████| 12/12 [00:00<00:00, 414.06it/s, loss=2.86]
Validation: 100%|██████████| 4/4 [00:00<00:00, 435.55it/s, loss=2.6]
[5th epoch]: 100%|██████████| 12/12 [00:00<00:00, 425.64it/s, loss=2.38]
Validation: 100%|██████████| 4/4 [00:00<00:00, 481.52it/s, loss=2.16]
[6th epoch]: 100%|██████████| 12/12 [00:00<00:00, 387.85it/s, loss=1.99]
Validation: 100%|██████████| 4/4 [00:00<00:00, 515.65it/s, loss=1.8]
[

In [15]:
evaluate(student, val_loader)

              precision    recall  f1-score   support

           0       0.93      1.00      0.97        57
           1       1.00      0.96      0.98        55
           2       0.96      1.00      0.98        45
           3       1.00      0.89      0.94        47
           4       0.96      0.98      0.97        46

    accuracy                           0.97       250
   macro avg       0.97      0.97      0.97       250
weighted avg       0.97      0.97      0.97       250



## Custom policy (MSE)

In [16]:
LossDictT = Dict[str, float]


class SingleMSEDistillationPolicy(AbstractDistillationPolicy):
    def forward(self, teacher_output, student_output, batch, epoch) -> Tuple[torch.Tensor, LossDictT]:
        loss_mse = nn.functional.mse_loss(student_output, teacher_output)
        loss_dict = {'mse': loss_mse.item()}
        return loss_mse, loss_dict

In [17]:
student = Student(num_features, num_classes)
optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4)

policy = SingleMSEDistillationPolicy()

distiller = CustomDistiller(teacher, student, policy, optimizer)
distiller(train_loader, val_loader, n_epochs=n_epochs)

[0th epoch]: 100%|██████████| 12/12 [00:00<00:00, 438.12it/s, loss=68.2]
Validation: 100%|██████████| 4/4 [00:00<00:00, 499.99it/s, loss=62.1]
[1th epoch]: 100%|██████████| 12/12 [00:00<00:00, 421.76it/s, loss=61.9]
Validation: 100%|██████████| 4/4 [00:00<00:00, 543.69it/s, loss=56.6]
[2th epoch]: 100%|██████████| 12/12 [00:00<00:00, 444.70it/s, loss=56.3]
Validation: 100%|██████████| 4/4 [00:00<00:00, 493.27it/s, loss=51.7]
[3th epoch]: 100%|██████████| 12/12 [00:00<00:00, 428.10it/s, loss=51.2]
Validation: 100%|██████████| 4/4 [00:00<00:00, 520.01it/s, loss=47.2]
[4th epoch]: 100%|██████████| 12/12 [00:00<00:00, 459.31it/s, loss=46.6]
Validation: 100%|██████████| 4/4 [00:00<00:00, 510.07it/s, loss=43.2]
[5th epoch]: 100%|██████████| 12/12 [00:00<00:00, 435.33it/s, loss=42.5]
Validation: 100%|██████████| 4/4 [00:00<00:00, 542.55it/s, loss=39.6]
[6th epoch]: 100%|██████████| 12/12 [00:00<00:00, 424.36it/s, loss=38.8]
Validation: 100%|██████████| 4/4 [00:00<00:00, 530.00it/s, loss=36.4]

In [18]:
evaluate(student, val_loader)

              precision    recall  f1-score   support

           0       0.98      1.00      0.99        57
           1       1.00      1.00      1.00        55
           2       0.98      1.00      0.99        45
           3       1.00      0.98      0.99        47
           4       1.00      0.98      0.99        46

    accuracy                           0.99       250
   macro avg       0.99      0.99      0.99       250
weighted avg       0.99      0.99      0.99       250



## Advanced policy
**CE + MSE with scale decay**

In [19]:
from math import exp

In [20]:
LossDictT = Dict[str, float]


class AdvancedDistillationPolicy(AbstractDistillationPolicy):
    def __init__(self, n_epochs: int, adapter_mapping: Optional[Tuple[int, int]] = None):
        super().__init__()
        self.n_epochs = n_epochs
    
    def forward(self, teacher_output, student_output, batch, epoch) -> Tuple[torch.Tensor, LossDictT]:
        alpha = math.exp((epoch + 1) / self.n_epochs)
        loss_mse = nn.functional.mse_loss(student_output, teacher_output)
        loss_ce = nn.functional.cross_entropy(student_output, batch[1])
        overall = loss_mse * alpha + loss_ce * (1 - alpha)
        scalars_dict = {
            'mse': loss_mse.item(),
            'cross_entropy': loss_ce.item(),
            'overall': overall.item(),
            'alpha': alpha,
        }
        return overall, scalars_dict

In [21]:
student = Student(num_features, num_classes)
optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4)

policy = AdvancedDistillationPolicy(n_epochs)

distiller = CustomDistiller(teacher, student, policy, optimizer)
distiller(train_loader, val_loader, n_epochs=n_epochs)

[0th epoch]: 100%|██████████| 12/12 [00:00<00:00, 350.09it/s, loss=74.2]
Validation: 100%|██████████| 4/4 [00:00<00:00, 350.85it/s, loss=75]
[1th epoch]: 100%|██████████| 12/12 [00:00<00:00, 355.71it/s, loss=68]
Validation: 100%|██████████| 4/4 [00:00<00:00, 409.25it/s, loss=68.8]
[2th epoch]: 100%|██████████| 12/12 [00:00<00:00, 350.06it/s, loss=62.3]
Validation: 100%|██████████| 4/4 [00:00<00:00, 401.65it/s, loss=63.1]
[3th epoch]: 100%|██████████| 12/12 [00:00<00:00, 375.07it/s, loss=57.1]
Validation: 100%|██████████| 4/4 [00:00<00:00, 462.42it/s, loss=57.9]
[4th epoch]: 100%|██████████| 12/12 [00:00<00:00, 387.04it/s, loss=52.4]
Validation: 100%|██████████| 4/4 [00:00<00:00, 421.20it/s, loss=53.1]
[5th epoch]: 100%|██████████| 12/12 [00:00<00:00, 371.14it/s, loss=48.1]
Validation: 100%|██████████| 4/4 [00:00<00:00, 438.20it/s, loss=48.8]
[6th epoch]: 100%|██████████| 12/12 [00:00<00:00, 359.00it/s, loss=44.1]
Validation: 100%|██████████| 4/4 [00:00<00:00, 395.96it/s, loss=44.8]
[7t

In [22]:
evaluate(student, val_loader)

              precision    recall  f1-score   support

           0       0.98      1.00      0.99        57
           1       1.00      1.00      1.00        55
           2       0.98      1.00      0.99        45
           3       1.00      0.98      0.99        47
           4       1.00      0.98      0.99        46

    accuracy                           0.99       250
   macro avg       0.99      0.99      0.99       250
weighted avg       0.99      0.99      0.99       250

