To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
<div class="align-center">
<a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
<a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
<a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
</div>

To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://docs.unsloth.ai/get-started/installing-+-updating).

You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)


### News


Unsloth's [Docker image](https://hub.docker.com/r/unsloth/unsloth) is here! Start training with no setup & environment issues. [Read our Guide](https://docs.unsloth.ai/new/how-to-train-llms-with-unsloth-and-docker).

[gpt-oss RL](https://docs.unsloth.ai/new/gpt-oss-reinforcement-learning) is now supported with the fastest inference & lowest VRAM. Try our [new notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/gpt-oss-(20B)-GRPO.ipynb) which creates kernels!

Introducing [Vision](https://docs.unsloth.ai/new/vision-reinforcement-learning-vlm-rl) and [Standby](https://docs.unsloth.ai/basics/memory-efficient-rl) for RL! Train Qwen, Gemma etc. VLMs with GSPO - even faster with less VRAM.

Unsloth now supports Text-to-Speech (TTS) models. Read our [guide here](https://docs.unsloth.ai/basics/text-to-speech-tts-fine-tuning).

Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).


### Installation

In [1]:
# !pip install requirements.txt

In [2]:
# ✅ ADD: GRIT integration - add to Python path
import sys
sys.path.insert(0, '/lambda/nfs/DiskUsEast1/finetuning_evaluation/comparative_study/0b_GRIT_utils')

### Unsloth

In [3]:
# CITA_GRIT - Use standard HuggingFace loading (no Unsloth)
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import os
from dotenv import load_dotenv

# Load environment variables
load_dotenv('/lambda/nfs/DiskUsEast1/finetuning_evaluation/.env')
hf_token = os.getenv('HF_TOKEN')

max_seq_length = 2048
dtype = torch.bfloat16  # BFloat16 for A100 GPU (Ampere architecture)
load_in_4bit = True

# BitsAndBytes 4-bit quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# Load base Llama-3 8B model
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B",
    device_map="auto",
    quantization_config=bnb_config,
    torch_dtype=dtype,
    token=hf_token,
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Meta-Llama-3-8B",
    use_fast=True,
    token=hf_token,
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

print(f"✅ Model loaded: {model.__class__.__name__}")
print(f"✅ Device map: {model.hf_device_map}")
print(f"✅ Tokenizer vocab size: {len(tokenizer)}")

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 4/4 [00:13<00:00,  3.46s/it]


✅ Model loaded: LlamaForCausalLM
✅ Device map: {'': 0}
✅ Tokenizer vocab size: 128256


In [4]:
# @title Alignment Handbook utils
import os
import re
from typing import List, Literal, Optional

from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
from datasets.builder import DatasetGenerationError


DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"


def apply_chat_template(
    example,
    tokenizer,
    task: Literal["sft", "generation", "rm", "dpo"] = "sft",
    assistant_prefix="<|assistant|>\n",
):
    def _strip_prefix(s, pattern):
        # Use re.escape to escape any special characters in the pattern
        return re.sub(f"^{re.escape(pattern)}", "", s)

    if task in ["sft", "generation"]:
        messages = example["messages"]
        # We add an empty system message if there is none
        if messages[0]["role"] != "system":
            messages.insert(0, {"role": "system", "content": ""})
        example["text"] = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True if task == "generation" else False,
        )
    elif task == "rm":
        if all(k in example.keys() for k in ("chosen", "rejected")):
            chosen_messages = example["chosen"]
            rejected_messages = example["rejected"]
            # We add an empty system message if there is none
            if chosen_messages[0]["role"] != "system":
                chosen_messages.insert(0, {"role": "system", "content": ""})
            if rejected_messages[0]["role"] != "system":
                rejected_messages.insert(0, {"role": "system", "content": ""})
            example["text_chosen"] = tokenizer.apply_chat_template(
                chosen_messages, tokenize=False
            )
            example["text_rejected"] = tokenizer.apply_chat_template(
                rejected_messages, tokenize=False
            )
        else:
            raise ValueError(
                f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
            )
    elif task == "dpo":
        if all(k in example.keys() for k in ("chosen", "rejected")):
            # Compared to reward modeling, we filter out the prompt, so the text is everything after the last assistant token
            prompt_messages = [
                [msg for msg in example["chosen"] if msg["role"] == "user"][0]
            ]
            # Insert system message
            if example["chosen"][0]["role"] != "system":
                prompt_messages.insert(0, {"role": "system", "content": ""})
            else:
                prompt_messages.insert(0, example["chosen"][0])
            # TODO: handle case where chosen/rejected also have system messages
            chosen_messages = example["chosen"][1:]
            rejected_messages = example["rejected"][1:]
            example["text_chosen"] = tokenizer.apply_chat_template(
                chosen_messages, tokenize=False
            )
            example["text_rejected"] = tokenizer.apply_chat_template(
                rejected_messages, tokenize=False
            )
            example["text_prompt"] = tokenizer.apply_chat_template(
                prompt_messages, tokenize=False, add_generation_prompt=True
            )
            example["text_chosen"] = _strip_prefix(
                example["text_chosen"], assistant_prefix
            )
            example["text_rejected"] = _strip_prefix(
                example["text_rejected"], assistant_prefix
            )
        else:
            raise ValueError(
                f"Could not format example as dialogue for `dpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
            )
    else:
        raise ValueError(
            f"Task {task} not supported, please ensure that the provided task is one of {['sft', 'generation', 'rm', 'dpo']}"
        )
    return example


def get_datasets(
    data_config: dict,
    splits: List[str] = ["train", "test"],
    shuffle: bool = True,
) -> DatasetDict:
    """
    Loads one or more datasets with varying training set proportions.

    Args:
        data_config (`DataArguments` or `dict`):
            Dataset configuration and split proportions.
        splits (`List[str]`, *optional*, defaults to `['train', 'test']`):
            Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
        shuffle (`bool`, *optional*, defaults to `True`):
            Whether to shuffle the training and testing/validation data.

    Returns
        [`DatasetDict`]: The dataset dictionary containing the loaded datasets.
    """

    if type(data_config) is dict:
        # Structure of the input is:
        #     dataset_mixer = {
        #             "dataset1": 0.5,
        #             "dataset1": 0.3,
        #             "dataset1": 0.2,
        #         }
        dataset_mixer = data_config
    else:
        raise ValueError(f"Data config {data_config} not recognized.")

    raw_datasets = mix_datasets(dataset_mixer, splits=splits, shuffle=shuffle)
    return raw_datasets


def mix_datasets(
    dataset_mixer: dict, splits: Optional[List[str]] = None, shuffle=True
) -> DatasetDict:
    """
    Loads and mixes datasets according to proportions specified in `dataset_mixer`.

    Args:
        dataset_mixer (`dict`):
            Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1.
        splits (Optional[List[str]], *optional*, defaults to `None`):
            Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
        shuffle (`bool`, *optional*, defaults to `True`):
            Whether to shuffle the training and testing/validation data.
    """
    raw_datasets = DatasetDict()
    raw_train_datasets = []
    raw_val_datasets = []
    fracs = []
    for ds, frac in dataset_mixer.items():
        fracs.append(frac)
        for split in splits:
            try:
                # Try first if dataset on a Hub repo
                dataset = load_dataset(ds, split=split)
            except DatasetGenerationError:
                # If not, check local dataset
                dataset = load_from_disk(os.path.join(ds, split))

            if "train" in split:
                raw_train_datasets.append(dataset)
            elif "test" in split:
                raw_val_datasets.append(dataset)
            else:
                raise ValueError(
                    f"Split type {split} not recognized as one of test or train."
                )

    if any(frac < 0 for frac in fracs):
        raise ValueError("Dataset fractions cannot be negative.")

    if len(raw_train_datasets) > 0:
        train_subsets = []
        for dataset, frac in zip(raw_train_datasets, fracs):
            train_subset = dataset.select(range(int(frac * len(dataset))))
            train_subsets.append(train_subset)
        if shuffle:
            raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=42)
        else:
            raw_datasets["train"] = concatenate_datasets(train_subsets)
    # No subsampling for test datasets to enable fair comparison across models
    if len(raw_val_datasets) > 0:
        if shuffle:
            raw_datasets["test"] = concatenate_datasets(raw_val_datasets).shuffle(
                seed=42
            )
        else:
            raw_datasets["test"] = concatenate_datasets(raw_val_datasets)

    if len(raw_datasets) == 0:
        raise ValueError(
            f"Dataset {dataset_mixer} not recognized with split {split}. Check the dataset has been correctly formatted."
        )

    return raw_datasets

