# GRPO Training for Qwen2.5-Math-1.5B on Google Colab

This notebook demonstrates how to train a math reasoning model using **Group Relative Policy Optimization (GRPO)** on the MATH dataset.

## Requirements
- Google Colab with GPU (T4 or better, A100 recommended)
- ~16GB GPU memory for training

## What is GRPO?
GRPO (from DeepSeekMath and DeepSeek R1) is a policy gradient method that:
1. Generates multiple responses per question
2. Computes rewards based on answer correctness
3. Normalizes rewards within each group to get advantages
4. Trains using policy gradient methods

## 1. Setup Environment

In [13]:
# Check GPU availability
!nvidia-smi

Sun Feb  8 03:52:55 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:00:05.0 Off |                    0 |
| N/A   34C    P0             54W /  400W |       5MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [14]:
# Clone the repository
!git clone https://github.com/bearbearyu1223/qwen_math_grpo.git
%cd qwen_math_grpo

Cloning into 'qwen_math_grpo'...
remote: Enumerating objects: 36, done.[K
remote: Counting objects: 100% (36/36), done.[K
remote: Compressing objects: 100% (30/30), done.[K
remote: Total 36 (delta 9), reused 32 (delta 5), pack-reused 0 (from 0)[K
Receiving objects: 100% (36/36), 204.80 KiB | 762.00 KiB/s, done.
Resolving deltas: 100% (9/9), done.
/content/qwen_math_grpo/qwen_math_grpo/qwen_math_grpo


In [27]:
!git pull origin main
!uv sync

From https://github.com/bearbearyu1223/qwen_math_grpo
 * branch            main       -> FETCH_HEAD
Already up to date.
[2K[2mResolved [1m196 packages[0m [2min 1.51s[0m[0m
[2K[2mPrepared [1m2 packages[0m [2min 365ms[0m[0m
[2mUninstalled [1m1 package[0m [2min 0.37ms[0m[0m
[2K[2mInstalled [1m3 packages[0m [2min 3ms[0m[0m
 [32m+[39m [1maccelerate[0m[2m==1.12.0[0m
 [32m+[39m [1mpsutil[0m[2m==7.2.2[0m
 [33m~[39m [1mqwen-math-grpo[0m[2m==0.1.0 (from file:///content/qwen_math_grpo/qwen_math_grpo/qwen_math_grpo)[0m


In [15]:
# Install uv package manager
!curl -LsSf https://astral.sh/uv/install.sh | sh

# Add uv to PATH (source doesn't work with ! in Colab)
import os
os.environ["PATH"] = f"{os.environ['HOME']}/.local/bin:{os.environ['PATH']}"

# Install base dependencies
!uv sync

# Install vLLM separately (needs system CUDA compatibility)
!uv pip install vllm>=0.8.4

downloading uv 0.10.0 x86_64-unknown-linux-gnu
no checksums to verify
installing to /usr/local/bin
  uv
  uvx
everything's installed!
Using CPython 3.12.12 interpreter at: [36m/usr/bin/python3[39m
Creating virtual environment at: [36m.venv[39m
[2mResolved [1m195 packages[0m [2min 1ms[0m[0m
[2K[2mPrepared [1m1 package[0m [2min 374ms[0m[0m
[2K[2mInstalled [1m78 packages[0m [2min 243ms[0m[0m
 [32m+[39m [1maiohappyeyeballs[0m[2m==2.6.1[0m
 [32m+[39m [1maiohttp[0m[2m==3.13.3[0m
 [32m+[39m [1maiosignal[0m[2m==1.4.0[0m
 [32m+[39m [1mannotated-types[0m[2m==0.7.0[0m
 [32m+[39m [1mantlr4-python3-runtime[0m[2m==4.13.2[0m
 [32m+[39m [1manyio[0m[2m==4.12.1[0m
 [32m+[39m [1mattrs[0m[2m==25.4.0[0m
 [32m+[39m [1mcertifi[0m[2m==2026.1.4[0m
 [32m+[39m [1mcharset-normalizer[0m[2m==3.4.4[0m
 [32m+[39m [1mclick[0m[2m==8.3.1[0m
 [32m+[39m [1mdatasets[0m[2m==4.5.0[0m
 [32m+[39m [1mdill[0m[2m==0.4.0[0m
 [32m+[39m

In [16]:
# Verify installation
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

PyTorch version: 2.9.1+cu128
CUDA available: True
GPU: NVIDIA A100-SXM4-80GB
GPU Memory: 85.2 GB


## 2. Download Dataset and Model

In [17]:
# Download the MATH dataset
!uv run python scripts/download_dataset.py

Downloading dataset: nlile/hendrycks-MATH-benchmark
Output directory: /content/qwen_math_grpo/qwen_math_grpo/qwen_math_grpo/data/math
Splits: ['train', 'test']

Saving train split (12000 examples) to data/math/train.jsonl
  Saved 12000 examples
Saving test split (500 examples) to data/math/test.jsonl
  Saved 500 examples

Download complete!


In [18]:
# Verify dataset
!wc -l data/math/train.jsonl data/math/test.jsonl

   12000 data/math/train.jsonl
     500 data/math/test.jsonl
   12500 total


In [19]:
# Preview a sample from the dataset
import json

with open('data/math/train.jsonl') as f:
    sample = json.loads(f.readline())

print("Problem:")
print(sample['problem'][:500])
print("\nAnswer:", sample['answer'])

Problem:
How many vertical asymptotes does the graph of $y=\frac{2}{x^2+x-6}$ have?

Answer: 2


## 3. Run GRPO Training

### Training Configuration

For Colab with a single GPU, we'll use single-GPU mode. Adjust parameters based on your GPU memory:

| GPU | Recommended Settings |
|-----|---------------------|
| T4 (16GB) | `--rollout-batch-size 8 --train-batch-size 8` |
| A100 (40GB) | `--rollout-batch-size 32 --train-batch-size 32` |
| A100 (80GB) | `--rollout-batch-size 64 --train-batch-size 64` |

In [22]:
# Quick test run (5 steps) to verify everything works
!uv run python scripts/run_grpo.py \
    --model-name-or-path Qwen/Qwen2.5-Math-1.5B \
    --single-gpu \
    --policy-device cuda:0 \
    --rollout-batch-size 32 \
    --train-batch-size 32 \
    --gradient-accumulation-steps 8 \
    --n-grpo-steps 5 \
    --output-dir outputs/grpo_test

2026-02-08 03:57:41,801 - __main__ - INFO - Loading training data...
2026-02-08 03:57:41,909 - __main__ - INFO - Loaded 12000 examples from data/math/train.jsonl
2026-02-08 03:57:41,909 - __main__ - INFO - Loading validation data...
2026-02-08 03:57:41,914 - __main__ - INFO - Loaded 500 examples from data/math/test.jsonl
2026-02-08 03:57:41,914 - __main__ - INFO - GRPO TRAINING CONFIGURATION
2026-02-08 03:57:41,914 - __main__ - INFO - Model: Qwen/Qwen2.5-Math-1.5B
2026-02-08 03:57:41,914 - __main__ - INFO - Training examples: 12000
2026-02-08 03:57:41,914 - __main__ - INFO - GRPO steps: 5
2026-02-08 03:57:41,914 - __main__ - INFO - Rollout batch size: 32
2026-02-08 03:57:41,914 - __main__ - INFO - Group size: 8
2026-02-08 03:57:41,914 - __main__ - INFO - Loss type: reinforce_with_baseline
2026-02-08 03:57:41,914 - __main__ - INFO - Learning rate: 1e-05
2026-02-08 03:57:41,914 - __main__ - INFO - Loading policy model from Qwen/Qwen2.5-Math-1.5B...
`torch_dtype` is deprecated! Use `dtype

## 4. Evaluate the Trained Model

In [24]:
# Check saved model
!ls -la outputs/grpo_test/

total 12
drwxr-xr-x 3 root root 4096 Feb  8 04:32 .
drwxr-xr-x 3 root root 4096 Feb  8 03:54 ..
drwxr-xr-x 2 root root 4096 Feb  8 04:32 final


In [30]:
!uv sync --extra vllm

[2mResolved [1m196 packages[0m [2min 1ms[0m[0m
░░░░░░░░░░░░░░░░░░░░ [0/0] [2mInstalling wheels...                                 [0m[2K░░░░░░░░░░░░░░░░░░░░ [0/100] [2mInstalling wheels...                               [0m[2K░░░░░░░░░░░░░░░░░░░░ [0/100] [2mastor==0.8.1                                       [0m[2K░░░░░░░░░░░░░░░░░░░░ [1/100] [2mastor==0.8.1                                       [0m[2K░░░░░░░░░░░░░░░░░░░░ [1/100] [2mlm-format-enforcer==0.11.3                         [0m[2K░░░░░░░░░░░░░░░░░░░░ [2/100] [2mlm-format-enforcer==0.11.3                         [0m[2K░░░░░░░░░░░░░░░░░░░░ [2/100] [2mcachetools==7.0.0                                  [0m[2K░░░░░░░░░░░░░░░░░░░░ [3/100] [2mcachetools==7.0.0                                  [0m[2K░░░░░░░░░░░░░░░░░░░░ [3/100] [2mijson==3.4.0.post0                                 [0m[2K░░░░░░░░░░░░░░░░░░░░ [4/100] [2mijson==3.4.0.post0                                 [0m[2K░░░░░░░░░░░░░░░░░

In [31]:
# Evaluate the GRPO-trained model
!uv run python scripts/run_math_eval.py \
    --model-name-or-path outputs/grpo_test/final \
    --input-path data/math/test.jsonl \
    --output-path outputs/grpo_eval_results.jsonl \
    --backend vllm \
    --num-samples 100


2026-02-08 04:39:09,755 - __main__ - INFO - Evaluating model: outputs/grpo_test/final
2026-02-08 04:39:09,755 - __main__ - INFO - Backend: vllm
2026-02-08 04:39:09,755 - __main__ - INFO - Input: data/math/test.jsonl
2026-02-08 04:39:09,755 - __main__ - INFO - Output: outputs/grpo_eval_results.jsonl
2026-02-08 04:39:09,759 - cs336_alignment.evaluate_math - INFO - Read 500 examples from data/math/test.jsonl
2026-02-08 04:39:09,760 - cs336_alignment.evaluate_math - INFO - Limiting evaluation to 100 samples
2026-02-08 04:39:16,035 - cs336_alignment.evaluate_math - INFO - Loading model outputs/grpo_test/final with vLLM backend...
[32mINFO[0m [90m02-08 04:39:16[0m [90m[utils.py:261][0m non-default args: {'trust_remote_code': True, 'disable_log_stats': True, 'model': 'outputs/grpo_test/final'}
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and 

In [32]:
# Evaluate the base model for comparison
!uv run python scripts/run_math_eval.py \
    --model-name-or-path Qwen/Qwen2.5-Math-1.5B \
    --input-path data/math/test.jsonl \
    --output-path outputs/base_eval_results.jsonl \
    --backend vllm \
    --num-samples 100

2026-02-08 04:41:55,196 - __main__ - INFO - Evaluating model: Qwen/Qwen2.5-Math-1.5B
2026-02-08 04:41:55,196 - __main__ - INFO - Backend: vllm
2026-02-08 04:41:55,196 - __main__ - INFO - Input: data/math/test.jsonl
2026-02-08 04:41:55,196 - __main__ - INFO - Output: outputs/base_eval_results.jsonl
2026-02-08 04:41:55,200 - cs336_alignment.evaluate_math - INFO - Read 500 examples from data/math/test.jsonl
2026-02-08 04:41:55,200 - cs336_alignment.evaluate_math - INFO - Limiting evaluation to 100 samples
2026-02-08 04:41:59,222 - cs336_alignment.evaluate_math - INFO - Loading model Qwen/Qwen2.5-Math-1.5B with vLLM backend...
[32mINFO[0m [90m02-08 04:41:59[0m [90m[utils.py:261][0m non-default args: {'trust_remote_code': True, 'disable_log_stats': True, 'model': 'Qwen/Qwen2.5-Math-1.5B'}
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is 

## 5. Compare Results

In [34]:
import json
from statistics import mean

def load_results(path):
    results = []
    with open(path) as f:
        for line in f:
            results.append(json.loads(line))
    return results

def compute_metrics(results):
    format_rewards = [r['metrics']['format_reward'] for r in results]
    answer_rewards = [r['metrics']['answer_reward'] for r in results]
    return {
        'format_accuracy': mean(format_rewards),
        'answer_accuracy': mean(answer_rewards),
        'n_samples': len(results)
    }

# Load and compare results
try:
    grpo_results = load_results('outputs/grpo_eval_results.jsonl')
    base_results = load_results('outputs/base_eval_results.jsonl')

    grpo_metrics = compute_metrics(grpo_results)
    base_metrics = compute_metrics(base_results)

    print("=" * 50)
    print("EVALUATION COMPARISON")
    print("=" * 50)
    print(f"\n{'Model':<25} {'Format Acc':<15} {'Answer Acc':<15}")
    print("-" * 55)
    print(f"{'Base (Qwen2.5-Math-1.5B)':<25} {base_metrics['format_accuracy']:<15.2%} {base_metrics['answer_accuracy']:<15.2%}")
    print(f"{'GRPO-Trained':<25} {grpo_metrics['format_accuracy']:<15.2%} {grpo_metrics['answer_accuracy']:<15.2%}")
    print("-" * 55)

    improvement = grpo_metrics['answer_accuracy'] - base_metrics['answer_accuracy']
    print(f"\nImprovement: {improvement:+.2%}")
except FileNotFoundError as e:
    print(f"Results file not found: {e}")
    print("Make sure to run the evaluation cells above first.")

EVALUATION COMPARISON

Model                     Format Acc      Answer Acc     
-------------------------------------------------------
Base (Qwen2.5-Math-1.5B)  38.00%          19.00%         
GRPO-Trained              59.00%          28.00%         
-------------------------------------------------------

Improvement: +9.00%


## 6. View Analysis Reports

In [35]:
# View GRPO model analysis report
!cat outputs/grpo_eval_results_analysis.txt | head -100

MATH EVALUATION ANALYSIS REPORT

SUMMARY STATISTICS
----------------------------------------
Total examples: 100

Category Breakdown:
  Correct (format=1, answer=1): 28 (28.0%)
  Format only (format=1, answer=0): 31 (31.0%)
  Neither (format=0): 41 (41.0%)

EXAMPLES: CORRECT

--- Example 1 ---
Problem: Convert the point $(0,3)$ in rectangular coordinates to polar coordinates.  Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$...
Ground Truth: \left( 3, \frac{\pi}{2} \right)
Model Output (first 500 chars): The point $(0,3)$ is located on the positive y-axis. To convert it to polar coordinates, we need to find the radius $r$ and the angle $\theta$. Since the point is on the y-axis, $\theta$ is $\frac{\pi}{2}$ or $\frac{3\pi}{2}$ depending on the quadrant. Since the point is in the positive y-axis, $\theta = \frac{\pi}{2}$. The radius $r$ is the distance from the origin to the point, which is 3. So the polar coordinates are $(3, \frac{\pi}{2})$. </think

In [36]:
# View base model analysis report
!cat outputs/base_eval_results_analysis.txt | head -100

MATH EVALUATION ANALYSIS REPORT

SUMMARY STATISTICS
----------------------------------------
Total examples: 100

Category Breakdown:
  Correct (format=1, answer=1): 19 (19.0%)
  Format only (format=1, answer=0): 19 (19.0%)
  Neither (format=0): 62 (62.0%)

EXAMPLES: CORRECT

--- Example 1 ---
Problem: How many positive whole-number divisors does 196 have?...
Ground Truth: 9
Model Output (first 500 chars):  The prime factorization of 196 is 2^2 * 7^2. The number of divisors is given by (2+1)(2+1) = 9. </think> <answer> 9 </answer>...


--- Example 2 ---
Problem: What is the smallest positive perfect cube that can be written as the sum of three consecutive integers?...
Ground Truth: 27
Model Output (first 500 chars):  Let's denote the three consecutive integers as n, n+1, and n+2. The sum of these three integers is n + (n+1) + (n+2) = 3n + 3. For this sum to be a perfect cube, 3n + 3 must be a perfect cube. The smallest positive perfect cube is 1, but 3n + 3 cannot be 1 for any integer 

## 7. Save Model to Google Drive (Optional)

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Copy trained model to Google Drive
!cp -r outputs/grpo_model /content/drive/MyDrive/grpo_model_backup

## 8. Interactive Testing

In [None]:
# Load the trained model for interactive testing
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_path = "outputs/grpo_model/final"

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)

print("Model loaded successfully!")

In [None]:
# Test with a math problem
def solve_math_problem(question):
    prompt = f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
User: {question}
Assistant: <think>"""

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=1024,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
        )

    response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    return response

# Example problem
question = "What is the sum of all positive integers n such that n^2 + n + 1 divides n^4 + 2n^3 + 3n^2 + 2n + 1?"
print(f"Question: {question}\n")
print("Model's Response:")
print(solve_math_problem(question))

In [None]:
# Try your own math problem
your_question = "If x + y = 10 and xy = 21, what is x^2 + y^2?"
print(f"Question: {your_question}\n")
print("Model's Response:")
print(solve_math_problem(your_question))

## Notes

### Training Tips
- Start with a small number of steps (5-10) to verify everything works
- Monitor GPU memory usage and adjust batch sizes accordingly
- Use Weights & Biases for experiment tracking: add `--wandb-project your-project-name`

### Expected Results
- Base Qwen2.5-Math-1.5B: ~50-60% format accuracy, varies on answer accuracy
- After GRPO training: Should see improvement in both format and answer accuracy

### Troubleshooting
- **OOM Error**: Reduce `--rollout-batch-size` and `--train-batch-size`
- **Slow Training**: This is expected on T4; consider using A100 for faster training
- **Low Accuracy**: Try more training steps or adjust learning rate