# GRPO Demo ‚Äì Mode A (Planner + Solver)

# Overview

This notebook implements **Mode A: Planner + Solver with RL only on the Solver**.

Key idea:
- A **Planner model** (Gemma3-1B + LoRA, SFT-only / frozen) generates a *plan*.
- A **Solver model** (Gemma3-1B + LoRA, trained with GRPO) receives *(question + plan)* and generates the final reasoning + answer.
- **Only the Solver is trained with GRPO.**
- All existing GSM8K reward functions (format + correctness) are preserved.

This is the safest way to introduce multi-step reasoning without destabilizing RL.

## Environment Setup

In [None]:
import os
os.environ["HF_HUB_DISABLE_XET"] = "1"

import jax
print("JAX backend:", jax.default_backend())
print("JAX devices:", jax.devices())

JAX backend: tpu
JAX devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]


## Install Dependencies

In [None]:
!pip install -q kagglehub

!pip install -q ipywidgets

!pip install -q tensorflow
!pip install -q tensorflow_datasets
!pip install -q tensorboardX
!pip install -q transformers
!pip install -q grain
!pip install "google-tunix[prod]==0.1.3"

# !pip install -q git+https://github.com/google/tunix
# !pip install -q git+https://github.com/google/qwix

!pip uninstall -q -y flax
# !pip install -U flax
!pip install flax==0.12.0

!pip install -q datasets wandb==0.22.0

Collecting flax==0.12.0
  Using cached flax-0.12.0-py3-none-any.whl.metadata (11 kB)
Using cached flax-0.12.0-py3-none-any.whl (466 kB)
Installing collected packages: flax
Successfully installed flax-0.12.0


## Core Imports

In [None]:
import os, gc, re, csv, shutil, functools
from pprint import pprint
from pathlib import Path
from tqdm import tqdm

import jax
import jax.numpy as jnp
import optax
import wandb
import grain
import humanize

from flax import nnx
from orbax import checkpoint as ocp
from orbax.checkpoint import CheckpointManager, CheckpointManagerOptions
from orbax.checkpoint.args import StandardRestore

from tunix.generate import sampler as sampler_lib
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.sft import metrics_logger
from tunix.models.gemma3 import params, model

import qwix
import tensorflow_datasets as tfds
from datasets import load_dataset

In [None]:
import wandb, os
os.environ['WANDB_API_KEY'] = "63a696cb4dc8f3fa953f0f109b8b2f68e575e8a0"

## Hyperparameters (Unchanged from Base GRPO)

In [None]:
TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR = "./data/test"
TRAIN_FRACTION = 1.0

RANK = 64
ALPHA = 64.0

MESH = [(1, 1), ("fsdp", "tp")]

MAX_PROMPT_LENGTH = 256
TOTAL_GENERATION_STEPS = 512
TEMPERATURE = 0.9
TOP_P = 1.0
TOP_K = 50
NUM_GENERATIONS = 4

# For planner RL (format-only training), keep it simple:
PLANNER_NUM_GENERATIONS = 2

# For solver RL (format + correctness), you can keep multiple generations:
SOLVER_NUM_GENERATIONS = 2  # or 2 if you want to be lighter


NUM_ITERATIONS = 1
BETA = 0.08
EPSILON = 0.2

TRAIN_MICRO_BATCH_SIZE = 2
NUM_BATCHES = 1000
NUM_TEST_BATCHES = 100
NUM_EPOCHS = 1

MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)

LEARNING_RATE = 3e-6
B1, B2 = 0.9, 0.99
WEIGHT_DECAY = 0.1
WARMUP_STEPS = 0.1 * MAX_STEPS
MAX_GRAD_NORM = 0.1

from google.colab import drive
drive.mount('/content/drive')

CKPT_ROOT = "/content/drive/MyDrive/tunix_ckpts_modeB"

PLANNER_CKPT_ROOT = f"{CKPT_ROOT}/planner"
SOLVER_CKPT_ROOT  = f"{CKPT_ROOT}/solver"

INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt"
# CKPT_ROOT = "/content/working/ckpts"
ACTOR_CKPT_DIR = os.path.join(CKPT_ROOT, "actor")

SAVE_INTERVAL_STEPS = 500
MAX_TO_KEEP = 4

# NEW: Separate roots for planner vs solver RL
# PLANNER_CKPT_ROOT = os.path.join(CKPT_ROOT, "planner")
# SOLVER_CKPT_ROOT = os.path.join(CKPT_ROOT, "solver")

PLANNER_ACTOR_CKPT_DIR = os.path.join(PLANNER_CKPT_ROOT, "actor")
SOLVER_ACTOR_CKPT_DIR = os.path.join(SOLVER_CKPT_ROOT, "actor")

for d in [INTERMEDIATE_CKPT_DIR, PLANNER_CKPT_ROOT, SOLVER_CKPT_ROOT,
          PLANNER_ACTOR_CKPT_DIR, SOLVER_ACTOR_CKPT_DIR]:
    os.makedirs(d, exist_ok=True)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
wandb.init(
        project="gemma3-grpo-planner-solver",
        name="solver-grpo-v4",
        group="solver",
        config={
            "model": "gemma3-1b-solver",
            "rank": RANK,
            "alpha": ALPHA,
            "max_steps": MAX_STEPS,
            "learning_rate": LEARNING_RATE,
            "beta": BETA,
            "epsilon": EPSILON,
            "num_generations": SOLVER_NUM_GENERATIONS,
        },
    )

0,1
jax/orbax/write/sharded_array_gb,‚ñÅ

0,1
jax/orbax/write/sharded_array_gb,0.00082


## Planner + Solver Prompt Templates

In [None]:
PLAN_START = "<plan>"
PLAN_END = "</plan>"

reasoning_start = "<reasoning>"
reasoning_end = "</reasoning>"
solution_start = "<answer>"
solution_end = "</answer>"

SYSTEM_PROMPT = f"""
Follow the plan. Show each step of following the plan between <reasoning> and </reasoning>.
Then output the final number between <answer> and </answer>.
"""

NULL_PLANNER_TEMPLATE = f"""
"""

PLANNER_TEMPLATE = f"""
<start_of_turn>user
You are a planning assistant. Produce a short numbered plan (3‚Äì5 steps)
for solving the problem. Do NOT solve the problem.

Problem:
{{question}}
<end_of_turn>

<start_of_turn>planner
{PLAN_START}
"""

SOLVER_TEMPLATE = f"""
You are a mathematical reasoning agent.

Your task:
1. Follow the provided plan EXACTLY.
2. Write detailed reasoning in the <reasoning>...</reasoning> block.
3. Place ONLY the final numeric answer in <answer>...</answer>.
4. After </answer>, STOP immediately.

Response format:
<solution>
<reasoning>
[step-by-step reasoning following the plan; do NOT skip steps]
</reasoning>
<answer>
[FINAL NUMERIC ANSWER ONLY]
</answer>
</solution>

Problem:
{{question}}

Plan:
{{plan}}

Begin.
"""

## Dataset Loader (Unchanged GSM8K)

In [None]:
def extract_hash_answer(text: str) -> str | None:
  if "####" not in text:
    return None
  return text.split("####")[1].strip()

def get_dataset(data_dir, split="train") -> grain.MapDataset:
  os.makedirs(data_dir, exist_ok=True)
  data = load_dataset("gsm8k", "main", split=split)

  def _as_text(v):
      return v if isinstance(v, str) else v.decode("utf-8")

  dataset = (
      grain.MapDataset.source(data)
      .shuffle(seed=42)
      .map(
          lambda x: {
              "question": _as_text(x["question"]),
              "answer": extract_hash_answer(_as_text(x["answer"])),
          }
      )
  )
  return dataset

## Train / Test Split

In [None]:
dataset = get_dataset(TRAIN_DATA_DIR, "train").batch(TRAIN_MICRO_BATCH_SIZE)[:NUM_BATCHES]

train_dataset = dataset.repeat(NUM_EPOCHS)
test_dataset = get_dataset(TEST_DATA_DIR, "test").batch(TRAIN_MICRO_BATCH_SIZE)[:NUM_TEST_BATCHES]

