# Train a Reasoning LLM with GRPO (R1-Zero Replication)

This notebook trains **Qwen2.5-3B** using pure RL (GRPO) to develop reasoning capabilities,
replicating the DeepSeek-R1-Zero experiment.

**Supported setups** (auto-detected):
| Setup | Config | Time | Cost |
|-------|--------|------|------|
| 1× H200 141GB | 3B, G=16, batch=2 | ~5h | — |
| 3× A100 80GB (Jarvis Labs) | 3B, G=12, data-parallel | ~4.3h | ~$17 |
| 2× A100 80GB | 3B, G=12, data-parallel | ~5.5h | ~$14 |
| 1× H100 80GB | 3B, G=12 | ~6h | ~$18 |
| 1× A100 80GB | 1.5B, G=16 | ~8h | ~$10 |
| 1× A100 40GB | 1.5B, G=8 | ~5.5h | ~$7 |

## 1. Setup

In [None]:
# Install dependencies (TRL >= 0.16 requires PyTorch >= 2.6 for FSDPModule)
!pip install -q -U torch torchvision torchaudio
!pip install -q transformers trl accelerate datasets math-verify wandb tensorboard pyyaml

In [None]:
# Platform + clone repo
import os

PLATFORM = "jarvis"  # Change to "colab" if running on Google Colab
REPO_URL = "https://github.com/manojkgorle/smol-reason"

if PLATFORM == "colab":
    PROJECT_DIR = "/content/smol-reason"
else:
    # Jarvis Labs — use home directory
    PROJECT_DIR = os.path.expanduser("~/smol-reason")

if REPO_URL and not os.path.exists(PROJECT_DIR):
    !git clone {REPO_URL} {PROJECT_DIR}
elif not os.path.exists(PROJECT_DIR):
    print("Please upload the reson-llm project files or set REPO_URL")

os.chdir(PROJECT_DIR)
print(f"Platform: {PLATFORM}")
print(f"Working directory: {os.getcwd()}")

In [None]:
# GPU check — detect count, type, and memory to auto-select config
import torch

if not torch.cuda.is_available():
    raise RuntimeError("No GPU detected! This notebook requires a GPU.")

NUM_GPUS = torch.cuda.device_count()
gpu_name = torch.cuda.get_device_name(0)
gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9

print(f"GPUs found: {NUM_GPUS}× {gpu_name}")
print(f"Memory per GPU: {gpu_mem_gb:.1f} GB")
print(f"Total GPU memory: {NUM_GPUS * gpu_mem_gb:.1f} GB")

# Auto-select config:
#   1× H200 141GB → 3B model, G=16, batch=2 (best single-GPU config)
#   3× A100 80GB  → 3B model, data-parallel
#   2× A100 80GB  → 3B model, data-parallel
#   1× H100 80GB  → 3B model, single GPU
#   1× A100 80GB  → 1.5B model, G=16
#   1× A100 40GB  → 1.5B model, G=8
#   Other         → 1.5B model, minimal
is_h200 = "H200" in gpu_name
is_h100 = "H100" in gpu_name
is_a100_80 = "A100" in gpu_name and gpu_mem_gb >= 70

if is_h200 and gpu_mem_gb >= 130:
    GPU_TIER = "h200_141gb"
    CONFIG_FILE = "configs/grpo_qwen2.5_3b_h200.yaml"
    print("\nConfig: Qwen2.5-3B on H200 (G=16, batch=2, max_completion=2048)")
elif NUM_GPUS >= 3 and is_a100_80:
    GPU_TIER = "3xa100_80gb"
    CONFIG_FILE = "configs/grpo_qwen2.5_3b_a100x3.yaml"
    print(f"\nConfig: Qwen2.5-3B on {NUM_GPUS}× A100 80GB (data-parallel, G=12)")
elif NUM_GPUS >= 2 and is_a100_80:
    GPU_TIER = "2xa100_80gb"
    CONFIG_FILE = "configs/grpo_qwen2.5_3b_a100x2.yaml"
    print(f"\nConfig: Qwen2.5-3B on {NUM_GPUS}× A100 80GB (data-parallel, G=12)")
elif (is_h100 or is_h200) and gpu_mem_gb >= 70:
    GPU_TIER = "h100_80gb"
    CONFIG_FILE = "configs/grpo_qwen2.5_3b_h100.yaml"
    print("\nConfig: Qwen2.5-3B on H100 (G=12, max_completion=1536)")
elif gpu_mem_gb >= 70:
    GPU_TIER = "a100_80gb"
    CONFIG_FILE = "configs/grpo_qwen2.5_1.5b.yaml"
    print("\nConfig: Qwen2.5-1.5B full (G=16, max_completion=2048)")
elif gpu_mem_gb >= 35:
    GPU_TIER = "40gb"
    CONFIG_FILE = "configs/grpo_qwen2.5_1.5b.yaml"
    print("\nConfig: Qwen2.5-1.5B reduced (G=8, max_completion=1024)")
else:
    GPU_TIER = "small"
    CONFIG_FILE = "configs/grpo_qwen2.5_1.5b.yaml"
    print("\nWARNING: Limited GPU memory. Using minimal config.")

In [None]:
# W&B login
import wandb
wandb.login()

In [None]:
# Checkpoint storage setup
#
# Jarvis Labs: persistent local storage — checkpoints survive session stop/restart.
# Google Colab: mount Google Drive for persistence across session disconnects.

