In [1]:
from typing import Dict, Optional, Tuple

import torch
from sklearn.datasets import make_classification
from sklearn.metrics import classification_report
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm

from squeezer.criterion import distill_loss
from squeezer.distiller import Distiller
from squeezer.policy import AbstractDistillationPolicy

%load_ext autoreload
%autoreload 2

In [2]:
torch.manual_seed(0xDEAD)

<torch._C.Generator at 0x7ff5e417b370>

In [3]:
n_epochs = 200

In [4]:
def train(model, loader, n_epochs: int = 200):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    for epoch in tqdm(range(n_epochs)):
        for i, (data, labels) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()

In [5]:
def evaluate(model, loader):
    preds = []
    targets = []
    for data, labels in loader:
        outputs = model(data).argmax(-1)
        preds.append(outputs)
        targets.append(labels)
    preds = torch.cat(preds)
    targets = torch.cat(targets)
    print(classification_report(targets, preds, zero_division=0))

# Models
Объявляем модель-учитель побольше и модель-ученик поменьше.  
Тип возвращаемого значения должен наследоваться от класса `ModelOutput` (или быть им).

In [6]:
class Teacher(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )

    def forward(self, inputs):
        logits = self.network(inputs)
        return logits

In [7]:
class Student(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.network = nn.Linear(input_size, output_size)

    def forward(self, inputs):
        logits = self.network(inputs)
        return logits

# Data

In [8]:
def get_loaders(num_features: int = 64, num_classes: int = 4,
                batch_size: int = 64, train_size: float = 0.75):
    x, y = make_classification(
        1000, num_features,
        n_classes=num_classes,
        n_informative=int(num_features * 0.9),
        n_clusters_per_class=2,
        class_sep=4.0,
        random_state=0xDEAD
    )
    dataset = TensorDataset(
        torch.from_numpy(x).float(),
        torch.from_numpy(y).long()
    )
    dataset_length = len(x)
    train_size = int(dataset_length * train_size)
    val_size = dataset_length - train_size
    train, val = random_split(dataset, [train_size, val_size])
    return DataLoader(train, batch_size=batch_size), DataLoader(val, batch_size=batch_size)

In [9]:
num_features = 128
num_classes = 5

train_loader, val_loader = get_loaders(num_features, num_classes)

# Train Teacher model

In [10]:
teacher = Teacher(num_features, num_classes, hidden_size=128)

train(teacher, train_loader, n_epochs=n_epochs)
evaluate(teacher, val_loader)

100%|██████████| 200/200 [00:03<00:00, 52.44it/s]

              precision    recall  f1-score   support

           0       0.98      1.00      0.99        57
           1       1.00      1.00      1.00        55
           2       0.98      1.00      0.99        45
           3       1.00      0.98      0.99        47
           4       1.00      0.98      0.99        46

    accuracy                           0.99       250
   macro avg       0.99      0.99      0.99       250
weighted avg       0.99      0.99      0.99       250






# Train Student model without distillation

In [11]:
student = Student(num_features, num_classes)
train(student, train_loader, n_epochs=n_epochs)
evaluate(student, val_loader)

100%|██████████| 200/200 [00:01<00:00, 123.67it/s]

              precision    recall  f1-score   support

           0       0.97      0.98      0.97        57
           1       0.96      0.96      0.96        55
           2       0.93      0.96      0.95        45
           3       0.98      0.94      0.96        47
           4       0.98      0.98      0.98        46

    accuracy                           0.96       250
   macro avg       0.96      0.96      0.96       250
weighted avg       0.96      0.96      0.96       250






# Distiller

In [12]:
class CustomDistiller(Distiller):
    def teacher_forward(self, batch):
        return self.teacher(batch[0])
    
    def student_forward(self, batch):
        return self.student(batch[0])

# Distillation

## Basic distillation policy

In [13]:
LossDictT = Dict[str, float]


class BasicDistillationPolicy(AbstractDistillationPolicy):
    def __init__(self, temperature: float = 1.0, alpha: float = 0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha

    def forward(self, teacher_output, student_output, batch, epoch: int) -> Tuple[torch.Tensor, LossDictT]:
        loss_kld, loss_ce, overall = distill_loss(
            teacher_logits=teacher_output,
            student_logits=student_output,
            labels=batch[1],
            temperature=self.temperature,
            alpha=self.alpha
        )
        loss_dict = {
            'kld': loss_kld.item(),
            'cross_entropy': loss_ce.item(),
            'overall': overall.item(),
        }
        return overall, loss_dict

In [14]:
student = Student(num_features, num_classes)
optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4)

# Инициализируем политику функции потерь для дистилляции.
# В этом примере используется стандартная политика, при которой
# ученик учится сразу на распределение логитов учителя (KLD) и на мейнстрим задачу (CE).
# Для кастомизации политики, например, для использования других функций потерь или
# добавления адаптеров между аутпутами моделей, необходимо наследоваться от класса `AbstractDistillationPolicy`
policy = BasicDistillationPolicy(temperature=1.2, alpha=0.5)

distiller = CustomDistiller(teacher, student, policy, optimizer)
distiller(train_loader, n_epochs=n_epochs)

Epoch 0: kld=0.53002   cross_entropy=0.43194   overall=0.48098
Epoch 1: kld=0.44321   cross_entropy=0.35951   overall=0.40136
Epoch 2: kld=0.36695   cross_entropy=0.29611   overall=0.33153
Epoch 3: kld=0.30198   cross_entropy=0.24205   overall=0.27202
Epoch 4: kld=0.24861   cross_entropy=0.19781   overall=0.22321
Epoch 5: kld=0.20615   cross_entropy=0.16302   overall=0.18458
Epoch 6: kld=0.17253   cross_entropy=0.13570   overall=0.15412
Epoch 7: kld=0.14562   cross_entropy=0.11391   overall=0.12976
Epoch 8: kld=0.12375   cross_entropy=0.09623   overall=0.10999
Epoch 9: kld=0.10569   cross_entropy=0.08162   overall=0.09366
Epoch 10: kld=0.09064   cross_entropy=0.06940   overall=0.08002
Epoch 11: kld=0.07801   cross_entropy=0.05916   overall=0.06858
Epoch 12: kld=0.06741   cross_entropy=0.05057   overall=0.05899
Epoch 13: kld=0.05846   cross_entropy=0.04336   overall=0.05091
Epoch 14: kld=0.05089   cross_entropy=0.03728   overall=0.04408
Epoch 15: kld=0.04443   cross_entropy=0.03212   ov

In [15]:
evaluate(student, val_loader)

              precision    recall  f1-score   support

           0       0.93      1.00      0.97        57
           1       1.00      0.96      0.98        55
           2       0.96      1.00      0.98        45
           3       1.00      0.89      0.94        47
           4       0.96      0.98      0.97        46

    accuracy                           0.97       250
   macro avg       0.97      0.97      0.97       250
weighted avg       0.97      0.97      0.97       250



## Custom policy (MSE)

In [16]:
LossDictT = Dict[str, float]


class SingleMSEDistillationPolicy(AbstractDistillationPolicy):
    def forward(self, teacher_output, student_output, batch, epoch) -> Tuple[torch.Tensor, LossDictT]:
        loss_mse = nn.functional.mse_loss(student_output, teacher_output)
        loss_dict = {'mse': loss_mse.item()}
        return loss_mse, loss_dict

In [17]:
student = Student(num_features, num_classes)
optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4)

