# Методы компрессии нейронных сетей

## Лекция №4 - Дистилляция весов моделей
- Принцип работы методов дистилляции моделей
- Отличия от обучения с нуля
- Практика - Попробуем обучить модель по принципу дистилляции весов

## ДЗ №4
Попытаться применить данный метод к своим моделям

## Домашняя работа
Рассматривается кастомная сверточная нейронная сеть.

Проверяются следующие метрики производительности:
- число параметров модели;
- вес файла модели;
- время инференса;
- целевая метрика.

In [8]:
import time

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms

from tqdm.auto import trange

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
def measure_size(model):
  param_size = 0
  for param in model.model.parameters():
      param_size += param.nelement() * param.element_size()
  buffer_size = 0
  for buffer in model.model.buffers():
      buffer_size += buffer.nelement() * buffer.element_size()

  size_all_mb = (param_size + buffer_size) / 1024**2
  print('model size: {:.3f}MB'.format(size_all_mb))

In [10]:
def count_parameters(model):
  return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [11]:
def train(model, n_epochs=10):
  model.to(device)
  optim = torch.optim.AdamW(model.parameters())

  best_epoch, best_accuracy = -1, 0
  for epoch in trange(n_epochs):
    model.train()
    for input, target in train_loader:
      input, target = input.to(device), target.to(device)
      pred = model(input)
      loss = F.cross_entropy(pred, target)

      optim.zero_grad()
      loss.backward()
      optim. step()

    model.eval()
    correct, total = 0, 0
    for input, target in test_loader:
      input, target = input.to(device), target.to(device)
      pred = model(input)
      pred = pred.argmax(dim=1)
      correct += (pred == target).sum()
      total += target.numel()

    accuracy = 100 * correct / total
    print(f"Epoch {epoch}: accuracy {accuracy:.1f}%")
    if accuracy > best_accuracy:
      best_epoch = epoch
      best_accuracy = accuracy

  print(f"Best accuracy {best_accuracy:.1f}% after epoch {best_epoch}")

In [12]:
def train_distill(student_model, teacher_model, n_epochs=10, alpha=0.7, T=5):
  student_model.to(device)
  teacher_model.to(device)
  optim = torch.optim. AdamW(student_model.parameters())

  best_epoch, best_accuracy = -1, 0
  for epoch in trange(n_epochs):
    student_model.train()
    for input, target in train_loader:
      input, target = input.to(device), target.to(device)
      student_pred = student_model(input)
      teacher_pred = teacher_model(input)

      student_logprobs = F.log_softmax(student_pred / T, dim=-1)
      teacher_probs = F.softmax(teacher_pred / T, dim=-1)
      distill_loss = F.kl_div(student_logprobs, teacher_probs, reduction="batchmean")

      ce_loss = F.cross_entropy(student_pred, target)
      loss = alpha * distill_loss * (T ** 2) + (1 - alpha) * ce_loss

      optim. zero_grad()
      loss.backward()
      optim. step()

    student_model.eval()
    correct, total = 0, 0
    for input, target in test_loader:
      input, target = input.to(device), target.to(device)
      pred = student_model(input)
      pred = pred.argmax(dim=1)
      correct += (pred == target).sum()
      total += target.numel()

    accuracy = 100 * correct / total
    print(f"Epoch {epoch}: accuracy {accuracy:.1f}%")
    if accuracy > best_accuracy:
      best_epoch = epoch
      best_accuracy = accuracy

  print(f"Best accuracy {best_accuracy:.1f}% after epoch {best_epoch}")

In [13]:
def inference(model):
  model.to(device)

  start = time.time()
  for input, target in train_loader:
    input, target = input.to(device), target.to(device)
    _ = model(input)
    break
  total_time_model = (time.time() - start) / 1024 * 1000

  return total_time_model

## Загрузка датасета

In [14]:
train_dataset = torchvision.datasets.CIFAR10(
    root="dataset/",
    train=True,
    transform=transforms.ToTensor(),
    download=True
)

