-
Notifications
You must be signed in to change notification settings - Fork 354
Jeff/fix/native threading #1663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
ff2becd
6bd7ad6
ee752c0
e316626
f13b2b3
3197ad3
43d5434
1ff8769
c202c57
fe787bb
8ae6282
b8fcd08
d930cb3
fb8b40b
1d06db9
75708ff
aff62bc
d5ae028
66eef1f
4595390
3b5c4fd
f9b6d04
e5d739f
530680f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -180,8 +180,7 @@ def start(self) -> None: | |
| def stop(self) -> None: | ||
| if self._uvicorn_server: | ||
| self._uvicorn_server.should_exit = True | ||
| loop = self._loop | ||
| if loop is not None and self._serve_future is not None: | ||
| if self._serve_future is not None: | ||
| self._serve_future.result(timeout=5.0) | ||
| self._uvicorn_server = None | ||
| self._serve_future = None | ||
|
|
@@ -250,6 +249,5 @@ def _start_server(self, port: int | None = None) -> None: | |
| config = uvicorn.Config(app, host=_host, port=_port, log_level="info") | ||
| server = uvicorn.Server(config) | ||
| self._uvicorn_server = server | ||
| loop = self._loop | ||
| assert loop is not None | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. loop always there until stop is called |
||
| loop = self._async_thread.loop | ||
| self._serve_future = asyncio.run_coroutine_threadsafe(server.serve(), loop) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,9 +16,13 @@ | |
|
|
||
| import asyncio | ||
| import json | ||
| import socket | ||
| import time | ||
| from unittest.mock import MagicMock | ||
|
|
||
| from dimos.agents.mcp.mcp_server import handle_request | ||
| import requests | ||
|
|
||
| from dimos.agents.mcp.mcp_server import McpServer, handle_request | ||
| from dimos.core.module import SkillInfo | ||
|
|
||
|
|
||
|
|
@@ -111,3 +115,56 @@ def test_mcp_module_initialize_and_unknown() -> None: | |
|
|
||
| response = asyncio.run(handle_request({"method": "unknown/method", "id": 2}, [], {})) | ||
| assert response["error"]["code"] == -32601 | ||
|
|
||
|
|
||
| def _free_port() -> int: | ||
| with socket.socket() as s: | ||
| s.bind(("", 0)) | ||
| return s.getsockname()[1] | ||
|
|
||
|
|
||
| def test_mcp_server_lifecycle() -> None: | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. new test |
||
| """Start a real McpServer, hit the HTTP endpoint, then stop it. | ||
|
|
||
| This exercises the AsyncModuleThread event loop integration that the | ||
| unit tests above do not cover. | ||
| """ | ||
| port = _free_port() | ||
|
|
||
| server = McpServer() | ||
| server._start_server(port=port) | ||
| url = f"http://127.0.0.1:{port}/mcp" | ||
|
|
||
| # Wait for the server to be ready | ||
| for _ in range(40): | ||
| try: | ||
| resp = requests.post( | ||
| url, | ||
| json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, | ||
| timeout=0.5, | ||
| ) | ||
| if resp.status_code == 200: | ||
| break | ||
| except requests.ConnectionError: | ||
| time.sleep(0.1) | ||
| else: | ||
| server.stop() | ||
| raise AssertionError("McpServer did not become ready") | ||
|
|
||
| # Verify it responds | ||
| data = resp.json() | ||
| assert data["result"]["serverInfo"]["name"] == "dimensional" | ||
|
|
||
| # Stop and verify it shuts down | ||
| server.stop() | ||
| time.sleep(0.3) | ||
|
|
||
| with socket.socket() as s: | ||
| # Port should be released after stop | ||
| try: | ||
| s.connect(("127.0.0.1", port)) | ||
| s.close() | ||
| # If we could connect, the server is still up — that's a bug | ||
| raise AssertionError("McpServer still listening after stop()") | ||
| except ConnectionRefusedError: | ||
| pass # expected — server is down | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,17 +11,16 @@ | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import asyncio | ||
| from collections.abc import Callable | ||
| from dataclasses import dataclass | ||
| from functools import partial | ||
| import inspect | ||
| import json | ||
| import sys | ||
| import threading | ||
| from typing import ( | ||
| TYPE_CHECKING, | ||
| Any, | ||
| Literal, | ||
| Protocol, | ||
| get_args, | ||
| get_origin, | ||
|
|
@@ -45,6 +44,9 @@ | |
| from dimos.protocol.tf.tf import LCMTF, TFSpec | ||
| from dimos.utils import colors | ||
| from dimos.utils.generic import classproperty | ||
| from dimos.utils.thread_utils import AsyncModuleThread, ThreadSafeVal | ||
|
|
||
| ModState = Literal["init", "started", "stopped"] | ||
|
|
||
| if TYPE_CHECKING: | ||
| from dimos.core.blueprints import Blueprint | ||
|
|
@@ -64,19 +66,6 @@ class SkillInfo: | |
| args_schema: str | ||
|
|
||
|
|
||
| def get_loop() -> tuple[asyncio.AbstractEventLoop, threading.Thread | None]: | ||
| try: | ||
| running_loop = asyncio.get_running_loop() | ||
| return running_loop, None | ||
| except RuntimeError: | ||
| loop = asyncio.new_event_loop() | ||
| asyncio.set_event_loop(loop) | ||
|
|
||
| thr = threading.Thread(target=loop.run_forever, daemon=True) | ||
| thr.start() | ||
| return loop, thr | ||
|
|
||
|
|
||
| class ModuleConfig(BaseConfig): | ||
| rpc_transport: type[RPCSpec] = LCMRPC | ||
| tf_transport: type[TFSpec] = LCMTF # type: ignore[type-arg] | ||
|
|
@@ -98,20 +87,22 @@ class ModuleBase(Configurable[ModuleConfigT], Resource): | |
|
|
||
| _rpc: RPCSpec | None = None | ||
| _tf: TFSpec[Any] | None = None | ||
| _loop: asyncio.AbstractEventLoop | None = None | ||
| _loop_thread: threading.Thread | None | ||
| _async_thread: AsyncModuleThread | ||
| _disposables: CompositeDisposable | ||
| _bound_rpc_calls: dict[str, RpcCall] = {} | ||
| _module_closed: bool = False | ||
| _module_closed_lock: threading.Lock | ||
| mod_state: ThreadSafeVal[ModState] | ||
|
|
||
| rpc_calls: list[str] = [] | ||
|
|
||
| def __init__(self, config_args: dict[str, Any]): | ||
| self.mod_state = ThreadSafeVal[ModState]("init") | ||
| super().__init__(**config_args) | ||
| self._module_closed_lock = threading.Lock() | ||
| self._loop, self._loop_thread = get_loop() | ||
| self._disposables = CompositeDisposable() | ||
| self._async_thread = ( | ||
| AsyncModuleThread( # NEEDS to be created after self._disposables exists | ||
| module=self | ||
| ) | ||
| ) | ||
| try: | ||
| self.rpc = self.config.rpc_transport() | ||
| self.rpc.serve_module_rpc(self) | ||
|
|
@@ -128,58 +119,43 @@ def frame_id(self) -> str: | |
|
|
||
| @rpc | ||
| def start(self) -> None: | ||
| pass | ||
| with self.mod_state as state: | ||
| if state == "stopped": | ||
| raise RuntimeError(f"{type(self).__name__} cannot be restarted after stop") | ||
| self.mod_state.set("started") | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know lots of modules don't call super().start() but they also wouldn't be using mod_state cause its a new thing. Different/off-topic discussion, but I think core2 should have ModuleBase as class decorator instead of an inherited class (we can basically wrap methods instead of saying "please remember to call super"). |
||
| self._async_thread.start() | ||
|
|
||
| @rpc | ||
| def stop(self) -> None: | ||
| self._close_module() | ||
|
|
||
| def _close_module(self) -> None: | ||
| with self._module_closed_lock: | ||
| if self._module_closed: | ||
| with self.mod_state as state: | ||
| if state == "stopped": | ||
| return | ||
| self._module_closed = True | ||
|
|
||
| self._close_rpc() | ||
| self.mod_state.set("stopped") | ||
|
|
||
| # Save into local variables to avoid race when stopping concurrently | ||
| # (from RPC and worker shutdown) | ||
| loop_thread = getattr(self, "_loop_thread", None) | ||
| loop = getattr(self, "_loop", None) | ||
| # dispose of things BEFORE making aspects like rpc and _tf invalid | ||
| if hasattr(self, "_disposables"): | ||
| self._disposables.dispose() # stops _async_thread via disposable | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think its important to move disposables up before the rpc stop and the tf stop |
||
|
|
||
| if loop_thread: | ||
| if loop_thread.is_alive(): | ||
| if loop: | ||
| loop.call_soon_threadsafe(loop.stop) | ||
| loop_thread.join(timeout=2) | ||
| self._loop = None | ||
| self._loop_thread = None | ||
| if self.rpc: | ||
| self.rpc.stop() # type: ignore[attr-defined] | ||
| self.rpc = None # type: ignore[assignment] | ||
|
|
||
| if hasattr(self, "_tf") and self._tf is not None: | ||
| self._tf.stop() | ||
| self._tf = None | ||
| if hasattr(self, "_disposables"): | ||
| self._disposables.dispose() | ||
|
|
||
| # Break the In/Out -> owner -> self reference cycle so the instance | ||
| # can be freed by refcount instead of waiting for GC. | ||
| for attr in list(vars(self).values()): | ||
| if isinstance(attr, (In, Out)): | ||
| attr.owner = None | ||
|
|
||
| def _close_rpc(self) -> None: | ||
| if self.rpc: | ||
| self.rpc.stop() # type: ignore[attr-defined] | ||
| self.rpc = None # type: ignore[assignment] | ||
|
|
||
| def __getstate__(self): # type: ignore[no-untyped-def] | ||
| """Exclude unpicklable runtime attributes when serializing.""" | ||
| state = self.__dict__.copy() | ||
| # Remove unpicklable attributes | ||
| state.pop("_disposables", None) | ||
| state.pop("_module_closed_lock", None) | ||
| state.pop("_loop", None) | ||
| state.pop("_loop_thread", None) | ||
| state.pop("_async_thread", None) | ||
| state.pop("_rpc", None) | ||
| state.pop("_tf", None) | ||
| return state | ||
|
|
@@ -189,9 +165,7 @@ def __setstate__(self, state) -> None: # type: ignore[no-untyped-def] | |
| self.__dict__.update(state) | ||
| # Reinitialize runtime attributes | ||
| self._disposables = CompositeDisposable() | ||
| self._module_closed_lock = threading.Lock() | ||
| self._loop = None | ||
| self._loop_thread = None | ||
| self._async_thread = None # type: ignore[assignment] | ||
| self._rpc = None | ||
| self._tf = None | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the loop is always there until super().stop() is called