# 01 - GRPO Training with Tunix (Local RTX 3090)

This notebook implements GRPO (Group Relative Policy Optimization) training for Gemma 3 1B on GSM8K math problems.

Based on the official [Tunix GRPO Demo](https://github.com/google/tunix/blob/main/examples/grpo_demo.ipynb), adapted for local GPU execution.

**Target format:**
```
<reasoning>step-by-step solution</reasoning><answer>final answer</answer>
```

## Cell 1: Environment Setup & Imports

In [None]:
import os
import warnings
import logging

import functools
import gc
import re
import shutil
import time
from pathlib import Path
from pprint import pprint

from flax import nnx
import grain
import humanize
import jax
import jax.numpy as jnp
import optax
from orbax import checkpoint as ocp
import qwix
from tqdm.auto import tqdm
from huggingface_hub import snapshot_download
from datasets import load_dataset
import wandb

# Tunix imports
from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
from tunix.models.gemma3 import model as gemma3_model
from tunix.models.gemma3 import params_safetensors as gemma3_safetensors
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.sft import metrics_logger

# Our library
from tunix_hack.utils.xml_parsing import extract_tag, has_valid_format
from tunix_hack.rewards.math_reward import normalize_math_answer
from tunix_hack.models import load_tokenizer
from tunix_hack.inference import create_sampler

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"Default backend: {jax.default_backend()}")

## Cell 2: Configuration Constants

In [None]:
# Run configuration
RUN_NAME = input("Enter run name (e.g., 'run_1', 'baseline', 'rank16'): ").strip()
if not RUN_NAME:
    RUN_NAME = f"run_{int(time.time())}"
print(f"Run name: {RUN_NAME}")

# Model configuration
MODEL_FAMILY = "gemma3"
MODEL_VERSION = "gemma3-1b-it"  # Instruction-tuned variant

# Mesh configuration for single GPU
# Gemma 3 requires both fsdp and tp axes, even on single device
MESH = ((1, 1), ("fsdp", "tp"))

# Training configuration - aggressive memory optimization for RTX 3090 (24GB VRAM)
TRAIN_MICRO_BATCH_SIZE = 1   # Reduced from 4 for memory
NUM_BATCHES = 500            # Number of training batches
NUM_EPOCHS = 3               # Number of epochs
NUM_ITERATIONS = 1           # Optimization iterations per batch
TRAIN_FRACTION = 0.9         # Train/val split
EVAL_EVERY_N_STEPS = 10      # Evaluation frequency

# Calculate training steps dynamically (like reference)
MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)
WARMUP_STEPS = int(0.1 * MAX_STEPS)  # 10% warmup (like reference)

# Generation configuration
MAX_PROMPT_LENGTH = 256
TOTAL_GENERATION_STEPS = 512
TEMPERATURE = 0.9            # High-ish for diverse responses during training
TOP_P = 1.0
TOP_K = 50

# GRPO hyperparameters
NUM_GENERATIONS = 4          # G in GRPO algorithm
BETA = 0.08                  # KL divergence penalty coefficient
EPSILON = 0.2                # Clipping range for stable updates

# Optimizer hyperparameters
LEARNING_RATE = 3e-6         # Peak learning rate
WEIGHT_DECAY = 0.1           # AdamW weight decay
B1 = 0.9                     # Adam beta1
B2 = 0.99                    # Adam beta2
MAX_GRAD_NORM = 0.1          # Gradient clipping

# LoRA configuration - reduced rank for memory
RANK = 16                    # Reduced from 64 for memory (still good quality)
ALPHA = 32                   # Usually alpha = rank, but can experiment

# Dataset configuration
NUM_TEST_BATCHES = 10        # Originally 100, reduced for for memory constraints

# Inference generation configs (from reference)
GENERATION_CONFIGS = {
    "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
    "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
}

# Paths - organized by run name (use project dir, not system /tmp)
PROJECT_ROOT = Path("/home/jimnix/gitrepos/tunix-hack")
CKPT_DIR = str(PROJECT_ROOT / "outputs" / "checkpoints" / "grpo" / RUN_NAME)
INTERMEDIATE_CKPT_DIR = str(PROJECT_ROOT / "tmp" / "tunix_intermediate" / RUN_NAME)
TENSORBOARD_DIR = str(PROJECT_ROOT / "tmp" / "tensorboard" / "grpo" / RUN_NAME)

