# Distilling Transformers for Question Answering

## Load libraries

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import warnings
from pathlib import Path

import datasets
import transformers

warnings.filterwarnings("ignore")
datasets.logging.set_verbosity_error()
transformers.logging.set_verbosity_error()

print(transformers.__version__, datasets.__version__)

4.1.1 1.2.0


In [3]:
import collections
from tqdm.auto import tqdm
from pprint import pprint
import math

from datasets import load_dataset, load_metric
from transformers import (AutoTokenizer, AutoModelForQuestionAnswering, default_data_collator, 
                          TrainingArguments, Trainer, EvalPrediction, QuestionAnsweringPipeline)
from transformers.trainer_utils import PredictionOutput

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

Running on device: cuda


## Helper functions

In [4]:
def prepare_train_features(examples, tokenizer, pad_on_right, max_length, doc_stride):
    # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # The offset mappings will give us a map from token to character position in the original context. This will
    # help us compute the start_positions and end_positions.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Let's label those examples!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # If no answers are given, set the cls_index as answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

In [5]:
def prepare_validation_features(examples, tokenizer, pad_on_right, max_length, doc_stride):
    # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples['question' if pad_on_right else 'context'],
        examples['context' if pad_on_right else 'question'],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

    # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
    # corresponding example_id and we will store the offset mappings.
    tokenized_examples["example_id"] = []

    for i in range(len(tokenized_examples["input_ids"])):
        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)
        context_index = 1 if pad_on_right else 0

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        tokenized_examples["example_id"].append(examples["id"][sample_index])

        # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
        # position is part of the context or not.
        tokenized_examples["offset_mapping"][i] = [
            (o if sequence_ids[k] == context_index else None)
            for k, o in enumerate(tokenized_examples["offset_mapping"][i])
        ]

    return tokenized_examples

In [6]:
metric = load_metric("squad")

def compute_metrics(p: EvalPrediction):
    return metric.compute(predictions=p.predictions, references=p.label_ids)

## Load data

In [7]:
squad = load_dataset("squad")
squad

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

## Fine-tune the teacher

### Preprocess data

In [8]:
teacher_model_name = "bert-base-uncased"
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)

In [9]:
max_length = 384 
doc_stride = 128 
pad_on_right = teacher_tokenizer.padding_side == "right"

fn_kwargs = {
    "tokenizer": teacher_tokenizer,
    "max_length": max_length,
    "doc_stride": doc_stride,
    "pad_on_right": pad_on_right
}

#### Preprocess training set

In [10]:
train_enc = squad['train'].map(prepare_train_features, fn_kwargs=fn_kwargs, batched=True, remove_columns=squad["train"].column_names)
train_enc

Dataset({
    features: ['attention_mask', 'end_positions', 'input_ids', 'start_positions', 'token_type_ids'],
    num_rows: 88524
})

In [11]:
# check we decode the first example
pprint(squad['train'][0])
teacher_tokenizer.decode(train_enc[0]['input_ids'], skip_special_tokens=True)

