# Two Step GRPO Trained Model Interactive Notebook

# Environment & Setup

In [1]:
import os
os.environ["HF_HUB_DISABLE_XET"] = "1"

import jax
print("JAX backend:", jax.default_backend())
print("JAX devices:", jax.devices())

# --- Colab installs (leave as-is) ---
!pip install -q kagglehub
!pip install -q ipywidgets

!pip install -q tensorflow
!pip install -q tensorflow_datasets
!pip install -q tensorboardX
!pip install -q transformers
!pip install -q grain
!pip install "google-tunix[prod]==0.1.3"

!pip uninstall -q -y flax
!pip install flax==0.12.0

!pip install -q datasets



JAX backend: tpu
JAX devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]
Collecting flax==0.12.0
  Using cached flax-0.12.0-py3-none-any.whl.metadata (11 kB)
Using cached flax-0.12.0-py3-none-any.whl (466 kB)
Installing collected packages: flax
Successfully installed flax-0.12.0


# Imports & Constants

In [2]:
import gc
import re
from pathlib import Path

import jax
import jax.numpy as jnp
import numpy as np

from flax import nnx
from orbax import checkpoint as ocp
from orbax.checkpoint import CheckpointManager, CheckpointManagerOptions
from orbax.checkpoint.args import StandardRestore

from tunix.generate import sampler as sampler_lib
from tunix.models.gemma3 import params, model

import qwix
import tensorflow_datasets as tfds
from datasets import load_dataset
from tqdm import tqdm


In [3]:
# --------------------
#   Global constants
# --------------------

TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR  = "./data/test"
TRAIN_FRACTION = 1.0

RANK  = 64
ALPHA = 64.0

MESH = [(1, 1), ("fsdp", "tp")]

MAX_PROMPT_LENGTH = 256

TEMPERATURE = 0.9
TOP_P       = 1.0
TOP_K       = 50

TRAIN_MICRO_BATCH_SIZE = 2
NUM_BATCHES      = 1000
NUM_TEST_BATCHES = 100
NUM_EPOCHS       = 1

# These are only used for deriving max steps in training; here, mainly kept
# to keep directory structure consistent with RL runs.
NUM_ITERATIONS = 1
MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)

# --------------------
#   Checkpoint paths
# --------------------

from google.colab import drive
drive.mount('/content/drive')

CKPT_ROOT = "/content/drive/MyDrive/tunix_ckpts_modeB"

PLANNER_CKPT_ROOT = f"{CKPT_ROOT}/planner"
SOLVER_CKPT_ROOT  = f"{CKPT_ROOT}/solver"

INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt"

PLANNER_ACTOR_CKPT_DIR = os.path.join(PLANNER_CKPT_ROOT, "actor")
SOLVER_ACTOR_CKPT_DIR  = os.path.join(SOLVER_CKPT_ROOT, "actor")

for d in [INTERMEDIATE_CKPT_DIR, PLANNER_CKPT_ROOT, SOLVER_CKPT_ROOT,
          PLANNER_ACTOR_CKPT_DIR, SOLVER_ACTOR_CKPT_DIR]:
    os.makedirs(d, exist_ok=True)

SAVE_INTERVAL_STEPS = 500
MAX_TO_KEEP = 4

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
# --------------------
#   Prompt constants
# --------------------

PLAN_START = "<plan>"
PLAN_END   = "</plan>"

reasoning_start = "<reasoning>"
reasoning_end   = "</reasoning>"
solution_start  = "<answer>"
solution_end    = "</answer>"

SYSTEM_PROMPT = (
    "Follow the plan. Show each step of following the plan between "
    "<reasoning> and </reasoning>. Then output the final number between "
    "<answer> and </answer>."
)

PLANNER_TEMPLATE = f"""
<start_of_turn>user
You are a planning assistant. Produce a short numbered plan (3–5 steps)
for solving the problem. Do NOT solve the problem.

Problem:
{{question}}
<end_of_turn>

<start_of_turn>planner
{PLAN_START}
"""

# IMPORTANT: This matches the *actual* training usage:
# SOLVER_TEMPLATE.format(question=q, plan=p)
SOLVER_TEMPLATE = f"""
You are a mathematical reasoning agent.

Your task:
1. Follow the provided plan EXACTLY.
2. Write detailed reasoning in the <reasoning>...</reasoning> block.
3. Place ONLY the final numeric answer in <answer>...</answer>.
4. After </solution>, STOP immediately.

Response format:
<solution>
<reasoning>
[step-by-step reasoning following the plan; do NOT skip steps]
</reasoning>
<answer>
[FINAL NUMERIC ANSWER ONLY]
</answer>
</solution>

Problem:
{{question}}

Plan:
{{plan}}

Begin.
"""