test_dataset = torchvision.datasets.CIFAR10(
  root="dataset/",
  train=False,
  transform=transforms.ToTensor(),
  download=True
)

train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to dataset/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 81192935.11it/s]


Extracting dataset/cifar-10-python.tar.gz to dataset/
Files already downloaded and verified


## Модель

In [15]:
class CNN_Block(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.block = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU()
    )

  def forward(self, input):
    return self.block(input)

In [16]:
class CNN(nn.Module):
  def __init__(self, in_channels=3, inter_channels=None, image_size=32, n_classes=10):
    super().__init__()

    if inter_channels is None:
      inter_channels = [64, 128, 256]

    model = []
    for out_channels in inter_channels:
      model.append(CNN_Block(in_channels, out_channels))
      in_channels = out_channels

    model.append(nn.Flatten())

    image_size = int(image_size / (2 ** len(inter_channels)))
    in_features = out_channels * image_size * image_size
    out_features = int(in_features / 4)
    model.append(nn.Linear(in_features, out_features))
    model.append(nn.ReLU())

    in_features = out_features
    out_features = n_classes
    model.append(nn.Linear(in_features, out_features))

    self.model = nn.Sequential(*model)

  def forward(self, input):
    return self.model(input)

#### Модель учителя

In [17]:
# teacher_model = CNN()
"---------------------------"
teacher_model = torch.load("model.pt")

##### Размер и число параметров

In [18]:
print("Teacher model")
print(f"Parameters: {count_parameters(teacher_model):,}")
measure_size(teacher_model)

Teacher model
Parameters: 5,351,882
model size: 20.423MB


##### Целевая метрика

In [None]:
train(teacher_model)

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 0: accuracy 45.7%
Epoch 1: accuracy 54.8%
Epoch 2: accuracy 64.1%
Epoch 3: accuracy 70.8%
Epoch 4: accuracy 74.7%
Epoch 5: accuracy 76.5%
Epoch 6: accuracy 74.6%
Epoch 7: accuracy 76.7%
Epoch 8: accuracy 77.8%
Epoch 9: accuracy 73.2%
Best accuracy 77.8% after epoch 8


##### Инференс модели

In [20]:
print(f"Teacher inference time: {inference(teacher_model)} ms")

Teacher inference time: 0.15277485363185406 ms


#### Модель ученика

##### Размер и число параметров

In [21]:
student_model_raw = CNN(inter_channels=[16, 32])
print("Student model")
print(f"Parameters: {count_parameters(student_model_raw):,}")
measure_size(student_model_raw)

Student model
Parameters: 1,070,970
model size: 4.086MB


##### Целевая метрика

In [22]:
train(student_model_raw)

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 0: accuracy 45.8%
Epoch 1: accuracy 58.4%
Epoch 2: accuracy 62.6%
Epoch 3: accuracy 62.4%
Epoch 4: accuracy 66.0%
Epoch 5: accuracy 68.0%
Epoch 6: accuracy 67.7%
Epoch 7: accuracy 69.0%
Epoch 8: accuracy 68.2%
Epoch 9: accuracy 69.2%
Best accuracy 69.2% after epoch 9


##### Инференс модели

In [26]:
print(f"Raw student inference time: {inference(student_model_raw)} ms")

Raw student inference time: 0.17061852850019932 ms


## Дистилляция

##### Целевая метрика

In [24]:
student_model = CNN(inter_channels=[16, 32])
train_distill(student_model, teacher_model, alpha=0.5, T=5)

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 0: accuracy 48.3%
Epoch 1: accuracy 57.2%
Epoch 2: accuracy 59.4%
Epoch 3: accuracy 64.2%
Epoch 4: accuracy 67.6%
Epoch 5: accuracy 68.4%
Epoch 6: accuracy 69.0%
Epoch 7: accuracy 69.8%
Epoch 8: accuracy 69.9%
Epoch 9: accuracy 70.5%
Best accuracy 70.5% after epoch 9


##### Инференс модели

In [25]:
print(f"Student inference time: {inference(student_model)} ms")

Student inference time: 0.14545675367116928 ms
