---
title: "GRPO Demo – Mode A (Planner + Solver)"
jupyter: python3
---

# Overview

This notebook implements **Mode A: Planner + Solver with RL only on the Solver**.

Key idea:
- A **Planner model** (Gemma3-1B + LoRA, SFT-only / frozen) generates a *plan*.
- A **Solver model** (Gemma3-1B + LoRA, trained with GRPO) receives *(question + plan)* and generates the final reasoning + answer.
- **Only the Solver is trained with GRPO.**
- All existing GSM8K reward functions (format + correctness) are preserved.

This is the safest way to introduce multi-step reasoning without destabilizing RL.

## Environment Setup

In [None]:
import os
os.environ["HF_HUB_DISABLE_XET"] = "1"

import jax
print("JAX backend:", jax.default_backend())
print("JAX devices:", jax.devices())

## Install Dependencies

In [None]:
!pip install -q kagglehub ipywidgets tensorflow tensorflow_datasets tensorboardX transformers grain
!pip install "google-tunix[prod]==0.1.3"
!pip uninstall -q -y flax
!pip install flax==0.12.0
!pip install -q datasets wandb==0.22.0

## Core Imports

In [None]:
import os, gc, re, csv, shutil, functools
from pprint import pprint
from pathlib import Path

import jax
import jax.numpy as jnp
import optax
import wandb
import grain
import humanize

from flax import nnx
from orbax import checkpoint as ocp
from orbax.checkpoint import CheckpointManager, CheckpointManagerOptions
from orbax.checkpoint.args import StandardRestore

from tunix.generate import sampler as sampler_lib
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.sft import metrics_logger
from tunix.models.gemma3 import params, model

import qwix
import tensorflow_datasets as tfds
from datasets import load_dataset

## Hyperparameters (Unchanged from Base GRPO)

In [None]:
TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR = "./data/test"
TRAIN_FRACTION = 1.0

RANK = 64
ALPHA = 64.0

MESH = [(1, 4), ("fsdp", "tp")]

MAX_PROMPT_LENGTH = 256
TOTAL_GENERATION_STEPS = 512
TEMPERATURE = 0.9
TOP_P = 1.0
TOP_K = 50
NUM_GENERATIONS = 4

NUM_ITERATIONS = 1
BETA = 0.08
EPSILON = 0.2

TRAIN_MICRO_BATCH_SIZE = 4
NUM_BATCHES = 3738
NUM_TEST_BATCHES = 100
NUM_EPOCHS = 1

MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)

LEARNING_RATE = 3e-6
B1, B2 = 0.9, 0.99
WEIGHT_DECAY = 0.1
WARMUP_STEPS = 0.1 * MAX_STEPS
MAX_GRAD_NORM = 0.1

INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt"
CKPT_ROOT = "/kaggle/working/ckpts"
ACTOR_CKPT_DIR = os.path.join(CKPT_ROOT, "actor")

SAVE_INTERVAL_STEPS = 500
MAX_TO_KEEP = 4

os.makedirs(INTERMEDIATE_CKPT_DIR, exist_ok=True)
os.makedirs(ACTOR_CKPT_DIR, exist_ok=True)


## Planner + Solver Prompt Templates

In [None]:
PLAN_START = "<plan>"
PLAN_END = "</plan>"

reasoning_start = "<reasoning>"
reasoning_end = "</reasoning>"
solution_start = "<answer>"
solution_end = "</answer>"

PLANNER_TEMPLATE = f"""
<start_of_turn>user
You are a planning assistant. Produce a short numbered plan (3–5 steps) for solving the problem.
Do NOT solve the problem.

Problem:
{{question}}
<end_of_turn>
<start_of_turn>planner
{PLAN_START}
"""

SYSTEM_PROMPT = f"""
Follow the plan. Show reasoning between {reasoning_start} and {reasoning_end}.
Then output the final number between {solution_start} and {solution_end}.
"""