# Base Gemma Setup

In [5]:
MODEL_CP_PATH = params.GEMMA3_1B_IT
config        = model.ModelConfig.gemma3_1b()
tokenizer     = params.create_tokenizer()

# Save a one-time intermediate checkpoint for Gemma base weights
gemma = params.create_model_from_checkpoint(MODEL_CP_PATH, config)
checkpointer = ocp.StandardCheckpointer()
_, state = nnx.split(gemma)

ckpt_manager = CheckpointManager(
    INTERMEDIATE_CKPT_DIR,
    checkpointers=checkpointer,
    options=CheckpointManagerOptions(save_interval_steps=1, max_to_keep=1),
)
ckpt_manager.save(0, state)
ckpt_manager.wait_until_finished()

del gemma, state, params
gc.collect()



0

# Model Helpers

In [6]:
from tunix.models.gemma3 import params, model

def enforce_stop_strings(text, stops=("</answer>",)):
    """Cut off generation at the first stop string (inclusive)."""
    for s in stops:
        if s in text:
            return text.split(s)[0] + s
    return text

def get_gemma_ref_model(ckpt_root):
    """
    Loads the *base Gemma model weights* (no LoRA).
    Returns: (restored_model, mesh)
    """
    mesh = jax.make_mesh(*MESH)

    # 1. Build abstract model
    abs_gemma = nnx.eval_shape(
        lambda: params.create_model_from_checkpoint(MODEL_CP_PATH, config)
    )

    # 2. Abstract state with sharding info
    abs_state = nnx.state(abs_gemma)
    abs_state = jax.tree.map(
        lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.bfloat16, sharding=s),
        abs_state,
        nnx.get_named_sharding(abs_state, mesh),
    )

    # 3. Restore base checkpoint
    ckpt_manager = CheckpointManager(
        ckpt_root,
        checkpointers=ocp.StandardCheckpointer(),
        options=CheckpointManagerOptions(save_interval_steps=1, max_to_keep=1),
    )

    latest_step = ckpt_manager.latest_step()
    if latest_step is None:
        raise FileNotFoundError(f"No checkpoints found in {ckpt_root}")

    restored_state = ckpt_manager.restore(
        latest_step,
        args=StandardRestore(abs_state),
    )

    # 4. Merge abstract graph + restored state
    graph_def, _ = nnx.split(abs_gemma)
    gemma = nnx.merge(graph_def, restored_state)

    print(f"✅ Loaded base Gemma model from step {latest_step}")
    return gemma, mesh


def get_lora_model(base_model, mesh):
    """
    Create a fresh LoRA-adapted policy from the base Gemma weights.
    """
    lora_provider = qwix.LoraProvider(
        module_path=(
            ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|"
            ".*up_proj|.*attn_vec_einsum"
        ),
        rank=RANK,
        alpha=ALPHA,
    )

    lora_model = qwix.apply_lora_to_model(
        base_model,
        lora_provider,
        **base_model.get_model_input(),
    )

    # Re-shard after LoRA insertion
    with mesh:
        state   = nnx.state(lora_model)
        pspecs  = nnx.get_partition_spec(state)
        sharded = jax.lax.with_sharding_constraint(state, pspecs)
        nnx.update(lora_model, sharded)

    return lora_model


def load_lora_exact(policy, ckpt_dir, label=""):
    """
    Load LoRA-only weights from a Tunix RL actor checkpoint directory into `policy`.
    """
    print(f"\n=== Loading {label} LoRA ===")

    steps = sorted([int(x) for x in os.listdir(ckpt_dir) if x.isdigit()])
    if not steps:
        raise ValueError(f"No valid checkpoint folders in {ckpt_dir}")
    latest = steps[-1]

    ckpt_path = os.path.join(ckpt_dir, str(latest), "model_params")
    print(f"→ Using checkpoint: {ckpt_path}")

    # Extract the LoRA tree
    lora_tree = nnx.state(policy, nnx.LoRAParam)

    # Abstract target
    abs_target = jax.tree.map(
        lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
        lora_tree,
    )

    checkpointer = ocp.StandardCheckpointer()
    restored = checkpointer.restore(ckpt_path, target=abs_target)

    # Inject into policy
    nnx.update(
        policy,
        jax.tree.map(lambda old, new: new, lora_tree, restored),
    )

    print(f"✓ {label} LoRA loaded successfully.")


