# üöÄ GRPO Trainer Super Notebook

**Advanced GRPO/GSPO Training for Math & Reasoning Tasks**

This notebook provides a complete, production-ready pipeline for training LLMs using:
- **GRPO** (Group Relative Policy Optimization)
- **GSPO** (dr_grpo variant with sequence-level importance sampling)
- **Multi-dataset support** (GSM8K, MATH, custom)
- **Memory-efficient training** (4-bit quantization, 8-bit optimizers)
- **Configurable reward functions**

---

## Table of Contents
1. [Setup & Installation](#setup)
2. [Configuration](#config)
3. [Dataset Preparation](#data)
4. [Model Loading](#model)
5. [Reward Functions](#rewards)
6. [Training](#train)
7. [Evaluation](#eval)
8. [Inference](#inference)
9. [Export & Save](#save)

<a name="setup"></a>
## 1. Setup & Installation

Run this cell to install all dependencies. Works on Google Colab (free T4 GPU).

In [None]:
# @title Install Dependencies { display-mode: "form" }
# @markdown Check the boxes for optional features:
install_flash_attn = False  # @param {type:"boolean"}
install_deepspeed = False  # @param {type:"boolean"}
use_unsloth = True  # @param {type:"boolean"}

import subprocess
import sys

def run(cmd):
    subprocess.run(cmd, shell=True, check=True)

# Core dependencies
run("pip install -q torch transformers>=4.40.0 trl>=0.8.0 peft>=0.10.0")
run("pip install -q datasets accelerate bitsandbytes wandb")
run("pip install -q rich typer pyyaml omegaconf")

if use_unsloth:
    run("pip install -q unsloth")
    
if install_flash_attn:
    run("pip install flash-attn --no-build-isolation -q")
    
if install_deepspeed:
    run("pip install deepspeed -q")

print("‚úÖ Installation complete!")

In [None]:
# Imports
import os
import re
import json
import logging
from dataclasses import dataclass, field
from typing import Optional, Callable

import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import GRPOConfig, GRPOTrainer

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Check GPU
print(f"üñ•Ô∏è PyTorch version: {torch.__version__}")
print(f"üéÆ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"üéÆ GPU: {torch.cuda.get_device_name(0)}")
    print(f"üíæ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

<a name="config"></a>
## 2. Configuration

Configure your training run. All settings are in one place!

In [None]:
# @title Training Configuration { display-mode: "form" }

# === MODEL ===
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"  # @param {type:"string"}
LOAD_IN_4BIT = True  # @param {type:"boolean"}
USE_FLASH_ATTN = False  # @param {type:"boolean"}

# === LORA ===
LORA_ENABLED = True  # @param {type:"boolean"}
LORA_R = 16  # @param {type:"integer"}
LORA_ALPHA = 64  # @param {type:"integer"}
LORA_DROPOUT = 0.05  # @param {type:"number"}

# === DATASET ===
DATASET_NAME = "gsm8k"  # @param ["gsm8k", "math", "custom"]
USE_ONE_SHOT = True  # @param {type:"boolean"}
MAX_SAMPLES = None  # @param {type:"integer"}

# === TRAINING ===
OUTPUT_DIR = "outputs/grpo-run"  # @param {type:"string"}
RUN_NAME = "grpo-gsm8k"  # @param {type:"string"}
LEARNING_RATE = 5e-6  # @param {type:"number"}
BATCH_SIZE = 2  # @param {type:"integer"}
GRAD_ACCUM_STEPS = 2  # @param {type:"integer"}
NUM_GENERATIONS = 8  # @param {type:"integer"}
NUM_EPOCHS = 1  # @param {type:"number"}
MAX_STEPS = -1  # @param {type:"integer"}

# === GSPO OPTIONS ===
USE_GSPO = False  # @param {type:"boolean"}
LOSS_TYPE = "grpo"  # @param ["grpo", "dr_grpo", "ipo", "simpo"]
IMPORTANCE_SAMPLING = "token"  # @param ["token", "sequence"]

# === OPTIMIZER ===
USE_8BIT_OPTIMIZER = True  # @param {type:"boolean"}

# === REWARD WEIGHTS ===
CORRECTNESS_WEIGHT = 2.0  # @param {type:"number"}
FORMAT_WEIGHT = 0.5  # @param {type:"number"}
INTEGER_WEIGHT = 0.5  # @param {type:"number"}
XML_COUNT_WEIGHT = 0.5  # @param {type:"number"}

# === CUSTOM DELIMITERS ===
USE_CUSTOM_DELIMITERS = False  # @param {type:"boolean"}
REASONING_START = "<REASONING>"  # @param {type:"string"}
REASONING_END = "</REASONING>"  # @param {type:"string"}
ANSWER_START = "<SOLUTION>"  # @param {type:"string"}
ANSWER_END = "</SOLUTION>"  # @param {type:"string"}

# === LOGGING ===
REPORT_TO = "none"  # @param ["none", "wandb", "tensorboard"]
LOGGING_STEPS = 1  # @param {type:"integer"}
SAVE_STEPS = 100  # @param {type:"integer"}

# Apply GSPO settings
if USE_GSPO:
    LOSS_TYPE = "dr_grpo"
    IMPORTANCE_SAMPLING = "sequence"

print("üìã Configuration Summary:")
print(f"   Model: {MODEL_NAME}")
print(f"   Dataset: {DATASET_NAME}")
print(f"   LoRA: {'Enabled' if LORA_ENABLED else 'Disabled'} (r={LORA_R})")
print(f"   4-bit: {'Yes' if LOAD_IN_4BIT else 'No'}")
print(f"   Loss: {LOSS_TYPE}")
print(f"   Optimizer: {'8-bit AdamW' if USE_8BIT_OPTIMIZER else 'AdamW'}")

<a name="data"></a>
## 3. Dataset Preparation

Load and format the dataset with chat templates and one-shot prompting.

In [None]:
# Define prompts based on delimiter choice
if USE_CUSTOM_DELIMITERS:
    SYSTEM_PROMPT = f"""
Respond in the following format:

{REASONING_START}
...
{REASONING_END}
{ANSWER_START}
...
{ANSWER_END}
"""
    COT_FORMAT = f"""\
{REASONING_START}
{{reasoning}}
{REASONING_END}
{ANSWER_START}
{{answer}}
{ANSWER_END}
"""
else:
    SYSTEM_PROMPT = """
Respond in the following format:

<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
    COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
    REASONING_START, REASONING_END = "<reasoning>", "</reasoning>"
    ANSWER_START, ANSWER_END = "<answer>", "</answer>"

# One-shot examples
ONE_SHOT_EXAMPLES = {
    "gsm8k": {
        "question": "What is the largest single-digit prime number?",
        "reasoning": "9 is divisible by 3 and 8 is divisible by 2, but 7 is prime.",
        "answer": "7"
    },
    "math": {
        "question": "Find the value of x if 2x + 5 = 13.",
        "reasoning": "Subtract 5 from both sides: 2x = 8. Divide by 2: x = 4.",
        "answer": "4"
    }
}

print(f"üìù System Prompt Preview:")
print(SYSTEM_PROMPT[:200] + "...")

In [None]:
def extract_hash_answer(text: str) -> str | None:
    """Extract answer from GSM8K format (#### answer)"""
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

def extract_boxed_answer(text: str) -> str | None:
    """Extract answer from MATH format (\\boxed{answer})"""
    pattern = r"\\boxed\{([^}]+)\}"
    match = re.search(pattern, text)
    return match.group(1).strip() if match else None

def load_and_prepare_dataset(dataset_name: str, use_one_shot: bool = True, max_samples: int = None):
    """Load and prepare dataset with chat formatting."""
    
    logger.info(f"Loading dataset: {dataset_name}")
    
    # Load dataset
    if dataset_name == "gsm8k":
        data = load_dataset("openai/gsm8k", "main")["train"]
        question_field = "question"
        answer_fn = lambda x: extract_hash_answer(x["answer"])
    elif dataset_name == "math":
        data = load_dataset("hendrycks/competition_math")["train"]
        question_field = "problem"
        answer_fn = lambda x: extract_boxed_answer(x["solution"]) or x.get("answer", "")
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    
    def format_example(x):
        prompt = [{"role": "system", "content": SYSTEM_PROMPT}]
        
        # Add one-shot example
        if use_one_shot and dataset_name in ONE_SHOT_EXAMPLES:
            ex = ONE_SHOT_EXAMPLES[dataset_name]
            prompt.extend([
                {"role": "user", "content": ex["question"]},
                {"role": "assistant", "content": COT_FORMAT.format(
                    reasoning=ex["reasoning"],
                    answer=ex["answer"]
                )}
            ])
        
        prompt.append({"role": "user", "content": x[question_field]})
        return {"prompt": prompt, "answer": answer_fn(x)}
    
    # Format dataset
    formatted = data.map(format_example)
    
    # Limit samples if specified
    if max_samples and len(formatted) > max_samples:
        formatted = formatted.shuffle(seed=42).select(range(max_samples))
    
    logger.info(f"Dataset size: {len(formatted)} samples")
    return formatted

# Load dataset
dataset = load_and_prepare_dataset(DATASET_NAME, USE_ONE_SHOT, MAX_SAMPLES)

print(f"\nüìä Dataset loaded: {len(dataset)} samples")
print(f"\nüìù Sample prompt structure:")
print(json.dumps(dataset[0]["prompt"][:2], indent=2))

<a name="model"></a>
## 4. Model Loading

Load the model with optional 4-bit quantization and LoRA.

In [None]:
def load_model_and_tokenizer(model_name, load_in_4bit=True, use_flash_attn=False):
    """Load model with quantization and attention optimizations."""
    
    logger.info(f"Loading model: {model_name}")
    
    # Quantization config
    quant_config = None
    if load_in_4bit:
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
        )
    
    # Model kwargs
    model_kwargs = {
        "torch_dtype": torch.bfloat16,
        "device_map": "auto",
        "trust_remote_code": True,
        "use_cache": False,
    }
    
    if quant_config:
        model_kwargs["quantization_config"] = quant_config
    
    if use_flash_attn:
        model_kwargs["attn_implementation"] = "flash_attention_2"
    
    # Load model
    model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
    
    # Prepare for training if quantized
    if load_in_4bit:
        model = prepare_model_for_kbit_training(model)
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Print model info
    total_params = sum(p.numel() for p in model.parameters())
    logger.info(f"Model loaded: {total_params/1e9:.2f}B parameters")
    
    return model, tokenizer

# Load model
model, tokenizer = load_model_and_tokenizer(MODEL_NAME, LOAD_IN_4BIT, USE_FLASH_ATTN)
print(f"‚úÖ Model loaded successfully!")

In [None]:
# Apply LoRA
if LORA_ENABLED:
    peft_config = LoraConfig(
        r=LORA_R,
        lora_alpha=LORA_ALPHA,
        lora_dropout=LORA_DROPOUT,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
        task_type="CAUSAL_LM",
        bias="none",
    )
    print(f"‚úÖ LoRA config ready (r={LORA_R}, alpha={LORA_ALPHA})")
else:
    peft_config = None
    print("‚ö†Ô∏è LoRA disabled - full fine-tuning")

<a name="rewards"></a>
## 5. Reward Functions

Define reward functions to guide GRPO training.

In [None]:
def extract_xml_answer(text: str) -> str:
    """Extract answer from XML-formatted response."""
    try:
        answer = text.split(ANSWER_START)[-1].split(ANSWER_END)[0].strip()
        return answer
    except IndexError:
        return ""

# === REWARD FUNCTIONS ===

def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    """Reward for correct answers."""
    responses = [completion[0]['content'] for completion in completions]
    extracted = [extract_xml_answer(r) for r in responses]
    
    # Log first example
    q = prompts[0][-1]['content']
    logger.info(f"Q: {q[:100]}... | A: {answer[0]} | Got: {extracted[0]}")
    
    return [CORRECTNESS_WEIGHT if r == a else 0.0 for r, a in zip(extracted, answer)]

def integer_reward_func(completions, **kwargs) -> list[float]:
    """Reward for numeric answers."""
    responses = [completion[0]['content'] for completion in completions]
    extracted = [extract_xml_answer(r) for r in responses]
    return [INTEGER_WEIGHT if r.lstrip('-').isdigit() else 0.0 for r in extracted]

def format_reward_func(completions, **kwargs) -> list[float]:
    """Reward for correct formatting."""
    pattern = f"{re.escape(REASONING_START)}.*?{re.escape(REASONING_END)}.*?{re.escape(ANSWER_START)}.*?{re.escape(ANSWER_END)}"
    responses = [completion[0]['content'] for completion in completions]
    return [FORMAT_WEIGHT if re.search(pattern, r, re.DOTALL) else 0.0 for r in responses]

def xml_count_reward_func(completions, **kwargs) -> list[float]:
    """Reward based on XML tag structure."""
    def count_tags(text):
        score = 0.0
        if text.count(f"{REASONING_START}\n") == 1:
            score += 0.125
        if text.count(f"\n{REASONING_END}\n") == 1:
            score += 0.125
        if text.count(f"\n{ANSWER_START}\n") == 1:
            score += 0.125
        if text.count(f"\n{ANSWER_END}") == 1:
            score += 0.125
            # Penalize content after closing tag
            extra = text.split(f"\n{ANSWER_END}")[-1]
            score -= len(extra) * 0.001
        return max(score, 0.0) * XML_COUNT_WEIGHT
    
    responses = [completion[0]['content'] for completion in completions]
    return [count_tags(r) for r in responses]

def gibberish_penalty_func(completions, **kwargs) -> list[float]:
    """Penalize gibberish outputs (VLM fix)."""
    patterns = ["addCriterion", "\n\n\n\n", "................"]
    responses = [completion[0]['content'] for completion in completions]
    
    rewards = []
    for r in responses:
        if len(r) == 0:
            rewards.append(0.0)
            continue
        cleaned = r
        for p in patterns:
            cleaned = cleaned.replace(p, "")
        if (len(r) - len(cleaned)) / len(r) >= 0.5:
            rewards.append(-2.0)
        else:
            rewards.append(0.0)
    return rewards

# Combine reward functions
reward_functions = [
    correctness_reward_func,
    format_reward_func,
    integer_reward_func,
]

if XML_COUNT_WEIGHT > 0:
    reward_functions.append(xml_count_reward_func)

print(f"‚úÖ {len(reward_functions)} reward functions configured")
print(f"   Weights: correctness={CORRECTNESS_WEIGHT}, format={FORMAT_WEIGHT}, integer={INTEGER_WEIGHT}")

<a name="train"></a>
## 6. Training

Configure and run GRPO/GSPO training.

In [None]:
# Training configuration
training_args = GRPOConfig(
    output_dir=OUTPUT_DIR,
    run_name=RUN_NAME,
    
    # Optimizer
    learning_rate=LEARNING_RATE,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    max_grad_norm=0.1,
    optim="adamw_8bit" if USE_8BIT_OPTIMIZER else "adamw_torch",
    
    # Batch sizes
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    
    # GRPO specific
    num_generations=NUM_GENERATIONS,
    max_prompt_length=256,
    max_completion_length=786,
    
    # Training duration
    num_train_epochs=NUM_EPOCHS,
    max_steps=MAX_STEPS if MAX_STEPS > 0 else -1,
    
    # Logging
    logging_steps=LOGGING_STEPS,
    save_steps=SAVE_STEPS,
    save_total_limit=2,
    
    # Precision
    bf16=True,
    
    # Reporting
    report_to=REPORT_TO,
    log_on_each_node=False,
)

# Add GSPO options if enabled
if USE_GSPO or LOSS_TYPE != "grpo":
    training_args.loss_type = LOSS_TYPE
    training_args.importance_sampling_level = IMPORTANCE_SAMPLING
    training_args.mask_truncated_completions = False
    print(f"üîß GSPO enabled: loss={LOSS_TYPE}, sampling={IMPORTANCE_SAMPLING}")

print(f"‚úÖ Training config ready")
print(f"   Effective batch size: {BATCH_SIZE * GRAD_ACCUM_STEPS * NUM_GENERATIONS}")

In [None]:
# Initialize trainer
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=reward_functions,
    args=training_args,
    train_dataset=dataset,
    peft_config=peft_config,
)

print("‚úÖ Trainer initialized!")
print(f"   Ready to train on {len(dataset)} samples")

In [None]:
# @title Start Training { display-mode: "form" }
# @markdown Click to start training. Watch the reward column increase!

print("üöÄ Starting training...")
print("="*50)

try:
    trainer.train()
    print("="*50)
    print("‚úÖ Training completed successfully!")
except KeyboardInterrupt:
    print("\n‚ö†Ô∏è Training interrupted by user")
except Exception as e:
    print(f"\n‚ùå Training failed: {e}")
    raise

<a name="eval"></a>
## 7. Evaluation

Evaluate the trained model on the test set.

In [None]:
def evaluate_model(model, tokenizer, dataset_name="gsm8k", num_samples=100):
    """Evaluate model accuracy on test set."""
    
    # Load test set
    if dataset_name == "gsm8k":
        test_data = load_dataset("openai/gsm8k", "main")["test"]
        q_field, a_fn = "question", lambda x: extract_hash_answer(x["answer"])
    else:
        test_data = load_dataset("hendrycks/competition_math")["test"]
        q_field, a_fn = "problem", lambda x: extract_boxed_answer(x["solution"])
    
    if num_samples:
        test_data = test_data.select(range(min(num_samples, len(test_data))))
    
    correct = 0
    total = len(test_data)
    
    model.eval()
    
    from tqdm import tqdm
    for example in tqdm(test_data, desc="Evaluating"):
        question = example[q_field]
        gold = a_fn(example)
        
        # Generate response
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": question}
        ]
        
        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=512,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
            )
        
        response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
        extracted = extract_xml_answer(response)
        
        if extracted == gold:
            correct += 1
    
    accuracy = correct / total
    print(f"\nüìä Evaluation Results:")
    print(f"   Accuracy: {accuracy:.2%} ({correct}/{total})")
    return accuracy

# Run evaluation
# accuracy = evaluate_model(model, tokenizer, DATASET_NAME, num_samples=50)
print("üí° Uncomment the line above to run evaluation")

<a name="inference"></a>
## 8. Inference

Test the trained model with custom prompts.

In [None]:
def generate_response(question: str, max_tokens: int = 512):
    """Generate a response for a given question."""
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": question}
    ]
    
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.pad_token_id,
        )
    
    response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    return response

# Test inference
test_question = "If a train travels at 60 mph for 2.5 hours, how far does it travel?"
print(f"‚ùì Question: {test_question}")
print(f"\nü§ñ Response:")
print(generate_response(test_question))

<a name="save"></a>
## 9. Export & Save

Save the trained model for later use.

In [None]:
# @title Save Model { display-mode: "form" }
save_path = "outputs/final_model"  # @param {type:"string"}
push_to_hub = False  # @param {type:"boolean"}
hub_model_id = "your-username/model-name"  # @param {type:"string"}

print(f"üíæ Saving model to {save_path}...")

# Save model and tokenizer
trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)

# Save config
config_dict = {
    "model_name": MODEL_NAME,
    "dataset": DATASET_NAME,
    "lora_r": LORA_R,
    "lora_alpha": LORA_ALPHA,
    "learning_rate": LEARNING_RATE,
    "loss_type": LOSS_TYPE,
    "epochs": NUM_EPOCHS,
}

with open(f"{save_path}/training_config.json", "w") as f:
    json.dump(config_dict, f, indent=2)

print(f"‚úÖ Model saved!")

# Push to Hub
if push_to_hub:
    print(f"üì§ Pushing to Hub: {hub_model_id}")
    model.push_to_hub(hub_model_id)
    tokenizer.push_to_hub(hub_model_id)
    print("‚úÖ Pushed to Hub!")

In [None]:
# @title Load Saved Model { display-mode: "form" }
# @markdown Use this to load a previously saved model

load_path = "outputs/final_model"  # @param {type:"string"}

from peft import PeftModel

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

# Load LoRA weights
model = PeftModel.from_pretrained(base_model, load_path)
tokenizer = AutoTokenizer.from_pretrained(load_path)

print(f"‚úÖ Model loaded from {load_path}")

---

## üéâ Done!

You've successfully trained a model using GRPO! Next steps:

1. **Experiment** with different reward weights
2. **Try GSPO** (`USE_GSPO = True`) for more stable training
3. **Scale up** with larger models and DeepSpeed
4. **Evaluate** on different test sets

For more info, see the [GRPO Trainer documentation](https://github.com/kossisoroyce/grpo-trainer).