# === SFT Warmup Checkpoint Loading ===
# Set to True to load LoRA weights from SFT warmup before GRPO training
# This helps the model start with XML format knowledge
LOAD_SFT_CHECKPOINT = True
SFT_CKPT_DIR = str(PROJECT_ROOT / "outputs" / "checkpoints" / "sft")
# Auto-find latest SFT checkpoint, or specify manually:
# SFT_CKPT_PATH = str(PROJECT_ROOT / "outputs" / "checkpoints" / "sft" / "warmup" / "500" / "model_params")
SFT_CKPT_PATH = None  # Will be auto-detected if LOAD_SFT_CHECKPOINT is True

# Create directories
os.makedirs(CKPT_DIR, exist_ok=True)
os.makedirs(INTERMEDIATE_CKPT_DIR, exist_ok=True)
os.makedirs(TENSORBOARD_DIR, exist_ok=True)

print(f"\nConfiguration loaded (memory-optimized for RTX 3090).")
print(f"  LoRA rank: {RANK}")
print(f"  Batch size: {TRAIN_MICRO_BATCH_SIZE}")
print(f"  Generations: {NUM_GENERATIONS}")
print(f"  MAX_STEPS: {MAX_STEPS} (calculated)")
print(f"  WARMUP_STEPS: {WARMUP_STEPS} (10% of max)")
print(f"  Load SFT checkpoint: {LOAD_SFT_CHECKPOINT}")
print(f"Checkpoint directory: {CKPT_DIR}")
print(f"TensorBoard directory: {TENSORBOARD_DIR}")


## Cell 3: HuggingFace Model Download

In [None]:
# Download Gemma 3 1B instruction-tuned model from HuggingFace
MODEL_ID = "google/gemma-3-1b-it"

print(f"Downloading {MODEL_ID} from HuggingFace...")
model_path = snapshot_download(MODEL_ID)

print(f"Model downloaded to: {model_path}")

# List downloaded files
for f in Path(model_path).iterdir():
    print(f"  {f.name}")

## Cell 4: JAX Mesh Setup

In [None]:
# Create JAX mesh for single GPU
# For multi-GPU, you would use (num_gpus,) or (fsdp, tp) configuration
mesh = jax.make_mesh(*MESH)
print(f"Mesh created: {mesh}")
print(f"Mesh devices: {mesh.devices}")

## Cell 5: Load Tokenizer

In [None]:
# Load tokenizer
tokenizer = load_tokenizer(model_path)
print("Tokenizer loaded.")

# Test tokenization
test_text = "Hello, world!"
tokens = tokenizer.encode(test_text)
decoded = tokenizer.decode(tokens)
print(f"Test: '{test_text}' -> {tokens} -> '{decoded}'")

## Cell 6: Helper Functions for Model Loading

In [None]:
def get_model_config():
    """Get Gemma 3 1B model configuration."""
    return gemma3_model.ModelConfig.gemma3_1b()


def get_base_model(model_path: str):
    """Load base Gemma 3 model from HuggingFace safetensors.
    
    Args:
        model_path: Path to the downloaded model directory.
        
    Returns:
        Tuple of (model, mesh, model_config).
    """
    mesh = jax.make_mesh(*MESH)
    model_config = get_model_config()
    
    # Load model from safetensors with sharding
    model = gemma3_safetensors.create_model_from_safe_tensors(
        model_path,
        model_config,
        mesh,
    )
    return model, mesh, model_config


def get_lora_model(base_model, mesh):
    """Apply LoRA to base model to create trainable policy.
    
    Note: This modifies base_model in-place, allowing backbone sharing
    between the LoRA policy and reference model for memory efficiency.
    
    Args:
        base_model: The base Gemma model.
        mesh: JAX mesh for sharding (unused, kept for API compatibility).
        
    Returns:
        LoRA-augmented model (shares backbone with base_model).
    """
    lora_provider = qwix.LoraProvider(
        # Target attention and MLP layers
        module_path=(
            ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
            ".*attn_vec_einsum"
        ),
        rank=RANK,
        alpha=ALPHA,
    )
    
    model_input = base_model.get_model_input()
    lora_model = qwix.apply_lora_to_model(
        base_model, lora_provider, **model_input
    )
    
    # Don't re-shard - this preserves shared backbone with base_model
    return lora_model


