In [1]:
from minilm import (
    MiniLMTrainer,
    MiniLMTrainingArguments,
    prepare_dataset,
    create_student,
)
import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModel,
    DataCollatorWithPadding,
    DataCollatorForLanguageModeling,
    TrainingArguments,
)
from pathlib import Path
from datetime import datetime

## Dataset

In [2]:
class ModernBertDataCollator(DataCollatorForLanguageModeling):
    def __call__(self, batch):
        batch = super().__call__(batch)
        # Compute position_ids for unpadded sequences
        seq_lens = [torch.sum(item["attention_mask"]).item() for item in batch]
        position_ids = [torch.arange(length) for length in seq_lens]
        batch["position_ids"] = torch.cat(position_ids).to(batch["input_ids"].device)
        return batch

In [3]:
cache_dir = "../.cache"  # Optional
dataset_id = "bookcorpus/bookcorpus"
# model_name = "google-bert/bert-base-uncased"
model_name = "answerdotai/ModernBERT-base"

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)

In [5]:
dataset = load_dataset(dataset_id, split="train", cache_dir=cache_dir)
dataset = dataset.select(range(min(len(dataset), 10_000)))  # Small dataset for testing

In [6]:
train_dataset = prepare_dataset(
    datasets=[dataset],
    tokenizer=tokenizer,
    max_seq_len=64,
    tokenization_kwargs={"padding": "do_not_pad"},
)

In [7]:
import random

random.seed(42)

val_dataset = dataset.select(
    random.sample(range(len(dataset)), 1_000)
)  # Small val dataset for testing

In [8]:
val_dataset = prepare_dataset(
    datasets=[val_dataset],
    tokenizer=tokenizer,
    max_seq_len=64,
    tokenization_kwargs={"padding": "do_not_pad"},
)

## Distillation Arguments

In [9]:
short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]

In [10]:
output_dir = Path("./results")
dt = datetime.now().strftime("%Y-%b-%d_%H-%M-%S")
output_dir = output_dir / f"{short_model_name}_{dt}"

In [None]:
TrainingArguments(
    eval_strategy="steps",
    loss
)

In [11]:
args = MiniLMTrainingArguments(
    # Distillation arguments
    teacher_layer=22,  # 12
    student_layer=12,
    student_hidden_size=384,
    student_attention_heads=12,
    num_relation_heads=48,
    relations={
        (1, 1): 1.0,
        (2, 2): 1.0,
        (3, 3): 1.0,
    },
    # Training arguments
    output_dir=output_dir,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    learning_rate=6e-4,
    weight_decay=0.01,
    adam_beta1=0.9,
    adam_beta2=0.999,
    adam_epsilon=1e-6,
    max_steps=400_000,
    warmup_steps=4_000,
    logging_steps=10,  # 1_000,
    save_steps=500,  # 50_000,
    seed=42,
    ddp_find_unused_parameters=True,
    save_total_limit=5,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    prediction_loss_only=True,
    greater_is_better=False,
    save_strategy="steps",
    eval_strategy="steps",
    eval_steps=10,  # 50_000
)

## Models

In [12]:
teacher = AutoModel.from_pretrained(model_name, cache_dir=cache_dir)

In [13]:
student = create_student(
    teacher_model_name_or_path=model_name,
    args=args,
    use_teacher_weights=False,
    cache_dir=cache_dir,
)

In [66]:
student_tw = create_student(
    teacher_model_name_or_path=model_name,
    args=args,
    use_teacher_weights=True,
    cache_dir=cache_dir,
)

## Trainer

In [67]:
trainer = MiniLMTrainer(
    args=args,
    teacher_model=teacher,
    model=student,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=DataCollatorWithPadding(tokenizer, padding="longest"),
)

In [17]:
trainer = MiniLMTrainer(
    args=args,
    teacher_model=teacher,
    model=student,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=ModernBertDataCollator(tokenizer),
)

Old bert

In [15]:
trainer.train()

Step,Training Loss,Validation Loss
10,1.3236,1.33541
20,1.3171,1.326876
30,1.3011,1.312937
40,1.3017,1.294551
50,1.2919,1.273274


KeyboardInterrupt: 

New ModernBert

In [16]:
trainer.train()

Step,Training Loss,Validation Loss
10,7.8946,7.886227
20,7.8309,7.883595
30,7.784,7.879015
40,7.867,7.872334
50,7.9569,7.863624
60,7.8124,7.851937
70,7.9173,7.836264
80,7.9475,7.814851


KeyboardInterrupt: 

In [18]:
trainer.train()

TypeError: string indices must be integers

---
Teacher Weights

In [17]:
trainer_tw = MiniLMTrainer(
    args=args,
    teacher_model=teacher,
    model=student_tw,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=DataCollatorWithPadding(tokenizer, padding="longest"),
)

Old BERT

In [None]:
Step	Training Loss	Validation Loss
10	1.213400	1.212632
20	1.204700	1.191857
30	1.176700	1.165211
40	1.168200	1.133592
50	1.149600	1.103302

In [17]:
trainer_tw.train()

Step,Training Loss,Validation Loss
10,1.26,1.268834
20,1.2543,1.254386
30,1.2319,1.233314
40,1.2261,1.207002
50,1.2074,1.178021


KeyboardInterrupt: 

ModernBERT

In [18]:
trainer_tw.train()

Step,Training Loss,Validation Loss
10,115.1899,112.757065
20,112.2192,107.928802
30,104.5463,100.267662
40,97.9708,90.428841
50,87.039,79.140587
60,72.1942,67.135345


KeyboardInterrupt: 

In [None]:
Old BERT Model trained on student with random weights:
Step	Training Loss	Validation Loss
10	1.328100	1.340491
20	1.318500	1.326254
30	1.296200	1.303429
40	1.289900	1.275088
50	1.272700	1.244863
60	1.226300	1.212218

Old Bert Model trained on student with teacher weights:
Step	Training Loss	Validation Loss
10	1.213400	1.212632
20	1.204700	1.191857
30	1.176700	1.165211
40	1.168200	1.133592
50	1.149600	1.103302

New BERT Model trained on student with random weights:
Step	Training Loss	Validation Loss
10	7.889100	7.881923
20	7.826600	7.878779
30	7.778000	7.873277
40	7.861000	7.865149
50	7.948200	7.854313

New BERT Model trained on student with teacher weights:
Step	Training Loss	Validation Loss
10	131.839500	126.133316
20	112.327400	103.819565
30	93.129700	74.777557
40	63.274800	52.100739
50	44.741300	36.296700
60	29.786600	24.159281

SyntaxError: invalid syntax (1345778915.py, line 1)