if PLATFORM == "colab":
    from google.colab import drive
    drive.mount("/content/drive")
    CHECKPOINT_DIR = "/content/drive/MyDrive/reson-llm-checkpoints"
else:
    # Jarvis Labs / generic GPU server — local storage persists
    CHECKPOINT_DIR = "outputs/grpo-qwen2.5-3b"

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print(f"Checkpoints: {CHECKPOINT_DIR}")

## 2. Configure Training

In [None]:
import yaml

# Load the auto-selected config
print(f"Loading config: {CONFIG_FILE}")
with open(CONFIG_FILE) as f:
    config = yaml.safe_load(f)

# Override output_dir to use our checkpoint directory
config["output_dir"] = CHECKPOINT_DIR

# Apply GPU-tier-specific overrides for smaller GPUs
if GPU_TIER == "40gb":
    config["num_generations"] = 8
    config["max_completion_length"] = 1024
    config["per_device_train_batch_size"] = 1
    print("Adjusted config for 40GB GPU")
elif GPU_TIER == "small":
    config["num_generations"] = 4
    config["max_completion_length"] = 512
    config["per_device_train_batch_size"] = 1
    config["gradient_accumulation_steps"] = 4
    print("Adjusted config for small GPU")

# Write adjusted config
RUNTIME_CONFIG = os.path.join(PROJECT_DIR, "runtime_config.yaml")
with open(RUNTIME_CONFIG, "w") as f:
    yaml.dump(config, f, default_flow_style=False)

print(f"\nModel: {config['model_name_or_path']}")
print(f"GPU tier: {GPU_TIER} ({NUM_GPUS}× GPU)")
print(f"\nTraining config:")
for k, v in config.items():
    print(f"  {k}: {v}")

## 3. Train

In [None]:
# Check for existing checkpoint to resume from
import glob

checkpoints = sorted(glob.glob(f"{CHECKPOINT_DIR}/checkpoint-*"))
resume_arg = ""
if checkpoints:
    latest = checkpoints[-1]
    print(f"Found checkpoint: {latest}")
    resume_arg = f"--resume_from_checkpoint {latest}"
else:
    print("No checkpoint found, starting fresh.")

In [None]:
# Launch training
# Multi-GPU: use accelerate launch with data parallelism
# Single-GPU: run directly with python
# PYTHONPATH ensures `from src.* import` works

import os
os.environ["PYTHONPATH"] = PROJECT_DIR

if NUM_GPUS > 1:
    print(f"Launching distributed training on {NUM_GPUS} GPUs...")
    !PYTHONPATH={PROJECT_DIR} accelerate launch --num_processes {NUM_GPUS} src/train_grpo.py \
        --config {RUNTIME_CONFIG} {resume_arg}
else:
    print("Launching single-GPU training...")
    !PYTHONPATH={PROJECT_DIR} python src/train_grpo.py --config {RUNTIME_CONFIG} {resume_arg}

## 4. Monitor (TensorBoard)

In [None]:
%load_ext tensorboard
%tensorboard --logdir {CHECKPOINT_DIR}

## 5. Evaluate

In [None]:
# Quick eval on 200 samples
!PYTHONPATH={PROJECT_DIR} python src/evaluate.py \
    --model_path {CHECKPOINT_DIR} \
    --num_samples 200 \
    --output_dir eval_results

print(f"\nModel evaluated: {config['model_name_or_path']} (trained)")

In [None]:
# Baseline comparison (raw base model, before RL training)
BASE_MODEL = config["model_name_or_path"]
print(f"Evaluating baseline: {BASE_MODEL}")
!PYTHONPATH={PROJECT_DIR} python src/evaluate.py \
    --model_path {BASE_MODEL} \
    --num_samples 200 \
    --output_dir eval_results_baseline

In [None]:
# Compare results
import json

with open("eval_results/summary.json") as f:
    trained = json.load(f)
with open("eval_results_baseline/summary.json") as f:
    baseline = json.load(f)

print("=" * 60)
print(f"{'Metric':<30} {'Baseline':>12} {'Trained':>12}")
print("=" * 60)
for dataset in ["gsm8k", "math"]:
    if dataset in trained and dataset in baseline:
        b = baseline[dataset]
        t = trained[dataset]
        print(f"{dataset.upper()} accuracy:{'':>13} {b['accuracy']:>11.1%} {t['accuracy']:>11.1%}")
        print(f"{dataset.upper()} format compliance:{'':>4} {b['format_compliance']:>11.1%} {t['format_compliance']:>11.1%}")
        print(f"{dataset.upper()} avg think tokens:{'':>5} {b['avg_think_tokens']:>11.0f} {t['avg_think_tokens']:>11.0f}")
        print("-" * 60)

## 6. Push to Hub (Optional)

In [None]:
# Uncomment and set your HF username to push
# from huggingface_hub import login
# login()

# HF_USERNAME = "your-username"
# MODEL_TAG = config["model_name_or_path"].split("/")[-1].lower()
# REPO_NAME = f"{HF_USERNAME}/{MODEL_TAG}-r1zero-grpo"
# from transformers import AutoModelForCausalLM, AutoTokenizer
# model = AutoModelForCausalLM.from_pretrained(CHECKPOINT_DIR, torch_dtype=torch.bfloat16)
# tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
# model.push_to_hub(REPO_NAME)
# tokenizer.push_to_hub(REPO_NAME)
# print(f"Pushed to: https://huggingface.co/{REPO_NAME}")