In [37]:
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import DatasetDict, load_dataset
from huggingface_hub import HfFolder
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
)

In [30]:
student_id = "google/bert_uncased_L-2_H-128_A-2"
teacher_id = "textattack/bert-base-uncased-SST-2"

repo_name = "tiny-bert-sst2-distilled"

cache_dir = Path(".cache")
output_dir = Path("output")

## Check Teacher and Student tokenizer output
Knowledge distillation will only work if `Teacher` and `Student` have the same tokenizer!

In [13]:
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_id, cache_dir=cache_dir)
student_tokenizer = AutoTokenizer.from_pretrained(student_id, cache_dir=cache_dir)

sample = "Sample that tests the tokenizer"

assert teacher_tokenizer(sample) == student_tokenizer(
    sample
), "Tokenizers are different"

## Dataset
[Stanford Sentiment Treebank v2 (SST-2)](https://paperswithcode.com/dataset/sst) <br>
Labels: positive/negative

In [8]:
dataset_id = "glue"
dataset_config = "sst2"

In [15]:
dataset = load_dataset(dataset_id, dataset_config, cache_dir=cache_dir)

In [17]:
dataset["train"]

Dataset({
    features: ['sentence', 'label', 'idx'],
    num_rows: 67349
})

### Pre-processing and Tokenization
Converting dataset text into token IDs

In [22]:
def process(example: DatasetDict) -> DatasetDict:
    tokenized_inputs = teacher_tokenizer(
        example["sentence"], truncation=True, max_length=512
    )
    return tokenized_inputs

In [20]:
type(dataset)

datasets.dataset_dict.DatasetDict

In [23]:
tokenized_dataset = dataset.map(process, batched=True)
tokenized_dataset = tokenized_dataset.rename_column("label", "labels")

tokenized_dataset["test"].features

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

{'sentence': Value(dtype='string', id=None),
 'labels': ClassLabel(names=['negative', 'positive'], id=None),
 'idx': Value(dtype='int32', id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)}

## Model Distillation with `PyTorch` and `DistillationTrainer`

In [25]:
class DistillationTrainingArguments(TrainingArguments):
    def __init__(self, *args, alpha=0.5, temperature=2.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self.temperature = temperature

In [26]:
class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        self._move_model_to_device(self.teacher, self.model.device)
        self.teacher.eval()

    def compute_loss(
        self, model, inputs, return_outputs: bool = False, num_items_in_batch=None
    ):
        outputs_student = model(**inputs)
        student_loss = outputs_student.loss

        with torch.no_grad():
            outputs_teacher = self.teacher(**inputs)

        assert (
            outputs_student.logits.size() == outputs_teacher.logits.size()
        ), "Logits size of student and teacher should match, student: {}, teacher: {}".format(
            outputs_student.logits.size(), outputs_teacher.logits.size()
        )

        # Soften probabilities and compute distillation loss
        loss_function = nn.KLDivLoss(reduction="batchmean")
        loss_logits = loss_function(
            F.log_softmax(outputs_student.logits / self.args.temperature, dim=-1),
            F.softmax(outputs_teacher.logits / self.args.temperature, dim=-1),
        ) * (self.args.temperature**2)
        # Weighted student loss
        loss = self.args.alpha * student_loss + (1 - self.args.alpha) * loss_logits
        return (loss, outputs_student) if return_outputs else loss

### Hyperparameters Definition

In [None]:
labels = tokenized_dataset["train"].features["labels"].names
# labels: ['negative', 'positive']
num_labels = len(labels)
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for label, i in label2id.items()}

training_args = DistillationTrainingArguments(
    output_dir=output_dir,
    num_train_epochs=7,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    fp16=True,
    learning_rate=6e-5,
    seed=33,
    # Logging and evaluation strategy
    logging_dir=output_dir / "logs",
    logging_strategy="epoch",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="tensorboard",
    # Push to hub parameters
    push_to_hub=False,
    # hub_strategy="every_save",
    # hub_model_id=repo_name,
    # hub_token=HfFolder.get_token(),
    # Distillation parameters
    alpha=0.5,
    temperature=4.0,
)