# Optional debug helpers
def lora_l2(model):
    lora_state = nnx.state(model, nnx.LoRAParam)
    leaves = jax.tree.leaves(lora_state)
    return sum(float(jnp.linalg.norm(x)) for x in leaves)

def lora_diff(model_a, model_b):
    a = nnx.state(model_a, nnx.LoRAParam)
    b = nnx.state(model_b, nnx.LoRAParam)
    diffs = jax.tree.map(lambda x, y: float(jnp.linalg.norm(x - y)), a, b)
    leaves = jax.tree.leaves(diffs)
    return sum(leaves)

# =========================
#   4. Load Policies & Samplers
# =========================

def load_policies_and_samplers():
    """
    Returns:
      baseline_sampler: baseline LoRA-injected model (no RL LoRA)
      planner_sampler:  planner with TRAINED LoRA
      solver_sampler:   solver with TRAINED LoRA
    """

    # 1) Load base Gemma
    ref_model, mesh = get_gemma_ref_model(INTERMEDIATE_CKPT_DIR)

    # 2) Inject LoRA heads
    baseline_policy = get_lora_model(ref_model, mesh)
    planner_policy  = get_lora_model(ref_model, mesh)
    solver_policy   = get_lora_model(ref_model, mesh)

    # 3) Load TRAINED LoRA weights into planner/solver
    load_lora_exact(planner_policy, PLANNER_ACTOR_CKPT_DIR, "planner")
    load_lora_exact(solver_policy,  SOLVER_ACTOR_CKPT_DIR,  "solver")

    print("baseline L2:", lora_l2(baseline_policy))
    print("planner  L2:", lora_l2(planner_policy))
    print("solver   L2:", lora_l2(solver_policy))
    print("Δ(planner-baseline):", lora_diff(planner_policy, baseline_policy))
    print("Δ(solver-baseline): ", lora_diff(solver_policy, baseline_policy))
    print("Δ(solver-planner):  ", lora_diff(solver_policy, planner_policy))

    # 4) Build samplers AFTER LoRA is loaded.
    #
    # Planner: short generations, small-ish cache.
    baseline_sampler = sampler_lib.Sampler(
        transformer=baseline_policy,
        tokenizer=tokenizer,
        cache_config=sampler_lib.CacheConfig(
            cache_size=MAX_PROMPT_LENGTH + 256,  # prompt + ~128-200 tokens
            num_layers=config.num_layers,
            num_kv_heads=config.num_kv_heads,
            head_dim=config.head_dim,
        ),
    )

    planner_sampler = sampler_lib.Sampler(
        transformer=planner_policy,
        tokenizer=tokenizer,
        cache_config=sampler_lib.CacheConfig(
            cache_size=MAX_PROMPT_LENGTH + 256,  # same as baseline
            num_layers=config.num_layers,
            num_kv_heads=config.num_kv_heads,
            head_dim=config.head_dim,
        ),
    )

    # Solver: longer generations; make sure cache_size >= MAX_PROMPT_LENGTH + max_gen
    solver_max_gen = 768
    solver_sampler = sampler_lib.Sampler(
        transformer=solver_policy,
        tokenizer=tokenizer,
        cache_config=sampler_lib.CacheConfig(
            cache_size=MAX_PROMPT_LENGTH + solver_max_gen + 256,
            num_layers=config.num_layers,
            num_kv_heads=config.num_kv_heads,
            head_dim=config.head_dim,
        ),
    )

    return baseline_sampler, planner_sampler, solver_sampler

baseline_sampler, planner_sampler, solver_sampler = load_policies_and_samplers()

  mesh = jax.make_mesh(*MESH)


✅ Loaded base Gemma model from step 0

=== Loading planner LoRA ===
→ Using checkpoint: /content/drive/MyDrive/tunix_ckpts_modeB/planner/actor/1000/model_params




✓ planner LoRA loaded successfully.

=== Loading solver LoRA ===
→ Using checkpoint: /content/drive/MyDrive/tunix_ckpts_modeB/solver/actor/1000/model_params
✓ solver LoRA loaded successfully.
baseline L2: 1766.375
planner  L2: 1768.9099426269531
solver   L2: 1769.4983825683594
Δ(planner-baseline): 3.2076034545898438
Δ(solver-baseline):  3.5033798217773438
Δ(solver-planner):   4.628963470458984


# Dataset Loader (GSM8K)

In [7]:
def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

