# DistilBERT
> A partial reimplementation of DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter by Victor Sanh, Lysandre Debut, Julien Chaumond, and Thomas Wolf [[arXiv:1910.01108](https://arxiv.org/abs/1910.01108)]

## Load libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from transformerlab.question_answering import *
from transformerlab.distillation import *

from pathlib import Path

import datasets
import transformers

print(transformers.__version__, datasets.__version__)

from pprint import pprint
import math

import torch
from datasets import load_dataset, load_metric
from transformers import (AutoTokenizer, AutoModelForQuestionAnswering, 
                          default_data_collator, QuestionAnsweringPipeline)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

4.1.1 1.2.0
Running on device: cuda


## Load data

In [None]:
squad = load_dataset("squad")
squad

Reusing dataset squad (/root/.cache/huggingface/datasets/squad/plain_text/1.0.0/4c81550d83a2ac7c7ce23783bd8ff36642800e6633c1f18417fb58c3ff50cdd7)


DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

## Fine-tune the teacher

### Preprocess data

In [None]:
teacher_model_name = "bert-base-uncased"
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)

In [None]:
max_length = 384 
doc_stride = 128 
pad_on_right = teacher_tokenizer.padding_side == "right"

fn_kwargs = {
    "tokenizer": teacher_tokenizer,
    "max_length": max_length,
    "doc_stride": doc_stride,
    "pad_on_right": pad_on_right
}

#### Preprocess training set

In [None]:
train_enc = squad['train'].map(prepare_train_features, fn_kwargs=fn_kwargs, batched=True, remove_columns=squad["train"].column_names)
train_enc

Loading cached processed dataset at /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/4c81550d83a2ac7c7ce23783bd8ff36642800e6633c1f18417fb58c3ff50cdd7/cache-5a2028a8427fd1c9.arrow


Dataset({
    features: ['attention_mask', 'end_positions', 'input_ids', 'start_positions', 'token_type_ids'],
    num_rows: 88524
})

In [None]:
# check we decode the first example
pprint(squad['train'][0])
teacher_tokenizer.decode(train_enc[0]['input_ids'], skip_special_tokens=True)

{'answers': {'answer_start': [515], 'text': ['Saint Bernadette Soubirous']},
 'context': 'Architecturally, the school has a Catholic character. Atop the '
            "Main Building's gold dome is a golden statue of the Virgin Mary. "
            'Immediately in front of the Main Building and facing it, is a '
            'copper statue of Christ with arms upraised with the legend '
            '"Venite Ad Me Omnes". Next to the Main Building is the Basilica '
            'of the Sacred Heart. Immediately behind the basilica is the '
            'Grotto, a Marian place of prayer and reflection. It is a replica '
            'of the grotto at Lourdes, France where the Virgin Mary reputedly '
            'appeared to Saint Bernadette Soubirous in 1858. At the end of the '
            'main drive (and in a direct line that connects through 3 statues '
            'and the Gold Dome), is a simple, modern stone statue of Mary.',
 'id': '5733be284776f41900661182',
 'question': 'To whom did t

'to whom did the virgin mary allegedly appear in 1858 in lourdes france? architecturally, the school has a catholic character. atop the main building\'s gold dome is a golden statue of the virgin mary. immediately in front of the main building and facing it, is a copper statue of christ with arms upraised with the legend " venite ad me omnes ". next to the main building is the basilica of the sacred heart. immediately behind the basilica is the grotto, a marian place of prayer and reflection. it is a replica of the grotto at lourdes, france where the virgin mary reputedly appeared to saint bernadette soubirous in 1858. at the end of the main drive ( and in a direct line that connects through 3 statues and the gold dome ), is a simple, modern stone statue of mary.'

#### Preprocess validation set

In [None]:
valid_enc = squad['validation'].map(prepare_validation_features, fn_kwargs=fn_kwargs, batched=True, remove_columns=squad["validation"].column_names)
valid_enc

Loading cached processed dataset at /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/4c81550d83a2ac7c7ce23783bd8ff36642800e6633c1f18417fb58c3ff50cdd7/cache-791647bce0494b0b.arrow


Dataset({
    features: ['attention_mask', 'example_id', 'input_ids', 'offset_mapping', 'token_type_ids'],
    num_rows: 10784
})

In [None]:
# check we decode the first example
pprint(squad['validation'][0])
teacher_tokenizer.decode(valid_enc[0]['input_ids'], skip_special_tokens=True)

{'answers': {'answer_start': [177, 177, 177],
             'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos']},
 'context': 'Super Bowl 50 was an American football game to determine the '
            'champion of the National Football League (NFL) for the 2015 '
            'season. The American Football Conference (AFC) champion Denver '
            'Broncos defeated the National Football Conference (NFC) champion '
            'Carolina Panthers 24–10 to earn their third Super Bowl title. The '
            "game was played on February 7, 2016, at Levi's Stadium in the San "
            'Francisco Bay Area at Santa Clara, California. As this was the '
            '50th Super Bowl, the league emphasized the "golden anniversary" '
            'with various gold-themed initiatives, as well as temporarily '
            'suspending the tradition of naming each Super Bowl game with '
            'Roman numerals (under which the game would have been known as '
            '"Super

'which nfl team represented the afc at super bowl 50? super bowl 50 was an american football game to determine the champion of the national football league ( nfl ) for the 2015 season. the american football conference ( afc ) champion denver broncos defeated the national football conference ( nfc ) champion carolina panthers 24 – 10 to earn their third super bowl title. the game was played on february 7, 2016, at levi\'s stadium in the san francisco bay area at santa clara, california. as this was the 50th super bowl, the league emphasized the " golden anniversary " with various gold - themed initiatives, as well as temporarily suspending the tradition of naming each super bowl game with roman numerals ( under which the game would have been known as " super bowl l " ), so that the logo could prominently feature the arabic numerals 50.'

### Initialize trainer

In [None]:
teacher_model = AutoModelForQuestionAnswering.from_pretrained(teacher_model_name)
batch_size = 8

frac_of_samples = .005

if frac_of_samples != 1:    
    train_ds = train_enc.select(range(int(frac_of_samples * train_enc.num_rows)))
    eval_ds = valid_enc.select(range(int(frac_of_samples * valid_enc.num_rows)))
    eval_raw_ds = squad["validation"].select(range(math.ceil(frac_of_samples * squad["validation"].num_rows)))
    
    assert eval_ds.num_rows == eval_raw_ds.num_rows
else:
    train_ds = train_enc
    eval_ds = valid_enc
    eval_raw_ds = squad["validation"]

print(f"Number of training examples: {train_ds.num_rows}")
print(f"Number of validation examples: {eval_ds.num_rows}")
print(f"Number of raw validation examples: {eval_raw_ds.num_rows}")

logging_steps = len(train_ds) // batch_size

teacher_args = QuestionAnsweringTrainingArguments(
    output_dir="checkpoints",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=2,
    weight_decay=0.01,
    logging_steps=logging_steps,
    disable_tqdm=False
)

data_collator = default_data_collator

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForQuestionAnswering: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased a

Number of training examples: 442
Number of validation examples: 53
Number of raw validation examples: 53


In [None]:
teacher_trainer = QuestionAnsweringTrainer(
    model=teacher_model,
    args=teacher_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    eval_examples=eval_raw_ds,
    tokenizer=teacher_tokenizer,
    data_collator=data_collator,
    compute_metrics=squad_metrics
)

In [None]:
teacher_trainer.train()

Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
teacher_trainer.save_model('models/bert-base-uncased-finetuned-squad-v1')

### Create pipeline

In [None]:
teacher_pipe = QuestionAnsweringPipeline(teacher_trainer.model.to('cpu'), teacher_tokenizer)

context = squad['validation'][0]['context']
question = squad['validation'][0]['question']

# expected answer: 'Denver Broncos', score: 0.8437, start: 177, end: 191
result = teacher_pipe(question=question, context=context)
result

{'score': 0.6260135173797607,
 'start': 177,
 'end': 191,
 'answer': 'Denver Broncos'}

## Distillation

### Preprocess data

In [None]:
student_model_name = "distilbert-base-uncased"
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name)

In [None]:
max_length = 384 
doc_stride = 128 
pad_on_right = student_tokenizer.padding_side == "right"

fn_kwargs = {
    "tokenizer": student_tokenizer,
    "max_length": max_length,
    "doc_stride": doc_stride,
    "pad_on_right": pad_on_right
}

#### Preprocess training set

In [None]:
train_enc = squad['train'].map(prepare_train_features, fn_kwargs=fn_kwargs, batched=True, remove_columns=squad["train"].column_names)
train_enc

Loading cached processed dataset at /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/4c81550d83a2ac7c7ce23783bd8ff36642800e6633c1f18417fb58c3ff50cdd7/cache-d55a1a21d752a705.arrow


Dataset({
    features: ['attention_mask', 'end_positions', 'input_ids', 'start_positions'],
    num_rows: 88524
})

#### Preprocess validation set

In [None]:
valid_enc = squad['validation'].map(prepare_validation_features, fn_kwargs=fn_kwargs, batched=True, remove_columns=squad["validation"].column_names)
valid_enc

Loading cached processed dataset at /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/4c81550d83a2ac7c7ce23783bd8ff36642800e6633c1f18417fb58c3ff50cdd7/cache-a4722655a19e6a17.arrow


Dataset({
    features: ['attention_mask', 'example_id', 'input_ids', 'offset_mapping'],
    num_rows: 10784
})

### Create distillation trainer

In [None]:
student_model = AutoModelForQuestionAnswering.from_pretrained(student_model_name).to(device)
teacher_model = AutoModelForQuestionAnswering.from_pretrained('lewtun/bert-base-uncased-finetuned-squad-v1').to(device)

batch_size = 16

frac_of_samples = 0.005

if frac_of_samples != 1:    
    train_ds = train_enc.select(range(int(frac_of_samples * train_enc.num_rows)))
    eval_ds = valid_enc.select(range(int(frac_of_samples * valid_enc.num_rows)))
    eval_raw_ds = squad["validation"].select(range(math.ceil(frac_of_samples * squad["validation"].num_rows)))
    
    assert eval_ds.num_rows == eval_raw_ds.num_rows
else:
    train_ds = train_enc
    eval_ds = valid_enc
    eval_raw_ds = squad["validation"]

print(f"Number of training examples: {train_ds.num_rows}")
print(f"Number of validation examples: {eval_ds.num_rows}")
print(f"Number of raw validation examples: {eval_raw_ds.num_rows}")

logging_steps = len(train_ds) // batch_size

student_training_args = DistillationTrainingArguments(
    output_dir=f"checkpoints",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_steps=logging_steps,
    disable_tqdm=False
)

data_collator = default_data_collator

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForQuestionAnswering: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this mode

Number of training examples: 442
Number of validation examples: 53
Number of raw validation examples: 53


In [None]:
distil_trainer = DistillationTrainer(
    model=student_model,
    teacher_model=teacher_model,
    args=student_training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    eval_examples=eval_raw_ds,
    tokenizer=student_tokenizer,
    data_collator=data_collator,
    compute_metrics=squad_metrics
)

In [None]:
distil_trainer.evaluate()

HBox(children=(FloatProgress(value=0.0, max=53.0), HTML(value='')))

Trainer is attempting to log a value of "No log" of type <class 'str'> for key "eval/loss" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.





{'eval_loss': 'No log', 'eval_exact_match': 0.0, 'eval_f1': 2.321054207846661}

In [None]:
distil_trainer.train()

Epoch,Training Loss,Validation Loss,Exact Match,F1
1.0,8.607296,No log,0.0,9.353172
2.0,7.026534,No log,15.09434,20.0
3.0,6.279649,No log,3.773585,11.223996


HBox(children=(FloatProgress(value=0.0, max=53.0), HTML(value='')))

Trainer is attempting to log a value of "No log" of type <class 'str'> for key "eval/loss" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.





HBox(children=(FloatProgress(value=0.0, max=53.0), HTML(value='')))

Trainer is attempting to log a value of "No log" of type <class 'str'> for key "eval/loss" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.





HBox(children=(FloatProgress(value=0.0, max=53.0), HTML(value='')))

Trainer is attempting to log a value of "No log" of type <class 'str'> for key "eval/loss" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.





TrainOutput(global_step=84, training_loss=7.258751755669003)

In [None]:
distil_trainer.save_model('models/distilbert-base-uncased-distilled-squad-v1')

### Create pipeline

In [None]:
student_pipe = QuestionAnsweringPipeline(distil_trainer.model.to('cpu'), student_tokenizer)

context = squad['validation'][0]['context']
question = squad['validation'][0]['question']
pprint(question + "\n" + context)

# expected answer: 'Denver Broncos', score: 0.8437, start: 177, end: 191
result = student_pipe(question=question, context=context)
result

('Which NFL team represented the AFC at Super Bowl 50?\n'
 'Super Bowl 50 was an American football game to determine the champion of the '
 'National Football League (NFL) for the 2015 season. The American Football '
 'Conference (AFC) champion Denver Broncos defeated the National Football '
 'Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super '
 "Bowl title. The game was played on February 7, 2016, at Levi's Stadium in "
 'the San Francisco Bay Area at Santa Clara, California. As this was the 50th '
 'Super Bowl, the league emphasized the "golden anniversary" with various '
 'gold-themed initiatives, as well as temporarily suspending the tradition of '
 'naming each Super Bowl game with Roman numerals (under which the game would '
 'have been known as "Super Bowl L"), so that the logo could prominently '
 'feature the Arabic numerals 50.')


{'score': 0.8734882473945618,
 'start': 177,
 'end': 191,
 'answer': 'Denver Broncos'}

## Speed test

In [None]:
student_model_ckpt = 'lewtun/distilbert-base-uncased-distilled-squad-v1'
teacher_model_ckpt = 'lewtun/bert-base-uncased-finetuned-squad-v1'

student_tokenizer = AutoTokenizer.from_pretrained(student_model_ckpt)
student_model = AutoModelForQuestionAnswering.from_pretrained(student_model_ckpt).to('cpu')

teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_ckpt)
teacher_model = AutoModelForQuestionAnswering.from_pretrained(teacher_model_ckpt).to('cpu')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=497.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=112.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=258.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=265499080.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=112.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=252.0, style=ProgressStyle(description_…




In [None]:
student_pipe = QuestionAnsweringPipeline(student_model, student_tokenizer)
teacher_pipe = QuestionAnsweringPipeline(teacher_model, teacher_tokenizer)

In [None]:
%%time

for idx in range(1000):
    context = squad['validation'][idx]['context']
    question = squad['validation'][idx]['question']
    teacher_pipe(question=question, context=context)

CPU times: user 43min 46s, sys: 19.9 s, total: 44min 6s
Wall time: 6min 38s


In [None]:
%%time

for idx in range(1000):
    context = squad['validation'][idx]['context']
    question = squad['validation'][idx]['question']
    student_pipe(question=question, context=context)

CPU times: user 21min 11s, sys: 9.75 s, total: 21min 21s
Wall time: 3min 12s
