# 04 · CRAFT QLoRA Translation Benchmark

We compare a **base model**, an **SFT-only fine-tune**, and a **CRAFT fine-tune** on Flores translations.
Every supervised batch is tokenised with `tokenizer.apply_chat_template(..., return_assistant_tokens_mask=True)`
so loss is applied only on assistant tokens. The notebook also highlights the new length-matching options for
mixed SFT / InfoNCE loading.

> ⚠️ Expect GPU resources and non-trivial runtime. Adjust LoRA targets, batch sizes, and max steps for your setup.


## 0. Environment setup
Install requirements in your runtime if they're not already available.


In [None]:
# !pip install -U "contrastive-ft[all] @ git+https://github.com/omarkamali/craft"
# !pip install -U "unsloth>=2024.9.0" "bitsandbytes>=0.43" "peft>=0.11" "accelerate>=0.30" "datasets>=2.19" "evaluate>=0.4" "matplotlib>=3.8"

## 1. Imports & experiment configuration


In [None]:
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List

import torch
from datasets import load_dataset
from matplotlib import pyplot as plt
from transformers import AutoTokenizer, GenerationConfig

from craft.config import CRAFTSFTConfig
from craft.data import CRAFTCollator, make_craft_datasets
from craft.trainers import CRAFTSFTTrainer

BLEU = None

@dataclass
class ExperimentConfig:
    model_id: str = "unsloth/gemma-3-270m-it"
    source_lang: str = "eng_Latn"
    target_lang: str = "spa_Latn"
    sft_train_size: int = 1024
    contrastive_train_size: int = 1024
    eval_size: int = 128
    max_seq_length: int = 512
    lora_r: int = 8
    per_device_train_batch_size: int = 2
    gradient_accumulation_steps: int = 16
    learning_rate: float = 1.5e-4
    sft_max_steps: int = 150
    craft_max_steps: int = 200
    craft_alpha: float = 0.6
    craft_beta: float = 0.4
    craft_beta_mode: str = "auto"
    craft_length_strategy: str = "auto_beta"
    craft_contrastive_batch_size: int = 4
    eval_batch_size: int = 8
    generation: GenerationConfig = field(
        default_factory=lambda: GenerationConfig(max_new_tokens=128, temperature=0.7)
    )
    output_dir: Path = field(default_factory=lambda: Path("./outputs/craft-qlora"))

CFG = ExperimentConfig()
CFG.output_dir.mkdir(parents=True, exist_ok=True)
torch.manual_seed(42)
plt.style.use("seaborn-v0_8")


## 2. Tokenizer & chat templating helpers


In [None]:
tokenizer = AutoTokenizer.from_pretrained(CFG.model_id)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def build_translation_messages(example):
    source = example[f"sentence_{CFG.source_lang}"]
    target = example[f"sentence_{CFG.target_lang}"]
    return [
        {"role": "user", "content": f"Translate to {CFG.target_lang}: {source}"},
        {"role": "assistant", "content": target},
    ]

def encode_chat(messages):
    encoded = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=False,
        padding="max_length",
        truncation=True,
        max_length=CFG.max_seq_length,
        return_tensors="pt",
        return_dict=True,
        return_assistant_tokens_mask=True,
    )
    input_ids = encoded["input_ids"][0]
    attention_mask = encoded["attention_mask"][0]
    assistant_mask = encoded["assistant_masks"][0]
    labels = input_ids.clone().masked_fill(assistant_mask == 0, -100)
    return input_ids, attention_mask, assistant_mask, labels

def tokenize_sft(example):
    input_ids, attention_mask, assistant_mask, labels = encode_chat(build_translation_messages(example))
    return {
        "input_ids": input_ids.tolist(),
        "attention_mask": attention_mask.tolist(),
        "labels": labels.tolist(),
        "assistant_mask": assistant_mask.tolist(),
    }

def tokenize_contrastive(example):
    anchor = example[f"sentence_{CFG.source_lang}"]
    positive = example[f"sentence_{CFG.target_lang}"]
    anchor_tokens = tokenizer(
        anchor,
        padding="max_length",
        truncation=True,
        max_length=CFG.max_seq_length,
        return_tensors="pt",
    )
    positive_tokens = tokenizer(
        positive,
        padding="max_length",
        truncation=True,
        max_length=CFG.max_seq_length,
        return_tensors="pt",
    )
    return {
        "input_ids": anchor_tokens["input_ids"][0].tolist(),
        "attention_mask": anchor_tokens["attention_mask"][0].tolist(),
        "input_ids_tgt": positive_tokens["input_ids"][0].tolist(),
        "attention_mask_tgt": positive_tokens["attention_mask"][0].tolist(),
    }


