Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/forge/controller/service/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, Generic, List, ParamSpec, TypeVar
from typing import Generic, ParamSpec, TypeVar

from monarch._src.actor.endpoint import EndpointProperty

Expand Down Expand Up @@ -96,7 +96,7 @@ async def route(self, *args: P.args, **kwargs: P.kwargs) -> R:
sess_id = kwargs.pop("sess_id", None)
return await self.service._call(sess_id, self.endpoint_name, *args, **kwargs)

async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]:
async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> list[R]:
"""Broadcasts a request to all healthy replicas and returns the results as a list."""
result = await self.service.call_all(self.endpoint_name, *args, **kwargs)
return result
Expand All @@ -107,7 +107,7 @@ async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R:
"Services only support route() and fanout()."
)

async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]:
async def call(self, *args: P.args, **kwargs: P.kwargs) -> list[R]:
raise NotImplementedError(
"You tried to use call() on a service, not an actor. "
"Services only support route() and fanout()."
Expand All @@ -119,7 +119,7 @@ async def call_one(self, *args: P.args, **kwargs: P.kwargs) -> R:
"Services only support route() and fanout()."
)

async def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> List[R]:
async def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> list[R]:
raise NotImplementedError(
"You tried to use broadcast() on a service, not an actor. "
"Services only support route() and fanout()."
Expand Down Expand Up @@ -157,7 +157,7 @@ async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R:
sess_id, self.endpoint_name, *args, **kwargs
)

async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]:
async def call(self, *args: P.args, **kwargs: P.kwargs) -> list[R]:
"""Broadcasts a request to all healthy replicas and returns the results as a list."""
result = await self.actor_mesh.call_all.call_one(
self.endpoint_name, *args, **kwargs
Expand Down Expand Up @@ -314,9 +314,9 @@ class Router(ABC):
@abstractmethod
def get_replica(
self,
healthy_replicas: List[Replica],
healthy_replicas: list[Replica],
sess_id: str | None = None,
session_map: Dict[str, int] | None = None,
session_map: dict[str, int] | None = None,
) -> Replica:
"""Select a replica from the list based on routing logic."""
pass
7 changes: 3 additions & 4 deletions src/forge/controller/service/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
"""

from dataclasses import dataclass, field
from typing import Dict, List

from forge.controller.service.replica import ReplicaMetrics

Expand All @@ -35,7 +34,7 @@ class ServiceMetrics:
"""

# Replica metrics
replica_metrics: Dict[int, ReplicaMetrics] = field(default_factory=dict)
replica_metrics: dict[int, ReplicaMetrics] = field(default_factory=dict)
# Service-level metrics
total_sessions: int = 0
healthy_replicas: int = 0
Expand All @@ -50,15 +49,15 @@ def get_total_request_rate(self, window_seconds: float = 60.0) -> float:
for metrics in self.replica_metrics.values()
)

def get_avg_queue_depth(self, replicas: List) -> float:
def get_avg_queue_depth(self, replicas: list) -> float:
"""Get average queue depth across all healthy replicas."""
healthy_replicas = [r for r in replicas if r.healthy]
if not healthy_replicas:
return 0.0
total_queue_depth = sum(r.request_queue.qsize() for r in healthy_replicas)
return total_queue_depth / len(healthy_replicas)

def get_avg_capacity_utilization(self, replicas: List) -> float:
def get_avg_capacity_utilization(self, replicas: list) -> float:
"""Get average capacity utilization across all healthy replicas."""
healthy_replicas = [r for r in replicas if r.healthy]
if not healthy_replicas:
Expand Down
13 changes: 6 additions & 7 deletions src/forge/controller/service/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Dict, List

from .interface import Router
from .replica import Replica
Expand All @@ -22,9 +21,9 @@ def __init__(self):

def get_replica(
self,
healthy_replicas: List[Replica],
healthy_replicas: list[Replica],
sess_id: str | None = None,
session_map: Dict[str, int] | None = None,
session_map: dict[str, int] | None = None,
) -> Replica:
if not healthy_replicas:
raise RuntimeError("No healthy replicas available for load balancing")
Expand All @@ -40,9 +39,9 @@ class LeastLoadedRouter(Router):

def get_replica(
self,
healthy_replicas: List[Replica],
healthy_replicas: list[Replica],
sess_id: str | None = None,
session_map: Dict[str, int] | None = None,
session_map: dict[str, int] | None = None,
) -> Replica:
if not healthy_replicas:
raise RuntimeError("No healthy replicas available for session assignment")
Expand All @@ -57,9 +56,9 @@ def __init__(self, fallback_router: Router):

