<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Response-Knowledge-Distillation" data-toc-modified-id="Response-Knowledge-Distillation-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Response Knowledge Distillation</a></span><ul class="toc-item"><li><span><a href="#Data-Preprocessing" data-toc-modified-id="Data-Preprocessing-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>Data Preprocessing</a></span></li><li><span><a href="#Teacher-Model" data-toc-modified-id="Teacher-Model-1.2"><span class="toc-item-num">1.2&nbsp;&nbsp;</span>Teacher Model</a></span></li><li><span><a href="#Student-Model" data-toc-modified-id="Student-Model-1.3"><span class="toc-item-num">1.3&nbsp;&nbsp;</span>Student Model</a></span></li><li><span><a href="#Benchmark" data-toc-modified-id="Benchmark-1.4"><span class="toc-item-num">1.4&nbsp;&nbsp;</span>Benchmark</a></span></li></ul></li><li><span><a href="#Reference" data-toc-modified-id="Reference-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Reference</a></span></li></ul></div>

In [1]:
# code for loading the format for the notebook
import os

# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', 'notebook_format'))

from formats import load_style
load_style(css_style='custom2.css', plot_style=False)

In [2]:
os.chdir(path)

# 1. magic for inline plot
# 2. magic to print version
# 3. magic so that the notebook will reload external python modules
# 4. magic to enable retina (high resolution) plots
# https://gist.github.com/minrk/3301035
%matplotlib inline
%load_ext watermark
%load_ext autoreload
%autoreload 2

import os
import torch
import evaluate
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from time import perf_counter
from torch.utils.data import DataLoader
from datasets import load_dataset, DatasetDict
from transformers import (
    pipeline,
    Trainer,
    TrainingArguments,
    AutoConfig,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding
)
device = "cuda" if torch.cuda.is_available() else "cpu"

%watermark -a 'Ethen' -d -u -p torch,datasets,transformers,evaluate,numpy,pandas

Author: Ethen

Last updated: 2022-08-26

torch       : 1.12.1
datasets    : 2.3.2
transformers: 4.20.1
evaluate    : 0.2.2
numpy       : 1.21.6
pandas      : 1.2.4



# Response Knowledge Distillation

In this documentation, we'll deep dive into a technique called knowledge distillation that's commonly used to compress large model, a.k.a. teacher model, into a smaller model, a.k.a student model. The hope is that these student models, which typically have fewer layers or/and fewer neurons per layer will be capable of reproducing the behavior of teacher models while being more light weight. In other words, making the model more cost efficient when it comes to serving in production setting without lossing too much performance. And just to clarify, as knowledge distillation is a broad topic, there are two primary types of knowledge distillation, task-specific knowledge distillation (left) and task-agnostic knowledge distillation (right). Here, our primary focus will be the former.

<img src="img/distillation_task.png" width="70%" height="70%">

Task specific response knowledge distillation involves optimizing a weighted combination of two objective functions

\begin{align}
L = \alpha L_{CE} + (1 - \alpha) L_{KD} \text{, where } \alpha \in [0, 1]
\end{align}

$L_{CE}$ is the cross entropy loss between the student logit $z_s$ and our one hot encoded ground truth labels $y$:

\begin{align}
L_{CE} = - \sum^c_{j=1}y_j \text{log} \sigma_j(z_s, 1)
\end{align}

Where $\sigma_i$ is our softmax output that takes the model's logit, $z$ ($z_t$ stands for teacher model's logit, whereas $z_s$ stands for student model's logit), as well as a temperature scaling parameter, $T$, as its inputs. $\sigma_i = \frac{exp\left(z_i / T \right)}{\sum_{j} \exp\left(z_j / T \right)}$. Here, the temperature parameter for softmax function is 1, which makes this the standard loss function that we generally optimize towards in supervised classification settings. 

$L_{KD}$ For knowledge distillation loss part, we are essentially add a KL-divergence loss between teacher model's response with student model's response. By adding this loss function, we are training our student model so it will become better at mimicking similar predictions as the teacher.

