# 03c Â· CRAFT with TRL ORPOTrainer

This notebook adapts ORPO preference optimisation to include the CRAFT contrastive objective.
Conversations are formatted with `tokenizer.apply_chat_template`, enabling assistant-only
loss masking via `return_assistant_tokens_mask`.


## 0. Optional environment setup


In [None]:
# !pip install -U "contrastive-ft[trl] @ git+https://github.com/omarkamali/craft"
# !pip install -U "datasets>=2.19" "transformers>=4.43" "trl>=0.9"

## 1. Imports


In [None]:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from craft.config import CRAFTORPOConfig
from craft.data import CRAFTCollator, make_craft_datasets
from craft.trainers import CRAFTORPOTrainer


## 2. Tokeniser and helper functions


In [None]:
MAX_LENGTH = 768


tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
tokenizer.pad_token = tokenizer.eos_token


def to_messages(prompt: str, response: str):
    prompt = prompt or ""
    return [
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": response},
    ]


def encode_response(prompt: str, response: str):
    encoded = tokenizer.apply_chat_template(
        to_messages(prompt, response),
        tokenize=True,
        add_generation_prompt=False,
        padding="max_length",
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors="pt",
        return_dict=True,
        return_assistant_tokens_mask=True,
    )
    input_ids = encoded["input_ids"][0]
    attention_mask = encoded["attention_mask"][0]
    assistant_mask = encoded["assistant_masks"][0]
    labels = input_ids.clone().masked_fill(assistant_mask == 0, -100)
    return (
        input_ids.tolist(),
        attention_mask.tolist(),
        labels.tolist(),
        assistant_mask.tolist(),
    )


def encode_pref(example):
    prompt = example.get("context", "")
    chosen = example["chosen_response"]
    rejected = example["rejected_response"]

    chosen_ids, chosen_attn, chosen_labels, chosen_mask = encode_response(prompt, chosen)
    rejected_ids, rejected_attn, rejected_labels, rejected_mask = encode_response(prompt, rejected)
    prompt_ids, prompt_attn, _, _ = encode_response(prompt, "")

    return {
        "prompt_input_ids": prompt_ids,
        "prompt_attention_mask": prompt_attn,
        "chosen_input_ids": chosen_ids,
        "chosen_attention_mask": chosen_attn,
        "chosen_labels": chosen_labels,
        "chosen_assistant_mask": chosen_mask,
        "rejected_input_ids": rejected_ids,
        "rejected_attention_mask": rejected_attn,
        "rejected_labels": rejected_labels,
        "rejected_assistant_mask": rejected_mask,
    }


def encode_contrastive(example):
    anchor_tokens = tokenizer(
        example["premise"],
        padding="max_length",
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors="pt",
    )
    positive_tokens = tokenizer(
        example["hypothesis"],
        padding="max_length",
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors="pt",
    )
    return {
        "input_ids": anchor_tokens["input_ids"][0].tolist(),
        "attention_mask": anchor_tokens["attention_mask"][0].tolist(),
        "input_ids_tgt": positive_tokens["input_ids"][0].tolist(),
        "attention_mask_tgt": positive_tokens["attention_mask"][0].tolist(),
    }


## 3. Load datasets


In [None]:
pref_raw = load_dataset("Anthropic/hh-rlhf", split="train[:0.2%]")
pref_raw = pref_raw.rename_columns({"chosen": "chosen_response", "rejected": "rejected_response"})
tokenized_pref = pref_raw.map(encode_pref, remove_columns=pref_raw.column_names)

contrastive_raw = load_dataset("sentence-transformers/all-nli", split="train[:0.2%]")
tokenized_contrastive = contrastive_raw.map(
    encode_contrastive, remove_columns=contrastive_raw.column_names
)


## 4. Bundle & collator


In [None]:
bundle = make_craft_datasets(
    tokenized_pref,
    contrastive_dataset=tokenized_contrastive,
    strategy="paired_dataset",
)
collator = CRAFTCollator()


## 5. Load model


In [None]:
model = AutoModelForCausalLM.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
model.config.use_cache = False


## 6. Configure trainer


In [None]:
training_args = CRAFTORPOConfig(
    output_dir="./outputs/craft-trl-orpo",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=12,
    num_train_epochs=1,
    learning_rate=8e-6,
    logging_steps=5,
    save_steps=50,
    craft_alpha=0.5,
    craft_beta=0.4,
    craft_beta_mode="auto",
    craft_length_strategy="oversample",
    craft_pooling="mean",
)

trainer = CRAFTORPOTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_pref,
    data_collator=collator,
    craft_bundle=bundle,
)


## 7. Train


In [None]:
trainer.train()


## 8. Inspect metrics


In [None]:
trainer.craft_metrics


## 9. Save outputs


In [None]:
trainer.save_model("./outputs/craft-trl-orpo")
tokenizer.save_pretrained("./outputs/craft-trl-orpo")
