From dd2cfa856c7940568b0623a8faf988bcf6eb17ac Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Thu, 21 Aug 2025 12:43:44 -0700 Subject: [PATCH 1/9] initial commit for replica --- src/forge/controller/replica.py | 289 +++++++++++ src/forge/controller/service_v2.py | 794 +++++++++++++++++++++++++++++ src/forge/controller/spawn.py | 25 + tests/test_service.py | 4 +- 4 files changed, 1110 insertions(+), 2 deletions(-) create mode 100644 src/forge/controller/replica.py create mode 100644 src/forge/controller/service_v2.py diff --git a/src/forge/controller/replica.py b/src/forge/controller/replica.py new file mode 100644 index 000000000..543bdf0c8 --- /dev/null +++ b/src/forge/controller/replica.py @@ -0,0 +1,289 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Replica for distributed actor service.""" + +import asyncio +import logging +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional + +from monarch.actor import Actor, ActorError + +from forge.controller import RecoverableProcMesh + +logger = logging.getLogger(__name__) + + +class ReplicaState(Enum): + HEALTHY = "healthy" + RECOVERING = "recovering" + UNHEALTHY = "unhealthy" + STOPPED = "stopped" + UNINITIALIZED = "uninitialized" + + +@dataclass +class ServiceRequest: + session_id: Optional[str] + function: str + args: tuple + kwargs: dict + future: asyncio.Future + + +@dataclass +class Replica: + proc_mesh: RecoverableProcMesh + actor: Optional[Actor] + idx: int + request_queue: asyncio.Queue[ServiceRequest] = field(default_factory=asyncio.Queue) + active_requests: int = 0 + max_concurrent_requests: int = 10 + _running: bool = False + metadata: dict = field(default_factory=dict) + state: ReplicaState = ReplicaState.UNINITIALIZED + return_first_rank_result: bool = False + + async def enqueue_request(self, request: ServiceRequest): + """Enqueues a request for processing by this replica.""" + if self.state == ReplicaState.STOPPED: + raise RuntimeError(f"Replica {self.idx} is stopped") + + # Accept requests in all other states - let the processing loop handle the rest + await self.request_queue.put(request) + + async def _process_single_request(self, request: ServiceRequest) -> bool: + """ + Processes a single request and returns success status. + + Returns: + bool: True if request succeeded, False if it failed + """ + self.active_requests += 1 + + try: + # Get the actor and endpoint + actor = self.actor + endpoint_func = getattr(actor, request.function) + + # Execute the request + success = True + try: + result = await endpoint_func.call(*request.args, **request.kwargs) + # Unwrap ValueMesh if configured to return first rank result + if ( + self.return_first_rank_result + and hasattr(result, "_values") + and result._values + ): + result = result._values[0] + request.future.set_result(result) + except ActorError as e: + logger.debug("Got failure on replica %d. Error:\n%s", self.idx, e) + # Mark proc_mesh as failed and transition state + self.proc_mesh.mark_failed() + self.state = ReplicaState.RECOVERING + # Unwrap the ActorError into its raw exception + request.future.set_exception(e.exception) + success = False + except Exception as e: + logger.debug( + "Got unexpected error on replica %d. Error:\n%s", self.idx, e + ) + # Mark proc_mesh as failed and transition state + self.proc_mesh.mark_failed() + self.state = ReplicaState.RECOVERING + request.future.set_exception(e) + success = False + + # Mark task as done + self.request_queue.task_done() + return success + + finally: + self.active_requests -= 1 + + async def run(self): + """ + Main processing loop for the replica. This replaces _persistent_processor. + + Continuously processes requests from the queue while the replica is healthy. + Handles capacity management and graceful degradation on failures. + """ + self._running = True + + try: + while self.state in (ReplicaState.HEALTHY, ReplicaState.RECOVERING): + try: + # Wait for a request with timeout to check health periodically + request = await asyncio.wait_for( + self.request_queue.get(), timeout=1.0 + ) + + # Check if we have capacity + if self.active_requests >= self.max_concurrent_requests: + # Put the request back and wait + await self.request_queue.put(request) + await asyncio.sleep(0.1) + continue + + # Update state if proc_mesh recovered + if self.state == ReplicaState.RECOVERING and self.proc_mesh.healthy: + self.state = ReplicaState.HEALTHY + logger.debug("Replica %d recovered to healthy state", self.idx) + + # If we're still recovering and proc_mesh isn't healthy, reject request + if ( + self.state == ReplicaState.RECOVERING + and not self.proc_mesh.healthy + ): + request.future.set_exception( + RuntimeError(f"Replica {self.idx} is still recovering") + ) + self.request_queue.task_done() + continue + + # Process the request + asyncio.create_task(self._process_single_request(request)) + + except asyncio.TimeoutError: + # No requests, check for health state changes + if self.state == ReplicaState.RECOVERING and self.proc_mesh.healthy: + self.state = ReplicaState.HEALTHY + logger.debug("Replica %d recovered to healthy state", self.idx) + elif ( + self.state == ReplicaState.HEALTHY + and not self.proc_mesh.healthy + ): + self.state = ReplicaState.RECOVERING + logger.debug("Replica %d entered recovering state", self.idx) + continue + + except Exception as e: + logger.error( + "Error in replica %d processing loop: %s", + self.idx, + e, + ) + self.state = ReplicaState.UNHEALTHY + break + + finally: + self._running = False + logger.debug("Replica %d stopped processing", self.idx) + + @property + def healthy(self) -> bool: + return self.state == ReplicaState.HEALTHY + + @property + def load(self) -> int: + """Get current load (active requests + queue depth)""" + return self.active_requests + self.request_queue.qsize() + + @property + def capacity_utilization(self) -> float: + """Get current capacity utilization (0.0 to 1.0)""" + if self.max_concurrent_requests <= 0: + return 0.0 + return self.active_requests / self.max_concurrent_requests + + def can_accept_request(self) -> bool: + """Check if replica can accept a new request""" + return ( + self.state == ReplicaState.HEALTHY + and self.active_requests < self.max_concurrent_requests + ) + + def __repr__(self) -> str: + return ( + f"Replica(idx={self.idx}, state={self.state.value}, " + f"active={self.active_requests}/{self.max_concurrent_requests}, " + f"queue={self.request_queue.qsize()})" + ) + + async def setup(self): + """ + Sets up the replica and transitions to healthy state. + + This should be called after the proc_mesh has been initialized + and the actor has been spawned on it. + """ + if self.state != ReplicaState.UNINITIALIZED: + logger.warning( + "Attempting to setup replica %d that's already initialized", self.idx + ) + return + + if self.actor is None: + raise RuntimeError(f"Cannot setup replica {self.idx}: actor is None") + + try: + # Call actor setup if it exists + if hasattr(self.actor, "setup"): + await self.actor.setup.call() + + # Transition to healthy state + self.state = ReplicaState.HEALTHY + logger.debug("Replica %d setup complete", self.idx) + + except Exception as e: + logger.error("Failed to setup replica %d: %s", self.idx, e) + self.state = ReplicaState.UNHEALTHY + raise + + async def stop(self): + """ + Stops the replica gracefully. + + Transitions to STOPPED state, stops the processing loop, and cleans up. + Fails any remaining requests in the queue. + """ + logger.debug("Stopping replica %d", self.idx) + + # Transition to stopped state to signal the run loop to exit + self.state = ReplicaState.STOPPED + + # Wait for processor to finish if it's running + if self._running: + # Give it a moment to finish current request and exit gracefully + for _ in range(50): # Wait up to 5 seconds + if not self._running: + break + await asyncio.sleep(0.1) + + if self._running: + logger.warning("Replica %d processor didn't stop gracefully", self.idx) + + # Fail any remaining requests in the queue + failed_requests = [] + while not self.request_queue.empty(): + try: + request = self.request_queue.get_nowait() + failed_requests.append(request) + self.request_queue.task_done() + except asyncio.QueueEmpty: + break + + # Fail all the collected requests + for request in failed_requests: + if not request.future.done(): + request.future.set_exception( + RuntimeError(f"Replica {self.idx} is stopping") + ) + + logger.debug( + "Replica %d stopped, failed %d remaining requests", + self.idx, + len(failed_requests), + ) + + # Stop the proc_mesh + try: + await self.proc_mesh.stop() + except Exception as e: + logger.warning("Error stopping proc_mesh for replica %d: %s", self.idx, e) diff --git a/src/forge/controller/service_v2.py b/src/forge/controller/service_v2.py new file mode 100644 index 000000000..2fe8b4ed9 --- /dev/null +++ b/src/forge/controller/service_v2.py @@ -0,0 +1,794 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +""" +Distributed Actor Service Controller + +This module provides a robust service orchestration system for managing distributed +actor-based workloads with automatic scaling, fault tolerance, and intelligent load balancing. + +The main Service class acts as a singleton controller that handles: +- Fault tolerance with automatic replica recovery +- Autoscaling based on real-time metrics +- Load balancing across healthy replicas +- Session management with context propagation +- Comprehensive metrics collection and monitoring + +Example: + Basic service setup: + + >>> config = ServiceConfig( + ... gpus_per_replica=1, + ... num_replicas=3 + ... ) + >>> service = Service(config, MyActorClass, *args, **kwargs) + >>> await service.__initialize__() + + Session-based usage: + + >>> async with service.session(): + ... result = await service.my_endpoint(arg1, arg2) +""" + + +import asyncio +import contextvars +import logging +import pprint +import time +import uuid +from collections import defaultdict, deque +from dataclasses import dataclass, field +from typing import Any, Callable, Coroutine, Dict, List, Optional + +from monarch._src.actor.endpoint import EndpointProperty +from monarch.actor import ProcMesh + +from forge.controller import RecoverableProcMesh +from forge.controller.replica import Replica, ServiceRequest +from forge.types import ServiceConfig + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +# TODO - tie this into metric logger when it exists +@dataclass +class ReplicaMetrics: + """ + Metrics collection for a single replica instance. + + Tracks request counts, timing metrics, current state, and session assignments + for performance monitoring and autoscaling decisions. + + Attributes: + replica_idx: Unique identifier for this replica + total_requests: Total number of requests processed + successful_requests: Number of successfully completed requests + failed_requests: Number of failed requests + request_times: Sliding window of request start timestamps + request_latencies: Sliding window of request completion latencies + active_requests: Currently processing requests + queue_depth: Number of pending requests in queue + assigned_sessions: Number of sessions assigned to this replica + """ + + replica_idx: int + # Request metrics + total_requests: int = 0 + successful_requests: int = 0 + failed_requests: int = 0 + # Timing metrics (sliding window) + request_times: deque = field(default_factory=lambda: deque(maxlen=100)) + request_latencies: deque = field(default_factory=lambda: deque(maxlen=100)) + # Current state + active_requests: int = 0 + queue_depth: int = 0 + # Session metrics + assigned_sessions: int = 0 + + def add_request_start(self, timestamp: float): + """Record when a request starts processing.""" + self.request_times.append(timestamp) + self.total_requests += 1 + + def add_request_completion(self, start_time: float, success: bool): + """Record when a request completes.""" + latency = time.time() - start_time + self.request_latencies.append(latency) + if success: + self.successful_requests += 1 + else: + self.failed_requests += 1 + + def get_request_rate(self, window_seconds: float = 60.0) -> float: + """Get requests per second over the last window_seconds.""" + now = time.time() + cutoff = now - window_seconds + recent_requests = [t for t in self.request_times if t >= cutoff] + return len(recent_requests) / window_seconds if window_seconds > 0 else 0.0 + + def get_avg_latency(self, window_requests: int = 50) -> float: + """Get average latency over the last N requests.""" + if not self.request_latencies: + return 0.0 + recent_latencies = list(self.request_latencies)[-window_requests:] + return sum(recent_latencies) / len(recent_latencies) + + def get_capacity_utilization(self, max_concurrent: int) -> float: + """Get current capacity utilization (0.0 to 1.0).""" + return self.active_requests / max_concurrent if max_concurrent > 0 else 0.0 + + +@dataclass +class ServiceMetrics: + """ + Aggregated metrics collection for the entire service. + + Provides service-wide visibility into performance, health, and scaling metrics + by aggregating data from all replica instances. + + Attributes: + replica_metrics: Per-replica metrics indexed by replica ID + total_sessions: Number of active sessions across all replicas + healthy_replicas: Number of currently healthy replicas + total_replicas: Total number of replicas (healthy + unhealthy) + last_scale_event: Timestamp of the last scaling operation + """ + + # Replica metrics + replica_metrics: Dict[int, ReplicaMetrics] = field(default_factory=dict) + # Service-level metrics + total_sessions: int = 0 + healthy_replicas: int = 0 + total_replicas: int = 0 + # Time-based metrics + last_scale_event: float = 0.0 + + def get_total_request_rate(self, window_seconds: float = 60.0) -> float: + """Get total requests per second across all replicas.""" + return sum( + metrics.get_request_rate(window_seconds) + for metrics in self.replica_metrics.values() + ) + + def get_avg_queue_depth(self) -> float: + """Get average queue depth across all healthy replicas.""" + healthy_metrics = [ + m + for m in self.replica_metrics.values() + if m.replica_idx < self.healthy_replicas + ] + if not healthy_metrics: + return 0.0 + return sum(m.queue_depth for m in healthy_metrics) / len(healthy_metrics) + + def get_avg_capacity_utilization(self, replicas: List) -> float: + """Get average capacity utilization across all healthy replicas.""" + healthy_replicas = [r for r in replicas if r.healthy] + if not healthy_replicas: + return 0.0 + + utilizations = [] + for replica in healthy_replicas: + if replica.idx in self.replica_metrics: + metrics = self.replica_metrics[replica.idx] + utilization = metrics.get_capacity_utilization( + replica.max_concurrent_requests + ) + utilizations.append(utilization) + + return sum(utilizations) / len(utilizations) if utilizations else 0.0 + + def get_sessions_per_replica(self) -> float: + """Get average sessions per healthy replica.""" + if self.healthy_replicas == 0: + return 0.0 + return self.total_sessions / self.healthy_replicas + + +@dataclass +class Session: + session_id: str + + +# Global context variable for session state +# This is used to propagate session state across async tasks +_session_context: contextvars.ContextVar[dict | None] = contextvars.ContextVar( + "session_context", default=None +) + + +class SessionContext: + """Context manager for service sessions using context variables.""" + + def __init__(self, service: "Service", **session_kwargs): + self.service = service + self.session_id: str | None = None + self.session_kwargs = session_kwargs + self._token = None + + async def __aenter__(self): + """Start a session and set context variables.""" + self.session_id = await self.service.start_session() + # Set context for this async task + context_value = {"session_id": self.session_id, "kwargs": self.session_kwargs} + self._token = _session_context.set(context_value) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Terminate the session and restore context.""" + if self._token: + _session_context.reset(self._token) + if self.session_id: + await self.service.terminate_session(self.session_id) + self.session_id = None + + +class Service: + """ + Distributed Actor Service Controller + + A sophisticated service orchestration system that manages multiple replicas of actor-based + services with automatic scaling, fault tolerance, and intelligent load balancing. + + The Service acts as a unified interface for distributed workloads, automatically handling: + - **Fault Tolerance**: Health monitoring, automatic replica recovery, request migration + - **Load Balancing**: Round-robin, least-loaded, and session-affinity routing + - **Session Management**: Stateful session handling with context propagation + - **Metrics Collection**: Comprehensive performance and health monitoring + + Args: + cfg: Service configuration including number of replicas, GPUs per replica, and health polling rate + actor_def: Actor class definition to instantiate on each replica + *actor_args: Positional arguments passed to actor constructor + **actor_kwargs: Keyword arguments passed to actor constructor + + Example: + Basic setup with autoscaling: + + >>> config = ServiceConfig( + ... gpus_per_replica=1, + ... num_replicas=3, + ... ) + >>> service = Service(config, MyActorClass, model_path="/path/to/model") + >>> await service.__initialize__() + + Session-based usage: + + >>> async with service.session(): + ... result1 = await service.my_endpoint(arg1, arg2) + ... result2 = await service.another_endpoint(arg3) + + Stateless usage: + + >>> result = await service.my_endpoint(arg1, arg2) # Uses round-robin + + Attributes: + _cfg: Service configuration + _replicas: List of managed replica instances + _active_sessions: Currently active sessions + _metrics: Aggregated service and replica metrics + _endpoints: Dynamically registered actor endpoints + """ + + def __init__(self, cfg: ServiceConfig, actor_def, *actor_args, **actor_kwargs): + self._cfg = cfg + self._replicas = [] + self._actor_def = actor_def + self._actor_args = actor_args + self._actor_kwargs = actor_kwargs + + self._active_sessions = [] + self._id_session_map = {} + self._session_replica_map: Dict[str, int] = {} + self._next_replica_idx = 0 # For round-robin load balancing + + # Initialize metrics collection + self._metrics = ServiceMetrics() + self._health_task = None + self._shutdown_requested = False + + # Replica initialization queue + self._replicas_to_init = [] + + # For all endpoints within the actor_def, create an interface from it + self._endpoints = [] + for func_name in dir(actor_def): + func = getattr(actor_def, func_name) + if isinstance(func, EndpointProperty): + logger.debug("Registering endpoint %s", func_name) + self._endpoints.append(func_name) + # Dynamically add this endpoint method to the Service class + self._add_endpoint_method(func_name) + + async def __initialize__(self): + logger.debug("Starting service up with %d replicas.", self._cfg.num_replicas) + replicas = [] + num_replicas = self._cfg.num_replicas + for i in range(num_replicas): + mesh = RecoverableProcMesh(proc_config=self._cfg.to_process_config()) + replica = Replica( + proc_mesh=mesh, + actor=None, + idx=len(self._replicas) + i, + max_concurrent_requests=self._cfg.replica_max_concurrent_requests, + return_first_rank_result=self._cfg.return_first_rank_result, + ) + replicas.append(replica) + + # Initializing should only happen in the health_loop + # and during the first initialization. + # If multiple parts of the code try to initialize replicas at + # the same time, it can cause nasty race conditions + # (e.g., double initialization, inconsistent state, or resource conflicts). + # By funneling all replica initialization through a single queue and the + # health loop, we ensure safe, serialized initialization. + logger.debug( + "Queued %d replicas for initialization. Total replicas: %d", + num_replicas, + len(self._replicas), + ) + self._replicas_to_init.extend(replicas) + await self._maybe_init_replicas() + self._replicas.extend(replicas) + + # Start the health loop in the background + self._health_task = asyncio.create_task( + self._health_loop(poll_rate_s=self._cfg.health_poll_rate) + ) + + def _add_endpoint_method(self, endpoint_name: str): + """Dynamically adds an endpoint method to this Service instance.""" + + async def endpoint_method(sess_id: str | None = None, *args, **kwargs): + return await self._call(sess_id, endpoint_name, *args, **kwargs) + + # Set the method on this instance + setattr(self, endpoint_name, endpoint_method) + + async def _call(self, sess_id: str | None, function: str, *args, **kwargs): + """ + Routes a function call to the appropriate replica with load balancing and fault tolerance. + + This is the core routing method that handles: + - Session-based routing for stateful calls + - Round-robin load balancing for stateless calls + - Custom routing based on context hints + - Automatic retry on replica failures + - Request queuing and processing + + Args: + sess_id: Optional session ID for stateful routing + function: Name of the actor endpoint to call + *args: Positional arguments to pass to the endpoint + **kwargs: Keyword arguments to pass to the endpoint + + Returns: + The result from the actor endpoint execution + + Raises: + RuntimeError: If no healthy replicas are available + Exception: Any exception raised by the actor endpoint + """ + # Check context variables for session state if no explicit sess_id + if sess_id is None: + ctx = _session_context.get() + if ctx: + sess_id = ctx["session_id"] + routing_hints = ctx["kwargs"] + else: + routing_hints = {} + else: + routing_hints = {} + + replica = await self._get_replica(sess_id, **routing_hints) + + # Create a ServiceRequest object to queue + request = ServiceRequest( + session_id=sess_id, + function=function, + args=args, + kwargs=kwargs, + future=asyncio.Future(), + ) + + # Queue the request using replica's method + await replica.enqueue_request(request) + + # Start the replica processing loop if not already running + if not replica._running: + asyncio.create_task(replica.run()) + + # Wait for the result + try: + return await request.future + except Exception as e: + # If the replica failed, try to retry once + if not replica.healthy: + logger.debug( + "Replica %d failed during request, retrying on healthy replica", + replica.idx, + ) + return await self._retry_request_on_healthy_replica( + sess_id, function, *args, **kwargs + ) + raise + + async def _retry_request_on_healthy_replica( + self, sess_id: str | None, function: str, *args, **kwargs + ): + """Retries a failed request on a healthy replica.""" + # Force reassignment to a healthy replica (only for session-based calls) + if sess_id is not None and sess_id in self._session_replica_map: + del self._session_replica_map[sess_id] + + # Retry the call (this will assign to a new healthy replica) + return await self._call(sess_id, function, *args, **kwargs) + + async def _migrate_remaining_requests(self, failed_replica: Replica): + """Migrates remaining requests from a failed replica to healthy replicas.""" + migrated_requests = [] + + # Collect all remaining requests + while not failed_replica.request_queue.empty(): + try: + request = failed_replica.request_queue.get_nowait() + migrated_requests.append(request) + except asyncio.QueueEmpty: + break + + if not migrated_requests: + return + + logger.debug( + "Migrating %d requests from failed replica %d", + len(migrated_requests), + failed_replica.idx, + ) + + # Find healthy replicas + healthy_replicas = [ + r for r in self._replicas if r.healthy and r != failed_replica + ] + + if not healthy_replicas: + # No healthy replicas, fail all requests + for request in migrated_requests: + request.future.set_exception( + RuntimeError("No healthy replicas available") + ) + return + + # Distribute requests among healthy replicas + for i, request in enumerate(migrated_requests): + target_replica = healthy_replicas[i % len(healthy_replicas)] + await target_replica.enqueue_request(request) + + # Start replica processing if not running + if not target_replica._running: + asyncio.create_task(target_replica.run()) + + # Update session mapping if needed + sess_id = request.session_id + if ( + sess_id in self._session_replica_map + and self._session_replica_map[sess_id] == failed_replica.idx + ): + self._session_replica_map[sess_id] = target_replica.idx + + async def start_session(self) -> str: + """ + Starts a new session for stateful request handling. + + Sessions enable request affinity to specific replicas, maintaining state + consistency for workloads that require it. Each session gets a unique ID + and is automatically assigned to the least loaded replica. + + Returns: + str: Unique session identifier for use in subsequent requests + + Example: + >>> session_id = await service.start_session() + >>> result = await service.my_endpoint(session_id, arg1, arg2) + >>> await service.terminate_session(session_id) + """ + sess_id = str(uuid.uuid4()) + session = Session(session_id=sess_id) + self._active_sessions.append(session) + + # Update metrics + self._update_service_metrics() + + return sess_id + + def session(self, **kwargs) -> SessionContext: + """Returns a context manager for session-based calls.""" + return SessionContext(self, **kwargs) + + def _update_service_metrics(self): + """Updates service-level metrics.""" + self._metrics.total_sessions = len(self._active_sessions) + self._metrics.total_replicas = len(self._replicas) + self._metrics.healthy_replicas = sum(1 for r in self._replicas if r.healthy) + + # Update queue depths for all replicas + for replica in self._replicas: + if replica.idx not in self._metrics.replica_metrics: + self._metrics.replica_metrics[replica.idx] = ReplicaMetrics(replica.idx) + + replica_metrics = self._metrics.replica_metrics[replica.idx] + replica_metrics.queue_depth = replica.request_queue.qsize() + replica_metrics.active_requests = replica.active_requests + + # Update session assignments per replica + session_counts = defaultdict(int) + for sess_id, replica_idx in self._session_replica_map.items(): + session_counts[replica_idx] += 1 + + for replica_idx, count in session_counts.items(): + if replica_idx in self._metrics.replica_metrics: + self._metrics.replica_metrics[replica_idx].assigned_sessions = count + + def get_metrics(self) -> ServiceMetrics: + """ + Get comprehensive service metrics for monitoring and analysis. + + Returns detailed metrics including per-replica performance data, + service-wide aggregations, and health status information. + + Returns: + ServiceMetrics: Complete metrics object with replica and service data + + Example: + >>> metrics = service.get_metrics() + >>> print(f"Request rate: {metrics.get_total_request_rate():.1f} req/s") + >>> print(f"Queue depth: {metrics.get_avg_queue_depth():.1f}") + """ + self._update_service_metrics() + return self._metrics + + def get_metrics_summary(self) -> dict: + """ + Get a summary of key metrics for monitoring and debugging. + + Provides a structured summary of service and replica metrics in a format + suitable for monitoring dashboards, logging, or debugging purposes. + + Returns: + dict: Structured metrics summary with service and per-replica data + + Example: + >>> summary = service.get_metrics_summary() + >>> print(f"Healthy replicas: {summary['service']['healthy_replicas']}") + >>> for idx, metrics in summary['replicas'].items(): + ... print(f"Replica {idx}: {metrics['request_rate']:.1f} req/s") + """ + self._update_service_metrics() + + summary = { + "service": { + "total_sessions": self._metrics.total_sessions, + "healthy_replicas": self._metrics.healthy_replicas, + "total_replicas": self._metrics.total_replicas, + "total_request_rate": self._metrics.get_total_request_rate(), + "avg_queue_depth": self._metrics.get_avg_queue_depth(), + "avg_capacity_utilization": self._metrics.get_avg_capacity_utilization( + self._replicas + ), + "sessions_per_replica": self._metrics.get_sessions_per_replica(), + }, + "replicas": {}, + } + + for replica_idx, metrics in self._metrics.replica_metrics.items(): + summary["replicas"][replica_idx] = { + "total_requests": metrics.total_requests, + "successful_requests": metrics.successful_requests, + "failed_requests": metrics.failed_requests, + "request_rate": metrics.get_request_rate(), + "avg_latency": metrics.get_avg_latency(), + "active_requests": metrics.active_requests, + "queue_depth": metrics.queue_depth, + "assigned_sessions": metrics.assigned_sessions, + "capacity_utilization": metrics.get_capacity_utilization(10), + } + + return summary + + async def terminate_session(self, sess_id: str): + """ + Terminates an active session and cleans up associated resources. + + Removes the session from active tracking, clears replica assignments, + and updates service metrics. Sessions should be terminated when no + longer needed to free up resources. + + Args: + sess_id: The unique session identifier to terminate + + Example: + >>> session_id = await service.start_session() + >>> # ... use session for requests ... + >>> await service.terminate_session(session_id) + """ + logger.debug("Terminating session %s", sess_id) + + # Remove from active sessions + self._active_sessions = [ + s for s in self._active_sessions if s.session_id != sess_id + ] + + # Remove from session-replica mapping + if sess_id in self._session_replica_map: + del self._session_replica_map[sess_id] + + # Update metrics + self._update_service_metrics() + + async def _health_loop(self, poll_rate_s: float): + """Runs the health loop to monitor and recover replicas. + + This loop continuously checks the health of replicas and recovers + failed replicas by reinitializing their proc_meshes. It also + periodically updates service metrics to reflect the current state. + + """ + while not self._shutdown_requested: + # Process any replicas that need initialization + await self._maybe_init_replicas() + + # Check for failed replicas and recover them + failed_replicas = [] + for replica in self._replicas: + if replica.proc_mesh.failed: + failed_replicas.append(replica) + + if any(failed_replicas): + logger.debug( + "[HEALTH LOOP] Detected %d failed replicas: %s", + len(failed_replicas), + pprint.pformat(failed_replicas), + ) + self._replicas_to_init.extend(failed_replicas) + + await asyncio.sleep(poll_rate_s) + + async def _custom_replica_routing( + self, sess_id: str | None, **kwargs + ) -> Optional[Replica]: + """Hook for custom routing logic. Override in subclasses to implement custom routing.""" + return None + + def _get_next_replica(self) -> "Replica": + """Get the next replica using round-robin selection.""" + healthy_replicas = [r for r in self._replicas if r.healthy] + if not healthy_replicas: + raise RuntimeError("No healthy replicas available for load balancing") + + # Simple round-robin + self._next_replica_idx = (self._next_replica_idx + 1) % len(healthy_replicas) + return healthy_replicas[self._next_replica_idx] + + def _get_least_loaded_replica(self) -> "Replica": + """Get the replica with the lowest load.""" + healthy_replicas = [r for r in self._replicas if r.healthy] + if not healthy_replicas: + raise RuntimeError("No healthy replicas available for session assignment") + + # Load = active_requests + queue_depth + def get_load(replica: "Replica") -> int: + return replica.active_requests + replica.request_queue.qsize() + + return min(healthy_replicas, key=get_load) + + async def _get_replica(self, sess_id: str | None, **kwargs) -> "Replica": + """Get a replica for the given session ID, with optional custom routing hints.""" + # Try custom routing first if hints are provided + if kwargs: + custom_result = await self._custom_replica_routing(sess_id, **kwargs) + if custom_result is not None: + return custom_result + + # Default routing logic + if sess_id is None: + # No session, use round-robin load balancing + replica = self._get_next_replica() + return replica + + # Session-based routing + if sess_id in self._session_replica_map: + replica_idx = self._session_replica_map[sess_id] + # Find the replica with this index + for replica in self._replicas: + if replica.idx == replica_idx and replica.healthy: + return replica + # If the replica is no longer healthy, remove from session map and reassign + del self._session_replica_map[sess_id] + + # New session, assign to least loaded replica + replica = self._get_least_loaded_replica() + self._session_replica_map[sess_id] = replica.idx + logger.debug("Assigning session %s to replica %d", sess_id, replica.idx) + return replica + + async def stop(self): + logger.debug("Stopping service...") + # Signal shutdown to health loop + self._shutdown_requested = True + + # Wait for health loop to finish gracefully + if self._health_task is not None: + try: + await asyncio.wait_for(self._health_task, timeout=5.0) + logger.info("Health loop stopped gracefully.") + except asyncio.TimeoutError: + logger.warning("Health loop didn't stop gracefully, cancelling...") + self._health_task.cancel() + try: + await self._health_task + except asyncio.CancelledError: + logger.info("Health loop task cancelled.") + + # Stop all replicas using their stop method + await asyncio.gather( + *[replica.stop() for replica in self._replicas], + return_exceptions=True, + ) + + async def _maybe_init_replicas(self): + """Initializes replicas that are queued for initialization.""" + if not self._replicas_to_init: + return + + logger.debug("Init replicas: %s", pprint.pformat(self._replicas_to_init)) + + def _recover_hook( + replica: Replica, + ) -> Callable[[ProcMesh], Coroutine[Any, Any, None]]: + async def inner_hook(proc_mesh: ProcMesh) -> None: + if "name" in self._actor_kwargs: + actor_name = self._actor_kwargs.pop("name") + else: + actor_name = self._actor_def.__name__ + # TODO - expand support so name can stick within kwargs + actor = await proc_mesh.spawn( + actor_name, + self._actor_def, + *self._actor_args, + **self._actor_kwargs, + ) + replica.actor = actor + # Use replica's setup method instead of inline setup + await replica.setup() + + return inner_hook + + await asyncio.gather( + *[ + replica.proc_mesh.spawn(_recover_hook(replica)) + for replica in self._replicas_to_init + ] + ) + self._replicas_to_init.clear() + + async def _migrate_replica_workload(self, replica_to_remove: Replica): + """Migrates all workload from a replica that's being removed.""" + # Migrate queued requests + await self._migrate_remaining_requests(replica_to_remove) + + # Reassign sessions to other replicas + sessions_to_reassign = [ + sess_id + for sess_id, replica_idx in self._session_replica_map.items() + if replica_idx == replica_to_remove.idx + ] + + for sess_id in sessions_to_reassign: + del self._session_replica_map[sess_id] + logger.debug("Session %s will be reassigned on next request", sess_id) + + def __repr__(self): + return f"Service(actor={self._actor_def.__name__})" diff --git a/src/forge/controller/spawn.py b/src/forge/controller/spawn.py index fe0512277..40b47eb5d 100644 --- a/src/forge/controller/spawn.py +++ b/src/forge/controller/spawn.py @@ -11,6 +11,10 @@ from monarch.actor import Actor from forge.controller import Service, ServiceConfig +from forge.controller.service_v2 import ( + Service as ServiceV2, + ServiceConfig as ServiceConfigV2, +) logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -35,3 +39,24 @@ async def spawn_service( service = Service(service_cfg, actor_def, *actor_args, **actor_kwargs) await service.__initialize__() return service + + +async def spawn_service_v2( + service_cfg: ServiceConfigV2, actor_def: Type[Actor], *actor_args, **actor_kwargs +) -> ServiceV2: + """Spawns a service based on the actor class. + + Args: + service_cfg: Service configuration + actor_def: Actor class definition + *actor_args: Arguments to pass to actor constructor + **actor_kwargs: Keyword arguments to pass to actor constructor + + Returns: + The appropriate service type based on the actor class + """ + # Default to base Service + logger.info("Spawning base Service for %s", actor_def.__name__) + service = ServiceV2(service_cfg, actor_def, *actor_args, **actor_kwargs) + await service.__initialize__() + return service diff --git a/tests/test_service.py b/tests/test_service.py index 7283aeeee..0898de605 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -12,8 +12,8 @@ import logging import pytest -from forge.controller.service import ServiceConfig -from forge.controller.spawn import spawn_service +from forge.controller.service_v2 import ServiceConfig +from forge.controller.spawn import spawn_service_v2 as spawn_service from monarch.actor import Actor, endpoint logger = logging.getLogger(__name__) From d3677ba0fed2e86c3863b24f3adebb88c1c538d9 Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Fri, 22 Aug 2025 08:23:57 -0700 Subject: [PATCH 2/9] clean up --- src/forge/controller/replica.py | 396 ++++++++++++++++++++++------- src/forge/controller/service_v2.py | 223 ++++------------ tests/test_service.py | 5 +- 3 files changed, 365 insertions(+), 259 deletions(-) diff --git a/src/forge/controller/replica.py b/src/forge/controller/replica.py index 543bdf0c8..a8cb89644 100644 --- a/src/forge/controller/replica.py +++ b/src/forge/controller/replica.py @@ -7,27 +7,79 @@ import asyncio import logging +import time +from collections import deque from dataclasses import dataclass, field from enum import Enum from typing import Optional -from monarch.actor import Actor, ActorError +from monarch.actor import Actor, ActorError, ProcMesh -from forge.controller import RecoverableProcMesh +from forge.controller import get_proc_mesh +from forge.types import ProcessConfig logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) class ReplicaState(Enum): - HEALTHY = "healthy" - RECOVERING = "recovering" - UNHEALTHY = "unhealthy" - STOPPED = "stopped" - UNINITIALIZED = "uninitialized" + HEALTHY = "HEALTHY" + RECOVERING = "RECOVERING" + UNHEALTHY = "UNHEALTHY" + STOPPED = "STOPPED" + UNINITIALIZED = "UNINITIALIZED" + + +@dataclass +class ReplicaMetrics: + """Simple metrics tracking for a replica.""" + + total_requests: int = 0 + successful_requests: int = 0 + failed_requests: int = 0 + request_times: deque = field(default_factory=lambda: deque(maxlen=100)) + request_latencies: deque = field(default_factory=lambda: deque(maxlen=100)) + + def add_request_start(self, timestamp: float): + """Records when a request starts processing.""" + self.request_times.append(timestamp) + self.total_requests += 1 + + def add_request_completion(self, start_time: float, success: bool): + """Records when a request completes.""" + latency = time.time() - start_time + self.request_latencies.append(latency) + if success: + self.successful_requests += 1 + else: + self.failed_requests += 1 + + def get_request_rate(self, window_seconds: float = 60.0) -> float: + """Gets requests per second over the last window_seconds.""" + now = time.time() + cutoff = now - window_seconds + recent_requests = [t for t in self.request_times if t >= cutoff] + return len(recent_requests) / window_seconds if window_seconds > 0 else 0.0 + + def get_avg_latency(self, window_requests: int = 50) -> float: + """Gets average latency over the last N requests.""" + if not self.request_latencies: + return 0.0 + recent_latencies = list(self.request_latencies)[-window_requests:] + return sum(recent_latencies) / len(recent_latencies) @dataclass class ServiceRequest: + """Representation of a request to the service. + + A service request will typically be a call to an actor endpoint. + - The endpoint call is represented by function str/args/kwargs, + - The session_id is used for stateful routing, and + - The future is used to return the result of the call. + + """ + session_id: Optional[str] function: str args: tuple @@ -37,34 +89,170 @@ class ServiceRequest: @dataclass class Replica: - proc_mesh: RecoverableProcMesh - actor: Optional[Actor] + """ + A distributed replica that serves as the fundamental unit of work within a service. + + Handles process lifecycle, async request queuing, fault recovery, and load balancing. + Each replica runs independently and can be deployed across multiple hosts via Monarch + + """ + idx: int + + # Configuration for the underlying ProcMesh (scheduler, hosts, GPUs) + proc_config: ProcessConfig + + # The proc_mesh and actor_mesh that this replica is running + proc_mesh: Optional[ProcMesh] = None + actor: Optional[Actor] = None + + # Async queue for incoming requests request_queue: asyncio.Queue[ServiceRequest] = field(default_factory=asyncio.Queue) + # Number of currently processing requests active_requests: int = 0 + # Maximum number of simultaneous requests max_concurrent_requests: int = 10 + # Whether the processing loop is currently running _running: bool = False - metadata: dict = field(default_factory=dict) + # How often to check for new requests when idle + _run_poll_rate_s: float = 1.0 + # Current replica health state state: ReplicaState = ReplicaState.UNINITIALIZED + # Whether to auto-unwrap ValueMesh to first rank return_first_rank_result: bool = False + # Recovery-related state + _recovery_task: Optional[asyncio.Task] = None + + # Run task is the replica's event loop + _run_task: Optional[asyncio.Task] = None + + # Metrics tracking + metrics: ReplicaMetrics = field(default_factory=ReplicaMetrics) + + # Initialization related functionalities + + async def init_proc_mesh(self): + """Initializes the proc_mesh using the stored proc_config.""" + # TODO - for policy replica, we would override this method to + # include multiple proc_meshes + if self.proc_mesh is not None: + logger.warning("Proc mesh already initialized for replica %d", self.idx) + return + + logger.debug("Initializing proc_mesh for replica %d", self.idx) + try: + self.proc_mesh = await get_proc_mesh(process_config=self.proc_config) + logger.debug("Proc mesh initialized successfully for replica %d", self.idx) + except Exception as e: + logger.error( + "Failed to initialize proc_mesh for replica %d: %s", self.idx, e + ) + self.state = ReplicaState.UNHEALTHY + raise + + async def spawn_actor(self, actor_def, *actor_args, **actor_kwargs): + """ + Spawn an actor on this replica's proc_mesh. + + This method handles the complete actor spawning process including + recovery if the proc_mesh has failed. + """ + # Ensure we have a healthy proc_mesh + await self._ensure_healthy_proc_mesh() + + if not self.proc_mesh: + raise RuntimeError( + f"Replica {self.idx}: proc_mesh is None after recovery attempt" + ) + + try: + # Determine actor name + if "name" in actor_kwargs: + actor_name = actor_kwargs.pop("name") + else: + actor_name = actor_def.__name__ + + # Spawn the actor + self.actor = await self.proc_mesh.spawn( + actor_name, + actor_def, + *actor_args, + **actor_kwargs, + ) + + # Call setup if it exists + await self.setup() + + logger.debug("Actor spawned successfully on replica %d", self.idx) + + except Exception as e: + logger.error("Failed to spawn actor on replica %d: %s", self.idx, e) + self.mark_failed() + raise + + async def setup(self): + """ + Sets up the replica and transitions to healthy state. + + This should be called after the proc_mesh has been initialized + and the actor has been spawned on it. + """ + if self.state != ReplicaState.UNINITIALIZED: + logger.warning( + "Attempting to setup replica %d that's already initialized", self.idx + ) + return + + if self.actor is None: + raise RuntimeError(f"Cannot setup replica {self.idx}: actor is None") + + try: + # Call actor setup if it exists + if hasattr(self.actor, "setup"): + # TODO - should this be a standard in our Forge Actor(s)? + await self.actor.setup.call() + + # Transition to healthy state and start processing + self.state = ReplicaState.HEALTHY + self.start_processing() + logger.debug("Replica %d setup complete", self.idx) + + except Exception as e: + logger.error("Failed to setup replica %d: %s", self.idx, e) + self.state = ReplicaState.UNHEALTHY + raise + + # Request handling / processing related functionality + + def start_processing(self): + """Start the replica's processing loop if not already running.""" + if self._run_task is None or self._run_task.done(): + self._run_task = asyncio.create_task(self.run()) + logger.debug("Started processing loop for replica %d", self.idx) + async def enqueue_request(self, request: ServiceRequest): """Enqueues a request for processing by this replica.""" if self.state == ReplicaState.STOPPED: - raise RuntimeError(f"Replica {self.idx} is stopped") + raise RuntimeError( + f"Replica {self.idx} is stopped and therefore will not accept requests." + ) # Accept requests in all other states - let the processing loop handle the rest await self.request_queue.put(request) async def _process_single_request(self, request: ServiceRequest) -> bool: - """ - Processes a single request and returns success status. + """Processes a single request and returns success status. Returns: bool: True if request succeeded, False if it failed """ + start_time = time.time() self.active_requests += 1 + # Record request start for metrics + self.metrics.add_request_start(start_time) + try: # Get the actor and endpoint actor = self.actor @@ -83,23 +271,30 @@ async def _process_single_request(self, request: ServiceRequest) -> bool: result = result._values[0] request.future.set_result(result) except ActorError as e: - logger.debug("Got failure on replica %d. Error:\n%s", self.idx, e) - # Mark proc_mesh as failed and transition state - self.proc_mesh.mark_failed() - self.state = ReplicaState.RECOVERING - # Unwrap the ActorError into its raw exception - request.future.set_exception(e.exception) + logger.warning("Got failure on replica %d. Error:\n%s", self.idx, e) + # The exception came from the actor. It itself is + # returned to be propagated through the services + # back to the caller. + request.future.set_result(e.exception) + + # TODO: we may want to conditionally mark the + # replica as failed here - i.e. where the actor itself + # can be healthy but the request failed. + self.mark_failed() success = False except Exception as e: logger.debug( "Got unexpected error on replica %d. Error:\n%s", self.idx, e ) - # Mark proc_mesh as failed and transition state - self.proc_mesh.mark_failed() - self.state = ReplicaState.RECOVERING + self.mark_failed() + + # The exception was not from the actor - in this case + # we will signal back to the service (through set_exception) + # to retry on another healthy node. request.future.set_exception(e) success = False + self.metrics.add_request_completion(start_time, success) # Mark task as done self.request_queue.task_done() return success @@ -108,8 +303,7 @@ async def _process_single_request(self, request: ServiceRequest) -> bool: self.active_requests -= 1 async def run(self): - """ - Main processing loop for the replica. This replaces _persistent_processor. + """Runs the main processing loop for the replica. Continuously processes requests from the queue while the replica is healthy. Handles capacity management and graceful degradation on failures. @@ -121,26 +315,19 @@ async def run(self): try: # Wait for a request with timeout to check health periodically request = await asyncio.wait_for( - self.request_queue.get(), timeout=1.0 + self.request_queue.get(), timeout=self._run_poll_rate_s ) - # Check if we have capacity + # Check if we have capacity - if we have too many ongoing, + # we will put the request back and wait. if self.active_requests >= self.max_concurrent_requests: - # Put the request back and wait await self.request_queue.put(request) await asyncio.sleep(0.1) continue - # Update state if proc_mesh recovered - if self.state == ReplicaState.RECOVERING and self.proc_mesh.healthy: - self.state = ReplicaState.HEALTHY - logger.debug("Replica %d recovered to healthy state", self.idx) - - # If we're still recovering and proc_mesh isn't healthy, reject request - if ( - self.state == ReplicaState.RECOVERING - and not self.proc_mesh.healthy - ): + # If we're recovering, reject the request + if self.state == ReplicaState.RECOVERING: + # This signals to the service to retry on another replica request.future.set_exception( RuntimeError(f"Replica {self.idx} is still recovering") ) @@ -151,16 +338,7 @@ async def run(self): asyncio.create_task(self._process_single_request(request)) except asyncio.TimeoutError: - # No requests, check for health state changes - if self.state == ReplicaState.RECOVERING and self.proc_mesh.healthy: - self.state = ReplicaState.HEALTHY - logger.debug("Replica %d recovered to healthy state", self.idx) - elif ( - self.state == ReplicaState.HEALTHY - and not self.proc_mesh.healthy - ): - self.state = ReplicaState.RECOVERING - logger.debug("Replica %d entered recovering state", self.idx) + # No requests, just continue checking for new ones continue except Exception as e: @@ -176,63 +354,70 @@ async def run(self): self._running = False logger.debug("Replica %d stopped processing", self.idx) + # Replica state management + @property def healthy(self) -> bool: return self.state == ReplicaState.HEALTHY @property - def load(self) -> int: - """Get current load (active requests + queue depth)""" - return self.active_requests + self.request_queue.qsize() + def failed(self) -> bool: + """Check if the replica has failed and needs recovery.""" + return self.state in (ReplicaState.RECOVERING, ReplicaState.UNHEALTHY) - @property - def capacity_utilization(self) -> float: - """Get current capacity utilization (0.0 to 1.0)""" - if self.max_concurrent_requests <= 0: - return 0.0 - return self.active_requests / self.max_concurrent_requests + def mark_failed(self): + """Mark the replica as failed, triggering recovery.""" + logger.debug("Marking replica %d as failed", self.idx) + self.state = ReplicaState.RECOVERING - def can_accept_request(self) -> bool: - """Check if replica can accept a new request""" - return ( - self.state == ReplicaState.HEALTHY - and self.active_requests < self.max_concurrent_requests - ) - - def __repr__(self) -> str: - return ( - f"Replica(idx={self.idx}, state={self.state.value}, " - f"active={self.active_requests}/{self.max_concurrent_requests}, " - f"queue={self.request_queue.qsize()})" - ) + async def _ensure_healthy_proc_mesh(self): + """Ensure we have a healthy proc_mesh, recovering if necessary.""" + if self.failed: + await self._recover() - async def setup(self): + async def _recover(self): """ - Sets up the replica and transitions to healthy state. + Recover the replica by recreating the proc_mesh and respawning actors. - This should be called after the proc_mesh has been initialized - and the actor has been spawned on it. + This is the core recovery logic moved from RecoverableProcMesh. """ - if self.state != ReplicaState.UNINITIALIZED: - logger.warning( - "Attempting to setup replica %d that's already initialized", self.idx - ) + if self._recovery_task and not self._recovery_task.done(): + # Recovery already in progress, wait for it + await self._recovery_task return - if self.actor is None: - raise RuntimeError(f"Cannot setup replica {self.idx}: actor is None") + logger.debug("Starting recovery for replica %d", self.idx) + self.state = ReplicaState.RECOVERING - try: - # Call actor setup if it exists - if hasattr(self.actor, "setup"): - await self.actor.setup.call() + # Create the recovery task + self._recovery_task = asyncio.create_task(self._do_recovery()) + await self._recovery_task + + async def _do_recovery(self): + """Internal method that performs the actual recovery work.""" + old_proc_mesh = self.proc_mesh + self.proc_mesh = None + self.actor = None + + # Stop old proc_mesh if it exists + if old_proc_mesh is not None: + try: + await old_proc_mesh.stop() + logger.debug("Old proc_mesh stopped for replica %d", self.idx) + except Exception as e: + logger.warning( + "Error stopping old proc_mesh for replica %d: %s", self.idx, e + ) - # Transition to healthy state + # Create new proc_mesh + try: + logger.debug("Creating new proc_mesh for replica %d", self.idx) + self.proc_mesh = await get_proc_mesh(process_config=self.proc_config) self.state = ReplicaState.HEALTHY - logger.debug("Replica %d setup complete", self.idx) + logger.debug("Recovery completed successfully for replica %d", self.idx) except Exception as e: - logger.error("Failed to setup replica %d: %s", self.idx, e) + logger.error("Recovery failed for replica %d: %s", self.idx, e) self.state = ReplicaState.UNHEALTHY raise @@ -283,7 +468,38 @@ async def stop(self): ) # Stop the proc_mesh - try: - await self.proc_mesh.stop() - except Exception as e: - logger.warning("Error stopping proc_mesh for replica %d: %s", self.idx, e) + if self.proc_mesh: + try: + await self.proc_mesh.stop() + except Exception as e: + logger.warning( + "Error stopping proc_mesh for replica %d: %s", self.idx, e + ) + + # Metric-related getters + + @property + def load(self) -> int: + """Get current load (active requests + queue depth)""" + return self.active_requests + self.request_queue.qsize() + + @property + def capacity_utilization(self) -> float: + """Get current capacity utilization (0.0 to 1.0)""" + if self.max_concurrent_requests <= 0: + return 0.0 + return self.active_requests / self.max_concurrent_requests + + def can_accept_request(self) -> bool: + """Check if replica can accept a new request""" + return ( + self.state == ReplicaState.HEALTHY + and self.active_requests < self.max_concurrent_requests + ) + + def __repr__(self) -> str: + return ( + f"Replica(idx={self.idx}, state={self.state.value}, " + f"active={self.active_requests}/{self.max_concurrent_requests}, " + f"queue={self.request_queue.qsize()})" + ) diff --git a/src/forge/controller/service_v2.py b/src/forge/controller/service_v2.py index 2fe8b4ed9..bd9c6c791 100644 --- a/src/forge/controller/service_v2.py +++ b/src/forge/controller/service_v2.py @@ -37,91 +37,20 @@ import contextvars import logging import pprint -import time import uuid -from collections import defaultdict, deque from dataclasses import dataclass, field -from typing import Any, Callable, Coroutine, Dict, List, Optional +from typing import Dict, List from monarch._src.actor.endpoint import EndpointProperty -from monarch.actor import ProcMesh -from forge.controller import RecoverableProcMesh -from forge.controller.replica import Replica, ServiceRequest +from forge.controller.replica import Replica, ReplicaMetrics, ServiceRequest from forge.types import ServiceConfig logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -# TODO - tie this into metric logger when it exists -@dataclass -class ReplicaMetrics: - """ - Metrics collection for a single replica instance. - - Tracks request counts, timing metrics, current state, and session assignments - for performance monitoring and autoscaling decisions. - - Attributes: - replica_idx: Unique identifier for this replica - total_requests: Total number of requests processed - successful_requests: Number of successfully completed requests - failed_requests: Number of failed requests - request_times: Sliding window of request start timestamps - request_latencies: Sliding window of request completion latencies - active_requests: Currently processing requests - queue_depth: Number of pending requests in queue - assigned_sessions: Number of sessions assigned to this replica - """ - - replica_idx: int - # Request metrics - total_requests: int = 0 - successful_requests: int = 0 - failed_requests: int = 0 - # Timing metrics (sliding window) - request_times: deque = field(default_factory=lambda: deque(maxlen=100)) - request_latencies: deque = field(default_factory=lambda: deque(maxlen=100)) - # Current state - active_requests: int = 0 - queue_depth: int = 0 - # Session metrics - assigned_sessions: int = 0 - - def add_request_start(self, timestamp: float): - """Record when a request starts processing.""" - self.request_times.append(timestamp) - self.total_requests += 1 - - def add_request_completion(self, start_time: float, success: bool): - """Record when a request completes.""" - latency = time.time() - start_time - self.request_latencies.append(latency) - if success: - self.successful_requests += 1 - else: - self.failed_requests += 1 - - def get_request_rate(self, window_seconds: float = 60.0) -> float: - """Get requests per second over the last window_seconds.""" - now = time.time() - cutoff = now - window_seconds - recent_requests = [t for t in self.request_times if t >= cutoff] - return len(recent_requests) / window_seconds if window_seconds > 0 else 0.0 - - def get_avg_latency(self, window_requests: int = 50) -> float: - """Get average latency over the last N requests.""" - if not self.request_latencies: - return 0.0 - recent_latencies = list(self.request_latencies)[-window_requests:] - return sum(recent_latencies) / len(recent_latencies) - - def get_capacity_utilization(self, max_concurrent: int) -> float: - """Get current capacity utilization (0.0 to 1.0).""" - return self.active_requests / max_concurrent if max_concurrent > 0 else 0.0 - - +# TODO - tie this into metrics logger when it exists. @dataclass class ServiceMetrics: """ @@ -154,16 +83,13 @@ def get_total_request_rate(self, window_seconds: float = 60.0) -> float: for metrics in self.replica_metrics.values() ) - def get_avg_queue_depth(self) -> float: + def get_avg_queue_depth(self, replicas: List) -> float: """Get average queue depth across all healthy replicas.""" - healthy_metrics = [ - m - for m in self.replica_metrics.values() - if m.replica_idx < self.healthy_replicas - ] - if not healthy_metrics: + healthy_replicas = [r for r in replicas if r.healthy] + if not healthy_replicas: return 0.0 - return sum(m.queue_depth for m in healthy_metrics) / len(healthy_metrics) + total_queue_depth = sum(r.request_queue.qsize() for r in healthy_replicas) + return total_queue_depth / len(healthy_replicas) def get_avg_capacity_utilization(self, replicas: List) -> float: """Get average capacity utilization across all healthy replicas.""" @@ -171,16 +97,8 @@ def get_avg_capacity_utilization(self, replicas: List) -> float: if not healthy_replicas: return 0.0 - utilizations = [] - for replica in healthy_replicas: - if replica.idx in self.replica_metrics: - metrics = self.replica_metrics[replica.idx] - utilization = metrics.get_capacity_utilization( - replica.max_concurrent_requests - ) - utilizations.append(utilization) - - return sum(utilizations) / len(utilizations) if utilizations else 0.0 + total_utilization = sum(r.capacity_utilization for r in healthy_replicas) + return total_utilization / len(healthy_replicas) def get_sessions_per_replica(self) -> float: """Get average sessions per healthy replica.""" @@ -309,10 +227,8 @@ async def __initialize__(self): replicas = [] num_replicas = self._cfg.num_replicas for i in range(num_replicas): - mesh = RecoverableProcMesh(proc_config=self._cfg.to_process_config()) replica = Replica( - proc_mesh=mesh, - actor=None, + proc_config=self._cfg.to_process_config(), idx=len(self._replicas) + i, max_concurrent_requests=self._cfg.replica_max_concurrent_requests, return_first_rank_result=self._cfg.return_first_rank_result, @@ -398,10 +314,6 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): # Queue the request using replica's method await replica.enqueue_request(request) - # Start the replica processing loop if not already running - if not replica._running: - asyncio.create_task(replica.run()) - # Wait for the result try: return await request.future @@ -467,10 +379,6 @@ async def _migrate_remaining_requests(self, failed_replica: Replica): target_replica = healthy_replicas[i % len(healthy_replicas)] await target_replica.enqueue_request(request) - # Start replica processing if not running - if not target_replica._running: - asyncio.create_task(target_replica.run()) - # Update session mapping if needed sess_id = request.session_id if ( @@ -513,24 +421,11 @@ def _update_service_metrics(self): self._metrics.total_sessions = len(self._active_sessions) self._metrics.total_replicas = len(self._replicas) self._metrics.healthy_replicas = sum(1 for r in self._replicas if r.healthy) - - # Update queue depths for all replicas + # Store direct references to replica metrics for aggregation + self._metrics.replica_metrics = {} for replica in self._replicas: - if replica.idx not in self._metrics.replica_metrics: - self._metrics.replica_metrics[replica.idx] = ReplicaMetrics(replica.idx) - - replica_metrics = self._metrics.replica_metrics[replica.idx] - replica_metrics.queue_depth = replica.request_queue.qsize() - replica_metrics.active_requests = replica.active_requests - - # Update session assignments per replica - session_counts = defaultdict(int) - for sess_id, replica_idx in self._session_replica_map.items(): - session_counts[replica_idx] += 1 - - for replica_idx, count in session_counts.items(): - if replica_idx in self._metrics.replica_metrics: - self._metrics.replica_metrics[replica_idx].assigned_sessions = count + # Use the replica's own metrics directly + self._metrics.replica_metrics[replica.idx] = replica.metrics def get_metrics(self) -> ServiceMetrics: """ @@ -574,7 +469,7 @@ def get_metrics_summary(self) -> dict: "healthy_replicas": self._metrics.healthy_replicas, "total_replicas": self._metrics.total_replicas, "total_request_rate": self._metrics.get_total_request_rate(), - "avg_queue_depth": self._metrics.get_avg_queue_depth(), + "avg_queue_depth": self._metrics.get_avg_queue_depth(self._replicas), "avg_capacity_utilization": self._metrics.get_avg_capacity_utilization( self._replicas ), @@ -583,17 +478,26 @@ def get_metrics_summary(self) -> dict: "replicas": {}, } - for replica_idx, metrics in self._metrics.replica_metrics.items(): - summary["replicas"][replica_idx] = { + for replica in self._replicas: + metrics = replica.metrics + + # Count sessions assigned to this replica + assigned_sessions = sum( + 1 + for replica_idx in self._session_replica_map.values() + if replica_idx == replica.idx + ) + + summary["replicas"][replica.idx] = { "total_requests": metrics.total_requests, "successful_requests": metrics.successful_requests, "failed_requests": metrics.failed_requests, "request_rate": metrics.get_request_rate(), "avg_latency": metrics.get_avg_latency(), - "active_requests": metrics.active_requests, - "queue_depth": metrics.queue_depth, - "assigned_sessions": metrics.assigned_sessions, - "capacity_utilization": metrics.get_capacity_utilization(10), + "active_requests": replica.active_requests, # Get from replica + "queue_depth": replica.request_queue.qsize(), # Get from replica + "assigned_sessions": assigned_sessions, # Calculate from session map + "capacity_utilization": replica.capacity_utilization, # Get from replica } return summary @@ -643,7 +547,7 @@ async def _health_loop(self, poll_rate_s: float): # Check for failed replicas and recover them failed_replicas = [] for replica in self._replicas: - if replica.proc_mesh.failed: + if replica.failed: failed_replicas.append(replica) if any(failed_replicas): @@ -656,12 +560,6 @@ async def _health_loop(self, poll_rate_s: float): await asyncio.sleep(poll_rate_s) - async def _custom_replica_routing( - self, sess_id: str | None, **kwargs - ) -> Optional[Replica]: - """Hook for custom routing logic. Override in subclasses to implement custom routing.""" - return None - def _get_next_replica(self) -> "Replica": """Get the next replica using round-robin selection.""" healthy_replicas = [r for r in self._replicas if r.healthy] @@ -686,13 +584,6 @@ def get_load(replica: "Replica") -> int: async def _get_replica(self, sess_id: str | None, **kwargs) -> "Replica": """Get a replica for the given session ID, with optional custom routing hints.""" - # Try custom routing first if hints are provided - if kwargs: - custom_result = await self._custom_replica_routing(sess_id, **kwargs) - if custom_result is not None: - return custom_result - - # Default routing logic if sess_id is None: # No session, use round-robin load balancing replica = self._get_next_replica() @@ -745,35 +636,33 @@ async def _maybe_init_replicas(self): logger.debug("Init replicas: %s", pprint.pformat(self._replicas_to_init)) - def _recover_hook( - replica: Replica, - ) -> Callable[[ProcMesh], Coroutine[Any, Any, None]]: - async def inner_hook(proc_mesh: ProcMesh) -> None: - if "name" in self._actor_kwargs: - actor_name = self._actor_kwargs.pop("name") - else: - actor_name = self._actor_def.__name__ - # TODO - expand support so name can stick within kwargs - actor = await proc_mesh.spawn( - actor_name, - self._actor_def, - *self._actor_args, - **self._actor_kwargs, - ) - replica.actor = actor - # Use replica's setup method instead of inline setup - await replica.setup() - - return inner_hook + # Initialize each replica (proc_mesh and actor spawning) + initialization_tasks = [] + for replica in self._replicas_to_init: + task = asyncio.create_task(self._init_single_replica(replica)) + initialization_tasks.append(task) - await asyncio.gather( - *[ - replica.proc_mesh.spawn(_recover_hook(replica)) - for replica in self._replicas_to_init - ] - ) + await asyncio.gather(*initialization_tasks, return_exceptions=True) self._replicas_to_init.clear() + async def _init_single_replica(self, replica: Replica): + """Initialize a single replica with proc_mesh and actor.""" + try: + # Initialize the proc_mesh + await replica.init_proc_mesh() + + # Spawn the actor using replica's method + await replica.spawn_actor( + self._actor_def, *self._actor_args, **self._actor_kwargs + ) + + logger.debug("Successfully initialized replica %d", replica.idx) + + except Exception as e: + logger.error("Failed to initialize replica %d: %s", replica.idx, e) + # Mark as failed so it can be retried later + replica.mark_failed() + async def _migrate_replica_workload(self, replica_to_remove: Replica): """Migrates all workload from a replica that's being removed.""" # Migrate queued requests diff --git a/tests/test_service.py b/tests/test_service.py index 0898de605..793964503 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -90,6 +90,7 @@ async def test_sessionless_calls(): try: # Test sessionless calls + logger.info("Starting requests") await service.incr() await service.incr() result = await service.value() @@ -172,7 +173,7 @@ async def test_replica_failure_and_recovery(): # Replica should be marked as failed failed_replica = service._replicas[original_replica_idx] - assert not failed_replica.proc_mesh.healthy + assert not failed_replica.healthy # Session should be reassigned on next call await service.incr(session) @@ -183,7 +184,7 @@ async def test_replica_failure_and_recovery(): new_session = await service.start_session() await service.incr(new_session) assigned_replica = service._replicas[service._session_replica_map[new_session]] - assert assigned_replica.proc_mesh.healthy + assert assigned_replica.healthy finally: await service.stop() From d4f566018b27339100af00bf5c6025620ae6a081 Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Fri, 22 Aug 2025 08:26:08 -0700 Subject: [PATCH 3/9] phase out service for service v2 --- src/forge/controller/__init__.py | 2 - src/forge/controller/recoverable_mesh.py | 289 -------------- src/forge/controller/service.py | 471 ++++------------------- tests/test_service.py | 4 +- 4 files changed, 77 insertions(+), 689 deletions(-) delete mode 100644 src/forge/controller/recoverable_mesh.py diff --git a/src/forge/controller/__init__.py b/src/forge/controller/__init__.py index a191d931a..a800eb14c 100644 --- a/src/forge/controller/__init__.py +++ b/src/forge/controller/__init__.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. from .actor import ForgeActor from .proc_mesh import get_proc_mesh, spawn_actors -from .recoverable_mesh import RecoverableProcMesh from .service import Service, ServiceConfig from .spawn import spawn_service @@ -16,5 +15,4 @@ "spawn_actors", "get_proc_mesh", "ForgeActor", - "RecoverableProcMesh", ] diff --git a/src/forge/controller/recoverable_mesh.py b/src/forge/controller/recoverable_mesh.py deleted file mode 100644 index d352eab17..000000000 --- a/src/forge/controller/recoverable_mesh.py +++ /dev/null @@ -1,289 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Recoverable Process Mesh - -This module provides a fault-tolerant wrapper around ProcMesh that automatically -recovers from crashes and failures. The RecoverableProcMesh class maintains the -same API as ProcMesh while adding automatic recovery capabilities. - -Key Features: -- **Automatic Recovery**: Detects mesh failures and automatically respawns processes -- **State Management**: Tracks mesh health and recovery status -- **Graceful Degradation**: Handles failures without losing the entire service -- **Context Management**: Supports async context manager for resource cleanup -- **Actor Respawning**: Automatically respawns actors after mesh recovery - -Example: - Basic usage with automatic recovery: - - >>> mesh = RecoverableProcMesh(num_gpus=2) - >>> - >>> async def spawn_actor(proc_mesh): - ... actor = await proc_mesh.spawn("MyActor", MyActorClass, *args) - ... return actor - >>> - >>> await mesh.spawn(spawn_actor) - >>> # Mesh will automatically recover if it fails - - Context manager usage: - - >>> async with RecoverableProcMesh(num_gpus=1) as mesh: - ... await mesh.spawn(spawn_actor) - ... # Mesh automatically cleaned up on exit -""" - -import asyncio -import logging -from enum import Enum -from typing import Any, Callable, Coroutine, Optional, TypeVar - -from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice -from monarch._src.actor.actor_mesh import Actor -from monarch._src.actor.proc_mesh import ProcMesh -from monarch._src.actor.shape import MeshTrait - -from forge.controller.proc_mesh import get_proc_mesh -from forge.types import ProcessConfig - -T = TypeVar("T", bound=Actor) -logger: logging.Logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -class MeshState(Enum): - """ - Enumeration of possible mesh states for tracking recovery status. - - States: - HEALTHY: Mesh is operational and ready to handle requests - RECOVERING: Mesh is in the process of recovering from a failure - UNHEALTHY: Mesh has failed and needs recovery - STOPPED: Mesh has been explicitly stopped and cannot be used - """ - - HEALTHY = 0 - RECOVERING = 1 - UNHEALTHY = 2 - STOPPED = 3 - - -class RecoverableProcMesh(MeshTrait): - """ - A fault-tolerant wrapper around ProcMesh with automatic crash recovery. - - This class provides the same API as ProcMesh while adding robust failure detection - and automatic recovery capabilities. When the underlying mesh crashes or becomes - unresponsive, it automatically creates a new mesh and respawns all actors. - - The RecoverableProcMesh maintains state tracking to ensure proper recovery sequencing - and prevents resource leaks during failure scenarios. It's designed for long-running - services that need high availability. - - Args: - proc_config: ProcessConfig containing mesh configuration including num_procs - - Attributes: - num_procs: Number of processes allocated to this mesh - state: Current state of the mesh (HEALTHY, RECOVERING, UNHEALTHY, STOPPED) - healthy: True if the mesh is operational and ready for requests - failed: True if the mesh has failed and needs recovery - - Example: - Basic usage with automatic recovery: - - >>> proc_config = ProcessConfig(num_procs=2, scheduler="local") - >>> mesh = RecoverableProcMesh(proc_config) - >>> - >>> async def setup_actor(proc_mesh): - ... actor = await proc_mesh.spawn("MyActor", MyActorClass) - ... await actor.initialize.call() - >>> - >>> await mesh.spawn(setup_actor) - >>> # If mesh fails, it will automatically recover and re-run setup_actor - - Context manager for automatic cleanup: - - >>> proc_config = ProcessConfig(num_procs=1) - >>> async with RecoverableProcMesh(proc_config) as mesh: - ... await mesh.spawn(setup_actor) - ... # Use mesh for operations - ... # Mesh automatically stopped and cleaned up on exit - - Manual state checking: - - >>> if mesh.healthy: - ... # Safe to use mesh - ... pass - >>> elif mesh.failed: - ... # Mesh needs recovery - ... await mesh.spawn(setup_actor) # Triggers recovery - """ - - def __init__( - self, - proc_config: ProcessConfig, - ) -> None: - self._proc_config: ProcessConfig = proc_config - self.num_procs = proc_config.num_procs - self._proc_mesh: Optional[ProcMesh] = None - self._recovery_task: Optional[asyncio.Task[None]] = None - self.state: MeshState = MeshState.UNHEALTHY - - async def spawn( - self, hook: Callable[[ProcMesh], Coroutine[Any, Any, None]] - ) -> None: - """ - Spawn actors on the mesh with automatic recovery. - - This method ensures the mesh is healthy before spawning actors. If the mesh - has failed, it automatically triggers recovery and then executes the spawn hook. - The hook function receives the underlying ProcMesh and should handle actor - creation and initialization. - - Args: - hook: Async function that receives a ProcMesh and spawns/initializes actors - - Example: - >>> async def setup_actors(proc_mesh): - ... actor = await proc_mesh.spawn("MyActor", MyActorClass) - ... await actor.setup.call() - >>> - >>> await mesh.spawn(setup_actors) - """ - await self._background_spawn(hook) - - def trigger_spawn( - self, hook: Callable[[ProcMesh], Coroutine[Any, Any, None]] - ) -> None: - self._background_spawn(hook) - - def _background_spawn( - self, hook: Callable[[ProcMesh], Coroutine[Any, Any, None]] - ) -> asyncio.Task[None]: - if self.state == MeshState.STOPPED: - logger.warning("ProcMesh was already stopped when trying to spawn") - - self.state = MeshState.RECOVERING - self._recovery_task = asyncio.create_task(self._recover(hook)) - - return self._recovery_task - - def gpus(self) -> int: - return self.num_procs - - async def _recover( - self, hook: Callable[[ProcMesh], Coroutine[Any, Any, None]] - ) -> None: - self.state = MeshState.RECOVERING - - old_proc_mesh = self._proc_mesh - self._proc_mesh = None - - if old_proc_mesh is not None: - try: - await old_proc_mesh.stop() - except Exception as e: - logger.warning(f"Error stopping old ProcMesh: {e}") - - try: - self._proc_mesh = await get_proc_mesh(process_config=self._proc_config) - if self._proc_mesh is not None: - await hook(self._proc_mesh) - self.state = MeshState.HEALTHY - - except Exception as e: - logger.exception(f"Recovery attempt failed: {e}") - self.state = MeshState.UNHEALTHY - - @property - def healthy(self) -> bool: - return self.state == MeshState.HEALTHY - - @property - def failed(self) -> bool: - return self.state == MeshState.UNHEALTHY - - async def stop(self) -> None: - """ - Stop the mesh and clean up all resources. - - Gracefully shuts down the underlying ProcMesh and marks this recoverable - mesh as stopped. Once stopped, the mesh cannot be used for further operations. - - This method is idempotent - calling it multiple times is safe. - - Example: - >>> await mesh.stop() - >>> # Mesh is now stopped and cannot be used - """ - logger.info("Stopping RecoverableProcMesh") - if self.state == MeshState.STOPPED: - logger.info("RecoverableProcMesh was already stopped") - return - try: - if self._proc_mesh is not None: - await self._proc_mesh.stop() - except RuntimeError as e: - logger.warning("RecoverableProcMesh could not be stopped: %s", e) - - self.state = MeshState.STOPPED - - async def __aenter__(self) -> "RecoverableProcMesh": - """Enter the async context manager.""" - if self.state == MeshState.STOPPED: - raise RuntimeError("RecoverableProcMesh has already been stopped") - return self - - async def __aexit__( - self, exc_type: object, exc_val: object, exc_tb: object - ) -> None: - """Exit the async context manager.""" - # In case there are multiple nested "async with" statements, we only - # want it to close once. - if self.state != MeshState.STOPPED: - await self.stop() - - def mark_failed(self): - """ - Mark the mesh as failed, triggering recovery on next spawn. - - This method is typically called when an operation on the mesh fails - or when external monitoring detects that the mesh is unresponsive. - The next call to spawn() will trigger automatic recovery. - - Example: - >>> try: - ... # Some operation that might fail - ... await actor.some_method.call() - >>> except Exception: - ... mesh.mark_failed() # Mark for recovery - """ - self.state = MeshState.UNHEALTHY - - @property - def _shape(self) -> Shape: - if self._proc_mesh is None: - raise RuntimeError("ProcMesh not initialized") - return self._proc_mesh._shape - - @property - def _ndslice(self) -> Slice: - if self._proc_mesh is None: - raise RuntimeError("ProcMesh not initialized") - return self._proc_mesh._ndslice - - @property - def _labels(self) -> list[str]: - if self._proc_mesh is None: - raise RuntimeError("ProcMesh not initialized") - return self._proc_mesh._labels - - def _new_with_shape(self, shape: Shape) -> "RecoverableProcMesh": - raise NotImplementedError( - "RecoverableProcMesh does not support _new_with_shape" - ) diff --git a/src/forge/controller/service.py b/src/forge/controller/service.py index 13c58db36..bd9c6c791 100644 --- a/src/forge/controller/service.py +++ b/src/forge/controller/service.py @@ -37,90 +37,20 @@ import contextvars import logging import pprint -import time import uuid -from collections import defaultdict, deque from dataclasses import dataclass, field -from typing import Any, Callable, Coroutine, Dict, List, Optional +from typing import Dict, List from monarch._src.actor.endpoint import EndpointProperty -from monarch.actor import ActorError, ProcMesh -from forge.controller import RecoverableProcMesh +from forge.controller.replica import Replica, ReplicaMetrics, ServiceRequest from forge.types import ServiceConfig logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -# TODO - tie this into metric logger when it exists -@dataclass -class ReplicaMetrics: - """ - Metrics collection for a single replica instance. - - Tracks request counts, timing metrics, current state, and session assignments - for performance monitoring and autoscaling decisions. - - Attributes: - replica_idx: Unique identifier for this replica - total_requests: Total number of requests processed - successful_requests: Number of successfully completed requests - failed_requests: Number of failed requests - request_times: Sliding window of request start timestamps - request_latencies: Sliding window of request completion latencies - active_requests: Currently processing requests - queue_depth: Number of pending requests in queue - assigned_sessions: Number of sessions assigned to this replica - """ - - replica_idx: int - # Request metrics - total_requests: int = 0 - successful_requests: int = 0 - failed_requests: int = 0 - # Timing metrics (sliding window) - request_times: deque = field(default_factory=lambda: deque(maxlen=100)) - request_latencies: deque = field(default_factory=lambda: deque(maxlen=100)) - # Current state - active_requests: int = 0 - queue_depth: int = 0 - # Session metrics - assigned_sessions: int = 0 - - def add_request_start(self, timestamp: float): - """Record when a request starts processing.""" - self.request_times.append(timestamp) - self.total_requests += 1 - - def add_request_completion(self, start_time: float, success: bool): - """Record when a request completes.""" - latency = time.time() - start_time - self.request_latencies.append(latency) - if success: - self.successful_requests += 1 - else: - self.failed_requests += 1 - - def get_request_rate(self, window_seconds: float = 60.0) -> float: - """Get requests per second over the last window_seconds.""" - now = time.time() - cutoff = now - window_seconds - recent_requests = [t for t in self.request_times if t >= cutoff] - return len(recent_requests) / window_seconds if window_seconds > 0 else 0.0 - - def get_avg_latency(self, window_requests: int = 50) -> float: - """Get average latency over the last N requests.""" - if not self.request_latencies: - return 0.0 - recent_latencies = list(self.request_latencies)[-window_requests:] - return sum(recent_latencies) / len(recent_latencies) - - def get_capacity_utilization(self, max_concurrent: int) -> float: - """Get current capacity utilization (0.0 to 1.0).""" - return self.active_requests / max_concurrent if max_concurrent > 0 else 0.0 - - +# TODO - tie this into metrics logger when it exists. @dataclass class ServiceMetrics: """ @@ -153,33 +83,22 @@ def get_total_request_rate(self, window_seconds: float = 60.0) -> float: for metrics in self.replica_metrics.values() ) - def get_avg_queue_depth(self) -> float: + def get_avg_queue_depth(self, replicas: List) -> float: """Get average queue depth across all healthy replicas.""" - healthy_metrics = [ - m - for m in self.replica_metrics.values() - if m.replica_idx < self.healthy_replicas - ] - if not healthy_metrics: + healthy_replicas = [r for r in replicas if r.healthy] + if not healthy_replicas: return 0.0 - return sum(m.queue_depth for m in healthy_metrics) / len(healthy_metrics) + total_queue_depth = sum(r.request_queue.qsize() for r in healthy_replicas) + return total_queue_depth / len(healthy_replicas) def get_avg_capacity_utilization(self, replicas: List) -> float: """Get average capacity utilization across all healthy replicas.""" - healthy_replicas = [r for r in replicas if r.proc_mesh.healthy] + healthy_replicas = [r for r in replicas if r.healthy] if not healthy_replicas: return 0.0 - utilizations = [] - for replica in healthy_replicas: - if replica.idx in self.replica_metrics: - metrics = self.replica_metrics[replica.idx] - utilization = metrics.get_capacity_utilization( - replica.max_concurrent_requests - ) - utilizations.append(utilization) - - return sum(utilizations) / len(utilizations) if utilizations else 0.0 + total_utilization = sum(r.capacity_utilization for r in healthy_replicas) + return total_utilization / len(healthy_replicas) def get_sessions_per_replica(self) -> float: """Get average sessions per healthy replica.""" @@ -188,18 +107,6 @@ def get_sessions_per_replica(self) -> float: return self.total_sessions / self.healthy_replicas -@dataclass -class Replica: - proc_mesh: RecoverableProcMesh - actor: Any - idx: int - request_queue: asyncio.Queue[dict] = field(default_factory=asyncio.Queue) - active_requests: int = 0 - max_concurrent_requests: int = 10 - _processor_running: bool = False - metadata: dict = field(default_factory=dict) - - @dataclass class Session: session_id: str @@ -299,11 +206,6 @@ def __init__(self, cfg: ServiceConfig, actor_def, *actor_args, **actor_kwargs): # Initialize metrics collection self._metrics = ServiceMetrics() - - # Autoscaling state - self._last_scale_up_time = 0.0 - self._last_scale_down_time = 0.0 - self._low_utilization_start_time = None self._health_task = None self._shutdown_requested = False @@ -325,12 +227,11 @@ async def __initialize__(self): replicas = [] num_replicas = self._cfg.num_replicas for i in range(num_replicas): - mesh = RecoverableProcMesh(proc_config=self._cfg.to_process_config()) replica = Replica( - proc_mesh=mesh, - actor=None, + proc_config=self._cfg.to_process_config(), idx=len(self._replicas) + i, max_concurrent_requests=self._cfg.replica_max_concurrent_requests, + return_first_rank_result=self._cfg.return_first_rank_result, ) replicas.append(replica) @@ -401,26 +302,24 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): replica = await self._get_replica(sess_id, **routing_hints) - # Create a request object to queue - request = { - "sess_id": sess_id, - "function": function, - "args": args, - "kwargs": kwargs, - "future": asyncio.Future(), - } - # Queue the request - await replica.request_queue.put(request) + # Create a ServiceRequest object to queue + request = ServiceRequest( + session_id=sess_id, + function=function, + args=args, + kwargs=kwargs, + future=asyncio.Future(), + ) - # Ensure the replica has a processor running - self._ensure_processor_running(replica) + # Queue the request using replica's method + await replica.enqueue_request(request) # Wait for the result try: - return await request["future"] + return await request.future except Exception as e: # If the replica failed, try to retry once - if not replica.proc_mesh.healthy: + if not replica.healthy: logger.debug( "Replica %d failed during request, retrying on healthy replica", replica.idx, @@ -430,100 +329,6 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): ) raise - def _ensure_processor_running(self, replica: Replica): - """Ensures a persistent processor is running for this replica.""" - if not replica._processor_running: - replica._processor_running = True - asyncio.create_task(self._persistent_processor(replica)) - - async def _persistent_processor(self, replica: Replica): - """Persistent processor that continuously handles requests for a replica.""" - try: - while replica.proc_mesh.healthy: - try: - # Wait for a request with timeout to check health periodically - request = await asyncio.wait_for( - replica.request_queue.get(), timeout=1.0 - ) - - # Check if we have capacity - if replica.active_requests >= replica.max_concurrent_requests: - # Put the request back and wait - await replica.request_queue.put(request) - await asyncio.sleep(0.1) - continue - - # Process the request - asyncio.create_task(self._process_single_request(replica, request)) - - except asyncio.TimeoutError: - # No requests, continue to check health - continue - except Exception as e: - logger.error( - "Error in persistent processor for replica %d: %s", - replica.idx, - e, - ) - break - finally: - replica._processor_running = False - # Migrate any remaining requests to healthy replicas - await self._migrate_remaining_requests(replica) - - async def _process_single_request(self, replica: Replica, request: dict): - """Processes a single request.""" - start_time = time.time() - replica.active_requests += 1 - - # Get or create metrics for this replica - if replica.idx not in self._metrics.replica_metrics: - self._metrics.replica_metrics[replica.idx] = ReplicaMetrics(replica.idx) - - replica_metrics = self._metrics.replica_metrics[replica.idx] - replica_metrics.add_request_start(start_time) - replica_metrics.active_requests = replica.active_requests - - try: - # Get the actor and endpoint - actor = replica.actor - endpoint_func = getattr(actor, request["function"]) - - # Execute the request - success = True - try: - result = await endpoint_func.call(*request["args"], **request["kwargs"]) - if ( - self._cfg.return_first_rank_result - and hasattr(result, "_values") - and result._values - ): - result = result._values[0] - request["future"].set_result(result) - except ActorError as e: - logger.debug("Got failure on replica %d. Error:\n%s", replica.idx, e) - replica.proc_mesh.mark_failed() - # Unwrap the ActorError into its raw exception. - request["future"].set_result(e.exception) - success = False - except Exception as e: - logger.debug( - "Got unexpected error on replica %d. Error:\n%s", replica.idx, e - ) - replica.proc_mesh.mark_failed() - request["future"].set_result(e) - success = False - - # Record completion metrics - replica_metrics.add_request_completion(start_time, success) - - # Mark task as done - replica.request_queue.task_done() - - finally: - replica.active_requests -= 1 - replica_metrics.active_requests = replica.active_requests - async def _retry_request_on_healthy_replica( self, sess_id: str | None, function: str, *args, **kwargs ): @@ -558,13 +363,13 @@ async def _migrate_remaining_requests(self, failed_replica: Replica): # Find healthy replicas healthy_replicas = [ - r for r in self._replicas if r.proc_mesh.healthy and r != failed_replica + r for r in self._replicas if r.healthy and r != failed_replica ] if not healthy_replicas: # No healthy replicas, fail all requests for request in migrated_requests: - request["future"].set_exception( + request.future.set_exception( RuntimeError("No healthy replicas available") ) return @@ -572,11 +377,10 @@ async def _migrate_remaining_requests(self, failed_replica: Replica): # Distribute requests among healthy replicas for i, request in enumerate(migrated_requests): target_replica = healthy_replicas[i % len(healthy_replicas)] - await target_replica.request_queue.put(request) - self._ensure_processor_running(target_replica) + await target_replica.enqueue_request(request) # Update session mapping if needed - sess_id = request["sess_id"] + sess_id = request.session_id if ( sess_id in self._session_replica_map and self._session_replica_map[sess_id] == failed_replica.idx @@ -616,27 +420,12 @@ def _update_service_metrics(self): """Updates service-level metrics.""" self._metrics.total_sessions = len(self._active_sessions) self._metrics.total_replicas = len(self._replicas) - self._metrics.healthy_replicas = sum( - 1 for r in self._replicas if r.proc_mesh.healthy - ) - - # Update queue depths for all replicas + self._metrics.healthy_replicas = sum(1 for r in self._replicas if r.healthy) + # Store direct references to replica metrics for aggregation + self._metrics.replica_metrics = {} for replica in self._replicas: - if replica.idx not in self._metrics.replica_metrics: - self._metrics.replica_metrics[replica.idx] = ReplicaMetrics(replica.idx) - - replica_metrics = self._metrics.replica_metrics[replica.idx] - replica_metrics.queue_depth = replica.request_queue.qsize() - replica_metrics.active_requests = replica.active_requests - - # Update session assignments per replica - session_counts = defaultdict(int) - for sess_id, replica_idx in self._session_replica_map.items(): - session_counts[replica_idx] += 1 - - for replica_idx, count in session_counts.items(): - if replica_idx in self._metrics.replica_metrics: - self._metrics.replica_metrics[replica_idx].assigned_sessions = count + # Use the replica's own metrics directly + self._metrics.replica_metrics[replica.idx] = replica.metrics def get_metrics(self) -> ServiceMetrics: """ @@ -680,7 +469,7 @@ def get_metrics_summary(self) -> dict: "healthy_replicas": self._metrics.healthy_replicas, "total_replicas": self._metrics.total_replicas, "total_request_rate": self._metrics.get_total_request_rate(), - "avg_queue_depth": self._metrics.get_avg_queue_depth(), + "avg_queue_depth": self._metrics.get_avg_queue_depth(self._replicas), "avg_capacity_utilization": self._metrics.get_avg_capacity_utilization( self._replicas ), @@ -689,17 +478,26 @@ def get_metrics_summary(self) -> dict: "replicas": {}, } - for replica_idx, metrics in self._metrics.replica_metrics.items(): - summary["replicas"][replica_idx] = { + for replica in self._replicas: + metrics = replica.metrics + + # Count sessions assigned to this replica + assigned_sessions = sum( + 1 + for replica_idx in self._session_replica_map.values() + if replica_idx == replica.idx + ) + + summary["replicas"][replica.idx] = { "total_requests": metrics.total_requests, "successful_requests": metrics.successful_requests, "failed_requests": metrics.failed_requests, "request_rate": metrics.get_request_rate(), "avg_latency": metrics.get_avg_latency(), - "active_requests": metrics.active_requests, - "queue_depth": metrics.queue_depth, - "assigned_sessions": metrics.assigned_sessions, - "capacity_utilization": metrics.get_capacity_utilization(10), + "active_requests": replica.active_requests, # Get from replica + "queue_depth": replica.request_queue.qsize(), # Get from replica + "assigned_sessions": assigned_sessions, # Calculate from session map + "capacity_utilization": replica.capacity_utilization, # Get from replica } return summary @@ -749,7 +547,7 @@ async def _health_loop(self, poll_rate_s: float): # Check for failed replicas and recover them failed_replicas = [] for replica in self._replicas: - if replica.proc_mesh.failed: + if replica.failed: failed_replicas.append(replica) if any(failed_replicas): @@ -762,15 +560,9 @@ async def _health_loop(self, poll_rate_s: float): await asyncio.sleep(poll_rate_s) - async def _custom_replica_routing( - self, sess_id: str | None, **kwargs - ) -> Optional[Replica]: - """Hook for custom routing logic. Override in subclasses to implement custom routing.""" - return None - def _get_next_replica(self) -> "Replica": """Get the next replica using round-robin selection.""" - healthy_replicas = [r for r in self._replicas if r.proc_mesh.healthy] + healthy_replicas = [r for r in self._replicas if r.healthy] if not healthy_replicas: raise RuntimeError("No healthy replicas available for load balancing") @@ -780,7 +572,7 @@ def _get_next_replica(self) -> "Replica": def _get_least_loaded_replica(self) -> "Replica": """Get the replica with the lowest load.""" - healthy_replicas = [r for r in self._replicas if r.proc_mesh.healthy] + healthy_replicas = [r for r in self._replicas if r.healthy] if not healthy_replicas: raise RuntimeError("No healthy replicas available for session assignment") @@ -792,13 +584,6 @@ def get_load(replica: "Replica") -> int: async def _get_replica(self, sess_id: str | None, **kwargs) -> "Replica": """Get a replica for the given session ID, with optional custom routing hints.""" - # Try custom routing first if hints are provided - if kwargs: - custom_result = await self._custom_replica_routing(sess_id, **kwargs) - if custom_result is not None: - return custom_result - - # Default routing logic if sess_id is None: # No session, use round-robin load balancing replica = self._get_next_replica() @@ -809,7 +594,7 @@ async def _get_replica(self, sess_id: str | None, **kwargs) -> "Replica": replica_idx = self._session_replica_map[sess_id] # Find the replica with this index for replica in self._replicas: - if replica.idx == replica_idx and replica.proc_mesh.healthy: + if replica.idx == replica_idx and replica.healthy: return replica # If the replica is no longer healthy, remove from session map and reassign del self._session_replica_map[sess_id] @@ -838,8 +623,9 @@ async def stop(self): except asyncio.CancelledError: logger.info("Health loop task cancelled.") + # Stop all replicas using their stop method await asyncio.gather( - *[replica.proc_mesh.stop() for replica in self._replicas], + *[replica.stop() for replica in self._replicas], return_exceptions=True, ) @@ -850,139 +636,32 @@ async def _maybe_init_replicas(self): logger.debug("Init replicas: %s", pprint.pformat(self._replicas_to_init)) - def _recover_hook( - replica: Replica, - ) -> Callable[[ProcMesh], Coroutine[Any, Any, None]]: - async def inner_hook(proc_mesh: ProcMesh) -> None: - if "name" in self._actor_kwargs: - actor_name = self._actor_kwargs.pop("name") - else: - actor_name = self._actor_def.__name__ - # TODO - expand support so name can stick within kwargs - actor = await proc_mesh.spawn( - actor_name, - self._actor_def, - *self._actor_args, - **self._actor_kwargs, - ) - replica.actor = actor - if hasattr(actor, "setup"): - await actor.setup.call() - - return inner_hook + # Initialize each replica (proc_mesh and actor spawning) + initialization_tasks = [] + for replica in self._replicas_to_init: + task = asyncio.create_task(self._init_single_replica(replica)) + initialization_tasks.append(task) - await asyncio.gather( - *[ - replica.proc_mesh.spawn(_recover_hook(replica)) - for replica in self._replicas_to_init - ] - ) + await asyncio.gather(*initialization_tasks, return_exceptions=True) self._replicas_to_init.clear() - async def _scale_up(self, num_replicas: int = 1): - """ - Scales up the service by adding new replicas. - - Creates new replica instances with their own process meshes and queues them - for initialization. The replicas will be initialized asynchronously by the - health loop to avoid blocking the scaling operation. - - Args: - num_replicas: Number of replicas to add (default: 1) - - Note: - Replicas are queued for initialization rather than initialized immediately - to prevent blocking during scaling operations. - """ - logger.debug("Scaling up with %d replicas.", num_replicas) - new_replicas = [] - for i in range(num_replicas): - mesh = RecoverableProcMesh( - self._cfg.procs_per_replica, - ) - replica = Replica( - proc_mesh=mesh, - actor=None, - idx=len(self._replicas) + i, - max_concurrent_requests=self._cfg.replica_max_concurrent_requests, - ) - new_replicas.append(replica) - - # Add to the initialization queue instead of initializing immediately - self._replicas_to_init.extend(new_replicas) - self._replicas.extend(new_replicas) - logger.debug( - "Queued %d replicas for initialization. Total replicas: %d", - num_replicas, - len(self._replicas), - ) - - async def _scale_down_replicas(self, num_replicas: int = 1): - """ - Scales down the service by intelligently removing replicas. - - Prioritizes removal of unhealthy replicas first, then selects healthy replicas - with the lowest load. Migrates all workload (sessions and queued requests) - from removed replicas to remaining healthy replicas. + async def _init_single_replica(self, replica: Replica): + """Initialize a single replica with proc_mesh and actor.""" + try: + # Initialize the proc_mesh + await replica.init_proc_mesh() - Args: - num_replicas: Number of replicas to remove (default: 1) - - Note: - # Test context manager usage - async with service.session(): - await service.incr() - await service.incr() - result = await service.value() - assert result == 2 - - Sessions are reassigned on their next request rather than immediately - to avoid disrupting active workloads. - """ - logger.debug("Scaling down by %d replicas.", num_replicas) - - # Find replicas to remove (prefer unhealthy ones first, then least loaded) - replicas_to_remove = [] - - # First, try to remove unhealthy replicas - unhealthy_replicas = [r for r in self._replicas if not r.proc_mesh.healthy] - for replica in unhealthy_replicas[:num_replicas]: - replicas_to_remove.append(replica) - - # If we need more, remove healthy replicas with least load - remaining_to_remove = num_replicas - len(replicas_to_remove) - if remaining_to_remove > 0: - healthy_replicas = [ - r - for r in self._replicas - if r.proc_mesh.healthy and r not in replicas_to_remove - ] - # Sort by load (queue depth + active requests) - healthy_replicas.sort( - key=lambda r: r.request_queue.qsize() + r.active_requests + # Spawn the actor using replica's method + await replica.spawn_actor( + self._actor_def, *self._actor_args, **self._actor_kwargs ) - for replica in healthy_replicas[:remaining_to_remove]: - replicas_to_remove.append(replica) - - # Migrate sessions and requests from replicas being removed - for replica in replicas_to_remove: - await self._migrate_replica_workload(replica) + logger.debug("Successfully initialized replica %d", replica.idx) - # Stop the replica - try: - await replica.proc_mesh.stop() - except Exception as e: - logger.warning("Error stopping replica %d: %s", replica.idx, e) - - # Remove from replicas list - self._replicas.remove(replica) - - # Update replica indices - for i, replica in enumerate(self._replicas): - replica.idx = i - - logger.debug("Scale down complete. Remaining replicas: %d", len(self._replicas)) + except Exception as e: + logger.error("Failed to initialize replica %d: %s", replica.idx, e) + # Mark as failed so it can be retried later + replica.mark_failed() async def _migrate_replica_workload(self, replica_to_remove: Replica): """Migrates all workload from a replica that's being removed.""" diff --git a/tests/test_service.py b/tests/test_service.py index 793964503..1bfdaad38 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -12,8 +12,8 @@ import logging import pytest -from forge.controller.service_v2 import ServiceConfig -from forge.controller.spawn import spawn_service_v2 as spawn_service +from forge.controller.service import ServiceConfig +from forge.controller.spawn import spawn_service from monarch.actor import Actor, endpoint logger = logging.getLogger(__name__) From 42028831770631880499563f285fb16e68b48802 Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Fri, 22 Aug 2025 08:26:34 -0700 Subject: [PATCH 4/9] remove v2 --- src/forge/controller/service_v2.py | 683 ----------------------------- 1 file changed, 683 deletions(-) delete mode 100644 src/forge/controller/service_v2.py diff --git a/src/forge/controller/service_v2.py b/src/forge/controller/service_v2.py deleted file mode 100644 index bd9c6c791..000000000 --- a/src/forge/controller/service_v2.py +++ /dev/null @@ -1,683 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -""" -Distributed Actor Service Controller - -This module provides a robust service orchestration system for managing distributed -actor-based workloads with automatic scaling, fault tolerance, and intelligent load balancing. - -The main Service class acts as a singleton controller that handles: -- Fault tolerance with automatic replica recovery -- Autoscaling based on real-time metrics -- Load balancing across healthy replicas -- Session management with context propagation -- Comprehensive metrics collection and monitoring - -Example: - Basic service setup: - - >>> config = ServiceConfig( - ... gpus_per_replica=1, - ... num_replicas=3 - ... ) - >>> service = Service(config, MyActorClass, *args, **kwargs) - >>> await service.__initialize__() - - Session-based usage: - - >>> async with service.session(): - ... result = await service.my_endpoint(arg1, arg2) -""" - - -import asyncio -import contextvars -import logging -import pprint -import uuid -from dataclasses import dataclass, field -from typing import Dict, List - -from monarch._src.actor.endpoint import EndpointProperty - -from forge.controller.replica import Replica, ReplicaMetrics, ServiceRequest -from forge.types import ServiceConfig - -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - - -# TODO - tie this into metrics logger when it exists. -@dataclass -class ServiceMetrics: - """ - Aggregated metrics collection for the entire service. - - Provides service-wide visibility into performance, health, and scaling metrics - by aggregating data from all replica instances. - - Attributes: - replica_metrics: Per-replica metrics indexed by replica ID - total_sessions: Number of active sessions across all replicas - healthy_replicas: Number of currently healthy replicas - total_replicas: Total number of replicas (healthy + unhealthy) - last_scale_event: Timestamp of the last scaling operation - """ - - # Replica metrics - replica_metrics: Dict[int, ReplicaMetrics] = field(default_factory=dict) - # Service-level metrics - total_sessions: int = 0 - healthy_replicas: int = 0 - total_replicas: int = 0 - # Time-based metrics - last_scale_event: float = 0.0 - - def get_total_request_rate(self, window_seconds: float = 60.0) -> float: - """Get total requests per second across all replicas.""" - return sum( - metrics.get_request_rate(window_seconds) - for metrics in self.replica_metrics.values() - ) - - def get_avg_queue_depth(self, replicas: List) -> float: - """Get average queue depth across all healthy replicas.""" - healthy_replicas = [r for r in replicas if r.healthy] - if not healthy_replicas: - return 0.0 - total_queue_depth = sum(r.request_queue.qsize() for r in healthy_replicas) - return total_queue_depth / len(healthy_replicas) - - def get_avg_capacity_utilization(self, replicas: List) -> float: - """Get average capacity utilization across all healthy replicas.""" - healthy_replicas = [r for r in replicas if r.healthy] - if not healthy_replicas: - return 0.0 - - total_utilization = sum(r.capacity_utilization for r in healthy_replicas) - return total_utilization / len(healthy_replicas) - - def get_sessions_per_replica(self) -> float: - """Get average sessions per healthy replica.""" - if self.healthy_replicas == 0: - return 0.0 - return self.total_sessions / self.healthy_replicas - - -@dataclass -class Session: - session_id: str - - -# Global context variable for session state -# This is used to propagate session state across async tasks -_session_context: contextvars.ContextVar[dict | None] = contextvars.ContextVar( - "session_context", default=None -) - - -class SessionContext: - """Context manager for service sessions using context variables.""" - - def __init__(self, service: "Service", **session_kwargs): - self.service = service - self.session_id: str | None = None - self.session_kwargs = session_kwargs - self._token = None - - async def __aenter__(self): - """Start a session and set context variables.""" - self.session_id = await self.service.start_session() - # Set context for this async task - context_value = {"session_id": self.session_id, "kwargs": self.session_kwargs} - self._token = _session_context.set(context_value) - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """Terminate the session and restore context.""" - if self._token: - _session_context.reset(self._token) - if self.session_id: - await self.service.terminate_session(self.session_id) - self.session_id = None - - -class Service: - """ - Distributed Actor Service Controller - - A sophisticated service orchestration system that manages multiple replicas of actor-based - services with automatic scaling, fault tolerance, and intelligent load balancing. - - The Service acts as a unified interface for distributed workloads, automatically handling: - - **Fault Tolerance**: Health monitoring, automatic replica recovery, request migration - - **Load Balancing**: Round-robin, least-loaded, and session-affinity routing - - **Session Management**: Stateful session handling with context propagation - - **Metrics Collection**: Comprehensive performance and health monitoring - - Args: - cfg: Service configuration including number of replicas, GPUs per replica, and health polling rate - actor_def: Actor class definition to instantiate on each replica - *actor_args: Positional arguments passed to actor constructor - **actor_kwargs: Keyword arguments passed to actor constructor - - Example: - Basic setup with autoscaling: - - >>> config = ServiceConfig( - ... gpus_per_replica=1, - ... num_replicas=3, - ... ) - >>> service = Service(config, MyActorClass, model_path="/path/to/model") - >>> await service.__initialize__() - - Session-based usage: - - >>> async with service.session(): - ... result1 = await service.my_endpoint(arg1, arg2) - ... result2 = await service.another_endpoint(arg3) - - Stateless usage: - - >>> result = await service.my_endpoint(arg1, arg2) # Uses round-robin - - Attributes: - _cfg: Service configuration - _replicas: List of managed replica instances - _active_sessions: Currently active sessions - _metrics: Aggregated service and replica metrics - _endpoints: Dynamically registered actor endpoints - """ - - def __init__(self, cfg: ServiceConfig, actor_def, *actor_args, **actor_kwargs): - self._cfg = cfg - self._replicas = [] - self._actor_def = actor_def - self._actor_args = actor_args - self._actor_kwargs = actor_kwargs - - self._active_sessions = [] - self._id_session_map = {} - self._session_replica_map: Dict[str, int] = {} - self._next_replica_idx = 0 # For round-robin load balancing - - # Initialize metrics collection - self._metrics = ServiceMetrics() - self._health_task = None - self._shutdown_requested = False - - # Replica initialization queue - self._replicas_to_init = [] - - # For all endpoints within the actor_def, create an interface from it - self._endpoints = [] - for func_name in dir(actor_def): - func = getattr(actor_def, func_name) - if isinstance(func, EndpointProperty): - logger.debug("Registering endpoint %s", func_name) - self._endpoints.append(func_name) - # Dynamically add this endpoint method to the Service class - self._add_endpoint_method(func_name) - - async def __initialize__(self): - logger.debug("Starting service up with %d replicas.", self._cfg.num_replicas) - replicas = [] - num_replicas = self._cfg.num_replicas - for i in range(num_replicas): - replica = Replica( - proc_config=self._cfg.to_process_config(), - idx=len(self._replicas) + i, - max_concurrent_requests=self._cfg.replica_max_concurrent_requests, - return_first_rank_result=self._cfg.return_first_rank_result, - ) - replicas.append(replica) - - # Initializing should only happen in the health_loop - # and during the first initialization. - # If multiple parts of the code try to initialize replicas at - # the same time, it can cause nasty race conditions - # (e.g., double initialization, inconsistent state, or resource conflicts). - # By funneling all replica initialization through a single queue and the - # health loop, we ensure safe, serialized initialization. - logger.debug( - "Queued %d replicas for initialization. Total replicas: %d", - num_replicas, - len(self._replicas), - ) - self._replicas_to_init.extend(replicas) - await self._maybe_init_replicas() - self._replicas.extend(replicas) - - # Start the health loop in the background - self._health_task = asyncio.create_task( - self._health_loop(poll_rate_s=self._cfg.health_poll_rate) - ) - - def _add_endpoint_method(self, endpoint_name: str): - """Dynamically adds an endpoint method to this Service instance.""" - - async def endpoint_method(sess_id: str | None = None, *args, **kwargs): - return await self._call(sess_id, endpoint_name, *args, **kwargs) - - # Set the method on this instance - setattr(self, endpoint_name, endpoint_method) - - async def _call(self, sess_id: str | None, function: str, *args, **kwargs): - """ - Routes a function call to the appropriate replica with load balancing and fault tolerance. - - This is the core routing method that handles: - - Session-based routing for stateful calls - - Round-robin load balancing for stateless calls - - Custom routing based on context hints - - Automatic retry on replica failures - - Request queuing and processing - - Args: - sess_id: Optional session ID for stateful routing - function: Name of the actor endpoint to call - *args: Positional arguments to pass to the endpoint - **kwargs: Keyword arguments to pass to the endpoint - - Returns: - The result from the actor endpoint execution - - Raises: - RuntimeError: If no healthy replicas are available - Exception: Any exception raised by the actor endpoint - """ - # Check context variables for session state if no explicit sess_id - if sess_id is None: - ctx = _session_context.get() - if ctx: - sess_id = ctx["session_id"] - routing_hints = ctx["kwargs"] - else: - routing_hints = {} - else: - routing_hints = {} - - replica = await self._get_replica(sess_id, **routing_hints) - - # Create a ServiceRequest object to queue - request = ServiceRequest( - session_id=sess_id, - function=function, - args=args, - kwargs=kwargs, - future=asyncio.Future(), - ) - - # Queue the request using replica's method - await replica.enqueue_request(request) - - # Wait for the result - try: - return await request.future - except Exception as e: - # If the replica failed, try to retry once - if not replica.healthy: - logger.debug( - "Replica %d failed during request, retrying on healthy replica", - replica.idx, - ) - return await self._retry_request_on_healthy_replica( - sess_id, function, *args, **kwargs - ) - raise - - async def _retry_request_on_healthy_replica( - self, sess_id: str | None, function: str, *args, **kwargs - ): - """Retries a failed request on a healthy replica.""" - # Force reassignment to a healthy replica (only for session-based calls) - if sess_id is not None and sess_id in self._session_replica_map: - del self._session_replica_map[sess_id] - - # Retry the call (this will assign to a new healthy replica) - return await self._call(sess_id, function, *args, **kwargs) - - async def _migrate_remaining_requests(self, failed_replica: Replica): - """Migrates remaining requests from a failed replica to healthy replicas.""" - migrated_requests = [] - - # Collect all remaining requests - while not failed_replica.request_queue.empty(): - try: - request = failed_replica.request_queue.get_nowait() - migrated_requests.append(request) - except asyncio.QueueEmpty: - break - - if not migrated_requests: - return - - logger.debug( - "Migrating %d requests from failed replica %d", - len(migrated_requests), - failed_replica.idx, - ) - - # Find healthy replicas - healthy_replicas = [ - r for r in self._replicas if r.healthy and r != failed_replica - ] - - if not healthy_replicas: - # No healthy replicas, fail all requests - for request in migrated_requests: - request.future.set_exception( - RuntimeError("No healthy replicas available") - ) - return - - # Distribute requests among healthy replicas - for i, request in enumerate(migrated_requests): - target_replica = healthy_replicas[i % len(healthy_replicas)] - await target_replica.enqueue_request(request) - - # Update session mapping if needed - sess_id = request.session_id - if ( - sess_id in self._session_replica_map - and self._session_replica_map[sess_id] == failed_replica.idx - ): - self._session_replica_map[sess_id] = target_replica.idx - - async def start_session(self) -> str: - """ - Starts a new session for stateful request handling. - - Sessions enable request affinity to specific replicas, maintaining state - consistency for workloads that require it. Each session gets a unique ID - and is automatically assigned to the least loaded replica. - - Returns: - str: Unique session identifier for use in subsequent requests - - Example: - >>> session_id = await service.start_session() - >>> result = await service.my_endpoint(session_id, arg1, arg2) - >>> await service.terminate_session(session_id) - """ - sess_id = str(uuid.uuid4()) - session = Session(session_id=sess_id) - self._active_sessions.append(session) - - # Update metrics - self._update_service_metrics() - - return sess_id - - def session(self, **kwargs) -> SessionContext: - """Returns a context manager for session-based calls.""" - return SessionContext(self, **kwargs) - - def _update_service_metrics(self): - """Updates service-level metrics.""" - self._metrics.total_sessions = len(self._active_sessions) - self._metrics.total_replicas = len(self._replicas) - self._metrics.healthy_replicas = sum(1 for r in self._replicas if r.healthy) - # Store direct references to replica metrics for aggregation - self._metrics.replica_metrics = {} - for replica in self._replicas: - # Use the replica's own metrics directly - self._metrics.replica_metrics[replica.idx] = replica.metrics - - def get_metrics(self) -> ServiceMetrics: - """ - Get comprehensive service metrics for monitoring and analysis. - - Returns detailed metrics including per-replica performance data, - service-wide aggregations, and health status information. - - Returns: - ServiceMetrics: Complete metrics object with replica and service data - - Example: - >>> metrics = service.get_metrics() - >>> print(f"Request rate: {metrics.get_total_request_rate():.1f} req/s") - >>> print(f"Queue depth: {metrics.get_avg_queue_depth():.1f}") - """ - self._update_service_metrics() - return self._metrics - - def get_metrics_summary(self) -> dict: - """ - Get a summary of key metrics for monitoring and debugging. - - Provides a structured summary of service and replica metrics in a format - suitable for monitoring dashboards, logging, or debugging purposes. - - Returns: - dict: Structured metrics summary with service and per-replica data - - Example: - >>> summary = service.get_metrics_summary() - >>> print(f"Healthy replicas: {summary['service']['healthy_replicas']}") - >>> for idx, metrics in summary['replicas'].items(): - ... print(f"Replica {idx}: {metrics['request_rate']:.1f} req/s") - """ - self._update_service_metrics() - - summary = { - "service": { - "total_sessions": self._metrics.total_sessions, - "healthy_replicas": self._metrics.healthy_replicas, - "total_replicas": self._metrics.total_replicas, - "total_request_rate": self._metrics.get_total_request_rate(), - "avg_queue_depth": self._metrics.get_avg_queue_depth(self._replicas), - "avg_capacity_utilization": self._metrics.get_avg_capacity_utilization( - self._replicas - ), - "sessions_per_replica": self._metrics.get_sessions_per_replica(), - }, - "replicas": {}, - } - - for replica in self._replicas: - metrics = replica.metrics - - # Count sessions assigned to this replica - assigned_sessions = sum( - 1 - for replica_idx in self._session_replica_map.values() - if replica_idx == replica.idx - ) - - summary["replicas"][replica.idx] = { - "total_requests": metrics.total_requests, - "successful_requests": metrics.successful_requests, - "failed_requests": metrics.failed_requests, - "request_rate": metrics.get_request_rate(), - "avg_latency": metrics.get_avg_latency(), - "active_requests": replica.active_requests, # Get from replica - "queue_depth": replica.request_queue.qsize(), # Get from replica - "assigned_sessions": assigned_sessions, # Calculate from session map - "capacity_utilization": replica.capacity_utilization, # Get from replica - } - - return summary - - async def terminate_session(self, sess_id: str): - """ - Terminates an active session and cleans up associated resources. - - Removes the session from active tracking, clears replica assignments, - and updates service metrics. Sessions should be terminated when no - longer needed to free up resources. - - Args: - sess_id: The unique session identifier to terminate - - Example: - >>> session_id = await service.start_session() - >>> # ... use session for requests ... - >>> await service.terminate_session(session_id) - """ - logger.debug("Terminating session %s", sess_id) - - # Remove from active sessions - self._active_sessions = [ - s for s in self._active_sessions if s.session_id != sess_id - ] - - # Remove from session-replica mapping - if sess_id in self._session_replica_map: - del self._session_replica_map[sess_id] - - # Update metrics - self._update_service_metrics() - - async def _health_loop(self, poll_rate_s: float): - """Runs the health loop to monitor and recover replicas. - - This loop continuously checks the health of replicas and recovers - failed replicas by reinitializing their proc_meshes. It also - periodically updates service metrics to reflect the current state. - - """ - while not self._shutdown_requested: - # Process any replicas that need initialization - await self._maybe_init_replicas() - - # Check for failed replicas and recover them - failed_replicas = [] - for replica in self._replicas: - if replica.failed: - failed_replicas.append(replica) - - if any(failed_replicas): - logger.debug( - "[HEALTH LOOP] Detected %d failed replicas: %s", - len(failed_replicas), - pprint.pformat(failed_replicas), - ) - self._replicas_to_init.extend(failed_replicas) - - await asyncio.sleep(poll_rate_s) - - def _get_next_replica(self) -> "Replica": - """Get the next replica using round-robin selection.""" - healthy_replicas = [r for r in self._replicas if r.healthy] - if not healthy_replicas: - raise RuntimeError("No healthy replicas available for load balancing") - - # Simple round-robin - self._next_replica_idx = (self._next_replica_idx + 1) % len(healthy_replicas) - return healthy_replicas[self._next_replica_idx] - - def _get_least_loaded_replica(self) -> "Replica": - """Get the replica with the lowest load.""" - healthy_replicas = [r for r in self._replicas if r.healthy] - if not healthy_replicas: - raise RuntimeError("No healthy replicas available for session assignment") - - # Load = active_requests + queue_depth - def get_load(replica: "Replica") -> int: - return replica.active_requests + replica.request_queue.qsize() - - return min(healthy_replicas, key=get_load) - - async def _get_replica(self, sess_id: str | None, **kwargs) -> "Replica": - """Get a replica for the given session ID, with optional custom routing hints.""" - if sess_id is None: - # No session, use round-robin load balancing - replica = self._get_next_replica() - return replica - - # Session-based routing - if sess_id in self._session_replica_map: - replica_idx = self._session_replica_map[sess_id] - # Find the replica with this index - for replica in self._replicas: - if replica.idx == replica_idx and replica.healthy: - return replica - # If the replica is no longer healthy, remove from session map and reassign - del self._session_replica_map[sess_id] - - # New session, assign to least loaded replica - replica = self._get_least_loaded_replica() - self._session_replica_map[sess_id] = replica.idx - logger.debug("Assigning session %s to replica %d", sess_id, replica.idx) - return replica - - async def stop(self): - logger.debug("Stopping service...") - # Signal shutdown to health loop - self._shutdown_requested = True - - # Wait for health loop to finish gracefully - if self._health_task is not None: - try: - await asyncio.wait_for(self._health_task, timeout=5.0) - logger.info("Health loop stopped gracefully.") - except asyncio.TimeoutError: - logger.warning("Health loop didn't stop gracefully, cancelling...") - self._health_task.cancel() - try: - await self._health_task - except asyncio.CancelledError: - logger.info("Health loop task cancelled.") - - # Stop all replicas using their stop method - await asyncio.gather( - *[replica.stop() for replica in self._replicas], - return_exceptions=True, - ) - - async def _maybe_init_replicas(self): - """Initializes replicas that are queued for initialization.""" - if not self._replicas_to_init: - return - - logger.debug("Init replicas: %s", pprint.pformat(self._replicas_to_init)) - - # Initialize each replica (proc_mesh and actor spawning) - initialization_tasks = [] - for replica in self._replicas_to_init: - task = asyncio.create_task(self._init_single_replica(replica)) - initialization_tasks.append(task) - - await asyncio.gather(*initialization_tasks, return_exceptions=True) - self._replicas_to_init.clear() - - async def _init_single_replica(self, replica: Replica): - """Initialize a single replica with proc_mesh and actor.""" - try: - # Initialize the proc_mesh - await replica.init_proc_mesh() - - # Spawn the actor using replica's method - await replica.spawn_actor( - self._actor_def, *self._actor_args, **self._actor_kwargs - ) - - logger.debug("Successfully initialized replica %d", replica.idx) - - except Exception as e: - logger.error("Failed to initialize replica %d: %s", replica.idx, e) - # Mark as failed so it can be retried later - replica.mark_failed() - - async def _migrate_replica_workload(self, replica_to_remove: Replica): - """Migrates all workload from a replica that's being removed.""" - # Migrate queued requests - await self._migrate_remaining_requests(replica_to_remove) - - # Reassign sessions to other replicas - sessions_to_reassign = [ - sess_id - for sess_id, replica_idx in self._session_replica_map.items() - if replica_idx == replica_to_remove.idx - ] - - for sess_id in sessions_to_reassign: - del self._session_replica_map[sess_id] - logger.debug("Session %s will be reassigned on next request", sess_id) - - def __repr__(self): - return f"Service(actor={self._actor_def.__name__})" From efe1806813c38480f91665abcc89843745274ed5 Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Fri, 22 Aug 2025 08:27:46 -0700 Subject: [PATCH 5/9] remove v2 from spawn --- src/forge/controller/spawn.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/src/forge/controller/spawn.py b/src/forge/controller/spawn.py index 40b47eb5d..fe0512277 100644 --- a/src/forge/controller/spawn.py +++ b/src/forge/controller/spawn.py @@ -11,10 +11,6 @@ from monarch.actor import Actor from forge.controller import Service, ServiceConfig -from forge.controller.service_v2 import ( - Service as ServiceV2, - ServiceConfig as ServiceConfigV2, -) logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -39,24 +35,3 @@ async def spawn_service( service = Service(service_cfg, actor_def, *actor_args, **actor_kwargs) await service.__initialize__() return service - - -async def spawn_service_v2( - service_cfg: ServiceConfigV2, actor_def: Type[Actor], *actor_args, **actor_kwargs -) -> ServiceV2: - """Spawns a service based on the actor class. - - Args: - service_cfg: Service configuration - actor_def: Actor class definition - *actor_args: Arguments to pass to actor constructor - **actor_kwargs: Keyword arguments to pass to actor constructor - - Returns: - The appropriate service type based on the actor class - """ - # Default to base Service - logger.info("Spawning base Service for %s", actor_def.__name__) - service = ServiceV2(service_cfg, actor_def, *actor_args, **actor_kwargs) - await service.__initialize__() - return service From 7d6b247e71c78764b5b31fba18f51c44c4625b11 Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Fri, 22 Aug 2025 08:44:10 -0700 Subject: [PATCH 6/9] more minor cleanups --- src/forge/controller/service.py | 57 ++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/src/forge/controller/service.py b/src/forge/controller/service.py index bd9c6c791..6fec6df2a 100644 --- a/src/forge/controller/service.py +++ b/src/forge/controller/service.py @@ -32,7 +32,6 @@ ... result = await service.my_endpoint(arg1, arg2) """ - import asyncio import contextvars import logging @@ -96,43 +95,54 @@ def get_avg_capacity_utilization(self, replicas: List) -> float: healthy_replicas = [r for r in replicas if r.healthy] if not healthy_replicas: return 0.0 - total_utilization = sum(r.capacity_utilization for r in healthy_replicas) return total_utilization / len(healthy_replicas) def get_sessions_per_replica(self) -> float: - """Get average sessions per healthy replica.""" - if self.healthy_replicas == 0: + """Get average sessions per replica.""" + if self.total_replicas == 0: return 0.0 - return self.total_sessions / self.healthy_replicas + return self.total_sessions / self.total_replicas + + +# Context variable for session state +_session_context = contextvars.ContextVar("session_context") @dataclass class Session: + """Simple session data holder.""" + session_id: str -# Global context variable for session state -# This is used to propagate session state across async tasks -_session_context: contextvars.ContextVar[dict | None] = contextvars.ContextVar( - "session_context", default=None -) +class SessionContext: + """ + Async context manager for stateful service sessions with automatic lifecycle management. + + Provides a convenient way to maintain stateful connections to replicas across multiple + requests. Sessions ensure that all requests within the context are routed to the same + replica, enabling stateful interactions while handling session lifecycle automatically. + + Example: + >>> async with service.session() as session: + ... # All calls within this block use the same replica + ... result1 = await service.my_endpoint(arg1) + ... result2 = await service.another_endpoint(result1) -class SessionContext: - """Context manager for service sessions using context variables.""" + """ - def __init__(self, service: "Service", **session_kwargs): + def __init__(self, service: "Service"): self.service = service self.session_id: str | None = None - self.session_kwargs = session_kwargs self._token = None async def __aenter__(self): """Start a session and set context variables.""" self.session_id = await self.service.start_session() # Set context for this async task - context_value = {"session_id": self.session_id, "kwargs": self.session_kwargs} + context_value = {"session_id": self.session_id} self._token = _session_context.set(context_value) return self @@ -228,8 +238,8 @@ async def __initialize__(self): num_replicas = self._cfg.num_replicas for i in range(num_replicas): replica = Replica( - proc_config=self._cfg.to_process_config(), idx=len(self._replicas) + i, + proc_config=self._cfg.to_process_config(), max_concurrent_requests=self._cfg.replica_max_concurrent_requests, return_first_rank_result=self._cfg.return_first_rank_result, ) @@ -294,13 +304,8 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): ctx = _session_context.get() if ctx: sess_id = ctx["session_id"] - routing_hints = ctx["kwargs"] - else: - routing_hints = {} - else: - routing_hints = {} - replica = await self._get_replica(sess_id, **routing_hints) + replica = await self._get_replica(sess_id) # Create a ServiceRequest object to queue request = ServiceRequest( @@ -412,9 +417,9 @@ async def start_session(self) -> str: return sess_id - def session(self, **kwargs) -> SessionContext: + def session(self) -> SessionContext: """Returns a context manager for session-based calls.""" - return SessionContext(self, **kwargs) + return SessionContext(self) def _update_service_metrics(self): """Updates service-level metrics.""" @@ -582,8 +587,8 @@ def get_load(replica: "Replica") -> int: return min(healthy_replicas, key=get_load) - async def _get_replica(self, sess_id: str | None, **kwargs) -> "Replica": - """Get a replica for the given session ID, with optional custom routing hints.""" + async def _get_replica(self, sess_id: str | None) -> "Replica": + """Get a replica for the given session ID.""" if sess_id is None: # No session, use round-robin load balancing replica = self._get_next_replica() From 8392ae6ab69c412290620a2d3d79bdb334afcaeb Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Fri, 22 Aug 2025 08:47:19 -0700 Subject: [PATCH 7/9] remove comment --- src/forge/controller/replica.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/controller/replica.py b/src/forge/controller/replica.py index a8cb89644..fec84d0e3 100644 --- a/src/forge/controller/replica.py +++ b/src/forge/controller/replica.py @@ -92,7 +92,7 @@ class Replica: """ A distributed replica that serves as the fundamental unit of work within a service. - Handles process lifecycle, async request queuing, fault recovery, and load balancing. + Handles process lifecycle, async request queuing and fault recovery. Each replica runs independently and can be deployed across multiple hosts via Monarch """ From 41d71da9775b8023fb069a5f0e2a8403bb5a966a Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Fri, 22 Aug 2025 08:48:21 -0700 Subject: [PATCH 8/9] remove comment --- tests/test_service.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_service.py b/tests/test_service.py index 1bfdaad38..dc6782baf 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -90,7 +90,6 @@ async def test_sessionless_calls(): try: # Test sessionless calls - logger.info("Starting requests") await service.incr() await service.incr() result = await service.value() From e6519eefcb550835ff7f11cb2139252c5d0f522a Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Fri, 22 Aug 2025 09:41:40 -0700 Subject: [PATCH 9/9] initial commit of ServiceEndpoint --- src/forge/controller/service.py | 96 ++++++++++++++++-- tests/test_service.py | 172 ++++++++++++++++++++++++++------ 2 files changed, 227 insertions(+), 41 deletions(-) diff --git a/src/forge/controller/service.py b/src/forge/controller/service.py index 6fec6df2a..f14c5b93c 100644 --- a/src/forge/controller/service.py +++ b/src/forge/controller/service.py @@ -38,7 +38,7 @@ import pprint import uuid from dataclasses import dataclass, field -from typing import Dict, List +from typing import Dict, Generic, List, ParamSpec, TypeVar from monarch._src.actor.endpoint import EndpointProperty @@ -48,6 +48,9 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) +P = ParamSpec("P") +R = TypeVar("R") + # TODO - tie this into metrics logger when it exists. @dataclass @@ -155,6 +158,34 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): self.session_id = None +class ServiceEndpoint(Generic[P, R]): + """An endpoint object specific to services. + + This loosely mimics the Endpoint APIs exposed in Monarch, with + a few key differences: + - Only choose and call are retained (dropping stream and call_one) + - Call returns a list directly rather than a ValueMesh. + + These changes are made with Forge use cases in mind, but can + certainly be expanded/adapted in the future. + + """ + + def __init__(self, service: "Service", endpoint_name: str): + self.service = service + self.endpoint_name = endpoint_name + + async def choose( + self, sess_id: str | None = None, *args: P.args, **kwargs: P.kwargs + ) -> R: + """Chooses a replica to call based on context and load balancing strategy.""" + return await self.service._call(sess_id, self.endpoint_name, *args, **kwargs) + + async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + """Broadcasts a request to all healthy replicas and returns the results as a list.""" + return await self.service._call_all(self.endpoint_name, *args, **kwargs) + + class Service: """ Distributed Actor Service Controller @@ -267,13 +298,9 @@ async def __initialize__(self): ) def _add_endpoint_method(self, endpoint_name: str): - """Dynamically adds an endpoint method to this Service instance.""" - - async def endpoint_method(sess_id: str | None = None, *args, **kwargs): - return await self._call(sess_id, endpoint_name, *args, **kwargs) - - # Set the method on this instance - setattr(self, endpoint_name, endpoint_method) + """Dynamically adds a ServiceEndpoint instance to this Service instance.""" + endpoint = ServiceEndpoint(self, endpoint_name) + setattr(self, endpoint_name, endpoint) async def _call(self, sess_id: str | None, function: str, *args, **kwargs): """ @@ -301,7 +328,7 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): """ # Check context variables for session state if no explicit sess_id if sess_id is None: - ctx = _session_context.get() + ctx = _session_context.get(None) if ctx: sess_id = ctx["session_id"] @@ -334,6 +361,57 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): ) raise + async def _call_all(self, function: str, *args, **kwargs) -> List: + """ + Broadcasts a function call to all healthy replicas and returns results as a list. + + Args: + function: Name of the actor endpoint to call + *args: Positional arguments to pass to the endpoint + **kwargs: Keyword arguments to pass to the endpoint + + Returns: + List of results from all healthy replicas + + Raises: + RuntimeError: If no healthy replicas are available + """ + healthy_replicas = [r for r in self._replicas if r.healthy] + + if not healthy_replicas: + raise RuntimeError("No healthy replicas available for broadcast call") + + # Create requests for all healthy replicas + requests = [] + for replica in healthy_replicas: + request = ServiceRequest( + session_id=None, # Broadcast calls don't use sessions + function=function, + args=args, + kwargs=kwargs, + future=asyncio.Future(), + ) + requests.append((replica, request)) + + # Enqueue all requests + for replica, request in requests: + await replica.enqueue_request(request) + + # Wait for all results + results = [] + for replica, request in requests: + try: + result = await request.future + results.append(result) + except Exception as e: + logger.warning( + "Request to replica %d failed during broadcast: %s", replica.idx, e + ) + # Add None for failed replicas to maintain indexing + results.append(None) + + return results + async def _retry_request_on_healthy_replica( self, sess_id: str | None, function: str, *args, **kwargs ): diff --git a/tests/test_service.py b/tests/test_service.py index dc6782baf..791b6da23 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -66,8 +66,8 @@ async def test_basic_service_operations(): assert isinstance(session1, str) # Test endpoint calls - await service.incr(session1) - result = await service.value(session1) + await service.incr.choose(sess_id=session1) + result = await service.value.choose(sess_id=session1) assert result == 1 # Test session mapping @@ -90,9 +90,9 @@ async def test_sessionless_calls(): try: # Test sessionless calls - await service.incr() - await service.incr() - result = await service.value() + await service.incr.choose() + await service.incr.choose() + result = await service.value.choose() assert result is not None # No sessions should be created @@ -121,18 +121,18 @@ async def test_session_context_manager(): try: # Test context manager usage async with service.session(): - await service.incr() - await service.incr() - result = await service.value() + await service.incr.choose() + await service.incr.choose() + result = await service.value.choose() assert result == 2 # Test sequential context managers to avoid interference async def worker(increments: int): async with service.session(): - initial = await service.value() + initial = await service.value.choose() for _ in range(increments): - await service.incr() - final = await service.value() + await service.incr.choose() + final = await service.value.choose() return final - initial # Run sessions sequentially to avoid concurrent modification @@ -162,12 +162,12 @@ async def test_replica_failure_and_recovery(): try: # Create session and cause failure session = await service.start_session() - await service.incr(session) + await service.incr.choose(session) original_replica_idx = service._session_replica_map[session] # Cause failure - error_result = await service.fail_me(session) + error_result = await service.fail_me.choose(session) assert isinstance(error_result, RuntimeError) # Replica should be marked as failed @@ -175,13 +175,13 @@ async def test_replica_failure_and_recovery(): assert not failed_replica.healthy # Session should be reassigned on next call - await service.incr(session) + await service.incr.choose(session) new_replica_idx = service._session_replica_map[session] assert new_replica_idx != original_replica_idx # New sessions should avoid failed replica new_session = await service.start_session() - await service.incr(new_session) + await service.incr.choose(new_session) assigned_replica = service._replicas[service._session_replica_map[new_session]] assert assigned_replica.healthy @@ -204,12 +204,12 @@ async def test_metrics_collection(): session1 = await service.start_session() session2 = await service.start_session() - await service.incr(session1) - await service.incr(session1) - await service.incr(session2) + await service.incr.choose(session1) + await service.incr.choose(session1) + await service.incr.choose(session2) # Test failure metrics - error_result = await service.fail_me(session1) + error_result = await service.fail_me.choose(session1) assert isinstance(error_result, RuntimeError) # Get metrics @@ -256,18 +256,18 @@ async def test_session_stickiness(): session = await service.start_session() # Make multiple calls - await service.incr(session) - await service.incr(session) - await service.incr(session) + await service.incr.choose(session) + await service.incr.choose(session) + await service.incr.choose(session) # Should always route to same replica replica_idx = service._session_replica_map[session] - await service.incr(session) + await service.incr.choose(session) assert service._session_replica_map[session] == replica_idx # Verify counter was incremented correctly - result = await service.value(session) + result = await service.value.choose(session) assert result == 4 finally: @@ -284,16 +284,16 @@ async def test_load_balancing_multiple_sessions(): try: # Create sessions with some load to trigger distribution session1 = await service.start_session() - await service.incr(session1) # Load replica 0 + await service.incr.choose(session1) # Load replica 0 session2 = await service.start_session() - await service.incr(session2) # Should go to replica 1 (least loaded) + await service.incr.choose(session2) # Should go to replica 1 (least loaded) session3 = await service.start_session() - await service.incr(session3) # Should go to replica 0 or 1 based on load + await service.incr.choose(session3) # Should go to replica 0 or 1 based on load session4 = await service.start_session() - await service.incr(session4) # Should balance the load + await service.incr.choose(session4) # Should balance the load # Check that sessions are distributed (may not be perfectly even due to least-loaded logic) replica_assignments = [ @@ -333,10 +333,10 @@ async def test_concurrent_operations(): # Concurrent operations tasks = [ - service.incr(session), # Session call - service.incr(session), # Session call - service.incr(), # Sessionless call - service.incr(), # Sessionless call + service.incr.choose(session), # Session call + service.incr.choose(session), # Session call + service.incr.choose(), # Sessionless call + service.incr.choose(), # Sessionless call ] await asyncio.gather(*tasks) @@ -355,3 +355,111 @@ async def test_concurrent_operations(): finally: await service.stop() + + +# `call` endpoint tests + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_broadcast_call_basic(): + """Test basic broadcast call functionality.""" + cfg = ServiceConfig(procs_per_replica=1, num_replicas=3) + service = await spawn_service(service_cfg=cfg, actor_def=Counter, v=10) + + try: + # Test broadcast call to all replicas + results = await service.incr.call() + + # Should get results from all healthy replicas + assert isinstance(results, list) + assert len(results) == 3 # All 3 replicas should respond + + # All results should be None (incr doesn't return anything) + assert all(result is None for result in results) + + # Test getting values from all replicas + values = await service.value.call() + assert isinstance(values, list) + assert len(values) == 3 + + # All replicas should have incremented from 10 to 11 + assert all(value == 11 for value in values) + + finally: + await service.stop() + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_broadcast_call_with_failed_replica(): + """Test broadcast call behavior when some replicas fail.""" + cfg = ServiceConfig(procs_per_replica=1, num_replicas=3) + service = await spawn_service(service_cfg=cfg, actor_def=Counter, v=0) + + try: + # First, cause one replica to fail by calling fail_me on a specific session + session = await service.start_session() + try: + await service.fail_me.choose(session) + except RuntimeError: + pass # Expected failure + + # Wait briefly for replica to be marked as failed + await asyncio.sleep(0.1) + + # Now test broadcast call - should only hit healthy replicas + results = await service.incr.call() + + # Should get results from healthy replicas only + assert isinstance(results, list) + # Results length should match number of healthy replicas (2 out of 3) + healthy_count = len([r for r in service._replicas if r.healthy]) + assert len(results) == healthy_count + + # Get values from all healthy replicas + values = await service.value.call() + assert len(values) == healthy_count + + # All healthy replicas should have incremented to 1 + assert all(value == 1 for value in values) + + finally: + await service.stop() + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_broadcast_call_vs_choose(): + """Test that broadcast call hits all replicas while choose hits only one.""" + cfg = ServiceConfig(procs_per_replica=1, num_replicas=3) + service = await spawn_service(service_cfg=cfg, actor_def=Counter, v=0) + + try: + # Use broadcast call to increment all replicas + await service.incr.call() + + # Get values from all replicas + values_after_broadcast = await service.value.call() + assert len(values_after_broadcast) == 3 + assert all(value == 1 for value in values_after_broadcast) + + # Use choose to increment only one replica + await service.incr.choose() + + # Get values again - one replica should be at 2, others at 1 + values_after_choose = await service.value.call() + assert len(values_after_choose) == 3 + assert sorted(values_after_choose) == [1, 1, 2] # One replica incremented twice + + # Verify metrics show the correct number of requests + metrics = service.get_metrics_summary() + total_requests = sum( + replica_metrics["total_requests"] + for replica_metrics in metrics["replicas"].values() + ) + # incr.call() (3 requests) + value.call() (3 requests) + incr.choose() (1 request) + value.call() (3 requests) = 10 total + assert total_requests == 10 + + finally: + await service.stop()