<a href="https://colab.research.google.com/github/dwycoff2013/dual-stream/blob/main/dsagrpo_tunixGemma3FIXED.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dual-Stream Gemma with DSA-GRPO (Dual-Stream Aware GRPO)

This notebook fine-tunes **Gemma 3 1B** on a single Kaggle TPU using **Tunix**, but instead of plain GRPO it uses an **extension I call DSA-GRPO (Dual-Stream Aware GRPO)**.

The goal is not just to get good answers, but to train a model that:

- **Thinks in a structured, machine-readable way** (the *Monologue Stream*), and  
- **Stays coherent with its own thinking** when it produces the final answer (the *Answer Stream*).

All model outputs follow the competition format:

```text
<reasoning>
[plan] ...
[evidence] ...
[sanity_check] ...
[risk_scan] ...
[final_internal_action] ...
</reasoning>
<answer> final_answer_here </answer>
````

On the competition side, this maps directly to:

* `model_thinking_trace`  ← the contents of `<reasoning>...</reasoning>`
* `model_answer`          ← the contents of `<answer>...</answer>`

---

## What is DSA-GRPO?

Standard **GRPO** optimizes a policy using a group of sampled completions per prompt and a set of reward functions (format, correctness, etc.) that are usually combined in a fairly simple way.

**DSA-GRPO** keeps the same *infrastructure* (Tunix’s `GRPOLearner`, TPUs, Gemma3 + LoRA), but changes the **reward shaping and aggregation** so that it explicitly understands the **Dual-Stream Architecture (DSA)**:

* **Monologue Stream**: everything inside `<reasoning>...</reasoning>`.
* **Answer Stream**: the final answer inside `<answer>...</answer>`.

Instead of treating each reward independently, DSA-GRPO uses a **composite reward function** `dsa_grpo_reward` that aggregates several components:

* **Format rewards**

  * `match_format_exactly` – strict tag/layout compliance.
  * `match_format_approximately` – softer structural check.

* **Answer correctness rewards**

  * `check_answer` – string-level match against the target solution.
  * `check_numbers` – numeric correctness / sanity.

* **Dual-Stream / DSA rewards**

  * `dsa_monologue_structure` – rewards structured internal monologues with sections like `[plan]`, `[evidence]`, `[sanity_check]`, `[risk_scan]`, `[final_internal_action]`, plus multi-step reasoning.
  * `dsa_stream_coherence` – rewards agreement between the monologue’s final numeric conclusion and the `<answer>` value, and penalizes contradictions.

These signals are then combined inside **one DSA-aware scalar reward**:

* Each component is weighted (e.g. correctness gets more weight than formatting, coherence has its own weight, etc.).
* If the **format is broken** (no proper `<reasoning>` / `<answer>`), the total reward is heavily down-weighted.
* If **coherence is negative** (monologue and answer disagree), the combined reward is penalized and shrunk — discouraging “pretty explanations” that don’t match what the model actually answers.

In other words, DSA-GRPO doesn’t just say “do many good things at once”; it encodes **how those things should relate** in a Dual-Stream setting:

> A response is only truly good if the *format is valid*, the *answer is correct*, the *monologue is structured*, **and** the *monologue and answer agree*.

---

## What this notebook actually does

* **Model & framework**

  * Loads **Gemma 3 1B (Flax)** via **Tunix** on a Kaggle TPU.
  * Attaches a **LoRA policy head** (via QWIX2) so we can fine-tune with RL while keeping the base weights mostly frozen.

* **Dual-Stream prompting**

  * Uses a DSA-aware `SYSTEM_PROMPT` that instructs the model to separate:

    * `<reasoning>...</reasoning>` (Monologue Stream) and
    * `<answer>...</answer>` (Answer Stream),
      with explicit structure in the monologue.

* **DSA-GRPO training loop**

  * Instantiates an `RLCluster` with:

    * `actor`   = LoRA-adapted Gemma3,
    * `reference` = frozen Gemma3,
    * `tokenizer` = Gemma tokenizer,
    * `cluster_config` tuned for a single Kaggle TPU session.
  * Wraps this in `GRPOLearner`, but sets:

    ```python
    reward_fns = [dsa_grpo_reward]
    ```

    so **all** learning is driven by the DSA-aware composite reward.
  * Trains on a reasoning dataset (e.g. GSM8K or similar) within a 9-hour Kaggle TPU budget.

* **Generation & submission**

  * Provides:

    * `parse_dual_stream(completion)` → `(monologue, answer)`
    * `generate_dual_stream(question, sampler, ...)` → `(monologue, answer)`
  * These are wired to produce the exact competition output format:

    ```text
    <reasoning>model_thinking_trace</reasoning>
    <answer>model_answer</answer>
    ```
  * A small demo cell shows how to generate and print:

    * `model_thinking_trace`
    * `model_answer`
      for a sample question.

* **Checkpoints & reproducibility**

  * Uses Tunix’s built-in checkpointing, so the **fine-tuned LoRA checkpoint**:

    * Is produced inside the same notebook / single TPU session.
    * Can be reloaded via the **Gemma2/3 modeling code in Tunix** on Kaggle.

---

## Why this is different from plain GRPO

Traditional GRPO on this competition might:

* Reward format,
* Reward correctness,
* Maybe reward “length” or “presence of chain-of-thought”.

But it typically doesn’t care whether the model’s *inner story* and *outer answer* actually line up.

**DSA-GRPO adds two critical properties:**

1. **Monologue structure as a first-class objective**
   The model is explicitly rewarded for producing monologues that look like real internal deliberation rather than a single line of algebra.

2. **Coherence as a first-class objective**
   The model is punished when its Monologue Stream and Answer Stream disagree. This pushes it away from post-hoc rationalizations and towards **actual alignment between thought and answer**.

That makes this notebook not just “Gemma with GRPO that shows its work”, but a **Dual-Stream Gemma** that is trained to keep its own internal monologue and final answer in sync, which is exactly what the competition is trying to probe: models that don’t just talk about their reasoning, but **live inside a training signal that cares about that reasoning being coherent.**


*all content based on the whitepaper "The Inner Monologue: A Dual-Stream Architecture for Verifiable Inner Alignment"- daniel wycoff*****

https://docs.google.com/document/d/1np-I9zEKArodlDhQzfydhloCXIVK9O72g3OJSuo_-Wk/edit?usp=sharing

*contact: daniel[dot]w[at]eorumyoung[dot]com*

In [None]:
# W&B compatibility patch for Tunix on Kaggle.
# Some Tunix versions expect wandb.util._has_internet, which is missing
# in newer wandb releases. This patch adds a no-op implementation so
# wandb.init() does not crash inside Tunix metrics_logger.
#!pip install wandb
#try:
#  import wandb
 # if not hasattr(wandb.util, "_has_internet"):
 #   def _has_internet():
      # Treat environment as offline; prevents crashes in wandb login.
 #     return False
 #   wandb.util._has_internet = _has_internet
 #   print("Patched wandb.util._has_internet for Tunix.")
#except Exception as e:
# print("W&B compatibility patch failed:", e)


## Install necessary libraries

In [None]:
!pip install -q kagglehub

!pip install -q ipywidgets

!pip install -q tensorflow
!pip install -q tensorflow_datasets
!pip install -q tensorboardX
!pip install -q transformers
!pip install -q grain
!pip install "google-tunix[prod]==0.1.3"

# !pip install -q git+https://github.com/google/tunix
# !pip install -q git+https://github.com/google/qwix

# !pip uninstall -q -y flax  # disabled: Tunix pins flax/nnx
# !pip install -q git+https://github.com/google/flax.git
# !pip install -U flax  # disabled: Tunix pins flax/nnx


!pip install -q datasets



## Imports

In [None]:
import functools
import gc
import os
from pprint import pprint
import re

import csv
import shutil

from flax import nnx
import grain
import humanize
import jax
import jax.numpy as jnp
import kagglehub
from orbax import checkpoint as ocp
import optax
from pathlib import Path
import qwix
import tensorflow_datasets as tfds
from tqdm.auto import tqdm
from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
# from tunix.models.gemma3 import model as gemma_lib
# from tunix.models.gemma3 import params as params_lib
from tunix.models.gemma3 import params
from tunix.models.gemma3 import model
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.sft import metrics_logger
from datasets import load_dataset



In [None]:
# Compatibility patch: make flax.nnx.Variable.set_metadata accept
# the legacy (name, value) calling pattern used by QWIX LoRA provider.
from flax import nnx

if getattr(getattr(nnx, "Variable", None), "_dsagrpo_patched", False) is False:
  _orig_set_metadata = nnx.Variable.set_metadata

  def _dsagrpo_set_metadata(self, *args, **kwargs):
    # Legacy patterns:
    #   set_metadata({'sharding_names': axes})
    #   set_metadata('sharding_names', axes)
    if args and not kwargs:
      if len(args) == 1 and isinstance(args[0], dict):
        return _orig_set_metadata(self, **args[0])
      if len(args) == 2 and isinstance(args[0], str):
        return _orig_set_metadata(self, **{args[0]: args[1]})
    # Modern usage: direct kwargs (e.g. set_metadata(sharding_names=axes))
    return _orig_set_metadata(self, **kwargs)

  nnx.Variable.set_metadata = _dsagrpo_set_metadata
  nnx.Variable._dsagrpo_patched = True
  print("Patched nnx.Variable.set_metadata for QWIX compatibility.")


Patched nnx.Variable.set_metadata for QWIX compatibility.


## Hyperparameters

Let's define the configuration we are going to use. Note that this is by no
means a "perfect" set of hyperparameters. To get good results, you might have
to train the model for longer.

In [None]:
# ====== Data ======
TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR = "./data/test"
TRAIN_FRACTION = 1.0

# ====== LoRA ======
RANK = 64
ALPHA = 64.0
# If you still hit OOM, you can try:
# RANK = 32

# ====== Sharding ======
MESH = [(1, 4), ("fsdp", "tp")]

# ====== GRPO ======
# === Generation during GRPO training ===
# These control how many tokens are seen per GRPO rollout during training
# (not the sampler cache, which we'll handle separately).
MAX_PROMPT_LENGTH = 256          # was 256
TOTAL_GENERATION_STEPS = 512     # was 512

# Important to keep a high-ish temperature for varied, diverse responses during
# training.
TEMPERATURE = 0.9
TOP_P = 1.0
TOP_K = 50

# The number of times the policy generates multiple responses for a given prompt
# within a single training step (G in GRPO).
NUM_GENERATIONS = 4              # was 4

# === other GRPO configs ===
NUM_ITERATIONS = 1
BETA = 0.08
EPSILON = 0.2

# ====== Training ======
# Per-device micro-batch size (this is the main memory lever).
TRAIN_MICRO_BATCH_SIZE = 4       # was 4

# Increase NUM_BATCHES / MAX_STEPS for better results if you have time.
NUM_BATCHES = 3738
NUM_TEST_BATCHES = 100

EVAL_EVERY_N_STEPS = 10
NUM_EPOCHS = 1

MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)

# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 3e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1

WARMUP_STEPS = 0.1 * MAX_STEPS

# == Grad clipping ==
MAX_GRAD_NORM = 0.1

# Checkpoint saving
INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt/"
CKPT_DIR = "/tmp/content/ckpts/"
SAVE_INTERVAL_STEPS = 500
MAX_TO_KEEP = 4

# ====== Rollout length for GRPO training ======
# The sampler cache_size is 704. We choose a safe split:
#   256 tokens for prompt + 256 tokens for generation = 512 < 704
ROLLOUT_MAX_PROMPT_LENGTH = 256
ROLLOUT_GENERATION_STEPS = 256

# ====== Evaluation sampling limits ======
# For eval (and manual generate()), we keep the same safe envelope.
EVAL_MAX_PROMPT_LENGTH = 256
EVAL_GENERATION_STEPS = 256

# ====== Inference ======
GENERATION_CONFIGS = {
    "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
    "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
}


## Utility functions

In [None]:
def show_hbm_usage():
  """Displays memory usage per device."""
  fmt_size = functools.partial(humanize.naturalsize, binary=True)

  for d in jax.local_devices():
    stats = d.memory_stats()
    used = stats["bytes_in_use"]
    limit = stats["bytes_limit"]
    print(f"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}")

## Data preprocessing

First, let's define some special tokens. We instruct the model to first reason
between the `<reasoning>` and `</reasoning>` tokens. After
reasoning, we expect it to provide the answer between the `<answer>` and
`</answer>` tokens.

In [None]:
reasoning_start = "<reasoning>"
reasoning_end = "</reasoning>"
solution_start = "<answer>"
solution_end = "</answer>"


SYSTEM_PROMPT = f"""You are a Dual-Stream Architecture reasoning model.

You must produce two distinct output streams:

1. A Monologue Stream between {reasoning_start} and {reasoning_end}
   containing your full internal chain-of-thought: planning,
   evidence, sanity checks, and any risk scans. Structure it with
   markers like [plan], [evidence], [sanity_check],
      and [final].

2. An Answer Stream between {solution_start} and {solution_end}
   containing the final numeric answer.

The Monologue Stream and Answer Stream must be coherent with each
other and with the question.

Always follow exactly this format:

{reasoning_start} <reasoning>
[plan] ...
[evidence] ...
[sanity_check] ...
[final] ... </reasoning>
{reasoning_end}
{solution_start} <answer> [Final answer here] </answer> {solution_end}
"""

TEMPLATE = """<start_of_turn> user:
{question}{system_prompt} </turn>
<start_of_turn> model:
<reasoning>
[plan]
[evidence] ...
[sanity_check] ...
[final] ...
</reasoning>
<answer> ... </answer>
</turn>"""


We use OpenAI's [GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k), which comprises grade school math word problems.

In [None]:
def extract_hash_answer(text: str) -> str | None:
  if "####" not in text:
    return None
  return text.split("####")[1].strip()


def _load_from_tfds(data_dir: str, split: str):
  import tensorflow_datasets.text.gsm8k
  return tfds.data_source(
      "gsm8k",
      split=split,
      data_dir=data_dir,
      builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
      download=True,
  )


def download_kaggle_dataset(target_dir="./data/gsm8k"):
  os.makedirs(target_dir, exist_ok=True)
  src = kagglehub.dataset_download("thedevastator/grade-school-math-8k-q-a")
  src = Path(src)
  dst = Path(target_dir)

  for csv_file in src.glob("*.csv"):  # match all CSV files
    shutil.copy2(csv_file, dst / csv_file.name)
    print(f"Copied {csv_file.name} → {dst/csv_file.name}")
  return target_dir


def get_dataset(data_dir, split="train", source="tfds") -> grain.MapDataset:
  # Download data
  if not os.path.exists(data_dir):
    os.makedirs(data_dir)

  if source == "tfds":
    import tensorflow_datasets.text.gsm8k
    data = tfds.data_source(
        "gsm8k",
        split=split,
        data_dir=data_dir,
        builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
        download=True,
    )

  elif source == "kaggle":
    kaggle_dir = download_kaggle_dataset(data_dir)
    file_name = "main_" + split + ".csv"
    csv_path = os.path.join(kaggle_dir, file_name)  # adjust filename if needed

    data = []
    with open(csv_path, newline="", encoding="utf-8") as csvfile:
      reader = csv.DictReader(csvfile)
      for row in reader:
        data.append({
            "question": row["question"],
            "answer": row["answer"],
        })

  elif source == "huggingface":
    os.environ["HF_HUB_DISABLE_XET"] = "1"
    data = load_dataset("gsm8k", "main", split=split)

  else:
    raise ValueError(f"Unknown source: {source}")

  def _as_text(v):
    return v if isinstance(v, str) else v.decode("utf-8")

  dataset = (
      grain.MapDataset.source(data)
      .shuffle(seed=42)
      .map(
          lambda x: {
              # passed to model forward pass
              "prompts": TEMPLATE.format(
                  system_prompt=SYSTEM_PROMPT,
                  question=_as_text(x["question"]),
              ),
              # passed to reward functions
              "question": _as_text(x["question"]),
              # passed to reward functions
              "answer": extract_hash_answer(_as_text(x["answer"])),
          }
      )
  )
  return dataset

We split the dataset set into train and test sets as usual.

In [None]:
# source = input("Choose data source [tfds/kaggle]: ").strip().lower()
source = "huggingface"

if source not in ("tfds", "kaggle", "huggingface"):
  print("Invalid choice. Defaulting to 'tfds'.")
  source = ""

print(f"Using data source: {source}")

dataset = get_dataset(TRAIN_DATA_DIR, "train", source).batch(TRAIN_MICRO_BATCH_SIZE)[
    :NUM_BATCHES
]

if TRAIN_FRACTION == 1.0:
  train_dataset = dataset.repeat(NUM_EPOCHS)
  val_dataset = None
else:
  train_dataset = dataset[: int(len(dataset) * TRAIN_FRACTION)]
  train_dataset = train_dataset.repeat(NUM_EPOCHS)

  val_dataset = dataset[int(len(dataset) * TRAIN_FRACTION) :].repeat(NUM_EPOCHS)

test_dataset = get_dataset(TEST_DATA_DIR, "test", source).batch(TRAIN_MICRO_BATCH_SIZE)[
    :NUM_TEST_BATCHES
]

dataset_lengths = (
    len(train_dataset),
    len(val_dataset) if val_dataset is not None else 0,
    len(test_dataset),
)
print(f"dataset contains {dataset_lengths} of batches")

Using data source: huggingface


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


dataset contains (3738, 0, 100) of batches


Let's see how one batch of the training dataset looks like!


In [None]:
for ele in train_dataset[:1]:
  pprint(ele)

{'answer': array(['3'], dtype='<U1'),
 'prompts': array(['<start_of_turn>user\nYou are a Dual-Stream Architecture reasoning model.\n\nYou must produce two distinct output streams:\n\n1. A Monologue Stream between <reasoning> and </reasoning>\n   containing your full internal chain-of-thought: planning,\n   evidence, sanity checks, and any risk scans. Structure it with\n   markers like [plan], [evidence], [sanity_check], [risk_scan],\n   and [final_internal_action].\n\n2. An Answer Stream between <answer> and </answer>\n   containing only the final numeric answer (no explanation).\n\nThe Monologue Stream and Answer Stream must be coherent with each\nother and with the question.\n\nAlways follow exactly this format:\n\n<reasoning>\n[plan] ...\n[evidence] ...\n[sanity_check] ...\n[risk_scan] ...\n[final_internal_action] ...\n</reasoning>\n<answer> <single numeric answer> </answer>\n\n\nMaria has 4 dimes, 4 quarters, and 7 nickels in her piggy bank. Her mom gives her 5 quarters. How much m

## Load the policy model and the reference model

The policy model is the model which is actually trained and whose weights are
updated. The reference model is the model with which we compute KL divergence.
This is to ensure that the policy updates are not huge and that it does not
deviate too much from the reference model.

Typically, the reference model is the base model, and the policy model is the
same base model, but with LoRA parameters. Only the LoRA parameters are updated.

Note: We perform full precision (fp32) training. You can, however, leverage
Qwix for QAT.

To load the model, you need to be on [Kaggle](https://www.kaggle.com/) and need
to have agreed to the Gemma license
[here](https://www.kaggle.com/models/google/gemma/flax/).

In [None]:
# Log in
if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
  kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

This section would normally describe an intermediate Orbax checkpoint
conversion step used in some of the original Tunix GRPO demos.

For this **DSAGRPO** version of the notebook we simplify the flow:

* We load the Gemma 3 1B-IT weights **directly** into an NNX model using
  `tunix.models.gemma3.params`.
* We keep that `base_model` in memory and build the LoRA policy and
  reinforcement learning stack on top of it.
* Orbax is still used internally by Tunix's checkpoint manager during RL
  training, but we no longer perform an extra save/restore pass just to
  re-materialize the model.

This keeps the model-loading path simpler and avoids a class of
checkpoint/metadata issues, while remaining fully compatible with the
competition requirements.


In [None]:
from tunix.models.gemma3 import params

# Load Gemma3 base model and tokenizer with an intermediate Orbax checkpoint
# to keep peak memory usage lower on Kaggle.
model_family = "gemma3"
if model_family != "gemma3":
  raise ValueError("This notebook currently only supports Gemma3-1B-IT.")

MODEL_CP_PATH = params.GEMMA3_1B_IT
model_config = model.ModelConfig.gemma3_1b()

# Clean out any prior intermediate / training checkpoints.
!rm /tmp/content/intermediate_ckpt/* -rf
!rm /tmp/content/ckpts/* -rf

print("Loading Gemma3 1B-IT base model from checkpoint...")
base_model_tmp = params.create_model_from_checkpoint(MODEL_CP_PATH, model_config)
checkpointer = ocp.StandardCheckpointer()
_, state = nnx.split(base_model_tmp)
checkpointer.save(os.path.join(INTERMEDIATE_CKPT_DIR, "state"), state)
checkpointer.wait_until_finished()
print("Saved intermediate checkpoint to", INTERMEDIATE_CKPT_DIR)

# Delete the temporary in-memory base model to free memory.
del base_model_tmp
del state
gc.collect()

# Tokenizer is lightweight, so we can keep it resident in memory.
tokenizer = params.create_tokenizer()


Loading Gemma3 1B-IT base model from checkpoint...




Saved intermediate checkpoint to /tmp/content/intermediate_ckpt/


click **Run > Run current and after** on next cell:

### Model Loading and LoRA Application

These two functions work together to load a base model from a checkpoint and apply a LoRA (Low-Rank Adaptation) layer to it.

* `get_ref_model`: Loads the complete Gemma model from a specified checkpoint path. It uses **JAX sharding** to distribute the model parameters across multiple devices.
* `get_lora_model`: Takes the base model and applies LoRA layers to it. It uses a `LoraProvider` to select specific layers (like attention and MLP layers) to be adapted. The resulting LoRA-infused model is then sharded and updated to ensure it's ready for distributed training.

In [None]:
from tunix.models.gemma3 import params
from tunix.models.gemma3 import model
import jax
import jax.numpy as jnp
from flax import nnx
import orbax.checkpoint as ocp
import os

# Ensure MESH is available (e.g., ((1, 1), ('dp', 'tp')) for single device)

def get_gemma_ref_model(ckpt_path: str | None = None):
  """Return the reference Gemma3 1B model plus mesh and config."""
  if ckpt_path is None:
    ckpt_path = os.path.join(INTERMEDIATE_CKPT_DIR, "state")

  # Handle mesh creation. Suppress warning by just calling make_mesh with what we have.
  # If MESH is ((1, 1), ('dp', 'tp')), we unpack it.
  mesh = jax.make_mesh(*MESH)

  model_config = model.ModelConfig.gemma3_1b()

  # 1. Build abstract module to get the state structure
  # Note: Using model.Gemma3 based on previous context
  abs_gemma: nnx.Module = nnx.eval_shape(
      lambda: model.Gemma3(model_config, rngs=nnx.Rngs(params=0))
  )

  abs_state = nnx.state(abs_gemma)
  abs_state = jax.tree.map(
      lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.bfloat16, sharding=s),
      abs_state,
      nnx.get_named_sharding(abs_state, mesh),
  )

  # 2. Restore parameters into the abstract state structure
  checkpointer = ocp.StandardCheckpointer()
  restored_state = checkpointer.restore(ckpt_path, target=abs_state)

  # 3. Reconstruct the model using nnx.merge
  # 'init_from_state' does not exist in modern nnx.
  # We simply merge the graph definition with the restored state.
  graph_def, _ = nnx.split(abs_gemma)

  # FIX: Use nnx.merge instead of graph_def.init_from_state
  base_model = nnx.merge(graph_def, restored_state)

  return base_model, mesh, model_config

def get_lora_model(base_model, mesh):
  """Apply LoRA adapters to the base model to obtain the policy model."""
  lora_provider = qwix.LoraProvider(
      module_path=(
          ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
          ".*attn_vec_einsum"
      ),
      rank=RANK,
      alpha=ALPHA,
  )

  model_input = base_model.get_model_input()
  lora_model = qwix.apply_lora_to_model(
      base_model, lora_provider, **model_input
  )

  return lora_model

Now we load reference and policy Gemma models using the Flax NNX library and display their structures.

In [None]:
dir(model)

['Attention',
 'AttentionType',
 'Block',
 'Cache',
 'Einsum',
 'Embedder',
 'FeedForward',
 'GEMMA3_ATTENTION_PATTERN',
 'Gemma3',
 'K_MASK',
 'LayerCache',
 'ModelConfig',
 'QueryPreAttentionNormalisation',
 'RMSNorm',
 'RematConfig',
 'ShardingConfig',
 'Tuple',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__spec__',
 'apply_rope',
 'create_sliding_window_mask',
 'dataclasses',
 'enum',
 'find_last_one_index',
 'flax',
 'itertools',
 'jax',
 'jaxtyping',
 'jnp',
 'nnx',
 'pxla',
 'shard',
 'shd']

In [None]:
# Reference model
if model_family == "gemma3":
  ref_model, mesh, model_config = get_gemma_ref_model()


  mesh = jax.make_mesh(*MESH)


In [None]:
# Policy model
lora_policy = get_lora_model(ref_model, mesh=mesh)
# nnx.display(lora_policy)

## Define reward functions

We define four reward functions:

- reward if the format of the output exactly matches the instruction given in
`TEMPLATE`;
- reward if the format of the output approximately matches the instruction given
in `TEMPLATE`;
- reward if the answer is correct/partially correct;
- Sometimes, the text between `<answer>`, `</answer>` might not be one
  number. So, we extract the number, and reward the model if the answer is correct.

The reward functions are inspired from
[here](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb).

First off, let's define a RegEx for checking whether the format matches.

In [None]:
match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{reasoning_start}.+?{reasoning_end}.*?"
    rf"{solution_start}(.+?){solution_end}"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL,
)

match_format.search(
    f"{reasoning_start}Let me"
    f" think!{reasoning_end}{solution_start}2{solution_end}",
)

<re.Match object; span=(0, 54), match='<reasoning>Let me think!</reasoning><answer>2</an>

Give the model a reward of 3 points if the format matches exactly.

In [None]:
def match_format_exactly(prompts, completions, **kwargs):
  return [
      0 if match_format.search(response) is None else 3.0
      for response in completions
  ]

We also reward the model if the format of the output matches partially.

In [None]:
def match_format_approximately(prompts, completions, **kwargs):
  scores = []

  for completion in completions:
    score = 0
    response = completion
    # Count how many keywords are seen - we penalize if too many!
    # If we see 1, then plus some points!
    score += 0.5 if response.count(reasoning_start) == 1 else -0.5
    score += 0.5 if response.count(reasoning_end) == 1 else -0.5
    score += 0.5 if response.count(solution_start) == 1 else -0.5
    score += 0.5 if response.count(solution_end) == 1 else -0.5
    scores.append(score)
  return scores

Reward the model if the answer is correct. A reward is also given if the answer
does not match exactly, i.e., based on how close the answer is to the correct
value.

In [None]:
def check_answer(prompts, completions, answer, **kwargs):
  responses = completions

  extracted_responses = [
      guess.group(1) if (guess := match_format.search(r)) is not None else None
      for r in responses
  ]

  scores = []
  assert len(extracted_responses) == len(
      answer
  ), f"{extracted_responses} and {answer} have mismatching length"
  for guess, true_answer in zip(extracted_responses, answer):
    score = 0
    if guess is None:
      scores.append(0)
      continue
    # Correct answer gets 3 points!
    if guess == true_answer:
      score += 3.0
    # Match if spaces are seen
    elif guess.strip() == true_answer.strip():
      score += 1.5
    else:
      # We also reward it if the answer is close via ratios!
      # Ie if the answer is within some range, reward it!
      try:
        ratio = float(guess) / float(true_answer)
        if ratio >= 0.9 and ratio <= 1.1:
          score += 0.5
        elif ratio >= 0.8 and ratio <= 1.2:
          score += 0.25
        else:
          score -= 1.0  # Penalize wrong answers
      except:
        score -= 0.5  # Penalize
    scores.append(score)
  return scores

Sometimes, the text between `<answer>` and `</answer>` might not be one
number; it can be a sentence. So, we extract the number and compare the answer.

In [None]:
match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL
)
match_numbers.findall(f"{solution_start}  0.34  {solution_end}")

['0.34']

In [None]:
def check_numbers(prompts, completions, answer, **kwargs):
  question = kwargs["question"]
  responses = completions

  extracted_responses = [
      guess.group(1) if (guess := match_numbers.search(r)) is not None else None
      for r in responses
  ]

  scores = []
  print("START ============================")
  print(f"Question: {question[0]}")
  print(f"Answer: {answer[0]}")
  print(f"Response: {responses[0]}")
  print(f"Extracted: {extracted_responses[0]}")
  print("END ==============================")
  for guess, true_answer in zip(extracted_responses, answer):
    if guess is None:
      scores.append(0)
      continue
    # Convert to numbers
    try:
      true_answer = float(true_answer.strip())
      guess = float(guess.strip())
      scores.append(1.5 if guess == true_answer else 0.0)
    except:
      scores.append(0)
      continue
  return scores

In [None]:
# === Dual-Stream / DSA rewards ===

# Heuristic markers that encourage structured internal monologue.
section_markers = [
    "[plan]",
    "[evidence]",
    "[sanity_check]",
    "[risk_scan]",
    "[final_internal_action]",
]


def dsa_monologue_structure(prompts, completions, **kwargs):
  """Reward structured monologues inside <reasoning>...</reasoning>."""
  scores = []
  for completion in completions:
    score = 0.0

    # Extract monologue block.
    m = re.search(
        rf"{re.escape(reasoning_start)}(.*?){re.escape(reasoning_end)}",
        completion,
        flags=re.DOTALL | re.MULTILINE,
    )
    if not m:
      scores.append(-1.0)
      continue

    mono = m.group(1)

    # Reward presence of structural markers.
    for marker in section_markers:
      if marker in mono:
        score += 0.25  # up to +1.25 from markers alone

    # Reward multi-step reasoning (numbered or bulleted steps).
    step_pattern = re.compile(
        r"(^\s*(\d+[)\.:-]|-\s+))",
        flags=re.MULTILINE,
    )
    num_steps = len(step_pattern.findall(mono))
    if num_steps >= 3:
      score += 0.5
    elif num_steps >= 1:
      score += 0.25
    else:
      score -= 0.25

    # Clamp to a reasonable range.
    score = max(-1.0, min(score, 2.0))
    scores.append(score)

  return scores


def dsa_stream_coherence(prompts, completions, answer, **kwargs):
  """Reward internal coherence between Monologue and Answer streams."""
  scores = []
  for completion in completions:
    score = 0.0

    # Extract last numeric guess from the monologue.
    m_mono = re.search(
        rf"{re.escape(reasoning_start)}(.*?){re.escape(reasoning_end)}",
        completion,
        flags=re.DOTALL | re.MULTILINE,
    )
    monologue_guess = None
    if m_mono:
      mono = m_mono.group(1)
      nums = re.findall(r"[-+]?\d*\.?\d+", mono)
      if nums:
        monologue_guess = nums[-1].strip()

    # Extract numeric answer from the <answer> block.
    m_ans = match_numbers.search(completion)
    answer_guess = m_ans.group(1).strip() if m_ans else None

    if monologue_guess is None or answer_guess is None:
      scores.append(0.0)
      continue

    try:
      mono_val = float(monologue_guess)
      ans_val = float(answer_guess)
    except Exception:
      scores.append(0.0)
      continue

    if mono_val == ans_val:
      score += 1.0
    else:
      # If they strongly disagree, penalize.
      if ans_val != 0:
        ratio = mono_val / ans_val
        if 0.9 <= ratio <= 1.1:
          score += 0.5  # numerically close
        else:
          score -= 1.0
      else:
        score -= 0.5

    scores.append(score)

  return scores


In [None]:
# === DSA-GRPO composite reward ===

DSA_REWARD_WEIGHTS = {
    "format_exact": 0.5,
    "format_approx": 0.25,
    "answer_exact": 1.5,
    "answer_numeric": 1.0,
    "monologue_structure": 0.5,
    "stream_coherence": 1.0,
}


def dsa_grpo_reward(prompts, completions, answer, **kwargs):
  """Composite reward used by DSA-GRPO.

  This aggregates:
    - format rewards (exact/approx)
    - answer correctness (string / numeric)
    - monologue structure
    - monologue/answer coherence

  and gates some of them on basic format + coherence to discourage
  'pretty' monologues that disagree with the final answer.
  """
  # Component rewards
  r_format_exact = match_format_exactly(prompts, completions, **kwargs)
  r_format_approx = match_format_approximately(prompts, completions, **kwargs)
  r_answer = check_answer(prompts, completions, answer, **kwargs)
  r_numbers = check_numbers(prompts, completions, answer, **kwargs)
  r_mono = dsa_monologue_structure(prompts, completions, **kwargs)
  r_coh = dsa_stream_coherence(prompts, completions, answer, **kwargs)

  rewards = []
  for fe, fa, ans, num, mono, coh in zip(
      r_format_exact, r_format_approx, r_answer, r_numbers, r_mono, r_coh
  ):
    total = 0.0

    total += DSA_REWARD_WEIGHTS["format_exact"] * fe
    total += DSA_REWARD_WEIGHTS["format_approx"] * fa
    total += DSA_REWARD_WEIGHTS["answer_exact"] * ans
    total += DSA_REWARD_WEIGHTS["answer_numeric"] * num
    total += DSA_REWARD_WEIGHTS["monologue_structure"] * mono
    total += DSA_REWARD_WEIGHTS["stream_coherence"] * coh

    # If format is badly broken, strongly downweight the rest.
    if fe <= 0.0 and fa <= 0.0:
      total *= 0.25

    # If coherence is negative, treat it as a red flag and penalize harder.
    if coh < 0.0:
      total += coh  # extra penalty
      total *= 0.5

    rewards.append(float(total))

  return rewards


## Evaluate


Before we train the model, let's evaluate the model on the test set so we can
see the improvement post training.

We evaluate it in two ways:

**Quantitative**

* **Answer Accuracy**: percentage of samples for which the model predicts the
correct final numerical answer  
* **Answer (Partial) Accuracy**: percentage of samples for which the model
predicts a final numerical answer such that the \`model answer / answer\`
ratio lies between 0.9 and 1.1.  
* **Format Accuracy**: percentage of samples for which the model outputs the
correct format, i.e., reasoning between the reasoning special tokens, and the
final answer between the \`\<start\_answer\>\`, \`\<end\_answer\>\` tokens.

**Qualitative**

We'll also print outputs for a few given questions so that we can compare the generated output later.


We define a helper function to generate an answer, given a prompt.

In [None]:
def generate(
    question,
    sampler,
    temperature: float = 0.7,
    top_k: int = 50,
    top_p: float = 0.95,
    seed: int | None = None,
):
  """Given prompt(s), generates text with cache-safe settings."""

  # Normalize to a list of questions
  if isinstance(question, str):
    questions = [question]
    single_input = True
  else:
    questions = list(question)
    single_input = False

  # Build the actual text prompts using your TEMPLATE / SYSTEM_PROMPT
  input_batch = [
      TEMPLATE.format(
          system_prompt=SYSTEM_PROMPT,
          question=q,
      )
      for q in questions
  ]

  # ---- Cache-aware length control ----
  # Get cache capacity from the sampler
  cache_limit = int(sampler.cache_config.cache_size)

  # Tokenize prompts to estimate their lengths
  # (we use the global `tokenizer` that was used to build the sampler)
  tokenized = [tokenizer.encode(s) for s in input_batch]
  prompt_lengths = [len(t) for t in tokenized]
  max_prompt_len_seen = max(prompt_lengths)

  # Desired eval lengths (can be defined in your hyperparam cell, e.g. 256/256)
  # If you already have these globals, this will just reuse them.
  # Otherwise you can set them here.
  try:
    desired_prompt_len = EVAL_MAX_PROMPT_LENGTH
    desired_gen_steps = EVAL_GENERATION_STEPS
  except NameError:
    # Fallback if those aren't defined elsewhere
    desired_prompt_len = 256
    desired_gen_steps = 256

  # Never let generation alone exceed half the cache
  safe_gen_steps = min(desired_gen_steps, cache_limit // 2)

  # Bound the prompt length so that prompt + generation <= cache_limit
  # 1. Respect both what we see and the desired cap
  tentative_prompt_len = min(max_prompt_len_seen, desired_prompt_len)
  # 2. Enforce cache constraint
  safe_prompt_len = min(tentative_prompt_len, cache_limit - safe_gen_steps)

  # Final safety clamp: if we somehow still exceed cache, shrink generation
  if safe_prompt_len + safe_gen_steps > cache_limit:
    safe_gen_steps = max(1, cache_limit - safe_prompt_len)

  # Debugging prints (optional)
  print(
      f"[generate] cache_limit={cache_limit}, "
      f"max_prompt_len_seen={max_prompt_len_seen}, "
      f"using max_prompt_length={safe_prompt_len}, "
      f"max_generation_steps={safe_gen_steps}"
  )

  # ---- Call Tunix sampler with safe lengths ----
  out_data = sampler(
      input_strings=input_batch,
      max_prompt_length=int(safe_prompt_len),
      max_generation_steps=int(safe_gen_steps),
      temperature=temperature,
      top_k=top_k,
      top_p=top_p,
      echo=False,
      seed=seed if seed is not None else None,
      eos_tokens=[1, 106],
  )

  output = out_data.text
  if single_input:
    return output[0]
  return output


In [None]:
# === Dual-Stream parsing & generation helpers ===

def parse_dual_stream(completion: str):
  """Split a completion into (monologue, answer) strings."""
  mono_match = re.search(
      rf"{re.escape(reasoning_start)}(.*?){re.escape(reasoning_end)}",
      completion,
      flags=re.DOTALL | re.MULTILINE,
  )
  ans_match = re.search(
      rf"{re.escape(solution_start)}(.*?){re.escape(solution_end)}",
      completion,
      flags=re.DOTALL | re.MULTILINE,
  )

  monologue = mono_match.group(1).strip() if mono_match else completion.strip()
  answer = ans_match.group(1).strip() if ans_match else ""
  return monologue, answer


def generate_dual_stream(
    question, sampler, temperature=0.7, top_k=50, top_p=0.95, seed=None
):
  """Wrapper around `generate` that returns (monologue, answer)."""
  raw = generate(
      question,
      sampler,
      temperature=temperature,
      top_k=top_k,
      top_p=top_p,
      seed=seed,
  )

  if isinstance(raw, str):
    return parse_dual_stream(raw)

  return [parse_dual_stream(c) for c in raw]


Another helper function for evaluation.

In [None]:
def evaluate(
    dataset,
    sampler,
    temperature=0.7,
    top_k=50,
    top_p=0.95,
    num_passes=1,
    corr_lst=False,
    make_lst=False,
):
  """Computes accuracy and percentage of outputs matching the format."""

  response_lst = []
  corr = 0
  partially_corr = 0
  corr_format = 0
  total = 0

  for batch in tqdm(dataset):
    answers = batch["answer"]
    questions = batch["question"]

    multiple_call_responses = [[] for _ in range(len(questions))]
    for p in range(num_passes):
      responses = generate(
          questions, sampler, temperature, top_k, top_p, seed=p
      )
      for idx, response in enumerate(responses):
        multiple_call_responses[idx].append(response)

    for question, multiple_call_response, answer in zip(
        questions, multiple_call_responses, answers
    ):
      # check answer
      corr_ctr_per_question = 0
      partially_corr_per_question = 0
      corr_format_per_question = 0
      for response in multiple_call_response:
        extracted_response = (
            guess.group(1)
            if (guess := match_numbers.search(response)) is not None
            else "-1000000"
        )
        try:
          if float(extracted_response.strip()) == float(answer.strip()):
            corr_ctr_per_question += 1

          ratio = float(extracted_response.strip()) / float(answer.strip())
          if ratio >= 0.9 and ratio <= 1.1:
            partially_corr_per_question += 1
        except:
          print("SKIPPED")

        # check format
        if match_format.search(response) is not None:
          corr_format_per_question += 1

        if (
            corr_ctr_per_question > 0
            and partially_corr_per_question > 0
            and corr_format_per_question > 0
        ):
          break

      if corr_ctr_per_question > 0:
        corr += 1
        if corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      else:
        if not corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      if partially_corr_per_question > 0:
        partially_corr += 1
      if corr_format_per_question > 0:
        corr_format += 1

      total += 1
      if total % 10 == 0:
        print(
            f"===> {corr=}, {total=}, {corr / total * 100=}, "
            f"{partially_corr / total * 100=}, {corr_format / total * 100=}"
        )

  to_return = (
      corr,
      total,
      corr / total * 100,
      partially_corr / total * 100,
      corr_format / total * 100,
  )
  if make_lst:
    return to_return, response_lst
  return to_return

In [None]:
sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 512,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

Now let's see how the original model does on the test set. You can see the percentages of the mode outputs that are fully correct, partially correct and just correct in format. The following step might take couple of minutes to finish.

In [None]:
# The evaluation might take up to couple of minutes to finish. Please be patient.

(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(
    test_dataset,
    sampler,
    **GENERATION_CONFIGS["greedy"],
)
print(
    f"{corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,"
    f" {format_accuracy=}%"
)

  0%|          | 0/100 [00:00<?, ?it/s]

[generate] cache_limit=960, max_prompt_len_seen=312, using max_prompt_length=256, max_generation_steps=256
[generate] cache_limit=960, max_prompt_len_seen=290, using max_prompt_length=256, max_generation_steps=256
[generate] cache_limit=960, max_prompt_len_seen=332, using max_prompt_length=256, max_generation_steps=256
[generate] cache_limit=960, max_prompt_len_seen=258, using max_prompt_length=256, max_generation_steps=256
[generate] cache_limit=960, max_prompt_len_seen=250, using max_prompt_length=250, max_generation_steps=256
[generate] cache_limit=960, max_prompt_len_seen=252, using max_prompt_length=252, max_generation_steps=256
[generate] cache_limit=960, max_prompt_len_seen=247, using max_prompt_length=247, max_generation_steps=256
[generate] cache_limit=960, max_prompt_len_seen=250, using max_prompt_length=250, max_generation_steps=256
[generate] cache_limit=960, max_prompt_len_seen=290, using max_prompt_length=256, max_generation_steps=256
[generate] cache_limit=960, max_promp

## Train

Let's set up all the configs first - checkpointing, metric logging and training.
We then train the model.

In [None]:
# Metrics logger
# We rely on Tunix's built-in checkpointing defaults and only configure
# metrics logging explicitly here.
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/content/tmp/tensorboard/grpo", flush_every_n_steps=20
)

In [None]:
# Optimizer, learning rate scheduler, gradient clipping
optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)
if MAX_GRAD_NORM is not None:
  optimizer = optax.chain(
      optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
      optimizer,
  )

In [None]:
# Training config
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine='vanilla',
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=EVAL_EVERY_N_STEPS,
        max_steps=MAX_STEPS,
        mini_batch_size=TRAIN_MICRO_BATCH_SIZE,
        train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
        # metrics logging
        metrics_logging_options=metrics_logging_options,
        # checkpoint saving
        checkpoint_root_directory=CKPT_DIR,
        #checkpointing_options=checkpointing_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
        eos_tokens=[1,106],
    ),
)

grpo_config = GRPOConfig(
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    beta=BETA,
    epsilon=EPSILON,
)

In [None]:
# Ensure RL rollout config uses a safe total sampling length
# so we don't exceed sampler.cache_config.cache_size (704).
cluster_config.rollout_config.max_prompt_length = ROLLOUT_MAX_PROMPT_LENGTH
cluster_config.rollout_config.max_tokens_to_generate = ROLLOUT_GENERATION_STEPS

print("RolloutConfig.max_prompt_length =",
      cluster_config.rollout_config.max_prompt_length)
print("RolloutConfig.max_tokens_to_generate =",
      cluster_config.rollout_config.max_tokens_to_generate)


RolloutConfig.max_prompt_length = 256
RolloutConfig.max_tokens_to_generate = 256


### Setting Up the GRPO Trainer

Now we initialize our system for training. First, we create an `RLCluster` instance, which brings together the **policy model (`actor`)**, a **reference model (`reference`)**, and a **tokenizer**. Our `actor` is a trainable LoRA model, while the `reference` is a fixed base model that we use to guide the training.

We then create a `GRPOLearner`, the specialized trainer that uses a list of **reward functions** to evaluate and optimize the model's output, completing the RL training setup.

Tunix trainers are integrated with [Weights & Biases](https://wandb.ai/) to help you visualize the training progress. You can choose how you want to use it:

**Option 1 (Type 1)**: If you're running a quick experiment or just testing things out, choose this. It creates a temporary, private dashboard right in your browser without requiring you to log in or create an account.

**Option 2 (Type 2)**: If you have an existing W&B account and want to save your project's history to your personal dashboard, choose this. You'll be prompted to enter your API key or log in.

In [None]:
# RL cluster\n
rl_cluster = rl_cluster_lib.RLCluster(
    actor=lora_policy,
    reference=ref_model,
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)

# DSA-GRPO Trainer: use composite DSA-aware reward\n
grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[dsa_grpo_reward],
    grpo_config=grpo_config,
)

The first couple of training step might take up to 5 minutes to finish. Please be patient. If you experience long training steps, e.g. >10 minutes per step, please open a bug. Really appreciated!

In [None]:
with mesh:
  grpo_trainer.train(train_dataset)

Question: Maria has 4 dimes, 4 quarters, and 7 nickels in her piggy bank. Her mom gives her 5 quarters. How much money, in dollars, does Maria have now?
Answer: 3
Response: Please provide the question! I need the question to generate the Monologue and Answer streams as specified.

Extracted: None


Actor Training:   0%|          | 0/3738 [00:00<?, ?step/s]

ValueError: Total sampling steps 768 must be less than the cache size 704.

## Evaluate

Let's evaluate our finetuned model!

In [None]:
# For evaluation we directly use the in-memory `lora_policy`.
# After the GRPO training loop above finishes, `lora_policy` already
# contains the updated LoRA parameters, so we can construct a sampler
# and run evaluation without performing an explicit Orbax restore here.
#
# If you later want to resume from disk, you could add restore logic
# using Tunix's checkpoint manager and Orbax, but that is not required
# for the competition workflow.


In [None]:
sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 512,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

In [None]:
# The evaluation might take up to couple of minutes to finish. Please be patient.
(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(
    test_dataset,
    sampler,
    **GENERATION_CONFIGS["greedy"],
)
print(
    f"{corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,"
    f" {format_accuracy=}%"
)

With sufficient training, you should see that the percentages of correct model outputs have clearly gone up, which means our training worked.

In [None]:
# Dual-Stream demo: shows Kaggle-style output fields.
demo_question = "If I have 5 apples and eat 2, how many are left? Answer with a number."

with mesh:
  demo_monologue, demo_answer = generate_dual_stream(
      demo_question,
      sampler,
      **GENERATION_CONFIGS["greedy"],
  )

print("model_thinking_trace:", demo_monologue)
print("model_answer:", demo_answer)


In [None]:
# Export the latest trained actor checkpoint for submission/use.
# This copies the most recent actor checkpoint directory from CKPT_DIR
# into /kaggle/working so it is preserved as a Kaggle artifact.

import os
import shutil

actor_root = os.path.join(CKPT_DIR, "actor")
if not os.path.isdir(actor_root):
  raise ValueError(f"No actor checkpoints found under {actor_root!r}.")

# Find the latest step_* directory.
step_dirs = [
    d for d in os.listdir(actor_root)
    if d.startswith("step_") and os.path.isdir(os.path.join(actor_root, d))
]
if not step_dirs:
  raise ValueError(f"No step_* subdirectories found under {actor_root!r}.")

latest_step = max(int(d.replace("step_", "")) for d in step_dirs)
latest_ckpt_path = os.path.join(actor_root, f"step_{latest_step}")
print(f"Latest actor checkpoint: {latest_ckpt_path}")

export_dir = "/kaggle/working/dsagrpo_gemma3_1b_final"
if os.path.exists(export_dir):
  shutil.rmtree(export_dir)

shutil.copytree(latest_ckpt_path, export_dir)
print(f"Copied final checkpoint directory to: {export_dir}")

# Optionally, also zip it so there's a single file artifact.
zip_base = "/kaggle/working/dsagrpo_gemma3_1b_final"
shutil.make_archive(zip_base, "zip", root_dir=export_dir)
print(f"Zipped final checkpoint to: {zip_base}.zip")


# Task
Okay, I will remove the following cells:

*   **`y_GHl_szwCAG`**: `os.environ["HF_HUB_DISABLE_XET"] = "1"`
*   **`QaVzUYC_wCAH`**: `from kaggle_secrets import UserSecretsClient`
*   **`953CnpN5xQc_`**: `import wandb, os` (which sets `os.environ['WANDB_API_KEY']`)

<execute_request>
{"delete_cells": ["y_GHl_szwCAG", "QaVzUYC_wCAH", "953CnpN5xQc_"]}
</execute_request>

## Remove Environment Variables Cell

### Subtask:
Identify and delete the code cell that sets environment variables using `os.environ["HF_HUB_DISABLE_XET"] = "1"` and the cells related to `kaggle_secrets` (`QaVzUYC_wCAH` and `953CnpN5xQc_`).


Please delete the following code cells from the notebook, as they set environment variables that are either unnecessary or handled differently:

*   **Cell ID `y_GHl_szwCAG`**: `import os; os.environ["HF_HUB_DISABLE_XET"] = "1"`
*   **Cell ID `QaVzUYC_wCAH`**: `from kaggle_secrets import UserSecretsClient; ...` (related to `kaggle_secrets`)
*   **Cell ID `953CnpN5xQc_`**: `import wandb, os; os.environ['WANDB_API_KEY'] = ...` (related to `kaggle_secrets`)

These cells contain redundant or potentially problematic environment variable settings and secret handling for this specific notebook's workflow after the initial setup phase.

## Final Task

### Subtask:
Confirm that the specified environment variable blocks have been removed from the notebook.


## Summary:

### Q&A
The specified environment variable blocks have not been removed automatically. Instead, instructions have been provided for manual removal due to system limitations.

### Data Analysis Key Findings
*   The system was unable to programmatically delete the specified code cells.
*   Instructions were provided to the user to manually delete three specific cells:
    *   **`y_GHl_szwCAG`**: Contained `os.environ["HF_HUB_DISABLE_XET"] = "1"`.
    *   **`QaVzUYC_wCAH`**: Related to `from kaggle_secrets import UserSecretsClient`.
    *   **`953CnpN5xQc_`**: Involved `import wandb, os` and setting `os.environ['WANDB_API_KEY']`.
*   The cells were marked for removal because they contained "redundant or potentially problematic environment variable settings and secret handling for this specific notebook's workflow."

### Insights or Next Steps
*   Future enhancements should explore direct programmatic cell deletion capabilities to automate such cleanup tasks.
*   The user must manually verify the removal of the specified cells to ensure the notebook's environment variable configuration is correct.
