#### HuggingFace Authentication & Configuration

In [None]:
# HuggingFace Authentication & Configuration
import os
from huggingface_hub import login
from dotenv import load_dotenv
from datetime import datetime

load_dotenv('/lambda/nfs/DiskUsEast1/finetuning_evaluation/.env')
hf_token = os.getenv('HF_TOKEN')
login(token=hf_token)

MODEL_NAME_MAP = {
    "SFT_Baseline": "kapilw25/llama3-8b-pku-sft-baseline",
    "SFT_GRIT": "kapilw25/llama3-8b-pku-sft-grit",
    "DPO_Baseline": "kapilw25/llama3-8b-pku-dpo-baseline",
    "DPO_GRIT": "kapilw25/llama3-8b-pku-dpo-grit",
    "CITA_Baseline": "kapilw25/llama3-8b-pku-cita-baseline",
    "CITA_GRIT": "kapilw25/llama3-8b-pku-cita-grit",
}

# Set for this notebook
RUN_NAME = "CITA_Baseline"  # ← This notebook's identifier
HF_REPO = MODEL_NAME_MAP[RUN_NAME] + "-bf16"  # Add -bf16 suffix

print(f"✅ HuggingFace authenticated")
print(f"📦 Model will be pushed to: {HF_REPO}")

### Unsloth

In [None]:
from unsloth import FastLanguageModel
import torch

# ✅ BF16 Configuration (no quantization)
max_seq_length = 2048  # Match SFT notebook
dtype = torch.bfloat16  # ✅ Explicit BF16
load_in_4bit = False    # ✅ Disable quantization

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Meta-Llama-3-8B",  # ✅ Full precision model
    max_seq_length = max_seq_length,
    dtype = dtype,  # ✅ BF16
    load_in_4bit = load_in_4bit,  # ✅ No quantization
    token = hf_token,  # ✅ Use token from HF auth
)

print(f"✅ Model loaded in {dtype} (no quantization)")
print(f"📊 Max sequence length: {max_seq_length}")

In [None]:
# @title Alignment Handbook utils
import os
import re
from typing import List, Literal, Optional

from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
from datasets.builder import DatasetGenerationError


DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"


def apply_chat_template(
    example,
    tokenizer,
    task: Literal["sft", "generation", "rm", "dpo"] = "sft",
    assistant_prefix="<|assistant|>\n",
):
    def _strip_prefix(s, pattern):
        # Use re.escape to escape any special characters in the pattern
        return re.sub(f"^{re.escape(pattern)}", "", s)

    if task in ["sft", "generation"]:
        messages = example["messages"]
        # We add an empty system message if there is none
        if messages[0]["role"] != "system":
            messages.insert(0, {"role": "system", "content": ""})
        example["text"] = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True if task == "generation" else False,
        )
    elif task == "rm":
        if all(k in example.keys() for k in ("chosen", "rejected")):
            chosen_messages = example["chosen"]
            rejected_messages = example["rejected"]
            # We add an empty system message if there is none
            if chosen_messages[0]["role"] != "system":
                chosen_messages.insert(0, {"role": "system", "content": ""})
            if rejected_messages[0]["role"] != "system":
                rejected_messages.insert(0, {"role": "system", "content": ""})
            example["text_chosen"] = tokenizer.apply_chat_template(
                chosen_messages, tokenize=False
            )
            example["text_rejected"] = tokenizer.apply_chat_template(
                rejected_messages, tokenize=False
            )
        else:
            raise ValueError(
                f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
            )
    elif task == "dpo":
        if all(k in example.keys() for k in ("chosen", "rejected")):
            # Compared to reward modeling, we filter out the prompt, so the text is everything after the last assistant token
            prompt_messages = [
                [msg for msg in example["chosen"] if msg["role"] == "user"][0]
            ]
            # Insert system message
            if example["chosen"][0]["role"] != "system":
                prompt_messages.insert(0, {"role": "system", "content": ""})
            else:
                prompt_messages.insert(0, example["chosen"][0])
            # TODO: handle case where chosen/rejected also have system messages
            chosen_messages = example["chosen"][1:]
            rejected_messages = example["rejected"][1:]
            example["text_chosen"] = tokenizer.apply_chat_template(
                chosen_messages, tokenize=False
            )
            example["text_rejected"] = tokenizer.apply_chat_template(
                rejected_messages, tokenize=False
            )
            example["text_prompt"] = tokenizer.apply_chat_template(
                prompt_messages, tokenize=False, add_generation_prompt=True
            )
            example["text_chosen"] = _strip_prefix(
                example["text_chosen"], assistant_prefix
            )
            example["text_rejected"] = _strip_prefix(
                example["text_rejected"], assistant_prefix
            )
        else:
            raise ValueError(
                f"Could not format example as dialogue for `dpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
            )
    else:
        raise ValueError(
            f"Task {task} not supported, please ensure that the provided task is one of {['sft', 'generation', 'rm', 'dpo']}"
        )
    return example


