In [1]:
from typing import Dict, Optional, Tuple

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

from squeezer.criterion import distill_loss
from squeezer.distiller import Distiller
from squeezer.policy import AbstractDistillationPolicy

%load_ext autoreload
%autoreload 2

In [2]:
torch.manual_seed(0xDEAD)

<torch._C.Generator at 0x7fa930876350>

# Models
Объявляем модель-учитель побольше и модель-ученик поменьше.  
Тип возвращаемого значения должен наследоваться от класса `ModelOutput` (или быть им).

In [3]:
class Teacher(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )

    def forward(self, inputs):
        logits = self.network(inputs)
        return logits

In [4]:
class Student(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.network = nn.Linear(input_size, output_size)

    def forward(self, inputs):
        logits = self.network(inputs)
        return logits

# Data

In [5]:
def get_loader(length: int = 10000, num_features: int = 20, num_classes: int = 4, batch_size: int = 64):
    data_tensor = torch.randn(length, num_features)
    target_tensor = torch.randint(high=num_classes, size=(length,))
    dataset = TensorDataset(data_tensor, target_tensor)
    return DataLoader(dataset, batch_size=batch_size)

# Distiller

In [6]:
class CustomDistiller(Distiller):
    def teacher_forward(self, batch):
        return self.teacher(batch[0])
    
    def student_forward(self, batch):
        return self.student(batch[0])

# Distillation

## Basic distillation policy

In [7]:
LossDictT = Dict[str, float]


class BasicDistillationPolicy(AbstractDistillationPolicy):
    def __init__(self, temperature: float = 1.0, alpha: float = 0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha

    def forward(self, teacher_output, student_output, batch, epoch: int) -> Tuple[torch.Tensor, LossDictT]:
        loss_kld, loss_ce, overall = distill_loss(
            teacher_logits=teacher_output,
            student_logits=student_output,
            labels=batch[1],
            temperature=self.temperature,
            alpha=self.alpha
        )
        loss_dict = {
            'kld': loss_kld.item(),
            'cross_entropy': loss_ce.item(),
            'overall': overall.item(),
        }
        return overall, loss_dict

In [8]:
input_size = 32
num_classes = 4
n_epochs = 50

train_loader = get_loader(num_features=input_size, num_classes=num_classes)
teacher = Teacher(input_size, num_classes, hidden_size=10)
student = Student(input_size, num_classes)
optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4)

# Инициализируем политику функции потерь для дистилляции.
# В этом примере используется стандартная политика, при которой
# ученик учится сразу на распределение логитов учителя (KLD) и на мейнстрим задачу (CE).
# Для кастомизации политики, например, для использования других функций потерь или
# добавления адаптеров между аутпутами моделей, необходимо наследоваться от класса `AbstractDistillationPolicy`
policy = BasicDistillationPolicy(temperature=0.1, alpha=1.0)

distiller = CustomDistiller(teacher, student, policy, optimizer)
distiller(train_loader, n_epochs=n_epochs)

Epoch 0: kld=0.00027   cross_entropy=0.00948   overall=0.00027        
Epoch 1: kld=0.00013   cross_entropy=0.00920   overall=0.00013        
Epoch 2: kld=0.00004   cross_entropy=0.00908   overall=0.00004      
Epoch 3: kld=0.00001   cross_entropy=0.00906   overall=0.00001        
Epoch 4: kld=0.00000   cross_entropy=0.00907   overall=0.00000        
Epoch 5: kld=0.00000   cross_entropy=0.00908   overall=0.00000        
Epoch 6: kld=0.00000   cross_entropy=0.00910   overall=0.00000        
Epoch 7: kld=0.00000   cross_entropy=0.00911   overall=0.00000        
Epoch 8: kld=0.00000   cross_entropy=0.00912   overall=0.00000        
Epoch 9: kld=0.00000   cross_entropy=0.00913   overall=0.00000        
Epoch 10: kld=0.00000   cross_entropy=0.00914   overall=0.00000        
Epoch 11: kld=0.00000   cross_entropy=0.00914   overall=0.00000        
Epoch 12: kld=0.00000   cross_entropy=0.00915   overall=0.00000        
Epoch 13: kld=0.00000   cross_entropy=0.00915   overall=0.00000        
Epoc

### Student vs Teacher

In [9]:
with torch.inference_mode():
    random_input = torch.randn(5, input_size)
    teacher_output = teacher(random_input)
    student_output = student(random_input)
    print('Teacher output:')
    print(teacher_output)
    print('Student output:')
    print(student_output)

Teacher output:
tensor([[-0.1217,  0.1959, -0.2506, -0.5039],
        [-0.1324,  0.0405, -0.1797, -0.4393],
        [-0.0844,  0.0537, -0.1318, -0.4583],
        [-0.1500,  0.1800, -0.2762, -0.5421],
        [-0.1290,  0.1328, -0.2370, -0.5321]])
Student output:
tensor([[-0.1160,  0.1547, -0.2141, -0.4610],
        [ 0.1311,  0.2426,  0.1211, -0.1439],
        [ 0.1691,  0.3280,  0.0969, -0.2133],
        [-0.0200,  0.3092, -0.1212, -0.3685],
        [ 0.3329,  0.6296,  0.2104, -0.0705]])


In [10]:
mse = nn.functional.mse_loss(student_output, teacher_output).item()
print(f'MSE: {mse:.5f}')

MSE: 0.07529


## Custom policy (MSE)

In [11]:
LossDictT = Dict[str, float]


class SingleMSEDistillationPolicy(AbstractDistillationPolicy):
    def forward(self, teacher_output, student_output, batch, epoch) -> Tuple[torch.Tensor, LossDictT]:
        loss_mse = nn.functional.mse_loss(student_output, teacher_output)
        loss_dict = {'mse': loss_mse.item()}
        return loss_mse, loss_dict

In [12]:
input_size = 32
num_classes = 4
n_epochs = 50

train_loader = get_loader(num_features=input_size, num_classes=num_classes)
teacher = Teacher(input_size, num_classes, hidden_size=10)
student = Student(input_size, num_classes)
optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4)

