# Task specific knowledge distillation for BERT using `HuggingFace/transformers` and `IntelLabs/Model-Compression-Research-Package`

## Installation

In [None]:
! sudo -H /venv/bin/pip install -U transformers datasets tensorboard
! (git clone https://github.com/IntelLabs/Model-Compression-Research-Package.git || :) && sudo -H /venv/bin/pip install ./Model-Compression-Research-Package

In [None]:
import datasets
import transformers
import numpy as np

import model_compression_research as mcr

## Setup
In this example we will use the following models from the HuggingFace models hub. It is important to make sure both models use tokenizer since the student's tokenizer will be used to process the input for both models!

In [None]:
student_id = "google/bert_uncased_L-2_H-128_A-2"
teacher_id = "textattack/bert-base-uncased-SST-2"

## Datasets & Pre-processing

In [None]:
dataset_id = "glue"
dataset_config = "sst2"

In [None]:
dataset = datasets.load_dataset(dataset_id, dataset_config)

### Preprocessing

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained(student_id)
teacher_tokenizer = transformers.AutoTokenizer.from_pretrained(teacher_id)

In [None]:
max_seq_len = min(tokenizer.model_max_length, teacher_tokenizer.model_max_length)
def process(examples):
    tokenized_inputs = tokenizer(
        examples["sentence"], truncation=True, max_length=max_seq_len
    )
    return tokenized_inputs

tokenized_datasets = dataset.map(process, batched=True)
# tokenized_datasets = tokenized_datasets.rename_column("label","labels")

# tokenized_datasets["test"].features

## Setup training and distillation

In [None]:
# create label2id, id2label dicts for nice outputs for the model
labels = tokenized_datasets["train"].features["label"].names
num_labels = len(labels)
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

training_args = transformers.TrainingArguments(
    output_dir='./run',
    num_train_epochs=7,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    fp16=True,
    learning_rate=6e-5,
    seed=33,
    # logging & evaluation strategies
    logging_strategy="epoch", # to get more information to TB
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="tensorboard",
)

data_collator  = transformers.DataCollatorWithPadding(tokenizer=tokenizer, pad_to_multiple_of=8 if training_args.fp16 else 1)

student_model = transformers.AutoModelForSequenceClassification.from_pretrained(student_id, num_labels=num_labels, id2label=id2label, label2id=label2id)

teacher_model = transformers.AutoModelForSequenceClassification.from_pretrained(teacher_id, num_labels=num_labels, id2label=id2label, label2id=label2id)

distillation_model = mcr.hf_add_teacher_to_student(
    student_model,
    teacher_model,
    # scaling factor for CE loss of the student's logits against the ground truth labels
    student_alpha=0.5,
    # scaling factor for CE loss of the student's against the teacher's logits (KL divergence)
    teacher_ce_alpha=0.5,
    # temperature applied to the CE loss of the student's against the teacher's logits
    teacher_ce_temperature=4.0,
)

## Evaluation metric

In [None]:
# define metrics and metrics function
accuracy_metric = datasets.load_metric( "accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    acc = accuracy_metric.compute(predictions=predictions, references=labels)
    return {
        "accuracy": acc["accuracy"],
    }

## Training

In [None]:
trainer = transformers.Trainer(
    model=distillation_model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
# get final trained model by removing the teacher from it
final_model = mcr.hf_remove_teacher_from_student(distillation_model)

In [None]:
final_model