def get_replica(
self,
healthy_replicas: List[Replica],
healthy_replicas: list[Replica],
sess_id: str | None = None,
session_map: Dict[str, int] | None = None,
session_map: dict[str, int] | None = None,
) -> Replica:
if sess_id is None:
raise ValueError("SessionRouter requires a session ID")
Expand Down
9 changes: 4 additions & 5 deletions src/forge/controller/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import logging
import pprint
import uuid
from typing import Dict, List

from monarch.actor import Actor, endpoint

Expand Down Expand Up @@ -92,7 +91,7 @@ def __init__(

self._active_sessions = []
self._id_session_map = {}
self._session_replica_map: Dict[str, int] = {}
self._session_replica_map: dict[str, int] = {}

# Initialize metrics collection
self._metrics = ServiceMetrics()
Expand Down Expand Up @@ -196,7 +195,7 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs):
)
raise

async def call_all(self, function: str, *args, **kwargs) -> List:
async def call_all(self, function: str, *args, **kwargs) -> list:
"""
Broadcasts a function call to all healthy replicas and returns results as a list.

Expand Down Expand Up @@ -622,7 +621,7 @@ def __init__(self, cfg: ServiceConfig, actor_def, actor_kwargs: dict):

self._active_sessions = []
self._id_session_map = {}
self._session_replica_map: Dict[str, int] = {}
self._session_replica_map: dict[str, int] = {}
self._next_replica_idx = 0 # For round-robin load balancing

# Initialize metrics collection
Expand Down Expand Up @@ -726,7 +725,7 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs):
raise

@endpoint
async def call_all(self, function: str, *args, **kwargs) -> List:
async def call_all(self, function: str, *args, **kwargs) -> list:
"""
Broadcasts a function call to all healthy replicas and returns results as a list.

Expand Down
22 changes: 11 additions & 11 deletions src/forge/observability/metric_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import asyncio
import logging
import os
from typing import Any, Dict, Union
from typing import Any, Union

from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc

Expand Down Expand Up @@ -127,7 +127,7 @@ def __init__(self, global_logger: Union["GlobalLoggingActor", None] = None) -> N
@endpoint
async def flush(
self, global_step: int, return_state: bool = False
) -> Dict[str, Dict[str, Any]]:
) -> dict[str, dict[str, Any]]:
"""Log to local logger backends (if any), reset accumulators and return metric states dict if return_state=True.
This should only ever be called by the global logger.

Expand All @@ -136,7 +136,7 @@ async def flush(
return_state (bool): Used by GlobalLoggingActor for reduction across all ranks.
If False, returns empty dict, else returns the state of all metrics collected.
Returns:
Dict[str, Dict[str, Any]]: Dict of {metric_key: metric_state},
dict[str, dict[str, Any]]: Dict of {metric_key: metric_state},
e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}.
"""
collector = MetricCollector()
Expand All @@ -146,8 +146,8 @@ async def flush(
@endpoint
async def init_backends(
self,
metadata_per_primary_backend: Dict[str, Dict[str, Any]],
config: Dict[str, Any],
metadata_per_primary_backend: dict[str, dict[str, Any]],
config: dict[str, Any],
) -> None:
"""Init local (per-rank) logger backends and MetricCollector."""
collector = MetricCollector()
Expand Down Expand Up @@ -179,13 +179,13 @@ class GlobalLoggingActor(Actor):
"""

def __init__(self):
self.fetchers: Dict[str, LocalFetcherActor] = {}
self.config: Dict[str, Any] | None = None
self.global_logger_backends: Dict[str, LoggerBackend] = {}
self.metadata_per_primary_backend: Dict[str, Dict[str, Any]] = {}
self.fetchers: dict[str, LocalFetcherActor] = {}
self.config: dict[str, Any] | None = None
self.global_logger_backends: dict[str, LoggerBackend] = {}
self.metadata_per_primary_backend: dict[str, dict[str, Any]] = {}

@endpoint
async def init_backends(self, config: Dict[str, Any]) -> None:
async def init_backends(self, config: dict[str, Any]) -> None:
"""
Sets config in global actor, so other actors can get it, then eagerly initializes backend and MetricCollectors
in all registered fetchers.
Expand All @@ -201,7 +201,7 @@ async def init_backends(self, config: Dict[str, Any]) -> None:
and reduce them to a single value, which will be logged by the primary backend in this controller.

Args:
config (Dict[str, Any]): Config for metric logging where keys are backend names,
config (dict[str, Any]): Config for metric logging where keys are backend names,
e.g. {"console": {"reduce_across_ranks": True}, "wandb": {"reduce_across_ranks": False}}
"""
self.config = config
Expand Down
Loading
Loading