In [None]:
# base `train_dataset` is already defined as batches of {"question","answer"}

def to_planner_view(batch):
    # batch: dict with "question" (and "answer", unused here)
    qs = batch["question"]
    planner_prompts = [
        PLANNER_TEMPLATE.format(question=q) for q in qs
    ]
    return {
        # RLCluster will treat this as the input prompt
        "prompt": planner_prompts,
        # keep around the answer if you want future more advanced planner rewards
        "answer": batch["answer"],
    }

planner_train_dataset = train_dataset.map(to_planner_view)

## Save Original Gemma Checkpoint into NNX Format

In [None]:
from tunix.models.gemma3 import params, model

MODEL_CP_PATH = params.GEMMA3_1B_IT
config = model.ModelConfig.gemma3_1b()

gemma = params.create_model_from_checkpoint(MODEL_CP_PATH, config)
tokenizer = params.create_tokenizer()

checkpointer = ocp.StandardCheckpointer()
_, state = nnx.split(gemma)

ckpt_manager = CheckpointManager(
    INTERMEDIATE_CKPT_DIR,
    checkpointers=checkpointer,
    options=CheckpointManagerOptions(save_interval_steps=1, max_to_keep=1),
)
ckpt_manager.save(0, state)
ckpt_manager.wait_until_finished()

del gemma, state, params
gc.collect()



1264

## Load Reference Model + Apply LoRA Separately for Planner and Solver

In [None]:
from tunix.models.gemma3 import params, model

def get_gemma_ref_model(ckpt_root):
    mesh = jax.make_mesh(*MESH)
    abs_gemma = nnx.eval_shape(lambda: params.create_model_from_checkpoint(MODEL_CP_PATH, config))

    abs_state = nnx.state(abs_gemma)
    abs_state = jax.tree.map(
        lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.bfloat16, sharding=s),
        abs_state,
        nnx.get_named_sharding(abs_state, mesh),
    )

    ckpt_manager = CheckpointManager(
        ckpt_root,
        checkpointers=ocp.StandardCheckpointer(),
        options=CheckpointManagerOptions(save_interval_steps=1, max_to_keep=1),
    )

    latest_step = ckpt_manager.latest_step()
    restored = ckpt_manager.restore(latest_step, args=StandardRestore(abs_state))
    graph_def, _ = nnx.split(abs_gemma)
    gemma = nnx.merge(graph_def, restored)
    return gemma, mesh


def get_lora_model(base_model, mesh):
    lora_provider = qwix.LoraProvider(
        module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum",
        rank=RANK,
        alpha=ALPHA,
    )

    lora_model = qwix.apply_lora_to_model(base_model, lora_provider, **base_model.get_model_input())

    with mesh:
        state = nnx.state(lora_model)
        pspecs = nnx.get_partition_spec(state)
        sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
        nnx.update(lora_model, sharded_state)

    return lora_model

ref_model, mesh = get_gemma_ref_model(INTERMEDIATE_CKPT_DIR)

# Separate planner and solver LoRA policies
planner_policy = get_lora_model(ref_model, mesh)
solver_policy = get_lora_model(ref_model, mesh)


  mesh = jax.make_mesh(*MESH)


## Samplers for Planner and Solver

In [None]:
planner_sampler = sampler_lib.Sampler(
    transformer=planner_policy,
    tokenizer=tokenizer, # Use the robust tokenizer
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + 128 + 32,
        num_layers=config.num_layers,
        num_kv_heads=config.num_kv_heads,
        head_dim=config.head_dim,
    ),
)

solver_sampler = sampler_lib.Sampler(
    transformer=solver_policy,
    tokenizer=tokenizer, # Use the robust tokenizer
    cache_config=sampler_lib.CacheConfig(
        # Max prompt length (256) + Max generation steps (768 in generate_with_plan) + buffer (100) = 1124
        cache_size=MAX_PROMPT_LENGTH + 1024 + 256,
        num_layers=config.num_layers,
        num_kv_heads=config.num_kv_heads,
        head_dim=config.head_dim,
    ),
)


## Planner ‚Üí Solver Generation Pipeline

In [None]:
def generate_plan(questions):
    inputs = [PLANNER_TEMPLATE.format(question=q) for q in questions]
    out = planner_sampler(
        input_strings=inputs,
        max_generation_steps=128,
        temperature=0.7,
        top_k=50,
        top_p=0.95,
        echo=False,
    )

    plans = []
    for txt in out.text:
        # Truncate runaway output
        if PLAN_END in txt:
            txt = txt.split(PLAN_END)[0] + PLAN_END

        m = re.search(r"<plan>(.*?)</plan>", txt, re.DOTALL)
        clean = m.group(1).strip() if m else txt.strip()

        # Remove any trailing </plan> or </end_of_turn> junk
        clean = re.sub(r"</?plan>", "", clean)
        clean = clean.split("<end_of_turn>")[0]

        plans.append(clean)

    return plans

def enforce_stop_strings(text, stops=["</answer>"]):
    for s in stops:
        if s in text:
            # keep everything up to and including </answer>
            return text.split(s)[0] + s
    return text

def generate_with_plan(questions):
    plans = generate_plan(questions)

    solver_inputs = [
        SOLVER_TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=q,
            plan=p,
        )
        for q, p in zip(questions, plans)
    ]

    out = solver_sampler(
        input_strings=solver_inputs,
        max_generation_steps=768,     # Increased
        max_prompt_length=MAX_PROMPT_LENGTH,
        temperature=TEMPERATURE,
        top_k=TOP_K,
        top_p=TOP_P,
        echo=False,
    )

    # Apply stop-string enforcement
    # cleaned_outputs = [enforce_stop_strings(t) for t in out.text]

    return plans, out.text


## Reward Functions (New ‚Äì Applied Only to Planner Output)

In [None]:
import jax.numpy as jnp

step_pattern = re.compile(r"^\s*\d+.", re.MULTILINE)

def planner_match_format(prompts, completions, **kwargs):
    """
    GRPO-compatible planner reward.

    IMPORTANT:
    - We return one score PER COMPLETION (len == len(completions)).
    - We return a jnp.ndarray instead of a Python list so that
      rl_learner._compute_rewards does NOT enforce len(r) == len(prompts).
    """

    scores = []

    for r in completions:
        score = 0.0

        # 1) Has <plan> and </plan> tags
        if PLAN_START in r and PLAN_END in r:
            score += 1.0
        else:
            score -= 1.0

        # 2) Count numbered steps
        num_steps = len(step_pattern.findall(r))
        if 3 <= num_steps <= 6:
            score += 1.0
        elif num_steps > 0:
            score += 0.5
        else:
            score -= 0.5

        # 3) Penalize extremely long plans
        if len(r.split()) <= 200:
            score += 0.5
        else:
            score -= 0.5

        scores.append(float(score))

    # üëâ CRITICAL: return an array, not a list
    scores = jnp.asarray(scores, dtype=jnp.float32)

    # Sanity check for you (won't trip the Tunix check)
    if scores.shape[0] != len(completions):
        raise RuntimeError(
            f"planner_match_format: scores len {scores.shape[0]} "
            f"!= completions len {len(completions)}"
        )

    return scores

def logged_planner_match_format(prompts, completions, **kwargs):
    r = planner_match_format(prompts, completions, **kwargs)

    import numpy as np
    wandb.log({
        "reward/planner_format_mean": float(np.mean(r)),
        "reward/planner_format_max": float(np.max(r)),
        "reward/planner_format_min": float(np.min(r)),
    })
    return r


In [None]:
dummy_prompts = ["p1", "p2", "p3", "p4"]
dummy_completions = [f"fake plan {i}" for i in range(16)]  # 4 prompts * 4 gens

test_rewards = planner_match_format(dummy_prompts, dummy_completions)
print("len(prompts)  =", len(dummy_prompts))
print("len(rewards)  =", len(test_rewards))
print("rewards:", test_rewards)

len(prompts)  = 4
len(rewards)  = 16
rewards: [-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.]


## Reward Functions (Unchanged ‚Äì Applied Only to Solver Output)

