From 066452458d78885edf18db9e05ef51e1255a4113 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Thu, 2 Oct 2025 16:56:42 -0700 Subject: [PATCH 1/4] Initial Commit renaming Policy to Generator --- apps/grpo/main.py | 4 +- apps/toy_rl/sumdigits.py | 4 +- apps/vllm/main.py | 4 +- src/forge/actors/__init__.py | 8 +- src/forge/actors/{policy.py => generator.py} | 130 +++++++++--------- src/forge/observability/metric_actors.py | 2 +- tests/integration_tests/test_policy_update.py | 8 +- ...icy_config.py => test_generator_config.py} | 94 ++++++------- 8 files changed, 129 insertions(+), 125 deletions(-) rename src/forge/actors/{policy.py => generator.py} (87%) rename tests/unit_tests/{test_policy_config.py => test_generator_config.py} (54%) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 7545aa561..5f3ac06f9 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -21,7 +21,7 @@ get_dcp_whole_state_dict_key, get_param_prefix, ) -from forge.actors.policy import Policy +from forge.actors.generator import Generator from forge.actors.reference_model import ReferenceModel from forge.actors.replay_buffer import ReplayBuffer from forge.actors.trainer import RLTrainer @@ -329,7 +329,7 @@ async def main(cfg: DictConfig): reward_actor, ) = await asyncio.gather( DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset), - Policy.options(**cfg.services.policy).as_service(**cfg.policy), + Generator.options(**cfg.services.policy).as_service(**cfg.policy), RLTrainer.options(**cfg.actors.trainer).as_actor( **cfg.trainer, loss=simple_grpo_loss ), diff --git a/apps/toy_rl/sumdigits.py b/apps/toy_rl/sumdigits.py index 57971e9b9..13440511c 100644 --- a/apps/toy_rl/sumdigits.py +++ b/apps/toy_rl/sumdigits.py @@ -17,7 +17,7 @@ import torch.nn.functional as F import torchstore as ts from forge.actors._torchstore_utils import get_param_key -from forge.actors.policy import Policy +from forge.actors.generator import Generator from forge.actors.replay_buffer import ReplayBuffer from forge.actors.trainer import _qwen3_hf_to_vllm from forge.cli.config import parse @@ -482,7 +482,7 @@ async def main(cfg: DictConfig): ref_model, ) = await asyncio.gather( DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset), - Policy.options(**cfg.services.policy).as_service(**cfg.policy), + Generator.options(**cfg.services.policy).as_service(**cfg.policy), Trainer.options(**cfg.actors.trainer).as_actor(**cfg.trainer), ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(**cfg.replay_buffer), RewardActor.options(**cfg.services.reward_actor).as_service(), diff --git a/apps/vllm/main.py b/apps/vllm/main.py index 3167817c7..a6a6647f6 100644 --- a/apps/vllm/main.py +++ b/apps/vllm/main.py @@ -13,7 +13,7 @@ import os -from forge.actors.policy import Policy +from forge.actors.generator import Generator from forge.cli.config import parse from forge.controller.provisioner import shutdown @@ -35,7 +35,7 @@ async def run(cfg: DictConfig): prompt = "What is 3+5?" if gd else "Tell me a joke" print("Spawning service...") - policy = await Policy.options(**cfg.services.policy).as_service(**cfg.policy) + policy = await Generator.options(**cfg.services.policy).as_service(**cfg.policy) import time diff --git a/src/forge/actors/__init__.py b/src/forge/actors/__init__.py index 54e450cd7..02536ae3f 100644 --- a/src/forge/actors/__init__.py +++ b/src/forge/actors/__init__.py @@ -4,14 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -__all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer", "TitanRefModel"] +__all__ = ["Generator", "PolicyRouter", "RLTrainer", "ReplayBuffer", "TitanRefModel"] def __getattr__(name): - if name == "Policy": - from .policy import Policy + if name == "Generator": + from .policy import Generator - return Policy + return Generator elif name == "PolicyRouter": from .policy import PolicyRouter diff --git a/src/forge/actors/policy.py b/src/forge/actors/generator.py similarity index 87% rename from src/forge/actors/policy.py rename to src/forge/actors/generator.py index 464674f2c..54f5711b8 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/generator.py @@ -55,7 +55,7 @@ from forge.data.sharding import VLLMSharding from forge.data_models.completion import Completion from forge.data_models.prompt import to_prompt -from forge.interfaces import Policy as PolicyInterface +from forge.interfaces import Policy as GeneratorInterface from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer from forge.types import ProcessConfig @@ -132,7 +132,7 @@ def create_vllm_config(self) -> VllmConfig: @dataclass -class Policy(PolicyInterface): +class Generator(GeneratorInterface): engine_config: EngineConfig | Mapping = field(default_factory=EngineConfig) sampling_config: SamplingConfig | Mapping = field(default_factory=SamplingConfig) use_vllm_builtin_load: bool = True @@ -142,13 +142,13 @@ class Policy(PolicyInterface): sampling_params: SamplingParams | None = None lora_request: LoRARequest | None = None tokenization_kwargs: dict = field(default_factory=dict) - policy_worker: "PolicyWorker" = None - policy_version: int | None = None + generator_worker: "GeneratorWorker" = None + generator_version: int | None = None def __post_init__(self): super().__init__() self._run_task: asyncio.Task | None = None - self._policy_proc: ProcMesh | None = None + self._generator_proc: ProcMesh | None = None self._worker_procs: ProcMesh | None = None self.running = False if isinstance(self.engine_config, Mapping): @@ -159,14 +159,14 @@ def __post_init__(self): @classmethod async def launch( # pyright: ignore[reportIncompatibleMethodOverride] - cls: type["Policy"], + cls: type["Generator"], *, engine_config: EngineConfig | Mapping = EngineConfig(), sampling_config: SamplingConfig | Mapping = SamplingConfig(), available_devices: str | None = None, use_dcp: bool = True, **kwargs, - ) -> "Policy": + ) -> "Generator": # Note - get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES # automatically. process_config: ProcessConfig = ProcessConfig( @@ -182,21 +182,21 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] # level leads to issues. # Once we can create multiple proc meshes on a host mesh, we can ensure # host colocation - policy_proc_config = copy(process_config) - policy_proc_config.procs = 1 - policy_proc_config.hosts = None - policy_proc_config.with_gpus = False + generator_proc_config = copy(process_config) + generator_proc_config.procs = 1 + generator_proc_config.hosts = None + generator_proc_config.with_gpus = False - policy_proc = await get_proc_mesh(process_config=policy_proc_config) + generator_proc = await get_proc_mesh(process_config=generator_proc_config) if isinstance(engine_config, Mapping): engine_config = EngineConfig.from_dict(engine_config) vllm_config = engine_config.create_vllm_config() # TODO (felipemello): LocalFetcherActor doesnt spawn with this, so cannot - # do logging within PolicyWorker + # do logging within GeneratorWorker workers = worker_procs.spawn( - "vllm_worker", PolicyWorker, vllm_config=vllm_config, use_dcp=use_dcp + "vllm_worker", GeneratorWorker, vllm_config=vllm_config, use_dcp=use_dcp ) if isinstance(sampling_config, Mapping): @@ -204,49 +204,49 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] # TODO - expand support so name can stick within kwargs actor_name = kwargs.pop("name", cls.__name__) - policy = policy_proc.spawn( + generator = generator_proc.spawn( actor_name, cls, engine_config=engine_config, sampling_config=sampling_config, available_devices=available_devices, - policy_worker=workers, + generator_worker=workers, **kwargs, ) - policy._policy_proc = policy_proc - policy._worker_procs = worker_procs - await policy.setup.call() - return policy + generator._generator_proc = generator_proc + generator._worker_procs = worker_procs + await generator.setup.call() + return generator @classmethod async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] - cls: type["Policy"], actor: "Policy" + cls: type["Generator"], actor: "Generator" ): assert ( - actor._policy_proc is not None - ), "Tried to shutdown a policy that was not initialized correctly" + actor._generator_proc is not None + ), "Tried to shutdown a generator that was not initialized correctly" assert ( actor._worker_procs is not None - ), "Tried to shutdown a policy that was not initialized correctly" + ), "Tried to shutdown a generator that was not initialized correctly" # TODO - may want to expand stop to gracefully respond to # ongoing requests. await actor.stop.call() await stop_proc_mesh(actor._worker_procs) - await stop_proc_mesh(actor._policy_proc) + await stop_proc_mesh(actor._generator_proc) @endpoint async def setup(self): - # Set up policy_worker - assert self.policy_worker is not None, "Policy worker should not be None" - await self.policy_worker.setup.call() + # Set up generator_worker + assert self.generator_worker is not None, "Policy worker should not be None" + await self.generator_worker.setup.call() self.request_id = 0 - self.policy_version = 0 + self.generator_version = 0 self.requests: dict[str, tuple[None | ParentRequest, asyncio.Future]] = {} - # TODO: Investigate whether this can be combined with `policy.running` - # Whether this policy is accepting requests. + # TODO: Investigate whether this can be combined with `generator.running` + # Whether this generator is accepting requests. self.accepting_requests = True # Guard for accepting_requests self.request_lock = asyncio.Condition() @@ -275,7 +275,7 @@ async def setup(self): # Setup scheduler # TODO: Add support for `log_stats` - kv_cache_configs = await self.policy_worker.setup_kv_cache.call() + kv_cache_configs = await self.generator_worker.setup_kv_cache.call() _, kv_cache_config = next(kv_cache_configs.items()) self.vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks self.vllm_config.cache_config.num_cpu_blocks = 0 @@ -306,10 +306,10 @@ async def generate(self, prompt: str, priority: int = 0) -> list[Completion]: Returns: RequestOutput: vLLM class with the generated response. """ - t = Tracer("policy_perf/generate", timer="gpu") + t = Tracer("generator_perf/generate", timer="gpu") t.start() - record_metric("policy/generate/count_requests", 1, Reduce.SUM) + record_metric("generator/generate/count_requests", 1, Reduce.SUM) self.request_id += 1 % sys.maxsize request_id = str(self.request_id) # implement from a counter @@ -380,7 +380,7 @@ async def generate(self, prompt: str, priority: int = 0) -> list[Completion]: t.step("generate") record_metric( - "policy/generate/count_sequences_completed", + "generator/generate/count_sequences_completed", len(completions), Reduce.SUM, ) @@ -388,13 +388,13 @@ async def generate(self, prompt: str, priority: int = 0) -> list[Completion]: for completion in completions: num_generated_tokens = len(completion.token_ids) record_metric( - "policy/generate/sum_tokens_generated", + "generator/generate/sum_tokens_generated", num_generated_tokens, Reduce.SUM, ) record_metric( - "policy/generate/avg_tokens_generated", + "generator/generate/avg_tokens_generated", num_generated_tokens, Reduce.MEAN, ) @@ -422,7 +422,7 @@ async def run(self): scheduler_output = self.scheduler.schedule() - worker_outputs = await self.policy_worker.execute_model.call( + worker_outputs = await self.generator_worker.execute_model.call( scheduler_output ) @@ -450,7 +450,7 @@ async def run(self): self.request_lock.notify_all() @endpoint - async def update_weights(self, policy_version: int): + async def update_weights(self, version: int): # Serialize updates (only one update at a time) async with self.update_lock: # Grab the lock to stop accepting requests and wait on pending requests @@ -461,12 +461,12 @@ async def update_weights(self, policy_version: int): if curr_requests: # Record pending requests metrics record_metric( - "policy_perf/update_weights/avg_pending_requests", + "generator_perf/update_weights/avg_pending_requests", len(curr_requests), Reduce.MEAN, ) record_metric( - "policy_perf/update_weights/max_pending_requests", + "generator_perf/update_weights/max_pending_requests", len(curr_requests), Reduce.MAX, ) @@ -474,18 +474,20 @@ async def update_weights(self, policy_version: int): # Wait until all pending requests have been processed # TODO: If generating long sequences, this might be long and will block - # policy weight updates + # generator weight updates await self.request_lock.wait_for(lambda: len(self.requests) == 0) # Record weight update metrics - record_metric("policy/update_weights/count_weight_updates", 1, Reduce.SUM) + record_metric( + "generator/update_weights/count_weight_updates", 1, Reduce.SUM + ) logger.debug(f"Starting weight update on {self.__class__.__name__}") if self.use_vllm_builtin_load: - await self.policy_worker.update.call(version=policy_version) + await self.generator_worker.update.call(version=version) else: - await self.policy_worker.update_DEPRECATED.call(version=policy_version) - self.policy_version = policy_version + await self.generator_worker.update_DEPRECATED.call(version=version) + self.generator_version = version # After updating the weights, we need to reset the KV cache self.scheduler.kv_cache_manager.reset_prefix_cache() @@ -495,24 +497,24 @@ async def update_weights(self, policy_version: int): self.accepting_requests = True self.request_lock.notify_all() - logger.info(f"Weight update completed (now v{self.policy_version})") + logger.info(f"Weight update completed (now v{self.generator_version})") @endpoint - async def update_weights_DEPRECATED(self, policy_version: int): # noqa: N802 - # TODO: If generating long sequences, this might be long and will block policy weight updates + async def update_weights_DEPRECATED(self, version: int): # noqa: N802 + # TODO: If generating long sequences, this might be long and will block generator weight updates curr_requests = [fut for _, fut in self.requests.values()] if curr_requests: logger.debug(f"Waiting for {len(curr_requests)} pending requests") await asyncio.gather(*curr_requests) - await self.policy_worker.update_DEPRECATED.call(version=policy_version) - self.policy_version = policy_version - logger.info(f"Weight update completed (now v{self.policy_version})") + await self.generator_worker.update_DEPRECATED.call(version=version) + self.generator_version = version + logger.info(f"Weight update completed (now v{self.generator_version})") @endpoint async def get_version(self) -> int: - """Get the current policy version.""" - return self.policy_version + """Get the current generator version.""" + return self.generator_version @endpoint async def stop(self): @@ -522,13 +524,13 @@ async def stop(self): async def _test_save_model_params(self): """Save model parameters before weight update, used for tesing purposes only.""" logger.info("[Policy] save model parameters for testing.") - await self.policy_worker._test_save_model_params.call() + await self.generator_worker._test_save_model_params.call() @endpoint async def _test_validate_model_params(self, validate_fn): """Validate updated model params using validate_fn.""" logger.info("[Policy] start validating model parameters.") - return await self.policy_worker._test_validate_model_params.call(validate_fn) + return await self.generator_worker._test_validate_model_params.call(validate_fn) def _to_completions(self, request_output: RequestOutput) -> list[Completion]: """Convert a RequestOutput to a list of Completion objects.""" @@ -546,7 +548,7 @@ def _to_completions(self, request_output: RequestOutput) -> list[Completion]: prompt_ids=torch.tensor(prompt_token_ids), token_ids=torch.tensor(output.token_ids), logprobs=self._extract_logprobs(output), - generator_version=self.policy_version, + generator_version=self.generator_version, ) ) @@ -569,7 +571,7 @@ def _extract_logprobs(self, one_sample: CompletionOutput) -> torch.Tensor | None @dataclass -class PolicyWorker(ForgeActor): +class GeneratorWorker(ForgeActor): vllm_config: VllmConfig state_dict_key: str = "model_state_dict" # TODO: remove this later since no plumbing exists to change this value. @@ -645,7 +647,7 @@ async def update_DEPRECATED(self, version: int): # noqa: N802 async def update(self, version: int): """Update model weights by reading state dict from torchstore""" logger.info( - f"[PolicyWorker::update] start updating weights to version {version}" + f"[GeneratorWorker::update] start updating weights to version {version}" ) model = self.worker.model_runner.model prefix = get_param_prefix(version) @@ -654,7 +656,7 @@ async def update(self, version: int): logger.debug(f"{matching_keys=}") dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) loaded_weights = set() - t = Tracer("policy_worker_perf/update", timer="gpu") + t = Tracer("generator_worker_perf/update", timer="gpu") t.start() # Entire state dict is stored in a single DCP handle if dcp_whole_state_dict_key in matching_keys: @@ -679,7 +681,7 @@ async def update(self, version: int): del param loaded_weights.update(loaded) t.stop() - logger.debug(f"[PolicyWorker::update] Loaded weights: {loaded_weights}") + logger.debug(f"[GeneratorWorker::update] Loaded weights: {loaded_weights}") @endpoint async def setup_kv_cache(self): @@ -714,18 +716,18 @@ async def setup_kv_cache(self): @endpoint async def _test_save_model_params(self): """Save model parameters before weight update, used for tesing purposes only.""" - logger.info("[PolicyWorker] save model parameters for testing.") + logger.info("[GeneratorWorker] save model parameters for testing.") for name, param in self.worker.model_runner.model.named_parameters(): self._test_prev_params[name] = param.detach().cpu() logger.info( - "[PolicyWorker] finished saving model parameters, len = %d", + "[GeneratorWorker] finished saving model parameters, len = %d", len(self._test_prev_params), ) @endpoint async def _test_validate_model_params(self, validate_fn): """Validate updated model params using validate_fn.""" - logger.info("[PolicyWorker] start validating model parameters.") + logger.info("[GeneratorWorker] start validating model parameters.") return validate_fn( self._test_prev_params, self.worker.model_runner.model, logger ) diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index d67a66a83..2c791a8bf 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -59,7 +59,7 @@ async def get_or_create_metric_logger( }) # Initialize services... - policy = await Policy.as_service(...) + policy = await Generator.as_service(...) # Training loop for step in range(max_steps): diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index 506fc5553..6c681f4dc 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -12,7 +12,7 @@ import torch import torchstore as ts -from forge.actors.policy import Policy +from forge.actors.generator import Generator from forge.actors.trainer import RLTrainer from forge.cli.config import resolve_hf_hub_paths @@ -203,7 +203,7 @@ async def test_sanity_check(self, request): trainer_cfg["dcp_path"] = tmpdir policy, rl_trainer = await asyncio.gather( *[ - Policy.options(**services_policy_cfg).as_service(**cfg.policy), + Generator.options(**services_policy_cfg).as_service(**cfg.policy), MockRLTrainer.options(**cfg.actors.trainer).as_actor(**trainer_cfg), ] ) @@ -224,7 +224,7 @@ async def test_sanity_check(self, request): for _, e in errs.items(): assert not e, f"Validation failed with exception: {e}" - await policy.update_weights.fanout(policy_version=v1) + await policy.update_weights.fanout(version=v1) all_errs = await policy._test_validate_model_params.fanout( validate_fn_all_zeros ) @@ -233,7 +233,7 @@ async def test_sanity_check(self, request): assert not e, f"Validation failed with exception: {e}" # Reloading v0, getting back original weights - await policy.update_weights.fanout(policy_version=v0) + await policy.update_weights.fanout(version=v0) all_errs = await policy._test_validate_model_params.fanout(validate_fn) for errs in all_errs: for _, e in errs.items(): diff --git a/tests/unit_tests/test_policy_config.py b/tests/unit_tests/test_generator_config.py similarity index 54% rename from tests/unit_tests/test_policy_config.py rename to tests/unit_tests/test_generator_config.py index 08de4f907..a9ec4430f 100644 --- a/tests/unit_tests/test_policy_config.py +++ b/tests/unit_tests/test_generator_config.py @@ -14,50 +14,52 @@ def _import_error(): """Check if there are import errors that would cause CI failures.""" try: - import forge.actors.policy # noqa: F401 + import forge.actors.generator # noqa: F401 return False except ImportError: return True -class TestPolicyConfig(unittest.TestCase): - """Test suite for Policy configuration handling after PolicyConfig removal.""" +class TestGeneratorConfig(unittest.TestCase): + """Test suite for Generator configuration handling after PolicyConfig removal.""" @pytest.mark.skipif( _import_error(), reason="Import error, likely due to missing dependencies on CI.", ) - def test_policy_default_initialization(self): - """Policy initializes with default values.""" - from forge.actors.policy import EngineConfig, Policy, SamplingConfig + def test_generator_default_initialization(self): + """Generator initializes with default values.""" + from forge.actors.generator import EngineConfig, Generator, SamplingConfig - policy = Policy() + generator = Generator() # Default factories - self.assertIsInstance(policy.engine_config, EngineConfig) - self.assertIsInstance(policy.sampling_config, SamplingConfig) - self.assertIsNone(policy.available_devices) + self.assertIsInstance(generator.engine_config, EngineConfig) + self.assertIsInstance(generator.sampling_config, SamplingConfig) + self.assertIsNone(generator.available_devices) # Worker defaults - self.assertEqual(policy.engine_config.model, "meta-llama/Llama-3.1-8B-Instruct") - self.assertEqual(policy.engine_config.tensor_parallel_size, 1) - self.assertEqual(policy.engine_config.pipeline_parallel_size, 1) - self.assertFalse(policy.engine_config.enforce_eager) - self.assertTrue(policy.engine_config._is_v1_supported_oracle()) + self.assertEqual( + generator.engine_config.model, "meta-llama/Llama-3.1-8B-Instruct" + ) + self.assertEqual(generator.engine_config.tensor_parallel_size, 1) + self.assertEqual(generator.engine_config.pipeline_parallel_size, 1) + self.assertFalse(generator.engine_config.enforce_eager) + self.assertTrue(generator.engine_config._is_v1_supported_oracle()) # Sampling defaults - self.assertEqual(policy.sampling_config.n, 1) - self.assertFalse(policy.sampling_config.guided_decoding) - self.assertEqual(policy.sampling_config.max_tokens, 512) + self.assertEqual(generator.sampling_config.n, 1) + self.assertFalse(generator.sampling_config.guided_decoding) + self.assertEqual(generator.sampling_config.max_tokens, 512) @pytest.mark.skipif( _import_error(), reason="Import error, likely due to missing dependencies on CI.", ) - def test_policy_with_dict_configs(self): - """Policy accepts dicts for engine_config and sampling_config, including nested dicts.""" - from forge.actors.policy import EngineConfig, Policy, SamplingConfig + def test_generator_with_dict_configs(self): + """Generator accepts dicts for engine_config and sampling_config, including nested dicts.""" + from forge.actors.generator import EngineConfig, Generator, SamplingConfig # Test with nested dict structure engine_dict = { @@ -78,26 +80,26 @@ def test_policy_with_dict_configs(self): "max_tokens": 2468, } - policy = Policy( + generator = Generator( engine_config=engine_dict, sampling_config=sampling_dict, available_devices="test-gpu-device-abcd", ) - self.assertIsInstance(policy.engine_config, EngineConfig) - self.assertIsInstance(policy.sampling_config, SamplingConfig) + self.assertIsInstance(generator.engine_config, EngineConfig) + self.assertIsInstance(generator.sampling_config, SamplingConfig) # Test basic fields - self.assertEqual(policy.engine_config.model, "test-model-6789") - self.assertEqual(policy.engine_config.tensor_parallel_size, 7777) - self.assertEqual(policy.engine_config.pipeline_parallel_size, 8888) - self.assertTrue(policy.engine_config.enforce_eager) - self.assertTrue(policy.engine_config._is_v1_supported_oracle()) + self.assertEqual(generator.engine_config.model, "test-model-6789") + self.assertEqual(generator.engine_config.tensor_parallel_size, 7777) + self.assertEqual(generator.engine_config.pipeline_parallel_size, 8888) + self.assertTrue(generator.engine_config.enforce_eager) + self.assertTrue(generator.engine_config._is_v1_supported_oracle()) - self.assertEqual(policy.sampling_config.n, 1357) + self.assertEqual(generator.sampling_config.n, 1357) # After __post_init__, guided_decoding becomes GuidedDecodingParams object when True - self.assertIsNotNone(policy.sampling_config.guided_decoding) - self.assertEqual(policy.sampling_config.max_tokens, 2468) + self.assertIsNotNone(generator.sampling_config.guided_decoding) + self.assertEqual(generator.sampling_config.max_tokens, 2468) # Test that engine_dict accepts and preserves nested dict structure # The original engine_dict should remain unchanged and accessible @@ -113,9 +115,9 @@ def test_policy_with_dict_configs(self): _import_error(), reason="Import error, likely due to missing dependencies on CI.", ) - def test_policy_yaml_config_loading(self): - """Policy can be constructed from a YAML config file.""" - from forge.actors.policy import Policy + def test_generator_yaml_config_loading(self): + """Generator can be constructed from a YAML config file.""" + from forge.actors.generator import Generator yaml_content = """ engine_config: @@ -139,20 +141,20 @@ def test_policy_yaml_config_loading(self): with open(f.name, "r") as yaml_file: config = yaml.safe_load(yaml_file) - policy = Policy(**config) + generator = Generator(**config) - self.assertEqual(policy.engine_config.model, "yaml-test-model-9876") - self.assertEqual(policy.engine_config.tensor_parallel_size, 1234) - self.assertEqual(policy.engine_config.pipeline_parallel_size, 5678) - self.assertTrue(policy.engine_config.enforce_eager) - self.assertTrue(policy.engine_config._is_v1_supported_oracle()) + self.assertEqual(generator.engine_config.model, "yaml-test-model-9876") + self.assertEqual(generator.engine_config.tensor_parallel_size, 1234) + self.assertEqual(generator.engine_config.pipeline_parallel_size, 5678) + self.assertTrue(generator.engine_config.enforce_eager) + self.assertTrue(generator.engine_config._is_v1_supported_oracle()) - self.assertEqual(policy.sampling_config.n, 2468) + self.assertEqual(generator.sampling_config.n, 2468) # After __post_init__, guided_decoding becomes GuidedDecodingParams object when True - self.assertIsNotNone(policy.sampling_config.guided_decoding) - self.assertEqual(policy.sampling_config.max_tokens, 1357) + self.assertIsNotNone(generator.sampling_config.guided_decoding) + self.assertEqual(generator.sampling_config.max_tokens, 1357) - self.assertEqual(policy.available_devices, "yaml-test-device-xyz") + self.assertEqual(generator.available_devices, "yaml-test-device-xyz") @pytest.mark.skipif( _import_error(), @@ -160,7 +162,7 @@ def test_policy_yaml_config_loading(self): ) def test_engineconfig_ignores_invalid_keys(self): """EngineConfig.from_dict ignores unexpected keys.""" - from forge.actors.policy import EngineConfig + from forge.actors.generator import EngineConfig engine_config = { "model": "custom-model", From d79aa65f5b74baab468105f892aefecfdf58ed4e Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Tue, 14 Oct 2025 12:50:38 -0700 Subject: [PATCH 2/4] Rebase cleanup --- src/forge/actors/generator.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 7f346b70e..8f8cf8fc7 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -128,11 +128,11 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] ) -> "Generator": """Launch the Generator with its workers. - We overwrite the default Service launch method in order to setup Actors (PolicyWorker) within this "coordinating" Actor. + We overwrite the default Service launch method in order to setup Actors (GeneratorWorker) within this "coordinating" Actor. We first create a proc_mesh for the workers, then a proc_mesh for the generator, and then we spawn the workers and the generator in setup. - The args here generally should match those in the `__init__` method of the Policy class. + The args here generally should match those in the `__init__` method of the Generator class. """ # Note: get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES process_config: ProcessConfig = ProcessConfig( @@ -144,7 +144,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] worker_procs = await get_proc_mesh(process_config=process_config) # TODO - issues/144 we will want to ensure colocation with workers - # We're currently locating the Policy on the local host proc mesh + # We're currently locating the Generator on the local host proc mesh # vLLM initialization without setting env variables at proc_mesh creation # level leads to issues. # Once we can create multiple proc meshes on a host mesh, we can ensure @@ -492,7 +492,7 @@ def _extract_logprobs(self, sample: CompletionOutput) -> torch.Tensor | None: @classmethod async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] - cls: type["Policy"], actor: "Policy" + cls: type["Generator"], actor: "Generator" ): assert ( actor._generator_proc is not None From 37995939b7ae34b21022e21bbcc38d7b5a62253a Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Tue, 14 Oct 2025 13:31:17 -0700 Subject: [PATCH 3/4] Whats up doc --- docs/source/api_generator.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/api_generator.md b/docs/source/api_generator.md index c5aee3eec..31b67c03c 100644 --- a/docs/source/api_generator.md +++ b/docs/source/api_generator.md @@ -1,26 +1,26 @@ # Generator ```{eval-rst} -.. currentmodule:: forge.actors.policy +.. currentmodule:: forge.actors.generator ``` The Generator (Policy) is the core inference engine in TorchForge, built on top of [vLLM](https://docs.vllm.ai/en/latest/). It manages model serving, text generation, and weight updates for reinforcement learning workflows. -## Policy +## Generator ```{eval-rst} -.. autoclass:: Policy +.. autoclass:: Generator :members: generate, update_weights, get_version, stop :exclude-members: __init__, launch :no-inherited-members: ``` -## PolicyWorker +## GeneratorWorker ```{eval-rst} -.. autoclass:: PolicyWorker +.. autoclass:: GeneratorWorker :members: execute_model, update, setup_kv_cache :show-inheritance: :exclude-members: __init__ From 4dd1025866c0a6115279129656b47196bbb6e1b3 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Tue, 14 Oct 2025 15:02:49 -0700 Subject: [PATCH 4/4] Alias Generator as Policy --- apps/grpo/main.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 63209b3b7..3c1b9e28f 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -79,6 +79,9 @@ def response_tensor(self) -> torch.Tensor: # Represents the group (G) of episodes in GRPO Group = list[Episode] +# Represents the Policy Model to collect data from +Policy = Generator + def collate( batches: list[Group], @@ -317,7 +320,7 @@ async def main(cfg: DictConfig): reward_actor, ) = await asyncio.gather( DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset), - Generator.options(**cfg.services.policy).as_service(**cfg.policy), + Policy.options(**cfg.services.policy).as_service(**cfg.policy), RLTrainer.options(**cfg.actors.trainer).as_actor( **cfg.trainer, loss=simple_grpo_loss ),