diff --git a/examples/rl/grpo/gsm8k/run_qwen3_8b.sh b/examples/rl/grpo/gsm8k/run_qwen3_8b.sh new file mode 100755 index 000000000..e40ed7c44 --- /dev/null +++ b/examples/rl/grpo/gsm8k/run_qwen3_8b.sh @@ -0,0 +1,164 @@ +#!/bin/bash +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Agentic GSM8K GRPO launcher for Qwen3 8B using +# tunix/cli/base_agentic_config.yaml plus explicit CLI overrides. +# +# Usage: +# bash /examples/rl/grpo/gsm8k/run_qwen3_8b.sh +# +# Run from the tunix repo root. + +set -euo pipefail + +export SKIP_JAX_PRECOMPILE=true + +model_name="${model_name:-Qwen3-8B}" +model_id="${model_id:-Qwen/Qwen3-8B}" +tokenizer_path="${tokenizer_path:-$model_id}" + +batch_size="${batch_size:-8}" +num_batches="${num_batches:-934}" +num_train_epochs="${num_train_epochs:-1}" +train_fraction="${train_fraction:-1.0}" +warmup_ratio="${warmup_ratio:-0.1}" + +mini_batch_size="${mini_batch_size:-8}" +train_micro_batch_size="${train_micro_batch_size:-1}" +rollout_micro_batch_size="${rollout_micro_batch_size:-8}" +compute_logps_micro_batch_size="${compute_logps_micro_batch_size:-1}" + +num_generations="${num_generations:-4}" + +train_mesh="${train_mesh:-(8,1)}" +rollout_mesh="${rollout_mesh:-(1,8)}" + +max_steps=$(awk "BEGIN { + value = $num_batches * $num_train_epochs * $train_fraction; + if (value < 1) value = 1; + printf \"%.0f\", value; +}") +warmup_steps=$(awk "BEGIN { + value = $warmup_ratio * $max_steps; + if (value < 1) value = 1; + printf \"%.0f\", value; +}") +vllm_max_num_seqs=$(awk "BEGIN { + value = $rollout_micro_batch_size * $num_generations; + if (value < 1) value = 1; + printf \"%.0f\", value; +}") + +python -m tunix.cli.grpo_main \ + tunix/cli/base_agentic_config.yaml \ + \ + `# -- Model ------------------------------------------------------------` \ + model_config.model_name="$model_name" \ + model_config.model_id="$model_id" \ + model_config.model_source="huggingface" \ + model_config.rng_seed=42 \ + model_config.model_display=false \ + model_config.remat_config=3 \ + actor_model_config.mesh.shape="$train_mesh" \ + actor_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh.shape="$rollout_mesh" \ + rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + \ + `# -- Data -------------------------------------------------------------` \ + data_source="huggingface" \ + dataset_name="openai/gsm8k:main" \ + \ + `# -- Training loop ----------------------------------------------------` \ + training_mode="agentic_grpo" \ + batch_size="$batch_size" \ + num_batches="$num_batches" \ + num_test_batches=100 \ + num_train_epochs="$num_train_epochs" \ + train_fraction="$train_fraction" \ + reward_functions=["tunix/cli/reward_fn/gsm8k.py"] \ + verl_compatible=false \ + \ + `# -- Rollout engine (vanilla | vllm | sglang_jax) ---------------------` \ + rollout_engine="vllm" \ + offload_to_cpu=false \ + \ + `# -- Rollout config ---------------------------------------------------` \ + rollout_config.max_prompt_length=256 \ + rollout_config.total_generation_steps=768 \ + rollout_config.max_tokens_to_generate=768 \ + rollout_config.temperature=0.9 \ + rollout_config.top_p=1.0 \ + rollout_config.top_k=50 \ + rollout_config.return_logprobs=true \ + \ + `# -- vLLM (used when rollout_engine=vllm) -----------------------------` \ + vllm_config.hbm_utilization=0.4 \ + vllm_config.tpu_backend_type="jax" \ + vllm_config.server_mode=true \ + vllm_config.async_scheduling=true \ + vllm_config.max_num_seqs="$vllm_max_num_seqs" \ + vllm_config.kwargs.kv_cache_metrics=true \ + vllm_config.kwargs.disable_log_stats=false \ + vllm_config.kwargs.enable_prefix_caching=true \ + \ + `# -- Tokenizer / chat parsing ----------------------------------------` \ + chat_parser_config.type="qwen" \ + tokenizer_config.tokenizer_type="huggingface" \ + tokenizer_config.tokenizer_path="$tokenizer_path" \ + tokenizer_config.add_bos=false \ + tokenizer_config.add_eos=false \ + \ + `# -- GRPO algorithm ---------------------------------------------------` \ + agentic_grpo_config.num_generations="$num_generations" \ + agentic_grpo_config.num_iterations=1 \ + agentic_grpo_config.beta=0.08 \ + agentic_grpo_config.epsilon=0.2 \ + agentic_grpo_config.system_prompt="You are given a grade school math problem. Think step by step and respond using ... followed by ... with only the final numeric answer inside ." \ + agentic_grpo_config.max_concurrency=128 \ + agentic_grpo_config.max_response_length=768 \ + agentic_grpo_config.max_turns=1 \ + agentic_grpo_config.context_ratio=1 \ + \ + `# -- Optimizer --------------------------------------------------------` \ + rl_training_config.actor_optimizer_config.opt_type="adamw" \ + rl_training_config.actor_optimizer_config.learning_rate=3e-6 \ + rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \ + rl_training_config.actor_optimizer_config.init_value=0.0 \ + rl_training_config.actor_optimizer_config.peak_value=3e-6 \ + rl_training_config.actor_optimizer_config.end_value=0.0 \ + rl_training_config.actor_optimizer_config.warmup_ratio="$warmup_ratio" \ + rl_training_config.actor_optimizer_config.warmup_steps="$warmup_steps" \ + rl_training_config.actor_optimizer_config.decay_steps="$max_steps" \ + rl_training_config.actor_optimizer_config.b1=0.9 \ + rl_training_config.actor_optimizer_config.b2=0.99 \ + rl_training_config.actor_optimizer_config.weight_decay=0.1 \ + rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \ + \ + `# -- RL training ------------------------------------------------------` \ + rl_training_config.eval_every_n_steps=10 \ + rl_training_config.max_steps="$max_steps" \ + rl_training_config.mini_batch_size="$mini_batch_size" \ + rl_training_config.train_micro_batch_size="$train_micro_batch_size" \ + rl_training_config.rollout_micro_batch_size="$rollout_micro_batch_size" \ + rl_training_config.compute_logps_micro_batch_size="$compute_logps_micro_batch_size" \ + rl_training_config.checkpoint_root_directory="/tmp/tunix/checkpoints/gsm8k_qwen3_8b" \ + rl_training_config.checkpointing_options.save_interval_steps=250 \ + rl_training_config.checkpointing_options.max_to_keep=4 \ + rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/gsm8k_qwen3_8b" \ + rl_training_config.metrics_logging_options.flush_every_n_steps=20 \ + \ + "$@" diff --git a/tests/cli/grpo_main_test.py b/tests/cli/grpo_main_test.py index 80d828396..5d8a8b663 100644 --- a/tests/cli/grpo_main_test.py +++ b/tests/cli/grpo_main_test.py @@ -357,10 +357,9 @@ def test_standard_grpo_dispatches_to_standard(self): """ pipeline = _make_pipeline(extra) self.assertEqual(pipeline.config.get("training_mode", "grpo"), "grpo") - # _run_standard_grpo should be called; we verify no AttributeError on dispatch - with mock.patch.object(pipeline, "_run_standard_grpo") as mock_std: + with mock.patch.object(pipeline, "_run") as mock_run: pipeline.run_grpo_trainer() - mock_std.assert_called_once() + mock_run.assert_called_once_with(mode="grpo") def test_agentic_grpo_dispatches_to_agentic(self): extra = """ @@ -398,9 +397,9 @@ def test_agentic_grpo_dispatches_to_agentic(self): """ pipeline = _make_pipeline(extra) self.assertEqual(pipeline.config["training_mode"], "agentic_grpo") - with mock.patch.object(pipeline, "_run_agentic_grpo") as mock_ag: + with mock.patch.object(pipeline, "_run") as mock_run: pipeline.run_grpo_trainer() - mock_ag.assert_called_once() + mock_run.assert_called_once_with(mode="agentic_grpo") def test_unknown_mode_raises(self): # Build pipeline with standard config then manually set bad mode @@ -418,8 +417,35 @@ def test_unknown_mode_raises(self): """ pipeline = _make_pipeline(extra) pipeline.config["training_mode"] = "bad_mode" - with self.assertRaisesRegex(ValueError, "Unknown training_mode"): - pipeline.run_grpo_trainer() + raw_dataset = mock.Mock() + raw_dataset.__len__ = mock.Mock(return_value=1) + with mock.patch.object(pipeline, "_setup_kubernetes"): + with mock.patch.object(pipeline, "_get_tokenizer", return_value=mock.sentinel.tokenizer): + with mock.patch.object( + pipeline, + "_create_chat_parser", + return_value=mock.sentinel.chat_parser, + ): + with mock.patch.object( + pipeline, + "_load_raw_dataset", + return_value=(raw_dataset, None), + ): + with mock.patch.object(pipeline, "compute_params"): + with mock.patch.object( + grpo_main.data_lib, + "post_init_dataset", + return_value=(mock.sentinel.dataset, None), + ): + with mock.patch.object( + pipeline, + "create_rl_cluster", + return_value=mock.sentinel.rl_cluster, + ): + with self.assertRaisesRegex( + ValueError, "Unsupported training_mode 'bad_mode'" + ): + pipeline.run_grpo_trainer() # --------------------------------------------------------------------------- diff --git a/tests/examples/data/math_dataset_test.py b/tests/examples/data/math_dataset_test.py index ed903baed..f9eb2e9a0 100644 --- a/tests/examples/data/math_dataset_test.py +++ b/tests/examples/data/math_dataset_test.py @@ -148,6 +148,14 @@ def test_parse_huggingface_dataset_name_supports_gsm8k_alias(self): self.assertEqual(dataset_name, "openai/gsm8k") self.assertEqual(config_name, "default") + def test_parse_huggingface_dataset_name_supports_explicit_config(self): + dataset_name, config_name = math_dataset._parse_huggingface_dataset_name( + "openai/gsm8k:main" + ) + + self.assertEqual(dataset_name, "openai/gsm8k") + self.assertEqual(config_name, "main") + def test_create_dataset_uses_huggingface_loader(self): raw_dataset = _BaseDataset([ {"question": "Q3", "answer": "#### 42"}, diff --git a/tunix/cli/grpo_main.py b/tunix/cli/grpo_main.py index 20e466f8e..0a9af1a28 100644 --- a/tunix/cli/grpo_main.py +++ b/tunix/cli/grpo_main.py @@ -34,6 +34,7 @@ import dataclasses import importlib import os +from types import ModuleType from typing import Any from absl import app @@ -51,10 +52,8 @@ from tunix.perf import metrics as perf_metrics from tunix.perf.experimental import export as perf_export_v2 from tunix.rl import rl_cluster as rl_cluster_lib -from tunix.rl.grpo import grpo_learner from tunix.rl.rollout import base_rollout -GrpoConfig = grpo_learner.GrpoConfig _PATHWAYS_BNS = flags.DEFINE_string( "pathways_bns", None, "BNS address of the Pathways server." @@ -86,6 +85,10 @@ class GrpoPipeline(config.HyperParameters): * ``kubernetes_config``: optional Kubernetes env-var and kube-config setup. """ + def __init__(self, argv: list[str], **kwargs): + self.data_module: ModuleType | None = None + super().__init__(argv, **kwargs) + # ------------------------------------------------------------------ # Mesh # ------------------------------------------------------------------ @@ -573,16 +576,22 @@ def compute_params(self, dataset): # Standard GRPO training # ------------------------------------------------------------------ - def _run_standard_grpo(self): - """Execute standard (non-agentic) GRPO training.""" - tokenizer = model_lib.create_tokenizer( + def _get_tokenizer(self): + return model_lib.create_tokenizer( self.config["tokenizer_config"], self.config["tokenizer_config"]["tokenizer_path"], ) + def _get_data_module(self,): + if self.data_module is None: + self.data_module = importlib.import_module(self.config["data_module"]) + return self.data_module + + def _get_dataset(self, tokenizer): if self.config.get("data_module", None): + data_module = self.config.get("data_module", None) dataset = data_lib.get_dataset_from_module( - self.config["data_module"], + data_module, tokenizer, ) elif self.config["data_source"] == "local": @@ -608,23 +617,7 @@ def _run_standard_grpo(self): else: raise ValueError(f"Unsupported data_source {self.config['data_source']}") - self.compute_params(dataset) - dataset, _ = data_lib.post_init_dataset( - dataset, - tokenizer, - batch_size=self.config["batch_size"], - num_batches=self.config.get("num_batches", None), - max_prompt_length=self.config["rollout_config"].get( - "max_prompt_length", None - ), - ) - rl_cluster = self.create_rl_cluster(tokenizer) - grpo_trainer = grpo_learner.GrpoLearner( - rl_cluster=rl_cluster, - reward_fns=self.obtain_reward_fn(), - algo_config=GrpoConfig(**self.config["grpo_config"]), - ) - grpo_trainer.train(dataset) + return dataset # ------------------------------------------------------------------ # Agentic GRPO helpers @@ -671,16 +664,17 @@ def _load_class_from_path(self, dotted_path: str) -> type: module_path, class_name = dotted_path.rsplit(".", 1) return getattr(importlib.import_module(module_path), class_name) - def _load_raw_dataset(self): + def _load_raw_dataset(self, tokenizer): """Load a raw grain.MapDataset from data_module. The module must expose ``create_dataset(**data_config) -> grain.MapDataset`` and optionally a ``batch_fn`` used as ``custom_batch_fn``. """ - module = importlib.import_module(self.config["data_module"]) - data_config = dict(self.config.get("data_config", {})) - dataset = module.create_dataset(**data_config) - batch_fn = getattr(module, "batch_fn", None) + dataset = self._get_dataset(tokenizer) + data_module = ( + self._get_data_module() if self.config.get("data_module", None) else None + ) + batch_fn = getattr(data_module, "batch_fn", None) if data_module else None return dataset, batch_fn def _setup_kubernetes(self) -> None: @@ -707,19 +701,15 @@ def _setup_kubernetes(self) -> None: # Agentic GRPO training # ------------------------------------------------------------------ - def _run_agentic_grpo(self): + def _run(self, mode: str = "grpo"): """Execute agentic GRPO training (DeepScaleR, DeepSWE, etc.).""" - from tunix.rl.agentic.agentic_grpo_learner import GRPOLearner # pylint: disable=g-import-not-at-top - self._setup_kubernetes() - tokenizer = model_lib.create_tokenizer( - self.config["tokenizer_config"], - self.config["tokenizer_config"]["tokenizer_path"], - ) + tokenizer = self._get_tokenizer() + chat_parser = self._create_chat_parser(tokenizer) - raw_dataset, custom_batch_fn = self._load_raw_dataset() + raw_dataset, custom_batch_fn = self._load_raw_dataset(tokenizer) self.compute_params(raw_dataset) dataset, _ = data_lib.post_init_dataset( @@ -737,6 +727,23 @@ def _run_agentic_grpo(self): ) rl_cluster = self.create_rl_cluster(tokenizer) + + if mode == "grpo": + from tunix.rl.grpo import grpo_learner + + grpo_trainer = grpo_learner.GrpoLearner( + rl_cluster=rl_cluster, + reward_fns=self.obtain_reward_fn(), + algo_config=grpo_learner.GrpoConfig(**self.config["grpo_config"]), + ) + grpo_trainer.train(dataset) + return + + # agentic GRPO + if mode != "agentic_grpo": + raise ValueError(f"Unsupported training_mode {mode!r}") + + from tunix.rl.agentic.agentic_grpo_learner import GRPOLearner # pylint: disable=g-import-not-at-top algo_config = self._create_agentic_grpo_config() reward_fns = ( @@ -774,14 +781,7 @@ def _run_agentic_grpo(self): def run_grpo_trainer(self): """Dispatch to standard or agentic GRPO based on training_mode.""" mode = self.config.get("training_mode", "grpo") - if mode == "agentic_grpo": - self._run_agentic_grpo() - elif mode == "grpo": - self._run_standard_grpo() - else: - raise ValueError( - f"Unknown training_mode: {mode!r}. Expected 'grpo' or 'agentic_grpo'." - ) + self._run(mode=mode) def _setup_jax_pathways(pathways_bns: str): diff --git a/tunix/examples/data/math_dataset.py b/tunix/examples/data/math_dataset.py index f5e2abd71..ea8d76bc9 100644 --- a/tunix/examples/data/math_dataset.py +++ b/tunix/examples/data/math_dataset.py @@ -124,13 +124,13 @@ def _parse_huggingface_dataset_name( dataset_name: str, ) -> tuple[str, str | None]: """Parses a Hugging Face dataset name into dataset/config components.""" - if "/" in dataset_name: - return dataset_name, "default" - if ":" in dataset_name: name, config_name = dataset_name.split(":", maxsplit=1) return name, config_name or None + if "/" in dataset_name: + return dataset_name, "default" + return dataset_name, None