diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 32e3039ad..145b6cd48 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import asyncio +import logging import time from dataclasses import dataclass from typing import Callable @@ -14,12 +15,15 @@ from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig from forge.actors.replay_buffer import ReplayBuffer from forge.controller.actor import ForgeActor -from forge.controller.service import ServiceConfig, spawn_service +from forge.controller.service import ServiceConfig, shutdown_service, spawn_service from forge.data.rewards import MathReward, ThinkingReward from forge.util.metric_logging import get_metric_logger from monarch.actor import endpoint from transformers import AutoModelForCausalLM, AutoTokenizer +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + def compute_sequence_logprobs( model: torch.nn.Module, @@ -314,18 +318,18 @@ async def forward(self, token_ids: list[int]) -> torch.Tensor: class DatasetActor(ForgeActor): """Actor wrapper for HuggingFace dataset to provide async interface.""" - def __init__(self, *args, **kwargs): + def __init__( + self, path: str, config_name: str, split: str, streaming: bool, **kwargs + ): super().__init__() - self._setup_dataset(*args, **kwargs) - def _setup_dataset(self, *args, **kwargs): def gsm8k_to_messages(sample): question = sample["question"] full_answer: str = sample["answer"] answer = full_answer.split("#### ")[1] return {"question": question, "answer": answer} - ds = load_dataset(*args, **kwargs) + ds = load_dataset(path, config_name, split=split, streaming=streaming) ds = ds.map(gsm8k_to_messages) ds = ds.shuffle() self._iterator = iter(ds) @@ -351,57 +355,62 @@ async def main(): ) # ---- Setup services ---- # - policy = await spawn_service( - ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1), - Policy, - PolicyConfig( - num_workers=1, - worker_params=WorkerConfig(model=model), - sampling_params=SamplingOverrides(num_samples=group_size, max_tokens=16), + ( + dataloader, + policy, + trainer, + replay_buffer, + compute_advantages, + ref_model, + reward_actor, + ) = await asyncio.gather( + spawn_service( + ServiceConfig(procs_per_replica=1, num_replicas=1), + DatasetActor, + path="openai/gsm8k", + config_name="main", + split="train", + streaming=True, + ), + spawn_service( + ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1), + Policy, + config=PolicyConfig( + worker_params=WorkerConfig(model=model), + sampling_params=SamplingOverrides( + num_samples=group_size, max_tokens=16 + ), + ), + ), + spawn_service( + ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1), + Trainer, + learning_rate=1e-5, + beta=0.1, + model_name=model, + ), + spawn_service( + ServiceConfig(procs_per_replica=1, num_replicas=1), + ReplayBuffer, + batch_size=4, + max_policy_age=1, + ), + spawn_service( + ServiceConfig(procs_per_replica=1, num_replicas=1), + ComputeAdvantages, + gamma=0.99, + lambda_=0.95, + ), + spawn_service( + ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True), + RefModel, + model_name=model, + ), + spawn_service( + ServiceConfig(procs_per_replica=1, num_replicas=1), + RewardActor, + reward_functions=[MathReward(), ThinkingReward()], ), - ) - - trainer = await spawn_service( - ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1), - Trainer, - learning_rate=1e-5, - beta=0.1, - model_name=model, - ) - - replay_buffer = await spawn_service( - ServiceConfig(procs_per_replica=1, num_replicas=1), - ReplayBuffer, - batch_size=4, - max_policy_age=1, - ) - - dataloader = await spawn_service( - ServiceConfig(procs_per_replica=1, num_replicas=1), - DatasetActor, - "openai/gsm8k", - "main", - split="train", - streaming=True, - ) - - compute_advantages = await spawn_service( - ServiceConfig(procs_per_replica=1, num_replicas=1), - ComputeAdvantages, - gamma=0.99, - lambda_=0.95, - ) - - ref_model = await spawn_service( - ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True), - RefModel, - model_name=model, - ) - - reward_actor = await spawn_service( - ServiceConfig(procs_per_replica=1, num_replicas=1), - RewardActor, - reward_functions=[MathReward(), ThinkingReward()], ) print("All services initialized successfully!") @@ -409,8 +418,6 @@ async def main(): # ---- Core RL loops ---- # async def continuous_rollouts(): rollout_count = 0 - # TODO: Move this into setup - asyncio.create_task(policy.run_processing.call()) while True: sample = await dataloader.__next__.choose() if sample is None: @@ -481,6 +488,17 @@ async def continuous_training(): print("Training interrupted by user") rollout_task.cancel() training_task.cancel() + finally: + print("Shutting down...") + await asyncio.gather( + shutdown_service(policy), + shutdown_service(trainer), + shutdown_service(replay_buffer), + shutdown_service(dataloader), + shutdown_service(compute_advantages), + shutdown_service(ref_model), + shutdown_service(reward_actor), + ) if __name__ == "__main__": diff --git a/apps/vllm/main.py b/apps/vllm/main.py index 1e296b78c..2d3c81ad9 100644 --- a/apps/vllm/main.py +++ b/apps/vllm/main.py @@ -16,7 +16,7 @@ from typing import List from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig -from forge.controller.service import ServiceConfig, spawn_service +from forge.controller.service import ServiceConfig, shutdown_service, spawn_service from vllm.outputs import CompletionOutput @@ -58,9 +58,11 @@ def parse_args() -> Namespace: def get_configs(args: Namespace) -> (PolicyConfig, ServiceConfig): + + worker_size = 2 worker_params = WorkerConfig( model=args.model, - tensor_parallel_size=2, + tensor_parallel_size=worker_size, pipeline_parallel_size=1, enforce_eager=True, vllm_args=None, @@ -72,9 +74,11 @@ def get_configs(args: Namespace) -> (PolicyConfig, ServiceConfig): ) policy_config = PolicyConfig( - num_workers=2, worker_params=worker_params, sampling_params=sampling_params + worker_params=worker_params, sampling_params=sampling_params + ) + service_config = ServiceConfig( + procs_per_replica=worker_size, num_replicas=1, with_gpus=True ) - service_config = ServiceConfig(procs_per_replica=1, num_replicas=1) return policy_config, service_config @@ -82,25 +86,22 @@ def get_configs(args: Namespace) -> (PolicyConfig, ServiceConfig): async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt: str): print("Spawning service...") policy = await spawn_service(service_config, Policy, config=config) - session_id = await policy.start_session() - print("Starting background processing...") - processing_task = asyncio.create_task(policy.run_processing.call()) + async with policy.session(): + print("Requesting generation...") + responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt) - print("Requesting generation...") - responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt) + print("\nGeneration Results:") + print("=" * 80) + for batch, response in enumerate(responses): + print(f"Sample {batch + 1}:") + print(f"User: {prompt}") + print(f"Assistant: {response.text}") + print("-" * 80) - print("\nGeneration Results:") - print("=" * 80) - for batch, response in enumerate(responses): - print(f"Sample {batch + 1}:") - print(f"User: {prompt}") - print(f"Assistant: {response.text}") - print("-" * 80) + print("\nShutting down...") - print("\nShutting down...") - await policy.shutdown.call() - await policy.terminate_session(session_id) + await shutdown_service(policy) if __name__ == "__main__": diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 5ac24c006..4a51f7225 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -9,11 +9,11 @@ import os import sys from copy import copy -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from typing import Dict, List import torch -from monarch.actor import Actor, current_rank, endpoint, proc_mesh +from monarch.actor import current_rank, endpoint, ProcMesh from torchstore import MultiProcessStore from torchstore._state_dict_utils import DELIM @@ -25,7 +25,7 @@ from vllm.sampling_params import GuidedDecodingParams, RequestOutputKind, SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.utils import get_distributed_init_method, get_loopback_ip, get_open_port +from vllm.utils import get_distributed_init_method from vllm.v1.core.kv_cache_utils import get_kv_cache_config from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler @@ -37,8 +37,11 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase +from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh + from forge.data.sharding import VLLMSharding from forge.interfaces import Policy as PolicyInterface +from forge.types import ProcessConfig logger = logging.getLogger(__name__) @@ -85,7 +88,6 @@ class WorkerConfig: @dataclass class PolicyConfig: - num_workers: int worker_params: WorkerConfig sampling_params: SamplingOverrides available_devices: str = None @@ -95,24 +97,73 @@ class PolicyConfig: class Policy(PolicyInterface): config: PolicyConfig # Gets set up by setup - policy_worker: Actor = None + sampling_params: SamplingParams | None = None + lora_request: LoRARequest | None = None + tokenization_kwargs: dict = field(default_factory=dict) + policy_worker: "PolicyWorker" = None + + def __post_init__(self): + self._run_task: asyncio.Task | None = None + self._policy_proc: ProcMesh | None = None + self._worker_procs: ProcMesh | None = None + + @classmethod + async def launch( # pyright: ignore[reportIncompatibleMethodOverride] + cls: type["Policy"], + *, + process_config: ProcessConfig, + config: PolicyConfig, + **kwargs, + ) -> "Policy": + # Note - get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES + # automatically. + worker_procs = await get_proc_mesh(process_config=process_config) + + # TODO - we will want to ensure colocation with workers + policy_proc_config = copy(process_config) + policy_proc_config.num_procs = 1 + policy_proc_config.with_gpus = False + + policy_proc = await get_proc_mesh(process_config=policy_proc_config) + workers = await worker_procs.spawn( + "vllm_worker", PolicyWorker, **asdict(config.worker_params) + ) - sampling_params: SamplingParams = None - lora_request: LoRARequest = None - tokenization_kwargs: dict = None + # TODO - expand support so name can stick within kwargs + actor_name = kwargs.pop("name", cls.__name__) + policy = await policy_proc.spawn( + actor_name, cls, config=config, policy_worker=workers + ) + policy._policy_proc = policy_proc + policy._worker_procs = worker_procs + await policy.setup.call() + return policy + + @classmethod + async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] + cls: type["Policy"], actor: "Policy" + ): + assert ( + actor._policy_proc is not None + ), "Tried to shutdown a policy that was not initialized correctly" + assert ( + actor._worker_procs is not None + ), "Tried to shutdown a policy that was not initialized correctly" + + # TODO - may want to expand stop to gracefully respond to + # ongoing requests. + await actor.stop.call() + await stop_proc_mesh(actor._worker_procs) + await stop_proc_mesh(actor._policy_proc) @endpoint async def setup(self): # Set up policy_worker - self.available_devices = ( - self.config.available_devices - if self.config.available_devices is not None - else ",".join(str(i) for i in range(torch.cuda.device_count())) - ) - await self.spawn_workers() + assert self.policy_worker is not None, "Policy worker should not be None" + await self.policy_worker.setup.call() self.request_id = 0 - self.requests: Dict[str, Tuple[None | ParentRequest, asyncio.Future]] = {} + self.requests: Dict[str, tuple[None | ParentRequest, asyncio.Future]] = {} self.vllm_args = await self.policy_worker.get_vllm_args.choose() # Setup sampling params @@ -157,20 +208,12 @@ async def setup(self): include_finished_set=False, log_stats=None, ) + self.start_processing() - async def spawn_workers(self): - self.worker_mesh = await proc_mesh( - gpus=self.config.num_workers, - env={ - "MASTER_ADDR": str(get_loopback_ip()), - "MASTER_PORT": str(get_open_port()), - "CUDA_VISIBLE_DEVICES": self.available_devices, - }, - ) - self.policy_worker = await self.worker_mesh.spawn( - "policy_worker", PolicyWorker, **asdict(self.config.worker_params) - ) - await self.policy_worker.setup.call() + def start_processing(self): + """Start the replica's processing loop if not already running.""" + if self._run_task is None or self._run_task.done(): + self._run_task = asyncio.create_task(self.run()) @endpoint async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutput]: @@ -243,8 +286,7 @@ def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, i return request, 0 # Unused Arg: Current Wave - @endpoint - async def run_processing(self): + async def run(self): # TODO: add support for `iteration_stats` # TODO: move postprocessing out of loop to not block parallel_config = self.vllm_args.parallel_config @@ -276,12 +318,12 @@ async def update_weights(self): pass @endpoint - async def shutdown(self): + async def stop(self): self.running = False @dataclass -class PolicyWorker(Actor): +class PolicyWorker(ForgeActor): model: str tensor_parallel_size: int = 1 pipeline_parallel_size: int = 1 diff --git a/src/forge/controller/__init__.py b/src/forge/controller/__init__.py index b0269df41..71d35c433 100644 --- a/src/forge/controller/__init__.py +++ b/src/forge/controller/__init__.py @@ -3,8 +3,26 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + from .actor import ForgeActor -from .proc_mesh import get_proc_mesh, spawn_actors, stop_proc_mesh +from .proc_mesh import get_proc_mesh, stop_proc_mesh + + +# TODO - remove this once everything has moved to +# service +async def spawn_actors( + name: str, + actor_cls: ForgeActor, + cfg, + processes, + set_address: bool = False, +): + """Setup process Mesh and spawn Actors.""" + mesh = await get_proc_mesh(processes) + actors = await mesh.spawn(name, actor_cls, **cfg) + actors.mesh = mesh + return actors + __all__ = [ "spawn_actors", diff --git a/src/forge/controller/actor.py b/src/forge/controller/actor.py index ed5d41c85..77ac324ee 100644 --- a/src/forge/controller/actor.py +++ b/src/forge/controller/actor.py @@ -9,7 +9,13 @@ import math import sys -from monarch.actor import Actor, current_rank, current_size +from monarch.actor import Actor, current_rank, current_size, endpoint + +from forge.controller.proc_mesh import get_proc_mesh, stop_proc_mesh +from forge.types import ProcessConfig + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) class ForgeActor(Actor): @@ -30,6 +36,56 @@ def __init__(self, *args, **kwargs): stdout_handler.setLevel(logging.INFO) stdout_handler.setFormatter(formatter) + self._proc_mesh = None self.logger.root.setLevel(logging.INFO) self.logger.root.addHandler(stdout_handler) super().__init__(*args, **kwargs) + + @endpoint + async def setup(self): + """Sets up the actor. + + We assume a specific setup function for all actors. The + best practice for actor deployment is to: + 1. Pass all data to the actor via the constructor. + 2. Call setup() to for heavy weight initializations. + + This is to ensure that any failures during initialization + can be propagated back to the caller. + + """ + pass + + @classmethod + async def launch(cls, *, process_config: ProcessConfig, **kwargs) -> "ForgeActor": + """Provisions and deploys a new actor. + + This method is used by `Service` to provision a new replica. + + We implement it this way because special actors like inference servers + may be composed of multiple actors spawned across multiple processes. + This allows you to specify how your actor gets launched together. + + This implementation is basic, assuming that we're spawning + a homogeneous set of actors on a single proc mesh. + + """ + proc_mesh = await get_proc_mesh(process_config=process_config) + + # TODO - expand support so name can stick within kwargs + actor_name = kwargs.pop("name", cls.__name__) + actor = await proc_mesh.spawn(actor_name, cls, **kwargs) + actor._proc_mesh = proc_mesh + + await actor.setup.call() + return actor + + @classmethod + async def shutdown(cls, actor: "ForgeActor"): + """Shuts down an actor. + + This method is used by `Service` to teardown a replica. + """ + if actor._proc_mesh is None: + raise AssertionError("Called shutdown on a replica with no proc_mesh.") + await stop_proc_mesh(actor._proc_mesh) diff --git a/src/forge/controller/proc_mesh.py b/src/forge/controller/proc_mesh.py index 8b18d4cad..d48769987 100644 --- a/src/forge/controller/proc_mesh.py +++ b/src/forge/controller/proc_mesh.py @@ -16,9 +16,6 @@ from monarch.actor import proc_mesh, ProcMesh from monarch.tools import commands from monarch.tools.config import Config -from omegaconf import DictConfig - -from forge.controller import ForgeActor from forge.controller.system_controllers.gpu_manager import get_gpu_ids, release_gpus from forge.types import ProcessConfig @@ -43,20 +40,6 @@ ) -async def spawn_actors( - name: str, - actor_cls: ForgeActor, - cfg: DictConfig, - processes: ProcessConfig, - set_address: bool = False, -): - """Setup process Mesh and spawn Actors.""" - mesh = await get_proc_mesh(processes) - actors = await mesh.spawn(name, actor_cls, **cfg) - actors.mesh = mesh - return actors - - async def get_proc_mesh(process_config: ProcessConfig) -> ProcMesh: """Returns a proc mesh with the given process config.""" # TODO - modify this to work with multi-host diff --git a/src/forge/controller/service/__init__.py b/src/forge/controller/service/__init__.py index e38adff0a..0fbce6a22 100644 --- a/src/forge/controller/service/__init__.py +++ b/src/forge/controller/service/__init__.py @@ -8,7 +8,7 @@ from .metrics import ServiceMetrics from .replica import Replica, ReplicaMetrics from .service import Service, ServiceConfig -from .spawn import spawn_service +from .spawn import shutdown_service, spawn_service __all__ = [ "Replica", @@ -20,4 +20,5 @@ "Session", "SessionContext", "spawn_service", + "shutdown_service", ] diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index 3b6b4b687..32e8df468 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -150,14 +150,6 @@ def session(self) -> "SessionContext": """Returns a context manager for session-based calls.""" return SessionContext(self) - # Service control methods - forwarded to Service Actor - async def stop(self): - """Stops the service gracefully.""" - # First stop the service - await self._service.stop.call_one() - # Then stop its underlying proc - await self._proc_mesh.stop() - # Metrics methods - forwarded to Service Actor async def get_metrics(self): """Get comprehensive service metrics for monitoring and analysis.""" diff --git a/src/forge/controller/service/replica.py b/src/forge/controller/service/replica.py index 16fa3a866..b84e5eec7 100644 --- a/src/forge/controller/service/replica.py +++ b/src/forge/controller/service/replica.py @@ -13,9 +13,9 @@ from enum import Enum from typing import Optional -from monarch.actor import Actor, ActorError, ProcMesh +from monarch.actor import ActorError -from forge.controller import get_proc_mesh, stop_proc_mesh +from forge.controller import ForgeActor from forge.types import ProcessConfig logger = logging.getLogger(__name__) @@ -102,13 +102,11 @@ class Replica: # Configuration for the underlying ProcMesh (scheduler, hosts, GPUs) proc_config: ProcessConfig - actor_def: type[Actor] - actor_args: tuple + actor_def: type[ForgeActor] actor_kwargs: dict - # The proc_mesh and actor_mesh that this replica is running - proc_mesh: Optional[ProcMesh] = None - actor: Optional[Actor] = None + # The Actor that this replica is running + actor: Optional[ForgeActor] = None # Async queue for incoming requests request_queue: asyncio.Queue[ServiceRequest] = field(default_factory=asyncio.Queue) @@ -155,23 +153,15 @@ async def initialize(self): - Transitions to healthy state - Starts the processing loop """ - assert self.proc_mesh is None, "Proc mesh should not be set yet" + assert self.actor is None, "Actor should not be set yet" try: - # Create proc_mesh - await self.create_proc_mesh() - - # Ensure we have a healthy proc_mesh - if not self.proc_mesh: - raise RuntimeError( - f"Replica {self.idx}: proc_mesh is None after creation" - ) - - # Spawn the actor - await self.spawn_actor( - self.actor_def, - *self.actor_args, + # Deploy the actor and its underlying resources + logger.debug(f"Launching actor for replica {self.idx}") + self.actor = await self.actor_def.launch( + process_config=self.proc_config, **self.actor_kwargs, ) + # Transition to healthy state and start processing self.state = ReplicaState.HEALTHY self.start_processing() @@ -191,24 +181,17 @@ async def recover(self): return async def _do_recovery(): - old_proc_mesh = self.proc_mesh - self.proc_mesh = None - self.actor = None - - # Stop old proc_mesh if it exists - if old_proc_mesh is not None: - try: - await stop_proc_mesh(old_proc_mesh) - logger.debug(f"Old proc_mesh stopped for replica {self.idx}") - except Exception as e: - logger.warning( - f"Error stopping old proc_mesh for replica {self.idx}: {e}" - ) + try: + await self.actor_def.shutdown(self.actor) + self.actor = None + except Exception as e: + logger.warning(f"Error shutting down actor for replica {self.idx}: {e}") + self.state = ReplicaState.UNHEALTHY + # Re-create the actor try: - logger.debug(f"Creating new proc_mesh for replica {self.idx}") + logger.debug(f"Re-launching actor for replica {self.idx}") await self.initialize() - logger.debug(f"Recovery completed successfully for replica {self.idx}") except Exception as e: logger.error(f"Recovery failed for replica {self.idx}: {e}") self.state = ReplicaState.UNHEALTHY @@ -219,57 +202,6 @@ async def _do_recovery(): self._recovery_task = asyncio.create_task(_do_recovery()) await self._recovery_task - async def create_proc_mesh(self): - """Creates the proc_mesh using the stored proc_config.""" - # TODO - for policy replica, we would override this method to - # include multiple proc_meshes - if self.proc_mesh is not None: - logger.warning(f"Proc mesh already initialized for replica {self.idx}") - return - - logger.debug(f"Creating proc_mesh for replica {self.idx}") - try: - self.proc_mesh = await get_proc_mesh(process_config=self.proc_config) - logger.debug(f"Proc mesh created successfully for replica {self.idx}") - except Exception as e: - logger.error(f"Failed to create proc_mesh for replica {self.idx}: {e}") - self.state = ReplicaState.UNHEALTHY - raise - - async def spawn_actor(self, actor_def, *actor_args, **actor_kwargs): - """ - Spawn an actor on this replica's proc_mesh. - - This method handles the complete actor spawning process including - recovery if the proc_mesh has failed. - """ - if not self.proc_mesh: - raise RuntimeError( - f"Replica {self.idx}: proc_mesh is None after recovery attempt" - ) - - try: - # TODO - expand support so name can stick within kwargs - actor_name = actor_kwargs.pop("name", actor_def.__name__) - - # Spawn the actor - self.actor = await self.proc_mesh.spawn( - actor_name, - actor_def, - *actor_args, - **actor_kwargs, - ) - # Call actor setup if it exists - if setup_method := getattr(self.actor, "setup", None): - await setup_method.call() - - logger.debug(f"Actor spawned successfully on replica {self.idx}") - - except Exception as e: - logger.error(f"Failed to spawn actor on replica {self.idx}: {e}") - self.mark_failed() - raise - # Request handling / processing related functionality def start_processing(self): @@ -465,10 +397,10 @@ async def stop(self): len(failed_requests), ) - # Stop the proc_mesh - if self.proc_mesh: + # Stop the actor + if self.actor: try: - await stop_proc_mesh(self.proc_mesh) + await self.actor_def.shutdown(self.actor) except Exception as e: logger.warning( "Error stopping proc_mesh for replica %d: %s", self.idx, e diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index 5afe4c35e..439d357dd 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -71,13 +71,10 @@ class Service(Actor): _endpoints: Dynamically registered actor endpoints """ - def __init__( - self, cfg: ServiceConfig, actor_def, actor_args: tuple, actor_kwargs: dict - ): + def __init__(self, cfg: ServiceConfig, actor_def, actor_kwargs: dict): self._cfg = cfg self._replicas = [] self._actor_def = actor_def - self._actor_args = actor_args self._actor_kwargs = actor_kwargs self._active_sessions = [] @@ -106,7 +103,6 @@ async def __initialize__(self): max_concurrent_requests=self._cfg.replica_max_concurrent_requests, return_first_rank_result=self._cfg.return_first_rank_result, actor_def=self._actor_def, - actor_args=self._actor_args, actor_kwargs=self._actor_kwargs, ) replicas.append(replica) diff --git a/src/forge/controller/service/spawn.py b/src/forge/controller/service/spawn.py index 3b498b4f0..b8bd3b5bc 100644 --- a/src/forge/controller/service/spawn.py +++ b/src/forge/controller/service/spawn.py @@ -8,37 +8,53 @@ import logging from typing import Type -from monarch.actor import Actor, proc_mesh +from monarch.actor import proc_mesh +from forge.controller import ForgeActor from forge.controller.service import Service, ServiceConfig from forge.controller.service.interface import ServiceInterface logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) +logger.setLevel(logging.INFO) async def spawn_service( - service_cfg: ServiceConfig, actor_def: Type[Actor], *actor_args, **actor_kwargs + service_cfg: ServiceConfig, actor_def: Type[ForgeActor], **actor_kwargs ) -> ServiceInterface: """Spawns a service based on the actor class. Args: service_cfg: Service configuration actor_def: Actor class definition - *actor_args: Arguments to pass to actor constructor **actor_kwargs: Keyword arguments to pass to actor constructor Returns: A ServiceInterface that provides access to the Service Actor """ + # Assert that actor_def is a subclass of ForgeActor + if not issubclass(actor_def, ForgeActor): + raise TypeError( + f"actor_def must be a subclass of ForgeActor, got {type(actor_def).__name__}" + ) + # Create a single-node proc_mesh and actor_mesh for the Service Actor logger.info("Spawning Service Actor for %s", actor_def.__name__) m = await proc_mesh(gpus=1) service_actor = await m.spawn( - "service", Service, service_cfg, actor_def, actor_args, actor_kwargs + "service", Service, service_cfg, actor_def, actor_kwargs ) await service_actor.__initialize__.call_one() # Return the ServiceInterface that wraps the proc_mesh, actor_mesh, and actor_def return ServiceInterface(m, service_actor, actor_def) + + +async def shutdown_service(service: ServiceInterface) -> None: + """Shuts down the service. + + Implemented in this way to avoid actors overriding stop() unintentionally. + + """ + await service._service.stop.call_one() + await service._proc_mesh.stop() diff --git a/src/forge/controller/system_controllers/gpu_manager.py b/src/forge/controller/system_controllers/gpu_manager.py index 66128ac88..cbb1bccda 100644 --- a/src/forge/controller/system_controllers/gpu_manager.py +++ b/src/forge/controller/system_controllers/gpu_manager.py @@ -8,15 +8,13 @@ import logging -from monarch.actor import ActorError, endpoint, get_or_spawn_controller - -from forge.controller import ForgeActor +from monarch.actor import Actor, ActorError, endpoint, get_or_spawn_controller logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -class GpuManager(ForgeActor): +class GpuManager(Actor): """An actor that tracks and assigns GPU devices on given HostMeshes.""" def __init__(self): diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py index 4bd2d4bbe..3dbbd560e 100644 --- a/src/forge/interfaces.py +++ b/src/forge/interfaces.py @@ -7,7 +7,9 @@ from abc import ABC, abstractmethod from typing import Any, Mapping -from monarch.actor import Actor, endpoint +from monarch.actor import endpoint + +from forge.controller import ForgeActor from forge.types import Action, Message, Observation, Scalar, State @@ -72,7 +74,7 @@ def _apply_transform(self, observation: Observation) -> Observation: return observation -class Policy(Actor, ABC): +class Policy(ForgeActor, ABC): """Abstract interface for policies.""" @endpoint diff --git a/tests/unit_tests/test_service.py b/tests/unit_tests/test_service.py index 150949468..06f4c40ce 100644 --- a/tests/unit_tests/test_service.py +++ b/tests/unit_tests/test_service.py @@ -12,14 +12,15 @@ import logging import pytest -from forge.controller.service import ServiceConfig, spawn_service +from forge.controller import ForgeActor +from forge.controller.service import ServiceConfig, shutdown_service, spawn_service from monarch.actor import Actor, endpoint logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -class Counter(Actor): +class Counter(ForgeActor): """Test actor that maintains a counter with various endpoints.""" def __init__(self, v: int): @@ -57,6 +58,23 @@ async def add_to_value(self, amount: int, multiplier: int = 1) -> int: # Core Functionality Tests +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_actor_def_type_validation(): + """Test that spawn_service validates actor_def is a subclass of ForgeActor.""" + + # Service can only spawn ForgeActor subclasses + class InvalidActor(Actor): + def __init__(self): + pass + + cfg = ServiceConfig(procs_per_replica=1, num_replicas=1) + + # Test that TypeError is raised when actor_def is not a ForgeActor subclass + with pytest.raises(TypeError, match="actor_def must be a subclass of ForgeActor"): + await spawn_service(service_cfg=cfg, actor_def=InvalidActor) + + @pytest.mark.timeout(10) @pytest.mark.asyncio async def test_basic_service_operations(): @@ -86,7 +104,7 @@ async def test_basic_service_operations(): assert session1 not in state["session_replica_map"] finally: - await service.stop() + await shutdown_service(service) @pytest.mark.timeout(10) @@ -121,7 +139,7 @@ async def test_sessionless_calls(): assert result == 11 # 1 + 10 finally: - await service.stop() + await shutdown_service(service) @pytest.mark.timeout(15) @@ -160,7 +178,7 @@ async def worker(increments: int): assert len(state["session_replica_map"]) == 0 finally: - await service.stop() + await shutdown_service(service) # Fault Tolerance Tests @@ -244,7 +262,7 @@ async def test_recovery_state_transitions(): logger.info(f"Final replica state: {replica_state['state']}") finally: - await service.stop() + await shutdown_service(service) @pytest.mark.timeout(15) @@ -285,7 +303,7 @@ async def test_replica_failure_and_recovery(): assert assigned_replica["healthy"] finally: - await service.stop() + await shutdown_service(service) # Metrics and Monitoring Tests @@ -338,7 +356,7 @@ async def test_metrics_collection(): assert total_failed == 1 finally: - await service.stop() + await shutdown_service(service) # Load Balancing and Session Management Tests @@ -372,7 +390,7 @@ async def test_session_stickiness(): assert result == 4 finally: - await service.stop() + await shutdown_service(service) @pytest.mark.timeout(10) @@ -421,7 +439,7 @@ async def test_load_balancing_multiple_sessions(): assert total_requests == 4 # All requests processed finally: - await service.stop() + await shutdown_service(service) @pytest.mark.timeout(10) @@ -461,7 +479,7 @@ async def test_concurrent_operations(): assert total_requests == 4 finally: - await service.stop() + await shutdown_service(service) # `call` endpoint tests @@ -494,7 +512,7 @@ async def test_broadcast_call_basic(): assert all(value == 11 for value in values) finally: - await service.stop() + await shutdown_service(service) @pytest.mark.timeout(15) @@ -533,7 +551,7 @@ async def test_broadcast_call_with_failed_replica(): assert all(value == 1 for value in values) finally: - await service.stop() + await shutdown_service(service) @pytest.mark.timeout(10) @@ -570,4 +588,4 @@ async def test_broadcast_call_vs_choose(): assert total_requests == 10 finally: - await service.stop() + await shutdown_service(service)