print("Helper functions defined.")

## Cell 7: Load and Convert Base Model

**Note:** Loading from HuggingFace safetensors and saving as intermediate checkpoint for NNX compatibility.

In [None]:
# Note: We now load directly from safetensors in the next cell,
# so intermediate checkpoint is not needed for this workflow.
# The get_base_model function handles sharding automatically.
print("Intermediate checkpoint step skipped - loading directly from safetensors.")

## Cell 8: Load Base Model

Load directly from HuggingFace safetensors. This model will be used as both the reference (frozen) and the backbone for LoRA (shared weights for memory efficiency).

In [None]:
# Load base model directly from safetensors (used as both reference and LoRA backbone)
# This enables backbone sharing between actor and reference for memory efficiency
print("Loading base model...")
base_model, mesh, model_config = get_base_model(model_path)
print("Base model loaded.")
nnx.display(base_model)

## Cell 9: Create LoRA Policy Model

In [None]:
# Create LoRA-augmented policy model (trainable)
# Note: This shares the backbone with base_model for memory efficiency
print("Creating LoRA policy model...")
lora_policy = get_lora_model(base_model, mesh=mesh)
print("LoRA policy model created.")
print(f"LoRA rank: {RANK}, alpha: {ALPHA}")
nnx.display(lora_policy)

In [None]:
# === Load SFT Checkpoint (if enabled) ===
# This loads the LoRA weights from SFT warmup to give the model XML format knowledge

def find_latest_sft_checkpoint(sft_dir: str) -> str | None:
    """Find the latest SFT checkpoint in the directory."""
    sft_path = Path(sft_dir)
    if not sft_path.exists():
        return None
    
    # Look for run directories
    run_dirs = [d for d in sft_path.iterdir() if d.is_dir()]
    if not run_dirs:
        return None
    
    # Find latest checkpoint across all runs
    latest_ckpt = None
    latest_step = -1
    
    for run_dir in run_dirs:
        # Look for step directories (numeric names)
        for step_dir in run_dir.iterdir():
            if step_dir.is_dir() and step_dir.name.isdigit():
                step = int(step_dir.name)
                model_params = step_dir / "model_params"
                if model_params.exists() and step > latest_step:
                    latest_step = step
                    latest_ckpt = str(model_params)
    
    return latest_ckpt


if LOAD_SFT_CHECKPOINT:
    # Find checkpoint path
    ckpt_path = SFT_CKPT_PATH or find_latest_sft_checkpoint(SFT_CKPT_DIR)
    
    if ckpt_path and os.path.exists(ckpt_path):
        print(f"Loading SFT checkpoint from: {ckpt_path}")
        
        # Get abstract structure of LoRA params
        abs_params = jax.tree.map(
            lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
            nnx.state(lora_policy, nnx.LoRAParam),
        )
        
        # Load checkpoint
        checkpointer = ocp.StandardCheckpointer()
        sft_lora_params = checkpointer.restore(ckpt_path, target=abs_params)
        
        # Update LoRA model with SFT weights
        nnx.update(
            lora_policy,
            jax.tree.map(
                lambda a, b: b,
                nnx.state(lora_policy, nnx.LoRAParam),
                sft_lora_params,
            ),
        )
        print("SFT checkpoint loaded successfully!")
        print("Model now has XML format knowledge from SFT warmup.")
        wandb.run.summary["sft_checkpoint"] = ckpt_path
    else:
        print(f"WARNING: LOAD_SFT_CHECKPOINT is True but no checkpoint found!")
        print(f"  Searched in: {SFT_CKPT_DIR}")
        print(f"  Run 00_sft_warmup.ipynb first to create an SFT checkpoint.")
        print("  Continuing with fresh LoRA weights...")
else:
    print("Skipping SFT checkpoint loading (LOAD_SFT_CHECKPOINT=False).")

## Cell 10: Define Reward Functions

Tunix GRPO expects reward functions with signature:
```python
def reward_fn(prompts, completions, **kwargs) -> List[float]
```

