From 63e7e78b31a04025ffd203286c06055894f5bcf8 Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 20 Oct 2025 17:17:23 -0700 Subject: [PATCH 1/3] provisioner as actor --- apps/grpo/main.py | 2 +- src/forge/controller/__init__.py | 4 +- src/forge/controller/provisioner.py | 82 ++++++++++++++++++----------- 3 files changed, 54 insertions(+), 34 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 85872681f..54a98f72f 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -346,7 +346,7 @@ async def main(cfg: DictConfig): # TODO: support multiple host meshes trainer_num_procs = cfg.actors.trainer["procs"] trainer_host_mesh_name = cfg.actors.trainer["mesh_name"] - trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name) + trainer_hosts = provisioner.get_host_mesh.call_one(trainer_host_mesh_name) await ts.initialize( mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}), strategy=ts.LocalRankStrategy(), diff --git a/src/forge/controller/__init__.py b/src/forge/controller/__init__.py index a579200e9..992eb71a6 100644 --- a/src/forge/controller/__init__.py +++ b/src/forge/controller/__init__.py @@ -5,9 +5,9 @@ # LICENSE file in the root directory of this source tree. from .actor import ForgeActor from .provisioner import ( + get_or_create_provisioner, get_proc_mesh, host_mesh_from_proc, - init_provisioner, shutdown, stop_proc_mesh, ) @@ -16,7 +16,7 @@ "ForgeActor", "get_proc_mesh", "stop_proc_mesh", - "init_provisioner", + "get_or_create_provisioner", "shutdown", "host_mesh_from_proc", ] diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index cb36b2568..2cda9f82e 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -16,7 +16,14 @@ from monarch._src.actor.actor_mesh import ActorMesh from monarch._src.actor.shape import Extent -from monarch.actor import Actor, endpoint, HostMesh, ProcMesh, this_host +from monarch.actor import ( + Actor, + endpoint, + get_or_spawn_controller, + HostMesh, + ProcMesh, + this_host, +) from monarch.tools import commands @@ -95,7 +102,7 @@ def release_gpus(self, gpu_ids: list[str]) -> None: self.available_gpus.add(int(gpu_id)) -class Provisioner: +class Provisioner(Actor): """A global resource provisioner.""" def __init__(self, cfg: ProvisionerConfig | None = None): @@ -138,11 +145,13 @@ def __init__(self, cfg: ProvisionerConfig | None = None): self._registered_actors: list["ForgeActor"] = [] self._registered_services: list["ServiceInterface"] = [] + @endpoint async def initialize(self): """Call this after creating the instance""" if self.launcher is not None: await self.launcher.initialize() + @endpoint async def create_host_mesh(self, name: str, num_hosts: int) -> HostMesh: """Creates a remote server and a HostMesh on it.""" # no need to lock here because this is already locked behind `get_proc_mesh` @@ -172,6 +181,7 @@ async def create_host_mesh(self, name: str, num_hosts: int) -> HostMesh: ) return host_mesh, server_name + @endpoint def get_host_mesh(self, name: str) -> HostMesh: """Returns the host mesh given its associated name. @@ -181,6 +191,7 @@ def get_host_mesh(self, name: str) -> HostMesh: """ return self._host_mesh_map[name] + @endpoint async def get_proc_mesh( self, num_procs: int, @@ -225,7 +236,7 @@ async def get_proc_mesh( created_hosts = len(self._server_names) mesh_name = f"alloc_{created_hosts}" if host_mesh is None: - host_mesh, server_name = await self.create_host_mesh( + host_mesh, server_name = await self.create_host_mesh.call_one( name=mesh_name, num_hosts=num_hosts, ) @@ -318,6 +329,7 @@ def bootstrap(env: dict[str, str]): _ = await get_or_create_metric_logger(procs, process_name=mesh_name) return procs + @endpoint async def host_mesh_from_proc(self, proc_mesh: ProcMesh): if proc_mesh not in self._proc_host_map: raise ValueError( @@ -325,6 +337,7 @@ async def host_mesh_from_proc(self, proc_mesh: ProcMesh): ) return self._proc_host_map[proc_mesh] + @endpoint async def stop_proc_mesh(self, proc_mesh: ProcMesh): """Stops a proc mesh.""" if proc_mesh not in self._proc_host_map: @@ -352,6 +365,7 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh): commands.kill(server_name) del self._proc_host_map[proc_mesh] + @endpoint def register_service(self, service: "ServiceInterface") -> None: """Registers a service allocation for cleanup.""" # Import ServiceInterface here instead of at top-level to avoid circular import @@ -364,6 +378,7 @@ def register_service(self, service: "ServiceInterface") -> None: self._registered_services.append(service) + @endpoint def register_actor(self, actor: "ForgeActor") -> None: """Registers a single actor allocation for cleanup.""" @@ -372,13 +387,15 @@ def register_actor(self, actor: "ForgeActor") -> None: self._registered_actors.append(actor) + @endpoint async def shutdown_all_allocations(self): """Gracefully shut down all tracked actors and services.""" + global _global_registered_services logger.info( - f"Shutting down {len(self._registered_services)} service(s) and {len(self._registered_actors)} actor(s)..." + f"Shutting down {len(_global_registered_services)} service(s) and {len(self._registered_actors)} actor(s)..." ) # --- ServiceInterface --- - for service in reversed(self._registered_services): + for service in reversed(_global_registered_services): try: await service.shutdown() @@ -398,29 +415,30 @@ async def shutdown_all_allocations(self): self._registered_actors.clear() self._registered_services.clear() + @endpoint async def shutdown(self): """Tears down all remaining remote allocations.""" - await self.shutdown_all_allocations() + await self.shutdown_all_allocations.call_one() async with self._lock: for server_name in self._server_names: commands.kill(server_name) -_provisioner: Provisioner | None = None - +_global_provisioner: Provisioner | None = None +_global_registered_services: list["ServiceInterface"] = [] -async def init_provisioner(cfg: ProvisionerConfig | None = None): - global _provisioner - if not _provisioner: - _provisioner = Provisioner(cfg) - await _provisioner.initialize() - return _provisioner - -async def _get_provisioner(): - if not _provisioner: - await init_provisioner() - return _provisioner +async def get_or_create_provisioner( + cfg: ProvisionerConfig | None = None, +) -> Provisioner: + """Gets or spawns the global Provisioner controller actor.""" + global _global_provisioner + if _global_provisioner is None: + _global_provisioner = await get_or_spawn_controller( + "provisioner_controller", Provisioner, cfg + ) + await _global_provisioner.initialize.call_one() + return _global_provisioner async def get_proc_mesh( @@ -445,8 +463,8 @@ async def get_proc_mesh( A proc mesh. """ - provisioner = await _get_provisioner() - return await provisioner.get_proc_mesh( + provisioner = await get_or_create_provisioner() + return await provisioner.get_proc_mesh.call_one( num_procs=process_config.procs, with_gpus=process_config.with_gpus, num_hosts=process_config.hosts, @@ -465,25 +483,27 @@ async def host_mesh_from_proc(proc_mesh: ProcMesh): API. """ - provisioner = await _get_provisioner() - return await provisioner.host_mesh_from_proc(proc_mesh) + provisioner = await get_or_create_provisioner() + return await provisioner.host_mesh_from_proc.call_one(proc_mesh) async def register_service(service: "ServiceInterface") -> None: """Registers a service allocation with the global provisioner.""" - provisioner = await _get_provisioner() - provisioner.register_service(service) + + # TODO: This is a temporary hack. Change this back once Services are actors + global _global_registered_services + _global_registered_services.append(service) async def register_actor(actor: "ForgeActor") -> None: """Registers an actor allocation with the global provisioner.""" - provisioner = await _get_provisioner() - provisioner.register_actor(actor) + provisioner = await get_or_create_provisioner() + provisioner.register_actor.call_one(actor) async def stop_proc_mesh(proc_mesh: ProcMesh): - provisioner = await _get_provisioner() - return await provisioner.stop_proc_mesh(proc_mesh=proc_mesh) + provisioner = await get_or_create_provisioner() + return await provisioner.stop_proc_mesh.call_one(proc_mesh=proc_mesh) async def shutdown_metric_logger(): @@ -504,8 +524,8 @@ async def shutdown(): logger.info("Shutting down provisioner..") - provisioner = await _get_provisioner() - result = await provisioner.shutdown() + provisioner = await get_or_create_provisioner() + result = await provisioner.shutdown.call_one() logger.info("Shutdown completed successfully") return result From a668434ac61a447ecdc24ee8964db67b2cf0b393 Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 20 Oct 2025 17:22:21 -0700 Subject: [PATCH 2/3] update all mains --- .meta/mast/main.py | 6 +- apps/grpo/main.py | 6 +- apps/grpo/notebook.ipynb | 1335 +++++++++-------- assets/versions.sh | 2 +- src/forge/controller/provisioner.py | 2 +- tests/integration_tests/test_policy_update.py | 4 +- tests/sandbox/rl_trainer/main.py | 4 +- tests/sandbox/vllm/main.py | 4 +- 8 files changed, 684 insertions(+), 679 deletions(-) diff --git a/.meta/mast/main.py b/.meta/mast/main.py index 513d96fc6..b764a7630 100644 --- a/.meta/mast/main.py +++ b/.meta/mast/main.py @@ -15,7 +15,7 @@ MastLauncher, mount_mnt_directory, ) -from forge.controller.provisioner import init_provisioner +from forge.controller.provisioner import get_or_create_provisioner from forge.types import ( Launcher, @@ -68,7 +68,9 @@ async def main(cfg: DictConfig, mode: str = "detached", extra_args: list = None) else: # In remote mode, we're already running inside MAST, so mount directory, init provisioner and run training mount_mnt_directory("/mnt/wsfuse") - await init_provisioner(ProvisionerConfig(launcher_config=launcher_config)) + await get_or_create_provisioner( + ProvisionerConfig(launcher_config=launcher_config) + ) await grpo_main(cfg) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 54a98f72f..bbd64f415 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -25,7 +25,7 @@ from forge.actors.replay_buffer import ReplayBuffer from forge.actors.trainer import RLTrainer from forge.controller.actor import ForgeActor -from forge.controller.provisioner import init_provisioner, shutdown +from forge.controller.provisioner import get_or_create_provisioner, shutdown from forge.data.rewards import MathReward, ThinkingReward from forge.data_models.completion import Completion from forge.observability.metric_actors import get_or_create_metric_logger @@ -298,11 +298,11 @@ async def main(cfg: DictConfig): # ---- Global setups ---- # provisioner = None if cfg.get("provisioner", None) is not None: - provisioner = await init_provisioner( + provisioner = await get_or_create_provisioner( ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) ) else: - provisioner = await init_provisioner() + provisioner = await get_or_create_provisioner() metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) mlogger = await get_or_create_metric_logger(process_name="Controller") diff --git a/apps/grpo/notebook.ipynb b/apps/grpo/notebook.ipynb index 8d9fbc75a..ceebb997b 100644 --- a/apps/grpo/notebook.ipynb +++ b/apps/grpo/notebook.ipynb @@ -1,669 +1,672 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "46c66f45-4be3-4674-a870-3849c1048ddb", - "metadata": {}, - "source": [ - "# GRPO for Math (GSM8k)\n", - "\n", - "## Import modules" - ] + "cells": [ + { + "cell_type": "markdown", + "id": "46c66f45-4be3-4674-a870-3849c1048ddb", + "metadata": {}, + "source": [ + "# GRPO for Math (GSM8k)\n", + "\n", + "## Import modules" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "97d9ca00-92a8-4bd3-9b2b-ab8856f5acce", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) Meta Platforms, Inc. and affiliates.\n", + "# All rights reserved.\n", + "#\n", + "# This source code is licensed under the BSD-style license found in the\n", + "# LICENSE file in the root directory of this source tree.\n", + "\n", + "import asyncio\n", + "import time\n", + "import uuid\n", + "from dataclasses import dataclass\n", + "from typing import Any, Callable\n", + "\n", + "import torch\n", + "import torch.nn.functional as F\n", + "import torchstore as ts\n", + "from datasets import load_dataset\n", + "from forge.actors._torchstore_utils import (\n", + " get_dcp_whole_state_dict_key,\n", + " get_param_prefix,\n", + ")\n", + "from forge.actors.generator import Generator as Policy\n", + "from forge.actors.reference_model import ReferenceModel\n", + "from forge.actors.replay_buffer import ReplayBuffer\n", + "from forge.actors.trainer import RLTrainer\n", + "from forge.cli.config import parse\n", + "from forge.controller.actor import ForgeActor\n", + "from forge.controller.provisioner import get_or_create_provisioner, shutdown\n", + "from forge.data.rewards import MathReward, ThinkingReward\n", + "from forge.observability.metric_actors import get_or_create_metric_logger\n", + "from forge.observability.metrics import record_metric, Reduce\n", + "from forge.observability.perf_tracker import Tracer\n", + "\n", + "from forge.types import LauncherConfig, ProvisionerConfig\n", + "from forge.util.ops import compute_logprobs\n", + "from monarch.actor import endpoint\n", + "from omegaconf import DictConfig\n", + "from vllm.transformers_utils.tokenizer import get_tokenizer\n", + "\n", + "import os\n", + "os.environ[\"MONARCH_HOSTMESH_V1\"] = \"1\"\n", + "os.environ[\"TORCHSTORE_RDMA_ENABLED\"] = \"1\"" + ] + }, + { + "cell_type": "markdown", + "id": "34d4319f-e6c9-4f4b-9b92-c572de08f0b2", + "metadata": {}, + "source": [ + "## Define Data Structures" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4a25e9d-e1dd-4ea7-a80c-383a2c04656a", + "metadata": {}, + "outputs": [], + "source": [ + "@dataclass\n", + "class Episode:\n", + " # TODO: add adtional layer for multi-turn\n", + " episode_id: str\n", + " request: str\n", + " policy_version: int\n", + " pad_id: int\n", + " request_len: int\n", + " response_len: int\n", + " target: Any | None = None\n", + " # processed data\n", + " response: str | None = None\n", + " request_tokens: list[int] | None = None\n", + " response_tokens: list[int] | None = None\n", + " ref_logprobs: torch.Tensor | None = None\n", + " reward: float | None = None\n", + " advantage: float | None = None\n", + "\n", + " @property\n", + " def request_tensor(self):\n", + " tensor = torch.tensor(self.request_tokens, dtype=torch.long)\n", + " if tensor.shape[0] < self.request_len: # left pad\n", + " diff = self.request_len - tensor.shape[0]\n", + " tensor = F.pad(tensor, (diff, 0), value=self.pad_id)\n", + " return tensor\n", + "\n", + " @property\n", + " def response_tensor(self):\n", + " tensor = torch.tensor(self.response_tokens, dtype=torch.long)\n", + " if tensor.shape[0] < self.response_len: # right pad\n", + " diff = self.response_len - tensor.shape[0]\n", + " tensor = F.pad(tensor, (0, diff), value=self.pad_id)\n", + " return tensor\n", + "\n", + "\n", + "@dataclass\n", + "class Group:\n", + " group_id: str\n", + " episodes: list[Episode]\n", + "\n", + " @classmethod\n", + " def new_group(\n", + " cls,\n", + " group_id: int,\n", + " group_size: int,\n", + " request: str,\n", + " policy_version: int,\n", + " pad_id: int,\n", + " request_len: int,\n", + " response_len: int,\n", + " target: Any = None,\n", + " ):\n", + " episodes = []\n", + " for _ in range(group_size):\n", + " episodes.append(\n", + " Episode(\n", + " episode_id=str(uuid.uuid4()),\n", + " request=request,\n", + " policy_version=policy_version,\n", + " pad_id=pad_id,\n", + " request_len=request_len,\n", + " response_len=response_len,\n", + " target=target,\n", + " )\n", + " )\n", + " return cls(str(group_id), episodes)\n", + "\n", + "\n", + "def collate(batches: list[list[Episode]]):\n", + " inputs = []\n", + " targets = []\n", + " for batch in batches:\n", + " request = [e.request_tensor for e in batch]\n", + " request = torch.stack(request) # [b x s]\n", + "\n", + " response = [e.response_tensor for e in batch]\n", + " response = torch.stack(response) # [b x s]\n", + "\n", + " ref_logprobs = [e.ref_logprobs for e in batch]\n", + " ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s]\n", + "\n", + " advantages = [e.advantage for e in batch]\n", + " advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1]\n", + "\n", + " pad_id = batch[0].pad_id\n", + " mask = response != pad_id\n", + "\n", + " input = {\"tokens\": torch.cat([request, response], dim=1)}\n", + " target = {\n", + " \"response\": response,\n", + " \"ref_logprobs\": ref_logprobs,\n", + " \"advantages\": advantages,\n", + " \"padding_mask\": mask,\n", + " }\n", + " inputs.append(input)\n", + " targets.append(target)\n", + " return inputs, targets\n", + "\n", + "@dataclass\n", + "class DatasetActor(ForgeActor):\n", + " \"\"\"Actor wrapper for HuggingFace dataset to provide async interface.\"\"\"\n", + "\n", + " path: str = \"openai/gsm8k\"\n", + " revision: str = \"main\"\n", + " data_split: str = \"train\"\n", + " streaming: bool = True\n", + " model: str = \"Qwen/Qwen3-1.7B\"\n", + "\n", + " @endpoint\n", + " def setup(self):\n", + " self._tokenizer = get_tokenizer(self.model)\n", + "\n", + " def gsm8k_transform(sample):\n", + " system_prompt = \"\"\"\n", + " Put all your scratchpad work between and tags.\n", + " Your final answer should be between and tags otherwise it will not be scored.\n", + " \"\"\"\n", + " request: str = sample[\"question\"]\n", + " as_chat = [\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " {\"role\": \"user\", \"content\": request},\n", + " ]\n", + " formatted_request = self._tokenizer.apply_chat_template(\n", + " as_chat,\n", + " tokenize=False,\n", + " add_generation_prompt=True,\n", + " )\n", + " target: str = sample[\"answer\"]\n", + " formatted_target = target.split(\"#### \")[1]\n", + " return {\"request\": formatted_request, \"target\": formatted_target}\n", + "\n", + " ds = load_dataset(\n", + " self.path, self.revision, split=self.data_split, streaming=self.streaming\n", + " )\n", + " ds = ds.map(gsm8k_transform)\n", + " ds = ds.shuffle()\n", + " self._iterator = iter(ds)\n", + "\n", + " @endpoint\n", + " async def sample(self) -> dict[str, str] | None:\n", + " try:\n", + " sample = next(self._iterator)\n", + "\n", + " # Record dataset metrics\n", + " record_metric(\"dataset/sample/count_samples_generated\", 1, Reduce.SUM)\n", + " record_metric(\n", + " \"dataset/sample/avg_sample_len\",\n", + " len(sample[\"request\"]),\n", + " Reduce.MEAN,\n", + " )\n", + "\n", + " return sample\n", + " except StopIteration:\n", + " return None\n", + "\n", + " @endpoint\n", + " async def pad_token(self):\n", + " return self._tokenizer.pad_token_id" + ] + }, + { + "cell_type": "markdown", + "id": "901b3d1d-7eba-4464-b881-48c11ff6e0ef", + "metadata": {}, + "source": [ + "## Define loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "934aca32-0953-4945-9f99-e7b34804443b", + "metadata": {}, + "outputs": [], + "source": [ + "def simple_grpo_loss(\n", + " logits: torch.Tensor,\n", + " response: torch.Tensor,\n", + " ref_logprobs: torch.Tensor,\n", + " advantages: torch.Tensor,\n", + " padding_mask: torch.Tensor,\n", + " beta: float = 0.1,\n", + ") -> torch.Tensor:\n", + " \"\"\"\n", + " Example GRPO Loss Function for RLTrainer\n", + " \"\"\"\n", + " logprobs: torch.Tensor = compute_logprobs(logits, response)\n", + "\n", + " # Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss`\n", + " kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1\n", + " per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages\n", + " per_token_loss = -(per_token_policy_loss - beta * kl)\n", + " loss = (\n", + " ((per_token_loss * padding_mask).sum(dim=1))\n", + " / (padding_mask.sum(dim=1).clamp(min=1.0))\n", + " ).mean()\n", + " return loss" + ] + }, + { + "cell_type": "markdown", + "id": "d4f8bbe3-b7ac-4905-b197-f10990f9a104", + "metadata": {}, + "source": [ + "## Define Reward" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "163e98bf-e0f5-4ec3-9690-9839e687f9b3", + "metadata": {}, + "outputs": [], + "source": [ + "@dataclass\n", + "class RewardActor(ForgeActor):\n", + " \"\"\"Reward actor that uses a list of scoring functions.\"\"\"\n", + "\n", + " reward_functions: list[Callable]\n", + "\n", + " @endpoint\n", + " async def evaluate_response(self, prompt: str, response: str, target: str) -> float:\n", + " total_rewards = 0.0\n", + " for reward_fn in self.reward_functions:\n", + " reward = reward_fn(prompt, response, target)\n", + " total_rewards += reward\n", + "\n", + " # Get a name for the reward function (works for classes, functions, lambdas)\n", + " reward_fn_name = getattr(\n", + " reward_fn, \"__name__\", reward_fn.__class__.__name__\n", + " )\n", + " # per function reward\n", + " record_metric(\n", + " f\"reward/evaluate_response/sum_{reward_fn_name}_reward\",\n", + " reward,\n", + " Reduce.SUM,\n", + " )\n", + " record_metric(\n", + " f\"reward/evaluate_response/avg_{reward_fn_name}_reward\",\n", + " reward,\n", + " Reduce.MEAN,\n", + " )\n", + " record_metric(\n", + " f\"reward/evaluate_response/std_{reward_fn_name}_reward\",\n", + " reward,\n", + " Reduce.STD,\n", + " )\n", + "\n", + " # avg total reward\n", + " record_metric(\n", + " \"reward/evaluate_response/avg_total_reward\",\n", + " reward,\n", + " Reduce.MEAN,\n", + " )\n", + "\n", + " # count fn calls\n", + " record_metric(\n", + " f\"reward/evaluate_response/count_{reward_fn_name}_calls\",\n", + " 1,\n", + " Reduce.SUM,\n", + " )\n", + "\n", + " avg_reward = total_rewards / len(self.reward_functions)\n", + " return avg_reward\n", + "\n", + "\n", + "@dataclass\n", + "class ComputeAdvantages(ForgeActor):\n", + " \"\"\"Compute advantages for GRPO using reward signals.\"\"\"\n", + "\n", + " @endpoint\n", + " async def compute(self, group: Group) -> list[float]:\n", + " # TODO: add batch processing\n", + " rewards = torch.tensor([[e.reward for e in group.episodes]])\n", + " mean = rewards.mean(1, keepdim=True)\n", + " std = rewards.std(1, keepdim=True)\n", + " advantages = (rewards - mean) / (std + 1e-4)\n", + " return advantages.squeeze(0).tolist()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88523484-b414-41db-bd3f-0d8dbf881a85", + "metadata": {}, + "outputs": [], + "source": [ + "async def drop_weights(version: int):\n", + " print(f\"Dropping weights @ version {version}\")\n", + " start_time = time.perf_counter()\n", + " prefix = get_param_prefix(version)\n", + " matching_keys = await ts.keys(prefix)\n", + " # TODO: once we have something like `get_meta()` in torchstore, we can just\n", + " # query the type of the object instead of relying on keys.\n", + " dcp_key = get_dcp_whole_state_dict_key(version)\n", + " if dcp_key in matching_keys:\n", + " dcp_handle = await ts.get(dcp_key)\n", + " dcp_handle.drop()\n", + " for key in matching_keys:\n", + " await ts.delete(key)\n", + " elapsed = time.perf_counter() - start_time\n", + " print(f\"Dropped weights @ version {version}, took {elapsed:.2f} seconds\")" + ] + }, + { + "cell_type": "markdown", + "id": "95d4fef3-180b-4b7e-8871-ecbe113cde72", + "metadata": {}, + "source": [ + "## Setup Services" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6c811974-cd6b-40ed-a179-4511a7a6c489", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "from omegaconf import OmegaConf\n", + "from forge.cli.config import resolve_hf_hub_paths\n", + "\n", + "cfg = OmegaConf.load('apps/grpo/qwen3_1_7b.yaml')\n", + "cfg = resolve_hf_hub_paths(cfg)\n", + "OmegaConf.resolve(cfg)\n", + "\n", + "group_size = cfg.group_size # 8\n", + "max_req_tokens = cfg.max_req_tokens # 512\n", + "max_res_tokens = cfg.max_res_tokens # 512\n", + "\n", + "metric_logging_cfg = cfg.get(\"metric_logging\", {\"console\": {\"log_per_rank\": False}})\n", + "mlogger = await get_or_create_metric_logger()\n", + "await mlogger.init_backends.call_one(metric_logging_cfg)\n", + "await ts.initialize(strategy=ts.ControllerStorageVolumes())\n", + "\n", + "dataloader, policy, trainer, replay_buffer, compute_advantages, ref_model, reward_actor = await asyncio.gather(\n", + " DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),\n", + " Policy.options(**cfg.services.policy).as_service(**cfg.policy),\n", + " RLTrainer.options(**cfg.actors.trainer).as_actor(\n", + " **cfg.trainer, loss=simple_grpo_loss\n", + " ),\n", + " ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(\n", + " **cfg.replay_buffer, collate=collate\n", + " ),\n", + " ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(),\n", + " ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),\n", + " RewardActor.options(**cfg.services.reward_actor).as_service(\n", + " reward_functions=[MathReward(), ThinkingReward()]\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "3f2a305f-b1e2-4eac-803c-71bf3225fed7", + "metadata": {}, + "source": [ + "## Rollout Loop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1c676fb-2cd6-4c2c-87d4-e1b8cd0b87af", + "metadata": {}, + "outputs": [], + "source": [ + "async def continuous_rollouts():\n", + " rollout_count = 0\n", + " pad_id = await dataloader.pad_token.call_one()\n", + " while True:\n", + " t = Tracer(\"main_perf/continuous_rollouts\")\n", + " t.start()\n", + " sample = await dataloader.sample.call_one()\n", + " if sample is None:\n", + " print(\"Dataloader is empty, exiting continuous rollout\")\n", + " return\n", + "\n", + " t.step(\"data_loading\")\n", + "\n", + " prompt, target = sample[\"request\"], sample[\"target\"]\n", + " responses = await policy.generate.route(prompt)\n", + " # TODO: this shall be part of the responses metadata instead of a separate call\n", + " version = await policy.get_version.route()\n", + "\n", + " t.step(\"policy_generation\")\n", + "\n", + " assert (\n", + " len(responses) > 0\n", + " ), \"Sanity check: Responses should NEVER return empty\"\n", + " assert (\n", + " version := responses[0].generator_version\n", + " ) is not None, \"Response must indicate a version\"\n", + " group = Group.new_group(\n", + " group_id=rollout_count,\n", + " group_size=group_size,\n", + " request=prompt,\n", + " policy_version=version,\n", + " pad_id=pad_id,\n", + " request_len=max_req_tokens,\n", + " response_len=max_res_tokens,\n", + " target=target,\n", + " )\n", + "\n", + " input_ids = torch.ones(\n", + " (group_size, max_req_tokens + max_res_tokens),\n", + " dtype=torch.long,\n", + " device=\"cuda\",\n", + " )\n", + " # Populate episode info and calculate rewards\n", + " for i, (episode, response) in enumerate(zip(group.episodes, responses)):\n", + " episode.request_tokens = response.prompt_ids\n", + " episode.response_tokens = response.token_ids\n", + " episode.response = response.text\n", + " input_ids[i, :max_req_tokens] = episode.request_tensor\n", + " input_ids[i, max_req_tokens:] = episode.response_tensor\n", + " episode.reward = await reward_actor.evaluate_response.route(\n", + " prompt=prompt, response=response.text, target=target\n", + " )\n", + "\n", + " t.step(\"reward_evaluation\")\n", + "\n", + " ref_logprobs = await ref_model.forward.route(\n", + " input_ids, max_req_tokens, return_logprobs=True\n", + " )\n", + " t.step(\"reference_model_calculate_logprobs\")\n", + "\n", + " for i, episode in enumerate(group.episodes):\n", + " episode.ref_logprobs = ref_logprobs[i]\n", + " del ref_logprobs, input_ids\n", + " t.step(\"compute_logprobs\")\n", + "\n", + " # Calculate advantages and add to replay buffer\n", + " advantages = await compute_advantages.compute.call_one(group)\n", + " for episode, advantage in zip(group.episodes, advantages):\n", + " episode.advantage = advantage\n", + " await replay_buffer.add.call_one(episode)\n", + "\n", + " # Log metrics\n", + " rollout_count += 1\n", + " record_metric(\n", + " \"main/continuous_rollouts/count_rollout_iterations\", 1, Reduce.SUM\n", + " )\n", + " t.stop()" + ] + }, + { + "cell_type": "markdown", + "id": "57c316dc-11b5-48ea-8b03-e1bb9d9d1f2b", + "metadata": {}, + "source": [ + "## Training Loop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "916a0e79-aded-4ee3-b1a8-db0e772996c9", + "metadata": {}, + "outputs": [], + "source": [ + "async def continuous_training():\n", + " training_step = 0\n", + " restart_tracer = True # Flag to control when to restart tracer\n", + " while True:\n", + " # Restart tracer when needed (initial start or after completing a training step)\n", + " # Otherwise, we cannot measure time waiting for buffer\n", + " if restart_tracer:\n", + " t = Tracer(\"main_perf/continuous_training\")\n", + " t.start()\n", + " restart_tracer = False\n", + "\n", + " batch = await replay_buffer.sample.call_one(\n", + " curr_policy_version=training_step\n", + " )\n", + " if batch is None:\n", + " await asyncio.sleep(0.1)\n", + " else:\n", + " t.step(\"waiting_for_buffer\")\n", + "\n", + " inputs, targets = batch\n", + " await trainer.train_step.call(inputs, targets)\n", + " training_step += 1\n", + " t.step(\"train_step\")\n", + "\n", + " await trainer.push_weights.call(training_step)\n", + " t.step(\"push_weights\")\n", + "\n", + " await policy.update_weights.fanout(training_step)\n", + " update_task = asyncio.create_task(policy.update_weights.fanout(training_step))\n", + " t.step(\"update_weights\")\n", + "\n", + " if training_step >= 2:\n", + " await drop_weights(training_step - 1)\n", + " t.step(\"drop_weights\")\n", + "\n", + " t.stop()\n", + " restart_tracer = True\n", + "\n", + " # Flush metrics every training step to WandB\n", + " await mlogger.flush.call_one(training_step)" + ] + }, + { + "cell_type": "markdown", + "id": "4542863b-59c5-40bc-896c-6d8d44ada00f", + "metadata": {}, + "source": [ + "## Run" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58194c13-b75e-405d-ab11-18cbe1874d92", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "num_rollout_threads = 1\n", + "num_training_threads = 1\n", + "\n", + "rollout_tasks = [\n", + " asyncio.create_task(continuous_rollouts()) for _ in range(num_rollout_threads)\n", + "]\n", + "training_task = asyncio.create_task(continuous_training())\n", + "\n", + "try:\n", + " await asyncio.gather(*rollout_tasks, training_task)\n", + "except KeyboardInterrupt:\n", + " print(\"Training interrupted by user\")\n", + " for rollout_task in rollout_tasks:\n", + " rollout_task.cancel()\n", + " training_task.cancel()" + ] + }, + { + "cell_type": "markdown", + "id": "b4603b80-1f25-49a1-920e-d24f38dfc687", + "metadata": {}, + "source": [ + "## Shutdown" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d74e781-c253-4bd0-929f-bd4ad516ba81", + "metadata": {}, + "outputs": [], + "source": [ + "await mlogger.shutdown.call_one()\n", + "await asyncio.sleep(2)\n", + "\n", + "await asyncio.gather(\n", + " DatasetActor.shutdown(dataloader),\n", + " policy.shutdown(),\n", + " RLTrainer.shutdown(trainer),\n", + " ReplayBuffer.shutdown(replay_buffer),\n", + " ComputeAdvantages.shutdown(compute_advantages),\n", + " ref_model.shutdown(),\n", + " reward_actor.shutdown(),\n", + ")\n", + "await shutdown()" + ] + } + ], + "metadata": { + "fileHeader": "", + "fileUid": "33e39445-9f70-425a-b663-fe2c861ef07b", + "isAdHoc": false, + "kernelspec": { + "display_name": "forge", + "language": "python", + "name": "forge" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.18" + } }, - { - "cell_type": "code", - "execution_count": null, - "id": "97d9ca00-92a8-4bd3-9b2b-ab8856f5acce", - "metadata": {}, - "outputs": [], - "source": [ - "# Copyright (c) Meta Platforms, Inc. and affiliates.\n", - "# All rights reserved.\n", - "#\n", - "# This source code is licensed under the BSD-style license found in the\n", - "# LICENSE file in the root directory of this source tree.\n", - "\n", - "import asyncio\n", - "import time\n", - "import uuid\n", - "from dataclasses import dataclass\n", - "from typing import Any, Callable\n", - "\n", - "import torch\n", - "import torch.nn.functional as F\n", - "import torchstore as ts\n", - "from datasets import load_dataset\n", - "from forge.actors._torchstore_utils import (\n", - " get_dcp_whole_state_dict_key,\n", - " get_param_prefix,\n", - ")\n", - "from forge.actors.generator import Generator as Policy\n", - "from forge.actors.reference_model import ReferenceModel\n", - "from forge.actors.replay_buffer import ReplayBuffer\n", - "from forge.actors.trainer import RLTrainer\n", - "from forge.cli.config import parse\n", - "from forge.controller.actor import ForgeActor\n", - "from forge.controller.provisioner import init_provisioner, shutdown\n", - "from forge.data.rewards import MathReward, ThinkingReward\n", - "from forge.observability.metric_actors import get_or_create_metric_logger\n", - "from forge.observability.metrics import record_metric, Reduce\n", - "from forge.observability.perf_tracker import Tracer\n", - "\n", - "from forge.types import LauncherConfig, ProvisionerConfig\n", - "from forge.util.ops import compute_logprobs\n", - "from monarch.actor import endpoint\n", - "from omegaconf import DictConfig\n", - "from vllm.transformers_utils.tokenizer import get_tokenizer\n", - "\n", - "import os\n", - "os.environ[\"MONARCH_HOSTMESH_V1\"] = \"1\"\n", - "os.environ[\"TORCHSTORE_RDMA_ENABLED\"] = \"1\"" - ] - }, - { - "cell_type": "markdown", - "id": "34d4319f-e6c9-4f4b-9b92-c572de08f0b2", - "metadata": {}, - "source": [ - "## Define Data Structures" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b4a25e9d-e1dd-4ea7-a80c-383a2c04656a", - "metadata": {}, - "outputs": [], - "source": [ - "@dataclass\n", - "class Episode:\n", - " # TODO: add adtional layer for multi-turn\n", - " episode_id: str\n", - " request: str\n", - " policy_version: int\n", - " pad_id: int\n", - " request_len: int\n", - " response_len: int\n", - " target: Any | None = None\n", - " # processed data\n", - " response: str | None = None\n", - " request_tokens: list[int] | None = None\n", - " response_tokens: list[int] | None = None\n", - " ref_logprobs: torch.Tensor | None = None\n", - " reward: float | None = None\n", - " advantage: float | None = None\n", - "\n", - " @property\n", - " def request_tensor(self):\n", - " tensor = torch.tensor(self.request_tokens, dtype=torch.long)\n", - " if tensor.shape[0] < self.request_len: # left pad\n", - " diff = self.request_len - tensor.shape[0]\n", - " tensor = F.pad(tensor, (diff, 0), value=self.pad_id)\n", - " return tensor\n", - "\n", - " @property\n", - " def response_tensor(self):\n", - " tensor = torch.tensor(self.response_tokens, dtype=torch.long)\n", - " if tensor.shape[0] < self.response_len: # right pad\n", - " diff = self.response_len - tensor.shape[0]\n", - " tensor = F.pad(tensor, (0, diff), value=self.pad_id)\n", - " return tensor\n", - "\n", - "\n", - "@dataclass\n", - "class Group:\n", - " group_id: str\n", - " episodes: list[Episode]\n", - "\n", - " @classmethod\n", - " def new_group(\n", - " cls,\n", - " group_id: int,\n", - " group_size: int,\n", - " request: str,\n", - " policy_version: int,\n", - " pad_id: int,\n", - " request_len: int,\n", - " response_len: int,\n", - " target: Any = None,\n", - " ):\n", - " episodes = []\n", - " for _ in range(group_size):\n", - " episodes.append(\n", - " Episode(\n", - " episode_id=str(uuid.uuid4()),\n", - " request=request,\n", - " policy_version=policy_version,\n", - " pad_id=pad_id,\n", - " request_len=request_len,\n", - " response_len=response_len,\n", - " target=target,\n", - " )\n", - " )\n", - " return cls(str(group_id), episodes)\n", - "\n", - "\n", - "def collate(batches: list[list[Episode]]):\n", - " inputs = []\n", - " targets = []\n", - " for batch in batches:\n", - " request = [e.request_tensor for e in batch]\n", - " request = torch.stack(request) # [b x s]\n", - "\n", - " response = [e.response_tensor for e in batch]\n", - " response = torch.stack(response) # [b x s]\n", - "\n", - " ref_logprobs = [e.ref_logprobs for e in batch]\n", - " ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s]\n", - "\n", - " advantages = [e.advantage for e in batch]\n", - " advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1]\n", - "\n", - " pad_id = batch[0].pad_id\n", - " mask = response != pad_id\n", - "\n", - " input = {\"tokens\": torch.cat([request, response], dim=1)}\n", - " target = {\n", - " \"response\": response,\n", - " \"ref_logprobs\": ref_logprobs,\n", - " \"advantages\": advantages,\n", - " \"padding_mask\": mask,\n", - " }\n", - " inputs.append(input)\n", - " targets.append(target)\n", - " return inputs, targets\n", - "\n", - "@dataclass\n", - "class DatasetActor(ForgeActor):\n", - " \"\"\"Actor wrapper for HuggingFace dataset to provide async interface.\"\"\"\n", - "\n", - " path: str = \"openai/gsm8k\"\n", - " revision: str = \"main\"\n", - " data_split: str = \"train\"\n", - " streaming: bool = True\n", - " model: str = \"Qwen/Qwen3-1.7B\"\n", - "\n", - " @endpoint\n", - " def setup(self):\n", - " self._tokenizer = get_tokenizer(self.model)\n", - "\n", - " def gsm8k_transform(sample):\n", - " system_prompt = \"\"\"\n", - " Put all your scratchpad work between and tags.\n", - " Your final answer should be between and tags otherwise it will not be scored.\n", - " \"\"\"\n", - " request: str = sample[\"question\"]\n", - " as_chat = [\n", - " {\"role\": \"system\", \"content\": system_prompt},\n", - " {\"role\": \"user\", \"content\": request},\n", - " ]\n", - " formatted_request = self._tokenizer.apply_chat_template(\n", - " as_chat,\n", - " tokenize=False,\n", - " add_generation_prompt=True,\n", - " )\n", - " target: str = sample[\"answer\"]\n", - " formatted_target = target.split(\"#### \")[1]\n", - " return {\"request\": formatted_request, \"target\": formatted_target}\n", - "\n", - " ds = load_dataset(\n", - " self.path, self.revision, split=self.data_split, streaming=self.streaming\n", - " )\n", - " ds = ds.map(gsm8k_transform)\n", - " ds = ds.shuffle()\n", - " self._iterator = iter(ds)\n", - "\n", - " @endpoint\n", - " async def sample(self) -> dict[str, str] | None:\n", - " try:\n", - " sample = next(self._iterator)\n", - "\n", - " # Record dataset metrics\n", - " record_metric(\"dataset/sample/count_samples_generated\", 1, Reduce.SUM)\n", - " record_metric(\n", - " \"dataset/sample/avg_sample_len\",\n", - " len(sample[\"request\"]),\n", - " Reduce.MEAN,\n", - " )\n", - "\n", - " return sample\n", - " except StopIteration:\n", - " return None\n", - "\n", - " @endpoint\n", - " async def pad_token(self):\n", - " return self._tokenizer.pad_token_id" - ] - }, - { - "cell_type": "markdown", - "id": "901b3d1d-7eba-4464-b881-48c11ff6e0ef", - "metadata": {}, - "source": [ - "## Define loss" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "934aca32-0953-4945-9f99-e7b34804443b", - "metadata": {}, - "outputs": [], - "source": [ - "def simple_grpo_loss(\n", - " logits: torch.Tensor,\n", - " response: torch.Tensor,\n", - " ref_logprobs: torch.Tensor,\n", - " advantages: torch.Tensor,\n", - " padding_mask: torch.Tensor,\n", - " beta: float = 0.1,\n", - ") -> torch.Tensor:\n", - " \"\"\"\n", - " Example GRPO Loss Function for RLTrainer\n", - " \"\"\"\n", - " logprobs: torch.Tensor = compute_logprobs(logits, response)\n", - "\n", - " # Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss`\n", - " kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1\n", - " per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages\n", - " per_token_loss = -(per_token_policy_loss - beta * kl)\n", - " loss = (\n", - " ((per_token_loss * padding_mask).sum(dim=1))\n", - " / (padding_mask.sum(dim=1).clamp(min=1.0))\n", - " ).mean()\n", - " return loss" - ] - }, - { - "cell_type": "markdown", - "id": "d4f8bbe3-b7ac-4905-b197-f10990f9a104", - "metadata": {}, - "source": [ - "## Define Reward" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "163e98bf-e0f5-4ec3-9690-9839e687f9b3", - "metadata": {}, - "outputs": [], - "source": [ - "@dataclass\n", - "class RewardActor(ForgeActor):\n", - " \"\"\"Reward actor that uses a list of scoring functions.\"\"\"\n", - "\n", - " reward_functions: list[Callable]\n", - "\n", - " @endpoint\n", - " async def evaluate_response(self, prompt: str, response: str, target: str) -> float:\n", - " total_rewards = 0.0\n", - " for reward_fn in self.reward_functions:\n", - " reward = reward_fn(prompt, response, target)\n", - " total_rewards += reward\n", - "\n", - " # Get a name for the reward function (works for classes, functions, lambdas)\n", - " reward_fn_name = getattr(\n", - " reward_fn, \"__name__\", reward_fn.__class__.__name__\n", - " )\n", - " # per function reward\n", - " record_metric(\n", - " f\"reward/evaluate_response/sum_{reward_fn_name}_reward\",\n", - " reward,\n", - " Reduce.SUM,\n", - " )\n", - " record_metric(\n", - " f\"reward/evaluate_response/avg_{reward_fn_name}_reward\",\n", - " reward,\n", - " Reduce.MEAN,\n", - " )\n", - " record_metric(\n", - " f\"reward/evaluate_response/std_{reward_fn_name}_reward\",\n", - " reward,\n", - " Reduce.STD,\n", - " )\n", - "\n", - " # avg total reward\n", - " record_metric(\n", - " \"reward/evaluate_response/avg_total_reward\",\n", - " reward,\n", - " Reduce.MEAN,\n", - " )\n", - "\n", - " # count fn calls\n", - " record_metric(\n", - " f\"reward/evaluate_response/count_{reward_fn_name}_calls\",\n", - " 1,\n", - " Reduce.SUM,\n", - " )\n", - "\n", - " avg_reward = total_rewards / len(self.reward_functions)\n", - " return avg_reward\n", - "\n", - "\n", - "@dataclass\n", - "class ComputeAdvantages(ForgeActor):\n", - " \"\"\"Compute advantages for GRPO using reward signals.\"\"\"\n", - "\n", - " @endpoint\n", - " async def compute(self, group: Group) -> list[float]:\n", - " # TODO: add batch processing\n", - " rewards = torch.tensor([[e.reward for e in group.episodes]])\n", - " mean = rewards.mean(1, keepdim=True)\n", - " std = rewards.std(1, keepdim=True)\n", - " advantages = (rewards - mean) / (std + 1e-4)\n", - " return advantages.squeeze(0).tolist()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "88523484-b414-41db-bd3f-0d8dbf881a85", - "metadata": {}, - "outputs": [], - "source": [ - "async def drop_weights(version: int):\n", - " print(f\"Dropping weights @ version {version}\")\n", - " start_time = time.perf_counter()\n", - " prefix = get_param_prefix(version)\n", - " matching_keys = await ts.keys(prefix)\n", - " # TODO: once we have something like `get_meta()` in torchstore, we can just\n", - " # query the type of the object instead of relying on keys.\n", - " dcp_key = get_dcp_whole_state_dict_key(version)\n", - " if dcp_key in matching_keys:\n", - " dcp_handle = await ts.get(dcp_key)\n", - " dcp_handle.drop()\n", - " for key in matching_keys:\n", - " await ts.delete(key)\n", - " elapsed = time.perf_counter() - start_time\n", - " print(f\"Dropped weights @ version {version}, took {elapsed:.2f} seconds\")" - ] - }, - { - "cell_type": "markdown", - "id": "95d4fef3-180b-4b7e-8871-ecbe113cde72", - "metadata": {}, - "source": [ - "## Setup Services" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6c811974-cd6b-40ed-a179-4511a7a6c489", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "from omegaconf import OmegaConf\n", - "from forge.cli.config import resolve_hf_hub_paths\n", - "\n", - "cfg = OmegaConf.load('apps/grpo/qwen3_1_7b.yaml')\n", - "cfg = resolve_hf_hub_paths(cfg)\n", - "OmegaConf.resolve(cfg)\n", - "\n", - "group_size = cfg.group_size # 8\n", - "max_req_tokens = cfg.max_req_tokens # 512\n", - "max_res_tokens = cfg.max_res_tokens # 512\n", - "\n", - "metric_logging_cfg = cfg.get(\"metric_logging\", {\"console\": {\"log_per_rank\": False}})\n", - "mlogger = await get_or_create_metric_logger()\n", - "await mlogger.init_backends.call_one(metric_logging_cfg)\n", - "await ts.initialize(strategy=ts.ControllerStorageVolumes())\n", - "\n", - "dataloader, policy, trainer, replay_buffer, compute_advantages, ref_model, reward_actor = await asyncio.gather(\n", - " DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),\n", - " Policy.options(**cfg.services.policy).as_service(**cfg.policy),\n", - " RLTrainer.options(**cfg.actors.trainer).as_actor(\n", - " **cfg.trainer, loss=simple_grpo_loss\n", - " ),\n", - " ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(\n", - " **cfg.replay_buffer, collate=collate\n", - " ),\n", - " ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(),\n", - " ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),\n", - " RewardActor.options(**cfg.services.reward_actor).as_service(\n", - " reward_functions=[MathReward(), ThinkingReward()]\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "3f2a305f-b1e2-4eac-803c-71bf3225fed7", - "metadata": {}, - "source": [ - "## Rollout Loop" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c1c676fb-2cd6-4c2c-87d4-e1b8cd0b87af", - "metadata": {}, - "outputs": [], - "source": [ - "async def continuous_rollouts():\n", - " rollout_count = 0\n", - " pad_id = await dataloader.pad_token.call_one()\n", - " while True:\n", - " t = Tracer(\"main_perf/continuous_rollouts\")\n", - " t.start()\n", - " sample = await dataloader.sample.call_one()\n", - " if sample is None:\n", - " print(\"Dataloader is empty, exiting continuous rollout\")\n", - " return\n", - "\n", - " t.step(\"data_loading\")\n", - "\n", - " prompt, target = sample[\"request\"], sample[\"target\"]\n", - " responses = await policy.generate.route(prompt)\n", - " # TODO: this shall be part of the responses metadata instead of a separate call\n", - " version = await policy.get_version.route()\n", - "\n", - " t.step(\"policy_generation\")\n", - "\n", - " assert (\n", - " len(responses) > 0\n", - " ), \"Sanity check: Responses should NEVER return empty\"\n", - " assert (\n", - " version := responses[0].generator_version\n", - " ) is not None, \"Response must indicate a version\"\n", - " group = Group.new_group(\n", - " group_id=rollout_count,\n", - " group_size=group_size,\n", - " request=prompt,\n", - " policy_version=version,\n", - " pad_id=pad_id,\n", - " request_len=max_req_tokens,\n", - " response_len=max_res_tokens,\n", - " target=target,\n", - " )\n", - "\n", - " input_ids = torch.ones(\n", - " (group_size, max_req_tokens + max_res_tokens),\n", - " dtype=torch.long,\n", - " device=\"cuda\",\n", - " )\n", - " # Populate episode info and calculate rewards\n", - " for i, (episode, response) in enumerate(zip(group.episodes, responses)):\n", - " episode.request_tokens = response.prompt_ids\n", - " episode.response_tokens = response.token_ids\n", - " episode.response = response.text\n", - " input_ids[i, :max_req_tokens] = episode.request_tensor\n", - " input_ids[i, max_req_tokens:] = episode.response_tensor\n", - " episode.reward = await reward_actor.evaluate_response.route(\n", - " prompt=prompt, response=response.text, target=target\n", - " )\n", - "\n", - " t.step(\"reward_evaluation\")\n", - "\n", - " ref_logprobs = await ref_model.forward.route(\n", - " input_ids, max_req_tokens, return_logprobs=True\n", - " )\n", - " t.step(\"reference_model_calculate_logprobs\")\n", - "\n", - " for i, episode in enumerate(group.episodes):\n", - " episode.ref_logprobs = ref_logprobs[i]\n", - " del ref_logprobs, input_ids\n", - " t.step(\"compute_logprobs\")\n", - "\n", - " # Calculate advantages and add to replay buffer\n", - " advantages = await compute_advantages.compute.call_one(group)\n", - " for episode, advantage in zip(group.episodes, advantages):\n", - " episode.advantage = advantage\n", - " await replay_buffer.add.call_one(episode)\n", - "\n", - " # Log metrics\n", - " rollout_count += 1\n", - " record_metric(\n", - " \"main/continuous_rollouts/count_rollout_iterations\", 1, Reduce.SUM\n", - " )\n", - " t.stop()" - ] - }, - { - "cell_type": "markdown", - "id": "57c316dc-11b5-48ea-8b03-e1bb9d9d1f2b", - "metadata": {}, - "source": [ - "## Training Loop" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "916a0e79-aded-4ee3-b1a8-db0e772996c9", - "metadata": {}, - "outputs": [], - "source": [ - "async def continuous_training():\n", - " training_step = 0\n", - " restart_tracer = True # Flag to control when to restart tracer\n", - " while True:\n", - " # Restart tracer when needed (initial start or after completing a training step)\n", - " # Otherwise, we cannot measure time waiting for buffer\n", - " if restart_tracer:\n", - " t = Tracer(\"main_perf/continuous_training\")\n", - " t.start()\n", - " restart_tracer = False\n", - "\n", - " batch = await replay_buffer.sample.call_one(\n", - " curr_policy_version=training_step\n", - " )\n", - " if batch is None:\n", - " await asyncio.sleep(0.1)\n", - " else:\n", - " t.step(\"waiting_for_buffer\")\n", - "\n", - " inputs, targets = batch\n", - " await trainer.train_step.call(inputs, targets)\n", - " training_step += 1\n", - " t.step(\"train_step\")\n", - "\n", - " await trainer.push_weights.call(training_step)\n", - " t.step(\"push_weights\")\n", - "\n", - " await policy.update_weights.fanout(training_step)\n", - " update_task = asyncio.create_task(policy.update_weights.fanout(training_step))\n", - " t.step(\"update_weights\")\n", - "\n", - " if training_step >= 2:\n", - " await drop_weights(training_step - 1)\n", - " t.step(\"drop_weights\")\n", - "\n", - " t.stop()\n", - " restart_tracer = True\n", - "\n", - " # Flush metrics every training step to WandB\n", - " await mlogger.flush.call_one(training_step)" - ] - }, - { - "cell_type": "markdown", - "id": "4542863b-59c5-40bc-896c-6d8d44ada00f", - "metadata": {}, - "source": [ - "## Run" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "58194c13-b75e-405d-ab11-18cbe1874d92", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "num_rollout_threads = 1\n", - "num_training_threads = 1\n", - "\n", - "rollout_tasks = [\n", - " asyncio.create_task(continuous_rollouts()) for _ in range(num_rollout_threads)\n", - "]\n", - "training_task = asyncio.create_task(continuous_training())\n", - "\n", - "try:\n", - " await asyncio.gather(*rollout_tasks, training_task)\n", - "except KeyboardInterrupt:\n", - " print(\"Training interrupted by user\")\n", - " for rollout_task in rollout_tasks:\n", - " rollout_task.cancel()\n", - " training_task.cancel()" - ] - }, - { - "cell_type": "markdown", - "id": "b4603b80-1f25-49a1-920e-d24f38dfc687", - "metadata": {}, - "source": [ - "## Shutdown" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0d74e781-c253-4bd0-929f-bd4ad516ba81", - "metadata": {}, - "outputs": [], - "source": [ - "await mlogger.shutdown.call_one()\n", - "await asyncio.sleep(2)\n", - "\n", - "await asyncio.gather(\n", - " DatasetActor.shutdown(dataloader),\n", - " policy.shutdown(),\n", - " RLTrainer.shutdown(trainer),\n", - " ReplayBuffer.shutdown(replay_buffer),\n", - " ComputeAdvantages.shutdown(compute_advantages),\n", - " ref_model.shutdown(),\n", - " reward_actor.shutdown(),\n", - ")\n", - "await shutdown()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "forge", - "language": "python", - "name": "forge" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.18" - } - }, - "nbformat": 4, - "nbformat_minor": 5 + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/assets/versions.sh b/assets/versions.sh index 49a755dc0..29ee8f0b5 100644 --- a/assets/versions.sh +++ b/assets/versions.sh @@ -14,6 +14,6 @@ PYTORCH_VERSION="2.9.0.dev20250905" VLLM_BRANCH="v0.10.0" # Commit hashes -MONARCH_COMMIT="195503223b5c2896846171f60ac99dc6868f8f2c" +MONARCH_COMMIT="2f14096083b1cc1dac6ae15220e4135bc23f9dd3" TORCHTITAN_COMMIT="d0e25450bcac2332359b13fbda430dc701f073d4" TORCHSTORE_COMMIT="662299faf4fd50ee30bd9aa3f4ce8c0e2db1d310" diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 2cda9f82e..a50df1f55 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -498,7 +498,7 @@ async def register_service(service: "ServiceInterface") -> None: async def register_actor(actor: "ForgeActor") -> None: """Registers an actor allocation with the global provisioner.""" provisioner = await get_or_create_provisioner() - provisioner.register_actor.call_one(actor) + await provisioner.register_actor.call_one(actor) async def stop_proc_mesh(proc_mesh: ProcMesh): diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index 01f01a390..2d08cbe87 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -17,7 +17,7 @@ from forge.actors.generator import Generator from forge.actors.trainer import RLTrainer -from forge.controller.provisioner import init_provisioner +from forge.controller.provisioner import get_or_create_provisioner from forge.controller.service.service import uuid from forge.types import LauncherConfig, ProvisionerConfig @@ -194,7 +194,7 @@ async def _setup_and_teardown(request): logger.info(f"`trainer.use_dcp` is overriden to {use_dcp_override}") if cfg.get("provisioner", None) is not None: - await init_provisioner( + await get_or_create_provisioner( ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) ) await ts.initialize(strategy=ts.ControllerStorageVolumes()) diff --git a/tests/sandbox/rl_trainer/main.py b/tests/sandbox/rl_trainer/main.py index 55714c49d..34dd13107 100644 --- a/tests/sandbox/rl_trainer/main.py +++ b/tests/sandbox/rl_trainer/main.py @@ -12,7 +12,7 @@ import torchstore as ts from forge.actors.trainer import RLTrainer from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY -from forge.controller.provisioner import init_provisioner, shutdown +from forge.controller.provisioner import get_or_create_provisioner, shutdown from forge.observability.metric_actors import get_or_create_metric_logger from forge.observability.perf_tracker import Tracer from forge.types import ( @@ -164,7 +164,7 @@ async def main(cfg: DictConfig): trainer_dp_degree = cfg.trainer.parallelism.get("data_parallel_shard_degree", 1) dp_size = trainer_dp_degree if trainer_dp_degree != -1 else 1 - await init_provisioner( + await get_or_create_provisioner( ProvisionerConfig( launcher_config=LauncherConfig( launcher=cfg.get(LAUNCHER_KEY, Launcher.SLURM.value), diff --git a/tests/sandbox/vllm/main.py b/tests/sandbox/vllm/main.py index 425352340..2c1a0d5e4 100644 --- a/tests/sandbox/vllm/main.py +++ b/tests/sandbox/vllm/main.py @@ -15,7 +15,7 @@ from forge.actors.generator import Generator -from forge.controller.provisioner import init_provisioner, shutdown +from forge.controller.provisioner import get_or_create_provisioner, shutdown from forge.data_models.completion import Completion from forge.observability.metric_actors import get_or_create_metric_logger @@ -29,7 +29,7 @@ async def run(cfg: DictConfig): if cfg.get("provisioner", None) is not None: - await init_provisioner( + await get_or_create_provisioner( ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) ) metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) From a68e8311c356fe1edd35fc4f7fed36aaef7bd02e Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 21 Oct 2025 14:16:07 -0700 Subject: [PATCH 3/3] fix mapping issue + a repro test --- assets/versions.sh | 2 +- src/forge/actors/generator.py | 2 +- src/forge/controller/provisioner.py | 35 +++++++++++++++++++-------- test.py | 37 +++++++++++++++++++++++++++++ 4 files changed, 64 insertions(+), 12 deletions(-) create mode 100644 test.py diff --git a/assets/versions.sh b/assets/versions.sh index 29ee8f0b5..4defaa2a6 100644 --- a/assets/versions.sh +++ b/assets/versions.sh @@ -14,6 +14,6 @@ PYTORCH_VERSION="2.9.0.dev20250905" VLLM_BRANCH="v0.10.0" # Commit hashes -MONARCH_COMMIT="2f14096083b1cc1dac6ae15220e4135bc23f9dd3" +MONARCH_COMMIT="main" TORCHTITAN_COMMIT="d0e25450bcac2332359b13fbda430dc701f073d4" TORCHSTORE_COMMIT="662299faf4fd50ee30bd9aa3f4ce8c0e2db1d310" diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 6c2efd5e6..a618e8805 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -154,7 +154,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] worker_procs = await get_proc_mesh(process_config=process_config) # Then, grab a single host from the workers... - host_mesh = await host_mesh_from_proc(worker_procs) + host_mesh = await host_mesh_from_proc(worker_procs._uid) singleton_slice = {k: slice(0, 1) for k in host_mesh.extent.keys()} host_mesh = host_mesh.slice(**singleton_slice) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index a50df1f55..32e1ef2a3 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -294,6 +294,11 @@ def bootstrap(env: dict[str, str]): per_host={"procs": num_procs}, bootstrap=functools.partial(bootstrap, env=env_vars), ) + uid = str(uuid.uuid4()) + # Generate a unique ID to map procmesh to hostmesh + procs._uid = uid + print(f"Allocating procmesh with uid={uid}") + print(f"Allocating procs._uid: {procs._uid}") if with_gpus: # Set up environment variables for PyTorch distributed... @@ -319,7 +324,7 @@ def bootstrap(env: dict[str, str]): self._server_names.append(server_name) self._proc_server_map[procs] = server_name - self._proc_host_map[procs] = host_mesh + self._proc_host_map[uid] = host_mesh # Spawn LocalFetcherActor for this ProcMesh and register with GlobalLoggingActor. # When called, the LocalFetcherActor is broadcast by Monarch to all ranks in the ProcMesh. @@ -327,20 +332,27 @@ def bootstrap(env: dict[str, str]): from forge.observability.metric_actors import get_or_create_metric_logger _ = await get_or_create_metric_logger(procs, process_name=mesh_name) - return procs + + print(f"Returning procmesh with uid={uid}") + print(f"Returning procs._uid: {procs._uid}") + return procs, uid @endpoint - async def host_mesh_from_proc(self, proc_mesh: ProcMesh): - if proc_mesh not in self._proc_host_map: + async def host_mesh_from_proc(self, uid: str | None): + # uid: str | None = getattr(proc_mesh, "_uid", None) + print(f"self._proc_host_map: {self._proc_host_map}") + print(f"proc_mesh._uid: {uid}") + if uid is None or uid not in self._proc_host_map: raise ValueError( "The proc mesh was not allocated with an associated hostmesh." ) - return self._proc_host_map[proc_mesh] + return self._proc_host_map[uid] @endpoint async def stop_proc_mesh(self, proc_mesh: ProcMesh): """Stops a proc mesh.""" - if proc_mesh not in self._proc_host_map: + uid: str | None = getattr(proc_mesh, "_uid", None) + if uid is None or uid not in self._proc_host_map: logger.warning( f"proc mesh {proc_mesh} was requested to be stopped, but was either already stopped or " "was never registered with the provisioner." @@ -363,7 +375,7 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh): if proc_mesh in self._proc_server_map: server_name = self._proc_server_map[proc_mesh] commands.kill(server_name) - del self._proc_host_map[proc_mesh] + del self._proc_host_map[uid] @endpoint def register_service(self, service: "ServiceInterface") -> None: @@ -464,7 +476,7 @@ async def get_proc_mesh( """ provisioner = await get_or_create_provisioner() - return await provisioner.get_proc_mesh.call_one( + procs, uid = await provisioner.get_proc_mesh.call_one( num_procs=process_config.procs, with_gpus=process_config.with_gpus, num_hosts=process_config.hosts, @@ -474,9 +486,12 @@ async def get_proc_mesh( port=port, addr=addr, ) + setattr(procs, "_uid", uid) + print(f"Setting procs._uid: {procs._uid}") + return procs -async def host_mesh_from_proc(proc_mesh: ProcMesh): +async def host_mesh_from_proc(uid: str | None): """Returns the host mesh that allocated the original proc_mesh. This functionality will be enabled in Monarch, so this is a temporary @@ -484,7 +499,7 @@ async def host_mesh_from_proc(proc_mesh: ProcMesh): """ provisioner = await get_or_create_provisioner() - return await provisioner.host_mesh_from_proc.call_one(proc_mesh) + return await provisioner.host_mesh_from_proc.call_one(uid) async def register_service(service: "ServiceInterface") -> None: diff --git a/test.py b/test.py new file mode 100644 index 000000000..680c2007a --- /dev/null +++ b/test.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. +# All rights reserved. +# +# Minimal repro: Provisioner host_mesh_from_proc() UID mapping bug +# +# Run this with: +# python -m forge.tests.test_provisioner_uid_mapping + +import asyncio + +# import pytest + +from forge.controller.provisioner import ( + get_or_create_provisioner, + get_proc_mesh, + stop_proc_mesh, +) +from forge.types import ProcessConfig + + +# @pytest.mark.asyncio +async def test_provisioner_host_mesh_lookup_uid_mapping(): + prov = await get_or_create_provisioner() + pm = await get_proc_mesh( + ProcessConfig(procs=1, with_gpus=False, hosts=None, mesh_name="uid_repro") + ) + # UID is attached locally by the helper + assert hasattr(pm, "_uid") and pm._uid, "missing _uid on returned ProcMesh" + print(f"✅ got ProcMesh with UID {pm._uid}") + hm = await prov.host_mesh_from_proc.call_one(pm._uid) # if pass pm, _uid is None + assert hm is not None + await stop_proc_mesh(pm) + print("✅ repro passed") + + +if __name__ == "__main__": + asyncio.run(test_provisioner_host_mesh_lookup_uid_mapping())