# Install needed package.

In [1]:
# !pip install trl

# Imports and basic setups.

In [2]:
import math
import torch

from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding,
    set_seed,
)
from trl.experimental.gkd import GKDConfig, GKDTrainer

  from trl.experimental.gkd import GKDConfig, GKDTrainer


# Config.

In [3]:
SEED = 57382
set_seed(SEED)

STUDENT_CKPT = "Qwen/Qwen2-0.5B-Instruct"
TEACHER_CKPT = "Qwen/Qwen2-1.5B-Instruct"

# Small subset for a full end-to-end smoke test
TRAIN_N = 200
EVAL_N = 50

MAX_LENGTH = 512

OUTDIR_BASELINE = "./baseline_ce_eval"
OUTDIR_GKD = "./gkd-model"
OUTDIR_DISTILLED = "./distilled_ce_eval"

# Load models and ..


In [4]:
tokenizer = AutoTokenizer.from_pretrained(STUDENT_CKPT)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

student_model = AutoModelForCausalLM.from_pretrained(
    STUDENT_CKPT,
    torch_dtype="auto",
    device_map="auto",
)
teacher_model = AutoModelForCausalLM.from_pretrained(
    TEACHER_CKPT,
    torch_dtype="auto",
    device_map="auto",
)

# Make teacher frozen & eval (saves memory + avoids accidental grads)
teacher_model.eval()
for p in teacher_model.parameters():
    p.requires_grad_(False)

# Disable KV cache during training/eval to reduce memory spikes
student_model.config.use_cache = False
teacher_model.config.use_cache = False


def _model_vocab_size(m):
    emb = m.get_input_embeddings()
    return None if emb is None else int(emb.num_embeddings)


# If vocab mismatch happens (rare for same-family models), fall back to teacher tokenizer for teacher CE eval.
# This keeps the CE definition the same, but tokenization may differ slightly.
teacher_tokenizer = tokenizer
if _model_vocab_size(teacher_model) is not None and _model_vocab_size(teacher_model) != len(tokenizer):
    teacher_tokenizer = AutoTokenizer.from_pretrained(TEACHER_CKPT)
    if teacher_tokenizer.pad_token is None:
        teacher_tokenizer.pad_token = teacher_tokenizer.eos_token
    teacher_tokenizer.padding_side = "right"

`torch_dtype` is deprecated! Use `dtype` instead!


# Setup dataset.

In [5]:
raw = load_dataset("databricks/databricks-dolly-15k", split="train")

# Make a train/eval split (dataset itself is a single split)
split = raw.train_test_split(test_size=0.02, seed=SEED)
train_raw = split["train"].shuffle(seed=SEED).select(range(min(TRAIN_N, len(split["train"]))))
eval_raw = split["test"].shuffle(seed=SEED).select(range(min(EVAL_N, len(split["test"]))))


def dolly_to_messages(ex):
    """
    Convert Dolly record to ChatML-like messages:
      user: instruction (+ optional context)
      assistant: response
    """
    instruction = (ex.get("instruction") or "").strip()
    context = (ex.get("context") or "").strip()
    response = (ex.get("response") or "").strip()

    if context:
        user = f"Instruction:\n{instruction}\n\nContext:\n{context}"
    else:
        user = f"Instruction:\n{instruction}"

    return {
        "messages": [
            {"role": "user", "content": user},
            {"role": "assistant", "content": response},
        ]
    }


train_dataset = train_raw.map(dolly_to_messages, remove_columns=train_raw.column_names)
eval_dataset = eval_raw.map(dolly_to_messages, remove_columns=eval_raw.column_names)

In [6]:
def tokenize_for_ce_eval(example, tok, max_length=512):
    """
    Build labels so that ONLY assistant response tokens contribute to CE loss.
    Prompt tokens (user + assistant header) are masked with -100.
    """
    messages = example["messages"]

    full_text = tok.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=False,
    )
    prompt_text = tok.apply_chat_template(
        messages[:-1],
        tokenize=False,
        add_generation_prompt=True,  # ends right before assistant content
    )

    full = tok(
        full_text,
        truncation=True,
        max_length=max_length,
        padding=False,
        add_special_tokens=True,
    )
    prompt = tok(
        prompt_text,
        truncation=True,
        max_length=max_length,
        padding=False,
        add_special_tokens=True,
    )

    input_ids = full["input_ids"]
    attention_mask = full["attention_mask"]

    prompt_len = min(len(prompt["input_ids"]), len(input_ids))
    labels = [-100] * prompt_len + input_ids[prompt_len:]

    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}


class DataCollatorForCausalLMEval:
    """Pads input_ids/attention_mask with tokenizer; pads labels with -100."""
    def __init__(self, tok):
        self.tok = tok
        self.padder = DataCollatorWithPadding(tokenizer=tok, padding=True)

    def __call__(self, features):
        labels = [f["labels"] for f in features]
        feats = [{k: v for k, v in f.items() if k != "labels"} for f in features]
        batch = self.padder(feats)

        max_len = batch["input_ids"].shape[1]
        padded_labels = []
        for lab in labels:
            padded_labels.append(lab + [-100] * (max_len - len(lab)))

        batch["labels"] = torch.tensor(padded_labels, dtype=torch.long)
        return batch