SOLVER_TEMPLATE = """
<start_of_turn>user
{system_prompt}

Problem:
{question}

Proposed plan:
{plan}
<end_of_turn>
<start_of_turn>model
"""

## Dataset Loader (Unchanged GSM8K)

In [None]:
def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()


def get_dataset(data_dir, split="train") -> grain.MapDataset:
    os.makedirs(data_dir, exist_ok=True)
    data = load_dataset("gsm8k", "main", split=split)

    def _as_text(v):
        return v if isinstance(v, str) else v.decode("utf-8")

    dataset = (
        grain.MapDataset.source(data)
        .shuffle(seed=42)
        .map(
            lambda x: {
                "question": _as_text(x["question"]),
                "answer": extract_hash_answer(_as_text(x["answer"])),
            }
        )
    )
    return dataset

## Train / Test Split

In [None]:
dataset = get_dataset(TRAIN_DATA_DIR, "train").batch(TRAIN_MICRO_BATCH_SIZE)[:NUM_BATCHES]
train_dataset = dataset.repeat(NUM_EPOCHS)
test_dataset = get_dataset(TEST_DATA_DIR, "test").batch(TRAIN_MICRO_BATCH_SIZE)[:NUM_TEST_BATCHES]

## Save Original Gemma Checkpoint into NNX Format

In [None]:
MODEL_CP_PATH = params.GEMMA3_1B_IT
config = model.ModelConfig.gemma3_1b()

gemma = params.create_model_from_checkpoint(MODEL_CP_PATH, config)
tokenizer = params.create_tokenizer()

checkpointer = ocp.StandardCheckpointer()
_, state = nnx.split(gemma)

ckpt_manager = CheckpointManager(
    INTERMEDIATE_CKPT_DIR,
    checkpointers=checkpointer,
    options=CheckpointManagerOptions(save_interval_steps=1, max_to_keep=1),
)
ckpt_manager.save(0, state)
ckpt_manager.wait_until_finished()

del gemma, state, params
gc.collect()

## Load Reference Model + Apply LoRA Separately for Planner and Solver

In [None]:
def get_gemma_ref_model(ckpt_root):
    mesh = jax.make_mesh(*MESH)
    abs_gemma = nnx.eval_shape(lambda: model.create_model_from_checkpoint(MODEL_CP_PATH, config))

    abs_state = nnx.state(abs_gemma)
    abs_state = jax.tree.map(
        lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.bfloat16, sharding=s),
        abs_state,
        nnx.get_named_sharding(abs_state, mesh),
    )

    ckpt_manager = CheckpointManager(
        ckpt_root,
        checkpointers=ocp.StandardCheckpointer(),
        options=CheckpointManagerOptions(save_interval_steps=1, max_to_keep=1),
    )

    latest_step = ckpt_manager.latest_step()
    restored = ckpt_manager.restore(latest_step, args=StandardRestore(abs_state))
    graph_def, _ = nnx.split(abs_gemma)
    gemma = nnx.merge(graph_def, restored)
    return gemma, mesh


def get_lora_model(base_model, mesh):
    lora_provider = qwix.LoraProvider(
        module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum",
        rank=RANK,
        alpha=ALPHA,
    )

    lora_model = qwix.apply_lora_to_model(base_model, lora_provider, **base_model.get_model_input())

    with mesh:
        state = nnx.state(lora_model)
        pspecs = nnx.get_partition_spec(state)
        sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
        nnx.update(lora_model, sharded_state)

    return lora_model

ref_model, mesh = get_gemma_ref_model(INTERMEDIATE_CKPT_DIR)

# Separate planner and solver LoRA policies
planner_policy = get_lora_model(ref_model, mesh)
solver_policy = get_lora_model(ref_model, mesh)


## Samplers for Planner and Solver

In [None]:
planner_sampler = sampler_lib.Sampler(
    transformer=planner_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(cache_size=512, num_layers=ref_model.model_config.num_layers),
)

