Skip to content
24 changes: 8 additions & 16 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from datasets import load_dataset
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.actors.replay_buffer import ReplayBuffer
from forge.controller import ServiceConfig, spawn_service
from forge.controller.actor import ForgeActor
from forge.controller.service import ServiceConfig, spawn_service
from forge.data.rewards import MathReward, ThinkingReward
from forge.util.metric_logging import get_metric_logger
from monarch.actor import endpoint
Expand Down Expand Up @@ -351,40 +351,33 @@ async def main():
)

# ---- Setup services ---- #
default_service_cfg = ServiceConfig(
procs_per_replica=1,
num_replicas=1,
)

policy = await spawn_service(
default_service_cfg,
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
Policy,
PolicyConfig(
num_workers=1,
worker_params=WorkerConfig(model=model),
sampling_params=SamplingOverrides(num_samples=group_size, max_tokens=16),
available_devices="3",
),
)

trainer = await spawn_service(
default_service_cfg,
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
Trainer,
learning_rate=1e-5,
beta=0.1,
model_name=model,
device=torch.device("cuda:1"),
)

replay_buffer = await spawn_service(
default_service_cfg,
ServiceConfig(procs_per_replica=1, num_replicas=1),
ReplayBuffer,
batch_size=4,
max_policy_age=1,
)

dataloader = await spawn_service(
default_service_cfg,
ServiceConfig(procs_per_replica=1, num_replicas=1),
DatasetActor,
"openai/gsm8k",
"main",
Expand All @@ -393,21 +386,20 @@ async def main():
)

compute_advantages = await spawn_service(
default_service_cfg,
ServiceConfig(procs_per_replica=1, num_replicas=1),
ComputeAdvantages,
gamma=0.99,
lambda_=0.95,
)

ref_model = await spawn_service(
default_service_cfg,
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
RefModel,
model_name=model,
device=torch.device("cuda:2"),
)

reward_actor = await spawn_service(
default_service_cfg,
ServiceConfig(procs_per_replica=1, num_replicas=1),
RewardActor,
reward_functions=[MathReward(), ThinkingReward()],
)
Expand Down
2 changes: 2 additions & 0 deletions apps/rl/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ trainer:
processes:
scheduler: local # local | mast (not supported yet)
num_hosts: 1
num_gpus: 4
num_procs: 4

optimizer:
Expand Down Expand Up @@ -65,6 +66,7 @@ replay_buffer:
processes:
scheduler: local # local | mast (not supported yet)
num_hosts: 1
num_gpus: 0
num_procs: 1

# policy:
Expand Down
3 changes: 2 additions & 1 deletion apps/sft_v2/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ comm:
model:
name: llama3
flavor: 8B
tokenizer_path: /tmp/Meta-Llama-3.1-8B-Instruct
tokenizer_path: /tmp/Llama-3.1-8B-Instruct

processes:
scheduler: local # local | mast (not supported yet)
num_hosts: 1
num_procs: 8
num_gpus: 8

optimizer:
name: AdamW
Expand Down
2 changes: 1 addition & 1 deletion apps/sft_v2/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

"""To run:

python -m apps.sft.main --config apps/sft/llama3_8b.yaml
python -m apps.sft_v2.main --config apps/sft_v2/llama3_8b.yaml

"""

Expand Down
3 changes: 1 addition & 2 deletions apps/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
from typing import List

from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.controller.service import ServiceConfig
from forge.controller.spawn import spawn_service
from forge.controller.service import ServiceConfig, spawn_service
from vllm.outputs import CompletionOutput


Expand Down
12 changes: 2 additions & 10 deletions src/forge/controller/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,11 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .actor import ForgeActor
from .interface import ServiceInterface, Session, SessionContext
from .proc_mesh import get_proc_mesh, spawn_actors
from .service import Service, ServiceConfig
from .spawn import spawn_service
from .proc_mesh import get_proc_mesh, spawn_actors, stop_proc_mesh

__all__ = [
"Service",
"ServiceConfig",
"ServiceInterface",
"Session",
"SessionContext",
"spawn_service",
"spawn_actors",
"stop_proc_mesh",
"get_proc_mesh",
"ForgeActor",
]
62 changes: 47 additions & 15 deletions src/forge/controller/proc_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@

import os
import socket
from functools import partial

from monarch.actor import proc_mesh, ProcMesh
from monarch.tools import commands
from monarch.tools.config import Config
from omegaconf import DictConfig

from forge.controller import ForgeActor

from forge.controller.system_controllers.gpu_manager import get_gpu_ids, release_gpus
from forge.types import ProcessConfig

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -48,27 +51,52 @@ async def spawn_actors(
set_address: bool = False,
):
"""Setup process Mesh and spawn Actors."""
mesh = await get_proc_mesh(processes, set_address)
mesh = await get_proc_mesh(processes)
actors = await mesh.spawn(name, actor_cls, **cfg)
actors.mesh = mesh
return actors


async def get_proc_mesh(process_config: ProcessConfig, set_address=False) -> ProcMesh:
env = None
if set_address:
env = {
"MASTER_ADDR": str(socket.gethostname()),
"MASTER_PORT": str(_find_free_port()),
}
async def get_proc_mesh(process_config: ProcessConfig) -> ProcMesh:
"""Returns a proc mesh with the given process config."""
# TODO - modify this to work with multi-host
env = {
"MASTER_ADDR": str(socket.gethostname()),
"MASTER_PORT": str(_find_free_port()),
}
gpu_ids = None

