<a href="https://colab.research.google.com/github/divyapalaniswamy/notebooks/blob/main/Huggingface_GRPO_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Finetune LLMs with GRPO
This notebook shows how to finetune an LLM with GRPO, using the trl library.

It's by Ben Burtenshaw and Maxime Labonne.

This is a minimal example. For a complete example, refer to the GRPO chapter in the course.

Based on the information I found, the notebook `grpo_finetune.ipynb` from the Hugging Face course explains how to perform Group Relative Policy Optimization (GRPO) fine-tuning.

Here is a breakdown of the key concepts and code steps typically covered in this type of notebook:

### What is GRPO?
GRPO is a method used to fine-tune a large language model (LLM) using a technique called reinforcement learning from human feedback (RLHF). Instead of training on a single, preferred response, GRPO works by generating a "group" of multiple completions for a given prompt. It then uses a reward function to compare these completions and guide the model's learning, encouraging it to produce better, more desirable outputs.

### Key Components of the Code

1.  **Environment and Library Setup:** The first step is to install the necessary libraries, such as `transformers`, `trl` (for Transformer Reinforcement Learning), `peft` (for Parameter-Efficient Fine-Tuning), `accelerate`, and `datasets`. The notebook often uses `Unsloth`, a library that accelerates the fine-tuning process.