{'answers': {'answer_start': [515], 'text': ['Saint Bernadette Soubirous']},
 'context': 'Architecturally, the school has a Catholic character. Atop the '
            "Main Building's gold dome is a golden statue of the Virgin Mary. "
            'Immediately in front of the Main Building and facing it, is a '
            'copper statue of Christ with arms upraised with the legend '
            '"Venite Ad Me Omnes". Next to the Main Building is the Basilica '
            'of the Sacred Heart. Immediately behind the basilica is the '
            'Grotto, a Marian place of prayer and reflection. It is a replica '
            'of the grotto at Lourdes, France where the Virgin Mary reputedly '
            'appeared to Saint Bernadette Soubirous in 1858. At the end of the '
            'main drive (and in a direct line that connects through 3 statues '
            'and the Gold Dome), is a simple, modern stone statue of Mary.',
 'id': '5733be284776f41900661182',
 'question': 'To whom did t

'to whom did the virgin mary allegedly appear in 1858 in lourdes france? architecturally, the school has a catholic character. atop the main building\'s gold dome is a golden statue of the virgin mary. immediately in front of the main building and facing it, is a copper statue of christ with arms upraised with the legend " venite ad me omnes ". next to the main building is the basilica of the sacred heart. immediately behind the basilica is the grotto, a marian place of prayer and reflection. it is a replica of the grotto at lourdes, france where the virgin mary reputedly appeared to saint bernadette soubirous in 1858. at the end of the main drive ( and in a direct line that connects through 3 statues and the gold dome ), is a simple, modern stone statue of mary.'

#### Preprocess validation set

In [12]:
valid_enc = squad['validation'].map(prepare_validation_features, fn_kwargs=fn_kwargs, batched=True, remove_columns=squad["validation"].column_names)
valid_enc

Dataset({
    features: ['attention_mask', 'example_id', 'input_ids', 'offset_mapping', 'token_type_ids'],
    num_rows: 10784
})

In [13]:
# check we decode the first example
pprint(squad['validation'][0])
teacher_tokenizer.decode(valid_enc[0]['input_ids'], skip_special_tokens=True)

{'answers': {'answer_start': [177, 177, 177],
             'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos']},
 'context': 'Super Bowl 50 was an American football game to determine the '
            'champion of the National Football League (NFL) for the 2015 '
            'season. The American Football Conference (AFC) champion Denver '
            'Broncos defeated the National Football Conference (NFC) champion '
            'Carolina Panthers 24–10 to earn their third Super Bowl title. The '
            "game was played on February 7, 2016, at Levi's Stadium in the San "
            'Francisco Bay Area at Santa Clara, California. As this was the '
            '50th Super Bowl, the league emphasized the "golden anniversary" '
            'with various gold-themed initiatives, as well as temporarily '
            'suspending the tradition of naming each Super Bowl game with '
            'Roman numerals (under which the game would have been known as '
            '"Super

'which nfl team represented the afc at super bowl 50? super bowl 50 was an american football game to determine the champion of the national football league ( nfl ) for the 2015 season. the american football conference ( afc ) champion denver broncos defeated the national football conference ( nfc ) champion carolina panthers 24 – 10 to earn their third super bowl title. the game was played on february 7, 2016, at levi\'s stadium in the san francisco bay area at santa clara, california. as this was the 50th super bowl, the league emphasized the " golden anniversary " with various gold - themed initiatives, as well as temporarily suspending the tradition of naming each super bowl game with roman numerals ( under which the game would have been known as " super bowl l " ), so that the logo could prominently feature the arabic numerals 50.'

### Create question answering trainer

In [14]:
class QuestionAnsweringTrainer(Trainer):
    def __init__(self, *args, eval_examples=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.eval_examples = eval_examples
        
    def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None):
        eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
        eval_examples = self.eval_examples if eval_examples is None else eval_examples

        # Temporarily disable metric computation, we will do it in the loop here.
        compute_metrics = self.compute_metrics
        self.compute_metrics = None
        try:
            output = self.prediction_loop(
                eval_dataloader,
                description="Evaluation",
                # No point gathering the predictions if there are no metrics, otherwise we defer to
                # self.args.prediction_loss_only
                prediction_loss_only=True if compute_metrics is None else None,
                ignore_keys=ignore_keys,
            )
        finally:
            self.compute_metrics = compute_metrics
    

        # We might have removed columns from the dataset so we put them back.
#         if isinstance(eval_dataset, datasets.Dataset):
        eval_dataset.set_format(type=eval_dataset.format["type"], columns=list(eval_dataset.features.keys()))

        if self.compute_metrics is not None:
            eval_preds = self._post_process_function(eval_examples, eval_dataset, output.predictions)
            metrics = self.compute_metrics(eval_preds)
            # HACK: for some reason the eval_loss is not computed in output
            # Problem lies in NotebookProgressCallback which assumes eval_loss exists
            metrics['eval_loss'] = 'No log'

            self.log(metrics)
        else:
            metrics = {}
            
        for key in list(metrics.keys()):
            if not key.startswith(f"eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)

        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
        return metrics

    def predict(self, test_dataset, test_examples, ignore_keys=None):
        test_dataloader = self.get_test_dataloader(test_dataset)

        # Temporarily disable metric computation, we will do it in the loop here.
        compute_metrics = self.compute_metrics
        self.compute_metrics = None
        try:
            output = self.prediction_loop(
                test_dataloader,
                description="Evaluation",
                # No point gathering the predictions if there are no metrics, otherwise we defer to
                # self.args.prediction_loss_only
                prediction_loss_only=None, #True if compute_metrics is None else None,
                ignore_keys=ignore_keys,
            )
        finally:
            self.compute_metrics = compute_metrics

        if self.compute_metrics is None:
            return output

        # We might have removed columns from the dataset so we put them back.
#         if isinstance(test_dataset, datasets.Dataset):
        test_dataset.set_format(type=test_dataset.format["type"], columns=list(test_dataset.features.keys()))

        eval_preds = self._post_process_function(test_examples, test_dataset, output.predictions)
        metrics = self.compute_metrics(eval_preds)

        return PredictionOutput(predictions=eval_preds.predictions, label_ids=eval_preds.label_ids, metrics=metrics)
    
    
    def _post_process_function(self, examples, features, predictions):
        # Post-processing: we match the start logits and end logits to answers in the original context.
        predictions = self._postprocess_qa_predictions(
            examples=examples,
            features=features,
            predictions=predictions,
            version_2_with_negative=self.args.version_2_with_negative,
            n_best_size=self.args.n_best_size,
            max_answer_length=self.args.max_answer_length,
            null_score_diff_threshold=self.args.null_score_diff_threshold,
            output_dir=self.args.output_dir,
            is_world_process_zero=self.is_world_process_zero(),
        )
        # Format the result to the format the metric expects.
        if self.args.version_2_with_negative:
            formatted_predictions = [
                {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
            ]
        else:
            formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
        references = [{"id": ex["id"], "answers": ex['answers']} for ex in self.eval_examples]
        return EvalPrediction(predictions=formatted_predictions, label_ids=references)
    
    
    def _postprocess_qa_predictions(
        self,
        examples,
        features,
        predictions,
        version_2_with_negative= False,
        n_best_size = None,
        max_answer_length = None,
        null_score_diff_threshold = None,
        output_dir = None,
        prefix = None,
        is_world_process_zero = True,
    ):
        assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)."
        all_start_logits, all_end_logits = predictions

        assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features."

        # Build a map example to its corresponding features.
        example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
        features_per_example = collections.defaultdict(list)
        for i, feature in enumerate(features):
            features_per_example[example_id_to_index[feature["example_id"]]].append(i)

        # The dictionaries we have to fill.
        all_predictions = collections.OrderedDict()
        all_nbest_json = collections.OrderedDict()
        if self.args.version_2_with_negative:
            scores_diff_json = collections.OrderedDict()

        # Let's loop over all the examples!
        for example_index, example in enumerate(tqdm(examples)):
            # Those are the indices of the features associated to the current example.
            feature_indices = features_per_example[example_index]

            min_null_prediction = None
            prelim_predictions = []

            # Looping through all the features associated to the current example.
            for feature_index in feature_indices:
                # We grab the predictions of the model for this feature.
                start_logits = all_start_logits[feature_index]
                end_logits = all_end_logits[feature_index]
                # This is what will allow us to map some the positions in our logits to span of texts in the original
                # context.
                offset_mapping = features[feature_index]["offset_mapping"]
                # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
                # available in the current feature.
                token_is_max_context = features[feature_index].get("token_is_max_context", None)

                # Update minimum null prediction.
                feature_null_score = start_logits[0] + end_logits[0]
                if min_null_prediction is None or min_null_prediction["score"] > feature_null_score:
                    min_null_prediction = {
                        "offsets": (0, 0),
                        "score": feature_null_score,
                        "start_logit": start_logits[0],
                        "end_logit": end_logits[0],
                    }

                # Go through all possibilities for the `n_best_size` greater start and end logits.
                start_indexes = np.argsort(start_logits)[-1 : -self.args.n_best_size - 1 : -1].tolist()
                end_indexes = np.argsort(end_logits)[-1 : -self.args.n_best_size - 1 : -1].tolist()
                for start_index in start_indexes:
                    for end_index in end_indexes:
                        # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
                        # to part of the input_ids that are not in the context.
                        if (
                            start_index >= len(offset_mapping)
                            or end_index >= len(offset_mapping)
                            or offset_mapping[start_index] is None
                            or offset_mapping[end_index] is None
                        ):
                            continue
                        # Don't consider answers with a length that is either < 0 or > max_answer_length.
                        if end_index < start_index or end_index - start_index + 1 > self.args.max_answer_length:
                            continue
                        # Don't consider answer that don't have the maximum context available (if such information is
                        # provided).
                        if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
                            continue
                        prelim_predictions.append(
                            {
                                "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
                                "score": start_logits[start_index] + end_logits[end_index],
                                "start_logit": start_logits[start_index],
                                "end_logit": end_logits[end_index],
                            }
                        )
            if self.args.version_2_with_negative:
                # Add the minimum null prediction
                prelim_predictions.append(min_null_prediction)
                null_score = min_null_prediction["score"]

            # Only keep the best `n_best_size` predictions.
            predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:self.args.n_best_size]

            # Add back the minimum null prediction if it was removed because of its low score.
            if self.args.version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions):
                predictions.append(min_null_prediction)

            # Use the offsets to gather the answer text in the original context.
            context = example["context"]
            for pred in predictions:
                offsets = pred.pop("offsets")
                pred["text"] = context[offsets[0] : offsets[1]]

            # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
            # failure.
            if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""):
                predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0})

            # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
            # the LogSumExp trick).
            scores = np.array([pred.pop("score") for pred in predictions])
            exp_scores = np.exp(scores - np.max(scores))
            probs = exp_scores / exp_scores.sum()

            # Include the probabilities in our predictions.
            for prob, pred in zip(probs, predictions):
                pred["probability"] = prob

            # Pick the best prediction. If the null answer is not possible, this is easy.
            if not self.args.version_2_with_negative:
                all_predictions[example["id"]] = predictions[0]["text"]
            else:
                # Otherwise we first need to find the best non-empty prediction.
                i = 0
                while predictions[i]["text"] == "":
                    i += 1
                best_non_null_pred = predictions[i]

                # Then we compare to the null prediction using the threshold.
                score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"]
                scores_diff_json[example["id"]] = float(score_diff)  # To be JSON-serializable.
                if score_diff > self.args.null_score_diff_threshold:
                    all_predictions[example["id"]] = ""
                else:
                    all_predictions[example["id"]] = best_non_null_pred["text"]

            # Make `predictions` JSON-serializable by casting np.float back to float.
            all_nbest_json[example["id"]] = [
                {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
                for pred in predictions
            ]

        return all_predictions

### Initialize trainer

In [15]:
class QuestionAnsweringTrainingArguments(TrainingArguments):
    def __init__(self, *args, max_length=384, doc_stride=128, version_2_with_negative=False, 
                 null_score_diff_threshold=0., n_best_size=20, max_answer_length=30,  **kwargs):
        super().__init__(*args, **kwargs)
        
        self.max_length = max_length
        self.doc_stride = doc_stride
        self.version_2_with_negative = version_2_with_negative
        self.null_score_diff_threshold = null_score_diff_threshold
        self.n_best_size = n_best_size
        self.max_answer_length = max_answer_length

In [16]:
teacher_model = AutoModelForQuestionAnswering.from_pretrained(teacher_model_name)
batch_size = 16

frac_of_samples = 1

if frac_of_samples != 1:    
    train_ds = train_enc.select(range(int(frac_of_samples * train_enc.num_rows)))
    eval_ds = valid_enc.select(range(int(frac_of_samples * valid_enc.num_rows)))
    eval_raw_ds = squad["validation"].select(range(math.ceil(frac_of_samples * squad["validation"].num_rows)))
    
    assert eval_ds.num_rows == eval_raw_ds.num_rows
else:
    train_ds = train_enc
    eval_ds = valid_enc
    eval_raw_ds = squad["validation"]

print(f"Number of training examples: {train_ds.num_rows}")
print(f"Number of validation examples: {eval_ds.num_rows}")
print(f"Number of raw validation examples: {eval_raw_ds.num_rows}")

logging_steps = len(train_ds) // batch_size

teacher_args = QuestionAnsweringTrainingArguments(
    output_dir="checkpoints",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=2,
    weight_decay=0.01,
    logging_steps=logging_steps,
    disable_tqdm=False
)

data_collator = default_data_collator

Number of training examples: 88524
Number of validation examples: 10784
Number of raw validation examples: 10570


In [17]:
teacher_trainer = QuestionAnsweringTrainer(
    model=teacher_model,
    args=teacher_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    eval_examples=eval_raw_ds,
    tokenizer=teacher_tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

In [19]:
# teacher_trainer.evaluate()

In [18]:
teacher_trainer.train()

Epoch,Training Loss,Validation Loss,Exact Match,F1
1.0,1.348786,No log,78.798486,86.681167
2.0,0.820882,No log,80.075686,87.778703


HBox(children=(FloatProgress(value=0.0, max=10570.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10570.0), HTML(value='')))




TrainOutput(global_step=11066, training_loss=1.0848221554800117)

In [20]:
teacher_trainer.save_model('models/bert-base-uncased-finetuned-squad-v1')

### Create pipeline

In [21]:
teacher_pipe = QuestionAnsweringPipeline(teacher_trainer.model.to('cpu'), teacher_tokenizer)

context = squad['validation'][0]['context']
question = squad['validation'][0]['question']

# expected answer: 'Denver Broncos', score: 0.8437, start: 177, end: 191
result = teacher_pipe(question=question, context=context)
result

{'score': 0.6260135173797607,
 'start': 177,
 'end': 191,
 'answer': 'Denver Broncos'}

## Distillation

### Preprocess data

In [16]:
student_model_name = "distilbert-base-uncased"
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name)

In [None]:
max_length = 384 
doc_stride = 128 
pad_on_right = student_tokenizer.padding_side == "right"

fn_kwargs = {
    "tokenizer": student_tokenizer,
    "max_length": max_length,
    "doc_stride": doc_stride,
    "pad_on_right": pad_on_right
}

#### Preprocess training set

In [None]:
train_enc = squad['train'].map(prepare_train_features, fn_kwargs=fn_kwargs, batched=True, remove_columns=squad["train"].column_names)
train_enc

HBox(children=(FloatProgress(value=0.0, max=88.0), HTML(value='')))




Dataset({
    features: ['attention_mask', 'end_positions', 'input_ids', 'start_positions'],
    num_rows: 88524
})

#### Preprocess validation set

In [None]:
valid_enc = squad['validation'].map(prepare_validation_features, fn_kwargs=fn_kwargs, batched=True, remove_columns=squad["validation"].column_names)
valid_enc

HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))