In [None]:
match_format = re.compile(
    r"<reasoning>[\s\S]*?</reasoning>\s*<answer>[\s\S]*?</answer>",
    re.MULTILINE
)

import jax.numpy as jnp

def match_format_exactly(prompts, completions, **kwargs):
    """
    Return one reward per PROMPT (length == B).

    For each prompt: max reward (3.0) if any completion matches the full <reasoning>...<answer> pattern
    """
    B = len(prompts)
    assert B > 0, "No prompts provided."
    total_completions = len(completions)
    G = total_completions // B

    group_scores = []
    idx = 0
    for _ in range(B):
        group_completions = completions[idx:idx + G]
        idx += G
        scores_for_group = [3.0 if match_format.search(r) else 0.0 for r in group_completions]
        group_scores.append(float(jnp.max(jnp.asarray(scores_for_group)))) # Take max score for the group

    return jnp.asarray(group_scores, dtype=jnp.float32)

def match_format_approximately(prompts, completions, **kwargs):
    """
    Return one approximate-format reward per PROMPT (length == B).

    For each prompt: average score over completions.
    """
    B = len(prompts)
    assert B > 0, "No prompts provided."
    total_completions = len(completions)
    G = total_completions // B

    group_scores = []
    idx = 0
    for _ in range(B):
        group_completions = completions[idx:idx + G]
        idx += G
        scores_for_group = []
        for r in group_completions:
            score = 0.0
            score += 0.5 if r.count(reasoning_start) == 1 else -0.5
            score += 0.5 if r.count(reasoning_end) == 1 else -0.5
            score += 0.5 if r.count(solution_start) == 1 else -0.5
            score += 0.5 if r.count(solution_end) == 1 else -0.5
            scores_for_group.append(score)
        group_scores.append(float(jnp.mean(jnp.asarray(scores_for_group)))) # Take average score for the group

    return jnp.asarray(group_scores, dtype=jnp.float32)


def clean_completions_for_rl(completions):
    return [enforce_stop_strings(c) for c in completions]

def logged_match_format_exactly(prompts, completions, **kwargs):
    completions = clean_completions_for_rl(completions)
    r = match_format_exactly(prompts, completions)

    import numpy as np
    wandb.log({
        "reward/format_exact_mean": float(np.mean(r)),
        "reward/format_exact_max": float(np.max(r)),
        "reward/format_exact_min": float(np.min(r)),
    })

    return r

def logged_match_format_approximately(prompts, completions, **kwargs):
    r = match_format_approximately(prompts, completions, **kwargs)

    import numpy as np
    wandb.log({
        "reward/format_approx_mean": float(np.mean(r)),
        "reward/format_approx_max": float(np.max(r)),
    })

    return r

In [None]:
def check_answer(prompts, completions, answer, **kwargs):
    completions = clean_completions_for_rl(completions)

    B = len(prompts)
    total_completions = len(completions)
    G = total_completions // B

    group_scores = []
    idx = 0
    for i in range(B):
        group = completions[idx:idx+G]
        idx += G

        true = str(answer[i]).strip()
        best = 0.0

        for r in group:
            # extract answer
            m = match_numbers.search(r)
            if m and m.group(1).strip() == true:
                best = max(best, 3.0)  # full credit on exact match

            # penalize trailing garbage *only if we had some credit*
            after = r.split("</answer>")[-1].strip()
            if after and best > 0:
                best = max(best - 1.0, 0.0)    # never negative

        group_scores.append(best)

    return jnp.asarray(group_scores, dtype=jnp.float32)



def logged_check_answer(prompts, completions, answer, **kwargs):
    r = check_answer(prompts, completions, answer=answer, **kwargs)

    import numpy as np
    r_arr = np.asarray(r, dtype=np.float32)

    # 0/3 rewards ‚Üí convert to [0,1] accuracy per prompt
    hits = (r_arr > 0).astype(np.float32)
    batch_acc = float(hits.mean())

    wandb.log({
        "reward/answer_mean": float(r_arr.mean()),
        "reward/answer_max": float(r_arr.max()),
        "metric/answer_accuracy": batch_acc,          # üëà % of prompts in batch with a correct answer
    })

    return r

In [None]:
import numpy as np

match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL
)
match_numbers.findall(f"{solution_start}  0.34  {solution_end}")

match_ans = re.compile(r"(.+?)", re.DOTALL)

# Removing final_answer_extraction_pattern as it's no longer needed, using match_numbers instead
# final_answer_extraction_pattern = re.compile(r"\*\*Final Answer:\*\*\s*([\d\.\-]+)", re.DOTALL)
def validate_reward_fn(reward_fn, name):
    print(f"\n\U0001f50d Validating reward fn: {name}")

    B = 2   # batch size
    G = 4   # num_generations - set to a value >1 for validation

    test_prompts = [f"prompt {i}" for i in range(B)]
    test_completions = [f"completion {i}" for i in range(B * G)] # Still generate B*G completions
    test_answers = ["3", "7"]

    kwargs = {}
    if "answer" in reward_fn.__code__.co_varnames:
        kwargs["answer"] = test_answers

    r = reward_fn(
        prompts=test_prompts,
        completions=test_completions,
        **kwargs
    )

    r = jnp.asarray(r)

    print("Returned shape:", r.shape)
    print("Returned values:", r)

    assert r.shape == (B,), ( # Now expecting shape (B,) instead of (B*G,)
        f"‚ùå {name} returned shape {r.shape}, expected {(B,)}"
    )
    assert jnp.isfinite(r).all(), (
        f"‚ùå {name} returned NaN/Inf values"
    )

    print(f"\u2705 {name} PASSED shape + numeric checks")



# ‚úÖ RUN THESE ONCE BEFORE TRAINING
validate_reward_fn(logged_match_format_exactly, "format_exact")
validate_reward_fn(logged_match_format_approximately, "format_approx")
validate_reward_fn(logged_check_answer, "answer")


üîç Validating reward fn: format_exact
Returned shape: (2,)
Returned values: [0. 0.]
‚úÖ format_exact PASSED shape + numeric checks

üîç Validating reward fn: format_approx
Returned shape: (2,)
Returned values: [-2. -2.]
‚úÖ format_approx PASSED shape + numeric checks

üîç Validating reward fn: answer
Returned shape: (2,)
Returned values: [0. 0.]
‚úÖ answer PASSED shape + numeric checks


## GRPO Setup (Planner Only)

In [None]:
# ---------- NEW: Planner GRPO setup ----------

planner_optimizer = optax.chain(
    optax.clip_by_global_norm(MAX_GRAD_NORM),
    optax.adamw(
        learning_rate=optax.schedules.warmup_cosine_decay_schedule(
            0.0, LEARNING_RATE, WARMUP_STEPS, MAX_STEPS
        ),
        b1=B1,
        b2=B2,
        weight_decay=WEIGHT_DECAY,
    ),
)

planner_cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine="vanilla",
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=planner_optimizer,
        max_steps=MAX_STEPS,
        train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
        mini_batch_size=TRAIN_MICRO_BATCH_SIZE,
        checkpoint_root_directory=PLANNER_CKPT_ROOT,
        checkpointing_options=ocp.CheckpointManagerOptions(
            save_interval_steps=SAVE_INTERVAL_STEPS,
            max_to_keep=MAX_TO_KEEP,
        ),
        metrics_logging_options=metrics_logger.MetricsLoggerOptions(
            log_dir="/tmp/content/tmp/tensorboard/grpo_planner",
            flush_every_n_steps=20,
        ),
        eval_every_n_steps=SAVE_INTERVAL_STEPS,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=128,               # plans are short
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_PROMPT_LENGTH + 512 + 256,
        temperature=0.7,
        top_p=0.95,
        top_k=50,
        eos_tokens=(1, 106),
    ),
)

planner_rl_cluster = rl_cluster_lib.RLCluster(
    actor=planner_policy,
    reference=ref_model,
    tokenizer=tokenizer,
    cluster_config=planner_cluster_config,
)

planner_grpo_trainer = GRPOLearner(
    rl_cluster=planner_rl_cluster,
    reward_fns=[logged_planner_match_format],
    grpo_config=GRPOConfig(
        num_generations=PLANNER_NUM_GENERATIONS,   # ‚úÖ 1
        num_iterations=NUM_ITERATIONS,
        beta=BETA,
        epsilon=EPSILON,
    ),
)




