In [None]:
# default_exp distill.distillation_callback

# Knowledge Distillation

> Train a network in a teacher-student fashion

In [None]:
# all_slow

In [None]:
#hide
from nbdev.showdoc import *

%config InlineBackend.figure_format = 'retina'

Knowledge Distillation, sometimes called teacher-student training, is a compression method in which a small (the student) model is trained to mimic the behaviour of a larger (the teacher) model.

The main goal is to reveal what is called the **Dark Knowledge** hidden in the teacher model.

If we take the same [example](https://www.ttic.edu/dl/dark14.pdf) provided by Geoffrey Hinton et al., we have

The main problem of classification is that the output activation function (softmax) will, by design, make a single value really high and squash others.

$$
p_{i}=\frac{\exp \left(z_{i}\right)}{\sum_{j} \exp \left(z_{j}\right)}
$$

With $p_i$ the probability of class $i$, computed from the logits $z$

Here is an example to illustrate this phenomenon:

Let's say that we have trained a model to discriminate between the following 5 classes: [cow, dog, plane, cat, car]

And here is the output of the final layer (the logits) when the model is fed a new input image: 

In [None]:
logits = torch.tensor([1.3, 3.1, 0.2, 1.9, -0.3])

By judging on the predictions, the model seems confident that the input data is a dog and quite confident that it is definitely not a plane nor a car, with predictions for cow and cat being moderately high.

So the model not only has learned to recognize a dog in the image, but also that a dog is very different from a car and a plane and share similarities with cats and cows. This information is what is called **dark knowledge** !

When passing those predictions through a softmax, we have:

In [None]:
predictions = F.softmax(hard_preds, dim=-1); predictions

tensor([0.0864, 0.6386, 0.0388, 0.2126, 0.0236])

This is accuenting the differences that we had earlier, discarding some of the dark knowledge acquired earlier. The way to keep this knowledge is to "soften" our softmax outputs, by adding a **temperature** parameter. The higher the temperature, the softer the predictions.

In [None]:
soft_predictions = F.softmax(hard_preds/3, dim=-1); soft_predictions

tensor([0.1751, 0.3410, 0.1341, 0.2363, 0.1135])

> Note: if the Temperature is equal to 1, then we have regular softmax

When applying Knowledge Distillation, we want to keep the **Dark Knowledge** that the teacher model has acquired during its training but not rely entirely on it. So we combine two losses: 

- The Teacher loss between the softened predictions of the teacher and the softened predictions of the student
- The Classification loss, which is the regular loss between hard labels and hard predictions

The combination between those losses are weighted by an additional parameter α, as:

$$
L_{K D}=\alpha  * \text { CrossEntropy }\left(p_{S}^{\tau}, p_{T}^{\tau}\right)+(1-\alpha) * \text { CrossEntropy }\left(p_{S}, y_{\text {true }}\right)
$$

With $p^{\tau}$ being the softened predictions of the student and teacher

> Note: In practice, the distillation loss will be a [bit different](http://cs230.stanford.edu/files_winter_2018/projects/6940224.pdf) in the implementation

![distill](imgs/distill.pdf "Knowledge Distillation")

In [None]:
#export
from fastai.vision.all import *

import torch
import torch.nn as nn
import torch.nn.functional as F

This can be done with fastai, using the Callback system !

In [None]:
#export
class KnowledgeDistillation(Callback):
    """Implementation inspired by https://github.com/peterliht/knowledge-distillation-pytorch/
    """
    def __init__(self, teacher, T:float=20., α:float=0.7):
        store_attr()
    
    def after_loss(self):
        self.teacher.model.eval()
        teacher_output = self.teacher.model(self.x)
        new_loss = DistillationLoss(self.pred, self.y, teacher_output, self.T, self.α)
        self.learn.loss_grad = new_loss

def DistillationLoss(y, labels, teacher_scores, T, alpha):
    return nn.KLDivLoss(reduction='batchmean')(F.log_softmax(y/T, dim=-1), F.softmax(teacher_scores/T, dim=-1)) * (T*T * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)