In [1]:
import math
from typing import Dict, Optional, Tuple

import torch
from sklearn.datasets import make_classification
from sklearn.metrics import classification_report
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm

from squeezer.criterion import distill_loss
from squeezer.distiller import Distiller
from squeezer.policy import AbstractDistillationPolicy

%load_ext autoreload
%autoreload 2

In [2]:
torch.manual_seed(0xDEAD)

<torch._C.Generator at 0x7fd11587b390>

In [3]:
n_epochs = 200

In [4]:
def train(model, loader, n_epochs: int = 200):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    for epoch in tqdm(range(n_epochs)):
        for i, (data, labels) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()

In [5]:
def evaluate(model, loader):
    preds = []
    targets = []
    for data, labels in loader:
        outputs = model(data).argmax(-1)
        preds.append(outputs)
        targets.append(labels)
    preds = torch.cat(preds)
    targets = torch.cat(targets)
    print(classification_report(targets, preds, zero_division=0))

# Models
Объявляем модель-учитель побольше и модель-ученик поменьше.  
Тип возвращаемого значения должен наследоваться от класса `ModelOutput` (или быть им).

In [6]:
class Teacher(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )

    def forward(self, inputs):
        logits = self.network(inputs)
        return logits

In [7]:
class Student(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.network = nn.Linear(input_size, output_size)

    def forward(self, inputs):
        logits = self.network(inputs)
        return logits

# Data

In [8]:
def get_loaders(num_features: int = 64, num_classes: int = 4,
                batch_size: int = 64, train_size: float = 0.75):
    x, y = make_classification(
        1000, num_features,
        n_classes=num_classes,
        n_informative=int(num_features * 0.9),
        n_clusters_per_class=2,
        class_sep=4.0,
        random_state=0xDEAD
    )
    dataset = TensorDataset(
        torch.from_numpy(x).float(),
        torch.from_numpy(y).long()
    )
    dataset_length = len(x)
    train_size = int(dataset_length * train_size)
    val_size = dataset_length - train_size
    train, val = random_split(dataset, [train_size, val_size])
    return DataLoader(train, batch_size=batch_size), DataLoader(val, batch_size=batch_size)

In [9]:
num_features = 128
num_classes = 5

train_loader, val_loader = get_loaders(num_features, num_classes)

# Train Teacher model

In [10]:
teacher = Teacher(num_features, num_classes, hidden_size=128)

train(teacher, train_loader, n_epochs=n_epochs)
evaluate(teacher, val_loader)

100%|██████████| 200/200 [00:03<00:00, 56.05it/s]

              precision    recall  f1-score   support

           0       0.98      1.00      0.99        57
           1       1.00      1.00      1.00        55
           2       0.98      1.00      0.99        45
           3       1.00      0.98      0.99        47
           4       1.00      0.98      0.99        46

    accuracy                           0.99       250
   macro avg       0.99      0.99      0.99       250
weighted avg       0.99      0.99      0.99       250






# Train Student model without distillation

In [11]:
student = Student(num_features, num_classes)
train(student, train_loader, n_epochs=n_epochs)
evaluate(student, val_loader)

100%|██████████| 200/200 [00:01<00:00, 133.73it/s]

              precision    recall  f1-score   support

           0       0.97      0.98      0.97        57
           1       0.96      0.96      0.96        55
           2       0.93      0.96      0.95        45
           3       0.98      0.94      0.96        47
           4       0.98      0.98      0.98        46

    accuracy                           0.96       250
   macro avg       0.96      0.96      0.96       250
weighted avg       0.96      0.96      0.96       250






# Distiller

In [12]:
class CustomDistiller(Distiller):
    def teacher_forward(self, batch):
        return self.teacher(batch[0])
    
    def student_forward(self, batch):
        return self.student(batch[0])

# Distillation

## Basic distillation policy

In [13]:
LossDictT = Dict[str, float]


class BasicDistillationPolicy(AbstractDistillationPolicy):
    def __init__(self, temperature: float = 1.0, alpha: float = 0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha

    def forward(self, teacher_output, student_output, batch, epoch: int) -> Tuple[torch.Tensor, LossDictT]:
        loss_kld, loss_ce, overall = distill_loss(
            teacher_logits=teacher_output,
            student_logits=student_output,
            labels=batch[1],
            temperature=self.temperature,
            alpha=self.alpha
        )
        loss_dict = {
            'kld': loss_kld.item(),
            'cross_entropy': loss_ce.item(),
            'overall': overall.item(),
        }
        return overall, loss_dict

In [14]:
student = Student(num_features, num_classes)
optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4)

# Инициализируем политику функции потерь для дистилляции.
# В этом примере используется стандартная политика, при которой
# ученик учится сразу на распределение логитов учителя (KLD) и на мейнстрим задачу (CE).
# Для кастомизации политики, например, для использования других функций потерь или
# добавления адаптеров между аутпутами моделей, необходимо наследоваться от класса `AbstractDistillationPolicy`
policy = BasicDistillationPolicy(temperature=1.2, alpha=0.5)

distiller = CustomDistiller(teacher, student, policy, optimizer)
distiller(train_loader, val_loader, n_epochs=n_epochs)

[0th epoch]: 100%|██████████| 12/12 [00:00<00:00, 380.05it/s, batch loss=5.15]
[1th epoch]: 100%|██████████| 12/12 [00:00<00:00, 366.23it/s, batch loss=4.29]
[2th epoch]: 100%|██████████| 12/12 [00:00<00:00, 351.01it/s, batch loss=3.52]
[3th epoch]: 100%|██████████| 12/12 [00:00<00:00, 408.62it/s, batch loss=2.88]
[4th epoch]: 100%|██████████| 12/12 [00:00<00:00, 386.05it/s, batch loss=2.35]
[5th epoch]: 100%|██████████| 12/12 [00:00<00:00, 415.31it/s, batch loss=1.93]
[6th epoch]: 100%|██████████| 12/12 [00:00<00:00, 380.61it/s, batch loss=1.61]
[7th epoch]: 100%|██████████| 12/12 [00:00<00:00, 409.88it/s, batch loss=1.35]
[8th epoch]: 100%|██████████| 12/12 [00:00<00:00, 383.77it/s, batch loss=1.14]
[9th epoch]: 100%|██████████| 12/12 [00:00<00:00, 422.92it/s, batch loss=0.971]
[10th epoch]: 100%|██████████| 12/12 [00:00<00:00, 361.09it/s, batch loss=0.828]
[11th epoch]: 100%|██████████| 12/12 [00:00<00:00, 353.03it/s, batch loss=0.707]
[12th epoch]: 100%|██████████| 12/12 [00:00<00:

In [15]:
evaluate(student, val_loader)

              precision    recall  f1-score   support

           0       0.93      1.00      0.97        57
           1       1.00      0.96      0.98        55
           2       0.96      1.00      0.98        45
           3       1.00      0.89      0.94        47
           4       0.96      0.98      0.97        46

    accuracy                           0.97       250
   macro avg       0.97      0.97      0.97       250
weighted avg       0.97      0.97      0.97       250



## Custom policy (MSE)

In [16]:
LossDictT = Dict[str, float]


class SingleMSEDistillationPolicy(AbstractDistillationPolicy):
    def forward(self, teacher_output, student_output, batch, epoch) -> Tuple[torch.Tensor, LossDictT]:
        loss_mse = nn.functional.mse_loss(student_output, teacher_output)
        loss_dict = {'mse': loss_mse.item()}
        return loss_mse, loss_dict

In [23]:
student = Student(num_features, num_classes)
optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4)

policy = SingleMSEDistillationPolicy()

distiller = CustomDistiller(teacher, student, policy, optimizer)
distiller(train_loader, val_loader, n_epochs=n_epochs)

[0th epoch]: 100%|██████████| 12/12 [00:00<00:00, 310.41it/s, loss=78.7]
[1th epoch]: 100%|██████████| 12/12 [00:00<00:00, 285.68it/s, loss=71.5]
[2th epoch]: 100%|██████████| 12/12 [00:00<00:00, 357.37it/s, loss=65]
[3th epoch]: 100%|██████████| 12/12 [00:00<00:00, 357.59it/s, loss=59]
[4th epoch]: 100%|██████████| 12/12 [00:00<00:00, 359.22it/s, loss=53.6]
[5th epoch]: 100%|██████████| 12/12 [00:00<00:00, 224.45it/s, loss=48.8]
[6th epoch]: 100%|██████████| 12/12 [00:00<00:00, 401.40it/s, loss=44.4]
[7th epoch]: 100%|██████████| 12/12 [00:00<00:00, 341.66it/s, loss=40.5]
[8th epoch]: 100%|██████████| 12/12 [00:00<00:00, 341.18it/s, loss=36.9]
[9th epoch]: 100%|██████████| 12/12 [00:00<00:00, 317.45it/s, loss=33.8]
[10th epoch]: 100%|██████████| 12/12 [00:00<00:00, 354.54it/s, loss=30.9]
[11th epoch]: 100%|██████████| 12/12 [00:00<00:00, 331.88it/s, loss=28.4]
[12th epoch]: 100%|██████████| 12/12 [00:00<00:00, 355.72it/s, loss=26.1]
[13th epoch]: 100%|██████████| 12/12 [00:00<00:00, 3

In [18]:
evaluate(student, val_loader)

              precision    recall  f1-score   support

           0       0.98      1.00      0.99        57
           1       1.00      1.00      1.00        55
           2       0.98      1.00      0.99        45
           3       1.00      0.98      0.99        47
           4       1.00      0.98      0.99        46

    accuracy                           0.99       250
   macro avg       0.99      0.99      0.99       250
weighted avg       0.99      0.99      0.99       250



## Advanced policy
**CE + MSE with scale decay**

In [19]:
from math import exp

In [20]:
LossDictT = Dict[str, float]


class AdvancedDistillationPolicy(AbstractDistillationPolicy):
    def __init__(self, n_epochs: int, adapter_mapping: Optional[Tuple[int, int]] = None):
        super().__init__()
        self.n_epochs = n_epochs
    
    def forward(self, teacher_output, student_output, batch, epoch) -> Tuple[torch.Tensor, LossDictT]:
        alpha = math.exp((epoch + 1) / self.n_epochs)
        loss_mse = nn.functional.mse_loss(student_output, teacher_output)
        loss_ce = nn.functional.cross_entropy(student_output, batch[1])
        overall = loss_mse * alpha + loss_ce * (1 - alpha)
        scalars_dict = {
            'mse': loss_mse.item(),
            'cross_entropy': loss_ce.item(),
            'overall': overall.item(),
            'alpha': alpha,
        }
        return overall, scalars_dict

In [21]:
student = Student(num_features, num_classes)
optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4)

policy = AdvancedDistillationPolicy(n_epochs)

distiller = CustomDistiller(teacher, student, policy, optimizer)
distiller(train_loader, val_loader, n_epochs=n_epochs)

[0th epoch]: 100%|██████████| 12/12 [00:00<00:00, 428.17it/s, batch loss=73.5]
[1th epoch]: 100%|██████████| 12/12 [00:00<00:00, 411.34it/s, batch loss=67.3]
[2th epoch]: 100%|██████████| 12/12 [00:00<00:00, 446.28it/s, batch loss=61.7]
[3th epoch]: 100%|██████████| 12/12 [00:00<00:00, 343.67it/s, batch loss=56.4]
[4th epoch]: 100%|██████████| 12/12 [00:00<00:00, 358.42it/s, batch loss=51.7]
[5th epoch]: 100%|██████████| 12/12 [00:00<00:00, 397.20it/s, batch loss=47.4]
[6th epoch]: 100%|██████████| 12/12 [00:00<00:00, 397.02it/s, batch loss=43.5]
[7th epoch]: 100%|██████████| 12/12 [00:00<00:00, 380.60it/s, batch loss=40]
[8th epoch]: 100%|██████████| 12/12 [00:00<00:00, 335.31it/s, batch loss=36.8]
[9th epoch]: 100%|██████████| 12/12 [00:00<00:00, 381.54it/s, batch loss=33.9]
[10th epoch]: 100%|██████████| 12/12 [00:00<00:00, 407.07it/s, batch loss=31.3]
[11th epoch]: 100%|██████████| 12/12 [00:00<00:00, 359.76it/s, batch loss=29]
[12th epoch]: 100%|██████████| 12/12 [00:00<00:00, 410

In [22]:
evaluate(student, val_loader)

              precision    recall  f1-score   support

           0       0.98      1.00      0.99        57
           1       1.00      1.00      1.00        55
           2       0.98      1.00      0.99        45
           3       1.00      0.98      0.99        47
           4       1.00      0.98      0.99        46

    accuracy                           0.99       250
   macro avg       0.99      0.99      0.99       250
weighted avg       0.99      0.99      0.99       250