0,1
jax/core/compile/jaxpr_trace_duration,‚ñÅ
jax/orbax/write/sharded_array_gb,‚ñÅ
metric/answer_accuracy,‚ñÅ
reward/answer_max,‚ñÅ
reward/answer_mean,‚ñÅ
reward/format_approx_max,‚ñÅ
reward/format_approx_mean,‚ñÅ
reward/format_exact_max,‚ñÅ
reward/format_exact_mean,‚ñÅ
reward/format_exact_min,‚ñÅ

0,1
jax/core/compile/jaxpr_trace_duration,1764787696.89261
jax/orbax/write/sharded_array_gb,0.0
metric/answer_accuracy,0.0
reward/answer_max,0.0
reward/answer_mean,0.0
reward/format_approx_max,-2.0
reward/format_approx_mean,-2.0
reward/format_exact_max,0.0
reward/format_exact_mean,0.0
reward/format_exact_min,0.0


0,1
jax/orbax/write/sharded_array_gb,‚ñÅ

0,1
jax/orbax/write/sharded_array_gb,0.00082


## GRPO Setup (Solver Only)

In [None]:
# ---------------------------------------------------------
# Solver Prompt Function (injects Planner‚ÜíSolver pipeline)
# ---------------------------------------------------------

def solver_prompt_fn(batch):
    """
    Given a batch from the GSM8K dataset, dynamically:
      1) generates a plan using the trained Planner policy,
      2) constructs Solver prompts in SOLVER_TEMPLATE format.

    This function is called *inside* the RL rollout loop,
    meaning the Solver is always trained on:
         question + (planner-generated plan)
    """
    questions = batch["question"]

    # Generate plans using the *current* planner policy
    plans = generate_plan(questions)

    # Construct solver prompts
    prompts = [
        SOLVER_TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=q,
            plan=p,
            reasoning_start=reasoning_start,
            reasoning_end=reasoning_end,
            solution_start=solution_start,
            solution_end=solution_end
        )
        for q, p in zip(questions, plans)
    ]

    return prompts


# ---------------------------------------------------------
# Solver Optimizer
# ---------------------------------------------------------

solver_optimizer = optax.chain(
    optax.clip_by_global_norm(MAX_GRAD_NORM),
    optax.adamw(
        learning_rate=optax.schedules.warmup_cosine_decay_schedule(
            0.0, LEARNING_RATE, WARMUP_STEPS, MAX_STEPS
        ),
        b1=B1,
        b2=B2,
        weight_decay=WEIGHT_DECAY,
    ),
)


# ---------------------------------------------------------
# Solver ClusterConfig (with injected solver_prompt_fn)
# ---------------------------------------------------------

solver_cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine="vanilla",
    offload_to_cpu=False,

    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=solver_optimizer,
        max_steps=MAX_STEPS,
        train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
        mini_batch_size=TRAIN_MICRO_BATCH_SIZE,
        checkpoint_root_directory=SOLVER_CKPT_ROOT,
        checkpointing_options=ocp.CheckpointManagerOptions(
            save_interval_steps=SAVE_INTERVAL_STEPS,
            max_to_keep=MAX_TO_KEEP,
        ),
        metrics_logging_options=metrics_logger.MetricsLoggerOptions(
            log_dir="/tmp/content/tmp/tensorboard/grpo_solver",
            flush_every_n_steps=20,
        ),
        eval_every_n_steps=SAVE_INTERVAL_STEPS,
    ),

    rollout_config=base_rollout.RolloutConfig(
        # max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        # max_prompt_length=MAX_PROMPT_LENGTH,
        # kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 512,
        max_tokens_to_generate=512,       # was 512
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_PROMPT_LENGTH + 1024,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
        eos_tokens=(1, 106)
    ),
)


# ---------------------------------------------------------
# Solver RLCluster + GRPOLearner
# ---------------------------------------------------------

solver_rl_cluster = rl_cluster_lib.RLCluster(
    actor=solver_policy,
    reference=ref_model,
    tokenizer=tokenizer, # Use robust_tokenizer here
    cluster_config=solver_cluster_config,
)

# Reward = format_exact + format_approx + correctness (check_answer)
solver_grpo_trainer = GRPOLearner(
    rl_cluster=solver_rl_cluster,
    reward_fns=[
        logged_match_format_exactly,
        logged_match_format_approximately,
        logged_check_answer,
    ],
    grpo_config=GRPOConfig(
        num_generations=SOLVER_NUM_GENERATIONS,   # ‚úÖ 4
        num_iterations=NUM_ITERATIONS,
        beta=BETA,
        epsilon=EPSILON,
    ),
)



0,1
jax/orbax/write/sharded_array_gb,‚ñÅ

0,1
jax/orbax/write/sharded_array_gb,0.00082


# Pretraining Evaluation

In [None]:
def evaluate(
    dataset,
    sampler,
    temperature=0.7,
    top_k=50,
    top_p=0.95,
    num_passes=1,
    corr_lst=False,
    make_lst=False,
):
  """Computes accuracy and percentage of outputs matching the format."""

  response_lst = []
  corr = 0
  partially_corr = 0
  corr_format = 0
  total = 0

  for batch in tqdm(dataset):
    answers = batch["answer"]
    questions = batch["question"]

    multiple_call_responses = [[] for _ in range(len(questions))]
    for p in range(num_passes):
      responses = generate(
          questions, sampler, temperature, top_k, top_p, seed=p
      )
      for idx, response in enumerate(responses):
        multiple_call_responses[idx].append(response)

    for question, multiple_call_response, answer in zip(
        questions, multiple_call_responses, answers
    ):
      # check answer
      corr_ctr_per_question = 0
      partially_corr_per_question = 0
      corr_format_per_question = 0
      for response in multiple_call_response:
        extracted_response = (
            guess.group(1)
            if (guess := match_numbers.search(response)) is not None
            else "-1000000"
        )
        try:
          if float(extracted_response.strip()) == float(answer.strip()):
            corr_ctr_per_question += 1

          ratio = float(extracted_response.strip()) / float(answer.strip())
          if ratio >= 0.9 and ratio <= 1.1:
            partially_corr_per_question += 1
        except:
          print("SKIPPED")

        # check format
        if match_format.search(response) is not None:
          corr_format_per_question += 1

        if (
            corr_ctr_per_question > 0
            and partially_corr_per_question > 0
            and corr_format_per_question > 0
        ):
          break

      if corr_ctr_per_question > 0:
        corr += 1
        if corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      else:
        if not corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      if partially_corr_per_question > 0:
        partially_corr += 1
      if corr_format_per_question > 0:
        corr_format += 1

      total += 1
      if total % 10 == 0:
        print(
            f"===> {corr=}, {total=}, {corr / total * 100=}, "
            f"{partially_corr / total * 100=}, {corr_format / total * 100=}"
        )

  to_return = (
      corr,
      total,
      corr / total * 100,
      partially_corr / total * 100,
      corr_format / total * 100,
  )
  if make_lst:
    return to_return, response_lst
  return to_return

In [None]:
metrics = evaluate(test_dataset, solver_sampler)

print(f"Pre-training Evaluation Metrics:")
print(f"Correct: {metrics[0]}")
print(f"Total: {metrics[1]}")
print(f"Accuracy: {metrics[2]:.2f}%")
print(f"Partially Correct Accuracy: {metrics[3]:.2f}%")
print(f"Format Match Accuracy: {metrics[4]:.2f}%")

  5%|‚ñå         | 5/100 [01:04<13:55,  8.79s/it]

===> corr=1, total=10, corr / total * 100=10.0, partially_corr / total * 100=10.0, corr_format / total * 100=20.0


 10%|‚ñà         | 10/100 [01:37<08:37,  5.75s/it]

===> corr=3, total=20, corr / total * 100=15.0, partially_corr / total * 100=25.0, corr_format / total * 100=25.0


 15%|‚ñà‚ñå        | 15/100 [01:52<05:01,  3.54s/it]