In [None]:
def match_format_exactly(prompts, completions, **kwargs):
    """Reward for having complete <reasoning>/<answer> format.
    
    Returns 3.0 if both tags present, 0.0 otherwise.
    """
    scores = []
    for completion in completions:
        if has_valid_format(completion):
            scores.append(3.0)
        else:
            scores.append(0.0)
    return scores


def match_format_approximately(prompts, completions, **kwargs):
    """Incremental reward for approaching correct format.
    
    Gives +0.5 for each tag present exactly once, -0.5 penalty if missing.
    Range: -2.0 (no tags) to +2.0 (all 4 tags present once).
    """
    scores = []
    for completion in completions:
        score = 0
        # +0.5 for each correct tag present exactly once, -0.5 penalty if missing
        score += 0.5 if completion.count("<reasoning>") == 1 else -0.5
        score += 0.5 if completion.count("</reasoning>") == 1 else -0.5
        score += 0.5 if completion.count("<answer>") == 1 else -0.5
        score += 0.5 if completion.count("</answer>") == 1 else -0.5
        scores.append(score)
    return scores


def check_answer(prompts, completions, answer, **kwargs):
    """Reward for correct answer with penalties for wrong answers.
    
    - 3.0 for exact match
    - 1.5 for match after strip
    - 0.5 for answer within 10% of correct
    - 0.25 for answer within 20% of correct
    - -1.0 for wrong answer (penalty)
    - -0.5 for unparseable answer (penalty)
    """
    scores = []
    for completion, true_answer in zip(completions, answer):
        extracted = extract_tag(completion, "answer")
        if not extracted:
            scores.append(0.0)
            continue
        
        score = 0
        # Exact match
        if extracted == true_answer:
            score += 3.0
        # Match after stripping whitespace
        elif extracted.strip() == true_answer.strip():
            score += 1.5
        else:
            # Ratio-based partial credit / penalty
            try:
                ratio = float(extracted) / float(true_answer)
                if 0.9 <= ratio <= 1.1:
                    score += 0.5
                elif 0.8 <= ratio <= 1.2:
                    score += 0.25
                else:
                    score -= 1.0  # Penalize wrong answers
            except (ValueError, ZeroDivisionError):
                score -= 0.5  # Penalize unparseable
        
        scores.append(score)
    return scores


def check_numbers(prompts, completions, answer, **kwargs):
    """Bonus for having the correct number appear anywhere in completion.
    
    Returns 1.0 if the answer number appears in the completion.
    """
    scores = []
    for completion, true_answer in zip(completions, answer):
        norm_answer = normalize_math_answer(true_answer)
        if norm_answer in completion:
            scores.append(1.0)
        else:
            scores.append(0.0)
    return scores


# Test reward functions
test_prompts = ["What is 2+2?"]
test_completions = ["<reasoning>2+2=4</reasoning><answer>4</answer>"]
test_answers = ["4"]

print("Testing reward functions:")
print(f"  format_exact: {match_format_exactly(test_prompts, test_completions)}")
print(f"  format_approx: {match_format_approximately(test_prompts, test_completions)}")
print(f"  check_answer: {check_answer(test_prompts, test_completions, test_answers)}")
print(f"  check_numbers: {check_numbers(test_prompts, test_completions, test_answers)}")

# Test with bad format (what the model was outputting)
test_bad = ["reasoning:\n2+2=4\nFinal answer: 4"]
print("\nTesting with bad format (model's actual output):")
print(f"  format_exact: {match_format_exactly(test_prompts, test_bad)}")
print(f"  format_approx: {match_format_approximately(test_prompts, test_bad)}  # Now gives -2.0 penalty!")

# Test with wrong answer
test_wrong = ["<reasoning>2+2=5</reasoning><answer>5</answer>"]
print("\nTesting with wrong answer:")
print(f"  check_answer: {check_answer(test_prompts, test_wrong, test_answers)}  # Now gives -1.0 penalty!")

## Cell 11: Load GSM8K Dataset

In [None]:
# System prompt for math problems
SYSTEM_PROMPT = """You are a math tutor. Solve the problem step by step.
Format your response EXACTLY as:
<reasoning>your step-by-step solution</reasoning><answer>final numerical answer only</answer>"""