\begin{align}
L_{KD} = - T^2 \sum^c_{j=1}\sigma_j(z_t, T) \text{log} \frac{\sigma_j(z_t, T)}{\sigma_j(z_s, T)}
\end{align}

The idea behind temperature scaling is that teacher model tend to assign extremely high predicted scores to the true class, as such it doesn't provide too much additional information beyond what dataset's ground truth label was already provided. To tackle this issue, temperature scaling acts as a scaling parameter to "soften" our predictions. The intuition behind this it allows us to learn "ish" concepts in our data, e.g. we have a 1-ish 7 (a 7 that looks like a 1, or more formally, although our model predicted 7 with the highest score, it still assign some amount of score to 1). Note:

- When a student model is a lot smaller than a teacher model, we tend to keep a smaller temperature. Because as we raise the temperature parameter, the resulting predicted distribution may start to contain too much "knowledge" for the student to capture effectively.
- Once our student model has been trained, the temperature parameter $T$, is set back to 1 during inferencing stage.
- There's a multiplication term $T^2$, in our knowledge distillation loss Since the magnitudes of the gradients produced by the soft targets scale as $1/T^2$. It is important to add a multiplication term back to ensure contribution from the ground truth hard target and the teacher's predicted soft target remains roughly equal.

As we can see, the main idea behind response knowledge distillation is that while training our student model, instead of solely optimizing for our task's original loss function using our dataset's ground truth label (e.g. in classification task this may be cross entropy loss), we will augment it with the teacher model's predicted output probability. In our loss function we will have a parameter $\alpha$ that controls weighting between the two loss function.

## Data Preprocessing

For this example, we will be using qqp (Quora Question Pairs2) [text classification task]((https://huggingface.co/tasks/text-classification)) from the [glue benchmark](https://huggingface.co/datasets/glue). These are collection of question pairs from the community question-answering website Quora. Our task is to determine whether a pair of questions are semantically equivalent.

In [3]:
dataset_dict = load_dataset('glue', 'qqp')
dataset_dict

Reusing dataset glue (/home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['question1', 'question2', 'label', 'idx'],
        num_rows: 363846
    })
    validation: Dataset({
        features: ['question1', 'question2', 'label', 'idx'],
        num_rows: 40430
    })
    test: Dataset({
        features: ['question1', 'question2', 'label', 'idx'],
        num_rows: 390965
    })
})

In [4]:
example = dataset_dict['train'][3]
example

{'question1': 'What can one do after MBBS?',
 'question2': 'What do i do after my MBBS ?',
 'label': 1,
 'idx': 3}

## Teacher Model

To establish our baseline, we'll piggyback on one of the pretrained models available from huggingface hub. In this case, we'll pick a teacher model that is already trained on our targeted dataset.

In [5]:
teacher_checkpoint = 'textattack/bert-base-uncased-QQP'
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_checkpoint)
teacher_model = AutoModelForSequenceClassification.from_pretrained(teacher_checkpoint).to(device)
print('# of parameters: ', teacher_model.num_parameters())

# of parameters:  109483778


We generate a sample prediction using our tokenizer and model. Double confirming our result matches with the [pipeline wrapper class](https://huggingface.co/docs/transformers/v4.21.1/en/main_classes/pipelines).

In [6]:
tokenized = teacher_tokenizer(
    example['question1'],
    example['question2'],
    return_tensors='pt'
).to(teacher_model.device)
tokenized

{'input_ids': tensor([[  101,  2054,  2064,  2028,  2079,  2044, 16914,  5910,  1029,   102,
          2054,  2079,  1045,  2079,  2044,  2026, 16914,  5910,  1029,   102]],
       device='cuda:0'), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='cuda:0')}

In [7]:
teacher_model.eval()
with torch.no_grad():
    output = teacher_model(**tokenized)
    batch_scores = F.softmax(output.logits, dim=-1)

batch_scores

tensor([[0.0223, 0.9777]], device='cuda:0')