===> corr=8, total=30, corr / total * 100=26.666666666666668, partially_corr / total * 100=33.33333333333333, corr_format / total * 100=20.0


 20%|‚ñà‚ñà        | 20/100 [02:08<04:15,  3.19s/it]

===> corr=12, total=40, corr / total * 100=30.0, partially_corr / total * 100=40.0, corr_format / total * 100=17.5


 25%|‚ñà‚ñà‚ñå       | 25/100 [02:24<03:52,  3.10s/it]

===> corr=14, total=50, corr / total * 100=28.000000000000004, partially_corr / total * 100=36.0, corr_format / total * 100=18.0


 30%|‚ñà‚ñà‚ñà       | 30/100 [02:39<03:39,  3.14s/it]

===> corr=18, total=60, corr / total * 100=30.0, partially_corr / total * 100=36.666666666666664, corr_format / total * 100=23.333333333333332


 35%|‚ñà‚ñà‚ñà‚ñå      | 35/100 [02:54<03:18,  3.06s/it]

===> corr=23, total=70, corr / total * 100=32.857142857142854, partially_corr / total * 100=38.57142857142858, corr_format / total * 100=24.285714285714285


 40%|‚ñà‚ñà‚ñà‚ñà      | 40/100 [03:10<03:04,  3.07s/it]

===> corr=27, total=80, corr / total * 100=33.75, partially_corr / total * 100=38.75, corr_format / total * 100=23.75


 45%|‚ñà‚ñà‚ñà‚ñà‚ñå     | 45/100 [03:25<02:48,  3.06s/it]

SKIPPED
===> corr=32, total=90, corr / total * 100=35.55555555555556, partially_corr / total * 100=40.0, corr_format / total * 100=21.11111111111111


 50%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 50/100 [03:40<02:33,  3.07s/it]

===> corr=34, total=100, corr / total * 100=34.0, partially_corr / total * 100=40.0, corr_format / total * 100=21.0


 51%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 51/100 [03:43<02:29,  3.04s/it]

SKIPPED


 55%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå    | 55/100 [03:56<02:18,  3.09s/it]

===> corr=37, total=110, corr / total * 100=33.63636363636363, partially_corr / total * 100=39.09090909090909, corr_format / total * 100=20.909090909090907


 60%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà    | 60/100 [04:11<02:02,  3.07s/it]

===> corr=40, total=120, corr / total * 100=33.33333333333333, partially_corr / total * 100=38.333333333333336, corr_format / total * 100=20.0


 65%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå   | 65/100 [04:26<01:44,  2.99s/it]

===> corr=45, total=130, corr / total * 100=34.61538461538461, partially_corr / total * 100=39.23076923076923, corr_format / total * 100=18.461538461538463


 66%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå   | 66/100 [04:29<01:42,  3.01s/it]

SKIPPED


 67%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã   | 67/100 [04:32<01:40,  3.06s/it]

SKIPPED


 70%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà   | 70/100 [04:41<01:32,  3.09s/it]

===> corr=45, total=140, corr / total * 100=32.142857142857146, partially_corr / total * 100=36.42857142857142, corr_format / total * 100=17.857142857142858


 72%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè  | 72/100 [04:47<01:25,  3.05s/it]

SKIPPED


 75%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå  | 75/100 [04:57<01:16,  3.05s/it]

===> corr=50, total=150, corr / total * 100=33.33333333333333, partially_corr / total * 100=37.333333333333336, corr_format / total * 100=18.0


 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 80/100 [05:12<01:01,  3.07s/it]

===> corr=52, total=160, corr / total * 100=32.5, partially_corr / total * 100=36.25, corr_format / total * 100=19.375


 85%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå | 85/100 [05:27<00:44,  2.97s/it]

===> corr=56, total=170, corr / total * 100=32.94117647058823, partially_corr / total * 100=37.05882352941177, corr_format / total * 100=18.823529411764707


 90%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà | 90/100 [05:42<00:29,  2.97s/it]

===> corr=57, total=180, corr / total * 100=31.666666666666664, partially_corr / total * 100=35.55555555555556, corr_format / total * 100=18.88888888888889


 95%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå| 95/100 [05:57<00:14,  2.95s/it]

===> corr=60, total=190, corr / total * 100=31.57894736842105, partially_corr / total * 100=35.26315789473684, corr_format / total * 100=18.421052631578945


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 100/100 [06:11<00:00,  3.72s/it]

===> corr=62, total=200, corr / total * 100=31.0, partially_corr / total * 100=34.5, corr_format / total * 100=18.5
Pre-training Evaluation Metrics:
Correct: 62
Total: 200
Accuracy: 31.00%
Partially Correct Accuracy: 34.50%
Format Match Accuracy: 18.50%





# Trace Entire Output Line

In [None]:
def trace_pipeline(question, true_answer):
    print("\n--- Tracing Pipeline ---")
    print(f"Question: {question}")

    # 1. Construct planner input
    planner_input = PLANNER_TEMPLATE.format(question=question)
    print(f"\nPlanner Input:\n{planner_input}")

    # 2. Generate plan
    plans = generate_plan([question])
    plan = plans[0]
    print(f"\nGenerated Plan:\n{PLAN_START}\n{plan}\n{PLAN_END}")

    # 2.1 Planner format score
    planner_format_scores = planner_match_format(
        prompts=[planner_input],
        completions=[f"{PLAN_START}\n{plan}\n{PLAN_END}"]
    )
    print(f"\nPlanner Format Match Score: {planner_format_scores[0]}")

    # 3. Construct solver input
    solver_input = SOLVER_TEMPLATE.format(
        system_prompt=SYSTEM_PROMPT,
        question=question,
        plan=plan,
        reasoning_start=reasoning_start,
        reasoning_end=reasoning_end,
        solution_start=solution_start,
        solution_end=solution_end
    )
    print(f"\nSolver Input:\n{solver_input}")

    # 4. Solver output
    _, solver_outputs = generate_with_plan([question])
    solver_output = solver_outputs[0]
    print(f"\nSolver Output:\n{solver_output}")

    # 5. Solver format scores
    solver_format_exact_scores = match_format_exactly(
        prompts=[solver_input],
        completions=[solver_output]
    )
    solver_format_approx_scores = match_format_approximately(
        prompts=[solver_input],
        completions=[solver_output]
    )
    print(f"\nSolver Format Match (Exact): {solver_format_exact_scores[0]}")
    print(f"Solver Format Match (Approx): {solver_format_approx_scores[0]}")

    # 6. Reward function: check_answer
    answer_scores = check_answer(
        prompts=[solver_input],
        completions=[solver_output],
        answer=[true_answer]
    )
    print(f"\nCheck Answer Score: {answer_scores[0]}")

    print("--- End Tracing Pipeline ---")


**Reasoning**:
Now that the `trace_pipeline` function is defined, I will select the first sample from the `test_dataset` and call the `trace_pipeline` function with its question and answer to demonstrate its functionality.



In [None]:
first_sample = next(iter(test_dataset))
# first_sample = next(iter(test_dataset.shuffle(seed=60)))
first_question = first_sample['question'][0]
first_answer = first_sample['answer'][0]

print(first_question)
print(first_answer)

trace_pipeline(first_question, first_answer)

Mr Hezekiah had 20 trucks from his store supplying fertiliser to different farmers in his hometown dispatched for delivery on a particular day. Each truck was carrying 20 tons of fertiliser packed in bags. Two hours after the trucks had departed for delivery, Mr Hezekiah got the news that a quarter of the number of lorries dispatched for delivery had mechanical failures on the road and could not deliver the fertilisers to the farmers. Calculate the total number of tons of fertiliser that reached the farmers that day?
300

--- Tracing Pipeline ---
Question: Mr Hezekiah had 20 trucks from his store supplying fertiliser to different farmers in his hometown dispatched for delivery on a particular day. Each truck was carrying 20 tons of fertiliser packed in bags. Two hours after the trucks had departed for delivery, Mr Hezekiah got the news that a quarter of the number of lorries dispatched for delivery had mechanical failures on the road and could not deliver the fertilisers to the farmers