TEMPLATE = "<start_of_turn>user\n{system_prompt}\n\n{question}<end_of_turn>\n<start_of_turn>model\n"


def extract_hash_answer(answer_text: str) -> str:
    """Extract final answer from GSM8K format.
    
    GSM8K answers end with #### followed by the numeric answer.
    """
    if "####" in answer_text:
        return answer_text.split("####")[-1].strip()
    return answer_text.strip()


def format_gsm8k_example(example):
    """Format a GSM8K example for GRPO training."""
    question = example["question"]
    answer = extract_hash_answer(example["answer"])
    
    prompt = TEMPLATE.format(
        system_prompt=SYSTEM_PROMPT,
        question=question
    )
    
    return {
        "prompts": prompt,
        "answer": answer,
    }


# Load GSM8K dataset from HuggingFace (more reliable than tfds)
print("Loading GSM8K dataset from HuggingFace...")
gsm8k_train = load_dataset("gsm8k", "main", split="train")
gsm8k_test = load_dataset("gsm8k", "main", split="test")

print(f"Train examples: {len(gsm8k_train)}")
print(f"Test examples: {len(gsm8k_test)}")

# Convert to list and format (HuggingFace returns strings directly)
train_data = [
    format_gsm8k_example({"question": ex["question"], 
                          "answer": ex["answer"]})
    for ex in tqdm(gsm8k_train, desc="Formatting train data")
]

test_data = [
    format_gsm8k_example({"question": ex["question"], 
                          "answer": ex["answer"]})
    for ex in tqdm(gsm8k_test, desc="Formatting test data")
]

# Show example
print("\nExample formatted data:")
pprint(train_data[0])

## Cell 12: Create Training Dataset

In [None]:
# Create grain MapDataset from our formatted data
# (Tunix expects grain.MapDataset, not plain Python lists)
train_grain = (
    grain.MapDataset.source(train_data)
    .shuffle(seed=42)
    .batch(TRAIN_MICRO_BATCH_SIZE)[:NUM_BATCHES]
)

test_grain = (
    grain.MapDataset.source(test_data)
    .batch(TRAIN_MICRO_BATCH_SIZE)[:NUM_TEST_BATCHES]
)

# Split train into train/val
split_idx = int(len(train_grain) * TRAIN_FRACTION)
if TRAIN_FRACTION == 1.0:
    train_dataset = train_grain.repeat(NUM_EPOCHS)
    val_dataset = None
else:
    train_dataset = train_grain[:split_idx].repeat(NUM_EPOCHS)
    val_dataset = train_grain[split_idx:]

test_dataset = test_grain

print(f"Train batches: {len(train_dataset)}")
print(f"Val batches: {len(val_dataset) if val_dataset else 0}")
print(f"Test batches: {len(test_dataset)}")

# Show first batch
print("\nFirst training batch:")
pprint(train_dataset[0])

## Cell 13: Configure Optimizer

In [None]:
# AdamW optimizer with warmup cosine decay schedule
optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)

# Add gradient clipping
if MAX_GRAD_NORM is not None:
    optimizer = optax.chain(
        optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
        optimizer,
    )

print("Optimizer configured.")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Weight decay: {WEIGHT_DECAY}")
print(f"  Gradient clipping: {MAX_GRAD_NORM}")

## Cell 14: Configure GRPO Training

In [None]:
# Metrics logging options
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir=TENSORBOARD_DIR,
    flush_every_n_steps=20,
)

# Checkpointing options
checkpointing_options = ocp.CheckpointManagerOptions(
    max_to_keep=3,
    save_interval_steps=EVAL_EVERY_N_STEPS,
)

# RL Cluster configuration
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine='vanilla',  # Use vanilla for single GPU
    offload_to_cpu=True,       # Enable CPU offloading for RTX 3090 memory
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=EVAL_EVERY_N_STEPS,
        max_steps=MAX_STEPS,
        mini_batch_size=TRAIN_MICRO_BATCH_SIZE,
        train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
        metrics_logging_options=metrics_logging_options,
        checkpoint_root_directory=CKPT_DIR,
        checkpointing_options=checkpointing_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
        eos_tokens=[1, 106],  # Gemma EOS tokens - critical for proper generation stopping
    ),
)

