In [None]:
from datasets import Dataset
from trl import GKDConfig, GKDTrainer, setup_chat_format
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import Accelerator

# must be set before importing torch/transformers
import os

# If reserved unallocated memory is large
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:64"

# (optional) avoid the fork/threads warning and nested parallelism
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

# Ensures that only 1 GPU is visible to torch/accelerate/transformers/trl
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"

torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high")

In [None]:
NUM_DUMMY_SAMPLES = 100

model_id = "HuggingFaceTB/SmolLM2-135M-Instruct"
teacher_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)


# The model to optimise
student = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="cuda:1",
)

if not getattr(tokenizer, "chat_template", None):
    student, tokenizer = setup_chat_format(student, tokenizer)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# The teacher model to calculate the KL divergence against
teacher = AutoModelForCausalLM.from_pretrained(
    teacher_id,
    torch_dtype=torch.bfloat16,
    device_map="cuda:1",
)

In [None]:
train_dataset = Dataset.from_dict(
    {
        "messages": [
            [
                {"role": "user", "content": "Hi, how are you?"},
                {"role": "assistant", "content": "I'm great thanks"},
            ]
        ]
        * NUM_DUMMY_SAMPLES
    }
)
eval_dataset = Dataset.from_dict(
    {
        "messages": [
            [
                {"role": "user", "content": "What colour is the sky?"},
                {"role": "assistant", "content": "The sky is blue"},
            ]
        ]
        * NUM_DUMMY_SAMPLES
    }
)

In [None]:
training_args = GKDConfig(
    output_dir="gkd-smol",
    per_device_train_batch_size=5,
    gradient_accumulation_steps=5,
    learning_rate=5e-4,
    num_train_epochs=10,
    logging_steps=5,
    eval_steps=5,
    eval_strategy="epoch",
    # do_eval=True,
    save_steps=50,
    seq_kd=True,  # sequence-level distillation
    lmbda=1,  # probability of on-policy generation
)

In [None]:
trainer = GKDTrainer(
    model=student,
    teacher_model=teacher,
    args=training_args,
    processing_class=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train()