# Task
I will now execute the training phase for both the planner and solver. This involves running the code in cell `O1OyzgJP9X4E` which initiates the two-phase training process using the `planner_grpo_trainer` and `solver_grpo_trainer` instances. This will log metrics to WandB and save checkpoints to `PLANNER_CKPT_ROOT` and `SOLVER_CKPT_ROOT` as configured.

## Verify Planner Training and Logging Configuration

### Subtask:
Inspect the `planner_cluster_config` to confirm that `checkpoint_root_directory`, `checkpointing_options`, and `metrics_logging_options` are correctly configured for the planner's training, ensuring model saving and WandB tracking are active.


**Reasoning**:
To verify the planner's checkpoint and logging configurations, I will print the relevant attributes from the `planner_cluster_config` object as specified in the instructions.



In [None]:
print(f"Planner Checkpoint Root Directory: {planner_cluster_config.training_config.checkpoint_root_directory}")
print(f"Planner Checkpointing Options: save_interval_steps={planner_cluster_config.training_config.checkpointing_options.save_interval_steps}, max_to_keep={planner_cluster_config.training_config.checkpointing_options.max_to_keep}")
print(f"Planner Metrics Logging Directory: {planner_cluster_config.training_config.metrics_logging_options.log_dir}")

Planner Checkpoint Root Directory: /content/drive/MyDrive/tunix_ckpts_modeB/planner
Planner Checkpointing Options: save_interval_steps=500, max_to_keep=4
Planner Metrics Logging Directory: /tmp/content/tmp/tensorboard/grpo_planner


## Verify Solver Training and Logging Configuration

### Subtask:
Inspect the `solver_cluster_config` to confirm that `checkpoint_root_directory`, `checkpointing_options`, and `metrics_logging_options` are correctly configured for the solver's training, ensuring model saving and WandB tracking are active. This is crucial for verifying the reward values over time.


**Reasoning**:
To verify the solver's checkpoint and logging configurations, I will print the relevant attributes from the `solver_cluster_config` object as specified in the instructions.



In [None]:
print(f"Solver Checkpoint Root Directory: {solver_cluster_config.training_config.checkpoint_root_directory}")
print(f"Solver Checkpointing Options: save_interval_steps={solver_cluster_config.training_config.checkpointing_options.save_interval_steps}, max_to_keep={solver_cluster_config.training_config.checkpointing_options.max_to_keep}")
print(f"Solver Metrics Logging Directory: {solver_cluster_config.training_config.metrics_logging_options.log_dir}")

Solver Checkpoint Root Directory: /content/drive/MyDrive/tunix_ckpts_modeB/solver
Solver Checkpointing Options: save_interval_steps=500, max_to_keep=4
Solver Metrics Logging Directory: /tmp/content/tmp/tensorboard/grpo_solver


## Execute Training Phase

### Subtask:
Run the two-phase training process for both the planner and solver using their respective `GRPOLearner` instances. This will initiate the learning, checkpointing, and metrics logging as configured.


**Reasoning**:
I need to execute the two-phase training process by running the code in the specified cell.



In [None]:
def to_planner_view(batch):
    qs = batch["question"]
    ans = batch["answer"]

    prompts = []
    answers = []

    for q, a in zip(qs, ans):
        # Defensively cast to string
        if isinstance(q, bytes):
            q = q.decode("utf-8", errors="ignore")
        if isinstance(a, bytes):
            a = a.decode("utf-8", errors="ignore")
        q = "" if q is None else str(q)
        a = "" if a is None else str(a)

        prompt = PLANNER_TEMPLATE.format(question=q)
        prompts.append(prompt)
        answers.append(a)

    return {
        "prompts": prompts,   # <- what GRPO expects
        "answer": answers,
    }

planner_train_dataset = train_dataset.map(to_planner_view)

first = next(iter(planner_train_dataset))
print(first.keys())
print(type(first["prompts"]), len(first["prompts"]))
print(type(first["prompts"][0]), repr(first["prompts"][0][:120]))


dict_keys(['prompts', 'answer'])
<class 'list'> 2
<class 'str'> '\n<start_of_turn>user\nYou are a planning assistant. Produce a short numbered plan (3‚Äì5 steps)\nfor solving the problem. Do'


In [None]:
# ============================
# ‚úÖ SAFE PLANNER TRAIN STREAM
# ============================
import numpy as np

def make_safe_planner_stream():
    for batch in planner_train_dataset:
        raw_prompts = batch["prompts"]
        raw_answers = batch["answer"]

        prompts = [str(p) for p in raw_prompts]
        answers = [str(a) for a in raw_answers]

        # ‚úÖ EXPAND PROMPTS TO MATCH num_generations
        expanded_prompts = []
        expanded_answers = []

        for p, a in zip(prompts, answers):
            for _ in range(PLANNER_NUM_GENERATIONS):
                expanded_prompts.append(p)
                expanded_answers.append(a)

        yield {
            "prompts": expanded_prompts,   # ‚úÖ now length = 16
            "answer": expanded_answers,
        }
# quick sanity check
stream = make_safe_planner_stream()
b0 = next(stream)
print("stream keys:", b0.keys())
print("prompt types:", [type(p) for p in b0["prompts"]])


stream keys: dict_keys(['prompts', 'answer'])
prompt types: [<class 'str'>, <class 'str'>, <class 'str'>, <class 'str'>]


In [None]:
orig_tokenize = sampler_lib.Sampler.tokenize

def debug_tokenize(self, input_string):
    if not isinstance(input_string, str):
        print("\n‚ùå NON-STRING HIT INSIDE SAMPLER")
        print("Type:", type(input_string))
        print("Value:", input_string)
        raise TypeError("Sampler received non-string input")

    if input_string.strip() == "":
        print("\n‚ö†Ô∏è EMPTY STRING HIT INSIDE SAMPLER")

    return orig_tokenize(self, input_string)

sampler_lib.Sampler.tokenize = debug_tokenize
print("‚úÖ Sampler tokenize is now instrumented")

‚úÖ Sampler tokenize is now instrumented


In [None]:
import numpy as np
from tunix.generate import sampler as sampler_lib

# Keep original __call__ so we can delegate.
# Store the original method only once to prevent recursion if the cell is run multiple times.
if not hasattr(sampler_lib.Sampler, '_original_tunix_sampler_call'):
    sampler_lib.Sampler._original_tunix_sampler_call = sampler_lib.Sampler.__call__
_orig_sampler_call = sampler_lib.Sampler._original_tunix_sampler_call

def safe_sampler_call(self, input_strings, *args, **kwargs):
    """
    Normalizes input_strings so that the sampler always sees
    a flat List[str], even if RLCluster passes NumPy arrays.
    """
    # Case 1: RLCluster gave us a single ndarray of prompts
    if isinstance(input_strings, np.ndarray):
        input_strings = input_strings.tolist()

    # Case 2: RLCluster gave us a list whose elements are ndarrays
    flat = []
    for x in input_strings:
        if isinstance(x, np.ndarray):
            # e.g. array(['prompt1', 'prompt2', ...], dtype='<U...')
            flat.extend(x.tolist())
        else:
            flat.append(x)

    # Force everything to plain Python strings
    clean_strings = [str(x) for x in flat]

    # Delegate to the original Sampler.__call__
    return _orig_sampler_call(self, input_strings=clean_strings, *args, **kwargs)

# üîß Monkey-patch Sampler.__call__ with our safe wrapper
sampler_lib.Sampler.__call__ = safe_sampler_call
print("‚úÖ Patched Sampler.__call__ to normalize NumPy arrays ‚Üí List[str] and avoid recursion.")

‚úÖ Patched Sampler.__call__ to normalize NumPy arrays ‚Üí List[str] and avoid recursion.


In [None]:
# ---------- NEW: Two-phase RL training (safe) ----------

with mesh:
    planner_run = wandb.init(
        project="gemma3-grpo-planner-solver",
        name="planner-grpo",
        group="planner",
        config={
            "model": "gemma3-1b-planner",
            "rank": RANK,
            "alpha": ALPHA,
            "max_steps": MAX_STEPS,
            "learning_rate": LEARNING_RATE,
            "beta": BETA,
            "epsilon": EPSILON,
            "num_generations": NUM_GENERATIONS,
        },
    )

    print("üöÄ Stage 1: Training PLANNER with GRPO...")
    planner_grpo_trainer.train(make_safe_planner_stream())
    print("‚úÖ Planner training complete.")