policy = SingleMSEDistillationPolicy()

distiller = CustomDistiller(teacher, student, policy, optimizer)
distiller(train_loader, n_epochs=n_epochs)

Epoch 0: mse=0.00135                                                  
Epoch 1: mse=0.00064                                                  
Epoch 2: mse=0.00028                                                  
Epoch 3: mse=0.00012                                                  
Epoch 4: mse=0.00005                                                  
Epoch 5: mse=0.00002                                                  
Epoch 6: mse=0.00001                                                  
Epoch 7: mse=0.00001                                                  
Epoch 8: mse=0.00001                                                  
Epoch 9: mse=0.00001                                                  
Epoch 10: mse=0.00001                                                  
Epoch 11: mse=0.00001                                                  
Epoch 12: mse=0.00001                                                  
Epoch 13: mse=0.00001                                                  
Ep

### Student vs Teacher

In [13]:
with torch.inference_mode():
    random_input = torch.randn(5, input_size)
    teacher_output = teacher(random_input)
    student_output = student(random_input)
    print('Teacher output:')
    print(teacher_output)
    print('Student output:')
    print(student_output)

Teacher output:
tensor([[ 0.2362,  0.2263,  0.1356, -0.0325],
        [ 0.1388,  0.2534,  0.2083, -0.0543],
        [ 0.2643,  0.1796,  0.0818,  0.0101],
        [ 0.2930,  0.1302, -0.0004,  0.0206],
        [ 0.2378,  0.1847,  0.1942, -0.0717]])
Student output:
tensor([[ 0.2168,  0.2256,  0.1496, -0.0585],
        [ 0.2011,  0.2101,  0.1793, -0.0609],
        [ 0.2844,  0.1557,  0.0257,  0.0510],
        [ 0.3263,  0.1230,  0.0051,  0.0220],
        [ 0.2074,  0.1624,  0.1819, -0.0587]])


In [14]:
mse = nn.functional.mse_loss(student_output, teacher_output).item()
print(f'MSE: {mse:.5f}')

MSE: 0.00083


## Advanced policy
**1. CE + MSE with scale decay**  
**2. Layer adapter**

In [15]:
LossDictT = Dict[str, float]


