Knowledge distillation is a technique used to transfer knowledge from a larger, more complex model (teacher) to a smaller, simpler model (student). To distill knowledge from one model to another, we take a pre-trained teacher model trained on a certain task (image classification for this case) and randomly initialize a student model to be trained on image classification. Next, we train the student model to minimize the difference between it’s outputs and the teacher’s outputs, thus making it mimic the behavior. It was first introduced in a paper by Hinton, Vinyals and Dean: https://arxiv.org/abs/1503.02531. 

In this guide, we will do task-specific knowledge distillation. Specifically, this guide demonstrates how you can distill a fine-tuned ViT model (teacher model) to a MobileNet (student model) using the Trainer API of HF Transformers.

# Libraries

In [4]:
pip install transformers datasets accelerate tensorboard evaluate --upgrade

Collecting transformers
  Using cached transformers-4.45.2-py3-none-any.whl.metadata (44 kB)
Collecting datasets
  Using cached datasets-3.0.1-py3-none-any.whl.metadata (20 kB)
Collecting accelerate
  Using cached accelerate-1.0.1-py3-none-any.whl.metadata (19 kB)
Collecting tensorboard
  Using cached tensorboard-2.18.0-py3-none-any.whl.metadata (1.6 kB)
Collecting evaluate
  Using cached evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Using cached pyarrow-17.0.0-cp310-cp310-macosx_10_15_x86_64.whl.metadata (3.3 kB)
Collecting requests (from transformers)
  Using cached requests-2.32.3-py3-none-any.whl.metadata (4.6 kB)
Collecting tqdm>=4.27 (from transformers)
  Using cached tqdm-4.66.5-py3-none-any.whl.metadata (57 kB)
Collecting grpcio>=1.48.2 (from tensorboard)
  Using cached grpcio-1.66.2.tar.gz (12.5 MB)
  Preparing metadata (setup.py) ... [?25ldone
Collecting tensorboard-data-server<0.8.0,>=0.7.0 (from tensorboard)
  Using cached t

Using cached requests-2.32.3-py3-none-any.whl (64 kB)
Using cached tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl (4.8 MB)
Using cached tqdm-4.66.5-py3-none-any.whl (78 kB)
Building wheels for collected packages: grpcio
  Building wheel for grpcio (setup.py) ... [?25l-^C
[?25canceled
[31mERROR: Operation cancelled by user[0m[31m
[0mNote: you may need to restart the kernel to use updated packages.


In [None]:
import torch
import evaluate
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import DefaultDataCollator
from transformers import AutoImageProcessor, TrainingArguments, Trainer, \
AutoModelForImageClassification, MobileNetV2Config, MobileNetV2ForImageClassification


# Data Load

In [None]:
# Load beans dataset. We'll use merve/beans-vit-224 model as teacher model. 
# It’s an image classification model, based on google/vit-base-patch16-224-in21k fine-tuned on beans dataset
dataset = load_dataset("beans")

# Preprocessing

In [None]:
# We can use a preprocessor from either model

teacher_processor = AutoImageProcessor.from_pretrained("merve/beans-vit-224")

def process(examples):
    processed_inputs = teacher_processor(examples["image"])
    return processed_inputs

# use the map() method of dataset to apply the preprocessing to every split of the dataset
processed_datasets = dataset.map(process, batched=True)

In [None]:
# We will distill merve/beans-vit-224 model to a randomly initialized MobileNetV2
# We want the randomly initialized MobileNet to mimic the teacher model
# To achieve this, (1) we first get the logits output from the teacher and the student, 
# (2) we divide each of them by the parameter temperature which controls the importance of each soft target,
# (3) use the Kullback-Leibler Divergence loss to compute the divergence between the student and teacher.
# If the logits from the two models are identical, their KL divergence will equal zero.

class ImageDistilTrainer(Trainer):
    def __init__(self, teacher_model=None, student_model=None, temperature=None, lambda_param=None,  *args, **kwargs):
        super().__init__(model=student_model, *args, **kwargs)
        self.teacher = teacher_model
        self.student = student_model
        self.loss_function = nn.KLDivLoss(reduction="batchmean")
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.teacher.to(device)
        self.teacher.eval()
        self.temperature = temperature
        self.lambda_param = lambda_param

    def compute_loss(self, student, inputs, return_outputs=False):
        student_output = self.student(**inputs)

        with torch.no_grad():
            teacher_output = self.teacher(**inputs)

        # Compute soft targets for teacher and student
        soft_teacher = F.softmax(teacher_output.logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1)

        # Compute the loss
        distillation_loss = self.loss_function(soft_student, soft_teacher) * (self.temperature ** 2)

        # Compute the true label loss
        student_target_loss = student_output.loss

        # Calculate final loss
        loss = (1. - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss
        return (loss, student_output) if return_outputs else loss

# Training

In [None]:

training_args = TrainingArguments(
    output_dir="knowledge_distillation",
    num_train_epochs=30,
    fp16=True,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="tensorboard",
    push_to_hub=False
    )

num_labels = len(processed_datasets["train"].features["labels"].names)

# initialize models
teacher_model = AutoModelForImageClassification.from_pretrained(
    "merve/beans-vit-224",
    num_labels=num_labels,
    ignore_mismatched_sizes=True
)

# training MobileNetV2 from scratch
student_config = MobileNetV2Config()
student_config.num_labels = num_labels
student_model = MobileNetV2ForImageClassification(student_config)

# Evaluation

In [None]:
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    acc = accuracy.compute(references=labels, predictions=np.argmax(predictions, axis=1))
    return {"accuracy": acc["accuracy"]}