def _setup_env(env: dict[str, str]):
"""Sets up the environment on proc mesh creation."""
for k, v in env.items():
os.environ[k] = v

if process_config.scheduler == "local":
if process_config.num_hosts != 1:
raise ValueError("Local scheduler only supports 1 host")
return await proc_mesh(gpus=process_config.num_procs, env=env)

if process_config.with_gpus:
gpu_ids = await get_gpu_ids(process_config.num_procs)
env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_ids))

# TODO - update to use this_host() whenever it supports
# being run within actors:
# AttributeError: NYI: attempting to get ProcMesh attribute `slice` on object that's
# actually a ProcMeshRef
# return this_host().spawn_procs(
# per_host={"procs": process_config.num_procs},
# bootstrap=partial(_setup_env, env=env),
# )
m = proc_mesh(gpus=process_config.num_procs, env=env)
m._gpu_ids = gpu_ids
return m
elif process_config.scheduler == "mast":
if not MAST_SUPPORTED:
raise ValueError("MAST is not supported on this platform")

if process_config.with_gpus:
raise ValueError("NYI - need to add HostMesh tracking in GpuManager")

logging.info("Scheduling on MAST with: ", process_config)
jobname = f"monarch-{getpass.getuser()}"
config = Config(
Expand Down Expand Up @@ -104,12 +132,7 @@ async def get_proc_mesh(process_config: ProcessConfig, set_address=False) -> Pro
)
alloc = await allocator.allocate(AllocSpec(constraints, **mesh_dimensions))
if env:

def setup(): # noqa: FB811
for k, v in env.items():
os.environ[k] = v

p = await ProcMesh.from_alloc(alloc, setup=setup)
p = await ProcMesh.from_alloc(alloc, setup=partial(_setup_env, env=env))
else:
p = await ProcMesh.from_alloc(alloc)
await p.logging_option(stream_to_client=True, aggregate_window_sec=3)
Expand All @@ -118,6 +141,15 @@ def setup(): # noqa: FB811
raise ValueError("Unsupported scheduler: {}".format(process_config.scheduler))


async def stop_proc_mesh(mesh: ProcMesh) -> None:
"""Stops the given proc mesh."""
if hasattr(mesh, "_gpu_ids") and mesh._gpu_ids is not None:
gpu_ids = mesh._gpu_ids
logger.debug("Releasing GPUs: %s", gpu_ids)
await release_gpus(gpu_ids)
await mesh.stop()


def _find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("localhost", 0))
Expand Down
23 changes: 23 additions & 0 deletions src/forge/controller/service/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .interface import ServiceInterface, Session, SessionContext
from .metrics import ServiceMetrics
from .replica import Replica, ReplicaMetrics
from .service import Service, ServiceConfig
from .spawn import spawn_service

__all__ = [
"Replica",
"ReplicaMetrics",
"Service",
"ServiceConfig",
"ServiceInterface",
"ServiceMetrics",
"Session",
"SessionContext",
"spawn_service",
]
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dataclasses import dataclass, field
from typing import Dict, List

from forge.controller.replica import ReplicaMetrics
from forge.controller.service.replica import ReplicaMetrics


# TODO - tie this into metrics logger when it exists.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from enum import Enum
from typing import Optional

from forge.controller import get_proc_mesh
from forge.types import ProcessConfig

from monarch.actor import Actor, ActorError, ProcMesh

from forge.controller import get_proc_mesh, stop_proc_mesh
from forge.types import ProcessConfig

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

Expand Down Expand Up @@ -198,7 +198,7 @@ async def _do_recovery():
# Stop old proc_mesh if it exists
if old_proc_mesh is not None:
try:
await old_proc_mesh.stop()
await stop_proc_mesh(old_proc_mesh)
logger.debug(f"Old proc_mesh stopped for replica {self.idx}")
except Exception as e:
logger.warning(
Expand Down Expand Up @@ -468,7 +468,7 @@ async def stop(self):
# Stop the proc_mesh
if self.proc_mesh:
try:
await self.proc_mesh.stop()
await stop_proc_mesh(self.proc_mesh)
except Exception as e:
logger.warning(
"Error stopping proc_mesh for replica %d: %s", self.idx, e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@

from monarch.actor import Actor, endpoint

from forge.controller.interface import _session_context, Session
from forge.controller.metrics import ServiceMetrics
from forge.controller.replica import Replica, ServiceRequest
from forge.controller.service.interface import _session_context, Session

from forge.controller.service.metrics import ServiceMetrics
from forge.controller.service.replica import Replica, ServiceRequest
from forge.types import ServiceConfig

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

from monarch.actor import Actor, proc_mesh

from forge.controller import Service, ServiceConfig
from forge.controller.interface import ServiceInterface
from forge.controller.service import Service, ServiceConfig

from forge.controller.service.interface import ServiceInterface

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
Expand Down
12 changes: 12 additions & 0 deletions src/forge/controller/system_controllers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .gpu_manager import get_gpu_ids, release_gpus

__all__ = [
"get_gpu_ids",
"release_gpus",
]
Loading
Loading