In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import pipeline
from trl import SFTTrainer, SFTConfig
from trl import DPOTrainer, DPOConfig
from peft import LoraConfig


In this notebook we will continue working with the GSM8K dataset. So, let's dig in!

**`TODO:`** Load the `"openai/gsm8k"` dataset directly (don't split it into train and test set).

In [None]:
data = load_dataset("openai/gsm8k", "main")

## The `messages` Format for Chat Data

When training modern large language models (LLMs) for **instruction following** or **chat**, data is usually represented in a structured format called `messages`.

### What is `messages`?

- `messages` is a **list of objects**.
- Each object represents **one turn** in a conversation.
- Every object has at least two fields:
  - `"role"` → who is speaking (`"user"`, `"assistant"`, or `"system"`)
  - `"content"` → the text spoken by that role

### Example

```json
"messages": [
  {"role": "system", "content": "You are a helpful math tutor."},
  {"role": "user", "content": "What is 2 + 2?"},
  {"role": "assistant", "content": "The answer is 4."}
]


**`TODO:`**  
Write a function that reformats each sample by creating a new `messages` field.  

- Each sample should contain a list of dictionaries, where each dictionary has a `"role"` and a `"content"` key.  
- The `question` becomes a message with `"role": "user"`.  
- The `answer` becomes a message with `"role": "assistant"`.  
- A `"system"` message is not required and should be omitted.  

After defining the function, apply it to the dataset using `.map()`.  


In [None]:
def format_gsm8k(sample):
    return {
        "messages": [
            {"role": "user", "content": f"{sample['question']}"},
            {"role": "assistant", "content": sample["answer"]}
        ]
    }

data_fmt = data.map(format_gsm8k)

assert "messages" in data_fmt["train"].features, "The 'messages' feature is missing."

### Qwen2-0.5B
- **Qwen2-0.5B** is part of the Qwen2 model family, released by Alibaba as a compact dense language model.  
- It is a **decoder-only** model designed for text generation and general language tasks.  
- An **instruction-tuned variant** (“Instruct”) is also available, optimized to follow prompts more reliably.  
- Qwen2-0.5B is useful for tasks like chat, text completion, and prompt-based generation, especially when a smaller yet capable model is preferred.  

You can check out the model card [here](https://huggingface.co/Qwen/Qwen2-0.5B).  

**`TODO:`**  Load the tokenizer for `Qwen/Qwen2-0.5B`, and if it does not have a `pad_token`, set it equal to the `eos_token`.  

In [None]:
model_name = "Qwen/Qwen2-0.5B"  # or whichever base model
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

**`TODO`**: Tokenize the `messages` feature of the dataset by defining a function that does this and then using `.map`. Keep in mind that for this to be done you have to flatten the `dict` object but still keep the roles in the text.

In [None]:
def tokenize_fn(example):
    # Convert the messages list into a simple string format
    full = ""
    for msg in example["messages"]:
        prefix = "User: " if msg["role"] == "user" else "Assistant: "
        full += prefix + msg["content"] + "\n"
    return tokenizer(
        full,
        padding="max_length",
        truncation=True,
        max_length=512
    )

tokenized = data_fmt.map(tokenize_fn, batched=False)

## SFTTrainer
- The `trl` library (short for **Transformers Reinforcement Learning**) extends Hugging Face’s ecosystem with tools for fine-tuning and alignment.  
- **`SFTTrainer`** is a class in this library designed for **Supervised Fine-Tuning (SFT)** of language models.  
- It simplifies training by wrapping the Hugging Face `Trainer` with defaults tailored for instruction tuning and chat-like datasets.  
- You define your training setup with an `SFTConfig`, which specifies parameters like batch size, learning rate, number of epochs, evaluation steps, and more.  
- The `SFTTrainer` takes the model, your tokenized dataset, and the config, and runs the fine-tuning loop automatically.  
- By default, it can also handle chat-style formatting (with roles like *user* and *assistant*), which is useful when fine-tuning dialogue models.  

For further documentation regarding `SFTTrainer`, you are encouraged to explore the following [documentation](https://huggingface.co/docs/trl/en/sft_trainer).  


### SFTConfig
- **`SFTConfig`** is a configuration class used for **Supervised Fine-Tuning (SFT)** of language models.  
- It stores all the key training settings in one place (like batch size, learning rate, epochs, logging, saving, etc.).  
- This config is passed to the trainer to control *how* the fine-tuning process runs.  

**`TODO:`** Review the hyperparameters below and ensure you understand each; adjust values like `num_train_epochs`, `per_device_train_batch_size`, or `gradient_accumulation_steps` if training is too slow or your hardware can’t handle the defaults.  


In [None]:
config = SFTConfig(
    output_dir="gsm8k-instruct",      # folder where the trained model and checkpoints will be saved
    num_train_epochs=3,               # number of times the model will see the entire training dataset
    per_device_train_batch_size=32,   # how many examples each GPU processes at once
    gradient_accumulation_steps=2,    # accumulate gradients for 2 steps before updating weights (acts like a larger batch size)
    learning_rate=2e-5,               # how fast the model's weights are updated
    logging_steps=10,                 # log training metrics every 10 steps
    save_strategy="steps",            # choose when to save checkpoints (here: every few steps)
    save_steps=200,                   # save a checkpoint every 200 steps
    eval_strategy="steps",            # choose when to run evaluation (here: every few steps)
    eval_steps=200,                   # run evaluation every 200 steps
    bf16=True,                        # use bfloat16 precision for faster training and lower memory use (if supported by GPU)
    packing=False,                    # do not pack multiple sequences into one sample (requires special formatting if True)
    dataset_text_field=None,          # not needed here, since we already tokenized the data
    assistant_only_loss=True          # compute the loss only on the assistant’s outputs, not the user’s prompts
)

**`TODO:`** Using the documentation provided above, define an `SFTTrainer` with your model, config, and processed dataset, then call train the chosen model(use `.select()` to reduce the dataset size if training is too slow).  


In [None]:
# TODO: Train with LoRA
trainer = SFTTrainer(
    model=model_name,
    train_dataset=tokenized["train"].shuffle(seed=42).select(range(3000)),
    eval_dataset=tokenized["test"],
    args=config
)

trainer.train()
trainer.save_model("gsm8k-instruct")


### Evaluating Our Fine-Tuned Model
Let’s see how our model performs after Supervised Fine-Tuning!  
We’ll reload the trained model, run it on a few test questions, and compare its answers against the ground truth.  


In [None]:
# Reload the fine-tuned model
pipe = pipeline("text-generation", model="gsm8k-instruct", device_map="auto")

# Grab a few samples from test set
samples = tokenized["test"].shuffle(seed=0).select(range(5))

for s in samples:
    q = s["question"]
    print("Q:", q)
    out = pipe(f"Solve step by step:\n{q}", max_new_tokens=200, do_sample=False)
    print("Model answer:", out[0]["generated_text"])
    print("Ground truth:", s["answer"])
    print("----")


## PEFT
- PEFT (Parameter-Efficient Fine-Tuning) is a collection of methods for adapting large pre-trained models without updating all of their parameters.  
- Instead of retraining the entire model, PEFT techniques introduce or adjust a small set of additional parameters (e.g., adapters, prompts, low-rank matrices) while keeping most of the original model frozen.  
- This approach greatly reduces the computational cost, memory usage, and storage requirements of fine-tuning, making it practical to personalize and deploy large language models on smaller hardware.  
- Widely used in research and production for customizing large models to specific tasks, domains, or user data in a cost-effective way.  

You can explore Hugging Face’s [`peft`](https://huggingface.co/docs/peft/index) library, which implements popular PEFT methods like LoRA, prefix tuning, and prompt tuning.  


### LoRA (LoraConfig)
- LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning (PEFT) method that adapts large pre-trained models by injecting low-rank trainable matrices into specific layers (commonly the attention layers).  
- Instead of updating the full weight matrices, LoRA represents weight updates as the product of two smaller matrices, drastically reducing the number of trainable parameters.  
- This makes fine-tuning more memory-efficient, faster to train, and easier to deploy, while allowing multiple LoRA modules to be composed for different tasks.  
- Widely adopted in NLP and generative AI for customizing large language models without the heavy costs of full fine-tuning.  

In the Hugging Face `peft` library, LoRA is configured using [`LoraConfig`](https://huggingface.co/docs/peft/package_reference/lora).

In [None]:
lora_cfg = LoraConfig(
    r=16,  # (int) Rank of the low-rank update matrices. 
           # Controls the adaptation capacity. Typical values: 4–64. 
           # Higher = more expressive but more parameters.

    lora_alpha=32,  # (int) Scaling factor for the LoRA updates. 
                    # The effective weight update is scaled by alpha/r. 
                    # Larger values give stronger adaptation.

    lora_dropout=0.05,  # (float) Dropout probability applied to LoRA layers during training. 
                        # Helps regularize and prevent overfitting. 
                        # Set to 0.0 to disable.

    task_type="CAUSAL_LM"  # (str) Type of task for which LoRA is applied. 
                           # Must match the model family. 
                           # Common values:
                           #   "CAUSAL_LM"   → decoder-only LMs (e.g., GPT-style)
                           #   "SEQ_2_SEQ_LM" → encoder-decoder LMs (e.g., T5, BART)
                           #   "TOKEN_CLS"    → token classification
                           #   "SEQ_CLS"      → sequence classification
                           #   "QUESTION_ANS" → QA tasks
)

**`TODO:`** Retrain the model using **LoRA**. Compared to standard SFT, only two adjustments are needed:  
- Update your `SFTConfig` if you want to save the model in a new folder (and optionally raise the learning rate).  
- Pass a `LoraConfig` object to the trainer via the `peft_config` argument (check the documentation if unsure).  


In [None]:
# Update training config for LoRA fine-tuning -> No big changes
config = SFTConfig(
    output_dir="gsm8k-instruct-lora", # Change 1: Just to save in a different folder
    num_train_epochs=3,
    per_device_train_batch_size=32,
    gradient_accumulation_steps=2,
    learning_rate=1e-4,               # Change 2: Usually this is higher than full fine-tuning
    logging_steps=10,
    save_strategy="steps",
    save_steps=200,
    eval_strategy="steps",
    eval_steps=200,
    bf16=True,
    packing=False,
    assistant_only_loss=True
)

trainer = SFTTrainer(
    model=model_name,                   # base model (string or loaded)
    train_dataset=tokenized["train"].shuffle(seed=42).select(range(3000)),
    eval_dataset=tokenized["test"],
    args=config,
    peft_config=lora_cfg                # ← enables LoRA
)

trainer.train()
trainer.save_model("gsm8k-instruct-lora")

## DPO
- DPO (Direct Preference Optimization) is a training method for aligning language models with human preferences without the need for reinforcement learning.  
- Instead of using a reward model + reinforcement learning (like in RLHF), DPO directly optimizes the model on preference data (pairs of “chosen” vs “rejected” responses).  
- This makes alignment training simpler, more stable, and often more efficient while still steering models towards producing responses preferred by humans.  
- Widely used in instruction-tuning and alignment pipelines as a lightweight alternative to RLHF.  

You can explore Hugging Face’s [`DPO`](https://huggingface.co/docs/trl/main/en/dpo_trainer) documentation in the `trl` library.

In [None]:
dpo_cfg = DPOConfig(
    beta=0.1,  # (float) Inverse temperature for preference optimization. 
               # Controls sharpness of the preference signal. 
               # Lower beta → stronger emphasis on chosen vs rejected. 
               # Typical range: 0.05 – 0.2.

    loss_type="sigmoid",  # (str) Loss function for comparing chosen vs rejected outputs.
                          # Options:
                          #   "sigmoid" (default) → smooth logistic preference loss
                          #   "hinge" → margin-based loss

    label_smoothing=0.0,  # (float) Applies label smoothing to the preference loss. 
                          # Helps regularize training when preference data is noisy. 
                          # Common values: 0.0 (none) – 0.1.

    max_length=512,       # (int) Maximum total sequence length (prompt + response).
                          # Longer sequences will be truncated.

    max_prompt_length=128,# (int) Maximum length for the prompt portion only. 
                          # Ensures balanced context vs response length.

    max_target_length=384,# (int) Maximum length for the generated response. 
                          # Ensures responses don’t dominate input length.

    truncation_mode="keep_end",  
                          # (str) How to truncate when sequences exceed max length.
                          #   "keep_end" → keep last tokens (default, useful for responses)
                          #   "keep_start" → keep first tokens (useful for prompts)

    generate_during_eval=False, 
                          # (bool) If True, generates model outputs during evaluation 
                          # instead of just scoring given responses. 
                          # More realistic but slower.

    # Standard Trainer config options (like learning_rate, batch sizes, logging, etc.) 
    # can also be passed when initializing the trainer.
)


### DPO Datasets
- DPO (Direct Preference Optimization) requires datasets where each example contains a **prompt** and two possible responses:  
  - **chosen** → the response preferred by humans (or a proxy, e.g. a reward model).  
  - **rejected** → the less-preferred response.  
- During training, the model is optimized to assign higher probability to the **chosen** response compared to the **rejected** one, given the same prompt.  
- This format is much simpler than reinforcement learning with a reward model, since the training objective is directly defined over these pairs.  
- Widely used in alignment pipelines for instruction-tuned LLMs, where the dataset may come from human annotations, pairwise preference collection, or synthetic generation (e.g., using a reward model or rule-based scoring).  

**`TODO:`** Load and explore the `xinlai/Math-Step-DPO-10K`dataset. What are the features? Print a few samples to get a good idea on what you're working on.

In [None]:
dataset = load_dataset("xinlai/Math-Step-DPO-10K")
print(dataset + "\n")

for k, v in dataset["train"][0].items():
    print(f"{k}: {v}\n")

**`TODO:`** Load your trained model and further train it using DPO on the following last given dataset. This time you need to train a `DPOTrainer` rather than a `SFTTrainer` but everything else remains relatively similar. For examples on how to do this, have a look at the documentation provided above.

In [None]:
model_name = "gsm8k-instruct-lora"  # or your SFT checkpoint
tokenizer = AutoTokenizer.from_pretrained(model_name)

trainer = DPOTrainer(
    model=model_name,
    ref_model=model_name,   # frozen reference (usually SFT model)
    args=config,
    beta=0.1,               # controls preference strength
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=tokenizer,
)

trainer.train()
trainer.save_model("gsm8k-instruct-dpo")