# 01 Â· Basic CRAFT SFT Trainer

This notebook demonstrates a minimal fine-tuning pipeline using `CRAFTSFTTrainer`.
It now relies on `tokenizer.apply_chat_template` to format conversations and
derives assistant-only loss masks via `return_assistant_tokens_mask`.


## 0. Optional setup


In [None]:
# !pip install -U "contrastive-ft @ git+https://github.com/omarkamali/craft"
# !pip install -U "datasets>=2.19" "transformers>=4.43" "trl>=0.9"

## 1. Imports


In [None]:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from craft.config import CRAFTSFTConfig
from craft.data import CRAFTCollator, make_craft_datasets
from craft.trainers import CRAFTSFTTrainer


## 2. Load tiny demo splits


In [None]:
sft_dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split="train[:0.1%]")
contrastive_dataset = load_dataset("sentence-transformers/all-nli", split="train[:0.1%]")


## 3. Tokenizer & chat templating helpers


In [None]:
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
tokenizer.pad_token = tokenizer.eos_token
MAX_LENGTH = 512


def apply_chat(example):
    encoded = tokenizer.apply_chat_template(
        example["messages"],
        tokenize=True,
        add_generation_prompt=False,
        padding="max_length",
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors="pt",
        return_dict=True,
        return_assistant_tokens_mask=True,
    )
    input_ids = encoded["input_ids"][0]
    attention_mask = encoded["attention_mask"][0]
    assistant_mask = encoded["assistant_masks"][0]
    labels = input_ids.clone()
    labels = labels.masked_fill(assistant_mask == 0, -100)
    return {
        "input_ids": input_ids.tolist(),
        "attention_mask": attention_mask.tolist(),
        "labels": labels.tolist(),
        "assistant_mask": assistant_mask.tolist(),
    }


def tokenize_contrastive(example):
    anchor = tokenizer(
        example["premise"], padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors="pt"
    )
    positive = tokenizer(
        example["hypothesis"], padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors="pt"
    )
    return {
        "input_ids": anchor["input_ids"][0].tolist(),
        "attention_mask": anchor["attention_mask"][0].tolist(),
        "input_ids_tgt": positive["input_ids"][0].tolist(),
        "attention_mask_tgt": positive["attention_mask"][0].tolist(),
    }


## 4. Tokenise datasets


In [None]:
tokenized_sft = sft_dataset.map(apply_chat, remove_columns=sft_dataset.column_names)
tokenized_contrastive = contrastive_dataset.map(
    tokenize_contrastive, remove_columns=contrastive_dataset.column_names
)


## 5. Build dataset bundle & collator


In [None]:
bundle = make_craft_datasets(
    tokenized_sft, contrastive_dataset=tokenized_contrastive, strategy="paired_dataset"
)
collator = CRAFTCollator()


## 6. Load base model


In [None]:
model = AutoModelForCausalLM.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
model.config.use_cache = False


## 7. Trainer configuration


In [None]:
training_args = CRAFTSFTConfig(
    output_dir="./outputs/craft-basic",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=2e-5,
    num_train_epochs=1,
    logging_steps=10,
    save_steps=50,
    craft_alpha=0.5,
    craft_beta=0.6,
)

trainer = CRAFTSFTTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_sft,
    data_collator=collator,
    craft_bundle=bundle,
)


## 8. Train


In [None]:
trainer.train()


## 9. Save


In [None]:
trainer.save_model()