Dataset({
    features: ['attention_mask', 'example_id', 'input_ids', 'offset_mapping'],
    num_rows: 10784
})

### Create distillation trainer

In [17]:
class DistillationTrainer(QuestionAnsweringTrainer):
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        
        self.teacher.eval()
        self.train_dataset.set_format(type=self.train_dataset.format["type"], columns=list(self.train_dataset.features.keys()))
        
    def compute_loss(self, model, inputs):
        inputs_stu = {
            "input_ids": inputs['input_ids'],
            "attention_mask": inputs['attention_mask'],
            "start_positions": inputs['start_positions'],
            "end_positions": inputs['end_positions'],
            }
        outputs_stu = model(**inputs_stu)
        loss = outputs_stu.loss
        start_logits_stu = outputs_stu.start_logits
        end_logits_stu = outputs_stu.end_logits
        
        with torch.no_grad():
            outputs_tea = self.teacher(
                input_ids=inputs["input_ids"],
                token_type_ids=inputs["token_type_ids"],
                attention_mask=inputs["attention_mask"],
            )
            start_logits_tea = outputs_tea.start_logits
            end_logits_tea = outputs_tea.end_logits
        assert start_logits_tea.size() == start_logits_stu.size()
        assert end_logits_tea.size() == end_logits_stu.size()
        
        loss_fct = nn.KLDivLoss(reduction="batchmean")
        loss_start = (
            loss_fct(
                F.log_softmax(start_logits_stu / self.args.temperature, dim=-1),
                F.softmax(start_logits_tea / self.args.temperature, dim=-1),
            )
            * (self.args.temperature ** 2)
        )
        loss_end = (
            loss_fct(
                F.log_softmax(end_logits_stu / self.args.temperature, dim=-1),
                F.softmax(end_logits_tea / self.args.temperature, dim=-1),
            )
            * (self.args.temperature ** 2)
        )
        loss_ce = (loss_start + loss_end) / 2.0

        loss = self.args.alpha_ce * loss_ce + self.args.alpha_squad * loss
        return loss