<a name="Data"></a>
### Data Prep
We follow Huggingface's [Alignment Handbook](https://github.com/huggingface/alignment-handbook) for [Zephyr](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta) and use the [Ultra Feedback dataset](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized), and sample 0.5% of it to speed things up. You can sample the full dataset for a full run.

In [5]:
import sys
sys.path.insert(0,
'/lambda/nfs/DiskUsEast1/finetuning_evaluation/comparative_study/0c_DATA_PREP_utils/src')

from data_prep import load_pku_filtered, format_dataset

# Load PKU-SafeRLHF with clear safety contrast
dataset = load_pku_filtered(split="train", max_samples=None)
print(f"📊 Loaded {len(dataset)} training samples with clear safety contrast")

Loading PKU-Alignment/PKU-SafeRLHF (split=train)...
Original dataset size: 73,907
Filtered dataset size: 10,813 (clear safety contrast)
📊 Loaded 10813 training samples with clear safety contrast


In [6]:
# Format PKU-SafeRLHF for CITA training
dataset = format_dataset(dataset, method="cita")
print(f"✅ Formatted {len(dataset)} samples for CITA")
print(f"✅ Sample structure: {list(dataset[0].keys())}")
print(f"✅ Chosen messages: {len(dataset[0]['chosen'])} turns")
print(f"✅ First message role: {dataset[0]['chosen'][0]['role']}")  # Should be 'system'

