<a href="https://colab.research.google.com/github/nbhimte/knowledge_distillation/blob/main/knowledge_distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Task-specific knowledge distillation for BERT using Hugging Face Transformers
### Text Classification Example using `BERT-Base` as Teacher and `BERT-Tiny` as Student

## Installation

In [None]:
#%pip install "pytorch==1.10.1"
%pip install transformers datasets tensorboard --upgrade

!sudo apt-get install git-lfs

Reading package lists... Done
Building dependency tree       
Reading state information... Done
git-lfs is already the newest version (2.3.4-1).
0 upgraded, 0 newly installed, 0 to remove and 39 not upgraded.


This example will use the [Hugging Face Hub](https://huggingface.co/models) as remote model versioning service. To be able to push our model to the Hub, you need to register on the [Hugging Face](https://huggingface.co/join).
If you already have an account you can skip this step.
After you have an account, we will use the `notebook_login` util from the `huggingface_hub` package to log into our account and store our token (access key) on the disk.

In [None]:
from huggingface_hub import notebook_login

notebook_login()


Login successful
Your token has been saved to /root/.huggingface/token


In [None]:
!git config --global credential.helper store

## Setup & Configuration

In this step we will define global configurations and paramters, which are used across the whole end-to-end fine-tuning proccess, e.g. `teacher` and `studen` we will use.

In this example, we will use [BERT-base](textattack/bert-base-uncased-SST-2) as Teacher and [BERT-Tiny](https://huggingface.co/google/bert_uncased_L-2_H-128_A-2) as Student. Our Teacher is already fine-tuned on our dataset, which makes it easy for us to directly start the distillation training job rather than fine-tuning the teacher first to then distill it afterwards.

_**IMPORTANT**: This example will only work with a `Teacher` & `Student` combination where the Tokenizer is creating the same output._

Additionally, describes the [FastFormers: Highly Efficient Transformer Models for Natural Language Understanding](https://arxiv.org/abs/2010.13382) paper an additional phenomenon.
> In our experiments, we have observed that dis-
tilled models do not work well when distilled to a
different model type. Therefore, we restricted our
setup to avoid distilling RoBERTa model to BERT
or vice versa. The major difference between the
two model groups is the input token (sub-word) em-
bedding. We think that different input embedding
spaces result in different output embedding spaces,
and knowledge transfer with different spaces does
not work well

In [None]:
# student_id = "google/bert_uncased_L-2_H-128_A-2"
student_id = "prajjwal1/bert-tiny-mnli"
# teacher_id = "textattack/bert-base-uncased-SST-2"
teacher_id = "ishan/bert-base-uncased-mnli"

# name for our repository on the hub
repo_name = "tiny-bert-mnli-distilled"

Below are some checks to make sure the `Teacher` & `Student` are creating the same output.

In [None]:
from transformers import AutoTokenizer

# init tokenizer
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_id)
student_tokenizer = AutoTokenizer.from_pretrained(student_id)

# sample input
sample = "This is a basic example, with different words to test."

# assert results
assert teacher_tokenizer(sample) == student_tokenizer(sample), "Tokenizers haven't created the same output"


loading configuration file https://huggingface.co/ishan/bert-base-uncased-mnli/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/3ac9c4ab3837356d9fe68f684b16a9f7d02b288d0fb8a9e0ccae7fab09c9b016.cda3a9d6349a2be3b0c218d9a1414c149919b95733a609cd1e04f6ff0e0472e0
Model config BertConfig {
  "_name_or_path": "ishan/bert-base-uncased-mnli",
  "architectures": [
    "BertForSequenceClassification"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "finetuning_task": "mnli",
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_

## Dataset & Pre-processing

As Dataset we will use the [Stanford Sentiment Treebank v2 (SST-2)](https://paperswithcode.com/dataset/sst) a text-classification for `sentiment-analysis`, which is included in the [GLUE benchmark](https://gluebenchmark.com/). The dataset is based on the dataset introduced by Pang and Lee (2005) and consists of 11,855 single sentences extracted from movie reviews. It was parsed with the Stanford parser and includes a total of 215,154 unique phrases from those parse trees, each annotated by 3 human judges. It uses the two-way (positive/negative) class split, with only sentence-level labels.


In [None]:
dataset_id="glue"
dataset_config="mnli"

To load the `sst2` dataset, we use the `load_dataset()` method from the 🤗 Datasets library.


In [None]:
from datasets import load_dataset

dataset = load_dataset(dataset_id,dataset_config)
dataset

Reusing dataset glue (/root/.cache/huggingface/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/5 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 392702
    })
    validation_matched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9815
    })
    validation_mismatched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9832
    })
    test_matched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9796
    })
    test_mismatched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9847
    })
})

