# Knowledge Distallation

В этой тетрадке мы обучим большой Resnet 
[отличать породы собак](https://github.com/fastai/imagenette#imagewoof) 
друг от друга, а потом дистиллируем её знания в resnet меньшего размера.

In [5]:
from IPython.display import clear_output

!wget https://s3.amazonaws.com/fast-ai-imageclas/imagewoof2-320.tgz
!tar -xf imagewoof2-320.tgz

clear_output()

In [6]:
import os

import torch

from torchvision import datasets, models, transforms
from torch.utils import data


# code taken from https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
data_transforms = {
    "train": transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    "val": transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = "imagewoof2-320/"
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ["train", "val"]}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ["train", "val"]}
dataset_sizes = {x: len(image_datasets[x]) for x in ["train", "val"]}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("running on: {}".format(device))

running on: cuda:0


Я выбрал в качестве задачу классификацию изображений resnet'ом, потому
что похожие задачи могут встретиться in the wild. "Обучим большую 
умную модель решать задачу хорошо, а потом скомпрессим её до
меньше модели, чтобы гонять на мобильных устройствах.

In [None]:
from torch import nn, optim
import torch.nn.functional as F

from tqdm.notebook import tqdm_notebook


# at first we'll transfer teach a 101-block resnet nn
teacher = models.resnet101(pretrained=True)
teacher.fc = nn.Linear(teacher.fc.in_features, 10)
teacher = teacher.to(device)

NUM_EPOCHS = 5
best_teacher_acc = 0.0

teacher_optimizer = optim.SGD(teacher.parameters(), lr=0.001, momentum=0.9)
t_lr_scheduler = optim.lr_scheduler.StepLR(teacher_optimizer, step_size=7, gamma=0.1)

for epoch in range(NUM_EPOCHS):
    running_loss = 0.0
    running_corrects = 0

    for inputs, labels in tqdm_notebook(dataloaders[phase]):
        inputs = inputs.to(device)
        labels = labels.to(device)

        teacher_optimizer.zero_grad()

        outputs = teacher(inputs)
        _, preds = torch.max(outputs, 1)
        loss = F.cross_entropy(outputs, labels)

        loss.backward()
        teacher_optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    t_lr_scheduler.step()

    epoch_loss = running_loss / dataset_sizes[phase]
    epoch_acc = running_corrects.double() / dataset_sizes[phase]

    clear_output()
    print("Epoch {}/{}".format(epoch, NUM_EPOCHS - 1))
    print("{} Teacher loss: {:.4f} Teacher acc: {:.4f}"\
          .format(epoch_loss, epoch_acc))

HBox(children=(FloatProgress(value=0.0, max=2257.0), HTML(value='')))

In [None]:
def val_accuracy(model):
    corrects = 0
    for inputs, labels in tqdm_notebook(dataloaders["val"]):
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        corrects += torch.sum(preds == labels.data)

    return corrects / dataset_sizes["val"]

teacher_accuracy = val_accuracy(teacher)
print(f"Teacher accuracy: {teacher_accuracy}")

In [None]:
def kd_loss(student_logits, teacher_logits, labels, alpha=0.5, T=10):
    assert(T > 0)
    assert(0 <= alpha <= 1)

    basic_loss = F.cross_entropy(student_logits, labels)
    kd_loss = F.kl_div(
      F.log_softmax(student_logits/T, dim=1),
      F.softmax(teacher_logits/T, dim=1)
    )
    return alpha * basic_loss + (1 - alpha) * kd_loss

In [None]:
# create student model
student = models.resnet18(pretrained=False)
student.fc = nn.Linear(student.fc.in_features, 10)
student = student.to(device)

NUM_EPOCHS = 150

student_optimizer = optim.SGD(student.parameters(), lr=0.001, momentum=0.9)
s_lr_scheduler = optim.lr_scheduler.StepLR(student_optimizer, step_size=7, gamma=0.1)

for epoch in range(NUM_EPOCHS):
    running_loss = 0.0
    running_corrects = 0

    for inputs, labels in tqdm_notebook(dataloaders[phase]):
        inputs = inputs.to(device)
        labels = labels.to(device)

        student_optimizer.zero_grad()

        teacher_outputs = teacher(inputs)
        student_outputs = student(inputs)

        _, preds = torch.max(student_outputs, 1)
        loss = kd_loss(student_outputs, teacher_outputs, labels, alpha=0.5)                    

        loss.backward()
        student_optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
    
    s_lr_scheduler.step()

    epoch_loss = running_loss / dataset_sizes[phase]
    epoch_acc = running_corrects.double() / dataset_sizes[phase]

    clear_output()
    print("Epoch {}/{}".format(epoch, NUM_EPOCHS - 1))
    print("{} Student loss: {:.4f} Student acc: {:.4f}"\
        .format(epoch_loss, epoch_acc))

In [None]:
student_accuracy = val_accuracy(student)

In [None]:
count_params = lambda model: sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Teacher: params - {cout_params(teacher}, accuracy - f{teacher_accuracy}")
print(f"Student: params - {cout_params(student}, accuracy - f{student}")