In [None]:
from minilm import (
    MiniLMTrainer,
    MiniLMTrainingArguments,
    prepare_dataset,
    create_student,
)
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModel,
    DataCollatorWithPadding,
    TrainingArguments,
)
from pathlib import Path
from datetime import datetime

## Dataset

In [None]:
cache_dir = "../.cache"  # Optional
dataset_id = "bookcorpus/bookcorpus"
model_name = "google-bert/bert-base-uncased"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)

In [None]:
dataset = load_dataset(dataset_id, split="train", cache_dir=cache_dir)
dataset = dataset.select(range(min(len(dataset), 10_000)))  # Small dataset for testing

In [None]:
train_dataset = prepare_dataset(
    datasets=[dataset],
    tokenizer=tokenizer,
    max_seq_len=64,
    tokenization_kwargs={"padding": "do_not_pad"},
)

In [None]:
import random

random.seed(42)

val_dataset = dataset.select(
    random.sample(range(len(dataset)), 1_000)
)  # Small val dataset for testing

In [None]:
val_dataset = prepare_dataset(
    datasets=[val_dataset],
    tokenizer=tokenizer,
    max_seq_len=64,
    tokenization_kwargs={"padding": "do_not_pad"},
)

## Distillation Arguments

In [None]:
short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
short_model_name

In [None]:
output_dir = Path("./results")
dt = datetime.now().strftime("%Y-%b-%d_%H-%M-%S")
output_dir = output_dir / f"{short_model_name}_{dt}"

In [None]:
TrainingArguments(
    eval_strategy="steps",
)

In [None]:
args = MiniLMTrainingArguments(
    # Distillation arguments
    teacher_layer=12,
    student_layer=6,
    student_hidden_size=384,
    student_attention_heads=12,
    num_relation_heads=48,
    relations={
        (1, 1): 1.0,
        (2, 2): 1.0,
        (3, 3): 1.0,
    },
    # Training arguments
    output_dir=output_dir,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    learning_rate=6e-4,
    weight_decay=0.01,
    adam_beta1=0.9,
    adam_beta2=0.999,
    adam_epsilon=1e-6,
    max_steps=400_000,
    warmup_steps=4_000,
    logging_steps=10,  # 1_000,
    save_steps=500,  # 50_000,
    seed=42,
    ddp_find_unused_parameters=True,
    save_total_limit=5,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    prediction_loss_only=True,
    greater_is_better=False,
    save_strategy="steps",
    eval_strategy="steps",
    eval_steps=10,  # 50_000
)

## Models

In [None]:
teacher = AutoModel.from_pretrained(model_name, cache_dir=cache_dir)

In [None]:
student = create_student(
    teacher_model_name_or_path=model_name,
    args=args,
    use_teacher_weights=False,
    cache_dir=cache_dir,
)

In [None]:
student_tw = create_student(
    teacher_model_name_or_path=model_name,
    args=args,
    use_teacher_weights=True,
    cache_dir=cache_dir,
)

## Trainer

In [None]:
trainer = MiniLMTrainer(
    args=args,
    teacher_model=teacher,
    model=student,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=DataCollatorWithPadding(tokenizer, padding="longest"),
)

In [None]:
trainer.train()

---
Teacher Weights

In [None]:
trainer_tw = MiniLMTrainer(
    args=args,
    teacher_model=teacher,
    model=student_tw,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=DataCollatorWithPadding(tokenizer, padding="longest"),
)

In [None]:
trainer_tw.train()