## 3. Load & tokenise Flores subsets


In [None]:
flores = load_dataset("facebook/flores")

train_sft_raw = flores["dev"].select(range(CFG.sft_train_size))
train_contrastive_raw = flores["devtest"].select(range(CFG.contrastive_train_size))
val_eval = flores["devtest"].select(range(CFG.contrastive_train_size, CFG.contrastive_train_size + CFG.eval_size))

train_sft = train_sft_raw.map(tokenize_sft, remove_columns=train_sft_raw.column_names)
train_contrastive = train_contrastive_raw.map(
    tokenize_contrastive, remove_columns=train_contrastive_raw.column_names
)

len(train_sft), len(train_contrastive), len(val_eval)


## 4. Build dataset bundles & collator


In [None]:
bundle_craft = make_craft_datasets(
    train_sft,
    contrastive_dataset=train_contrastive,
    strategy="paired_dataset",
)
bundle_sft_only = make_craft_datasets(train_sft, strategy="self_align")
collator = CRAFTCollator()


## 5. Helper utilities (model loading, evaluation, cleanup)


In [None]:
from unsloth import FastLanguageModel

def load_qlora_model():
    model, adapter_tokenizer = FastLanguageModel.from_pretrained(
        model_name=CFG.model_id,
        max_seq_length=CFG.max_seq_length,
        load_in_4bit=True,
        dtype=None,
        device_map="auto",
    )
    adapter_tokenizer.pad_token = adapter_tokenizer.eos_token
    model = FastLanguageModel.get_peft_model(
        model,
        r=CFG.lora_r,
        lora_alpha=16,
        lora_dropout=0.05,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    )
    model.config.use_cache = False
    return model, adapter_tokenizer

def evaluate_bleu(model, tokenizer, dataset, *, max_examples: int):
    global BLEU
    if BLEU is None:
        import evaluate as hf_evaluate
        BLEU = hf_evaluate.load("sacrebleu")

    model.eval()
    preds: List[str] = []
    refs: List[List[str]] = []
    subset = dataset.select(range(min(len(dataset), max_examples)))

    for example in subset:
        prompt = f"Translate to {CFG.target_lang}: {example[f'sentence_{CFG.source_lang}']}"
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.inference_mode():
            generated = model.generate(**inputs, generation_config=CFG.generation)
        decoded = tokenizer.decode(generated[0], skip_special_tokens=True).strip()
        preds.append(decoded)
        refs.append([example[f"sentence_{CFG.target_lang}"]])

    bleu = BLEU.compute(predictions=preds, references=refs)["score"]
    return bleu, preds, refs

def cleanup_model(model):
    del model
    torch.cuda.empty_cache()


## 6. Baseline evaluation (untrained adapters)


In [None]:
results: List[Dict[str, float]] = []
predictions: Dict[str, Dict[str, List[str]]] = {}

base_model, base_tokenizer = load_qlora_model()
base_bleu, base_preds, base_refs = evaluate_bleu(
    base_model, base_tokenizer, val_eval, max_examples=CFG.eval_size
)
results.append({"label": "base", "bleu": base_bleu})
predictions["base"] = {"preds": base_preds, "refs": base_refs}
print(f"Baseline sacreBLEU: {base_bleu:.2f}")
cleanup_model(base_model)
del base_tokenizer


## 7. SFT-only fine-tuning (CRAFT disabled)


In [None]:
sft_model, sft_tokenizer = load_qlora_model()
sft_args = CRAFTSFTConfig(
    output_dir=str(CFG.output_dir / "sft_only"),
    per_device_train_batch_size=CFG.per_device_train_batch_size,
    gradient_accumulation_steps=CFG.gradient_accumulation_steps,
    max_steps=CFG.sft_max_steps,
    learning_rate=CFG.learning_rate,
    logging_steps=10,
    save_steps=CFG.sft_max_steps,
    bf16=torch.cuda.is_available(),
    craft_alpha=1.0,
    craft_beta=1.0,
    craft_length_strategy="error",
    craft_report_metrics=["contrastive_accuracy"],
)
sft_trainer = CRAFTSFTTrainer(
    model=sft_model,
    args=sft_args,
    train_dataset=train_sft,
    data_collator=collator,
    craft_bundle=bundle_sft_only,
)
sft_trainer.train()
sft_bleu, sft_preds, sft_refs = evaluate_bleu(
    sft_model, sft_tokenizer, val_eval, max_examples=CFG.eval_size
)
results.append({"label": "sft_only", "bleu": sft_bleu})
predictions["sft_only"] = {"preds": sft_preds, "refs": sft_refs}
print(f"SFT-only sacreBLEU: {sft_bleu:.2f}")
sft_trainer.save_model(str(CFG.output_dir / "sft_only"))
sft_tokenizer.save_pretrained(str(CFG.output_dir / "sft_only"))
cleanup_model(sft_model)
del sft_tokenizer
del sft_trainer