# Show sample
print(f"\n📋 Sample chosen trajectory:")
for i, msg in enumerate(dataset[0]['chosen']):
    print(f"   {i+1}. {msg['role']}: {msg['content'][:80]}...")

✅ Formatted 10813 samples for CITA
✅ Sample structure: ['chosen', 'rejected']
✅ Chosen messages: 3 turns
✅ First message role: system

📋 Sample chosen trajectory:
   1. system: You are a helpful AI assistant. You must REFUSE any requests involving violence,...
   2. user: My roommate constantly borrowes my clothes without asking me first. How can I ex...
   3. assistant: You can express your anger in non-physical ways as well. You can calmly tell him...


We now add `LoRA adapters` so we only need to update `1 to 10% of all parameters` !

In [7]:
# CITA_GRIT - Use PEFT's native LoRA (no Unsloth)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import gc

# Clear cache before preparation
gc.collect()
torch.cuda.empty_cache()

# Prepare model for k-bit training
print("Preparing model for k-bit training...")
model = prepare_model_for_kbit_training(
    model,
    use_gradient_checkpointing=True,
)

# LoRA configuration
lora_config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.0,
    bias="none",
    task_type="CAUSAL_LM",
    inference_mode=False,
)

# Apply LoRA adapters
print("Applying LoRA adapters...")
model = get_peft_model(model, lora_config)
model.enable_input_require_grads()
model.config.use_cache = False

# Enable gradient checkpointing
model.gradient_checkpointing_enable()

# Print trainable parameters
model.print_trainable_parameters()

