# 02 · CRAFT Trainer Best Practices

This notebook expands on the basic workflow with:

1. Chat-template tokenisation with assistant-only masks.
2. Self-alignment (no external positives) while reusing assistant masks.
3. Parameter-efficient finetuning via QLoRA.
4. Custom `craft_beta` ratios and mixed precision hints.

> ⚠️ Adapt dataset subsampling, precision flags, and LoRA ranks for your hardware.


## 0. Optional environment setup


In [None]:
# !pip install -U "contrastive-ft[trl,peft] @ git+https://github.com/omarkamali/craft"
# !pip install -U "datasets>=2.19" "transformers>=4.43" "trl>=0.9" "accelerate>=0.30"

## 1. Imports


In [None]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model

from craft.config import CRAFTSFTConfig
from craft.data import CRAFTCollator, make_craft_datasets
from craft.trainers import CRAFTSFTTrainer


## 2. Tokeniser helpers with chat templates


In [None]:
MAX_LENGTH = 1024
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer.pad_token = tokenizer.eos_token

def encode_chat(example):
    encoded = tokenizer.apply_chat_template(
        example["messages"],
        tokenize=True,
        add_generation_prompt=False,
        padding="max_length",
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors="pt",
        return_dict=True,
        return_assistant_tokens_mask=True,
    )
    input_ids = encoded["input_ids"][0]
    attention_mask = encoded["attention_mask"][0]
    assistant_mask = encoded["assistant_masks"][0]
    labels = input_ids.clone()
    labels = labels.masked_fill(assistant_mask == 0, -100)
    return {
        "input_ids": input_ids.tolist(),
        "attention_mask": attention_mask.tolist(),
        "labels": labels.tolist(),
        "assistant_mask": assistant_mask.tolist(),
        "attention_mask_tgt": assistant_mask.tolist(),
    }

def encode_contrastive(example):
    anchor = tokenizer(
        example["premise"], padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors="pt"
    )
    positive = tokenizer(
        example["hypothesis"], padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors="pt"
    )
    return {
        "input_ids": anchor["input_ids"][0].tolist(),
        "attention_mask": anchor["attention_mask"][0].tolist(),
        "input_ids_tgt": positive["input_ids"][0].tolist(),
        "attention_mask_tgt": positive["attention_mask"][0].tolist(),
    }


## 3. Load and preprocess datasets


In [None]:
sft_raw = load_dataset("HuggingFaceH4/ultrachat_200k", split="train[:0.3%]")
contrastive_raw = load_dataset("sentence-transformers/all-nli", split="train[:0.3%]")

tokenized_sft = sft_raw.map(encode_chat, remove_columns=sft_raw.column_names)
tokenized_contrastive = contrastive_raw.map(
    encode_contrastive, remove_columns=contrastive_raw.column_names
)


## 4. Bundle + collator (self-align friendly)


In [None]:
bundle = make_craft_datasets(tokenized_sft, strategy="self_align")
collator = CRAFTCollator()


## 5. Prepare LoRA-wrapped model


In [None]:
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", load_in_4bit=True)
lora_cfg = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, target_modules=["q_proj", "v_proj"])
model = get_peft_model(base_model, lora_cfg)
model.config.use_cache = False


## 6. Configure trainer


In [None]:
training_args = CRAFTSFTConfig(
    output_dir="./outputs/craft-best-practices",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=1e-4,
    num_train_epochs=1,
    logging_steps=5,
    save_steps=50,
    bf16=True,
    craft_alpha=0.7,
    craft_beta=0.4,
    craft_pooling="cls",
    craft_assistant_mask_strategy="auto",
    craft_length_strategy="oversample",
)

trainer = CRAFTSFTTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_sft,
    data_collator=collator,
    craft_bundle=bundle,
)


## 7. Train


In [None]:
trainer.train()


## 8. Inspect metrics


In [None]:
trainer.state.log_history[-5:]


## 9. Save adapters


In [None]:
model.save_pretrained("./outputs/craft-best-practices")
tokenizer.save_pretrained("./outputs/craft-best-practices")