def get_dataset(data_dir, split="train"):
    os.makedirs(data_dir, exist_ok=True)
    data = load_dataset("gsm8k", "main", split=split)

    def _as_text(v):
        return v if isinstance(v, str) else v.decode("utf-8")

    import grain
    dataset = (
        grain.MapDataset.source(data)
        .shuffle(seed=42)
        .map(
            lambda x: {
                "question": _as_text(x["question"]),
                "answer": extract_hash_answer(_as_text(x["answer"])),
            }
        )
    )
    return dataset

In [8]:
train_dataset = (
    get_dataset(TRAIN_DATA_DIR, "train")
    .batch(TRAIN_MICRO_BATCH_SIZE)[:NUM_BATCHES]
    .repeat(NUM_EPOCHS)
)

test_dataset = (
    get_dataset(TEST_DATA_DIR, "test")
    .batch(TRAIN_MICRO_BATCH_SIZE)[:NUM_TEST_BATCHES]
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


# Planner & Solver Wrappers

In [9]:
def generate_plan_with_sampler(questions, planner_sampler):
    """
    Planner stage:
    Uses PLANNER_TEMPLATE and the given planner_sampler to generate <plan> text.
    """
    inputs = [PLANNER_TEMPLATE.format(question=q) for q in questions]

    out = planner_sampler(
        input_strings=inputs,
        max_generation_steps=128,     # plans are short
        temperature=0.3,              # more deterministic / obedient
        top_k=50,
        top_p=0.95,
        echo=False,
    )

    plans = []
    for txt in out.text:
        # Trim at PLAN_END if present
        if PLAN_END in txt:
            txt = txt.split(PLAN_END)[0] + PLAN_END

        m = re.search(r"<plan>(.*?)</plan>", txt, re.DOTALL)
        clean = m.group(1).strip() if m else txt.strip()

        # Strip tags & trailing junk
        clean = re.sub(r"</?plan>", "", clean)
        clean = clean.split("<end_of_turn>")[0]
        plans.append(clean)

    return plans


def generate_solution_with_sampler(
    questions,
    plans,
    solver_sampler,
    temperature=0.2,
    top_k=TOP_K,
    top_p=TOP_P,
    max_generation_steps=768,
):
    """
    Solver stage:
    Builds SOLVER_TEMPLATE(question, plan) and generates full CoT + answer.
    """
    solver_inputs = [
        SOLVER_TEMPLATE.format(question=q, plan=p)
        for q, p in zip(questions, plans)
    ]

    out = solver_sampler(
        input_strings=solver_inputs,
        max_generation_steps=max_generation_steps,
        max_prompt_length=MAX_PROMPT_LENGTH,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        echo=False,
    )

    outputs = [enforce_stop_strings(t) for t in out.text]
    return outputs

# Evaluation Helpers

In [10]:
# Regex for extracting numeric answer from within <answer>...</answer>
match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})",
    flags=re.MULTILINE | re.DOTALL,
)

# Regex for format correctness: <reasoning>...</reasoning> then <answer>...</answer>
match_format = re.compile(
    r"<reasoning>[\s\S]*?</reasoning>\s*<answer>[\s\S]*?</answer>",
    re.MULTILINE,
)

def evaluate_pipeline(
    dataset,
    planner_sampler,
    solver_sampler,
    temperature=0.2,
    top_k=50,
    top_p=0.9,
    num_passes=1,
    record_incorrect=False,
    record_correct=False,
):
    """
    Full 2-stage pipeline evaluation:
      planner_sampler → generates plan
      solver_sampler  → generates reasoning + answer from that plan

    Returns a metrics dict and an optional list of collected examples.
    """

    total           = 0
    exact_correct   = 0
    partial_correct = 0
    format_correct  = 0

    collected = []

    for batch in tqdm(dataset):
        questions = batch["question"]
        answers   = batch["answer"]

        multi_outputs = [[] for _ in range(len(questions))]

        for _ in range(num_passes):
            # 1) Planner stage
            plans = generate_plan_with_sampler(questions, planner_sampler)

            # 2) Solver stage
            outputs = generate_solution_with_sampler(
                questions,
                plans,
                solver_sampler,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
            )

            for i, out in enumerate(outputs):
                multi_outputs[i].append(out)

        # 3) Evaluate each question
        for q, gold, responses in zip(questions, answers, multi_outputs):
            total += 1
            gold = gold.strip()

            is_exact  = False
            is_partial = False
            is_format  = False

            for resp in responses:
                # Extract numeric answer
                m = match_numbers.search(resp)
                pred_str = m.group(1).strip() if m else None

                try:
                    pred_val = float(pred_str) if pred_str is not None else None
                    gold_val = float(gold)
                except Exception:
                    pred_val = None

                # Exact correctness
                if pred_val is not None and pred_val == gold_val:
                    is_exact = True

                # Partial correctness (±10%)
                if pred_val is not None:
                    ratio = pred_val / gold_val
                    if 0.9 <= ratio <= 1.1:
                        is_partial = True

                # Format correctness: <reasoning>...</reasoning><answer>...</answer>
                if match_format.search(resp):
                    is_format = True

                if is_exact and is_partial and is_format:
                    break

            exact_correct   += int(is_exact)
            partial_correct += int(is_partial)
            format_correct  += int(is_format)

            if record_correct and is_exact:
                collected.append((q, gold, responses))
            if record_incorrect and not is_exact:
                collected.append((q, gold, responses))

    metrics = {
        "total": total,
        "exact_correct":   exact_correct,
        "partial_correct": partial_correct,
        "format_correct":  format_correct,
        "exact_accuracy":   100.0 * exact_correct   / total,
        "partial_accuracy": 100.0 * partial_correct / total,
        "format_accuracy":  100.0 * format_correct  / total,
    }

    return metrics, collected

