From 58191c5f297af5a86613f5a7092ca7b57e3de4df Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Thu, 18 Sep 2025 16:06:33 -0400 Subject: [PATCH 1/8] Update logic to only track policy version in the main controller --- apps/grpo/main.py | 9 ++++----- src/forge/actors/policy.py | 13 +++++++------ src/forge/interfaces.py | 8 ++++++-- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 7f31c26c9..6c338bdc3 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -286,6 +286,7 @@ async def continuous_rollouts(): return prompt, target = sample["request"], sample["target"] responses = await policy.generate.choose(prompt) + # TODO: this shall be part of the responses metadata instead of a separate call version = await policy.get_version.choose() group = Group.new_group( group_id=rollout_count, @@ -343,10 +344,9 @@ async def continuous_rollouts(): async def continuous_training(): training_step = 0 - policy_version = 0 while True: batch = await replay_buffer.sample.choose( - curr_policy_version=policy_version + curr_policy_version=training_step ) if batch is None: await asyncio.sleep(0.1) @@ -355,9 +355,8 @@ async def continuous_training(): loss = await trainer.train_step.choose(inputs, targets) training_step += 1 mlogger.log("loss/training_step", loss, training_step) - await trainer.push_weights.call(policy_version) - policy_version += 1 - await policy.update_weights.call() + await trainer.push_weights.call(training_step) + await policy.update_weights.call(training_step) print("Starting GRPO training loops...") # TODO: Start multiple rollouts once all serivces support it diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 070c00798..7e2f181c7 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -123,12 +123,12 @@ class Policy(PolicyInterface): lora_request: LoRARequest | None = None tokenization_kwargs: dict = field(default_factory=dict) policy_worker: "PolicyWorker" = None + policy_version: int | None = None def __post_init__(self): self._run_task: asyncio.Task | None = None self._policy_proc: ProcMesh | None = None self._worker_procs: ProcMesh | None = None - self.weights_version: int = 0 self.running = False if isinstance(self.engine_config, Mapping): self.engine_config = EngineConfig.from_dict(self.engine_config) @@ -212,6 +212,7 @@ async def setup(self): await self.policy_worker.setup.call() self.request_id = 0 + self.policy_version = 0 self.requests: dict[str, tuple[None | ParentRequest, asyncio.Future]] = {} self.vllm_config: VllmConfig = self.engine_config.create_vllm_config() @@ -364,7 +365,7 @@ async def run(self): fut.set_result(request_output) @endpoint - async def update_weights(self): + async def update_weights(self, policy_version: int): # TODO: If generating long sequences, this might be long and will block policy weight updates curr_requests = [fut for _, fut in self.requests.values()] if curr_requests: @@ -372,9 +373,9 @@ async def update_weights(self): await asyncio.gather(*curr_requests) self.logger.debug(f"Starting weight update on {self.__class__.__name__}") - await self.policy_worker.update.call(version=self.weights_version) - self.weights_version += 1 - self.logger.info(f"Weight update completed (now v{self.weights_version})") + await self.policy_worker.update.call(version=policy_version) + self.policy_version = policy_version + self.logger.info(f"Weight update completed (now v{self.policy_version})") @endpoint async def _get_model_params(self) -> dict[str, torch.Tensor]: @@ -388,7 +389,7 @@ async def _get_model_params(self) -> dict[str, torch.Tensor]: @endpoint async def get_version(self) -> int: """Get the current policy version.""" - return self.weights_version + return self.policy_version @endpoint async def stop(self): diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py index 3dbbd560e..df79c302e 100644 --- a/src/forge/interfaces.py +++ b/src/forge/interfaces.py @@ -85,8 +85,12 @@ async def generate(self, request: Observation) -> Action: @endpoint @abstractmethod - async def update_weights(self): - """Update the policy weights.""" + async def update_weights(self, policy_version: int): + """Update the policy weights. + + Args: + policy_version: The version number to update to. + """ pass From c3240788ead243727568f25cc811820a7f59412c Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Thu, 18 Sep 2025 16:06:33 -0400 Subject: [PATCH 2/8] Update logic to only track policy version in the main controller --- apps/grpo/main.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 6c338bdc3..b568705e4 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -345,9 +345,7 @@ async def continuous_rollouts(): async def continuous_training(): training_step = 0 while True: - batch = await replay_buffer.sample.choose( - curr_policy_version=training_step - ) + batch = await replay_buffer.sample.choose(curr_policy_version=training_step) if batch is None: await asyncio.sleep(0.1) else: From 14f77aeee4abd6fdb2af1fe66d1a5ac3b3a6446f Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Fri, 19 Sep 2025 10:49:41 -0400 Subject: [PATCH 3/8] update the sumdigits app --- apps/toy_rl/sumdigits.py | 21 ++++++++++++--------- src/forge/actors/replay_buffer.py | 12 ++++++++++-- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/apps/toy_rl/sumdigits.py b/apps/toy_rl/sumdigits.py index 6b1d8d763..6097bd697 100644 --- a/apps/toy_rl/sumdigits.py +++ b/apps/toy_rl/sumdigits.py @@ -397,7 +397,13 @@ async def main(cfg: DictConfig): # ---- Setup services ---- # await ts.initialize() - (dataloader, policy, trainer, replay_buffer, reward_actor,) = await asyncio.gather( + ( + dataloader, + policy, + trainer, + replay_buffer, + reward_actor, + ) = await asyncio.gather( DatasetActor.options(**cfg.services.dataset).as_service(**cfg.dataset), Policy.options(**cfg.services.policy).as_service(**cfg.policy), Trainer.options(**cfg.services.trainer).as_service(**cfg.trainer), @@ -464,21 +470,18 @@ async def continuous_rollouts(): async def continuous_training(): training_step = 0 - policy_version = 0 while True: - batch = await replay_buffer.sample.choose( - curr_policy_version=policy_version - ) + batch = await replay_buffer.sample.choose(curr_policy_version=training_step) if batch is None: await asyncio.sleep(0.1) else: loss = await trainer.train_step.choose(batch[0]) training_step += 1 mlogger.log("loss/training_step", loss, training_step) - print(f"loss/training_step: {loss} at {training_step}") - await trainer.push_weights.call(policy_version) - policy_version += 1 - await policy.update_weights.call() + print(f"loss/training_step: {loss} at training step {training_step}") + await trainer.push_weights.call(training_step) + await policy.update_weights.call(training_step) + # TODO: remove this line? this clears the buffer so it's always on-policy await replay_buffer.clear.call() print("Starting training loop.") diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index ca7d487bc..15c89ccbf 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -8,10 +8,10 @@ from dataclasses import dataclass from typing import Any, Callable -from monarch.actor import endpoint - from forge.controller import ForgeActor +from monarch.actor import endpoint + @dataclass class ReplayBuffer(ForgeActor): @@ -87,11 +87,18 @@ async def evict(self, curr_policy_version: int) -> None: self._evict(curr_policy_version) def _evict(self, curr_policy_version: int) -> None: + buffer_len_before_evict = len(self.buffer) self.buffer = [ trajectory for trajectory in self.buffer if (curr_policy_version - trajectory.policy_version) <= self.max_policy_age ] + buffer_len_after_evict = len(self.buffer) + + print( + f"maximum policy age: {self.max_policy_age}, current policy version: {curr_policy_version}, " + f"{buffer_len_before_evict - buffer_len_after_evict} episodes expired, {buffer_len_after_evict} episodes left" + ) @endpoint async def _getitem(self, idx: int): @@ -106,6 +113,7 @@ async def _numel(self) -> int: async def clear(self) -> None: """Clear the replay buffer immediately - dropping all episodes.""" self.buffer.clear() + print("replay buffer cleared") @endpoint async def state_dict(self) -> dict[str, Any]: From ca4161d50bc087e646bf6636a4a0d1e1bbc6c7c4 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Fri, 19 Sep 2025 10:57:51 -0400 Subject: [PATCH 4/8] format --- apps/toy_rl/sumdigits.py | 8 +------- src/forge/actors/replay_buffer.py | 4 ++-- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/apps/toy_rl/sumdigits.py b/apps/toy_rl/sumdigits.py index 6097bd697..210bef8d3 100644 --- a/apps/toy_rl/sumdigits.py +++ b/apps/toy_rl/sumdigits.py @@ -397,13 +397,7 @@ async def main(cfg: DictConfig): # ---- Setup services ---- # await ts.initialize() - ( - dataloader, - policy, - trainer, - replay_buffer, - reward_actor, - ) = await asyncio.gather( + (dataloader, policy, trainer, replay_buffer, reward_actor,) = await asyncio.gather( DatasetActor.options(**cfg.services.dataset).as_service(**cfg.dataset), Policy.options(**cfg.services.policy).as_service(**cfg.policy), Trainer.options(**cfg.services.trainer).as_service(**cfg.trainer), diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index 15c89ccbf..af934c8d2 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -8,10 +8,10 @@ from dataclasses import dataclass from typing import Any, Callable -from forge.controller import ForgeActor - from monarch.actor import endpoint +from forge.controller import ForgeActor + @dataclass class ReplayBuffer(ForgeActor): From 664c5d7861d6ea04c04850ce0fcb1c5335d4b5c3 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Fri, 19 Sep 2025 11:22:47 -0400 Subject: [PATCH 5/8] use logger --- src/forge/actors/replay_buffer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index af934c8d2..e0fb438bb 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -4,13 +4,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging import random from dataclasses import dataclass from typing import Any, Callable +from forge.controller import ForgeActor + from monarch.actor import endpoint -from forge.controller import ForgeActor +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) @dataclass @@ -95,7 +99,7 @@ def _evict(self, curr_policy_version: int) -> None: ] buffer_len_after_evict = len(self.buffer) - print( + logger.info( f"maximum policy age: {self.max_policy_age}, current policy version: {curr_policy_version}, " f"{buffer_len_before_evict - buffer_len_after_evict} episodes expired, {buffer_len_after_evict} episodes left" ) @@ -113,7 +117,7 @@ async def _numel(self) -> int: async def clear(self) -> None: """Clear the replay buffer immediately - dropping all episodes.""" self.buffer.clear() - print("replay buffer cleared") + logger.info("replay buffer cleared") @endpoint async def state_dict(self) -> dict[str, Any]: @@ -127,3 +131,6 @@ async def state_dict(self) -> dict[str, Any]: async def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.buffer = state_dict["buffer"] random.setstate(state_dict["rng_state"]) + + def __post_init__(self): + super().__init__() From e1f3e5673c7d0459591cbc0c016074594c3954b4 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Fri, 19 Sep 2025 11:25:09 -0400 Subject: [PATCH 6/8] lint --- src/forge/actors/replay_buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index e0fb438bb..12c11a8e1 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -9,10 +9,10 @@ from dataclasses import dataclass from typing import Any, Callable -from forge.controller import ForgeActor - from monarch.actor import endpoint +from forge.controller import ForgeActor + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) From a0ccc7e3a30696d449b21c63c3c86ea077519e2d Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Fri, 19 Sep 2025 14:34:37 -0400 Subject: [PATCH 7/8] address comments --- apps/toy_rl/sumdigits.py | 10 ++++++++-- src/forge/actors/replay_buffer.py | 10 +++++----- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/apps/toy_rl/sumdigits.py b/apps/toy_rl/sumdigits.py index 210bef8d3..ede00b235 100644 --- a/apps/toy_rl/sumdigits.py +++ b/apps/toy_rl/sumdigits.py @@ -397,7 +397,13 @@ async def main(cfg: DictConfig): # ---- Setup services ---- # await ts.initialize() - (dataloader, policy, trainer, replay_buffer, reward_actor,) = await asyncio.gather( + ( + dataloader, + policy, + trainer, + replay_buffer, + reward_actor, + ) = await asyncio.gather( DatasetActor.options(**cfg.services.dataset).as_service(**cfg.dataset), Policy.options(**cfg.services.policy).as_service(**cfg.policy), Trainer.options(**cfg.services.trainer).as_service(**cfg.trainer), @@ -475,7 +481,7 @@ async def continuous_training(): print(f"loss/training_step: {loss} at training step {training_step}") await trainer.push_weights.call(training_step) await policy.update_weights.call(training_step) - # TODO: remove this line? this clears the buffer so it's always on-policy + # NOTE: hard-coded to be on-policy for faster convergence await replay_buffer.clear.call() print("Starting training loop.") diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index 12c11a8e1..fd60ce35c 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -27,6 +27,9 @@ class ReplayBuffer(ForgeActor): seed: int | None = None collate: Callable = lambda batch: batch + def __post_init__(self): + super().__init__() + @endpoint async def setup(self) -> None: self.buffer: list = [] @@ -99,7 +102,7 @@ def _evict(self, curr_policy_version: int) -> None: ] buffer_len_after_evict = len(self.buffer) - logger.info( + logger.debug( f"maximum policy age: {self.max_policy_age}, current policy version: {curr_policy_version}, " f"{buffer_len_before_evict - buffer_len_after_evict} episodes expired, {buffer_len_after_evict} episodes left" ) @@ -117,7 +120,7 @@ async def _numel(self) -> int: async def clear(self) -> None: """Clear the replay buffer immediately - dropping all episodes.""" self.buffer.clear() - logger.info("replay buffer cleared") + logger.debug("replay buffer cleared") @endpoint async def state_dict(self) -> dict[str, Any]: @@ -131,6 +134,3 @@ async def state_dict(self) -> dict[str, Any]: async def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.buffer = state_dict["buffer"] random.setstate(state_dict["rng_state"]) - - def __post_init__(self): - super().__init__() From 0f8b999957f268f691557d33ed74b598b7cf508d Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Fri, 19 Sep 2025 14:47:06 -0400 Subject: [PATCH 8/8] lint --- apps/toy_rl/sumdigits.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/apps/toy_rl/sumdigits.py b/apps/toy_rl/sumdigits.py index ede00b235..d4780f2e6 100644 --- a/apps/toy_rl/sumdigits.py +++ b/apps/toy_rl/sumdigits.py @@ -397,13 +397,7 @@ async def main(cfg: DictConfig): # ---- Setup services ---- # await ts.initialize() - ( - dataloader, - policy, - trainer, - replay_buffer, - reward_actor, - ) = await asyncio.gather( + (dataloader, policy, trainer, replay_buffer, reward_actor,) = await asyncio.gather( DatasetActor.options(**cfg.services.dataset).as_service(**cfg.dataset), Policy.options(**cfg.services.policy).as_service(**cfg.policy), Trainer.options(**cfg.services.trainer).as_service(**cfg.trainer),