def get_datasets(
    data_config: dict,
    splits: List[str] = ["train", "test"],
    shuffle: bool = True,
) -> DatasetDict:
    """
    Loads one or more datasets with varying training set proportions.

    Args:
        data_config (`DataArguments` or `dict`):
            Dataset configuration and split proportions.
        splits (`List[str]`, *optional*, defaults to `['train', 'test']`):
            Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
        shuffle (`bool`, *optional*, defaults to `True`):
            Whether to shuffle the training and testing/validation data.

    Returns
        [`DatasetDict`]: The dataset dictionary containing the loaded datasets.
    """

    if type(data_config) is dict:
        # Structure of the input is:
        #     dataset_mixer = {
        #             "dataset1": 0.5,
        #             "dataset1": 0.3,
        #             "dataset1": 0.2,
        #         }
        dataset_mixer = data_config
    else:
        raise ValueError(f"Data config {data_config} not recognized.")

    raw_datasets = mix_datasets(dataset_mixer, splits=splits, shuffle=shuffle)
    return raw_datasets


def mix_datasets(
    dataset_mixer: dict, splits: Optional[List[str]] = None, shuffle=True
) -> DatasetDict:
    """
    Loads and mixes datasets according to proportions specified in `dataset_mixer`.

    Args:
        dataset_mixer (`dict`):
            Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1.
        splits (Optional[List[str]], *optional*, defaults to `None`):
            Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
        shuffle (`bool`, *optional*, defaults to `True`):
            Whether to shuffle the training and testing/validation data.
    """
    raw_datasets = DatasetDict()
    raw_train_datasets = []
    raw_val_datasets = []
    fracs = []
    for ds, frac in dataset_mixer.items():
        fracs.append(frac)
        for split in splits:
            try:
                # Try first if dataset on a Hub repo
                dataset = load_dataset(ds, split=split)
            except DatasetGenerationError:
                # If not, check local dataset
                dataset = load_from_disk(os.path.join(ds, split))

            if "train" in split:
                raw_train_datasets.append(dataset)
            elif "test" in split:
                raw_val_datasets.append(dataset)
            else:
                raise ValueError(
                    f"Split type {split} not recognized as one of test or train."
                )

    if any(frac < 0 for frac in fracs):
        raise ValueError("Dataset fractions cannot be negative.")

    if len(raw_train_datasets) > 0:
        train_subsets = []
        for dataset, frac in zip(raw_train_datasets, fracs):
            train_subset = dataset.select(range(int(frac * len(dataset))))
            train_subsets.append(train_subset)
        if shuffle:
            raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=42)
        else:
            raw_datasets["train"] = concatenate_datasets(train_subsets)
    # No subsampling for test datasets to enable fair comparison across models
    if len(raw_val_datasets) > 0:
        if shuffle:
            raw_datasets["test"] = concatenate_datasets(raw_val_datasets).shuffle(
                seed=42
            )
        else:
            raw_datasets["test"] = concatenate_datasets(raw_val_datasets)

    if len(raw_datasets) == 0:
        raise ValueError(
            f"Dataset {dataset_mixer} not recognized with split {split}. Check the dataset has been correctly formatted."
        )

    return raw_datasets

<a name="Data"></a>
### Data Prep
We follow Huggingface's [Alignment Handbook](https://github.com/huggingface/alignment-handbook) for [Zephyr](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta) and use the [Ultra Feedback dataset](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized), and sample 0.5% of it to speed things up. You can sample the full dataset for a full run.

In [None]:
import sys
sys.path.insert(0,
'/lambda/nfs/DiskUsEast1/finetuning_evaluation/comparative_study/0c_DATA_PREP_utils/src')

from data_prep import load_pku_filtered, format_dataset

# Load PKU-SafeRLHF with clear safety contrast
dataset = load_pku_filtered(split="train", max_samples=None)
print(f"📊 Loaded {len(dataset)} training samples with clear safety contrast")

In [None]:
# Format PKU-SafeRLHF for CITA training
dataset = format_dataset(dataset, method="cita")
print(f"✅ Formatted {len(dataset)} samples for CITA")
print(f"✅ Sample structure: {list(dataset[0].keys())}")
print(f"✅ Chosen messages: {len(dataset[0]['chosen'])} turns")
print(f"✅ First message role: {dataset[0]['chosen'][0]['role']}")  # Should be 'system'

# Show sample
print(f"\n📋 Sample chosen trajectory:")
for i, msg in enumerate(dataset[0]['chosen']):
    print(f"   {i+1}. {msg['role']}: {msg['content'][:80]}...")

We now add `LoRA adapters` so we only need to update `1 to 10% of all parameters` !

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 16,  # ✅ Match SFT (for fair comparison)
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,  # ✅ Match SFT
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
)

<a name="Train"></a>
### Train the DPO model
Now let's train our model. We do 3 epochs on 0.5% of the dataset to speed things up.