### Pre-processing & Tokenization

To distill our model we need to convert our "Natural Language" to token IDs. This is done by a 🤗 Transformers Tokenizer which will tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary). If you are not sure what this means check out [chapter 6](https://huggingface.co/course/chapter6/1?fw=tf) of the Hugging Face Course.

We are going to use the tokenizer of the `Teacher`, but since both are creating same output you could also go with the `Student` tokenizer.


In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(teacher_id)

loading configuration file https://huggingface.co/ishan/bert-base-uncased-mnli/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/3ac9c4ab3837356d9fe68f684b16a9f7d02b288d0fb8a9e0ccae7fab09c9b016.cda3a9d6349a2be3b0c218d9a1414c149919b95733a609cd1e04f6ff0e0472e0
Model config BertConfig {
  "_name_or_path": "ishan/bert-base-uncased-mnli",
  "architectures": [
    "BertForSequenceClassification"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "finetuning_task": "mnli",
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_

Additionally we add the `truncation=True` and `max_length=512` to align the length and truncate texts that are bigger than the maximum size allowed by the model.

In [None]:
def process(examples):
    tokenized_inputs = tokenizer(
        examples["premise"], truncation=True, max_length=512
    )
    return tokenized_inputs

tokenized_datasets = dataset.map(process, batched=True)
tokenized_datasets = tokenized_datasets.rename_column("label","labels")

tokenized_datasets["test_matched"].features

Loading cached processed dataset at /root/.cache/huggingface/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9f857e39291dd5a1.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-2799fafa20454de2.arrow


  0%|          | 0/10 [00:00<?, ?ba/s]

Loading cached processed dataset at /root/.cache/huggingface/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-62444087ca2ad0a9.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-81278361b62030e8.arrow


{'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'hypothesis': Value(dtype='string', id=None),
 'idx': Value(dtype='int32', id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'labels': ClassLabel(num_classes=3, names=['entailment', 'neutral', 'contradiction'], id=None),
 'premise': Value(dtype='string', id=None),
 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)}

## Distilling the model using `PyTorch` and `DistillationTrainer`


Now that our `dataset` is processed, we can distill it. Normally, when fine-tuning a transformer model using PyTorch you should go with the `Trainer-API`. The [Trainer](https://huggingface.co/docs/transformers/v4.16.1/en/main_classes/trainer#transformers.Trainer) class provides an API for feature-complete training in PyTorch for most standard use cases.

In our example we cannot use the `Trainer` out-of-the-box, since we need to pass in two models, the `Teacher` and the `Student` and compute the loss for both. But we can subclass the `Trainer` to create a `DistillationTrainer` which will take care of it and only overwrite the [compute_loss](https://github.com/huggingface/transformers/blob/c4ad38e5ac69e6d96116f39df789a2369dd33c21/src/transformers/trainer.py#L1962) method as well as the `init` method. In addition to this we also need to subclass the `TrainingArguments` to include the our distillation hyperparameters.


In [None]:
from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationTrainingArguments(TrainingArguments):
    def __init__(self, *args, alpha=0.5, temperature=2.0, **kwargs):
        super().__init__(*args, **kwargs)

        self.alpha = alpha
        self.temperature = temperature

class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        # place teacher on same device as student
        self._move_model_to_device(self.teacher,self.model.device)
        self.teacher.eval()

    def compute_loss(self, model, inputs, return_outputs=False):

        # compute student output
        outputs_student = model(**inputs)

        # compute teacher output
        with torch.no_grad():
          outputs_teacher = self.teacher(**inputs)

        # assert size
        assert outputs_student.logits.size() == outputs_teacher.logits.size()

        # Soften probabilities and compute distillation loss
        loss_function = nn.KLDivLoss(reduction="batchmean")
        loss_logits = (loss_function(
            F.log_softmax(outputs_student.logits / self.args.temperature, dim=-1),
            F.softmax(outputs_teacher.logits / self.args.temperature, dim=-1)) * (self.args.temperature ** 2))
        # Return weighted student loss
        student_loss=outputs_student.loss
        loss = self.args.alpha * student_loss + (1. - self.args.alpha) * loss_logits
        return (loss, outputs_student) if return_outputs else loss

### Hyperparameter Definition, Model Loading

In [None]:
from transformers import AutoModelForSequenceClassification, DataCollatorWithPadding
from huggingface_hub import HfFolder

# create label2id, id2label dicts for nice outputs for the model
labels = tokenized_datasets["train"].features["labels"].names
num_labels = len(labels)
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

# define training args
training_args = DistillationTrainingArguments(
    output_dir=repo_name,
    num_train_epochs=4,
    #128
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    fp16=True,
    learning_rate=3e-4,
    seed=33,
    # logging & evaluation strategies
    logging_dir=f"{repo_name}/logs",
    logging_strategy="epoch", # to get more information to TB
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="tensorboard",
    # push to hub parameters
    push_to_hub=True,
    hub_strategy="every_save",
    hub_model_id=repo_name,
    hub_token=HfFolder.get_token(),
    # distilation parameters
    alpha=0.5,
    temperature=4.0
    )

# define data_collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# define model
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    teacher_id,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
)

# define student model
student_model = AutoModelForSequenceClassification.from_pretrained(
    student_id,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
)

PyTorch: setting up devices
loading configuration file https://huggingface.co/ishan/bert-base-uncased-mnli/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/3ac9c4ab3837356d9fe68f684b16a9f7d02b288d0fb8a9e0ccae7fab09c9b016.cda3a9d6349a2be3b0c218d9a1414c149919b95733a609cd1e04f6ff0e0472e0
Model config BertConfig {
  "_name_or_path": "ishan/bert-base-uncased-mnli",
  "architectures": [
    "BertForSequenceClassification"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "finetuning_task": "mnli",
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "entailment",
    "1": "neutral",
    "2": "contradiction"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "contradiction": "2",
    "entailment": "0",
    "neutral": "1"
  },
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_

### Evaluation metric

we can create a `compute_metrics` function to evaluate our model on the test set. This function will be used during the training process to compute the `accuracy` & `f1` of our model.

In [None]:
from datasets import load_metric
import numpy as np

# define metrics and metrics function
accuracy_metric = load_metric( "accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    acc = accuracy_metric.compute(predictions=predictions, references=labels)
    return {
        "accuracy": acc["accuracy"],
    }

## Training

Start training with calling `trainer.train`

In [None]:
!git config http.postBuffer 524288000

fatal: not in a git directory


In [None]:
trainer = DistillationTrainer(
    student_model,
    training_args,
    teacher_model=teacher_model,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation_matched"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

Cloning https://huggingface.co/nbhimte/tiny-bert-mnli-distilled into local empty directory.
Using amp half precision backend


start training using the `DistillationTrainer`.

In [None]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: hypothesis, idx, premise. If hypothesis, idx, premise are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 392702
  Num Epochs = 4
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 196352


Epoch,Training Loss,Validation Loss,Accuracy
1,0.7433,0.740455,0.328273
2,0.7287,0.736857,0.328069
3,0.7201,0.733062,0.326949
4,0.7121,0.732788,0.328986


The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: hypothesis, idx, premise. If hypothesis, idx, premise are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 9815
  Batch size = 8
Saving model checkpoint to tiny-bert-mnli-distilled/checkpoint-49088
Configuration saved in tiny-bert-mnli-distilled/checkpoint-49088/config.json
Model weights saved in tiny-bert-mnli-distilled/checkpoint-49088/pytorch_model.bin
tokenizer config file saved in tiny-bert-mnli-distilled/checkpoint-49088/tokenizer_config.json
Special tokens file saved in tiny-bert-mnli-distilled/checkpoint-49088/special_tokens_map.json
tokenizer config file saved in tiny-bert-mnli-distilled/tokenizer_config.json
Special tokens file saved in tiny-bert-mnli-distilled/special_tokens_map.json
The following columns in the evaluation set  don't have 

TrainOutput(global_step=196352, training_loss=0.7260512808311095, metrics={'train_runtime': 7024.1467, 'train_samples_per_second': 223.63, 'train_steps_per_second': 27.954, 'total_flos': 210715752136032.0, 'train_loss': 0.7260512808311095, 'epoch': 4.0})

## Hyperparameter Search for Distillation parameter `alpha` & `temperature` with optuna

The parameter `alpha` & `temparature` in the `DistillationTrainer` can also be used when doing Hyperparamter search to maxizime our "knowledge extraction". As Hyperparamter Optimization framework are we using [Optuna](https://optuna.org/), which has a integration into the `Trainer-API`. Since we the `DistillationTrainer` is a sublcass of the `Trainer` we can use the `hyperparameter_search` without any code changes.


In [None]:
!pip install optuna

Collecting optuna
  Downloading optuna-2.10.0-py3-none-any.whl (308 kB)
[K     |████████████████████████████████| 308 kB 5.2 MB/s 
Collecting alembic
  Downloading alembic-1.7.7-py3-none-any.whl (210 kB)
[K     |████████████████████████████████| 210 kB 53.4 MB/s 
Collecting colorlog
  Downloading colorlog-6.6.0-py2.py3-none-any.whl (11 kB)
Collecting cliff
  Downloading cliff-3.10.1-py3-none-any.whl (81 kB)
[K     |████████████████████████████████| 81 kB 8.5 MB/s 
[?25hCollecting cmaes>=0.8.2
  Downloading cmaes-0.8.2-py3-none-any.whl (15 kB)
Collecting Mako
  Downloading Mako-1.2.0-py3-none-any.whl (78 kB)
[K     |████████████████████████████████| 78 kB 7.9 MB/s 
Collecting pbr!=2.1.0,>=2.0.0
  Downloading pbr-5.8.1-py2.py3-none-any.whl (113 kB)
[K     |████████████████████████████████| 113 kB 53.1 MB/s 
[?25hCollecting stevedore>=2.0.1
  Downloading stevedore-3.5.0-py3-none-any.whl (49 kB)
[K     |████████████████████████████████| 49 kB 4.8 MB/s 
[?25hCollecting autopage>=0.

To do Hyperparameter Optimization using `optuna` we need to define our hyperparameter space. In this example we are trying to optimize/maximize the `num_train_epochs`, `learning_rate`, `alpha` & `temperature` for our `student_model`.

In [None]:
def hp_space(trial):
    return {
      "num_train_epochs": trial.suggest_int("num_train_epochs", 2, 10),
      "learning_rate": trial.suggest_float("learning_rate", 1e-5, 1e-3 ,log=True),
      "alpha": trial.suggest_float("alpha", 0, 1),
      "temperature": trial.suggest_int("temperature", 2, 30),
      }

To start our Hyperparmeter search we just need to call `hyperparameter_search` provide our `hp_space` and number of trials to run.

In [None]:
def student_init():
    return AutoModelForSequenceClassification.from_pretrained(
        student_id,
        num_labels=num_labels,
        id2label=id2label,
        label2id=label2id
    )

trainer = DistillationTrainer(
    model_init=student_init,
    args=training_args,
    teacher_model=teacher_model,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation_matched"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)
best_run = trainer.hyperparameter_search(
    n_trials=50,
    direction="maximize",
    hp_space=hp_space
)

print(best_run)

loading configuration file https://huggingface.co/prajjwal1/bert-tiny-mnli/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/6541d21313d66b47824c6d9dfe8500c425923193e04bb3f5421e9609a29e8ae7.77cb6cfe5ec6550086efef613e1c1be70610fcaaf9fc90c978e1b179b628eb46
Model config BertConfig {
  "_name_or_path": "prajjwal1/bert-tiny-mnli",
  "architectures": [
    "BertForSequenceClassification"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "finetuning_task": "mnli",
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 128,
  "id2label": {
    "0": "entailment",
    "1": "neutral",
    "2": "contradiction"
  },
  "initializer_range": 0.02,
  "intermediate_size": 512,
  "label2id": {
    "contradiction": "2",
    "entailment": "0",
    "neutral": "1"
  },
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 2,
  "num_hidden_layers": 2,


Epoch,Training Loss,Validation Loss


Epoch,Training Loss,Validation Loss,Accuracy
1,0.4589,0.448937,0.329088
2,0.4324,0.444797,0.328884


The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: hypothesis, idx, premise. If hypothesis, idx, premise are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 9815
  Batch size = 8
Saving model checkpoint to tiny-bert-mnli-distilled/run-0/checkpoint-49088
Configuration saved in tiny-bert-mnli-distilled/run-0/checkpoint-49088/config.json
Model weights saved in tiny-bert-mnli-distilled/run-0/checkpoint-49088/pytorch_model.bin
tokenizer config file saved in tiny-bert-mnli-distilled/run-0/checkpoint-49088/tokenizer_config.json
Special tokens file saved in tiny-bert-mnli-distilled/run-0/checkpoint-49088/special_tokens_map.json
tokenizer config file saved in tiny-bert-mnli-distilled/tokenizer_config.json
Special tokens file saved in tiny-bert-mnli-distilled/special_tokens_map.json
The following columns in t

Epoch,Training Loss,Validation Loss


Since optuna is just finding the best hyperparameters we need to fine-tune our model again using the best hyperparamters from the `best_run`.

In [None]:
# overwrite initial hyperparameters with from the best_run
for k,v in best_run.hyperparameters.items():
    setattr(training_args, k, v)

# Define a new repository to store our distilled model
best_model_ckpt = "tiny-bert-best"
training_args.output_dir = best_model_ckpt

We have overwritten the default Hyperparameters with the one from our `best_run` and can start the training now.

In [None]:
# Create a new Trainer with optimal parameters
optimal_trainer = DistillationTrainer(
    student_model,
    training_args,
    teacher_model=teacher_model,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation_matched"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

optimal_trainer.train()


# save best model, metrics and create model card
trainer.create_model_card(model_name=training_args.hub_model_id)
trainer.push_to_hub()

In [None]:
from huggingface_hub import HfApi

whoami = HfApi().whoami()
username = whoami['name']

print(f"https://huggingface.co/{username}/{repo_name}")

## Results

We were able to achieve a `accuracy` of 0.8337, which is a very good result for our model. Our distilled `Tiny-Bert` has 96% less parameters than the teacher `bert-base` and runs ~46.5x faster while preserving over 90% of BERT’s performances as measured on the SST2 dataset.

| model | Parameter | Speed-up | Accuracy |
|------------|-----------|----------|----------|
| BERT-base  | 109M      | 1x       | 93%      |
| tiny-BERT  | 4M        | 46.5x    | 83%      |

_Note: The [FastFormers paper](https://arxiv.org/abs/2010.13382) uncovered that the biggest boost in performance is observerd when having 6 or more layers in the student. The [google/bert_uncased_L-2_H-128_A-2](https://huggingface.co/google/bert_uncased_L-2_H-128_A-2) we used only had 2, which means when changing our student to, e.g. `distilbert-base-uncased` we should better performance in terms of accuracy._

If you are now planning to implement and add task-specific knowledge distillation to your models. I suggest to take a look at the [sagemaker-distillation](https://github.com/philschmid/knowledge-distillation-transformers-pytorch-sagemaker/blob/master/sagemaker-distillation.ipynb), which shows how to run task-specific knowledge distillation on Amazon SageMaker. For the example i created a script deriving this notebook to make it as easy as possible to use for you. You only need to define your `teacher_id`, `student_id` as well as your `dataset` config to run task-specific knowledge distillation for `text-classification`.

```python
from sagemaker.huggingface import HuggingFace

# hyperparameters, which are passed into the training job
hyperparameters={
    'teacher_id':'textattack/bert-base-uncased-SST-2',
    'student_id':'google/bert_uncased_L-2_H-128_A-2',
    'dataset_id':'glue',
    'dataset_config':'sst2',
    # distillation parameter
    'alpha': 0.5,
    'temparature': 4,
    # hpo parameter
    "run_hpo": True,
    "n_trials": 100,
}

# create the Estimator
huggingface_estimator = HuggingFace(..., hyperparameters=hyperparameters)

# start knwonledge distillation training
huggingface_estimator.fit()
```