In [8]:
classifier = pipeline("text-classification", model=teacher_checkpoint, device=teacher_model.device)
output = classifier({"text": example['question1'], "text_pair": example['question2']})
output

{'label': 'LABEL_1', 'score': 0.9777140021324158}

## Student Model

As always, we are free to choose different student models and compare results, though as a general principle, we typically avoid distilling different model family against each other, as different inputs/tokens will result in different embeddings, and knowledge transfering different spaces tend to not work well.

In the next code chunk, apart from the typically step of initiating our student model using `.from_pretrained` method, we also copy some additional config such as number of labels as well as label id to label name mapping from the teacher model's config.

In [9]:
student_checkpoint = 'distilbert-base-uncased'

student_tokenizer = AutoTokenizer.from_pretrained(student_checkpoint)

student_config = AutoConfig.from_pretrained(
    student_checkpoint,
    num_labels=teacher_model.config.num_labels,
    id2label=teacher_model.config.id2label,
    label2id=teacher_model.config.label2id
)

In [10]:
def student_model_init():
    student_model = AutoModelForSequenceClassification.from_pretrained(
        student_checkpoint,
        config=student_config
    ).to(device)
    return student_model


student_model = student_model_init()
print('# of parameters: ', student_model.num_parameters())

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.weight', 'pre_classi

# of parameters:  66955010


In [11]:
def tokenize_dataset(dataset, tokenizer):
    def tokenize_fn(batch):
        return tokenizer(batch["question1"], batch["question2"], truncation=True)

    return dataset.map(
        tokenize_fn,
        batched=True,
        num_proc=8,
        remove_columns=["question1", "question2", "idx"]
    )

In [12]:
dataset_dict_student_tokenized = tokenize_dataset(dataset_dict, student_tokenizer)

 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-a8f61aaa83ba2ee2.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-daebdc7105b1e8f7.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-20d395268d52b4e3.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-8bc658db78dde795.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-f0374f1cbc06fa12.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-2e9088b1dbb8434c.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-7bbccc3a4229a83e.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-4c824600f621243f.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-51ae49f9d8a95037.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-591dc6edfbe0d7c2.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-302a4412aa27fb83.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-39d35fb67c02d467.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-458b6d6a7aa7b7f8.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-cc09642df872468a.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-452ef0103d9c6221.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-377807dbd74b5216.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-c7bdc7a185c7f2c8.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-c6fe5844c32761d8.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-f16094dbb6270595.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-074946734d50b003.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-93678765a8db8d07.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-2dc173cc4049ad28.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-480e4d963d7d4495.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-2b01915a997e8ff1.arrow


In [13]:
dataset_dict_student_tokenized['train'][0]

{'label': 0,
 'input_ids': [101,
  2129,
  2003,
  1996,
  2166,
  1997,
  1037,
  8785,
  3076,
  1029,
  2071,
  2017,
  6235,
  2115,
  2219,
  6322,
  1029,
  102,
  2029,
  2504,
  1997,
  17463,
  8156,
  2003,
  2438,
  2005,
  1996,
  11360,
  1046,
  14277,
  2102,
  2629,
  1029,
  102],
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1]}

For model performance, we'll compute some of the standard text classification metrics, Huggingface evaluate allows us to combine multiple metric's calculation in one go using the `.combine` method. As `roc_auc` expects a different input (it requires the predicted score instead of predicted labels) compared to `f1`, `precision`, `recall`, we load it separately.

In [14]:
clf_metrics = evaluate.combine(["f1", "precision", "recall"])
roc_auc_metric = evaluate.load("roc_auc")

results = clf_metrics.compute(predictions=[0, 1], references=[0, 1])
print(results)

{'f1': 1.0, 'precision': 1.0, 'recall': 1.0}


In [15]:
def compute_metrics(pred):
    scores, labels = pred
    predictions = np.argmax(scores, axis=1)
    metrics = clf_metrics.compute(predictions=predictions, references=labels)
    metrics['roc_auc'] = roc_auc_metric.compute(prediction_scores=scores[:, 1], references=labels)['roc_auc']
    return metrics