### Initialise and fine-tune trainer

In [18]:
class DistillationTrainingArguments(QuestionAnsweringTrainingArguments):
    def __init__(self, *args, alpha_ce=0.5, alpha_squad=0.5, temperature=2.0, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.alpha_ce = alpha_ce
        self.alpha_squad = alpha_squad
        self.temperature = temperature

In [19]:
student_model = AutoModelForQuestionAnswering.from_pretrained(student_model_name).to(device)
teacher_model = AutoModelForQuestionAnswering.from_pretrained('lewtun/bert-base-uncased-finetuned-squad-v1').to(device)

batch_size = 16

frac_of_samples = 1

if frac_of_samples != 1:    
    train_ds = train_enc.select(range(int(frac_of_samples * train_enc.num_rows)))
    eval_ds = valid_enc.select(range(int(frac_of_samples * valid_enc.num_rows)))
    eval_raw_ds = squad["validation"].select(range(math.ceil(frac_of_samples * squad["validation"].num_rows)))
    
    assert eval_ds.num_rows == eval_raw_ds.num_rows
else:
    train_ds = train_enc
    eval_ds = valid_enc
    eval_raw_ds = squad["validation"]

print(f"Number of training examples: {train_ds.num_rows}")
print(f"Number of validation examples: {eval_ds.num_rows}")
print(f"Number of raw validation examples: {eval_raw_ds.num_rows}")

logging_steps = len(train_ds) // batch_size

student_training_args = DistillationTrainingArguments(
    output_dir=f"checkpoints",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_steps=logging_steps,
    disable_tqdm=False
)

data_collator = default_data_collator

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=558.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=435659279.0, style=ProgressStyle(descri…


Number of training examples: 88524
Number of validation examples: 10784
Number of raw validation examples: 10570


In [20]:
distil_trainer = DistillationTrainer(
    model=student_model,
    teacher_model=teacher_model,
    args=student_training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    eval_examples=eval_raw_ds,
    tokenizer=student_tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

In [21]:
distil_trainer.evaluate()

HBox(children=(FloatProgress(value=0.0, max=10570.0), HTML(value='')))




{'eval_loss': 'No log',
 'eval_exact_match': 0.15137180700094607,
 'eval_f1': 7.167017942222715}

In [22]:
distil_trainer.train()

Epoch,Training Loss,Validation Loss,Exact Match,F1
1.0,1.487708,No log,76.40492,84.7624
2.0,0.764606,No log,77.417219,85.620589
3.0,0.609311,No log,78.391675,86.447313


HBox(children=(FloatProgress(value=0.0, max=10570.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10570.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10570.0), HTML(value='')))




TrainOutput(global_step=16599, training_loss=0.9538024121202958)

In [23]:
distil_trainer.save_model('models/distilbert-base-uncased-distilled-squad-v1')

### Create pipeline

In [27]:
student_pipe = QuestionAnsweringPipeline(distil_trainer.model.to('cpu'), student_tokenizer)

context = squad['validation'][0]['context']
question = squad['validation'][0]['question']
pprint(question + "\n" + context)

# expected answer: 'Denver Broncos', score: 0.8437, start: 177, end: 191
result = student_pipe(question=question, context=context)
result

('Which NFL team represented the AFC at Super Bowl 50?\n'
 'Super Bowl 50 was an American football game to determine the champion of the '
 'National Football League (NFL) for the 2015 season. The American Football '
 'Conference (AFC) champion Denver Broncos defeated the National Football '
 'Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super '
 "Bowl title. The game was played on February 7, 2016, at Levi's Stadium in "
 'the San Francisco Bay Area at Santa Clara, California. As this was the 50th '
 'Super Bowl, the league emphasized the "golden anniversary" with various '
 'gold-themed initiatives, as well as temporarily suspending the tradition of '
 'naming each Super Bowl game with Roman numerals (under which the game would '
 'have been known as "Super Bowl L"), so that the logo could prominently '
 'feature the Arabic numerals 50.')


{'score': 0.8734882473945618,
 'start': 177,
 'end': 191,
 'answer': 'Denver Broncos'}