From 5c68d21541829472fe11d054e4ef90d3e6642c79 Mon Sep 17 00:00:00 2001 From: DNXie Date: Wed, 17 Sep 2025 12:37:34 -0700 Subject: [PATCH 1/8] move rr router to class --- src/forge/controller/service/interface.py | 14 +++++++++++ src/forge/controller/service/router.py | 29 +++++++++++++++++++++++ src/forge/controller/service/service.py | 24 ++++++------------- tests/unit_tests/test_service.py | 26 ++++++++++++++++++++ 4 files changed, 76 insertions(+), 17 deletions(-) create mode 100644 src/forge/controller/service/router.py diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index 4c8718a03..cd13e771e 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -12,11 +12,14 @@ import contextvars import logging +from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Generic, List, ParamSpec, TypeVar from monarch._src.actor.endpoint import EndpointProperty +from .replica import Replica + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -277,3 +280,14 @@ def __getattr__(self, name: str): raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{name}'" ) + + +class Router(ABC): + """Abstract base class for routing logic.""" + + @abstractmethod + def get_replica( + self, replicas: List[Replica], sess_id: str | None = None + ) -> Replica: + """Select a replica from the list based on routing logic.""" + pass diff --git a/src/forge/controller/service/router.py b/src/forge/controller/service/router.py new file mode 100644 index 000000000..8767a7209 --- /dev/null +++ b/src/forge/controller/service/router.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +from .interface import Router +from .replica import Replica + + +class RoundRobinRouter(Router): + """Round-robin router for stateless requests.""" + + def __init__(self): + self._next_idx = 0 + + def get_replica( + self, replicas: List[Replica], sess_id: str | None = None + ) -> Replica: + healthy_replicas = [r for r in replicas if r.healthy] + if not healthy_replicas: + raise RuntimeError("No healthy replicas available for load balancing") + + self._next_idx = (self._next_idx + 1) % len(healthy_replicas) + replica = healthy_replicas[self._next_idx] + + return replica diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index ede58c821..f98cb254f 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -38,14 +38,16 @@ import uuid from typing import Dict, List -from monarch.actor import Actor, endpoint - from forge.controller.service.interface import _session_context, Session from forge.controller.service.metrics import ServiceMetrics from forge.controller.service.replica import Replica, ServiceRequest + +from forge.controller.service.router import RoundRobinRouter from forge.types import ServiceConfig +from monarch.actor import Actor, endpoint + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -80,7 +82,7 @@ def __init__(self, cfg: ServiceConfig, actor_def, actor_kwargs: dict): self._active_sessions = [] self._id_session_map = {} self._session_replica_map: Dict[str, int] = {} - self._next_replica_idx = 0 # For round-robin load balancing + self._router = RoundRobinRouter() # Initialize metrics collection self._metrics = ServiceMetrics() @@ -455,16 +457,6 @@ async def _health_loop(self, poll_rate_s: float): await asyncio.sleep(poll_rate_s) - def _get_next_replica(self) -> "Replica": - """Get the next replica using round-robin selection.""" - healthy_replicas = [r for r in self._replicas if r.healthy] - if not healthy_replicas: - raise RuntimeError("No healthy replicas available for load balancing") - - # Simple round-robin - self._next_replica_idx = (self._next_replica_idx + 1) % len(healthy_replicas) - return healthy_replicas[self._next_replica_idx] - def _get_least_loaded_replica(self) -> "Replica": """Get the replica with the lowest load.""" healthy_replicas = [r for r in self._replicas if r.healthy] @@ -477,9 +469,8 @@ def _get_least_loaded_replica(self) -> "Replica": async def _get_replica(self, sess_id: str | None) -> "Replica": """Get a replica for the given session ID.""" if sess_id is None: - # No session, use round-robin load balancing - replica = self._get_next_replica() - return replica + # No session, use the default router + return self._router.get_replica(self._replicas) # Session-based routing if sess_id in self._session_replica_map: @@ -592,7 +583,6 @@ async def _get_internal_state(self) -> dict: for replica in self._replicas ], # Load balancing state - "next_replica_idx": self._next_replica_idx, # Service-level state "total_replicas": len(self._replicas), "healthy_replica_count": sum(1 for r in self._replicas if r.healthy), diff --git a/tests/unit_tests/test_service.py b/tests/unit_tests/test_service.py index fb8504ed2..f2a9847c6 100644 --- a/tests/unit_tests/test_service.py +++ b/tests/unit_tests/test_service.py @@ -634,3 +634,29 @@ async def test_broadcast_call_vs_choose(): finally: await service.shutdown() + + +# Rounter Tests +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_round_robin_router_distribution(): + """Test that the RoundRobinRouter distributes sessionless calls evenly across replicas.""" + service = await Counter.options(procs_per_replica=1, num_replicas=3).as_service(v=0) + + try: + # Make multiple sessionless calls using choose() + results = [] + for _ in range(6): + await service.incr.choose() + values = await service.value.call() + print(values) + results.append(values) + print("results: ", results) + # Verify that requests were distributed round-robin + # Each call increments a single replica, so after 6 calls we expect: + # - 2 increments per replica (since 3 replicas, 6 calls) + final_values = results[-1] # last snapshot + assert sorted(final_values) == [2, 2, 2] + + finally: + await service.shutdown() From 4a5ba51d8fd485e7f4960773545dda6de5a5e785 Mon Sep 17 00:00:00 2001 From: DNXie Date: Wed, 17 Sep 2025 14:14:39 -0700 Subject: [PATCH 2/8] fix lint --- src/forge/controller/service/service.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index f98cb254f..05e05ea76 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -38,6 +38,8 @@ import uuid from typing import Dict, List +from monarch.actor import Actor, endpoint + from forge.controller.service.interface import _session_context, Session from forge.controller.service.metrics import ServiceMetrics @@ -46,8 +48,6 @@ from forge.controller.service.router import RoundRobinRouter from forge.types import ServiceConfig -from monarch.actor import Actor, endpoint - logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) From 11e76b486f4b8be40bc34fcd9193668982e0a906 Mon Sep 17 00:00:00 2001 From: DNXie Date: Wed, 17 Sep 2025 15:36:48 -0700 Subject: [PATCH 3/8] add sessionloader and leastloadedloader and tests --- src/forge/controller/service/__init__.py | 7 +- src/forge/controller/service/interface.py | 7 +- src/forge/controller/service/router.py | 65 +++++++- src/forge/controller/service/service.py | 64 ++++---- tests/unit_tests/test_service.py | 185 +++++++++++++++++++++- 5 files changed, 291 insertions(+), 37 deletions(-) diff --git a/src/forge/controller/service/__init__.py b/src/forge/controller/service/__init__.py index aa79a48df..f0d8fca7b 100644 --- a/src/forge/controller/service/__init__.py +++ b/src/forge/controller/service/__init__.py @@ -6,12 +6,14 @@ from .interface import ServiceInterface, Session, SessionContext from .metrics import ServiceMetrics -from .replica import Replica, ReplicaMetrics +from .replica import Replica, ReplicaMetrics, ReplicaState +from .router import LeastLoadedRouter, RoundRobinRouter, SessionRouter from .service import Service, ServiceActor, ServiceConfig __all__ = [ "Replica", "ReplicaMetrics", + "ReplicaState", "Service", "ServiceConfig", "ServiceInterface", @@ -19,4 +21,7 @@ "Session", "SessionContext", "ServiceActor", + "LeastLoadedRouter", + "RoundRobinRouter", + "SessionRouter", ] diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index cd13e771e..a41596c20 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -14,7 +14,7 @@ import logging from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Generic, List, ParamSpec, TypeVar +from typing import Dict, Generic, List, ParamSpec, TypeVar from monarch._src.actor.endpoint import EndpointProperty @@ -287,7 +287,10 @@ class Router(ABC): @abstractmethod def get_replica( - self, replicas: List[Replica], sess_id: str | None = None + self, + replicas: List[Replica], + sess_id: str | None = None, + session_map: Dict[str, int] | None = None, ) -> Replica: """Select a replica from the list based on routing logic.""" pass diff --git a/src/forge/controller/service/router.py b/src/forge/controller/service/router.py index 8767a7209..44c78f4bf 100644 --- a/src/forge/controller/service/router.py +++ b/src/forge/controller/service/router.py @@ -4,11 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List +import logging +from typing import Dict, List from .interface import Router from .replica import Replica +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + class RoundRobinRouter(Router): """Round-robin router for stateless requests.""" @@ -17,7 +21,10 @@ def __init__(self): self._next_idx = 0 def get_replica( - self, replicas: List[Replica], sess_id: str | None = None + self, + replicas: List[Replica], + sess_id: str | None = None, + session_map: Dict[str, int] | None = None, ) -> Replica: healthy_replicas = [r for r in replicas if r.healthy] if not healthy_replicas: @@ -27,3 +34,57 @@ def get_replica( replica = healthy_replicas[self._next_idx] return replica + + +class LeastLoadedRouter(Router): + """Always routes to the replica with the lowest current load.""" + + def get_replica( + self, + replicas: List["Replica"], + sess_id: str | None = None, + session_map: Dict[str, int] | None = None, + ) -> "Replica": + healthy_replicas = [r for r in replicas if r.healthy] + if not healthy_replicas: + raise RuntimeError("No healthy replicas available for session assignment") + return min(healthy_replicas, key=lambda r: r.current_load) + + +class SessionRouter(Router): + """Session-based routing: sticky sessions with a fallback router.""" + + def __init__(self, fallback_router: Router): + self.fallback_router = fallback_router + + def get_replica( + self, + replicas: List["Replica"], + sess_id: str | None = None, + session_map: Dict[str, int] | None = None, + ) -> "Replica": + if sess_id is None: + raise ValueError("SessionRouter requires a session ID") + + if session_map is None: + raise ValueError("Session map must be provided for SessionRouter") + + # Check if session already has a replica + if sess_id in session_map: + replica_idx = session_map[sess_id] + # Find the replica with this index + for r in replicas: + if r.idx == replica_idx and r.healthy: + return r + # If the replica is no longer healthy, remove from session map and reassign + del session_map[sess_id] + + # Use fallback router to assign a new replica + replica = self.fallback_router.get_replica(replicas, sess_id, session_map) + session_map[sess_id] = replica.idx + logger.debug( + "Assigning session %s to replica %d", + sess_id, + replica.idx, + ) + return replica diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index 05e05ea76..61e279d2a 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -36,16 +36,20 @@ import logging import pprint import uuid -from typing import Dict, List +from typing import Dict, List, Type from monarch.actor import Actor, endpoint -from forge.controller.service.interface import _session_context, Session +from forge.controller.service.interface import _session_context, Router, Session from forge.controller.service.metrics import ServiceMetrics from forge.controller.service.replica import Replica, ServiceRequest -from forge.controller.service.router import RoundRobinRouter +from forge.controller.service.router import ( + LeastLoadedRouter, + RoundRobinRouter, + SessionRouter, +) from forge.types import ServiceConfig logger = logging.getLogger(__name__) @@ -64,6 +68,13 @@ class Service: actor_def: Actor class definition to instantiate on each replica *actor_args: Positional arguments passed to actor constructor **actor_kwargs: Keyword arguments passed to actor constructor + router_cls (Type[Router], optional): Router class used for non-session + calls. Defaults to RoundRobinRouter. Examples include RoundRobinRouter + or LeastLoadedRouter. The router is instantiated internally. + fallback_router_cls: Router class used as a fallback when a session + cannot be mapped to an existing replica. Defaults + to LeastLoadedRouter. + Attributes: _cfg: Service configuration @@ -73,16 +84,24 @@ class Service: _endpoints: Dynamically registered actor endpoints """ - def __init__(self, cfg: ServiceConfig, actor_def, actor_kwargs: dict): + def __init__( + self, + cfg: ServiceConfig, + actor_def, + actor_kwargs: dict, + router_cls: Type["Router"] = RoundRobinRouter, + fallback_router_cls: Type["Router"] = LeastLoadedRouter, + ): self._cfg = cfg self._replicas = [] self._actor_def = actor_def self._actor_kwargs = actor_kwargs + self.router_cls = router_cls + self.fallback_router_cls = fallback_router_cls self._active_sessions = [] self._id_session_map = {} self._session_replica_map: Dict[str, int] = {} - self._router = RoundRobinRouter() # Initialize metrics collection self._metrics = ServiceMetrics() @@ -95,6 +114,12 @@ def __init__(self, cfg: ServiceConfig, actor_def, actor_kwargs: dict): async def __initialize__(self): """Initializes the service and starts the health loop.""" logger.debug(f"Starting service up with {self._cfg.num_replicas} replicas.") + + # Initialize the routers + self._default_router = self.router_cls() + self._session_router = SessionRouter(fallback_router=self.fallback_router_cls()) + + # Initialize all replicas replicas = [] num_replicas = self._cfg.num_replicas for i in range(num_replicas): @@ -457,36 +482,15 @@ async def _health_loop(self, poll_rate_s: float): await asyncio.sleep(poll_rate_s) - def _get_least_loaded_replica(self) -> "Replica": - """Get the replica with the lowest load.""" - healthy_replicas = [r for r in self._replicas if r.healthy] - if not healthy_replicas: - raise RuntimeError("No healthy replicas available for session assignment") - - # Use the replica's current_load property - return min(healthy_replicas, key=lambda replica: replica.current_load) - async def _get_replica(self, sess_id: str | None) -> "Replica": """Get a replica for the given session ID.""" if sess_id is None: # No session, use the default router - return self._router.get_replica(self._replicas) - - # Session-based routing - if sess_id in self._session_replica_map: - replica_idx = self._session_replica_map[sess_id] - # Find the replica with this index - for replica in self._replicas: - if replica.idx == replica_idx and replica.healthy: - return replica - # If the replica is no longer healthy, remove from session map and reassign - del self._session_replica_map[sess_id] + return self._default_router.get_replica(self._replicas) - # New session, assign to least loaded replica - replica = self._get_least_loaded_replica() - self._session_replica_map[sess_id] = replica.idx - logger.debug("Assigning session %s to replica %d", sess_id, replica.idx) - return replica + return self._session_router.get_replica( + self._replicas, sess_id, self._session_replica_map + ) async def stop(self): logger.debug("Stopping service...") diff --git a/tests/unit_tests/test_service.py b/tests/unit_tests/test_service.py index f2a9847c6..ad898e3f8 100644 --- a/tests/unit_tests/test_service.py +++ b/tests/unit_tests/test_service.py @@ -13,8 +13,15 @@ import pytest from forge.controller import ForgeActor - -from forge.controller.service import ServiceConfig +from forge.controller.service import ( + LeastLoadedRouter, + Replica, + ReplicaState, + RoundRobinRouter, + ServiceConfig, + SessionRouter, +) +from forge.types import ProcessConfig from monarch.actor import Actor, endpoint logger = logging.getLogger(__name__) @@ -56,6 +63,19 @@ async def add_to_value(self, amount: int, multiplier: int = 1) -> int: return self.v +def make_replica(idx: int, healthy: bool = True, load: int = 0) -> Replica: + """Helper to build a replica with specified state and load.""" + replica = Replica( + idx=idx, + proc_config=ProcessConfig(), + actor_def=Counter, + actor_kwargs={}, + ) + replica.state = ReplicaState.HEALTHY if healthy else ReplicaState.UNHEALTHY + replica.active_requests = load + return replica + + # Core Functionality Tests @@ -637,6 +657,82 @@ async def test_broadcast_call_vs_choose(): # Rounter Tests + + +@pytest.mark.asyncio +async def test_least_loaded_router_basic(): + """LeastLoadedRouter picks the replica with lowest load.""" + replicas = [ + make_replica(0, load=5), + make_replica(1, load=1), + make_replica(2, load=3), + ] + router = LeastLoadedRouter() + chosen = router.get_replica(replicas) + assert chosen.idx == 1 # lowest load + + +@pytest.mark.asyncio +async def test_session_router_assigns_and_updates_session_map(): + """SessionRouter updates session_map and preserves sticky sessions.""" + replicas = [make_replica(0), make_replica(1)] + session_map = {} + fallback = LeastLoadedRouter() + router = SessionRouter(fallback) + + # First request assigns via fallback + r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) + assert session_map["sess1"] == r1.idx + + # Second request should stick + r2 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) + assert r1.idx == r2.idx + + +@pytest.mark.asyncio +async def test_session_router_removes_unhealthy_session_mapping(): + """If mapped replica becomes unhealthy, SessionRouter deletes entry and reassigns.""" + replicas = [make_replica(0, healthy=False), make_replica(1, healthy=True)] + session_map = {"sess1": 0} + fallback = LeastLoadedRouter() + router = SessionRouter(fallback) + + # Replica 0 unhealthy → deleted from session_map → reassigned to 1 + r = router.get_replica(replicas, sess_id="sess1", session_map=session_map) + assert r.idx == 1 + assert session_map["sess1"] == 1 + + +@pytest.mark.asyncio +async def test_session_router_with_round_robin_fallback(): + """Switch fallback router to round-robin and verify assignment order.""" + # Choose RoundRobinRouter as fallback, r1 and r2 should be assigned to different replicas + replicas = [make_replica(0, load=0), make_replica(1, load=5)] + session_map = {} + fallback = RoundRobinRouter() + router = SessionRouter(fallback) + + r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) + r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map) + + assert r1.idx != r2.idx + assert set(session_map.values()) == {0, 1} + + # If LeastLoadedRouter as fallback, r1 and r2 should be assigned to same replicas + replicas = [make_replica(0, load=0), make_replica(1, load=5)] + session_map = {} + fallback = LeastLoadedRouter() + router = SessionRouter(fallback) + + r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) + r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map) + + assert r1.idx == r2.idx == 0 + + +# Router integeration tests + + @pytest.mark.timeout(10) @pytest.mark.asyncio async def test_round_robin_router_distribution(): @@ -660,3 +756,88 @@ async def test_round_robin_router_distribution(): finally: await service.shutdown() + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_session_router_assigns_and_updates_session_map_in_service(): + """Integration: Service with SessionRouter preserves sticky sessions.""" + # Use LeastLoaded as default, SessionRouter (with fallback) is always active + service = await Counter.options( + procs_per_replica=1, + num_replicas=2, + ).as_service(v=0) + + try: + # First call with sess_id -> assign a replica + await service.incr.choose(sess_id="sess1") + values1 = await service.value.call() + + # Second call with same sess_id -> must hit same replica + await service.incr.choose(sess_id="sess1") + values2 = await service.value.call() + + # Difference should only be on one replica (sticky session) + diffs = [v2 - v1 for v1, v2 in zip(values1, values2)] + assert ( + sum(diffs) == 1 + ), f"Expected exactly one replica to increment, got {diffs}" + assert max(diffs) == 1 and min(diffs) == 0 + + # Session map in service should reflect assigned replica + assigned_idx = service._session_replica_map["sess1"] + assert values2[assigned_idx] == values1[assigned_idx] + 1 + + finally: + await service.shutdown() + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def integration_router_test_default_setting(): + """Verify SessionRouter falls back to RoundRobinRouter correctly for new sessions.""" + # Initialize Counter as a real service with 3 replicas + service = await Counter.options( + procs_per_replica=1, + num_replicas=3, + ).as_service(v=0) + + try: + session_map = service._session_replica_map + + # Create new sessions and assign replicas via session router + sess_ids = ["sess1", "sess2", "sess3"] + assigned_indices = [] + + for sess_id in sess_ids: + replica = await service._session_router.get_replica( + service._replicas, sess_id=sess_id, session_map=session_map + ) + assigned_indices.append(replica.idx) + + # Should assign in round-robin order: 0,1,2 + assert assigned_indices == [ + 0, + 1, + 2, + ], f"Expected round-robin assignment, got {assigned_indices}" + + # Reuse a session, should stick to the same replica + replica_again = await service._session_router.get_replica( + service._replicas, sess_id="sess2", session_map=session_map + ) + assert replica_again.idx == 1 + + # Test that making calls through the service works + for sess_id in sess_ids: + await service.incr.choose(sess_id=sess_id) + + # Verify counters updated correctly per session + values = [] + for sess_id in sess_ids: + val = await service.value.choose(sess_id=sess_id) + values.append(val) + assert sorted(values) == [1, 1, 1] + + finally: + await service.shutdown() From 977276478248c3240c652ee156b20e649488dcce Mon Sep 17 00:00:00 2001 From: DNXie Date: Wed, 17 Sep 2025 16:46:03 -0700 Subject: [PATCH 4/8] remove customization --- src/forge/controller/service/service.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index 61e279d2a..530fea340 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -36,11 +36,11 @@ import logging import pprint import uuid -from typing import Dict, List, Type +from typing import Dict, List from monarch.actor import Actor, endpoint -from forge.controller.service.interface import _session_context, Router, Session +from forge.controller.service.interface import _session_context, Session from forge.controller.service.metrics import ServiceMetrics from forge.controller.service.replica import Replica, ServiceRequest @@ -68,12 +68,6 @@ class Service: actor_def: Actor class definition to instantiate on each replica *actor_args: Positional arguments passed to actor constructor **actor_kwargs: Keyword arguments passed to actor constructor - router_cls (Type[Router], optional): Router class used for non-session - calls. Defaults to RoundRobinRouter. Examples include RoundRobinRouter - or LeastLoadedRouter. The router is instantiated internally. - fallback_router_cls: Router class used as a fallback when a session - cannot be mapped to an existing replica. Defaults - to LeastLoadedRouter. Attributes: @@ -89,15 +83,11 @@ def __init__( cfg: ServiceConfig, actor_def, actor_kwargs: dict, - router_cls: Type["Router"] = RoundRobinRouter, - fallback_router_cls: Type["Router"] = LeastLoadedRouter, ): self._cfg = cfg self._replicas = [] self._actor_def = actor_def self._actor_kwargs = actor_kwargs - self.router_cls = router_cls - self.fallback_router_cls = fallback_router_cls self._active_sessions = [] self._id_session_map = {} @@ -116,8 +106,8 @@ async def __initialize__(self): logger.debug(f"Starting service up with {self._cfg.num_replicas} replicas.") # Initialize the routers - self._default_router = self.router_cls() - self._session_router = SessionRouter(fallback_router=self.fallback_router_cls()) + self._default_router = RoundRobinRouter() + self._session_router = SessionRouter(fallback_router=LeastLoadedRouter()) # Initialize all replicas replicas = [] From 9c7360b75c1787052708815738967c2b49fef4ee Mon Sep 17 00:00:00 2001 From: DNXie Date: Wed, 17 Sep 2025 20:08:33 -0700 Subject: [PATCH 5/8] resolve comment --- src/forge/controller/service/interface.py | 2 +- src/forge/controller/service/router.py | 20 ++++++++++---------- src/forge/controller/service/service.py | 5 +++-- tests/unit_tests/test_service.py | 16 +--------------- 4 files changed, 15 insertions(+), 28 deletions(-) diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index a41596c20..a70ec8ad2 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -288,7 +288,7 @@ class Router(ABC): @abstractmethod def get_replica( self, - replicas: List[Replica], + healthy_replicas: List[Replica], sess_id: str | None = None, session_map: Dict[str, int] | None = None, ) -> Replica: diff --git a/src/forge/controller/service/router.py b/src/forge/controller/service/router.py index 44c78f4bf..502402e36 100644 --- a/src/forge/controller/service/router.py +++ b/src/forge/controller/service/router.py @@ -22,11 +22,10 @@ def __init__(self): def get_replica( self, - replicas: List[Replica], + healthy_replicas: List[Replica], sess_id: str | None = None, session_map: Dict[str, int] | None = None, ) -> Replica: - healthy_replicas = [r for r in replicas if r.healthy] if not healthy_replicas: raise RuntimeError("No healthy replicas available for load balancing") @@ -41,11 +40,10 @@ class LeastLoadedRouter(Router): def get_replica( self, - replicas: List["Replica"], + healthy_replicas: List[Replica], sess_id: str | None = None, session_map: Dict[str, int] | None = None, - ) -> "Replica": - healthy_replicas = [r for r in replicas if r.healthy] + ) -> Replica: if not healthy_replicas: raise RuntimeError("No healthy replicas available for session assignment") return min(healthy_replicas, key=lambda r: r.current_load) @@ -59,10 +57,10 @@ def __init__(self, fallback_router: Router): def get_replica( self, - replicas: List["Replica"], + healthy_replicas: List[Replica], sess_id: str | None = None, session_map: Dict[str, int] | None = None, - ) -> "Replica": + ) -> Replica: if sess_id is None: raise ValueError("SessionRouter requires a session ID") @@ -73,14 +71,16 @@ def get_replica( if sess_id in session_map: replica_idx = session_map[sess_id] # Find the replica with this index - for r in replicas: - if r.idx == replica_idx and r.healthy: + for r in healthy_replicas: + if r.idx == replica_idx: return r # If the replica is no longer healthy, remove from session map and reassign del session_map[sess_id] # Use fallback router to assign a new replica - replica = self.fallback_router.get_replica(replicas, sess_id, session_map) + replica = self.fallback_router.get_replica( + healthy_replicas, sess_id, session_map + ) session_map[sess_id] = replica.idx logger.debug( "Assigning session %s to replica %d", diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index 530fea340..2b8d8ab9c 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -474,12 +474,13 @@ async def _health_loop(self, poll_rate_s: float): async def _get_replica(self, sess_id: str | None) -> "Replica": """Get a replica for the given session ID.""" + healthy_replicas = [r for r in self._replicas if r.healthy] if sess_id is None: # No session, use the default router - return self._default_router.get_replica(self._replicas) + return self._default_router.get_replica(healthy_replicas) return self._session_router.get_replica( - self._replicas, sess_id, self._session_replica_map + healthy_replicas, sess_id, self._session_replica_map ) async def stop(self): diff --git a/tests/unit_tests/test_service.py b/tests/unit_tests/test_service.py index ad898e3f8..4c8813ae2 100644 --- a/tests/unit_tests/test_service.py +++ b/tests/unit_tests/test_service.py @@ -656,7 +656,7 @@ async def test_broadcast_call_vs_choose(): await service.shutdown() -# Rounter Tests +# Router Tests @pytest.mark.asyncio @@ -689,20 +689,6 @@ async def test_session_router_assigns_and_updates_session_map(): assert r1.idx == r2.idx -@pytest.mark.asyncio -async def test_session_router_removes_unhealthy_session_mapping(): - """If mapped replica becomes unhealthy, SessionRouter deletes entry and reassigns.""" - replicas = [make_replica(0, healthy=False), make_replica(1, healthy=True)] - session_map = {"sess1": 0} - fallback = LeastLoadedRouter() - router = SessionRouter(fallback) - - # Replica 0 unhealthy → deleted from session_map → reassigned to 1 - r = router.get_replica(replicas, sess_id="sess1", session_map=session_map) - assert r.idx == 1 - assert session_map["sess1"] == 1 - - @pytest.mark.asyncio async def test_session_router_with_round_robin_fallback(): """Switch fallback router to round-robin and verify assignment order.""" From 35d9a121748f8006394893f328f6cdb27faf05d6 Mon Sep 17 00:00:00 2001 From: DNXie Date: Wed, 17 Sep 2025 20:11:04 -0700 Subject: [PATCH 6/8] minor --- tests/unit_tests/test_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/test_service.py b/tests/unit_tests/test_service.py index 4c8813ae2..f0e44c34f 100644 --- a/tests/unit_tests/test_service.py +++ b/tests/unit_tests/test_service.py @@ -736,7 +736,7 @@ async def test_round_robin_router_distribution(): print("results: ", results) # Verify that requests were distributed round-robin # Each call increments a single replica, so after 6 calls we expect: - # - 2 increments per replica (since 3 replicas, 6 calls) + # 2 increments per replica (since 3 replicas, 6 calls) final_values = results[-1] # last snapshot assert sorted(final_values) == [2, 2, 2] From d2cd10579d4c629e1d1f21f6f3400f14f82f53ea Mon Sep 17 00:00:00 2001 From: DNXie Date: Thu, 18 Sep 2025 13:38:55 -0700 Subject: [PATCH 7/8] remove a test case --- tests/unit_tests/test_service.py | 51 -------------------------------- 1 file changed, 51 deletions(-) diff --git a/tests/unit_tests/test_service.py b/tests/unit_tests/test_service.py index f0e44c34f..08c4f5c4e 100644 --- a/tests/unit_tests/test_service.py +++ b/tests/unit_tests/test_service.py @@ -776,54 +776,3 @@ async def test_session_router_assigns_and_updates_session_map_in_service(): finally: await service.shutdown() - - -@pytest.mark.timeout(10) -@pytest.mark.asyncio -async def integration_router_test_default_setting(): - """Verify SessionRouter falls back to RoundRobinRouter correctly for new sessions.""" - # Initialize Counter as a real service with 3 replicas - service = await Counter.options( - procs_per_replica=1, - num_replicas=3, - ).as_service(v=0) - - try: - session_map = service._session_replica_map - - # Create new sessions and assign replicas via session router - sess_ids = ["sess1", "sess2", "sess3"] - assigned_indices = [] - - for sess_id in sess_ids: - replica = await service._session_router.get_replica( - service._replicas, sess_id=sess_id, session_map=session_map - ) - assigned_indices.append(replica.idx) - - # Should assign in round-robin order: 0,1,2 - assert assigned_indices == [ - 0, - 1, - 2, - ], f"Expected round-robin assignment, got {assigned_indices}" - - # Reuse a session, should stick to the same replica - replica_again = await service._session_router.get_replica( - service._replicas, sess_id="sess2", session_map=session_map - ) - assert replica_again.idx == 1 - - # Test that making calls through the service works - for sess_id in sess_ids: - await service.incr.choose(sess_id=sess_id) - - # Verify counters updated correctly per session - values = [] - for sess_id in sess_ids: - val = await service.value.choose(sess_id=sess_id) - values.append(val) - assert sorted(values) == [1, 1, 1] - - finally: - await service.shutdown() From eb32e55a745a3f6358e702b7107fa025236b4099 Mon Sep 17 00:00:00 2001 From: DNXie Date: Thu, 18 Sep 2025 13:43:59 -0700 Subject: [PATCH 8/8] remove redundant tests --- tests/unit_tests/test_service.py | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/tests/unit_tests/test_service.py b/tests/unit_tests/test_service.py index 08c4f5c4e..ee3f39eb0 100644 --- a/tests/unit_tests/test_service.py +++ b/tests/unit_tests/test_service.py @@ -659,36 +659,6 @@ async def test_broadcast_call_vs_choose(): # Router Tests -@pytest.mark.asyncio -async def test_least_loaded_router_basic(): - """LeastLoadedRouter picks the replica with lowest load.""" - replicas = [ - make_replica(0, load=5), - make_replica(1, load=1), - make_replica(2, load=3), - ] - router = LeastLoadedRouter() - chosen = router.get_replica(replicas) - assert chosen.idx == 1 # lowest load - - -@pytest.mark.asyncio -async def test_session_router_assigns_and_updates_session_map(): - """SessionRouter updates session_map and preserves sticky sessions.""" - replicas = [make_replica(0), make_replica(1)] - session_map = {} - fallback = LeastLoadedRouter() - router = SessionRouter(fallback) - - # First request assigns via fallback - r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) - assert session_map["sess1"] == r1.idx - - # Second request should stick - r2 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) - assert r1.idx == r2.idx - - @pytest.mark.asyncio async def test_session_router_with_round_robin_fallback(): """Switch fallback router to round-robin and verify assignment order."""