# GRPO algorithm configuration
grpo_config = GRPOConfig(
    num_generations=NUM_GENERATIONS,  # Generate N responses per prompt
    num_iterations=NUM_ITERATIONS,    # Optimization iterations per batch
    beta=BETA,                        # KL divergence penalty
    epsilon=EPSILON,                  # Clipping range
)

print("GRPO configuration complete.")
print(f"  Generations per prompt: {NUM_GENERATIONS}")
print(f"  KL beta: {BETA}")
print(f"  Clip epsilon: {EPSILON}")
print(f"  CPU offloading: enabled")

## Cell 15: Initialize RLCluster and GRPOLearner

In [None]:
# Create RL Cluster
# Note: Using base_model as reference enables backbone sharing with lora_policy
print("Creating RL cluster...")
rl_cluster = rl_cluster_lib.RLCluster(
    actor=lora_policy,
    reference=base_model,  # Shares backbone with lora_policy for memory efficiency
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)
print("RL cluster created.")

# Create GRPO Trainer
print("Creating GRPO learner...")
grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    algo_config=grpo_config,
)
print("GRPO learner created.")

## Cell 15b: Periodic Completion Logging

Register a callback to print sample completions every N steps so we can see what format the model is outputting.

In [None]:
# Periodic completion logging
def log_completions(metrics_buffer):
    """Print sample prompts and completions every N steps."""
    step = metrics_buffer.global_steps
    if step % 50 != 0:
        return

    metrics = metrics_buffer.metrics
    if "completions" not in metrics:
        return

    completions, _ = metrics["completions"]
    prompts, _ = metrics.get("prompts", ([], None))

    print(f"\n{'='*80}")
    print(f"Step {step} - Sample Prompts & Completions")
    print(f"{'='*80}")

    for i in range(min(2, len(completions))):
        if prompts and i < len(prompts):
            print(f"\n[Prompt {i+1}]")
            print(prompts[i])
        print(f"\n[Completion {i+1}]")
        print(completions[i][:800])
        if len(completions[i]) > 800:
            print(f"... [truncated, {len(completions[i])} total chars]")
    print()

rl_cluster.with_external_metrics_logger(log_completions)
print("Completion logging registered (every 50 steps).")

In [None]:
# Checkpoint preservation callback - saves best and epoch-end checkpoints
# These won't be deleted by Orbax's max_to_keep policy

_ckpt_state = {
    "best_reward": float("-inf"),
    "best_step": None,
}

def preserve_checkpoints(metrics_buffer):
    """Preserve best and epoch-end checkpoints, log to wandb."""
    step = metrics_buffer.global_steps
    metrics = metrics_buffer.metrics

    # Calculate total reward from metrics
    total_reward = 0.0
    reward_dict = {}
    for key in ["match_format_exactly", "match_format_approximately",
                "check_answer", "check_numbers"]:
        if key in metrics:
            val, _ = metrics[key]
            if isinstance(val, (int, float)):
                reward_val = float(val)
            elif hasattr(val, "mean"):
                reward_val = float(val.mean())
            else:
                reward_val = 0.0
            total_reward += reward_val
            reward_dict[f"reward/{key}"] = reward_val

    # Log to wandb
    log_dict = {
        "step": step,
        "reward/total": total_reward,
        **reward_dict,
    }
    
    # Add other metrics
    for k, v in metrics.items():
        if isinstance(v, tuple) and len(v) == 2:
            val, _ = v
            if isinstance(val, (int, float)):
                log_dict[k] = val
            elif hasattr(val, "mean"):
                log_dict[k] = float(val.mean())
    
    wandb.log(log_dict, step=step)

    # Check for new best
    if total_reward > _ckpt_state["best_reward"]:
        _ckpt_state["best_reward"] = total_reward
        _ckpt_state["best_step"] = step
        _copy_checkpoint(step, "best")
        print(f"[Checkpoint] New best at step {step}: reward={total_reward:.3f}")
        wandb.run.summary["best_step"] = step
        wandb.run.summary["best_reward"] = total_reward

    # Epoch-end preservation (steps_per_epoch = NUM_BATCHES * TRAIN_FRACTION)
    steps_per_epoch = int(NUM_BATCHES * TRAIN_FRACTION)
    if step > 0 and step % steps_per_epoch == 0:
        epoch = step // steps_per_epoch
        _copy_checkpoint(step, f"epoch_{epoch}")
        print(f"[Checkpoint] Saved epoch {epoch} checkpoint at step {step}")


