From b08e85a953fca64b9bf1500bb415d455704ccc81 Mon Sep 17 00:00:00 2001 From: joecummings Date: Tue, 14 Oct 2025 09:04:57 -0700 Subject: [PATCH 1/8] Policy cleaner launch / setup --- src/forge/actors/policy.py | 131 +++++++++++++++++++------------------ 1 file changed, 67 insertions(+), 64 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 3a1b3e86e..0029b3a2f 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -22,7 +22,6 @@ from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.utils import _validate_truncation_size from vllm.executor.multiproc_worker_utils import set_multiprocessing_worker_envs -from vllm.lora.request import LoRARequest from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs @@ -53,7 +52,6 @@ from forge.data_models.completion import Completion from forge.data_models.prompt import to_prompt from forge.env import TORCHSTORE_USE_RDMA -from forge.interfaces import Policy as PolicyInterface from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer from forge.types import ProcessConfig @@ -63,7 +61,7 @@ @dataclass -class Policy(PolicyInterface): +class Policy(ForgeActor): """Instance of a vLLM-based Policy. This class manually recreates a vLLM engine that mirrors the design of AsyncLLMEngine in v1. The @@ -72,8 +70,8 @@ class Policy(PolicyInterface): Args: engine_args (EngineArgs): The engine arguments to use for the vLLM engine. sampling_params (SamplingParams): The sampling parameters to use for the vLLM engine. - available_devices (str): The available devices to use for the vLLM engine. - use_dcp (bool): Whether to use DCP for NFS-based weight sync. + use_dcp (bool): Whether to use DCP for NFS-based weight sync. Default depends on whether or not + RDMA is enabled in torchstore. Example: >>> policy = await Policy.options(procs=1, num_replicas=1, with_gpus=True).as_service( @@ -88,19 +86,13 @@ class Policy(PolicyInterface): engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs) sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams) - available_devices: str | None = None - use_dcp: bool = ( - TORCHSTORE_USE_RDMA.get_value() == 0 - ) # torchstore currently only accepts 0 or 1 - # Remaining variables are initialized in self.setup() - lora_request: LoRARequest | None = None - tokenization_kwargs: dict = field(default_factory=dict) - policy_worker: PolicyWorker | None = None + use_dcp: bool | None = None def __post_init__(self): super().__init__() self._run_task: asyncio.Task | None = None self._policy_proc: ProcMesh | None = None + self.worker: PolicyWorker | None = None self._worker_procs: ProcMesh | None = None self.running = False self.policy_version: int = 0 @@ -113,16 +105,18 @@ def __post_init__(self): self.sampling_params = SamplingParams.from_optional(**self.sampling_params) self.sampling_params.output_kind = RequestOutputKind.FINAL_ONLY + if self.use_dcp is None: + self.use_dcp = TORCHSTORE_USE_RDMA.get_value() == 0 + + @endpoint + async def register_worker(self, worker: PolicyWorker) -> None: + self.worker = worker + logger.debug("Registered PolicyWorker on Policy.") + @classmethod async def launch( # pyright: ignore[reportIncompatibleMethodOverride] cls: type["Policy"], - *, - engine_args: EngineArgs | Mapping = EngineArgs(), - sampling_params: SamplingParams | Mapping = SamplingParams(), - available_devices: str | None = None, - use_dcp: bool = ( - TORCHSTORE_USE_RDMA.get_value() == 0 - ), # torchstore currently only accepts 0 or 1 + *args, **kwargs, ) -> "Policy": """Launch the policy with its workers. @@ -154,45 +148,47 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] policy_proc_config.with_gpus = False policy_proc = await get_proc_mesh(process_config=policy_proc_config) - if isinstance(engine_args, Mapping): - engine_args = EngineArgs(**engine_args) - engine_args._is_v1_supported_oracle = lambda *_: True # Always default on - logger.debug(f"Resolved engine args: {engine_args}") + # TODO - expand support so name can stick within kwargs + actor_name = kwargs.pop("name", cls.__name__) - vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS) - workers = worker_procs.spawn( - "vllm_worker", PolicyWorker, vllm_config=vllm_config, use_dcp=use_dcp - ) + # if isinstance(engine_args, Mapping): + # engine_args = EngineArgs(**engine_args) + # engine_args._is_v1_supported_oracle = lambda *_: True # Always default on + # logger.debug(f"Resolved engine args: {engine_args}") - if isinstance(sampling_params, Mapping): - sampling_params = SamplingParams.from_optional(**sampling_params) - sampling_params.output_kind = RequestOutputKind.FINAL_ONLY - logger.debug(f"Resolved sampling params: {sampling_params}") + # vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS) + # engine_args = kwargs["engine_args"] + # if isinstance(engine_args, Mapping): + # engine_args = EngineArgs(**engine_args) + # engine_args._is_v1_supported_oracle = lambda *_: True # Always default on + # vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS) + worker = worker_procs.spawn("vllm_worker", PolicyWorker, *args, **kwargs) + + # if isinstance(sampling_params, Mapping): + # sampling_params = SamplingParams.from_optional(**sampling_params) + # sampling_params.output_kind = RequestOutputKind.FINAL_ONLY + # logger.debug(f"Resolved sampling params: {sampling_params}") - # TODO - expand support so name can stick within kwargs - actor_name = kwargs.pop("name", cls.__name__) policy = policy_proc.spawn( actor_name, cls, - engine_args=engine_args, - sampling_params=sampling_params, - available_devices=available_devices, - policy_worker=workers, + *args, **kwargs, ) policy._policy_proc = policy_proc policy._worker_procs = worker_procs + await policy.register_worker.call(worker) await policy.setup.call() return policy @endpoint async def setup(self): """Mirrors the __init__ of vLLM's LLMEngine.""" - if self.policy_worker is None: + if self.worker is None: raise RuntimeError( "Policy worker should not be None. Usually it would be attached to Policy in the ``launch`` method." ) - await self.policy_worker.setup.call() + await self.worker.setup.call() self.request_id = 0 self.requests: dict[str, tuple[ParentRequest | None, asyncio.Future]] = {} @@ -203,11 +199,6 @@ async def setup(self): self.request_lock = asyncio.Condition() # Guard for accepting_requests self.update_lock = asyncio.Condition() # Guard for updating requests - vllm_config: VllmConfig = self.engine_args.create_engine_config( - UsageContext.LLM_CLASS - ) - self.max_model_len = vllm_config.model_config.max_model_len - # Setup processors # TODO: move all processing to the Environment # TODO: add support for `log_stats` and `mm_registry` @@ -222,7 +213,7 @@ async def setup(self): self.output_processor = OutputProcessor(tokenizer, log_stats=None) # Configure KV caches - kv_cache_configs = await self.policy_worker.setup_kv_cache.call() + kv_cache_configs = await self.worker.setup_kv_cache.call() _, kv_cache_config = next(kv_cache_configs.items()) vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks vllm_config.cache_config.num_cpu_blocks = 0 @@ -261,7 +252,7 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: self.request_id += 1 % sys.maxsize request_id = str(self.request_id) - tokenization_kwargs = self.tokenization_kwargs or {} + tokenization_kwargs = {} # TODO: add truncation support https://github.com/vllm-project/vllm/issues/4507 truncate_prompt_tokens = self.sampling_params.truncate_prompt_tokens _validate_truncation_size( @@ -274,7 +265,6 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: prompt={"prompt": prompt}, params=self.sampling_params, arrival_time=None, - lora_request=self.lora_request, tokenization_kwargs=tokenization_kwargs, trace_headers=None, priority=priority, @@ -341,8 +331,9 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: def _preprocess_add_request( self, request: EngineCoreRequest ) -> tuple[Request, int]: - """ (forge/issues/332) Will require attention when we bump vllm versions - https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419""" + """(forge/issues/332) Will require attention when we bump vllm versions + https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419 + """ if request.mm_hashes is not None: raise NotImplementedError("Support for mm_hash is not implemented yet.") req = Request.from_engine_core_request(request) @@ -358,9 +349,7 @@ async def run(self) -> None: self.running = True while self.running: scheduler_output = self.scheduler.schedule() - worker_outputs = await self.policy_worker.execute_model.call( - scheduler_output - ) + worker_outputs = await self.worker.execute_model.call(scheduler_output) # The results of `execute_model` are gathered on the driver rank (rank 0) _, worker_output = next(worker_outputs.items()) @@ -427,8 +416,8 @@ async def update_weights(self, policy_version: int) -> None: record_metric("policy/update_weights/count_weight_updates", 1, Reduce.SUM) logger.debug(f"Starting weight update on {self.__class__.__name__}") - # Call update_weights on every policy_worker - await self.policy_worker.update_weights.call(policy_version) + # Call update_weights on every policy worker + await self.worker.update_weights.call(policy_version) self.policy_version = policy_version # After updating the weights, we need to reset the KV cache @@ -507,13 +496,16 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] async def _test_save_model_params(self): """Save model parameters before weight update, used for tesing purposes only.""" logger.info("[Policy] save model parameters for testing.") - await self.policy_worker._test_save_model_params.call() + await self.worker._test_save_model_params.call() @endpoint async def _test_validate_model_params(self, validate_fn): """Validate updated model params using validate_fn.""" logger.info("[Policy] start validating model parameters.") - return await self.policy_worker._test_validate_model_params.call(validate_fn) + return await self.worker._test_validate_model_params.call(validate_fn) + + +from typing import Any @dataclass @@ -525,20 +517,31 @@ class PolicyWorker(ForgeActor): the creation and invocation of all PolicyWorkers. """ - vllm_config: VllmConfig - state_dict_key: str = "model_state_dict" - # TODO: remove this later since no plumbing exists to change this value. - # Also, whether to use dcp or not can be inferred from torchstore get() call. - use_dcp: bool = True - - # used for tesing purposes only + engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs) + sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams) + use_dcp: bool | None = None + # TODO: Remove below param _test_prev_params = {} def __post_init__(self): super().__init__() + if isinstance(self.engine_args, Mapping): + self.engine_args = EngineArgs(**self.engine_args) + self.engine_args._is_v1_supported_oracle = lambda *_: True + # Note: vllm_config creation is deferred to setup() method to avoid + # model inspection issues during remote actor initialization + self.vllm_config = None + print("HELLLOOOO") + + if self.use_dcp is None: + self.use_dcp = TORCHSTORE_USE_RDMA.get_value() == 0 @endpoint async def setup(self): + # Create vllm_config here instead of during initialization to avoid + # model inspection issues during remote actor initialization + self.vllm_config = self.engine_args.create_engine_config(UsageContext.LLM_CLASS) + self.rank = current_rank().rank os.environ["RANK"] = str(self.rank) parallel_config = self.vllm_config.parallel_config From a1823afdf414937b47006eaa260edac58b08e534 Mon Sep 17 00:00:00 2001 From: joecummings Date: Wed, 15 Oct 2025 08:59:29 -0700 Subject: [PATCH 2/8] Address comments; move things around --- src/forge/actors/policy.py | 86 +++++++++++++------------------------- 1 file changed, 30 insertions(+), 56 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 0029b3a2f..88e47794e 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -100,6 +100,7 @@ def __post_init__(self): if isinstance(self.engine_args, Mapping): self.engine_args = EngineArgs(**self.engine_args) self.engine_args._is_v1_supported_oracle = lambda *_: True + self.vllm_config = self.engine_args.create_engine_config(UsageContext.LLM_CLASS) if isinstance(self.sampling_params, Mapping): self.sampling_params = SamplingParams.from_optional(**self.sampling_params) @@ -107,6 +108,11 @@ def __post_init__(self): if self.use_dcp is None: self.use_dcp = TORCHSTORE_USE_RDMA.get_value() == 0 + logger.debug(f"{self.use_dcp=}") + + @endpoint + async def get_vllm_config(self) -> VllmConfig: + return self.vllm_config @endpoint async def register_worker(self, worker: PolicyWorker) -> None: @@ -134,14 +140,12 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] with_gpus=cls.with_gpus, mesh_name=cls.mesh_name, ) - worker_procs = await get_proc_mesh(process_config=process_config) # TODO - issues/144 we will want to ensure colocation with workers # We're currently locating the Policy on the local host proc mesh # vLLM initialization without setting env variables at proc_mesh creation - # level leads to issues. - # Once we can create multiple proc meshes on a host mesh, we can ensure - # host colocation + # level leads to issues. Once we can create multiple proc meshes on a host mesh, + # we can ensure host colocation policy_proc_config = copy(process_config) policy_proc_config.procs = 1 policy_proc_config.hosts = None @@ -150,46 +154,32 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] # TODO - expand support so name can stick within kwargs actor_name = kwargs.pop("name", cls.__name__) - - # if isinstance(engine_args, Mapping): - # engine_args = EngineArgs(**engine_args) - # engine_args._is_v1_supported_oracle = lambda *_: True # Always default on - # logger.debug(f"Resolved engine args: {engine_args}") - - # vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS) - # engine_args = kwargs["engine_args"] - # if isinstance(engine_args, Mapping): - # engine_args = EngineArgs(**engine_args) - # engine_args._is_v1_supported_oracle = lambda *_: True # Always default on - # vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS) - worker = worker_procs.spawn("vllm_worker", PolicyWorker, *args, **kwargs) - - # if isinstance(sampling_params, Mapping): - # sampling_params = SamplingParams.from_optional(**sampling_params) - # sampling_params.output_kind = RequestOutputKind.FINAL_ONLY - # logger.debug(f"Resolved sampling params: {sampling_params}") - policy = policy_proc.spawn( actor_name, cls, *args, **kwargs, ) + + worker_procs = await get_proc_mesh(process_config=process_config) + vllm_config = ( + await policy.get_vllm_config.call_one() + ) # Config should be the same across all policy actors + worker = worker_procs.spawn( + "vllm_worker", PolicyWorker, vllm_config=vllm_config + ) + await worker.setup.call() + await policy.register_worker.call(worker) + policy._policy_proc = policy_proc policy._worker_procs = worker_procs - await policy.register_worker.call(worker) await policy.setup.call() + return policy @endpoint async def setup(self): """Mirrors the __init__ of vLLM's LLMEngine.""" - if self.worker is None: - raise RuntimeError( - "Policy worker should not be None. Usually it would be attached to Policy in the ``launch`` method." - ) - await self.worker.setup.call() - self.request_id = 0 self.requests: dict[str, tuple[ParentRequest | None, asyncio.Future]] = {} @@ -203,26 +193,26 @@ async def setup(self): # TODO: move all processing to the Environment # TODO: add support for `log_stats` and `mm_registry` tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config, + model_config=self.vllm_config.model_config, + scheduler_config=self.vllm_config.scheduler_config, + lora_config=self.vllm_config.lora_config, ) self.processor = Processor( - vllm_config=vllm_config, tokenizer=tokenizer, mm_registry=None + vllm_config=self.vllm_config, tokenizer=tokenizer, mm_registry=None ) self.output_processor = OutputProcessor(tokenizer, log_stats=None) # Configure KV caches kv_cache_configs = await self.worker.setup_kv_cache.call() _, kv_cache_config = next(kv_cache_configs.items()) - vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks - vllm_config.cache_config.num_cpu_blocks = 0 + self.vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks + self.vllm_config.cache_config.num_cpu_blocks = 0 # Setup scheduler # TODO: Add support for `log_stats` - structured_output_manager = StructuredOutputManager(vllm_config) + structured_output_manager = StructuredOutputManager(self.vllm_config) self.scheduler = Scheduler( - vllm_config=vllm_config, + vllm_config=self.vllm_config, kv_cache_config=kv_cache_config, structured_output_manager=structured_output_manager, include_finished_set=False, @@ -256,7 +246,7 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: # TODO: add truncation support https://github.com/vllm-project/vllm/issues/4507 truncate_prompt_tokens = self.sampling_params.truncate_prompt_tokens _validate_truncation_size( - self.max_model_len, + self.vllm_config.model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs, ) @@ -505,9 +495,6 @@ async def _test_validate_model_params(self, validate_fn): return await self.worker._test_validate_model_params.call(validate_fn) -from typing import Any - - @dataclass class PolicyWorker(ForgeActor): """Mirrors a vLLM GPUWorker @@ -517,31 +504,18 @@ class PolicyWorker(ForgeActor): the creation and invocation of all PolicyWorkers. """ - engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs) - sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams) + vllm_config: VllmConfig use_dcp: bool | None = None # TODO: Remove below param _test_prev_params = {} def __post_init__(self): super().__init__() - if isinstance(self.engine_args, Mapping): - self.engine_args = EngineArgs(**self.engine_args) - self.engine_args._is_v1_supported_oracle = lambda *_: True - # Note: vllm_config creation is deferred to setup() method to avoid - # model inspection issues during remote actor initialization - self.vllm_config = None - print("HELLLOOOO") - if self.use_dcp is None: self.use_dcp = TORCHSTORE_USE_RDMA.get_value() == 0 @endpoint async def setup(self): - # Create vllm_config here instead of during initialization to avoid - # model inspection issues during remote actor initialization - self.vllm_config = self.engine_args.create_engine_config(UsageContext.LLM_CLASS) - self.rank = current_rank().rank os.environ["RANK"] = str(self.rank) parallel_config = self.vllm_config.parallel_config From da772af26151d5c75fac197cf2257f6d282ceb30 Mon Sep 17 00:00:00 2001 From: joecummings Date: Wed, 15 Oct 2025 09:30:25 -0700 Subject: [PATCH 3/8] Final merge --- src/forge/actors/generator.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index fc7af071c..d96e82936 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -63,10 +63,10 @@ @dataclass class Generator(ForgeActor): """Instance of a vLLM-based generator. - + This class manually recreates a vLLM engine that mirrors the design of AsyncLLMEngine in v1. The main difference is that all communications are controlled here via Monarch's proc meshes. - + Args: engine_args (EngineArgs): The engine arguments to use for the vLLM engine. sampling_params (SamplingParams): The sampling parameters to use for the vLLM engine. @@ -90,7 +90,10 @@ class Generator(ForgeActor): def __post_init__(self): super().__init__() + self._run_task: asyncio.Task | None = None + self._policy_proc: ProcMesh | None = None self._worker_procs: ProcMesh | None = None + self.worker: GeneratorWorker | None = None self.running = False self.generator_version: int = 0 @@ -161,7 +164,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] worker_procs = await get_proc_mesh(process_config=process_config) vllm_config = ( await generator.get_vllm_config.call_one() - ) # Config should be the same across all policy actors + ) # Config should be the same across all actors worker = worker_procs.spawn( "vllm_worker", GeneratorWorker, vllm_config=vllm_config ) From 17dede24e33d710010576b1ff2369074cecf322a Mon Sep 17 00:00:00 2001 From: joecummings Date: Wed, 15 Oct 2025 10:42:03 -0700 Subject: [PATCH 4/8] Remove dcp stuff --- src/forge/actors/generator.py | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index d96e82936..10c41a87b 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -70,8 +70,8 @@ class Generator(ForgeActor): Args: engine_args (EngineArgs): The engine arguments to use for the vLLM engine. sampling_params (SamplingParams): The sampling parameters to use for the vLLM engine. - use_dcp (bool): Whether to use DCP for NFS-based weight sync. Default depends on whether or not - RDMA is enabled in torchstore. + use_dcp_for_weight_sync (bool): Whether to use DCP for NFS-based weight sync. Default depends on + whether or not RDMA is enabled in torchstore. If it is, then DCP is disabled. Otherwise, DCP is enabled. Example: >>> generator = await Generator.options(procs=1, num_replicas=1, with_gpus=True).as_service( @@ -86,7 +86,7 @@ class Generator(ForgeActor): engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs) sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams) - use_dcp: bool | None = None + use_dcp_for_weight_sync: bool | None = None def __post_init__(self): super().__init__() @@ -106,9 +106,9 @@ def __post_init__(self): self.sampling_params = SamplingParams.from_optional(**self.sampling_params) self.sampling_params.output_kind = RequestOutputKind.FINAL_ONLY - if self.use_dcp is None: - self.use_dcp = TORCHSTORE_USE_RDMA.get_value() == 0 - logger.debug(f"{self.use_dcp=}") + if self.use_dcp_for_weight_sync is None: + self.use_dcp_for_weight_sync = TORCHSTORE_USE_RDMA.get_value() == 0 + logger.debug(f"{self.use_dcp_for_weight_sync=}") @endpoint async def get_vllm_config(self) -> VllmConfig: @@ -130,8 +130,6 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] We overwrite the default Service launch method in order to setup Actors (GeneratorWorker) within this "coordinating" Actor. We first create a proc_mesh for the workers, then a proc_mesh for the generator, and then we spawn the workers and the generator in setup. - - The args here generally should match those in the `__init__` method of the Generator class. """ # Note: get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES process_config: ProcessConfig = ProcessConfig( @@ -507,15 +505,9 @@ class GeneratorWorker(ForgeActor): """ vllm_config: VllmConfig - use_dcp: bool | None = None # TODO: Remove below param _test_prev_params = {} - def __post_init__(self): - super().__init__() - if self.use_dcp is None: - self.use_dcp = TORCHSTORE_USE_RDMA.get_value() == 0 - @endpoint async def setup(self): self.rank = current_rank().rank @@ -577,11 +569,12 @@ async def update_weights(self, version: int) -> None: prefix = get_param_prefix(version) matching_keys = await ts.keys(prefix) dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) + use_dcp_for_weight_sync = dcp_whole_state_dict_key in matching_keys loaded_weights = set() t = Tracer("worker_perf/update_weights", timer="gpu") t.start() - # Entire state dict is stored in a single DCP handle - if dcp_whole_state_dict_key in matching_keys: + + if use_dcp_for_weight_sync: dcp_handle = await ts.get(dcp_whole_state_dict_key) hf_param_names = dcp_handle.param_names for name in hf_param_names: @@ -589,7 +582,7 @@ async def update_weights(self, version: int) -> None: loaded = model.load_weights([(name, param)]) del param loaded_weights.update(loaded) - else: # Load each parameter from torchstore directly without DCP + else: hf_param_names = [extract_param_name(key) for key in matching_keys] # We can't pass a generator since vllm load_weights is not async. # Instead, we just call load_weights with one parameter at a time. @@ -599,6 +592,7 @@ async def update_weights(self, version: int) -> None: loaded = model.load_weights([(name, param)]) del param loaded_weights.update(loaded) + t.stop() @endpoint From 1c84d19dd0e136560fa3260eeedcbacb0df523ef Mon Sep 17 00:00:00 2001 From: joecummings Date: Wed, 15 Oct 2025 10:57:10 -0700 Subject: [PATCH 5/8] Remove testing stuff --- tests/unit_tests/test_generator_config.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/unit_tests/test_generator_config.py b/tests/unit_tests/test_generator_config.py index 1c64e42e2..bff2e3e6e 100644 --- a/tests/unit_tests/test_generator_config.py +++ b/tests/unit_tests/test_generator_config.py @@ -39,7 +39,6 @@ def test_generator_default_initialization(self): # Default factories self.assertIsInstance(generator.engine_args, EngineArgs) self.assertIsInstance(generator.sampling_params, SamplingParams) - self.assertIsNone(generator.available_devices) # Worker defaults self.assertEqual(generator.engine_args.model, "Qwen/Qwen3-0.6B") @@ -81,7 +80,6 @@ def test_generator_with_dict_configs(self): generator = Generator( engine_args=engine_dict, sampling_params=sampling_dict, - available_devices="test-gpu-device-abcd", ) self.assertIsInstance(generator.engine_args, EngineArgs) @@ -117,8 +115,6 @@ def test_generator_yaml_config_loading(self): sampling_params: n: 2468 max_tokens: 1357 - - available_devices: "yaml-test-device-xyz" """ with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: @@ -138,8 +134,6 @@ def test_generator_yaml_config_loading(self): self.assertEqual(generator.sampling_params.n, 2468) self.assertEqual(generator.sampling_params.max_tokens, 1357) - self.assertEqual(generator.available_devices, "yaml-test-device-xyz") - if __name__ == "__main__": unittest.main() From c1cafa37de5740c5cd322c1ea9e2ea8f82377d6d Mon Sep 17 00:00:00 2001 From: joecummings Date: Wed, 15 Oct 2025 11:00:02 -0700 Subject: [PATCH 6/8] Last comments --- src/forge/actors/generator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 10c41a87b..ca934127e 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -91,7 +91,7 @@ class Generator(ForgeActor): def __post_init__(self): super().__init__() self._run_task: asyncio.Task | None = None - self._policy_proc: ProcMesh | None = None + self._generator_proc: ProcMesh | None = None self._worker_procs: ProcMesh | None = None self.worker: GeneratorWorker | None = None self.running = False @@ -107,7 +107,7 @@ def __post_init__(self): self.sampling_params.output_kind = RequestOutputKind.FINAL_ONLY if self.use_dcp_for_weight_sync is None: - self.use_dcp_for_weight_sync = TORCHSTORE_USE_RDMA.get_value() == 0 + self.use_dcp_for_weight_sync = not TORCHSTORE_USE_RDMA.get_value() logger.debug(f"{self.use_dcp_for_weight_sync=}") @endpoint From ca842b25562b97d0c8ae71655cc9e4aee2b87f22 Mon Sep 17 00:00:00 2001 From: joecummings Date: Wed, 15 Oct 2025 11:18:29 -0700 Subject: [PATCH 7/8] Update test --- tests/unit_tests/test_generator_config.py | 30 +++++++++++------------ 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/tests/unit_tests/test_generator_config.py b/tests/unit_tests/test_generator_config.py index bff2e3e6e..9f9076895 100644 --- a/tests/unit_tests/test_generator_config.py +++ b/tests/unit_tests/test_generator_config.py @@ -57,24 +57,22 @@ def test_generator_default_initialization(self): reason="Import error, likely due to missing dependencies on CI.", ) def test_generator_with_dict_configs(self): - """Generator accepts dicts for engine_config and sampling_config, including nested dicts.""" from forge.actors.generator import Generator from vllm.engine.arg_utils import EngineArgs from vllm.sampling_params import SamplingParams - # Test with nested dict structure engine_dict = { - "model": "test-model-6789", - "tensor_parallel_size": 7777, - "pipeline_parallel_size": 8888, + "model": "Qwen/Qwen3-0.6B", + "tensor_parallel_size": 1, + "pipeline_parallel_size": 1, "enforce_eager": True, - "gpu_memory_utilization": 0.9, - "max_model_len": 4096, + "gpu_memory_utilization": 0.1, + "max_model_len": 1024, } sampling_dict = { - "n": 1357, - "max_tokens": 2468, + "n": 2, + "max_tokens": 32, } generator = Generator( @@ -86,16 +84,16 @@ def test_generator_with_dict_configs(self): self.assertIsInstance(generator.sampling_params, SamplingParams) # Test basic fields - self.assertEqual(generator.engine_args.model, "test-model-6789") - self.assertEqual(generator.engine_args.tensor_parallel_size, 7777) - self.assertEqual(generator.engine_args.pipeline_parallel_size, 8888) - self.assertEqual(generator.engine_args.gpu_memory_utilization, 0.9) - self.assertEqual(generator.engine_args.max_model_len, 4096) + self.assertEqual(generator.engine_args.model, "Qwen/Qwen3-0.6B") + self.assertEqual(generator.engine_args.tensor_parallel_size, 1) + self.assertEqual(generator.engine_args.pipeline_parallel_size, 1) + self.assertEqual(generator.engine_args.gpu_memory_utilization, 0.1) + self.assertEqual(generator.engine_args.max_model_len, 1024) self.assertTrue(generator.engine_args.enforce_eager) self.assertTrue(generator.engine_args._is_v1_supported_oracle()) - self.assertEqual(generator.sampling_params.n, 1357) - self.assertEqual(generator.sampling_params.max_tokens, 2468) + self.assertEqual(generator.sampling_params.n, 2) + self.assertEqual(generator.sampling_params.max_tokens, 32) @pytest.mark.skipif( _import_error(), From ecf7d7372529fff08850f459ae734eeb7f9b0c08 Mon Sep 17 00:00:00 2001 From: joecummings Date: Wed, 15 Oct 2025 11:34:07 -0700 Subject: [PATCH 8/8] Fix test --- tests/unit_tests/test_generator_config.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/unit_tests/test_generator_config.py b/tests/unit_tests/test_generator_config.py index 9f9076895..94cb58859 100644 --- a/tests/unit_tests/test_generator_config.py +++ b/tests/unit_tests/test_generator_config.py @@ -105,14 +105,14 @@ def test_generator_yaml_config_loading(self): yaml_content = """ engine_args: - model: "yaml-test-model-9876" - tensor_parallel_size: 1234 - pipeline_parallel_size: 5678 + model: "Qwen/Qwen3-0.6B" + tensor_parallel_size: 1 + pipeline_parallel_size: 1 enforce_eager: true sampling_params: - n: 2468 - max_tokens: 1357 + n: 2 + max_tokens: 32 """ with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: @@ -123,14 +123,14 @@ def test_generator_yaml_config_loading(self): config = yaml.safe_load(yaml_file) generator = Generator(**config) - self.assertEqual(generator.engine_args.model, "yaml-test-model-9876") - self.assertEqual(generator.engine_args.tensor_parallel_size, 1234) - self.assertEqual(generator.engine_args.pipeline_parallel_size, 5678) + self.assertEqual(generator.engine_args.model, "Qwen/Qwen3-0.6B") + self.assertEqual(generator.engine_args.tensor_parallel_size, 1) + self.assertEqual(generator.engine_args.pipeline_parallel_size, 1) self.assertTrue(generator.engine_args.enforce_eager) self.assertTrue(generator.engine_args._is_v1_supported_oracle()) - self.assertEqual(generator.sampling_params.n, 2468) - self.assertEqual(generator.sampling_params.max_tokens, 1357) + self.assertEqual(generator.sampling_params.n, 2) + self.assertEqual(generator.sampling_params.max_tokens, 32) if __name__ == "__main__":