policy = SingleMSEDistillationPolicy()

distiller = CustomDistiller(teacher, student, policy, optimizer)
distiller(train_loader, n_epochs=n_epochs)

Epoch 0: mse=6.79457                                     
Epoch 1: mse=6.16706                                     
Epoch 2: mse=5.59082                                     
Epoch 3: mse=5.06709                                     
Epoch 4: mse=4.59360                                     
Epoch 5: mse=4.16680                                     
Epoch 6: mse=3.78282                                     
Epoch 7: mse=3.43788                                     
Epoch 8: mse=3.12836                                     
Epoch 9: mse=2.85094                                     
Epoch 10: mse=2.60250                                     
Epoch 11: mse=2.38020                                     
Epoch 12: mse=2.18141                                     
Epoch 13: mse=2.00375                                     
Epoch 14: mse=1.84501                                     
Epoch 15: mse=1.70321                                     
Epoch 16: mse=1.57654                                     
Epoch 1

In [18]:
evaluate(student, val_loader)

              precision    recall  f1-score   support

           0       0.98      1.00      0.99        57
           1       1.00      1.00      1.00        55
           2       0.98      1.00      0.99        45
           3       1.00      0.98      0.99        47
           4       1.00      0.98      0.99        46

    accuracy                           0.99       250
   macro avg       0.99      0.99      0.99       250
