# DistilBERT
> A partial reimplementation of DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter by Victor Sanh, Lysandre Debut, Julien Chaumond, and Thomas Wolf [[arXiv:1910.01108](https://arxiv.org/abs/1910.01108)]

The goal of this notebook is to explore _task-specific_ knowledge distillation, where a teacher is used to augment the cross-entropy loss of the student during fine-tuning:

$${\cal L}(\mathbf{x}|T) = - \sum_i \bar{y}_i\log y_i(\mathbf{x}|T) -T^2 \sum_i \hat{y}_i(\mathbf{x}|T)\log y_i(\mathbf{x}|T) \,.$$

Here $T$ is the temperature, $\hat{y}$ are the outputs from the model, $\bar{y}$ the ground-truth labels, and $y_i$ a softmax with temperature.

This idea comes from the DistilBERT paper, where the authors found that including a "second step of distillation" produced a student that performed better than simply fine-tuning the distilled language model:

> We also studied whether we could add another step of distillation during the adaptation phase by fine-tuning DistilBERT on SQuAD using a BERT model previously fine-tuned on SQuAD as a teacher for an additional term in the loss (knowledge distillation). In this setting, there are thus two successive steps of distillation, one during the pre-training phase and one during the adaptation phase. In this case, we were able to reach interesting performances given the size of the model:79.8 F1 and 70.4 EM, i.e. within 3 points of the full model.

We'll take the same approach here and aim to reproduce the SQuAD v1 results from the paper. In the table below, each entry refers to the Exact Match / F1-score on the validation set.

| Implementation | BERT-base | DistilBERT | (DistilBERT)^2 |
| :--- | :---: | :---: | :---: |
| HuggingFace | 81.2 / 88.5 | 77.7 / 85.8 | 79.1 / 86.9 |
| Ours | 80.1 / 87.8 | 76.7 / 85.2 | 78.4 / 86.5 |

## Load libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import math

from transformerlab.question_answering import *
from transformerlab.distillation import *

import datasets
import transformers
datasets.logging.set_verbosity_error()
transformers.logging.set_verbosity_error()
import torch
from datasets import load_dataset, load_metric
from transformers import (AutoTokenizer, AutoModelForQuestionAnswering, 
                          default_data_collator, QuestionAnsweringPipeline)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using transformers v{transformers.__version__} and datasets v{datasets.__version__}")
print(f"Running on device: {device}")

Using transformers v4.1.1 and datasets v1.2.0
Running on device: cuda


## Load data

In [None]:
squad_ds = load_dataset("squad")
squad_ds

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

## Fine-tune BERT-base

### Initialize trainer

In [None]:
batch_size = 16
num_train_examples = len(squad_ds['train'])
num_eval_examples = len(squad_ds['validation'])

teacher_model_checkpoint = "bert-base-uncased"
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_checkpoint)

train_ds, eval_ds, eval_examples = convert_examples_to_features(squad_ds, teacher_tokenizer, num_train_examples, num_eval_examples)
logging_steps = len(train_ds) // batch_size

teacher_args = QuestionAnsweringTrainingArguments(
    output_dir="checkpoints",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=2,
    weight_decay=0.01,
    logging_steps=logging_steps,
    disable_tqdm=False
)

print(f"Number of training examples: {train_ds.num_rows}")
print(f"Number of validation examples: {eval_ds.num_rows}")
print(f"Number of raw validation examples: {eval_examples.num_rows}")
print(f"Logging steps: {logging_steps}")

Number of training examples: 88524
Number of validation examples: 10784
Number of raw validation examples: 10570
Logging steps: 5532


In [None]:
def teacher_init():
    return AutoModelForQuestionAnswering.from_pretrained(teacher_model_checkpoint)

data_collator = default_data_collator

teacher_trainer = QuestionAnsweringTrainer(
    model_init=teacher_init,
    args=teacher_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    eval_examples=eval_examples,
    tokenizer=teacher_tokenizer,
    data_collator=data_collator,
    compute_metrics=squad_metrics
)

In [None]:
teacher_trainer.train();

Epoch,Training Loss,Validation Loss,Exact Match,F1
1.0,4.321222,No log,17.1875,28.298791
2.0,2.801787,No log,33.125,44.253312


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))




In [None]:
teacher_trainer.save_model('models/bert-base-uncased-finetuned-squad-v1')

### Evaluate the teacher

In [None]:
teacher_checkpoint = "lewtun/bert-base-uncased-finetuned-squad-v1"
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_checkpoint)
teacher_finetuned = AutoModelForQuestionAnswering.from_pretrained(teacher_checkpoint)