2.  **Dataset Preparation:** The code loads a dataset of prompts. For this specific notebook, it might use the "tldr" (too long; didn't read) dataset, where the goal is to train the model to generate a summary from a longer text. The notebook may also filter the data to manage memory usage.

3.  **Model Loading:** A pre-trained LLM, such as a Qwen or Llama model, is loaded. The code often uses 4-bit quantization and a method like LoRA (Low-Rank Adaptation) for efficient fine-tuning, which saves memory and speeds up the process.

4.  **Reward Function Definition:** This is a crucial part of the GRPO process. The reward function evaluates the quality of the generated completions. The notebook defines custom reward functions, which can be rule-based (e.g., rewarding the model for generating a specific answer format) or based on other criteria, such as length. The reward function's job is to assign a score to each completion in a group.

5.  **Training Configuration:** The code sets up the training process using `GRPOConfig` from the `trl` library. Key parameters configured here include:
    * `learning_rate`: How quickly the model's weights are updated.
    * `num_generations`: This is a key parameter for GRPO; it defines the size of the "group" of completions to generate for each prompt. A typical value is between 4 and 16.
    * `max_steps` or `num_train_epochs`: The total number of training steps or epochs.
    * `per_device_train_batch_size`: The batch size for training.

6.  **Training the Model:** A `GRPOTrainer` is initialized with the model, the reward function(s), the training arguments, and the dataset. The `trainer.train()` method then executes the fine-tuning process. During training, the model generates groups of completions, the reward function evaluates them, and the model updates its parameters to learn how to produce higher-quality outputs.

7.  **Testing and Saving:** After training, the notebook typically includes a section to test the fine-tuned model with new prompts to see how it performs. Finally, it provides code to save the model in various formats for future use.

In [None]:
!pip install -qqq datasets==3.2.0 transformers==4.47.1 trl==0.14.0 peft==0.14.0 accelerate==1.2.1 bitsandbytes==0.45.2 wandb==0.19.7 --progress-bar off
!pip install -qqq flash-attn --no-build-isolation --progress-bar off

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2025.3.0 requires fsspec==2025.3.0, but you have fsspec 2024.9.0 which is incompatible.[0m[31m
[0m  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for flash-attn (setup.py) ... [?25l[?25hdone


In [None]:
import torch
import wandb
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer

# Log to Weights & Biases
wandb.login()

# Load dataset
dataset = load_dataset("mlabonne/smoltldr")
print(dataset)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mdivyaswamy87[0m ([33mdivyaswamy87-the-george-washington-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/981 [00:00<?, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/1.44M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/151k [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/151k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/200 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/200 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['prompt', 'completion'],
        num_rows: 2000
    })
    validation: Dataset({
        features: ['prompt', 'completion'],
        num_rows: 200
    })
    test: Dataset({
        features: ['prompt', 'completion'],
        num_rows: 200
    })
})


In [None]:
# Load model
model_id = "HuggingFaceTB/SmolLM-135M-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    device_map="auto",
    attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Load LoRA
lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=16,
    lora_alpha=32,
    target_modules="all-linear",
)
model = get_peft_model(model, lora_config)
print(model.print_trainable_parameters())


model.safetensors:   0%|          | 0.00/269M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/565 [00:00<?, ?B/s]

trainable params: 4,884,480 || all params: 139,399,488 || trainable%: 3.5039
None


In [None]:
# Install a specific version of flash-attn that might be more compatible
!pip uninstall -y flash-attn
!pip install -qqq flash-attn==2.3.6 --no-build-isolation --progress-bar off

Found existing installation: flash_attn 2.8.2
Uninstalling flash_attn-2.8.2:
  Successfully uninstalled flash_attn-2.8.2
  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for flash-attn (setup.py) ... [?25l[?25hdone


In [None]:
# Reward function
def reward_len(completions, **kwargs):
    return [-abs(50 - len(completion)) for completion in completions]


In [19]:
# Training arguments
training_args = GRPOConfig(
    output_dir="GRPO",
    learning_rate=2e-5,
    per_device_train_batch_size=2, # Reduced from 4
    gradient_accumulation_steps=2, # Reduced from 4
    max_prompt_length=512,
    max_completion_length=64, # Reduced from 96
    num_generations=4,
    optim="adamw_8bit",
    num_train_epochs=1,
    bf16=True,
    report_to=["wandb"],
    remove_unused_columns=False,
    logging_steps=1,
)

# Trainer
trainer = GRPOTrainer(
    model=model,
    reward_funcs=[reward_len],
    args=training_args,
    train_dataset=dataset["train"],
)

# Train model
wandb.init(project="GRPO")
trainer.train()

Step,Training Loss
1,-0.0
2,0.0001
3,0.0
4,0.0001
5,0.0001
6,0.0001
7,0.0
8,0.0
9,0.0001
10,0.0001


TrainOutput(global_step=500, training_loss=0.014921467357780785, metrics={'train_runtime': 4763.9571, 'train_samples_per_second': 0.42, 'train_steps_per_second': 0.105, 'total_flos': 0.0, 'train_loss': 0.014921467357780785})

In [18]:
!pip install -qqq fsspec==2024.9.0 --progress-bar off

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2025.3.0 requires fsspec==2025.3.0, but you have fsspec 2024.9.0 which is incompatible.[0m[31m
[0m

In [26]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [29]:
# Save model
merged_model = trainer.model.merge_and_unload()
merged_model.push_to_hub("divyaswamy87/SmolLM-135M-Instruct-GRPO", private=False) # Replace with your desired repo ID

# Push the tokenizer to the same repository
tokenizer.push_to_hub("divyaswamy87/SmolLM-135M-Instruct-GRPO", private=False)

README.md: 0.00B [00:00, ?B/s]

No files have been modified since last commit. Skipping to prevent empty commit.


CommitInfo(commit_url='https://huggingface.co/divyaswamy87/SmolLM-135M-Instruct-GRPO/commit/718172b6d4fa44881dde6889d9a748658cdd60c8', commit_message='Upload tokenizer', commit_description='', oid='718172b6d4fa44881dde6889d9a748658cdd60c8', pr_url=None, repo_url=RepoUrl('https://huggingface.co/divyaswamy87/SmolLM-135M-Instruct-GRPO', endpoint='https://huggingface.co', repo_type='model', repo_id='divyaswamy87/SmolLM-135M-Instruct-GRPO'), pr_revision=None, pr_num=None)

In [22]:
prompt = """
# A long document about the Cat

The cat (Felis catus), also referred to as the domestic cat or house cat, is a small
domesticated carnivorous mammal. It is the only domesticated species of the family Felidae.
Advances in archaeology and genetics have shown that the domestication of the cat occurred
in the Near East around 7500 BC. It is commonly kept as a pet and farm cat, but also ranges
freely as a feral cat avoiding human contact. It is valued by humans for companionship and
its ability to kill vermin. Its retractable claws are adapted to killing small prey species
such as mice and rats. It has a strong, flexible body, quick reflexes, and sharp teeth,
and its night vision and sense of smell are well developed. It is a social species,
but a solitary hunter and a crepuscular predator. Cat communication includes
vocalizations—including meowing, purring, trilling, hissing, growling, and grunting—as
well as body language. It can hear sounds too faint or too high in frequency for human ears,
such as those made by small mammals. It secretes and perceives pheromones.
"""

messages = [
    {"role": "user", "content": prompt},
]


In [31]:
# Generate text
from transformers import pipeline

generator = pipeline("text-generation", model="divyaswamy87/SmolLM-135M-Instruct-GRPO")

## Or use the model and tokenizer we defined earlier
# generator = pipeline("text-generation", model=model, tokenizer=tokenizer)

generate_kwargs = {
    "max_new_tokens": 256,
    "do_sample": True,
    "temperature": 0.5,
    "min_p": 0.1,
}

generated_text = generator(messages, **generate_kwargs)

print(generated_text)

Device set to use cuda:0


[{'generated_text': [{'role': 'user', 'content': '\n# A long document about the Cat\n\nThe cat (Felis catus), also referred to as the domestic cat or house cat, is a small \ndomesticated carnivorous mammal. It is the only domesticated species of the family Felidae.\nAdvances in archaeology and genetics have shown that the domestication of the cat occurred\nin the Near East around 7500 BC. It is commonly kept as a pet and farm cat, but also ranges\nfreely as a feral cat avoiding human contact. It is valued by humans for companionship and\nits ability to kill vermin. Its retractable claws are adapted to killing small prey species\nsuch as mice and rats. It has a strong, flexible body, quick reflexes, and sharp teeth,\nand its night vision and sense of smell are well developed. It is a social species,\nbut a solitary hunter and a crepuscular predator. Cat communication includes\nvocalizations—including meowing, purring, trilling, hissing, growling, and grunting—as\nwell as body language. 