Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .meta/mast/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from apps.grpo.main import main as grpo_main
from forge.cli.config import parse
from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY
from forge.controller.provisioner import init_provisioner
from forge.controller.provisioner import get_or_create_provisioner

from forge.types import (
Launcher,
Expand Down Expand Up @@ -45,7 +45,7 @@ async def main(cfg: DictConfig):
print(f"Overriding checkpoint folder to {cfg[DEFAULT_CHECKPOINT_FOLDER_KEY]}")

# init mast provisioner
await init_provisioner(
await get_or_create_provisioner(
ProvisionerConfig(
launcher_config=LauncherConfig(
launcher=Launcher(cfg.get(LAUNCHER_KEY, Launcher.MAST.value)),
Expand Down
8 changes: 4 additions & 4 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from forge.actors.trainer import RLTrainer
from forge.cli.config import parse
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import init_provisioner, shutdown
from forge.controller.provisioner import get_or_create_provisioner, shutdown
from forge.data.rewards import MathReward, ThinkingReward
from forge.data_models.completion import Completion
from forge.observability.metric_actors import get_or_create_metric_logger
Expand Down Expand Up @@ -298,11 +298,11 @@ async def main(cfg: DictConfig):
# ---- Global setups ---- #
provisioner = None
if cfg.get("provisioner", None) is not None:
provisioner = await init_provisioner(
provisioner = await get_or_create_provisioner(
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
)
else:
provisioner = await init_provisioner()
provisioner = await get_or_create_provisioner()

metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger()
Expand Down Expand Up @@ -346,7 +346,7 @@ async def main(cfg: DictConfig):
# TODO: support multiple host meshes
trainer_num_procs = cfg.actors.trainer["procs"]
trainer_host_mesh_name = cfg.actors.trainer["mesh_name"]
trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name)
trainer_hosts = provisioner.get_host_mesh.call_one(trainer_host_mesh_name)
await ts.initialize(
mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}),
strategy=ts.LocalRankStrategy(),
Expand Down
6 changes: 3 additions & 3 deletions src/forge/actors/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
worker_procs = await get_proc_mesh(process_config=process_config)

# Then, grab a single host from the workers...
host_mesh = await host_mesh_from_proc(worker_procs)
host_mesh = await host_mesh_from_proc.call_one(worker_procs)
singleton_slice = {k: slice(0, 1) for k in host_mesh.extent.keys()}
host_mesh = host_mesh.slice(**singleton_slice)

Expand Down Expand Up @@ -488,8 +488,8 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
# TODO - may want to expand stop to gracefully respond to
# ongoing requests.
await actor.stop.call()
await stop_proc_mesh(actor._worker_procs)
await stop_proc_mesh(actor._generator_proc)
await stop_proc_mesh.call_one(actor._worker_procs)
await stop_proc_mesh.call_one(actor._generator_proc)

@endpoint
async def save_model_params(self):
Expand Down
4 changes: 2 additions & 2 deletions src/forge/controller/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
# LICENSE file in the root directory of this source tree.
from .actor import ForgeActor
from .provisioner import (
get_or_create_provisioner,
get_proc_mesh,
host_mesh_from_proc,
init_provisioner,
shutdown,
stop_proc_mesh,
)
Expand All @@ -16,7 +16,7 @@
"ForgeActor",
"get_proc_mesh",
"stop_proc_mesh",
"init_provisioner",
"shutdown",
"host_mesh_from_proc",
"get_or_create_provisioner",
]
6 changes: 3 additions & 3 deletions src/forge/controller/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ async def as_service(
await service.__initialize__()
service_interface = ServiceInterface(service, cls)
# Register this service with the provisioner so it can cleanly shut this down
await register_service(service_interface)
await register_service.call_one(service_interface)
return service_interface

@endpoint
Expand Down Expand Up @@ -234,7 +234,7 @@ async def as_actor(cls: Type[T], *args, **actor_kwargs) -> T:
logger.info(f"Spawning actor {cls.__name__}")
actor = await cls.launch(*args, **actor_kwargs)
# Register this actor with the provisioner so it can cleanly shut this down
await register_actor(actor)
await register_actor.call_one(actor)
return actor

@classmethod
Expand All @@ -244,4 +244,4 @@ async def shutdown(cls, actor: "ForgeActor"):
"""
if actor._proc_mesh is None:
raise AssertionError("Called shutdown on a replica with no proc_mesh.")
await stop_proc_mesh(actor._proc_mesh)
await stop_proc_mesh.call_one(actor._proc_mesh)
85 changes: 59 additions & 26 deletions src/forge/controller/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@
from monarch._src.actor.actor_mesh import ActorMesh
from monarch._src.actor.shape import Extent

from monarch.actor import Actor, endpoint, HostMesh, ProcMesh, this_host
from monarch.actor import (
Actor,
endpoint,
get_or_spawn_controller,
HostMesh,
ProcMesh,
this_host,
)

from monarch.tools import commands

Expand Down Expand Up @@ -95,7 +102,7 @@ def release_gpus(self, gpu_ids: list[str]) -> None:
self.available_gpus.add(int(gpu_id))


class Provisioner:
class Provisioner(Actor):
"""A global resource provisioner."""

def __init__(self, cfg: ProvisionerConfig | None = None):
Expand Down Expand Up @@ -138,11 +145,13 @@ def __init__(self, cfg: ProvisionerConfig | None = None):
self._registered_actors: list["ForgeActor"] = []
self._registered_services: list["ServiceInterface"] = []

@endpoint
async def initialize(self):
"""Call this after creating the instance"""
if self.launcher is not None:
await self.launcher.initialize()

@endpoint
async def create_host_mesh(self, name: str, num_hosts: int) -> HostMesh:
"""Creates a remote server and a HostMesh on it."""
# no need to lock here because this is already locked behind `get_proc_mesh`
Expand Down Expand Up @@ -172,6 +181,7 @@ async def create_host_mesh(self, name: str, num_hosts: int) -> HostMesh:
)
return host_mesh, server_name

@endpoint
def get_host_mesh(self, name: str) -> HostMesh:
"""Returns the host mesh given its associated name.

Expand All @@ -181,6 +191,7 @@ def get_host_mesh(self, name: str) -> HostMesh:
"""
return self._host_mesh_map[name]

# @endpoint
async def get_proc_mesh(
self,
num_procs: int,
Expand Down Expand Up @@ -225,7 +236,7 @@ async def get_proc_mesh(
created_hosts = len(self._server_names)
mesh_name = f"alloc_{created_hosts}"
if host_mesh is None:
host_mesh, server_name = await self.create_host_mesh(
host_mesh, server_name = await self.create_host_mesh.call_one(
name=mesh_name,
num_hosts=num_hosts,
)
Expand Down Expand Up @@ -317,13 +328,15 @@ def bootstrap(env: dict[str, str]):
_ = await get_or_create_metric_logger(procs)
return procs

@endpoint
async def host_mesh_from_proc(self, proc_mesh: ProcMesh):
if proc_mesh not in self._proc_host_map:
raise ValueError(
"The proc mesh was not allocated with an associated hostmesh."
)
return self._proc_host_map[proc_mesh]

@endpoint
async def stop_proc_mesh(self, proc_mesh: ProcMesh):
"""Stops a proc mesh."""
if proc_mesh not in self._proc_host_map:
Expand Down Expand Up @@ -351,6 +364,7 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh):
commands.kill(server_name)
del self._proc_host_map[proc_mesh]

@endpoint
def register_service(self, service: "ServiceInterface") -> None:
"""Registers a service allocation for cleanup."""
# Import ServiceInterface here instead of at top-level to avoid circular import
Expand All @@ -363,6 +377,7 @@ def register_service(self, service: "ServiceInterface") -> None:

self._registered_services.append(service)

@endpoint
def register_actor(self, actor: "ForgeActor") -> None:
"""Registers a single actor allocation for cleanup."""

Expand All @@ -371,6 +386,7 @@ def register_actor(self, actor: "ForgeActor") -> None:

self._registered_actors.append(actor)

@endpoint
async def shutdown_all_allocations(self):
"""Gracefully shut down all tracked actors and services."""
logger.info(
Expand All @@ -397,29 +413,46 @@ async def shutdown_all_allocations(self):
self._registered_actors.clear()
self._registered_services.clear()

@endpoint
async def shutdown(self):
"""Tears down all remaining remote allocations."""
await self.shutdown_all_allocations()
await self.shutdown_all_allocations.call_one()
async with self._lock:
for server_name in self._server_names:
commands.kill(server_name)


_provisioner: Provisioner | None = None
_global_provisioner: Provisioner | None = None


async def get_or_create_provisioner(
cfg: ProvisionerConfig | None = None,
) -> Provisioner:
"""Gets or spawns the global Provisioner controller actor."""
global _global_provisioner
if _global_provisioner is None:
_global_provisioner = await get_or_spawn_controller(
"provisioner_controller", Provisioner, cfg
)
await _global_provisioner.initialize.call_one()
return _global_provisioner


# _provisioner: Provisioner | None = None


async def init_provisioner(cfg: ProvisionerConfig | None = None):
global _provisioner
if not _provisioner:
_provisioner = Provisioner(cfg)
await _provisioner.initialize()
return _provisioner
# async def init_provisioner(cfg: ProvisionerConfig | None = None):
# global _provisioner
# if not _provisioner:
# _provisioner = Provisioner(cfg)
# await _provisioner.initialize()
# return _provisioner


async def _get_provisioner():
if not _provisioner:
await init_provisioner()
return _provisioner
# async def _get_provisioner():
# if not _provisioner:
# await init_provisioner()
# return _provisioner


async def get_proc_mesh(
Expand All @@ -444,7 +477,7 @@ async def get_proc_mesh(
A proc mesh.

"""
provisioner = await _get_provisioner()
provisioner = await get_or_create_provisioner()
return await provisioner.get_proc_mesh(
num_procs=process_config.procs,
with_gpus=process_config.with_gpus,
Expand All @@ -464,25 +497,25 @@ async def host_mesh_from_proc(proc_mesh: ProcMesh):
API.

"""
provisioner = await _get_provisioner()
return await provisioner.host_mesh_from_proc(proc_mesh)
provisioner = await get_or_create_provisioner()
return await provisioner.host_mesh_from_proc.call_one(proc_mesh)


async def register_service(service: "ServiceInterface") -> None:
"""Registers a service allocation with the global provisioner."""
provisioner = await _get_provisioner()
provisioner.register_service(service)
provisioner = await get_or_create_provisioner()
provisioner.register_service.call_one(service)


async def register_actor(actor: "ForgeActor") -> None:
"""Registers an actor allocation with the global provisioner."""
provisioner = await _get_provisioner()
provisioner.register_actor(actor)
provisioner = await get_or_create_provisioner()
provisioner.register_actor.call_one(actor)


async def stop_proc_mesh(proc_mesh: ProcMesh):
provisioner = await _get_provisioner()
return await provisioner.stop_proc_mesh(proc_mesh=proc_mesh)
provisioner = await get_or_create_provisioner()
return await provisioner.stop_proc_mesh.call_one(proc_mesh=proc_mesh)


async def shutdown_metric_logger():
Expand All @@ -503,8 +536,8 @@ async def shutdown():

logger.info("Shutting down provisioner..")

provisioner = await _get_provisioner()
result = await provisioner.shutdown()
provisioner = await get_or_create_provisioner()
result = await provisioner.shutdown.call_one()

logger.info("Shutdown completed successfully")
return result
4 changes: 2 additions & 2 deletions tests/integration_tests/test_policy_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from forge.actors.trainer import RLTrainer
from forge.cli.config import resolve_hf_hub_paths
from forge.controller.provisioner import init_provisioner
from forge.controller.provisioner import get_or_create_provisioner

from forge.controller.service.service import uuid
from forge.types import LauncherConfig, ProvisionerConfig
Expand Down Expand Up @@ -194,7 +194,7 @@ async def _setup_and_teardown(request):
logger.info(f"`trainer.use_dcp` is overriden to {use_dcp_override}")

if cfg.get("provisioner", None) is not None:
await init_provisioner(
await get_or_create_provisioner(
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
)
await ts.initialize(strategy=ts.ControllerStorageVolumes())
Expand Down
4 changes: 2 additions & 2 deletions tests/sandbox/rl_trainer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from forge.actors.trainer import RLTrainer
from forge.cli.config import parse
from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY
from forge.controller.provisioner import init_provisioner, shutdown
from forge.controller.provisioner import get_or_create_provisioner, shutdown
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.perf_tracker import Tracer
from forge.types import (
Expand Down Expand Up @@ -164,7 +164,7 @@ async def main(cfg: DictConfig):
trainer_dp_degree = cfg.trainer.parallelism.get("data_parallel_shard_degree", 1)
dp_size = trainer_dp_degree if trainer_dp_degree != -1 else 1

await init_provisioner(
await get_or_create_provisioner(
ProvisionerConfig(
launcher_config=LauncherConfig(
launcher=cfg.get(LAUNCHER_KEY, Launcher.SLURM.value),
Expand Down
4 changes: 2 additions & 2 deletions tests/sandbox/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from forge.actors.generator import Generator
from forge.cli.config import parse

from forge.controller.provisioner import init_provisioner, shutdown
from forge.controller.provisioner import get_or_create_provisioner, shutdown

from forge.data_models.completion import Completion
from forge.observability.metric_actors import get_or_create_metric_logger
Expand All @@ -29,7 +29,7 @@

async def run(cfg: DictConfig):
if cfg.get("provisioner", None) is not None:
await init_provisioner(
await get_or_create_provisioner(
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
)
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
Expand Down