0,1
jax/core/compile/backend_compile_duration,‚ñÅ
jax/core/compile/jaxpr_to_mlir_module_duration,‚ñÅ
jax/core/compile/jaxpr_trace_duration,‚ñÅ

0,1
jax/core/compile/backend_compile_duration,1764782861.13301
jax/core/compile/jaxpr_to_mlir_module_duration,1764782861.12474
jax/core/compile/jaxpr_trace_duration,1764782861.12316


üöÄ Stage 1: Training PLANNER with GRPO...


Actor Training:   0%|          | 0/1000 [00:00<?, ?step/s]



0,1
actor/train/kl,‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÅ‚ñÇ‚ñÅ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñà‚ñÇ‚ñÇ‚ñÅ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÇ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÇ‚ñÅ‚ñÅ‚ñÇ‚ñÅ‚ñÇ
actor/train/loss,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÜ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñà‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
actor/train/perplexity,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÜ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÜ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñà‚ñÅ‚ñÅ‚ñÅ
actor/train/step_time_sec,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñà‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñà‚ñÅ
actor/train/steps_per_sec,‚ñà‚ñà‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñá‚ñá‚ñà‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñÅ‚ñà‚ñá‚ñà‚ñà‚ñà‚ñà
actor/train/tflops_per_step,‚ñÅ
jax/core/compile/backend_compile_duration,‚ñÅ
jax/core/compile/jaxpr_to_mlir_module_duration,‚ñÅ
jax/core/compile/jaxpr_trace_duration,‚ñÅ
jax/orbax/write/sharded_array_gb,‚ñÅ

0,1
actor/train/kl,0.00149
actor/train/loss,0.00012
actor/train/perplexity,1.00012
actor/train/step_time_sec,0.05397
actor/train/steps_per_sec,18.52759
actor/train/tflops_per_step,6.20142
jax/core/compile/backend_compile_duration,1764782948.99968
jax/core/compile/jaxpr_to_mlir_module_duration,1764782948.99568
jax/core/compile/jaxpr_trace_duration,1764782948.99372
jax/orbax/write/sharded_array_gb,0.00082


‚úÖ Planner training complete.


# Evaluate Trained Model

In [None]:
def load_latest_planner_lora_from_actor_ckpts(
    planner_policy,
    planner_actor_ckpt_dir,
):
    """
    Loads the latest GRPO-trained LoRA weights for the planner
    from PLANNER_ACTOR_CKPT_DIR/{step}/model_params
    and injects them into planner_policy.
    """

    import os, re
    from orbax import checkpoint as ocp

    # ---------------------------
    # 1. Find latest numeric step
    # ---------------------------
    latest_step = -1
    if os.path.exists(planner_actor_ckpt_dir):
        for item in os.listdir(planner_actor_ckpt_dir):
            full = os.path.join(planner_actor_ckpt_dir, item)
            if os.path.isdir(full) and re.match(r"^\d+$", item):
                step = int(item)
                if step > latest_step:
                    latest_step = step

    if latest_step == -1:
        raise FileNotFoundError(
            f"No checkpoints found in {planner_actor_ckpt_dir}"
        )

    print(f"‚úÖ Latest planner checkpoint step: {latest_step}")

    # ---------------------------
    # 2. Build full checkpoint path
    # ---------------------------
    trained_ckpt_path = os.path.join(
        planner_actor_ckpt_dir,
        str(latest_step),
        "model_params",
    )

    print(f"üìÇ Loading planner LoRA from:\n{trained_ckpt_path}")

    # ---------------------------
    # 3. Build abstract target for restore
    # ---------------------------
    abs_params = jax.tree.map(
        lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
        nnx.state(planner_policy, nnx.LoRAParam),
    )

    # ---------------------------
    # 4. Restore with Orbax
    # ---------------------------
    checkpointer = ocp.StandardCheckpointer()
    trained_lora_params = checkpointer.restore(
        trained_ckpt_path,
        target=abs_params,
    )

    # ---------------------------
    # 5. Inject into planner policy
    # ---------------------------
    nnx.update(
        planner_policy,
        jax.tree.map(
            lambda _, b: b,
            nnx.state(planner_policy, nnx.LoRAParam),
            trained_lora_params,
        ),
    )

    print("‚úÖ Planner LoRA successfully loaded into planner_policy.")


load_latest_planner_lora_from_actor_ckpts(
    planner_policy,
    PLANNER_ACTOR_CKPT_DIR,
)


‚úÖ Latest planner checkpoint step: 1000
üìÇ Loading planner LoRA from:
/content/drive/MyDrive/tunix_ckpts_modeB/planner/actor/1000/model_params




‚úÖ Planner LoRA successfully loaded into planner_policy.


In [None]:
random_sample = next(iter(test_dataset.shuffle(seed=3)))
random_question = random_sample['question'][0]

print(f"Random Question from Test Set:\n{random_question}\n")

generated_plans = generate_plan([random_question])

if generated_plans:
    print(f"Generated Plan:\n{generated_plans[0]}")
else:
    print("Could not generate a plan for the question.")

Random Question from Test Set:
Felix notices that kids in the neighborhood are always getting things stuck in trees. Since he is an expert tree climber, he decided to start charging kids to get their stuff out. He charges based on how high he has to climb. Every branch he has to climb up costs $.25. During the week he made $105. On average, how many branches did he climb per day?

Generated Plan:
1.  Calculate the total number of branches climbed during the week.
2.  Calculate the total cost of the climbs.
3.  Calculate the daily climb cost.
4.  Divide the total weekly income by the total daily income to determine the number of branches climbed per day.
5.  Present the answer in a clear and concise format.


# With Loaded Planner Train Solver

In [None]:
def make_solver_train_stream():
    for batch in train_dataset:

        questions = batch["question"]
        answers   = batch["answer"]

        safe_qs = [str(q) for q in questions]
        safe_as = [str(a) for a in answers]

        # ‚úÖ 1. Planner generates plans (frozen weights)
        plans = generate_plan(safe_qs)

        # ‚úÖ 2. Inject into solver template
        solver_prompts = [
            SOLVER_TEMPLATE.format(
                system_prompt=SYSTEM_PROMPT,
                question=q,
                plan=p,
                reasoning_start=reasoning_start,
                reasoning_end=reasoning_end,
                solution_start=solution_start,
                solution_end=solution_end,
            )
            for q, p in zip(safe_qs, plans)
        ]

        yield {
            "prompts": solver_prompts,
            "answer":  safe_as,
            "question": safe_qs,
        }

from tunix.rl.grpo import grpo_helpers

# _orig_adv = grpo_helpers.compute_advantages

# def debug_advantages(rewards, num_generations):
#     print("\nüß™ Advantage Debug:")
#     print("rewards shape:", jnp.asarray(rewards).shape)
#     print("num_generations:", num_generations)

#     adv = _orig_adv(rewards, num_generations)

#     print("advantages shape:", adv.shape)
#     return adv

# grpo_helpers.compute_advantages = debug_advantages


In [None]:
import jax.numpy as jnp
from tunix.rl.grpo import grpo_helpers

_orig_compute_adv = grpo_helpers.compute_advantages

def compute_advantages_patched(rewards, num_generations):
    """
    Tunix expects rewards of shape (B,), but GRPO needs (B*G,)
    Expand per-prompt rewards to per-completion.
    """
    rewards = jnp.asarray(rewards)
    B = rewards.shape[0]
    G = num_generations

    # Normalize per-prompt rewards
    mean = rewards.mean()
    std = rewards.std(ddof=1)
    std = jnp.where(std == 0, 1.0, std)
    adv_prompt = (rewards - mean) / std  # shape (B,)

    # Expand to per-completion shape (B*G,)
    adv_expanded = jnp.repeat(adv_prompt, G)

    return adv_expanded

# Patch it
grpo_helpers.compute_advantages = compute_advantages_patched
print("‚úÖ Patched GRPO compute_advantages for per-prompt ‚Üí per-completion expansion")