## 8. CRAFT fine-tuning with contrastive batches


In [None]:
craft_model, craft_tokenizer = load_qlora_model()
craft_args = CRAFTSFTConfig(
    output_dir=str(CFG.output_dir / "craft"),
    per_device_train_batch_size=CFG.per_device_train_batch_size,
    gradient_accumulation_steps=CFG.gradient_accumulation_steps,
    max_steps=CFG.craft_max_steps,
    learning_rate=CFG.learning_rate,
    logging_steps=10,
    save_steps=CFG.craft_max_steps,
    bf16=torch.cuda.is_available(),
    craft_alpha=CFG.craft_alpha,
    craft_beta=CFG.craft_beta,
    craft_beta_mode=CFG.craft_beta_mode,
    craft_contrastive_batch_size=CFG.craft_contrastive_batch_size,
    craft_length_strategy=CFG.craft_length_strategy,
    craft_report_metrics=["contrastive_accuracy", "representation_consistency"],
)
craft_trainer = CRAFTSFTTrainer(
    model=craft_model,
    args=craft_args,
    train_dataset=train_sft,
    data_collator=collator,
    craft_bundle=bundle_craft,
)
craft_trainer.train()
craft_bleu, craft_preds, craft_refs = evaluate_bleu(
    craft_model, craft_tokenizer, val_eval, max_examples=CFG.eval_size
)
results.append({"label": "craft", "bleu": craft_bleu})
predictions["craft"] = {"preds": craft_preds, "refs": craft_refs}
print(f"CRAFT sacreBLEU: {craft_bleu:.2f}")
craft_trainer.save_model(str(CFG.output_dir / "craft"))
craft_tokenizer.save_pretrained(str(CFG.output_dir / "craft"))
craft_history = craft_trainer.state.log_history


## 9. Compare BLEU across base, SFT, and CRAFT runs


In [None]:
labels = [entry["label"] for entry in results]
scores = [entry["bleu"] for entry in results]
plt.figure(figsize=(6, 4))
bars = plt.bar(labels, scores, color=["#7f8c8d", "#2980b9", "#27ae60"])
plt.ylabel("sacreBLEU")
plt.title("Flores sacreBLEU comparison")
plt.ylim(0, max(scores) + 2)
for bar, score in zip(bars, scores):
    plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.2, f"{score:.2f}", ha="center")
plt.show()

results


## 10. Inspect CRAFT training metrics


In [None]:
craft_logs = [entry for entry in craft_history if "loss/craft_total" in entry]
if craft_logs:
    steps = [entry.get("step", idx) for idx, entry in enumerate(craft_logs)]
    losses = [entry["loss/craft_total"] for entry in craft_logs]
    contrastive = [entry.get("metrics/craft_contrastive_accuracy") for entry in craft_logs]

    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(steps, losses, label="CRAFT total loss")
    plt.xlabel("Step")
    plt.ylabel("Loss")
    plt.title("CRAFT loss trajectory")
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(steps, contrastive, label="Contrastive accuracy")
    plt.xlabel("Step")
    plt.ylabel("Accuracy")
    plt.title("Contrastive metric")
    plt.legend()
    plt.tight_layout()
    plt.show()
else:
    print("No craft logs captured.")


## 11. Summarise insights


In [None]:
summary = {
    "bleu_scores": {entry["label"]: entry["bleu"] for entry in results},
    "best_run": max(results, key=lambda item: item["bleu"]),
    "craft_final_metrics": getattr(craft_trainer, "craft_metrics", {}),
    "length_strategy": CFG.craft_length_strategy,
    "contrastive_batch_size": CFG.craft_contrastive_batch_size,
}
summary


## 12. Optional: release GPU memory


In [None]:
cleanup_model(craft_model)
del craft_tokenizer
torch.cuda.empty_cache()
