diff --git a/tests/rl/rl_cluster_test.py b/tests/rl/rl_cluster_test.py index 0e3096372..1d5671596 100644 --- a/tests/rl/rl_cluster_test.py +++ b/tests/rl/rl_cluster_test.py @@ -32,6 +32,7 @@ from tunix.rl import rl_cluster as rl_cluster_lib from tunix.rl import utils from tunix.rl.rollout import base_rollout +from tunix.rl.rollout import mock_rollout from tunix.tests import test_common as tc # Some tests relying on SGLang and vLLM cannot run in run_prod environment. @@ -468,6 +469,48 @@ def __init__(self): cluster_config=cluster_config, ) + @parameterized.named_parameters( + dict( + testcase_name='single_config', + rollout_config=base_rollout.RolloutConfig( + max_tokens_to_generate=10, kv_cache_size=1024 + ), + expected_train_config=base_rollout.RolloutConfig( + max_tokens_to_generate=10, kv_cache_size=1024 + ), + ), + dict( + testcase_name='dict_config', + rollout_config={ + rl_cluster_lib.Mode.TRAIN: base_rollout.RolloutConfig( + max_tokens_to_generate=10, kv_cache_size=1024 + ), + rl_cluster_lib.Mode.EVAL: base_rollout.RolloutConfig( + max_tokens_to_generate=20, kv_cache_size=2048 + ), + }, + expected_train_config=base_rollout.RolloutConfig( + max_tokens_to_generate=10, kv_cache_size=1024 + ), + ), + ) + def test_init_mock_rollout_engine( + self, + rollout_config, + expected_train_config, + ): + with mock.patch.object( + mock_rollout.MockRollout, '__init__', autospec=True, return_value=None + ) as mock_init: + rl_cluster = self._create_test_rl_cluster( + mock_rollout.MockRollout, rollout_config + ) + + mock_init.assert_called_once() + self.assertIsInstance(rl_cluster.rollout, mock_rollout.MockRollout) + called_kwargs = mock_init.call_args.kwargs + self.assertEqual(called_kwargs['rollout_config'], expected_train_config) + @parameterized.named_parameters( dict( testcase_name='single_config', diff --git a/tunix/rl/rl_cluster.py b/tunix/rl/rl_cluster.py index 73d49b376..1f9619fec 100644 --- a/tunix/rl/rl_cluster.py +++ b/tunix/rl/rl_cluster.py @@ -38,9 +38,9 @@ import jaxtyping import numpy as np import optax +from tunix.generate import tokenizer_adapter # Internal placeholder for sglang_jax rollout worker stub, don't change this line. # Internal placeholder for vllm rollout worker stub, don't change this line. -from tunix.generate import tokenizer_adapter from tunix.perf import metrics as perf_metrics from tunix.perf import trace as perf_trace from tunix.perf.experimental import constants as perf_constants @@ -369,8 +369,8 @@ def _init_cluster(self): "sglang_jax", ]: raise ValueError( - "`cluster_config.rollout_engine` should be one of `'vanilla'` or" - " `'vllm'` or `'sglang_jax'`. Received:" + "`cluster_config.rollout_engine` should be one of `'vanilla'`, " + "`'vllm'`, or `'sglang_jax'`. Received:" f" '{self.cluster_config.rollout_engine}'." ) @@ -467,11 +467,16 @@ def _init_cluster(self): base_rollout.BaseRollout, ) ): + if isinstance(self.cluster_config.rollout_config, dict): + loaded_config = self.cluster_config.rollout_config[Mode.TRAIN] + else: + loaded_config = self.cluster_config.rollout_config + self._rollout = self.cluster_config.rollout_engine( rollout_actor=self.rollout_actor, tokenizer=self.tokenizer, mesh=self.r2m[Role.ROLLOUT], - rollout_config=self.cluster_config.rollout_config, + rollout_config=loaded_config, ) else: raise NotImplementedError( diff --git a/tunix/rl/rollout/mock_rollout.py b/tunix/rl/rollout/mock_rollout.py new file mode 100644 index 000000000..5826897be --- /dev/null +++ b/tunix/rl/rollout/mock_rollout.py @@ -0,0 +1,260 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Mock rollout worker.""" + +from collections.abc import Sequence +import random +import time +from typing import Any + +from absl import logging +import jax +import jaxtyping +import numpy as np +from tunix.rl.rollout import base_rollout + +_DUMMY_WORDS = ( + "mock", + "test", + "token", + "rollout", + "random", + "data", + "output", + "engine", +) + + +class MockRollout(base_rollout.BaseRollout): + """Mock rollout worker for testing RL pipelines. + + This engine simulates the behavior of a real LLM rollout worker (like vLLM + or SG-Lang) without requiring an actual model, weights, or accelerator + resources. + + The mock rollout is particularly useful for testing and benchmarking + performance, especially on the training and communication side, without + needing huge TPU clusters. For example, it can be used for benchmarking + new optimizations, or collecting mock perf metric traces at scale. + + Behaviors mocked: + * Text Generation: Produces sequences of random dummy words. + * Tokenization: Uses the provided tokenizer to encode/decode the dummy text, + or falls back to generating random token IDs if no tokenizer is provided. + * Latency: Simulates inference delay by sleeping for a random duration + between `min_generation_time` and + `max_generation_time`. + * Tensors (Logits/Logprobs/Logps): Returns zero-filled numpy arrays of the + correct shapes to simulate model outputs while keeping memory on the host. + * Parameter Updates: `update_params` is a no-op. + * Reproducibility: Fully supports seeding via `RolloutConfig.seed` for + deterministic testing. + + Kwargs: + min_generation_time (float): Minimum simulated generation delay in seconds + (default: 1.0). + max_generation_time (float): Maximum simulated generation delay in seconds + (default: 10.0). + """ + + def __init__( + self, + rollout_actor: Any | None = None, + tokenizer: Any | None = None, + vocab_size: int | None = None, + pad_id: int | None = None, + eos_id: int | None = None, + rollout_config: base_rollout.RolloutConfig | None = None, + **kwargs, + ): + self._model = rollout_actor + self._tokenizer = tokenizer + self._vocab_size = vocab_size if vocab_size is not None else 32_000 + self._pad_id = pad_id if pad_id is not None else 0 + self._eos_id = eos_id if eos_id is not None else 1 + self._min_generation_time = kwargs.get("min_generation_time", 1.0) + self._max_generation_time = kwargs.get("max_generation_time", 10.0) + + seed_val = None + if rollout_config is not None and rollout_config.seed is not None: + seed_val = int( + rollout_config.seed.item() + if isinstance(rollout_config.seed, jax.Array) + else rollout_config.seed + ) + + if seed_val is not None: + self._rng = random.Random(seed_val) + self._np_rng = np.random.default_rng(seed_val) + else: + self._rng = random.Random() + self._np_rng = np.random.default_rng() + + def _encode_text(self, text: str) -> np.ndarray | None: + """Attempts to encode text using the tokenizer, returning None on failure.""" + if self._tokenizer is not None and hasattr(self._tokenizer, "encode"): + try: + return np.array(self._tokenizer.encode(text), dtype=np.int32) + except Exception as e: # pylint: disable=broad-except + logging.log_every_n( + logging.WARNING, "Tokenization failed in mock_rollout: %s", 100, e + ) + return None + + def generate( + self, + prompts: str | Sequence[str], + rollout_config: base_rollout.RolloutConfig, + **kwargs, + ) -> base_rollout.RolloutOutput: + """Generates random samples and simulates time delay. + + Args: + prompts: A list of text prompts for generation. + rollout_config: Configuration settings for generation and mock behavior. + **kwargs: Additional generation arguments. + + Returns: + A RolloutOutput containing the mock generated texts, tokens, and tensors. + """ + if isinstance(prompts, str): + prompts = [prompts] + + rng = self._rng + np_rng = self._np_rng + + min_generation_time = self._min_generation_time + max_generation_time = self._max_generation_time + + sleep_time = rng.uniform(min_generation_time, max_generation_time) + time.sleep(sleep_time) + + batch_size = len(prompts) + max_tokens = rollout_config.max_tokens_to_generate + # Fallback to at least 1 token if max_tokens is less than 1 + max_tokens = max(1, max_tokens) + + texts = [] + logits_list = [] + tokens_list = [] + + left_padded_prompt_tokens = np.full( + (batch_size, rollout_config.max_prompt_length), + self.pad_id(), + dtype=np.int32, + ) + + for i in range(batch_size): + prompt = prompts[i] + target_length = rng.randint(1, max_tokens) + chosen_words = rng.choices(_DUMMY_WORDS, k=target_length) + text = " ".join(chosen_words) + + # 1. Tokenize the prompt for left_padded_prompt_tokens + prompt_tokens = self._encode_text(prompt) + if prompt_tokens is not None: + if len(prompt_tokens) > rollout_config.max_prompt_length: + # Truncate to fit, keeping the suffix for left-padding + prompt_tokens = prompt_tokens[-rollout_config.max_prompt_length :] + + start_idx = rollout_config.max_prompt_length - len(prompt_tokens) + left_padded_prompt_tokens[i, start_idx:] = prompt_tokens + + # 2. Tokenize the generated completion + tokens = self._encode_text(text) + if tokens is not None: + if len(tokens) > max_tokens: + tokens = tokens[:max_tokens] + elif len(tokens) == 0: + tokens = np_rng.integers( + 0, self._vocab_size, size=(1,), dtype=np.int32 + ) + + length = len(tokens) + if hasattr(self._tokenizer, "decode"): + text = self._tokenizer.decode(tokens.tolist()) + else: + length = target_length + tokens = np_rng.integers( + 0, self._vocab_size, size=(length,), dtype=np.int32 + ) + + tokens_list.append(tokens) + texts.append(text) + + logits = np.zeros((length, self._vocab_size), dtype=np.float16) + logits_list.append(logits) + + if rollout_config.return_logprobs: + logprobs_list = [np.zeros(len(t), dtype=np.float32) for t in tokens_list] + else: + logprobs_list = None + + return base_rollout.RolloutOutput( + text=texts, + logits=logits_list, + tokens=tokens_list, + left_padded_prompt_tokens=left_padded_prompt_tokens, + logprobs=logprobs_list, + ) + + def get_per_token_logps( + self, + prompt_tokens: jax.Array, + completion_tokens: jax.Array, + completion_mask: jax.Array | None = None, + ) -> jax.Array: + """Returns mock per-token log probabilities. + + Args: + prompt_tokens: The tokens of the input prompts. + completion_tokens: The generated completion tokens. + completion_mask: An optional mask indicating valid completion tokens. + + Returns: + A zero-filled array of shape (batch_size, length) representing mock + log probabilities. + """ + batch_size, length = completion_tokens.shape + # Use numpy to keep it on host memory. + return np.zeros((batch_size, length), dtype=np.float32) + + def update_params( + self, + params: jaxtyping.PyTree, + filter_types: tuple[Any, ...] | None = None, + ) -> None: + """Mock update params. + + Args: + params: A PyTree of parameters to update. + filter_types: Optional types to filter which parameters to update. + """ + pass + + def pad_id(self) -> int: + if self._tokenizer is not None and hasattr(self._tokenizer, "pad_id"): + pad_id_attr = self._tokenizer.pad_id + return pad_id_attr() if callable(pad_id_attr) else pad_id_attr + return self._pad_id + + def eos_id(self) -> int: + if self._tokenizer is not None and hasattr(self._tokenizer, "eos_id"): + eos_id_attr = self._tokenizer.eos_id + return eos_id_attr() if callable(eos_id_attr) else eos_id_attr + return self._eos_id + + def model(self) -> Any: + return self._model diff --git a/tunix/rl/rollout/test_mock_rollout.py b/tunix/rl/rollout/test_mock_rollout.py new file mode 100644 index 000000000..a2ecb2330 --- /dev/null +++ b/tunix/rl/rollout/test_mock_rollout.py @@ -0,0 +1,214 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +import jax.numpy as jnp +import numpy as np +from tunix.rl.rollout import base_rollout +from tunix.rl.rollout import mock_rollout + + +class FakeTokenizer: + """A fake tokenizer for testing purposes. + + This class simulates a tokenizer with configurable vocab size, pad/eos IDs, + and an option to raise an exception during encoding. + """ + + def __init__(self, vocab_size=100, pad_id=0, eos_id=1, fail_encode=False): + self.vocab_size = vocab_size + self._pad_id = pad_id + self._eos_id = eos_id + self._fail_encode = fail_encode + + def encode(self, text): + if self._fail_encode: + raise ValueError("Encode failed") + # Return dummy token IDs based on string length to simulate tokenization + return [min(i, self.vocab_size - 1) for i in range(len(text))] + + def decode(self, tokens): + return "decoded_" + "_".join(str(t) for t in tokens) + + def pad_id(self): + return self._pad_id + + def eos_id(self): + return self._eos_id + + +class FakeTokenizerProperties: + # Some tokenizers have these as properties rather than functions + def __init__(self, pad_id=0, eos_id=1): + self.pad_id = pad_id + self.eos_id = eos_id + + +class MockRolloutTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.base_rc = base_rollout.RolloutConfig( + max_prompt_length=10, + max_tokens_to_generate=15, + return_logprobs=True, + seed=42, + ) + + def _create_mock_rollout(self, **kwargs): + kwargs.setdefault("vocab_size", 100) + kwargs.setdefault("pad_id", 0) + kwargs.setdefault("eos_id", 1) + kwargs.setdefault("min_generation_time", 0.01) + kwargs.setdefault("max_generation_time", 0.02) + return mock_rollout.MockRollout(**kwargs) + + @mock.patch.object(mock_rollout.time, "sleep", autospec=True) + def test_generate_basic(self, mock_sleep): + m = self._create_mock_rollout() + out = m.generate(["prompt 1", "prompt 2"], rollout_config=self.base_rc) + + self.assertLen(out.text, 2) + self.assertLen(out.logits, 2) + self.assertLen(out.tokens, 2) + self.assertLen(out.logprobs, 2) + self.assertEqual(out.left_padded_prompt_tokens.shape, (2, 10)) + self.assertTrue(mock_sleep.called) + + @mock.patch.object(mock_rollout.time, "sleep", autospec=True) + def test_generate_single_prompt(self, mock_sleep): + m = self._create_mock_rollout() + out = m.generate("single prompt", rollout_config=self.base_rc) + + self.assertLen(out.text, 1) + self.assertLen(out.logits, 1) + self.assertLen(out.tokens, 1) + self.assertLen(out.logprobs, 1) + self.assertEqual(out.left_padded_prompt_tokens.shape, (1, 10)) + + @mock.patch.object(mock_rollout.time, "sleep", autospec=True) + def test_generate_no_logprobs(self, mock_sleep): + m = self._create_mock_rollout() + rc = dataclasses.replace(self.base_rc, return_logprobs=False) + out = m.generate(["prompt 1"], rollout_config=rc) + + self.assertIsNone(out.logprobs) + + @mock.patch.object(mock_rollout.time, "sleep", autospec=True) + def test_generate_with_tokenizer(self, mock_sleep): + tokenizer = FakeTokenizer(vocab_size=100) + m = self._create_mock_rollout(tokenizer=tokenizer) + out = m.generate(["prompt"], rollout_config=self.base_rc) + + self.assertLen(out.text, 1) + self.assertNotEmpty(out.tokens[0]) + self.assertTrue(out.text[0].startswith("decoded_")) + + @mock.patch.object(mock_rollout.time, "sleep", autospec=True) + def test_generate_tokenizer_encode_exception(self, mock_sleep): + tokenizer = FakeTokenizer(vocab_size=100, fail_encode=True) + m = self._create_mock_rollout(tokenizer=tokenizer) + out = m.generate(["prompt"], rollout_config=self.base_rc) + + self.assertLen(out.tokens[0], len(out.text[0].split())) + + @mock.patch.object(mock_rollout.time, "sleep", autospec=True) + def test_generate_reproducibility_with_seed(self, mock_sleep): + rollout_config = dataclasses.replace(self.base_rc, seed=42) + m1 = self._create_mock_rollout(rollout_config=rollout_config) + m2 = self._create_mock_rollout(rollout_config=rollout_config) + + out1 = m1.generate(["prompt 1", "prompt 2"], rollout_config=rollout_config) + out2 = m2.generate(["prompt 1", "prompt 2"], rollout_config=rollout_config) + + self.assertEqual(out1.text, out2.text) + np.testing.assert_array_equal(out1.tokens[0], out2.tokens[0]) + np.testing.assert_array_equal(out1.tokens[1], out2.tokens[1]) + np.testing.assert_array_equal(out1.logits[0], out2.logits[0]) + np.testing.assert_array_equal(out1.logits[1], out2.logits[1]) + + @mock.patch.object(mock_rollout.time, "sleep", autospec=True) + def test_generate_reproducibility_with_jax_seed(self, mock_sleep): + rollout_config = dataclasses.replace(self.base_rc, seed=jnp.array(42)) + m1 = self._create_mock_rollout(rollout_config=rollout_config) + m2 = self._create_mock_rollout(rollout_config=rollout_config) + + out1 = m1.generate(["prompt"], rollout_config=rollout_config) + out2 = m2.generate(["prompt"], rollout_config=rollout_config) + + self.assertEqual(out1.text, out2.text) + np.testing.assert_array_equal(out1.tokens[0], out2.tokens[0]) + np.testing.assert_array_equal(out1.logits[0], out2.logits[0]) + + @mock.patch.object(mock_rollout.time, "sleep", autospec=True) + def test_sleep_time_bounds(self, mock_sleep): + m = self._create_mock_rollout() + m.generate(["prompt"], rollout_config=self.base_rc) + + mock_sleep.assert_called_once() + sleep_time = mock_sleep.call_args[0][0] + self.assertBetween(sleep_time, 0.01, 0.02) + + def test_get_per_token_logps(self): + m = self._create_mock_rollout() + prompt_tokens = jnp.zeros((2, 5)) + completion_tokens = jnp.ones((2, 10)) + + logps = m.get_per_token_logps(prompt_tokens, completion_tokens) + + self.assertEqual(logps.shape, (2, 10)) + np.testing.assert_array_equal(logps, np.zeros((2, 10), dtype=np.float32)) + + def test_update_params(self): + m = self._create_mock_rollout() + # update_params is a no-op, just ensure it doesn't raise + m.update_params({"dummy": "tree"}) + + @parameterized.named_parameters( + ("without_tokenizer", None, 5, 10), + ( + "with_callable_tokenizer_methods", + FakeTokenizer(pad_id=7, eos_id=14), + 7, + 14, + ), + ( + "with_property_tokenizer_attributes", + FakeTokenizerProperties(pad_id=8, eos_id=16), + 8, + 16, + ), + ) + def test_pad_id_eos_id(self, tokenizer, expected_pad_id, expected_eos_id): + if tokenizer is None: + m = self._create_mock_rollout( + pad_id=expected_pad_id, eos_id=expected_eos_id + ) + else: + m = self._create_mock_rollout(tokenizer=tokenizer) + self.assertEqual(m.pad_id(), expected_pad_id) + self.assertEqual(m.eos_id(), expected_eos_id) + + def test_model_property(self): + dummy_model = {"weights": [1, 2, 3]} + m = self._create_mock_rollout(rollout_actor=dummy_model) + self.assertEqual(m.model(), dummy_model) + + +if __name__ == "__main__": + absltest.main()