# Trained Evaluations

In [11]:
# Baseline pipeline: planner + solver both baseline
# baseline_metrics, baseline_collected = evaluate_pipeline(
#     test_dataset,
#     planner_sampler=baseline_sampler,
#     solver_sampler=baseline_sampler,
#     num_passes=1,
# )

# print("BASELINE METRICS:")
# for k, v in baseline_metrics.items():
#     print(f"  {k}: {v}")

# Trained pipeline: planner + solver both trained
trained_metrics, trained_collected = evaluate_pipeline(
    test_dataset,
    planner_sampler=planner_sampler,
    solver_sampler=solver_sampler,
    num_passes=1,
)

print("\nTRAINED METRICS:")
for k, v in trained_metrics.items():
    print(f"  {k}: {v}")


100%|██████████| 100/100 [06:06<00:00,  3.67s/it]


TRAINED METRICS:
  total: 200
  exact_correct: 77
  partial_correct: 84
  format_correct: 95
  exact_accuracy: 38.5
  partial_accuracy: 42.0
  format_accuracy: 47.5





# Optional: Trace One Example

In [12]:
def trace_pipeline(question, true_answer, planner_sampler, solver_sampler):
    print("\n--- Tracing Pipeline ---")
    print(f"Question: {question}")

    planner_input = PLANNER_TEMPLATE.format(question=question)
    print(f"\nPlanner Input:\n{planner_input}")

    plans = generate_plan_with_sampler([question], planner_sampler)
    plan  = plans[0]

    print(f"\nGenerated Plan:\n{PLAN_START}\n{plan}\n{PLAN_END}")

    outputs = generate_solution_with_sampler([question], [plan], solver_sampler)
    solver_output = outputs[0]

    print(f"\nSolver Output:\n{solver_output}")

    # crude check-answer view
    m = match_numbers.search(solver_output)
    pred_str = m.group(1).strip() if m else None
    print(f"\nExtracted Predicted Answer: {pred_str}")
    print(f"True Answer: {true_answer}")
    print("--- End Tracing Pipeline ---")

# Example: trace the first sample from test set with trained pipeline
first_sample   = next(iter(test_dataset.shuffle(seed=65)))
first_question = first_sample["question"][0]
first_answer   = first_sample["answer"][0]

trace_pipeline(first_question, first_answer, planner_sampler, solver_sampler)


--- Tracing Pipeline ---
Question: Bill is ordering a new truck. He has decided to purchase a two-ton truck with several added features: a king cab upgrade, a towing package, leather seats, running boards, and the upgraded exterior light package.  The base price of the truck is $30,000, and the other features are at extra cost. The king cab is an extra $7,500, leather seats are one-third the cost of the king cab upgrade, running boards are $500 less than the leather seats, and the upgraded exterior light package is $1500.  What is the total cost of Bill's new truck, in dollars?

Planner Input:

<start_of_turn>user
You are a planning assistant. Produce a short numbered plan (3–5 steps)
for solving the problem. Do NOT solve the problem.

Problem:
Bill is ordering a new truck. He has decided to purchase a two-ton truck with several added features: a king cab upgrade, a towing package, leather seats, running boards, and the upgraded exterior light package.  The base price of the truck is 