In [None]:
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl.experimental.gkd import GKDConfig, GKDTrainer

In [None]:
# tokenizer（以 teacher 为准）
tokenizer = AutoTokenizer.from_pretrained("GreatGoose/gemma3-4b-it-lora-loglm")

# student model
model = AutoModelForCausalLM.from_pretrained("GreatGoose/gemma3-270m-full-loglm")
model.resize_token_embeddings(len(tokenizer))

# teacher model
teacher_model = AutoModelForCausalLM.from_pretrained("GreatGoose/gemma3-4b-it-lora-loglm")
teacher_model.resize_token_embeddings(len(tokenizer))


# ===== 数据加载 =====
raw_data_path = "data/train_gemma3.jsonl"
ds = load_dataset("json", data_files={"train": raw_data_path})
split = ds["train"].train_test_split(test_size=0.2, seed=42)

train_dataset = split["train"]
eval_dataset = split["test"]


# ===== ChatML 转换 =====
def convert_to_chatml(example):
    new_messages = []
    for msg in example["messages"]:
        if msg["from"] == "human":
            role = "user"
        elif msg["from"] == "gpt":
            role = "assistant"
        else:
            raise ValueError(f"Unknown role: {msg['from']}")

        new_messages.append({
            "role": role,
            "content": msg["value"]
        })

    return {"messages": new_messages}


train_dataset = train_dataset.map(
    convert_to_chatml,
    remove_columns=train_dataset.column_names,
)

eval_dataset = eval_dataset.map(
    convert_to_chatml,
    remove_columns=eval_dataset.column_names,
)


# ===== 只取少量数据做快速验证 =====
train_dataset = train_dataset.select(range(min(100, len(train_dataset))))
eval_dataset  = eval_dataset.select(range(min(10, len(eval_dataset))))

In [None]:
# ===== 训练参数 =====
training_args = GKDConfig(
    learning_rate=1e-5,
    output_dir="gkd-quick-run",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=10,
    num_train_epochs=1.0,
    logging_steps=1,
    save_strategy="epoch",
)

In [None]:
# ===== Trainer =====
trainer = GKDTrainer(
    model=model,
    teacher_model=teacher_model,
    args=training_args,
    processing_class=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)


# ===== 开始训练 =====
trainer.train()

In [None]:
# ===== 保存所有结果 =====
trainer.save_model("gkd-quick-run")      # student 模型
tokenizer.save_pretrained("gkd-quick-run")