def _copy_checkpoint(step: int, name: str):
    """Copy a checkpoint to preserved directory."""
    src = Path(CKPT_DIR) / "actor" / str(step)
    dst = Path(CKPT_DIR) / "preserved" / name

    if not src.exists():
        # Checkpoint may not be saved yet, try nearby steps
        for offset in [0, -1, 1, -2, 2, -5, 5, -10, 10]:
            candidate = Path(CKPT_DIR) / "actor" / str(step + offset)
            if candidate.exists():
                src = candidate
                break
        else:
            print(f"[Checkpoint] Warning: No checkpoint found near step {step}")
            return

    dst.parent.mkdir(parents=True, exist_ok=True)
    if dst.exists():
        shutil.rmtree(dst)
    shutil.copytree(src, dst)
    print(f"[Checkpoint] Copied {src.name} -> preserved/{name}")


rl_cluster.with_external_metrics_logger(preserve_checkpoints)
print("Checkpoint preservation callback registered.")

## Cell 16: TensorBoard Setup (Optional)

In [None]:
# # Load TensorBoard extension for monitoring
# %load_ext tensorboard
# %tensorboard --logdir {TENSORBOARD_DIR} --port=0
# TENSORBOARD_DIR
f"tensorboard --logdir {TENSORBOARD_DIR} --port=6006"


## Cell 17: Run GRPO Training

In [None]:
# Run training!
print("Starting GRPO training...")
print(f"  Total steps: {MAX_STEPS}")
print(f"  Batch size: {TRAIN_MICRO_BATCH_SIZE}")
print(f"  Generations per prompt: {NUM_GENERATIONS}")
print()

with mesh:
    grpo_trainer.train(train_dataset)

print("\nTraining complete!")

# Create sampler for inference (wandb already initialized in Cell 4)
sampler = create_sampler(
    lora_policy,
    tokenizer,
    model_config,
    max_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
)
print("Sampler created.")

In [None]:
# JAX's monitoring callbacks need wandb initialized (disabled mode = no actual logging)
if wandb.run is None:
    wandb.init(mode="disabled")

# Create sampler for inference
sampler = create_sampler(
    lora_policy,
    tokenizer,
    model_config,
    max_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
)
print("Sampler created.")

In [None]:
# Test the trained model on a few examples
print("Testing trained model...\n")

test_questions = [
    "What is 15 + 27?",
    "If Mary has 5 apples and gives 2 to John, how many apples does Mary have?",
    "A store sells 3 books for $12. How much does one book cost?",
]

for question in test_questions:
    prompt = TEMPLATE.format(
        system_prompt=SYSTEM_PROMPT,
        question=question
    )
    
    # Generate using greedy config for deterministic evaluation
    result = sampler(
        input_strings=[prompt],
        max_generation_steps=TOTAL_GENERATION_STEPS,
        **GENERATION_CONFIGS["greedy"],  # Use greedy for eval (deterministic)
        echo=False,
        eos_tokens=[1, 106],  # EOS tokens for Gemma 3
    )
    
    response = result.text[0]
    
    print(f"Question: {question}")
    print(f"Response: {response}")
    print("-" * 60)

print("\nCheckpoints saved to:", CKPT_DIR)

In [None]:
# Clean up intermediate checkpoints (optional)
# Uncomment to remove intermediate files
# shutil.rmtree(INTERMEDIATE_CKPT_DIR, ignore_errors=True)
# print("Intermediate checkpoints cleaned up.")

# Finish wandb run
wandb.finish()
print("WandB run finished.")

# Free GPU memory
del grpo_trainer
del rl_cluster
del lora_policy
del base_model
gc.collect()
print("Memory freed.")

In [None]:
# Clean up intermediate checkpoints (optional)
# Uncomment to remove intermediate files
# shutil.rmtree(INTERMEDIATE_CKPT_DIR, ignore_errors=True)
# print("Intermediate checkpoints cleaned up.")

# Free GPU memory
del grpo_trainer
del rl_cluster
del lora_policy
del base_model
gc.collect()
print("Memory freed.")