solver_sampler = sampler_lib.Sampler(
    transformer=solver_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=ref_model.model_config.num_layers,
    ),
)

## Planner → Solver Generation Pipeline

In [None]:
def generate_plan(questions):
    inputs = [PLANNER_TEMPLATE.format(question=q) for q in questions]
    out = planner_sampler(
        input_strings=inputs,
        max_generation_steps=128,
        temperature=0.7,
        top_k=50,
        top_p=0.95,
        echo=False,
    )

    plans = []
    for txt in out.text:
        m = re.search(r"<plan>(.+?)</plan>", txt, re.DOTALL)
        plans.append(m.group(1).strip() if m else txt.strip())
    return plans


def generate_with_plan(questions):
    plans = generate_plan(questions)

    solver_inputs = [
        SOLVER_TEMPLATE.format(system_prompt=SYSTEM_PROMPT, question=q, plan=p)
        for q, p in zip(questions, plans)
    ]

    out = solver_sampler(
        input_strings=solver_inputs,
        max_generation_steps=768,
        temperature=TEMPERATURE,
        top_k=TOP_K,
        top_p=TOP_P,
        echo=False,
    )

    return plans, out.text

## Reward Functions (Unchanged – Applied Only to Solver Output)

In [None]:
match_format = re.compile(rf"^[\\s]{{0,}}{reasoning_start}.+?{reasoning_end}.*?{solution_start}(.+?){solution_end}[\\s]{{0,}}$", re.DOTALL)

def match_format_exactly(prompts, completions, **kwargs):
    return [0 if match_format.search(r) is None else 3.0 for r in completions]


def match_format_approximately(prompts, completions, **kwargs):
    scores = []
    for r in completions:
        score = 0
        score += 0.5 if r.count(reasoning_start) == 1 else -0.5
        score += 0.5 if r.count(reasoning_end) == 1 else -0.5
        score += 0.5 if r.count(solution_start) == 1 else -0.5
        score += 0.5 if r.count(solution_end) == 1 else -0.5
        scores.append(score)
    return scores

## GRPO Setup (Solver Only)

In [None]:
optimizer = optax.chain(
    optax.clip_by_global_norm(MAX_GRAD_NORM),
    optax.adamw(
        learning_rate=optax.schedules.warmup_cosine_decay_schedule(
            0.0, LEARNING_RATE, WARMUP_STEPS, MAX_STEPS
        ),
        b1=B1,
        b2=B2,
        weight_decay=WEIGHT_DECAY,
    ),
)

cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine="vanilla",
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        max_steps=MAX_STEPS,
        train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
        checkpoint_root_directory=CKPT_ROOT,
        checkpointing_options=ocp.CheckpointManagerOptions(
            save_interval_steps=SAVE_INTERVAL_STEPS,
            max_to_keep=MAX_TO_KEEP,
        ),
        metrics_logging_options=metrics_logger.MetricsLoggerOptions(
            log_dir="/tmp/content/tmp/tensorboard/grpo",
            flush_every_n_steps=20,
        ),
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
        eos_tokens=[1, 106],
    ),
)

rl_cluster = rl_cluster_lib.RLCluster(
    actor=solver_policy,
    reference=ref_model,
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)

grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[match_format_exactly, match_format_approximately],
    grpo_config=GRPOConfig(
        num_generations=NUM_GENERATIONS,
        num_iterations=NUM_ITERATIONS,
        beta=BETA,
        epsilon=EPSILON,
    ),
)

## Training (Solver Only, Planner Frozen)

In [None]:
with mesh:
    grpo_trainer.train(train_dataset)

# ✅ Summary

This notebook implements:
- **Planner (SFT-only, frozen)**
- **Solver (GRPO RL)**
- **Two-stage reasoning pipeline**
- **Preserved GSM8K reward structure**
- **No planner destabilization**

This is the correct stepping stone before full dual-policy RL.