# Final memory check
gc.collect()
torch.cuda.empty_cache()
print(f"✅ LoRA adapters added successfully")
print(f"GPU memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
print(f"GPU memory reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")

Preparing model for k-bit training...
Applying LoRA adapters...
trainable params: 41,943,040 || all params: 8,072,204,288 || trainable%: 0.5196
✅ LoRA adapters added successfully
GPU memory allocated: 7.43 GB
GPU memory reserved: 11.17 GB


In [8]:
# Ensure GRIT path is in sys.path (in case this cell runs standalone)
import sys
if '/lambda/nfs/DiskUsEast1/finetuning_evaluation/comparative_study/0b_GRIT_utils' not in sys.path:
    sys.path.insert(0, '/lambda/nfs/DiskUsEast1/finetuning_evaluation/comparative_study/0b_GRIT_utils')

# ✅ Initialize GRIT Manager for CITA training
from grit.config import GRITConfig
from grit.manager import GRITManager

# Initialize config
grit_config = GRITConfig()

# Model settings
grit_config.lora_rank = 16
grit_config.lora_alpha = 16
grit_config.precision = "bf16"

# CONSERVATIVE SETTINGS (same as SFT_GRIT, DPO_GRIT)
grit_config.kfac_update_freq = 10
grit_config.reprojection_freq = 50
grit_config.kfac_damping = 1e-5
grit_config.lambda_kfac = 1e-6
grit_config.lambda_reproj = 1e-5
grit_config.kfac_min_samples = 16

# Warmup settings
grit_config.reprojection_warmup_steps = 20
grit_config.ng_warmup_steps = 0
grit_config.regularizer_warmup_steps = 0

# Rank adaptation - DISABLED for fair comparison
grit_config.enable_rank_adaptation = False
grit_config.use_two_sided_reprojection = True

# Logging - DISABLED to reduce overhead
grit_config.log_fisher_spectrum = False
grit_config.log_top_eigs = 0
grit_config.log_eig_heatmaps = False

print("🎯 Initializing GRIT Manager for CITA...")
grit_manager = GRITManager(
    model=model,
    config=grit_config,
    device="cuda",
)
print("✅ GRIT Manager initialized successfully!")

🎯 Initializing GRIT Manager for CITA...
Instrumented 224 LoRA modules with custom autograd.
Using r-dim (16x16) covariances.
GRITManager: Initialization complete.
🔍 Optimizing 224 key LoRA modules.
💾 K-FAC covariances kept on-device; snapshot to CPU only at inversion.
✅ GRIT Manager initialized successfully!


<a name="Train"></a>
### Train the DPO model
Now let's train our model. We do 3 epochs on 0.5% of the dataset to speed things up.

In [9]:
from transformers import TrainingArguments
from trl import DPOConfig
from datetime import datetime
import os
import sys

# No Unsloth patching for GRIT

# ✅ Set Llama-3 chat template (required by CITATrainer internally)
# CITATrainer extends DPOTrainer which uses apply_chat_template()
tokenizer.chat_template = """{% set loop_messages = messages %}{% for message in loop_messages %}{% set content
= '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + 
'<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% 
endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"""

print("✅ Chat template set for tokenizer (required by CITATrainer)")

# Import custom CITA trainer
sys.path.insert(0, '/lambda/nfs/DiskUsEast1/finetuning_evaluation/comparative_study/03a_CITA_Baseline')
from cita_trainer import CITATrainer

✅ Chat template set for tokenizer (required by CITATrainer)


In [10]:
# TensorBoard setup (match SFT notebook)
tensorboard_base_dir = "/home/ubuntu/DiskUsEast1/finetuning_evaluation/tensorboard_logs"
os.makedirs(tensorboard_base_dir, exist_ok=True)

run_name = "CITA_GRIT"
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
tensorboard_run_dir = os.path.join(tensorboard_base_dir, f"{run_name}_{timestamp}")

print(f"📊 TensorBoard logs: {tensorboard_run_dir}")

📊 TensorBoard logs: /home/ubuntu/DiskUsEast1/finetuning_evaluation/tensorboard_logs/CITA_GRIT_20251006_123436


In [11]:
cita_trainer = CITATrainer(
    model = model,
    ref_model = None,
    args = DPOConfig(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        max_steps = 300,
        learning_rate = 5e-6,
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        report_to = "tensorboard",
        logging_dir = tensorboard_run_dir,
        logging_first_step = True,
        beta = 0.1,
        max_length = 2048,
        max_prompt_length = 1024,
    ),
    lambda_kl = 0.01,
    train_dataset = dataset,
    processing_class = tokenizer,
)

# Note: GRIT will automatically instrument via LoRA module hooks
print("🎁 GRIT preconditioning active via instrumented LoRA modules")

Applying chat template to train dataset: 100%|██████████| 10813/10813 [00:07<00:00, 1521.94 examples/s]
Tokenizing train dataset: 100%|██████████| 10813/10813 [00:10<00:00, 1031.83 examples/s]


🎁 GRIT preconditioning active via instrumented LoRA modules


In [12]:
# Show current memory stats (match SFT notebook)
import torch

gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA A100-SXM4-40GB. Max memory = 39.495 GB.
11.166 GB of memory reserved.


In [13]:
cita_trainer.train()



Step,Training Loss
1,5.1229
2,3.2176
3,1.9655
4,3.4331
5,4.6296
6,2.2848
7,2.0209
8,0.4083
9,0.172
10,2.1321


TrainOutput(global_step=300, training_loss=2.78987518923978, metrics={'train_runtime': 944.3949, 'train_samples_per_second': 2.541, 'train_steps_per_second': 0.318, 'total_flos': 0.0, 'train_loss': 2.78987518923978, 'epoch': 0.2219345293138524})

In [14]:
# Show final stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)

print(f"\n{'='*80}")
print(f"Training complete!")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"{'='*80}")


Training complete!
Peak reserved memory = 22.385 GB.
Peak reserved memory for training = 11.219 GB.
Peak reserved memory % of max memory = 56.678 %.


<a name="Inference"></a>
### Inference
Let's run the model! Unsloth makes inference natively 2x faster as well! You should use prompts which are similar to the ones you had finetuned on, otherwise you might get bad results!

In [15]:
# Save LoRA adapters (match SFT notebook)
model.save_pretrained("lora_model_CITA_GRIT")
tokenizer.save_pretrained("lora_model_CITA_GRIT")

print("✅ CITA model saved to: lora_model_CITA_GRIT/")

✅ CITA model saved to: lora_model_CITA_GRIT/


And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!

Some other links:
1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)
2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
6. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://docs.unsloth.ai/get-started/unsloth-notebooks)!

<div class="align-center">
  <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
  <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
  <a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>

  Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️
</div>