class AdvancedDistillationPolicy(AbstractDistillationPolicy):
    def __init__(self, n_epochs: int, adapter_mapping: Optional[Tuple[int, int]] = None):
        super().__init__()
        self.n_epochs = n_epochs
        self.adapter = nn.Identity() if adapter_mapping is None else nn.Linear(*adapter_mapping)
    
    def forward(self, teacher_output, student_output, batch, epoch) -> Tuple[torch.Tensor, LossDictT]:
        alpha = (epoch + 1) / self.n_epochs
        projected_teacher_logits = self.adapter(teacher_output)

        loss_mse = nn.functional.mse_loss(student_output, projected_teacher_logits)
        loss_ce = nn.functional.cross_entropy(student_output, batch[1])
        overall = loss_mse * alpha + loss_ce * (1 - alpha)
        scalars_dict = {
            'mse': loss_mse.item(),
            'cross_entropy': loss_ce.item(),
            'overall': overall.item(),
            'alpha': alpha,
        }
        return overall, scalars_dict

In [16]:
input_size = 32
num_teacher_logits = 6
num_student_logits = 4
n_epochs = 50

train_loader = get_loader(num_features=input_size, num_classes=num_classes)
teacher = Teacher(input_size, num_teacher_logits, hidden_size=10)
student = Student(input_size, num_student_logits)
optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4)

policy = AdvancedDistillationPolicy(n_epochs, adapter_mapping=(num_teacher_logits, num_student_logits))

distiller = CustomDistiller(teacher, student, policy, optimizer)
distiller(train_loader, n_epochs=n_epochs)

Epoch 0: mse=0.00238   cross_entropy=0.00998   overall=0.00982   alpha=0.00013
Epoch 1: mse=0.00196   cross_entropy=0.00961   overall=0.00930   alpha=0.00026
Epoch 2: mse=0.00164   cross_entropy=0.00933   overall=0.00887   alpha=0.00038
Epoch 3: mse=0.00139   cross_entropy=0.00913   overall=0.00851   alpha=0.00051
Epoch 4: mse=0.00120   cross_entropy=0.00898   overall=0.00820   alpha=0.00064
Epoch 5: mse=0.00105   cross_entropy=0.00887   overall=0.00793   alpha=0.00077
Epoch 6: mse=0.00091   cross_entropy=0.00880   overall=0.00769   alpha=0.00090
Epoch 7: mse=0.00079   cross_entropy=0.00875   overall=0.00748   alpha=0.00103
Epoch 8: mse=0.00069   cross_entropy=0.00872   overall=0.00727   alpha=0.00115
Epoch 9: mse=0.00059   cross_entropy=0.00870   overall=0.00708   alpha=0.00128
Epoch 10: mse=0.00051   cross_entropy=0.00869   overall=0.00689   alpha=0.00141
Epoch 11: mse=0.00043   cross_entropy=0.00869   overall=0.00670   alpha=0.00154
Epoch 12: mse=0.00036   cross_entropy=0.00869   ov

### Student vs Teacher

In [17]:
with torch.no_grad():
    random_input = torch.randn(5, input_size)
    teacher_output = teacher(random_input)
    student_output = student(random_input)
    print('Teacher output:')
    print(teacher_output)
    print('Teacher output after adapter:')
    print(policy.adapter(teacher_output))
    print('Student output:')
    print(student_output)

Teacher output:
tensor([[-0.1329,  0.1282, -0.1064,  0.2148,  0.0931,  0.2372],
        [-0.1190,  0.1222, -0.0096,  0.1823,  0.0670,  0.2634],
        [-0.0450,  0.0447, -0.0888,  0.1382,  0.0143,  0.1679],
        [-0.1969,  0.2049, -0.0978,  0.1848,  0.0004, -0.0018],
        [-0.0853, -0.0314, -0.0410,  0.1831, -0.0424,  0.2571]])
Teacher after adapter output:
tensor([[-0.0648, -0.2931,  0.3177, -0.1903],
        [-0.0909, -0.2952,  0.2953, -0.1913],
        [-0.0404, -0.2404,  0.2275, -0.2135],
        [ 0.0391, -0.2174,  0.2448, -0.2471],
        [-0.0678, -0.2120,  0.2082, -0.2116]])
Student output:
tensor([[-0.0589, -0.2942,  0.3198, -0.1966],
        [-0.0669, -0.2720,  0.2676, -0.2041],
        [-0.0342, -0.2686,  0.2836, -0.2150],
        [ 0.0141, -0.2429,  0.2646, -0.2340],
        [-0.0799, -0.2555,  0.2584, -0.2100]])


In [18]:
mse = nn.functional.mse_loss(student_output, policy.adapter(teacher_output)).item()
print(f'MSE: {mse:.5f}')

MSE: 0.00062
