diff --git a/examples/rl/grpo/gsm8k/run_qwen3_8b.sh b/examples/rl/grpo/gsm8k/run_qwen3_8b.sh
new file mode 100755
index 000000000..e40ed7c44
--- /dev/null
+++ b/examples/rl/grpo/gsm8k/run_qwen3_8b.sh
@@ -0,0 +1,164 @@
+#!/bin/bash
+# Copyright 2026 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Agentic GSM8K GRPO launcher for Qwen3 8B using
+# tunix/cli/base_agentic_config.yaml plus explicit CLI overrides.
+#
+# Usage:
+# bash /examples/rl/grpo/gsm8k/run_qwen3_8b.sh
+#
+# Run from the tunix repo root.
+
+set -euo pipefail
+
+export SKIP_JAX_PRECOMPILE=true
+
+model_name="${model_name:-Qwen3-8B}"
+model_id="${model_id:-Qwen/Qwen3-8B}"
+tokenizer_path="${tokenizer_path:-$model_id}"
+
+batch_size="${batch_size:-8}"
+num_batches="${num_batches:-934}"
+num_train_epochs="${num_train_epochs:-1}"
+train_fraction="${train_fraction:-1.0}"
+warmup_ratio="${warmup_ratio:-0.1}"
+
+mini_batch_size="${mini_batch_size:-8}"
+train_micro_batch_size="${train_micro_batch_size:-1}"
+rollout_micro_batch_size="${rollout_micro_batch_size:-8}"
+compute_logps_micro_batch_size="${compute_logps_micro_batch_size:-1}"
+
+num_generations="${num_generations:-4}"
+
+train_mesh="${train_mesh:-(8,1)}"
+rollout_mesh="${rollout_mesh:-(1,8)}"
+
+max_steps=$(awk "BEGIN {
+ value = $num_batches * $num_train_epochs * $train_fraction;
+ if (value < 1) value = 1;
+ printf \"%.0f\", value;
+}")
+warmup_steps=$(awk "BEGIN {
+ value = $warmup_ratio * $max_steps;
+ if (value < 1) value = 1;
+ printf \"%.0f\", value;
+}")
+vllm_max_num_seqs=$(awk "BEGIN {
+ value = $rollout_micro_batch_size * $num_generations;
+ if (value < 1) value = 1;
+ printf \"%.0f\", value;
+}")
+
+python -m tunix.cli.grpo_main \
+ tunix/cli/base_agentic_config.yaml \
+ \
+ `# -- Model ------------------------------------------------------------` \
+ model_config.model_name="$model_name" \
+ model_config.model_id="$model_id" \
+ model_config.model_source="huggingface" \
+ model_config.rng_seed=42 \
+ model_config.model_display=false \
+ model_config.remat_config=3 \
+ actor_model_config.mesh.shape="$train_mesh" \
+ actor_model_config.mesh.axis_names="('fsdp','tp')" \
+ reference_model_config.mesh=null \
+ reference_model_config.same_mesh_as="actor" \
+ rollout_model_config.mesh.shape="$rollout_mesh" \
+ rollout_model_config.mesh.axis_names="('fsdp','tp')" \
+ \
+ `# -- Data -------------------------------------------------------------` \
+ data_source="huggingface" \
+ dataset_name="openai/gsm8k:main" \
+ \
+ `# -- Training loop ----------------------------------------------------` \
+ training_mode="agentic_grpo" \
+ batch_size="$batch_size" \
+ num_batches="$num_batches" \
+ num_test_batches=100 \
+ num_train_epochs="$num_train_epochs" \
+ train_fraction="$train_fraction" \
+ reward_functions=["tunix/cli/reward_fn/gsm8k.py"] \
+ verl_compatible=false \
+ \
+ `# -- Rollout engine (vanilla | vllm | sglang_jax) ---------------------` \
+ rollout_engine="vllm" \
+ offload_to_cpu=false \
+ \
+ `# -- Rollout config ---------------------------------------------------` \
+ rollout_config.max_prompt_length=256 \
+ rollout_config.total_generation_steps=768 \
+ rollout_config.max_tokens_to_generate=768 \
+ rollout_config.temperature=0.9 \
+ rollout_config.top_p=1.0 \
+ rollout_config.top_k=50 \
+ rollout_config.return_logprobs=true \
+ \
+ `# -- vLLM (used when rollout_engine=vllm) -----------------------------` \
+ vllm_config.hbm_utilization=0.4 \
+ vllm_config.tpu_backend_type="jax" \
+ vllm_config.server_mode=true \
+ vllm_config.async_scheduling=true \
+ vllm_config.max_num_seqs="$vllm_max_num_seqs" \
+ vllm_config.kwargs.kv_cache_metrics=true \
+ vllm_config.kwargs.disable_log_stats=false \
+ vllm_config.kwargs.enable_prefix_caching=true \
+ \
+ `# -- Tokenizer / chat parsing ----------------------------------------` \
+ chat_parser_config.type="qwen" \
+ tokenizer_config.tokenizer_type="huggingface" \
+ tokenizer_config.tokenizer_path="$tokenizer_path" \
+ tokenizer_config.add_bos=false \
+ tokenizer_config.add_eos=false \
+ \
+ `# -- GRPO algorithm ---------------------------------------------------` \
+ agentic_grpo_config.num_generations="$num_generations" \
+ agentic_grpo_config.num_iterations=1 \
+ agentic_grpo_config.beta=0.08 \
+ agentic_grpo_config.epsilon=0.2 \
+ agentic_grpo_config.system_prompt="You are given a grade school math problem. Think step by step and respond using ... followed by ... with only the final numeric answer inside ." \
+ agentic_grpo_config.max_concurrency=128 \
+ agentic_grpo_config.max_response_length=768 \
+ agentic_grpo_config.max_turns=1 \
+ agentic_grpo_config.context_ratio=1 \
+ \
+ `# -- Optimizer --------------------------------------------------------` \
+ rl_training_config.actor_optimizer_config.opt_type="adamw" \
+ rl_training_config.actor_optimizer_config.learning_rate=3e-6 \
+ rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \
+ rl_training_config.actor_optimizer_config.init_value=0.0 \
+ rl_training_config.actor_optimizer_config.peak_value=3e-6 \
+ rl_training_config.actor_optimizer_config.end_value=0.0 \
+ rl_training_config.actor_optimizer_config.warmup_ratio="$warmup_ratio" \
+ rl_training_config.actor_optimizer_config.warmup_steps="$warmup_steps" \
+ rl_training_config.actor_optimizer_config.decay_steps="$max_steps" \
+ rl_training_config.actor_optimizer_config.b1=0.9 \
+ rl_training_config.actor_optimizer_config.b2=0.99 \
+ rl_training_config.actor_optimizer_config.weight_decay=0.1 \
+ rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \
+ \
+ `# -- RL training ------------------------------------------------------` \
+ rl_training_config.eval_every_n_steps=10 \
+ rl_training_config.max_steps="$max_steps" \
+ rl_training_config.mini_batch_size="$mini_batch_size" \
+ rl_training_config.train_micro_batch_size="$train_micro_batch_size" \
+ rl_training_config.rollout_micro_batch_size="$rollout_micro_batch_size" \
+ rl_training_config.compute_logps_micro_batch_size="$compute_logps_micro_batch_size" \
+ rl_training_config.checkpoint_root_directory="/tmp/tunix/checkpoints/gsm8k_qwen3_8b" \
+ rl_training_config.checkpointing_options.save_interval_steps=250 \
+ rl_training_config.checkpointing_options.max_to_keep=4 \
+ rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/gsm8k_qwen3_8b" \
+ rl_training_config.metrics_logging_options.flush_every_n_steps=20 \
+ \
+ "$@"
diff --git a/tests/cli/grpo_main_test.py b/tests/cli/grpo_main_test.py
index 80d828396..5d8a8b663 100644
--- a/tests/cli/grpo_main_test.py
+++ b/tests/cli/grpo_main_test.py
@@ -357,10 +357,9 @@ def test_standard_grpo_dispatches_to_standard(self):
"""
pipeline = _make_pipeline(extra)
self.assertEqual(pipeline.config.get("training_mode", "grpo"), "grpo")
- # _run_standard_grpo should be called; we verify no AttributeError on dispatch
- with mock.patch.object(pipeline, "_run_standard_grpo") as mock_std:
+ with mock.patch.object(pipeline, "_run") as mock_run:
pipeline.run_grpo_trainer()
- mock_std.assert_called_once()
+ mock_run.assert_called_once_with(mode="grpo")
def test_agentic_grpo_dispatches_to_agentic(self):
extra = """
@@ -398,9 +397,9 @@ def test_agentic_grpo_dispatches_to_agentic(self):
"""
pipeline = _make_pipeline(extra)
self.assertEqual(pipeline.config["training_mode"], "agentic_grpo")
- with mock.patch.object(pipeline, "_run_agentic_grpo") as mock_ag:
+ with mock.patch.object(pipeline, "_run") as mock_run:
pipeline.run_grpo_trainer()
- mock_ag.assert_called_once()
+ mock_run.assert_called_once_with(mode="agentic_grpo")
def test_unknown_mode_raises(self):
# Build pipeline with standard config then manually set bad mode
@@ -418,8 +417,35 @@ def test_unknown_mode_raises(self):
"""
pipeline = _make_pipeline(extra)
pipeline.config["training_mode"] = "bad_mode"
- with self.assertRaisesRegex(ValueError, "Unknown training_mode"):
- pipeline.run_grpo_trainer()
+ raw_dataset = mock.Mock()
+ raw_dataset.__len__ = mock.Mock(return_value=1)
+ with mock.patch.object(pipeline, "_setup_kubernetes"):
+ with mock.patch.object(pipeline, "_get_tokenizer", return_value=mock.sentinel.tokenizer):
+ with mock.patch.object(
+ pipeline,
+ "_create_chat_parser",
+ return_value=mock.sentinel.chat_parser,
+ ):
+ with mock.patch.object(
+ pipeline,
+ "_load_raw_dataset",
+ return_value=(raw_dataset, None),
+ ):
+ with mock.patch.object(pipeline, "compute_params"):
+ with mock.patch.object(
+ grpo_main.data_lib,
+ "post_init_dataset",
+ return_value=(mock.sentinel.dataset, None),
+ ):
+ with mock.patch.object(
+ pipeline,
+ "create_rl_cluster",
+ return_value=mock.sentinel.rl_cluster,
+ ):
+ with self.assertRaisesRegex(
+ ValueError, "Unsupported training_mode 'bad_mode'"
+ ):
+ pipeline.run_grpo_trainer()
# ---------------------------------------------------------------------------
diff --git a/tests/examples/data/math_dataset_test.py b/tests/examples/data/math_dataset_test.py
index ed903baed..f9eb2e9a0 100644
--- a/tests/examples/data/math_dataset_test.py
+++ b/tests/examples/data/math_dataset_test.py
@@ -148,6 +148,14 @@ def test_parse_huggingface_dataset_name_supports_gsm8k_alias(self):
self.assertEqual(dataset_name, "openai/gsm8k")
self.assertEqual(config_name, "default")
+ def test_parse_huggingface_dataset_name_supports_explicit_config(self):
+ dataset_name, config_name = math_dataset._parse_huggingface_dataset_name(
+ "openai/gsm8k:main"
+ )
+
+ self.assertEqual(dataset_name, "openai/gsm8k")
+ self.assertEqual(config_name, "main")
+
def test_create_dataset_uses_huggingface_loader(self):
raw_dataset = _BaseDataset([
{"question": "Q3", "answer": "#### 42"},
diff --git a/tunix/cli/grpo_main.py b/tunix/cli/grpo_main.py
index 20e466f8e..0a9af1a28 100644
--- a/tunix/cli/grpo_main.py
+++ b/tunix/cli/grpo_main.py
@@ -34,6 +34,7 @@
import dataclasses
import importlib
import os
+from types import ModuleType
from typing import Any
from absl import app
@@ -51,10 +52,8 @@
from tunix.perf import metrics as perf_metrics
from tunix.perf.experimental import export as perf_export_v2
from tunix.rl import rl_cluster as rl_cluster_lib
-from tunix.rl.grpo import grpo_learner
from tunix.rl.rollout import base_rollout
-GrpoConfig = grpo_learner.GrpoConfig
_PATHWAYS_BNS = flags.DEFINE_string(
"pathways_bns", None, "BNS address of the Pathways server."
@@ -86,6 +85,10 @@ class GrpoPipeline(config.HyperParameters):
* ``kubernetes_config``: optional Kubernetes env-var and kube-config setup.
"""
+ def __init__(self, argv: list[str], **kwargs):
+ self.data_module: ModuleType | None = None
+ super().__init__(argv, **kwargs)
+
# ------------------------------------------------------------------
# Mesh
# ------------------------------------------------------------------
@@ -573,16 +576,22 @@ def compute_params(self, dataset):
# Standard GRPO training
# ------------------------------------------------------------------
- def _run_standard_grpo(self):
- """Execute standard (non-agentic) GRPO training."""
- tokenizer = model_lib.create_tokenizer(
+ def _get_tokenizer(self):
+ return model_lib.create_tokenizer(
self.config["tokenizer_config"],
self.config["tokenizer_config"]["tokenizer_path"],
)
+ def _get_data_module(self,):
+ if self.data_module is None:
+ self.data_module = importlib.import_module(self.config["data_module"])
+ return self.data_module
+
+ def _get_dataset(self, tokenizer):
if self.config.get("data_module", None):
+ data_module = self.config.get("data_module", None)
dataset = data_lib.get_dataset_from_module(
- self.config["data_module"],
+ data_module,
tokenizer,
)
elif self.config["data_source"] == "local":
@@ -608,23 +617,7 @@ def _run_standard_grpo(self):
else:
raise ValueError(f"Unsupported data_source {self.config['data_source']}")
- self.compute_params(dataset)
- dataset, _ = data_lib.post_init_dataset(
- dataset,
- tokenizer,
- batch_size=self.config["batch_size"],
- num_batches=self.config.get("num_batches", None),
- max_prompt_length=self.config["rollout_config"].get(
- "max_prompt_length", None
- ),
- )
- rl_cluster = self.create_rl_cluster(tokenizer)
- grpo_trainer = grpo_learner.GrpoLearner(
- rl_cluster=rl_cluster,
- reward_fns=self.obtain_reward_fn(),
- algo_config=GrpoConfig(**self.config["grpo_config"]),
- )
- grpo_trainer.train(dataset)
+ return dataset
# ------------------------------------------------------------------
# Agentic GRPO helpers
@@ -671,16 +664,17 @@ def _load_class_from_path(self, dotted_path: str) -> type:
module_path, class_name = dotted_path.rsplit(".", 1)
return getattr(importlib.import_module(module_path), class_name)
- def _load_raw_dataset(self):
+ def _load_raw_dataset(self, tokenizer):
"""Load a raw grain.MapDataset from data_module.
The module must expose ``create_dataset(**data_config) -> grain.MapDataset``
and optionally a ``batch_fn`` used as ``custom_batch_fn``.
"""
- module = importlib.import_module(self.config["data_module"])
- data_config = dict(self.config.get("data_config", {}))
- dataset = module.create_dataset(**data_config)
- batch_fn = getattr(module, "batch_fn", None)
+ dataset = self._get_dataset(tokenizer)
+ data_module = (
+ self._get_data_module() if self.config.get("data_module", None) else None
+ )
+ batch_fn = getattr(data_module, "batch_fn", None) if data_module else None
return dataset, batch_fn
def _setup_kubernetes(self) -> None:
@@ -707,19 +701,15 @@ def _setup_kubernetes(self) -> None:
# Agentic GRPO training
# ------------------------------------------------------------------
- def _run_agentic_grpo(self):
+ def _run(self, mode: str = "grpo"):
"""Execute agentic GRPO training (DeepScaleR, DeepSWE, etc.)."""
- from tunix.rl.agentic.agentic_grpo_learner import GRPOLearner # pylint: disable=g-import-not-at-top
-
self._setup_kubernetes()
- tokenizer = model_lib.create_tokenizer(
- self.config["tokenizer_config"],
- self.config["tokenizer_config"]["tokenizer_path"],
- )
+ tokenizer = self._get_tokenizer()
+
chat_parser = self._create_chat_parser(tokenizer)
- raw_dataset, custom_batch_fn = self._load_raw_dataset()
+ raw_dataset, custom_batch_fn = self._load_raw_dataset(tokenizer)
self.compute_params(raw_dataset)
dataset, _ = data_lib.post_init_dataset(
@@ -737,6 +727,23 @@ def _run_agentic_grpo(self):
)
rl_cluster = self.create_rl_cluster(tokenizer)
+
+ if mode == "grpo":
+ from tunix.rl.grpo import grpo_learner
+
+ grpo_trainer = grpo_learner.GrpoLearner(
+ rl_cluster=rl_cluster,
+ reward_fns=self.obtain_reward_fn(),
+ algo_config=grpo_learner.GrpoConfig(**self.config["grpo_config"]),
+ )
+ grpo_trainer.train(dataset)
+ return
+
+ # agentic GRPO
+ if mode != "agentic_grpo":
+ raise ValueError(f"Unsupported training_mode {mode!r}")
+
+ from tunix.rl.agentic.agentic_grpo_learner import GRPOLearner # pylint: disable=g-import-not-at-top
algo_config = self._create_agentic_grpo_config()
reward_fns = (
@@ -774,14 +781,7 @@ def _run_agentic_grpo(self):
def run_grpo_trainer(self):
"""Dispatch to standard or agentic GRPO based on training_mode."""
mode = self.config.get("training_mode", "grpo")
- if mode == "agentic_grpo":
- self._run_agentic_grpo()
- elif mode == "grpo":
- self._run_standard_grpo()
- else:
- raise ValueError(
- f"Unknown training_mode: {mode!r}. Expected 'grpo' or 'agentic_grpo'."
- )
+ self._run(mode=mode)
def _setup_jax_pathways(pathways_bns: str):
diff --git a/tunix/examples/data/math_dataset.py b/tunix/examples/data/math_dataset.py
index f5e2abd71..ea8d76bc9 100644
--- a/tunix/examples/data/math_dataset.py
+++ b/tunix/examples/data/math_dataset.py
@@ -124,13 +124,13 @@ def _parse_huggingface_dataset_name(
dataset_name: str,
) -> tuple[str, str | None]:
"""Parses a Hugging Face dataset name into dataset/config components."""
- if "/" in dataset_name:
- return dataset_name, "default"
-
if ":" in dataset_name:
name, config_name = dataset_name.split(":", maxsplit=1)
return name, config_name or None
+ if "/" in dataset_name:
+ return dataset_name, "default"
+
return dataset_name, None