# 03b Â· CRAFT with TRL SFTTrainer

This notebook mirrors the TRL SFT workflow while adding the CRAFT objective.
It now formats conversations with `tokenizer.apply_chat_template` so only assistant
tokens contribute to the supervised loss, and it highlights the new length-handling
options for mixed SFT / contrastive dataloaders.


## 0. Optional environment setup


In [None]:
# !pip install -U "contrastive-ft[trl] @ git+https://github.com/omarkamali/craft"
# !pip install -U "datasets>=2.19" "transformers>=4.43" "trl>=0.9"

## 1. Imports


In [None]:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from craft.config import CRAFTSFTConfig
from craft.data import CRAFTCollator, make_craft_datasets
from craft.trainers import CRAFTSFTTrainer


## 2. Tokeniser helpers with chat template


In [None]:
MAX_LENGTH = 1024
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
tokenizer.pad_token = tokenizer.eos_token

def encode_chat(example):
    encoded = tokenizer.apply_chat_template(
        example["messages"],
        tokenize=True,
        add_generation_prompt=False,
        max_length=MAX_LENGTH,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
        return_dict=True,
        return_assistant_tokens_mask=True,
    )
    input_ids = encoded["input_ids"][0]
    attention_mask = encoded["attention_mask"][0]
    assistant_mask = encoded["assistant_masks"][0]
    labels = input_ids.clone().masked_fill(assistant_mask == 0, -100)
    return {
        "input_ids": input_ids.tolist(),
        "attention_mask": attention_mask.tolist(),
        "labels": labels.tolist(),
    }


## 3. Load conversational slices


In [None]:
sft_raw = load_dataset("HuggingFaceH4/ultrachat_200k", split="train[:0.3%]")
tokenized_sft = sft_raw.map(encode_chat, remove_columns=sft_raw.column_names)


## 4. Bundle & collator (self-align)


In [None]:
bundle = make_craft_datasets(tokenized_sft, strategy="self_align")
collator = CRAFTCollator()


## 5. Load base model


In [None]:
model = AutoModelForCausalLM.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
model.config.use_cache = False


## 6. Trainer configuration


In [None]:
training_args = CRAFTSFTConfig(
    output_dir="./outputs/craft-trl-sft",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=12,
    learning_rate=1.5e-5,
    num_train_epochs=1,
    logging_steps=5,
    save_steps=50,
    craft_alpha=0.65,
    craft_beta=0.5,
    craft_beta_mode="auto",
    craft_length_strategy="auto_beta",
    craft_pooling="mean",
    craft_report_metrics=["contrastive_accuracy", "representation_consistency"],
)

trainer = CRAFTSFTTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_sft,
    data_collator=collator,
    craft_bundle=bundle,
)


## 7. Train


In [None]:
trainer.train()


## 8. Inspect metrics


In [None]:
trainer.craft_metrics


## 9. Save weights & tokenizer


In [None]:
trainer.save_model("./outputs/craft-trl-sft")
tokenizer.save_pretrained("./outputs/craft-trl-sft")