teacher_trainer = QuestionAnsweringTrainer(
    model=teacher_finetuned,
    args=teacher_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    eval_examples=eval_examples,
    tokenizer=teacher_tokenizer,
    data_collator=data_collator,
    compute_metrics=squad_metrics
)
teacher_trainer.evaluate()

HBox(children=(FloatProgress(value=0.0, max=10570.0), HTML(value='')))




{'eval_loss': 'No log',
 'eval_exact_match': 80.07568590350047,
 'eval_f1': 87.77870284880602}

## Fine-tune DistilBERT

### Configure and initialise trainer

In [None]:
batch_size = 16
num_train_examples = len(squad_ds['train'])
num_eval_examples = len(squad_ds['validation'])

distilbert_checkpoint = "distilbert-base-uncased"
distilbert_tokenizer = AutoTokenizer.from_pretrained(distilbert_checkpoint)

train_ds, eval_ds, eval_examples = convert_examples_to_features(squad_ds, distilbert_tokenizer, num_train_examples, num_eval_examples)
logging_steps = len(train_ds) // batch_size

distilbert_args = QuestionAnsweringTrainingArguments(
    output_dir="checkpoints",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_steps=logging_steps,
    disable_tqdm=False
)

print(f"Number of training examples: {train_ds.num_rows}")
print(f"Number of validation examples: {eval_ds.num_rows}")
print(f"Number of raw validation examples: {eval_examples.num_rows}")
print(f"Logging steps: {logging_steps}")

Number of training examples: 88524
Number of validation examples: 10784
Number of raw validation examples: 10570
Logging steps: 5532


In [None]:
def distilbert_init():
    return AutoModelForQuestionAnswering.from_pretrained(distilbert_checkpoint)

data_collator = default_data_collator

distilbert_trainer = QuestionAnsweringTrainer(
    model_init=distilbert_init,
    args=distilbert_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    eval_examples=eval_examples,
    tokenizer=distilbert_tokenizer,
    data_collator=data_collator,
    compute_metrics=squad_metrics
)

In [None]:
distilbert_trainer.train();

In [None]:
distilbert_trainer.save_model('models/bert-base-uncased-finetuned-squad-v1')

## Distill DistilBERT

In [None]:
student_model_name = "distilbert-base-uncased"
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name)

In [None]:
max_length = 384 
doc_stride = 128 
pad_on_right = student_tokenizer.padding_side == "right"

fn_kwargs = {
    "tokenizer": student_tokenizer,
    "max_length": max_length,
    "doc_stride": doc_stride,
    "pad_on_right": pad_on_right
}

#### Preprocess training set

In [None]:
train_enc = squad['train'].map(prepare_train_features, fn_kwargs=fn_kwargs, batched=True, remove_columns=squad["train"].column_names)
train_enc

Loading cached processed dataset at /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/4c81550d83a2ac7c7ce23783bd8ff36642800e6633c1f18417fb58c3ff50cdd7/cache-d55a1a21d752a705.arrow


Dataset({
    features: ['attention_mask', 'end_positions', 'input_ids', 'start_positions'],
    num_rows: 88524
})

#### Preprocess validation set

In [None]:
valid_enc = squad['validation'].map(prepare_validation_features, fn_kwargs=fn_kwargs, batched=True, remove_columns=squad["validation"].column_names)
valid_enc

Loading cached processed dataset at /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/4c81550d83a2ac7c7ce23783bd8ff36642800e6633c1f18417fb58c3ff50cdd7/cache-a4722655a19e6a17.arrow


Dataset({
    features: ['attention_mask', 'example_id', 'input_ids', 'offset_mapping'],
    num_rows: 10784
})

### Create distillation trainer

In [None]:
student_model = AutoModelForQuestionAnswering.from_pretrained(student_model_name).to(device)
teacher_model = AutoModelForQuestionAnswering.from_pretrained('lewtun/bert-base-uncased-finetuned-squad-v1').to(device)

batch_size = 16

frac_of_samples = 0.005

if frac_of_samples != 1:    
    train_ds = train_enc.select(range(int(frac_of_samples * train_enc.num_rows)))
    eval_ds = valid_enc.select(range(int(frac_of_samples * valid_enc.num_rows)))
    eval_raw_ds = squad["validation"].select(range(math.ceil(frac_of_samples * squad["validation"].num_rows)))
    
    assert eval_ds.num_rows == eval_raw_ds.num_rows
else:
    train_ds = train_enc
    eval_ds = valid_enc
    eval_raw_ds = squad["validation"]

print(f"Number of training examples: {train_ds.num_rows}")
print(f"Number of validation examples: {eval_ds.num_rows}")
print(f"Number of raw validation examples: {eval_raw_ds.num_rows}")

