Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .meta/mast/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
MastLauncher,
mount_mnt_directory,
)
from forge.controller.provisioner import init_provisioner
from forge.controller.provisioner import get_or_create_provisioner

from forge.types import (
Launcher,
Expand Down Expand Up @@ -68,7 +68,9 @@ async def main(cfg: DictConfig, mode: str = "detached", extra_args: list = None)
else:
# In remote mode, we're already running inside MAST, so mount directory, init provisioner and run training
mount_mnt_directory("/mnt/wsfuse")
await init_provisioner(ProvisionerConfig(launcher_config=launcher_config))
await get_or_create_provisioner(
ProvisionerConfig(launcher_config=launcher_config)
)
await grpo_main(cfg)


Expand Down
8 changes: 4 additions & 4 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from forge.actors.replay_buffer import ReplayBuffer
from forge.actors.trainer import RLTrainer
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import init_provisioner, shutdown
from forge.controller.provisioner import get_or_create_provisioner, shutdown
from forge.data.rewards import MathReward, ThinkingReward
from forge.data_models.completion import Completion
from forge.observability.metric_actors import get_or_create_metric_logger
Expand Down Expand Up @@ -298,11 +298,11 @@ async def main(cfg: DictConfig):
# ---- Global setups ---- #
provisioner = None
if cfg.get("provisioner", None) is not None:
provisioner = await init_provisioner(
provisioner = await get_or_create_provisioner(
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
)
else:
provisioner = await init_provisioner()
provisioner = await get_or_create_provisioner()

metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger(process_name="Controller")
Expand Down Expand Up @@ -346,7 +346,7 @@ async def main(cfg: DictConfig):
# TODO: support multiple host meshes
trainer_num_procs = cfg.actors.trainer["procs"]
trainer_host_mesh_name = cfg.actors.trainer["mesh_name"]
trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name)
trainer_hosts = provisioner.get_host_mesh.call_one(trainer_host_mesh_name)
await ts.initialize(
mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}),
strategy=ts.LocalRankStrategy(),
Expand Down
Loading
Loading