In [None]:
from transformers import TrainingArguments
from trl import DPOConfig  # ← Keep DPOConfig (CITA uses same config structure)
import os
import sys

# One must patch the DPO Trainer first!
from unsloth import PatchDPOTrainer
PatchDPOTrainer()

# ✅ Set chat template for Llama-3
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
    tokenizer,
    chat_template= "alpaca",  # ✅ CHANGED: Alpaca template
    #   WHY THIS CHANGE:
    #     - "llama-3" template uses complex chat format: <|user|>, <|assistant|> tokens
    #     - "alpaca" template uses simpler instruction format:
    #     Below are some instructions...
    #     ### Instruction:
    #     {prompt}
    #     ### Response:
    #     {response}
    #     - CITA's instruction conditioning P(Y+|I,X) may work better with explicit "Instruction:" markers
)

# Import custom CITA trainer
sys.path.insert(0, './comparative_study/03a_CITA_Baseline')
from cita_trainer import CITATrainer  # ← NEW: Custom trainer

In [None]:
# TensorBoard setup (match SFT notebook)
tensorboard_base_dir = "/home/ubuntu/DiskUsEast1/finetuning_evaluation/tensorboard_logs"
os.makedirs(tensorboard_base_dir, exist_ok=True)

run_name = "CITA_Baseline"
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
tensorboard_run_dir = os.path.join(tensorboard_base_dir, f"{run_name}_{timestamp}")

print(f"📊 TensorBoard logs: {tensorboard_run_dir}")

In [None]:
cita_trainer = CITATrainer(
    model = model,
    ref_model = None,
    args = DPOConfig(
        per_device_train_batch_size = 1,  # ✅ Reduced (BF16 uses 2x memory)
        gradient_accumulation_steps = 8,  # ✅ Increased (keep effective batch=8)
        warmup_steps = 5,
        max_steps = 1000,  # ✅ 3.3x longer training (Plan 9 Exp3)
        learning_rate = 2e-5,  # ✅ Standard DPO rate (Plan 9 Exp2)
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = f"outputs/{RUN_NAME}_BF16",  # ✅ Update output dir
        report_to = "tensorboard",
        logging_dir = tensorboard_run_dir,
        logging_first_step = True,
        beta = 0.1,
        fp16 = False,  # ✅ Explicitly disable FP16
        bf16 = True,   # ✅ Enable BF16
    ),
    lambda_kl = 0.01,
    train_dataset = dataset,
    processing_class = tokenizer,
    max_length = 2048,
    max_prompt_length = 1024,
)

print(f"✅ Training config: batch_size=1, grad_accum=8, effective_batch=8")
print(f"✅ Precision: BF16 (fp16=False, bf16=True)")
print(f"✅ Training steps: 1000 (3.3x longer than original)")
print(f"✅ Learning rate: 2e-5 (4x higher, standard DPO rate)")  # ← Add this line

In [None]:
# Show current memory stats (match SFT notebook)
import torch

gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

In [None]:
cita_trainer.train()

In [None]:
# Show final stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)

print(f"\n{'='*80}")
print(f"Training complete!")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"{'='*80}")

In [None]:
# ===================================================================
# Save Model & Push to HuggingFace
# ===================================================================

# ✅ Update model metadata with training stats
model.config.update({
    "training_date": datetime.now().strftime("%Y%m%d_%H%M%S"),
    "training_steps": cita_trainer.state.global_step,
    "training_loss": cita_trainer.state.log_history[-1].get('loss', 'N/A'),
    "dataset": "PKU-SafeRLHF",
    "filtered_samples": 10813,
    "max_steps": 1000,  # ✅ Updated to match actual training
    "method": "CITA",
    "precision": "BF16",
    "run_name": RUN_NAME,
    "chat_template": "alpaca",  # ✅ Add template info
    "learning_rate": 2e-5,  # ✅ Add LR info
})

# ✅ Save locally (backup)
local_path = f"lora_model_{RUN_NAME}_BF16"
model.save_pretrained(local_path)
tokenizer.save_pretrained(local_path)
print(f"✅ Saved locally: {local_path}/")

# ✅ Push to HuggingFace
print(f"\n📤 Pushing to HuggingFace: {HF_REPO}")
try:
    # ✅ Calculate loss value BEFORE push (avoid SyntaxError)
    loss_value = cita_trainer.state.log_history[-1].get('loss', 0.0)
    commit_msg = f"CITA BF16 Training: {cita_trainer.state.global_step} steps, Loss: {loss_value:.4f}"

    model.push_to_hub(
        HF_REPO,
        token=hf_token,
        commit_message=commit_msg,  # ✅ FIXED: Use pre-defined variable
        private=True,
    )
    tokenizer.push_to_hub(HF_REPO, token=hf_token, private=True)
    print(f"✅ Model available at: https://huggingface.co/{HF_REPO}")
except Exception as e:
    print(f"⚠️ HuggingFace push failed: {e}")
    print(f"   Model saved locally at: {local_path}/")