In [None]:
import math
from typing import Dict, Optional, Tuple

import torch
from sklearn.datasets import make_classification
from sklearn.metrics import classification_report
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm

from squeezer.criterion import distill_loss
from squeezer.distiller import Distiller
from squeezer.policy import AbstractDistillationPolicy

%load_ext autoreload
%autoreload 2

In [None]:
torch.manual_seed(0xDEAD)

In [None]:
n_epochs = 200

In [None]:
def train(model, loader, n_epochs: int = 200):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    for epoch in tqdm(range(n_epochs)):
        for i, (data, labels) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()

In [None]:
def evaluate(model, loader):
    preds = []
    targets = []
    for data, labels in loader:
        outputs = model(data).argmax(-1)
        preds.append(outputs)
        targets.append(labels)
    preds = torch.cat(preds)
    targets = torch.cat(targets)
    print(classification_report(targets, preds, zero_division=0))

# Models
Объявляем модель-учитель побольше и модель-ученик поменьше.  
Тип возвращаемого значения должен наследоваться от класса `ModelOutput` (или быть им).

In [None]:
class Teacher(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )

    def forward(self, inputs):
        logits = self.network(inputs)
        return logits

In [None]:
class Student(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.network = nn.Linear(input_size, output_size)

    def forward(self, inputs):
        logits = self.network(inputs)
        return logits

# Data

In [None]:
def get_loaders(num_features: int = 64, num_classes: int = 4,
                batch_size: int = 64, train_size: float = 0.75):
    x, y = make_classification(
        1000, num_features,
        n_classes=num_classes,
        n_informative=int(num_features * 0.9),
        n_clusters_per_class=2,
        class_sep=4.0,
        random_state=0xDEAD
    )
    dataset = TensorDataset(
        torch.from_numpy(x).float(),
        torch.from_numpy(y).long()
    )
    dataset_length = len(x)
    train_size = int(dataset_length * train_size)
    val_size = dataset_length - train_size
    train, val = random_split(dataset, [train_size, val_size])
    return DataLoader(train, batch_size=batch_size), DataLoader(val, batch_size=batch_size)

In [None]:
num_features = 128
num_classes = 5

train_loader, val_loader = get_loaders(num_features, num_classes)

# Train Teacher model

In [None]:
teacher = Teacher(num_features, num_classes, hidden_size=128)

train(teacher, train_loader, n_epochs=n_epochs)
evaluate(teacher, val_loader)

# Train Student model without distillation

In [None]:
student = Student(num_features, num_classes)
train(student, train_loader, n_epochs=n_epochs)
evaluate(student, val_loader)

# Distiller

In [None]:
class CustomDistiller(Distiller):
    def teacher_forward(self, batch):
        return self.teacher(batch[0])
    
    def student_forward(self, batch):
        return self.student(batch[0])

# Distillation

## Basic distillation policy

In [None]:
LossDictT = Dict[str, float]


class BasicDistillationPolicy(AbstractDistillationPolicy):
    def __init__(self, temperature: float = 1.0, alpha: float = 0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha

    def forward(self, teacher_output, student_output, batch, epoch: int) -> Tuple[torch.Tensor, LossDictT]:
        loss_kld, loss_ce, overall = distill_loss(
            teacher_logits=teacher_output,
            student_logits=student_output,
            labels=batch[1],
            temperature=self.temperature,
            alpha=self.alpha
        )
        loss_dict = {
            'kld': loss_kld.item(),
            'cross_entropy': loss_ce.item(),
            'overall': overall.item(),
        }
        return overall, loss_dict

In [None]:
student = Student(num_features, num_classes)
optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4)

# Инициализируем политику функции потерь для дистилляции.
# В этом примере используется стандартная политика, при которой
# ученик учится сразу на распределение логитов учителя (KLD) и на мейнстрим задачу (CE).
# Для кастомизации политики, например, для использования других функций потерь или
# добавления адаптеров между аутпутами моделей, необходимо наследоваться от класса `AbstractDistillationPolicy`
policy = BasicDistillationPolicy(temperature=1.2, alpha=0.5)

distiller = CustomDistiller(teacher, student, policy, optimizer)
distiller(train_loader, val_loader, n_epochs=n_epochs)

In [None]:
evaluate(student, val_loader)

## Custom policy (MSE)

In [None]:
LossDictT = Dict[str, float]


class SingleMSEDistillationPolicy(AbstractDistillationPolicy):
    def forward(self, teacher_output, student_output, batch, epoch) -> Tuple[torch.Tensor, LossDictT]:
        loss_mse = nn.functional.mse_loss(student_output, teacher_output)
        loss_dict = {'mse': loss_mse.item()}
        return loss_mse, loss_dict

In [None]:
student = Student(num_features, num_classes)
optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4)

policy = SingleMSEDistillationPolicy()

distiller = CustomDistiller(teacher, student, policy, optimizer)
distiller(train_loader, val_loader, n_epochs=n_epochs)

In [None]:
evaluate(student, val_loader)

## Advanced policy
**CE + MSE with scale decay**

In [None]:
from math import exp

In [None]:
LossDictT = Dict[str, float]


class AdvancedDistillationPolicy(AbstractDistillationPolicy):
    def __init__(self, n_epochs: int, adapter_mapping: Optional[Tuple[int, int]] = None):
        super().__init__()
        self.n_epochs = n_epochs
    
    def forward(self, teacher_output, student_output, batch, epoch) -> Tuple[torch.Tensor, LossDictT]:
        alpha = math.exp((epoch + 1) / self.n_epochs)
        loss_mse = nn.functional.mse_loss(student_output, teacher_output)
        loss_ce = nn.functional.cross_entropy(student_output, batch[1])
        overall = loss_mse * alpha + loss_ce * (1 - alpha)
        scalars_dict = {
            'mse': loss_mse.item(),
            'cross_entropy': loss_ce.item(),
            'overall': overall.item(),
            'alpha': alpha,
        }
        return overall, scalars_dict

In [None]:
student = Student(num_features, num_classes)
optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4)

policy = AdvancedDistillationPolicy(n_epochs)

distiller = CustomDistiller(teacher, student, policy, optimizer)
distiller(train_loader, val_loader, n_epochs=n_epochs)

In [None]:
evaluate(student, val_loader)