In the next few code chunk, we'll first train a student model with and without knowledge distillation for comparison.

In [16]:
batch_size = 64
num_train_epochs = 3
learning_rate = 2e-5
weight_decay = 0.01

student_finetuned_checkpoint = "distilbert-base-uncased-finetuned-qqp"
student_training_args = TrainingArguments(
    output_dir=student_finetuned_checkpoint,
    num_train_epochs=num_train_epochs,
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=weight_decay,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True
)

student_trainer = Trainer(
    model_init=student_model_init,
    args=student_training_args,
    tokenizer=student_tokenizer, 
    train_dataset=dataset_dict_student_tokenized['train'],
    eval_dataset=dataset_dict_student_tokenized['validation'],
    compute_metrics=compute_metrics
)
student_trainer.train()

loading weights file https://huggingface.co/distilbert-base-uncased/resolve/main/pytorch_model.bin from cache at /home/mingyuliu/.cache/huggingface/transformers/9c169103d7e5a73936dd2b627e42851bec0831212b677c637033ee4bce9ab5ee.126183e36667471617ae2f0835fab707baa54b731f991507ebbb55ea85adb12a
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a

Epoch,Training Loss,Validation Loss,F1,Precision,Recall,Roc Auc
1,0.2851,0.270191,0.843404,0.816137,0.872556,0.951686
2,0.2198,0.254518,0.861569,0.830724,0.894793,0.959855
3,0.1722,0.257736,0.863629,0.846228,0.88176,0.961481


***** Running Evaluation *****
  Num examples = 40430
  Batch size = 64
Saving model checkpoint to distilbert-base-uncased-finetuned-qqp/checkpoint-5686
Configuration saved in distilbert-base-uncased-finetuned-qqp/checkpoint-5686/config.json
Model weights saved in distilbert-base-uncased-finetuned-qqp/checkpoint-5686/pytorch_model.bin
tokenizer config file saved in distilbert-base-uncased-finetuned-qqp/checkpoint-5686/tokenizer_config.json
Special tokens file saved in distilbert-base-uncased-finetuned-qqp/checkpoint-5686/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 40430
  Batch size = 64
Saving model checkpoint to distilbert-base-uncased-finetuned-qqp/checkpoint-11372
Configuration saved in distilbert-base-uncased-finetuned-qqp/checkpoint-11372/config.json
Model weights saved in distilbert-base-uncased-finetuned-qqp/checkpoint-11372/pytorch_model.bin
tokenizer config file saved in distilbert-base-uncased-finetuned-qqp/checkpoint-11372/tokenizer_config.json
S

TrainOutput(global_step=17058, training_loss=0.24440976354326394, metrics={'train_runtime': 2436.7937, 'train_samples_per_second': 447.94, 'train_steps_per_second': 7.0, 'total_flos': 2.2013605137213264e+16, 'train_loss': 0.24440976354326394, 'epoch': 3.0})

In order for us to finetune for model using knowledge distillation, we will subclass the `TrainingArguments` to include our two hyperparameters, $\alpha$ and $T$, as well as `Trainer` to mainly overwrite its `compute_loss` method so we can add our knowledge distillation loss term.