logging_steps = len(train_ds) // batch_size

student_training_args = DistillationTrainingArguments(
    output_dir=f"checkpoints",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_steps=logging_steps,
    disable_tqdm=False
)

data_collator = default_data_collator

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForQuestionAnswering: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this mode

Number of training examples: 442
Number of validation examples: 53
Number of raw validation examples: 53


In [None]:
distil_trainer = DistillationTrainer(
    model=student_model,
    teacher_model=teacher_model,
    args=student_training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    eval_examples=eval_raw_ds,
    tokenizer=student_tokenizer,
    data_collator=data_collator,
    compute_metrics=squad_metrics
)

In [None]:
distil_trainer.evaluate()

HBox(children=(FloatProgress(value=0.0, max=53.0), HTML(value='')))

Trainer is attempting to log a value of "No log" of type <class 'str'> for key "eval/loss" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.





{'eval_loss': 'No log', 'eval_exact_match': 0.0, 'eval_f1': 2.321054207846661}

In [None]:
distil_trainer.train()

Epoch,Training Loss,Validation Loss,Exact Match,F1
1.0,8.607296,No log,0.0,9.353172
2.0,7.026534,No log,15.09434,20.0
3.0,6.279649,No log,3.773585,11.223996


HBox(children=(FloatProgress(value=0.0, max=53.0), HTML(value='')))

Trainer is attempting to log a value of "No log" of type <class 'str'> for key "eval/loss" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.





HBox(children=(FloatProgress(value=0.0, max=53.0), HTML(value='')))

Trainer is attempting to log a value of "No log" of type <class 'str'> for key "eval/loss" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.





HBox(children=(FloatProgress(value=0.0, max=53.0), HTML(value='')))

Trainer is attempting to log a value of "No log" of type <class 'str'> for key "eval/loss" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.





TrainOutput(global_step=84, training_loss=7.258751755669003)

In [None]:
distil_trainer.save_model('models/distilbert-base-uncased-distilled-squad-v1')

### Create pipeline

In [None]:
student_pipe = QuestionAnsweringPipeline(distil_trainer.model.to('cpu'), student_tokenizer)

context = squad['validation'][0]['context']
question = squad['validation'][0]['question']
pprint(question + "\n" + context)

# expected answer: 'Denver Broncos', score: 0.8437, start: 177, end: 191
result = student_pipe(question=question, context=context)
result

('Which NFL team represented the AFC at Super Bowl 50?\n'
 'Super Bowl 50 was an American football game to determine the champion of the '
 'National Football League (NFL) for the 2015 season. The American Football '
 'Conference (AFC) champion Denver Broncos defeated the National Football '
 'Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super '
 "Bowl title. The game was played on February 7, 2016, at Levi's Stadium in "
 'the San Francisco Bay Area at Santa Clara, California. As this was the 50th '
 'Super Bowl, the league emphasized the "golden anniversary" with various '
 'gold-themed initiatives, as well as temporarily suspending the tradition of '
 'naming each Super Bowl game with Roman numerals (under which the game would '
 'have been known as "Super Bowl L"), so that the logo could prominently '
 'feature the Arabic numerals 50.')


{'score': 0.8734882473945618,
 'start': 177,
 'end': 191,
 'answer': 'Denver Broncos'}

## Speed test

In [None]:
student_model_ckpt = 'lewtun/distilbert-base-uncased-distilled-squad-v1'
teacher_model_ckpt = 'lewtun/bert-base-uncased-finetuned-squad-v1'

student_tokenizer = AutoTokenizer.from_pretrained(student_model_ckpt)
student_model = AutoModelForQuestionAnswering.from_pretrained(student_model_ckpt).to('cpu')

teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_ckpt)
teacher_model = AutoModelForQuestionAnswering.from_pretrained(teacher_model_ckpt).to('cpu')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=497.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=112.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=258.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=265499080.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=112.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=252.0, style=ProgressStyle(description_…




In [None]:
student_pipe = QuestionAnsweringPipeline(student_model, student_tokenizer)
teacher_pipe = QuestionAnsweringPipeline(teacher_model, teacher_tokenizer)

In [None]:
%%time

for idx in range(1000):
    context = squad['validation'][idx]['context']
    question = squad['validation'][idx]['question']
    teacher_pipe(question=question, context=context)

CPU times: user 43min 46s, sys: 19.9 s, total: 44min 6s
Wall time: 6min 38s


In [None]:
%%time

for idx in range(1000):
    context = squad['validation'][idx]['context']
    question = squad['validation'][idx]['question']
    student_pipe(question=question, context=context)

CPU times: user 21min 11s, sys: 9.75 s, total: 21min 21s
Wall time: 3min 12s