def build_tokenized_eval(ds, tok, max_length=512):
    return ds.map(
        lambda ex: tokenize_for_ce_eval(ex, tok, max_length=max_length),
        remove_columns=ds.column_names,
    )


def evaluate_ce_loss(model, tok, tokenized_eval_ds, out_dir, batch_size=4):
    """
    Returns:
      metrics dict with eval_loss, plus ppl (exp(loss)) for convenience.
    """
    args = TrainingArguments(
        output_dir=out_dir,
        per_device_eval_batch_size=batch_size,
        do_train=False,
        do_eval=True,
        report_to="none",
    )

    trainer = Trainer(
        model=model,
        args=args,
        eval_dataset=tokenized_eval_ds,
        tokenizer=tok,
        data_collator=DataCollatorForCausalLMEval(tok),
    )

    metrics = trainer.evaluate()
    loss = float(metrics["eval_loss"])
    ppl = math.exp(loss) if loss < 20 else float("inf")
    metrics["eval_ppl"] = ppl
    return metrics


# Pre-tokenize eval once (for student)
tokenized_eval_student = build_tokenized_eval(eval_dataset, tokenizer, max_length=MAX_LENGTH)

# Teacher may need its own tokenized eval if tokenizer fallback happened
tokenized_eval_teacher = tokenized_eval_student
if teacher_tokenizer is not tokenizer:
    tokenized_eval_teacher = build_tokenized_eval(eval_dataset, teacher_tokenizer, max_length=MAX_LENGTH)

# Baseline evaluation.

In [7]:
print("\n=== Baseline CE evaluation (answer-only, unified definition) ===")

teacher_metrics = evaluate_ce_loss(
    model=teacher_model,
    tok=teacher_tokenizer,
    tokenized_eval_ds=tokenized_eval_teacher,
    out_dir=f"{OUTDIR_BASELINE}/teacher",
    batch_size=4,
)
print(f"[Teacher] ce_eval_loss={teacher_metrics['eval_loss']:.4f} | ppl={teacher_metrics['eval_ppl']:.2f}")

student_metrics = evaluate_ce_loss(
    model=student_model,
    tok=tokenizer,
    tokenized_eval_ds=tokenized_eval_student,
    out_dir=f"{OUTDIR_BASELINE}/student",
    batch_size=4,
)
print(f"[Student] ce_eval_loss={student_metrics['eval_loss']:.4f} | ppl={student_metrics['eval_ppl']:.2f}")

  trainer = Trainer(
The model is already on multiple devices. Skipping the move to device specified in `args`.



=== Baseline CE evaluation (answer-only, unified definition) ===


  trainer = Trainer(
The model is already on multiple devices. Skipping the move to device specified in `args`.


[Teacher] ce_eval_loss=1.6615 | ppl=5.27


[Student] ce_eval_loss=1.8408 | ppl=6.30


# Run distillation with GKD.

In [8]:
print("\n=== Start GKD distillation (training logs show GKD loss; CE is evaluated separately) ===")

gkd_args = GKDConfig(
    output_dir=OUTDIR_GKD,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    num_train_epochs=3
)

gkd_trainer = GKDTrainer(
    model=student_model,
    teacher_model=teacher_model,
    args=gkd_args,
    processing_class=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,  # kept so Trainer can run its own eval if you enable it later
)

gkd_trainer.train()

The model is already on multiple devices. Skipping the move to device specified in `args`.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.



=== Start GKD distillation (training logs show GKD loss; CE is evaluated separately) ===


[34m[1mwandb[0m: Currently logged in as: [33mgreatgoose[0m ([33mgreatgoose-none[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


`generation_config` default values have been modified to match model-specific defaults: {'top_p': 0.8, 'repetition_penalty': 1.1}. If this is not desired, please set these values explicitly.


Step,Training Loss
10,0.616
20,0.6438
30,0.6757
40,0.5346
50,0.5357
60,0.4667
70,0.5277


TrainOutput(global_step=75, training_loss=0.5681931972503662, metrics={'train_runtime': 950.3154, 'train_samples_per_second': 0.631, 'train_steps_per_second': 0.079, 'total_flos': 234086839134720.0, 'train_loss': 0.5681931972503662, 'epoch': 3.0})

In [9]:
print("\n=== Post-distillation CE evaluation (answer-only, same metric as baseline) ===")

distilled_metrics = evaluate_ce_loss(
    model=student_model,  # updated by distillation
    tok=tokenizer,
    tokenized_eval_ds=tokenized_eval_student,  # same eval set, same tokenization
    out_dir=OUTDIR_DISTILLED,
    batch_size=4,
)
print(f"[Distilled Student] ce_eval_loss={distilled_metrics['eval_loss']:.4f} | ppl={distilled_metrics['eval_ppl']:.2f}")

  trainer = Trainer(
The model is already on multiple devices. Skipping the move to device specified in `args`.



=== Post-distillation CE evaluation (answer-only, same metric as baseline) ===


[Distilled Student] ce_eval_loss=1.8608 | ppl=6.43
