In [None]:
!pip install -qU transformers datasets accelerate tensorboard evaluate

# Knowledge Distillation for Computer Vision

**Knowledge distillation** is a technique used to transfer knowledge from a larger, more complex model (teacher) to a smaller, simpler model (student).

To distill knowledge from one model to another, we take a pre-trained teacher model trained on a certain task and randomly initialize a student model to be trained on the same task. Next, we train the student model to minimize the difference between its outputs and the teacher's outputs, thus making it mimic the behavior.

We will distill a fine-tuned `ViT` model (teacher model) to a `MobileNet` (student model) with the beans dataset for an image classification task.
* teacher model `merve/beans-vit-224`
* student model `google/vit-base-patch16-224-in21k`
* dataset `beans`

In [3]:
# load dataset
from datasets import load_dataset

dataset = load_dataset("beans")

README.md:   0%|          | 0.00/4.95k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/144M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/18.5M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/17.7M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1034 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/133 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/128 [00:00<?, ? examples/s]

In [4]:
dataset

DatasetDict({
    train: Dataset({
        features: ['image_file_path', 'image', 'labels'],
        num_rows: 1034
    })
    validation: Dataset({
        features: ['image_file_path', 'image', 'labels'],
        num_rows: 133
    })
    test: Dataset({
        features: ['image_file_path', 'image', 'labels'],
        num_rows: 128
    })
})

We can use an image processor from either of the models, because in this case they return the same output with the same resolution. We will use the `map()` method of `dataset` to apply the preprocessing to every split of the dataset.

In [5]:
from transformers import AutoImageProcessor

teacher_processor = AutoImageProcessor.from_pretrained("merve/beans-vit-224")

def process(examples):
    processed_inputs = teacher_processor(examples['image'])
    return processed_inputs

processed_datasets = dataset.map(process, batched=True)

preprocessor_config.json:   0%|          | 0.00/325 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Map:   0%|          | 0/1034 [00:00<?, ? examples/s]

Map:   0%|          | 0/133 [00:00<?, ? examples/s]

Map:   0%|          | 0/128 [00:00<?, ? examples/s]

Essentially, we want the student model (a randomly initialized `MobileNet` in this case) to mimic the teach model (fine-tuned `ViT`).

To achieve this, we first get the `logits` output from the teacher and the student. Then, we divide each of them by the parameter `temperature` which controls the importance of each soft target. A parameter `lambda` weights the importance of the distillation loss. Here, `temperature=5` and `lambda=0.5`.

We will use the KL divergence loss to compute the divergence between the student and the teacher. Given two data P and Q, KL divergence explains how much extra information we need to represent P using Q. If two are identical, their KL divergence is zero, as there is no other information needed to explain P from Q. Thus, in the context of knowledge distillation, KL divergence is useful.

In [6]:
from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate.test_utils.testing import get_backend


class ImageDistilTrainer(Trainer):
    def __init__(self, teacher=None, student=None, temperature=None, lambda_param=None, *args, **kwargs):
        super().__init__(model=student, *args, **kwargs)

        self.teacher = teacher
        self.student = student
        self.loss_function = nn.KLDivLoss(reduction='batchmean')
        device, _, _ = get_backend()
        self.teacher.to(device)
        self.teacher.eval()
        self.temperature = temperature
        self.lambda_param = lambda_param


    def compute_loss(self, student, inputs, return_outputs=False):
        student_output = self.student(**inputs)

        with torch.no_grad():
            teacher_output = self.teacher(**inputs)

        # Compute soft targets for teacher and student
        soft_teacher = F.softmax(teacher_output.logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1)

        # Compute the loss
        distillation_loss = self.loss_function(soft_student, soft_teacher) * (self.temperature ** 2)

        # Compute the true label loss
        student_target_loss = student_output.loss

        # Calculate the final loss
        loss = (1. - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss

        return (loss, student_output) if return_outputs else loss

Set the `TrainingArguments`

In [7]:
from transformers import (
    AutoModelForImageClassification,
    MobileNetV2Config, MobileNetV2ForImageClassification
)

training_args = TrainingArguments(
    output_dir='my-awesome-student-model',
    num_train_epochs=30,
    fp16=True,
    logging_dir=f'image-classification-distillation/logs',
    logging_strategy='epoch',
    eval_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    report_to='tensorboard',
    push_to_hub=False,
)

In [8]:
num_labels = len(processed_datasets['train'].features['labels'].names)

# intiialize teacher model
teacher_model = AutoModelForImageClassification.from_pretrained(
    'merve/beans-vit-224',
    num_labels=num_labels,
    ignore_mismatched_sizes=True,
)

# initialize student model
student_config = MobileNetV2Config()
student_config.num_labels = num_labels
student_model = MobileNetV2ForImageClassification(student_config)

config.json:   0%|          | 0.00/799 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/343M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/343M [00:00<?, ?B/s]

We can use `compute_metrics` to evaluate our model on the test set. We will use `accuracy` and `f1` for our models

In [9]:
import evaluate
import numpy as np

accuracy = evaluate.load('accuracy')
f1 = evaluate.load('f1')

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    acc = accuracy.compute(
        predictions=np.argmax(predictions, axis=1),
        references=labels
    )
    f1_score = f1.compute(
        predictions=np.argmax(predictions, axis=1),
        references=labels,
        average='weighted'
    )

    return {'accuracy': acc['accuracy'], 'f1': f1_score['f1']}

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/6.79k [00:00<?, ?B/s]

Initialize our `Trainer`:

In [11]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()

trainer = ImageDistilTrainer(
    teacher=teacher_model,
    student=student_model,
    temperature=5,
    lambda_param=0.5,

    args=training_args,
    train_dataset=processed_datasets['train'],
    eval_dataset=processed_datasets['validation'],
    data_collator=data_collator,
    processing_class=teacher_processor,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

We can evaluate the student model on the test set

In [None]:
trainer.evaluate(processed_datasets['test'])