weighted avg       0.99      0.99      0.99       250



## Advanced policy
**CE + MSE with scale decay**

In [19]:
from math import exp

In [20]:
LossDictT = Dict[str, float]


class AdvancedDistillationPolicy(AbstractDistillationPolicy):
    def __init__(self, n_epochs: int, adapter_mapping: Optional[Tuple[int, int]] = None):
        super().__init__()
        self.n_epochs = n_epochs
    
    def forward(self, teacher_output, student_output, batch, epoch) -> Tuple[torch.Tensor, LossDictT]:
        alpha = exp((epoch + 1) / self.n_epochs)
        loss_mse = nn.functional.mse_loss(student_output, teacher_output)
        loss_ce = nn.functional.cross_entropy(student_output, batch[1])
        overall = loss_mse * alpha + loss_ce * (1 - alpha)
        scalars_dict = {
            'mse': loss_mse.item(),
            'cross_entropy': loss_ce.item(),
            'overall': overall.item(),
            'alpha': alpha,
        }
        return overall, scalars_dict

In [21]:
student = Student(num_features, num_classes)
optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4)

policy = AdvancedDistillationPolicy(n_epochs)

distiller = CustomDistiller(teacher, student, policy, optimizer)
distiller(train_loader, n_epochs=n_epochs)

Epoch 0: mse=6.53153   cross_entropy=0.52086   overall=6.56166   alpha=0.09222
Epoch 1: mse=5.90396   cross_entropy=0.43576   overall=5.95892   alpha=0.09269
Epoch 2: mse=5.32849   cross_entropy=0.35798   overall=5.40361   alpha=0.09315
Epoch 3: mse=4.80595   cross_entropy=0.28936   overall=4.89719   alpha=0.09362
Epoch 4: mse=4.33409   cross_entropy=0.23106   overall=4.43796   alpha=0.09409
Epoch 5: mse=3.90946   cross_entropy=0.18301   overall=4.02295   alpha=0.09456
Epoch 6: mse=3.52830   cross_entropy=0.14486   overall=3.64882   alpha=0.09503
Epoch 7: mse=3.18693   cross_entropy=0.11538   overall=3.31229   alpha=0.09551
Epoch 8: mse=2.88185   cross_entropy=0.09211   overall=3.01026   alpha=0.09599
Epoch 9: mse=2.60978   cross_entropy=0.07313   overall=2.73983   alpha=0.09647
Epoch 10: mse=2.36764   cross_entropy=0.05762   overall=2.49825   alpha=0.09695
Epoch 11: mse=2.15257   cross_entropy=0.04525   overall=2.28288   alpha=0.09744
Epoch 12: mse=1.96192   cross_entropy=0.03534   ov

In [22]:
evaluate(student, val_loader)

              precision    recall  f1-score   support

           0       0.98      1.00      0.99        57
           1       1.00      1.00      1.00        55
           2       0.98      1.00      0.99        45
           3       1.00      0.98      0.99        47
           4       1.00      0.98      0.99        46

    accuracy                           0.99       250
   macro avg       0.99      0.99      0.99       250
weighted avg       0.99      0.99      0.99       250