In [17]:
class DistillationTrainingArguments(TrainingArguments):
    def __init__(self, *args, alpha=0.5, temperature=1.5, **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self.temperature = temperature


class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        # place teacher on same device as student
        self._move_model_to_device(self.teacher,self.model.device)
        self.teacher.eval()

        self.kl_div_loss = nn.KLDivLoss(reduction="batchmean")

    def compute_loss(self, model, inputs, return_outputs=False):
        # compute student and teacher output
        outputs_student = model(**inputs)
        with torch.no_grad():
            outputs_teacher = self.teacher(**inputs)

        # Soften probabilities and compute distillation loss
        # note, the kl divergence loss expects the input to be in log-space
        # https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
        distillation_loss = self.kl_div_loss(
            F.log_softmax(outputs_student.logits / self.args.temperature, dim=-1),
            F.softmax(outputs_teacher.logits / self.args.temperature, dim=-1)
        ) * (self.args.temperature ** 2)
        # Return weighted student loss
        loss = self.args.alpha * outputs_student.loss + (1. - self.args.alpha) * distillation_loss
        return (loss, outputs_student) if return_outputs else loss

In [18]:
student_distillation_checkpoint = "distilbert-base-uncased-finetuned-distillation-qqp"
student_distillation_training_args = DistillationTrainingArguments(
    output_dir=student_distillation_checkpoint,
    num_train_epochs=num_train_epochs,
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=weight_decay,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    alpha=0.8
)

student_distillation_trainer = DistillationTrainer(
    model_init=student_model_init,
    args=student_distillation_training_args,
    tokenizer=student_tokenizer,
    teacher_model=teacher_model, 
    train_dataset=dataset_dict_student_tokenized['train'],
    eval_dataset=dataset_dict_student_tokenized['validation'],
    compute_metrics=compute_metrics
)
student_distillation_trainer.train()

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
loading weights file https://huggingface.co/distilbert-base-uncased/resolve/main/pytorch_model.bin from cache at /home/mingyuliu/.cache/huggingface/transformers/9c169103d7e5a73936dd2b627e42851bec0831212b677c637033ee4bce9ab5ee.126183e36667471617ae2f0835fab707baa54b731f991507ebbb55ea85adb12a
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on anothe

Epoch,Training Loss,Validation Loss,F1,Precision,Recall,Roc Auc
1,0.4187,0.408516,0.83175,0.852827,0.81169,0.951295
2,0.3763,0.396593,0.850368,0.867762,0.833658,0.958391
3,0.3487,0.400977,0.856636,0.870724,0.842996,0.960325


***** Running Evaluation *****
  Num examples = 40430
  Batch size = 64
Saving model checkpoint to distilbert-base-uncased-finetuned-distillation-qqp/checkpoint-5686
Configuration saved in distilbert-base-uncased-finetuned-distillation-qqp/checkpoint-5686/config.json
Model weights saved in distilbert-base-uncased-finetuned-distillation-qqp/checkpoint-5686/pytorch_model.bin
tokenizer config file saved in distilbert-base-uncased-finetuned-distillation-qqp/checkpoint-5686/tokenizer_config.json
Special tokens file saved in distilbert-base-uncased-finetuned-distillation-qqp/checkpoint-5686/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 40430
  Batch size = 64
Saving model checkpoint to distilbert-base-uncased-finetuned-distillation-qqp/checkpoint-11372
Configuration saved in distilbert-base-uncased-finetuned-distillation-qqp/checkpoint-11372/config.json
Model weights saved in distilbert-base-uncased-finetuned-distillation-qqp/checkpoint-11372/pytorch_model.bin
token

TrainOutput(global_step=17058, training_loss=0.3921033198647013, metrics={'train_runtime': 3955.2862, 'train_samples_per_second': 275.969, 'train_steps_per_second': 4.313, 'total_flos': 2.2013605137213264e+16, 'train_loss': 0.3921033198647013, 'epoch': 3.0})

## Benchmark

When determining which model to move forward with production, we usually look at model performance, latency, as well as memory (a.k.a model size). We'll create a helper class for measuring these key aspects, run our models through it for a fair comparison.

In [19]:
class Benchmark:

    def __init__(
        self,
        dataset,
        latency_warmup: int = 10,
        latency_rounds: int = 100,
        perf_batch_size: int = 128,
        perf_round_digits: int = 3
    ):
        self.dataset = dataset
        self.latency_warmup = latency_warmup
        self.latency_rounds = latency_rounds
        self.perf_batch_size = perf_batch_size
        self.perf_round_digits = perf_round_digits

        self.temp_model_path = 'model.pt'

    def run(self, tokenizer, model, run_name):
        """run benchmark for a given tokenizer and model
        we can provide a run_name to differentiate the results
        from different runs in the final dictionary.
        
        e.g.
        {
            "run_name": {
                'size_mb': 417.73,
                'num_parameters': 109483778,
                'latency_avg_ms': 8.33,
                'latency_std_ms': 1.16,
                'f1': 0.878,
                'precision': 0.867,
                'recall': 0.89,
                'roc_auc': 0.968
            }
        }
        """
        model.eval()
        
        size = self.compute_size(model)
        latency = self.compute_latency(tokenizer, model)
        performance = self.compute_performance(tokenizer, model)

        # merge various metrics into one single dictionary
        metrics = {**size, **latency, **performance}
        return {run_name: metrics}
    
    def predict(self, example, tokenizer, model):
        inputs = tokenizer(
            example['question1'],
            example['question2'],
            return_tensors='pt'
        ).to(model.device)
        with torch.no_grad():
            output = model(**inputs.to(model.device))

        return output

    def compute_size(self, model):
        """save the model's parameter temporarily to local path for calculating model size.
        Once calculation is done, purge the checkpoint.
        Size is reported in megabtyes.

        https://pytorch.org/tutorials/beginner/saving_loading_models.html
        """
        torch.save(model.state_dict(), self.temp_model_path)
        size_mb = os.path.getsize(self.temp_model_path) / (1024 * 1024)
        size_mb = round(size_mb, 2)
        os.remove(self.temp_model_path)
        print(f"Model size (MB): {size_mb}")
        print(f"# of parameters: {model.num_parameters()}")
        return {"size_mb": size_mb, "num_parameters": model.num_parameters()}
    
    def compute_latency(self, tokenizer, model):
        """
        Pick the first example of the input dataset, compute the average latency as well as
        standard deviation over a configurable number of runs.
        Latency is reported in milliseconds.
        """
        example = self.dataset[0]
        latencies = []

        for _ in range(self.latency_warmup):
            _ = self.predict(example, tokenizer, model)

        for _ in range(self.latency_rounds):
            start_time = perf_counter()
            _ = self.predict(example, tokenizer, model)
            latency = perf_counter() - start_time
            latencies.append(latency)

        # Compute run statistics
        latency_avg_ms = round(1000 * np.mean(latencies), 2)
        latency_std_ms = round(1000 * np.std(latencies), 2)
        print(f"Average latency (ms): {latency_avg_ms} +\- {latency_std_ms}")
        return {"latency_avg_ms": latency_avg_ms, "latency_std_ms": latency_std_ms}
        
    def compute_performance(self, tokenizer, model):
        """compute f1/precision/recall/roc_auc metrics around sequence classification."""
        clf_metrics = evaluate.combine(["f1", "precision", "recall"])
        roc_auc_metric = evaluate.load("roc_auc")

        scores = []
        predictions = []
        references = []
        
        dataset_tokenized = tokenize_dataset(self.dataset, tokenizer)
        
        data_collator = DataCollatorWithPadding(tokenizer)
        data_loader = DataLoader(dataset_tokenized, batch_size=self.perf_batch_size, collate_fn=data_collator)
        for example in data_loader:
            labels = example.pop('labels')
            with torch.no_grad():
                output = model(**example.to(model.device))
                score = F.softmax(output.logits, dim=-1)
                prediction = score.argmax(dim=-1)

            scores += tensor_to_list(score[:, 1])
            predictions += tensor_to_list(prediction)
            references += tensor_to_list(labels)

        metrics = clf_metrics.compute(predictions=predictions, references=references)
        metrics['roc_auc'] = roc_auc_metric.compute(prediction_scores=scores, references=references)['roc_auc']
        for metric, value in metrics.items():
            metrics[metric] = round(value, self.perf_round_digits)

        return metrics
    
    
def tensor_to_list(tensor):
    return tensor.cpu().numpy().tolist()

In [20]:
benchmark_metrics_dict = {}
benchmark = Benchmark(dataset_dict['validation'])
benchmark_metrics = benchmark.run(teacher_tokenizer, teacher_model, 'bert_uncased_teacher')
benchmark_metrics_dict.update(benchmark_metrics)

Model size (MB): 417.73
# of parameters: 109483778
Average latency (ms): 7.01 +\- 0.63
 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-0c6ac1d39dfd9e94.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-7aa70ce020a63c9d.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-58720fa6daa16de3.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-554217326819c7b0.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-eda7739157194c24.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-d08f8b13da74e2bf.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-675828f6d2b994cd.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-f92fb802867c268e.arrow


In [21]:
benchmark_metrics = benchmark.run(
    student_tokenizer,
    student_trainer.model,
    'distilbert_student'
)
benchmark_metrics_dict.update(benchmark_metrics)

Model size (MB): 255.45
# of parameters: 66955010
Average latency (ms): 4.24 +\- 0.62
 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-51ae49f9d8a95037.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-591dc6edfbe0d7c2.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-302a4412aa27fb83.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-39d35fb67c02d467.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-458b6d6a7aa7b7f8.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-cc09642df872468a.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-452ef0103d9c6221.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-377807dbd74b5216.arrow


In [22]:
benchmark_metrics = benchmark.run(
    student_tokenizer,
    student_distillation_trainer.model,
    'distilbert_distillation_student'
)
benchmark_metrics_dict.update(benchmark_metrics)

Model size (MB): 255.45
# of parameters: 66955010
Average latency (ms): 4.02 +\- 0.34
 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-51ae49f9d8a95037.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-591dc6edfbe0d7c2.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-302a4412aa27fb83.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-39d35fb67c02d467.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-458b6d6a7aa7b7f8.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-cc09642df872468a.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-452ef0103d9c6221.arrow


 

Loading cached processed dataset at /home/mingyuliu/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-377807dbd74b5216.arrow


In [23]:
pd.DataFrame.from_dict(benchmark_metrics_dict, orient='index')

Unnamed: 0,size_mb,num_parameters,latency_avg_ms,latency_std_ms,f1,precision,recall,roc_auc
bert_uncased_teacher,417.73,109483778,7.01,0.63,0.878,0.867,0.89,0.968
distilbert_student,255.45,66955010,4.24,0.62,0.862,0.831,0.895,0.96
distilbert_distillation_student,255.45,66955010,4.02,0.34,0.85,0.868,0.834,0.959


The final table is a comparison on our teacher model (bert), and two student model (distilbert), where one of the students was trained with knowledge distilation loss, and the other wasn't. Quick observations are: we can definitely shrink our model size and improve latency by using a student model without much loss in terms of model performance. Note, we also didn't spend too much time tuning additional loss weighting, $\alpha$, and temperature scaling, $T$ hyperparameters that comes with knowledge distillation.

# Reference

- [Blog: Task-specific knowledge distillation for BERT using Transformers & Amazon SageMaker](https://www.philschmid.de/knowledge-distillation-bert-transformers)
- [Blog: Smaller, faster, cheaper, lighter: Introducing DistilBERT, a distilled version of BERT](https://medium.com/huggingface/distilbert-8cf3380435b5)
- [Blog: Knowledge Distillation: Principles, Algorithms, Applications](https://neptune.ai/blog/knowledge-distillation)
- [Blog: Weeknotes: Distilling distilled transformers](https://lewtun.github.io/blog/weeknotes/nlp/huggingface/transformers/2021/01/17/wknotes-distillation-and-generation.html)
- [Doc: Neural Network Distiller - Knowledge Distillation](https://intellabs.github.io/distiller/knowledge_distillation.html)
- [Paper: V. Sanh, L. Debut, J. Chaumond, T. Wolf - DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter - 2019](https://arxiv.org/abs/1910.01108)
- [Paper: G. Hinton, O. Vinyals, J. Dean - Distilling the Knowledge in a Neural Network - 2015](https://arxiv.org/abs/1910.01108)