‚úÖ Patched GRPO compute_advantages for per-prompt ‚Üí per-completion expansion


In [None]:
import jax
import jax.numpy as jnp
from tunix.sft import metrics_logger

_original_metrics_logger_log = metrics_logger.MetricsLogger.log

def patched_metrics_logger_log(self, metric_name, scalar_value, mode, step):
    # Always append to the internal buffer, regardless of type
    self._metrics[mode][metric_name].append(scalar_value)

    # Only log to jax.monitoring.record_scalar if it's truly a scalar numeric value
    is_scalar_numeric = False
    if isinstance(scalar_value, (int, float)):
        is_scalar_numeric = True
    elif isinstance(scalar_value, (jax.Array, jnp.ndarray)):
        # Check if it's a JAX array or JAX NumPy array, is numeric, and has 0 dimensions (i.e., a scalar)
        if jnp.issubdtype(scalar_value.dtype, jnp.number) and scalar_value.ndim == 0:
            is_scalar_numeric = True

    if is_scalar_numeric:
        jax.monitoring.record_scalar(
            f"{self.metric_prefix}{mode}/{metric_name}", scalar_value, step=step
        )

metrics_logger.MetricsLogger.log = patched_metrics_logger_log
print("‚úÖ Patched tunix.sft.metrics_logger.MetricsLogger.log to only record scalar numerics.")

‚úÖ Patched tunix.sft.metrics_logger.MetricsLogger.log to only record scalar numerics.


In [None]:
with mesh:
    solver_run = wandb.init(
        project="gemma3-grpo-planner-solver",
        name="solver-grpo-v5",
        group="solver",
        config={
            "model": "gemma3-1b-solver",
            "rank": RANK,
            "alpha": ALPHA,
            "max_steps": MAX_STEPS,
            "learning_rate": LEARNING_RATE,
            "beta": BETA,
            "epsilon": EPSILON,
            "num_generations": SOLVER_NUM_GENERATIONS,
        },
    )

    print("üöÄ Stage 2: Training SOLVER with GRPO (Planner frozen)...")
    solver_grpo_trainer.train(make_solver_train_stream())
    print("‚úÖ Solver GRPO training complete.")

    wandb.finish()

0,1
jax/core/compile/backend_compile_duration,‚ñÅ
jax/core/compile/jaxpr_to_mlir_module_duration,‚ñÅ
jax/core/compile/jaxpr_trace_duration,‚ñÅ
jax/orbax/write/sharded_array_gb,‚ñÅ

0,1
jax/core/compile/backend_compile_duration,1764783739.01391
jax/core/compile/jaxpr_to_mlir_module_duration,1764783739.00515
jax/core/compile/jaxpr_trace_duration,1764783739.00339
jax/orbax/write/sharded_array_gb,0.00082


üöÄ Stage 2: Training SOLVER with GRPO (Planner frozen)...


Actor Training:   0%|          | 0/1000 [00:00<?, ?step/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m


0,1
jax/core/compile/backend_compile_duration,‚ñÅ
jax/core/compile/jaxpr_to_mlir_module_duration,‚ñÅ
jax/core/compile/jaxpr_trace_duration,‚ñÅ
metric/answer_accuracy,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
reward/answer_max,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
reward/answer_mean,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
reward/format_approx_max,‚ñÅ‚ñà‚ñÜ‚ñÅ‚ñÉ‚ñà‚ñà‚ñÉ‚ñÉ‚ñà‚ñà‚ñÉ‚ñÜ‚ñÉ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÉ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÅ‚ñà‚ñà‚ñà‚ñÜ‚ñà‚ñÉ‚ñÜ‚ñÉ‚ñà‚ñÜ‚ñà‚ñà‚ñÅ‚ñÜ‚ñÜ‚ñÜ
reward/format_approx_mean,‚ñÖ‚ñÅ‚ñÅ‚ñÉ‚ñÖ‚ñÉ‚ñÜ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñà‚ñá‚ñÜ‚ñá‚ñÉ‚ñÜ‚ñÜ‚ñá‚ñá‚ñÜ‚ñÖ‚ñÖ‚ñÉ‚ñà‚ñá‚ñÜ‚ñÜ‚ñÜ‚ñÑ‚ñá‚ñá‚ñÜ‚ñà‚ñá‚ñÜ‚ñÜ‚ñÜ‚ñÑ‚ñá
reward/format_exact_max,‚ñÅ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñÅ‚ñÅ‚ñà‚ñà‚ñà‚ñà‚ñÅ‚ñÅ‚ñà‚ñà‚ñÅ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñÅ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
reward/format_exact_mean,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÖ‚ñÖ‚ñà‚ñÖ‚ñà‚ñà‚ñÖ‚ñÖ‚ñà‚ñÖ‚ñÖ‚ñÖ‚ñà‚ñà‚ñÖ‚ñÖ‚ñà‚ñÖ‚ñà‚ñÅ‚ñÖ‚ñÖ‚ñÖ‚ñÅ‚ñÅ‚ñÅ‚ñÖ‚ñà‚ñÖ‚ñà‚ñÖ‚ñà‚ñÖ‚ñà‚ñÅ‚ñà

0,1
jax/core/compile/backend_compile_duration,1764783790.2879
jax/core/compile/jaxpr_to_mlir_module_duration,1764783790.00263
jax/core/compile/jaxpr_trace_duration,1764783789.8875
metric/answer_accuracy,0
reward/answer_max,0
reward/answer_mean,0
reward/format_approx_max,0
reward/format_approx_mean,-0.75
reward/format_exact_max,0
reward/format_exact_mean,0


‚úÖ Solver GRPO training complete.


In [None]:
def load_latest_solver_lora_from_actor_ckpts(
    solver_policy,
    solver_actor_ckpt_dir,
):
    """
    Loads the latest GRPO-trained LoRA weights for the solver
    from SOLVER_ACTOR_CKPT_DIR/{step}/model_params
    and injects them into solver_policy.
    """

    import os, re
    from orbax import checkpoint as ocp

    # ---------------------------
    # 1. Find latest numeric step
    # ---------------------------
    latest_step = -1
    if os.path.exists(solver_actor_ckpt_dir):
        for item in os.listdir(solver_actor_ckpt_dir):
            full = os.path.join(solver_actor_ckpt_dir, item)
            if os.path.isdir(full) and re.match(r"^\\d+$", item):
                step = int(item)
                if step > latest_step:
                    latest_step = step

    if latest_step == -1:
        raise FileNotFoundError(
            f"No checkpoints found in {solver_actor_ckpt_dir}"
        )

    print(f"‚úÖ Latest solver checkpoint step: {latest_step}")

    # ---------------------------
    # 2. Build full checkpoint path
    # ---------------------------
    trained_ckpt_path = os.path.join(
        solver_actor_ckpt_dir,
        str(latest_step),
        "model_params",
    )

    print(f"üìÇ Loading solver LoRA from:\n{trained_ckpt_path}")

    # ---------------------------
    # 3. Build abstract target for restore
    # ---------------------------
    abs_params = jax.tree.map(
        lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
        nnx.state(solver_policy, nnx.LoRAParam),
    )

    # ---------------------------
    # 4. Restore with Orbax
    # ---------------------------
    checkpointer = ocp.StandardCheckpointer()
    trained_lora_params = checkpointer.restore(
        trained_ckpt_path,
        target=abs_params,
    )

    # ---------------------------
    # 5. Inject into solver policy
    # ---------------------------
    nnx.update(
        solver_policy,
        jax.tree.map(
            lambda _, b: b,
            nnx.state(solver_policy, nnx.LoRAParam),
            trained_lora_params,
        ),
    )

    print("‚úÖ Solver LoRA successfully loaded into solver_policy.")


load_latest_solver_lora_from_actor_ckpts(
    solver_policy,
    SOLVER_ACTOR_CKPT_DIR,
)

# 2. Re-evaluate with the same evaluate() loop
metrics = evaluate(test_dataset, solver_sampler)
print(metrics)


FileNotFoundError: No checkpoints found in /content/drive/MyDrive/tunix_ckpts_modeB/solver/actor