diff --git a/dimos/agents/mcp/mcp_server.py b/dimos/agents/mcp/mcp_server.py index 9149de06ec..1e1d7d9942 100644 --- a/dimos/agents/mcp/mcp_server.py +++ b/dimos/agents/mcp/mcp_server.py @@ -180,8 +180,7 @@ def start(self) -> None: def stop(self) -> None: if self._uvicorn_server: self._uvicorn_server.should_exit = True - loop = self._loop - if loop is not None and self._serve_future is not None: + if self._serve_future is not None: self._serve_future.result(timeout=5.0) self._uvicorn_server = None self._serve_future = None @@ -250,6 +249,5 @@ def _start_server(self, port: int | None = None) -> None: config = uvicorn.Config(app, host=_host, port=_port, log_level="info") server = uvicorn.Server(config) self._uvicorn_server = server - loop = self._loop - assert loop is not None + loop = self._async_thread.loop self._serve_future = asyncio.run_coroutine_threadsafe(server.serve(), loop) diff --git a/dimos/agents/mcp/test_mcp_server.py b/dimos/agents/mcp/test_mcp_server.py index 1cbca9e3e4..0424a2b16f 100644 --- a/dimos/agents/mcp/test_mcp_server.py +++ b/dimos/agents/mcp/test_mcp_server.py @@ -16,9 +16,13 @@ import asyncio import json +import socket +import time from unittest.mock import MagicMock -from dimos.agents.mcp.mcp_server import handle_request +import requests + +from dimos.agents.mcp.mcp_server import McpServer, handle_request from dimos.core.module import SkillInfo @@ -111,3 +115,56 @@ def test_mcp_module_initialize_and_unknown() -> None: response = asyncio.run(handle_request({"method": "unknown/method", "id": 2}, [], {})) assert response["error"]["code"] == -32601 + + +def _free_port() -> int: + with socket.socket() as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def test_mcp_server_lifecycle() -> None: + """Start a real McpServer, hit the HTTP endpoint, then stop it. + + This exercises the AsyncModuleThread event loop integration that the + unit tests above do not cover. + """ + port = _free_port() + + server = McpServer() + server._start_server(port=port) + url = f"http://127.0.0.1:{port}/mcp" + + # Wait for the server to be ready + for _ in range(40): + try: + resp = requests.post( + url, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + timeout=0.5, + ) + if resp.status_code == 200: + break + except requests.ConnectionError: + time.sleep(0.1) + else: + server.stop() + raise AssertionError("McpServer did not become ready") + + # Verify it responds + data = resp.json() + assert data["result"]["serverInfo"]["name"] == "dimensional" + + # Stop and verify it shuts down + server.stop() + time.sleep(0.3) + + with socket.socket() as s: + # Port should be released after stop + try: + s.connect(("127.0.0.1", port)) + s.close() + # If we could connect, the server is still up — that's a bug + raise AssertionError("McpServer still listening after stop()") + except ConnectionRefusedError: + pass # expected — server is down diff --git a/dimos/core/module.py b/dimos/core/module.py index 1c5b311883..83dfeddcd4 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -11,17 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import asyncio from collections.abc import Callable from dataclasses import dataclass from functools import partial import inspect import json import sys -import threading from typing import ( TYPE_CHECKING, Any, + Literal, Protocol, get_args, get_origin, @@ -45,6 +44,9 @@ from dimos.protocol.tf.tf import LCMTF, TFSpec from dimos.utils import colors from dimos.utils.generic import classproperty +from dimos.utils.thread_utils import AsyncModuleThread, ThreadSafeVal + +ModState = Literal["init", "started", "stopped"] if TYPE_CHECKING: from dimos.core.blueprints import Blueprint @@ -64,19 +66,6 @@ class SkillInfo: args_schema: str -def get_loop() -> tuple[asyncio.AbstractEventLoop, threading.Thread | None]: - try: - running_loop = asyncio.get_running_loop() - return running_loop, None - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - thr = threading.Thread(target=loop.run_forever, daemon=True) - thr.start() - return loop, thr - - class ModuleConfig(BaseConfig): rpc_transport: type[RPCSpec] = LCMRPC tf_transport: type[TFSpec] = LCMTF # type: ignore[type-arg] @@ -98,20 +87,22 @@ class ModuleBase(Configurable[ModuleConfigT], Resource): _rpc: RPCSpec | None = None _tf: TFSpec[Any] | None = None - _loop: asyncio.AbstractEventLoop | None = None - _loop_thread: threading.Thread | None + _async_thread: AsyncModuleThread _disposables: CompositeDisposable _bound_rpc_calls: dict[str, RpcCall] = {} - _module_closed: bool = False - _module_closed_lock: threading.Lock + mod_state: ThreadSafeVal[ModState] rpc_calls: list[str] = [] def __init__(self, config_args: dict[str, Any]): + self.mod_state = ThreadSafeVal[ModState]("init") super().__init__(**config_args) - self._module_closed_lock = threading.Lock() - self._loop, self._loop_thread = get_loop() self._disposables = CompositeDisposable() + self._async_thread = ( + AsyncModuleThread( # NEEDS to be created after self._disposables exists + module=self + ) + ) try: self.rpc = self.config.rpc_transport() self.rpc.serve_module_rpc(self) @@ -128,38 +119,30 @@ def frame_id(self) -> str: @rpc def start(self) -> None: - pass + with self.mod_state as state: + if state == "stopped": + raise RuntimeError(f"{type(self).__name__} cannot be restarted after stop") + self.mod_state.set("started") + self._async_thread.start() @rpc def stop(self) -> None: - self._close_module() - - def _close_module(self) -> None: - with self._module_closed_lock: - if self._module_closed: + with self.mod_state as state: + if state == "stopped": return - self._module_closed = True - - self._close_rpc() + self.mod_state.set("stopped") - # Save into local variables to avoid race when stopping concurrently - # (from RPC and worker shutdown) - loop_thread = getattr(self, "_loop_thread", None) - loop = getattr(self, "_loop", None) + # dispose of things BEFORE making aspects like rpc and _tf invalid + if hasattr(self, "_disposables"): + self._disposables.dispose() # stops _async_thread via disposable - if loop_thread: - if loop_thread.is_alive(): - if loop: - loop.call_soon_threadsafe(loop.stop) - loop_thread.join(timeout=2) - self._loop = None - self._loop_thread = None + if self.rpc: + self.rpc.stop() # type: ignore[attr-defined] + self.rpc = None # type: ignore[assignment] if hasattr(self, "_tf") and self._tf is not None: self._tf.stop() self._tf = None - if hasattr(self, "_disposables"): - self._disposables.dispose() # Break the In/Out -> owner -> self reference cycle so the instance # can be freed by refcount instead of waiting for GC. @@ -167,19 +150,12 @@ def _close_module(self) -> None: if isinstance(attr, (In, Out)): attr.owner = None - def _close_rpc(self) -> None: - if self.rpc: - self.rpc.stop() # type: ignore[attr-defined] - self.rpc = None # type: ignore[assignment] - def __getstate__(self): # type: ignore[no-untyped-def] """Exclude unpicklable runtime attributes when serializing.""" state = self.__dict__.copy() # Remove unpicklable attributes state.pop("_disposables", None) - state.pop("_module_closed_lock", None) - state.pop("_loop", None) - state.pop("_loop_thread", None) + state.pop("_async_thread", None) state.pop("_rpc", None) state.pop("_tf", None) return state @@ -189,9 +165,7 @@ def __setstate__(self, state) -> None: # type: ignore[no-untyped-def] self.__dict__.update(state) # Reinitialize runtime attributes self._disposables = CompositeDisposable() - self._module_closed_lock = threading.Lock() - self._loop = None - self._loop_thread = None + self._async_thread = None # type: ignore[assignment] self._rpc = None self._tf = None diff --git a/dimos/core/native_module.py b/dimos/core/native_module.py index 74471f34d5..080ad7df13 100644 --- a/dimos/core/native_module.py +++ b/dimos/core/native_module.py @@ -40,23 +40,20 @@ class MyCppModule(NativeModule): from __future__ import annotations -import collections import enum import inspect -import json import os from pathlib import Path -import signal import subprocess import sys -import threading -from typing import IO, Any +from typing import Any from pydantic import Field from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.utils.logging_config import setup_logger +from dimos.utils.thread_utils import ModuleProcess if sys.version_info < (3, 13): from typing_extensions import TypeVar @@ -129,136 +126,35 @@ class NativeModule(Module[_NativeConfig]): """ default_config: type[_NativeConfig] = NativeModuleConfig # type: ignore[assignment] - _process: subprocess.Popen[bytes] | None = None - _watchdog: threading.Thread | None = None - _stopping: bool = False - _last_stderr_lines: collections.deque[str] + _proc: ModuleProcess | None = None def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - self._last_stderr_lines = collections.deque(maxlen=50) self._resolve_paths() @rpc def start(self) -> None: - if self._process is not None and self._process.poll() is None: - logger.warning("Native process already running", pid=self._process.pid) - return - + super().start() self._maybe_build() - - topics = self._collect_topics() - - cmd = [self.config.executable] - for name, topic_str in topics.items(): - cmd.extend([f"--{name}", topic_str]) - cmd.extend(self.config.to_cli_args()) - cmd.extend(self.config.extra_args) - - env = {**os.environ, **self.config.extra_env} - cwd = self.config.cwd or str(Path(self.config.executable).resolve().parent) - - module_name = type(self).__name__ - logger.info( - f"Starting native process: {module_name}", - module=module_name, - cmd=" ".join(cmd), - cwd=cwd, - ) - self._process = subprocess.Popen( - cmd, - env=env, - cwd=cwd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - logger.info( - f"Native process started: {module_name}", - module=module_name, - pid=self._process.pid, - ) - - self._stopping = False - self._watchdog = threading.Thread(target=self._watch_process, daemon=True) - self._watchdog.start() - - @rpc - def stop(self) -> None: - self._stopping = True - if self._process is not None and self._process.poll() is None: - logger.info("Stopping native process", pid=self._process.pid) - self._process.send_signal(signal.SIGTERM) - try: - self._process.wait(timeout=self.config.shutdown_timeout) - except subprocess.TimeoutExpired: - logger.warning( - "Native process did not exit, sending SIGKILL", pid=self._process.pid - ) - self._process.kill() - self._process.wait(timeout=5) - if self._watchdog is not None and self._watchdog is not threading.current_thread(): - self._watchdog.join(timeout=2) - self._watchdog = None - self._process = None - super().stop() - - def _watch_process(self) -> None: - """Block until the native process exits; trigger stop() if it crashed.""" - if self._process is None: - return - - stdout_t = self._start_reader(self._process.stdout, "info") - stderr_t = self._start_reader(self._process.stderr, "warning") - rc = self._process.wait() - stdout_t.join(timeout=2) - stderr_t.join(timeout=2) - - if self._stopping: - return - - module_name = type(self).__name__ - exe_name = Path(self.config.executable).name if self.config.executable else "unknown" - - # Use buffered stderr lines from the reader thread for the crash report. - last_stderr = "\n".join(self._last_stderr_lines) - - logger.error( - f"Native process crashed: {module_name} ({exe_name})", - module=module_name, - executable=exe_name, - pid=self._process.pid, - returncode=rc, - last_stderr=last_stderr[:500] if last_stderr else None, + self._proc = ModuleProcess( + module=self, + args=[ + self.config.executable, + *[ + arg + for name, topic in self._collect_topics().items() + for arg in (f"--{name}", topic) + ], + *self.config.to_cli_args(), + *self.config.extra_args, + ], + env={**os.environ, **self.config.extra_env}, + cwd=self.config.cwd or str(Path(self.config.executable).resolve().parent), + on_exit=self.stop, + shutdown_timeout=self.config.shutdown_timeout, + log_json=self.config.log_format == LogFormat.JSON, ) - self.stop() - - def _start_reader(self, stream: IO[bytes] | None, level: str) -> threading.Thread: - """Spawn a daemon thread that pipes a subprocess stream through the logger.""" - t = threading.Thread(target=self._read_log_stream, args=(stream, level), daemon=True) - t.start() - return t - - def _read_log_stream(self, stream: IO[bytes] | None, level: str) -> None: - if stream is None: - return - log_fn = getattr(logger, level) - is_stderr = level == "warning" - for raw in stream: - line = raw.decode("utf-8", errors="replace").rstrip() - if not line: - continue - if is_stderr: - self._last_stderr_lines.append(line) - if self.config.log_format == LogFormat.JSON: - try: - data = json.loads(line) - event = data.pop("event", line) - log_fn(event, **data) - continue - except (json.JSONDecodeError, TypeError): - logger.warning("malformed JSON from native module", raw=line) - log_fn(line, pid=self._process.pid if self._process else None) - stream.close() + self._proc.start() def _resolve_paths(self) -> None: """Resolve relative ``cwd`` and ``executable`` against the subclass's source file.""" @@ -300,16 +196,12 @@ def _maybe_build(self) -> None: if line.strip(): logger.warning(line) if proc.returncode != 0: - stderr_tail = stderr.decode("utf-8", errors="replace").strip()[-1000:] raise RuntimeError( - f"Build command failed (exit {proc.returncode}): {self.config.build_command}\n" - f"stderr: {stderr_tail}" + f"Build command failed (exit {proc.returncode}): {self.config.build_command}" ) if not exe.exists(): raise FileNotFoundError( - f"Build command succeeded but executable still not found: {exe}\n" - f"Build output may have been written to a different path. " - f"Check that build_command produces the executable at the expected location." + f"Build command succeeded but executable still not found: {exe}" ) def _collect_topics(self) -> dict[str, str]: diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index f9a89829d5..858cc81849 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -89,7 +89,7 @@ def test_classmethods() -> None: ) assert hasattr(class_rpcs["start"], "__rpc__"), "start should have __rpc__ attribute" - nav._close_module() + nav.stop() @pytest.mark.slow diff --git a/dimos/core/test_native_module.py b/dimos/core/test_native_module.py index 5d57c42854..7005498ebb 100644 --- a/dimos/core/test_native_module.py +++ b/dimos/core/test_native_module.py @@ -96,16 +96,22 @@ def test_process_crash_triggers_stop() -> None: mod.pointcloud.transport = LCMTransport("/pc", PointCloud2) mod.start() - assert mod._process is not None - pid = mod._process.pid + assert mod._proc is not None + assert mod._proc.is_alive + pid = mod._proc.pid - # Wait for the process to die and the watchdog to call stop() + # Wait for the process to die and the on_exit callback to call stop() for _ in range(30): time.sleep(0.1) - if mod._process is None: + if mod._proc is None or not mod._proc.is_alive: break - assert mod._process is None, f"Watchdog did not clean up after process {pid} died" + assert mod._proc is None or not mod._proc.is_alive, ( + f"Watchdog did not clean up after process {pid} died" + ) + + # stop() is idempotent + mod.stop() # Wait for background threads (run_forever, _lcm_loop, _watch_process) to finish # after the watchdog-triggered stop(). Without this, monitor_threads catches them. diff --git a/dimos/perception/detection/conftest.py b/dimos/perception/detection/conftest.py index 5f8f1bc4b9..cb20e3d06a 100644 --- a/dimos/perception/detection/conftest.py +++ b/dimos/perception/detection/conftest.py @@ -221,7 +221,7 @@ def moment_provider(**kwargs) -> Moment2D: yield moment_provider moment_provider.cache_clear() - module._close_module() + module.stop() @pytest.fixture(scope="session") @@ -256,7 +256,7 @@ def moment_provider(**kwargs) -> Moment3D: yield moment_provider moment_provider.cache_clear() if module is not None: - module._close_module() + module.stop() @pytest.fixture(scope="session") @@ -290,9 +290,9 @@ def object_db_module(get_moment): yield moduleDB - module2d._close_module() - module3d._close_module() - moduleDB._close_module() + module2d.stop() + module3d.stop() + moduleDB.stop() @pytest.fixture(scope="session") diff --git a/dimos/perception/detection/reid/test_module.py b/dimos/perception/detection/reid/test_module.py index aac6ba11d1..89d752aae3 100644 --- a/dimos/perception/detection/reid/test_module.py +++ b/dimos/perception/detection/reid/test_module.py @@ -40,5 +40,5 @@ def test_reid_ingress(imageDetections2d) -> None: print("Processing detections through ReidModule...") reid_module.annotations._transport = LCMTransport("/annotations", ImageAnnotations) reid_module.ingress(imageDetections2d) - reid_module._close_module() + reid_module.stop() print("✓ ReidModule ingress test completed successfully") diff --git a/dimos/robot/unitree/b1/test_connection.py b/dimos/robot/unitree/b1/test_connection.py index 9c7a2867fa..977554eb59 100644 --- a/dimos/robot/unitree/b1/test_connection.py +++ b/dimos/robot/unitree/b1/test_connection.py @@ -82,7 +82,7 @@ def test_watchdog_actually_zeros_commands(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=1.0) conn.watchdog_thread.join(timeout=1.0) - conn._close_module() + conn.stop() def test_watchdog_resets_on_new_command(self) -> None: """Test that watchdog timeout resets when new command arrives.""" @@ -114,7 +114,7 @@ def test_watchdog_resets_on_new_command(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=1.0) conn.watchdog_thread.join(timeout=1.0) - conn._close_module() + conn.stop() def test_watchdog_thread_efficiency(self) -> None: """Test that watchdog uses only one thread regardless of command rate.""" @@ -148,7 +148,7 @@ def test_watchdog_thread_efficiency(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn.stop() def test_watchdog_with_send_loop_blocking(self) -> None: """Test that watchdog still works if send loop blocks.""" @@ -200,7 +200,7 @@ def blocking_send_loop() -> None: conn.watchdog_running = False conn.send_thread.join(timeout=1.0) conn.watchdog_thread.join(timeout=1.0) - conn._close_module() + conn.stop() def test_continuous_commands_prevent_timeout(self) -> None: """Test that continuous commands prevent watchdog timeout.""" @@ -238,7 +238,7 @@ def test_continuous_commands_prevent_timeout(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=1.0) conn.watchdog_thread.join(timeout=1.0) - conn._close_module() + conn.stop() def test_watchdog_timing_accuracy(self) -> None: """Test that watchdog zeros commands at approximately 200ms.""" @@ -280,7 +280,7 @@ def test_watchdog_timing_accuracy(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn.stop() def test_mode_changes_with_watchdog(self) -> None: """Test that mode changes work correctly with watchdog.""" @@ -323,7 +323,7 @@ def test_mode_changes_with_watchdog(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn.stop() def test_watchdog_stops_movement_when_commands_stop(self) -> None: """Verify watchdog zeros commands when packets stop being sent.""" @@ -390,7 +390,7 @@ def test_watchdog_stops_movement_when_commands_stop(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=1.0) conn.watchdog_thread.join(timeout=1.0) - conn._close_module() + conn.stop() def test_rapid_command_thread_safety(self) -> None: """Test thread safety with rapid commands from multiple threads.""" @@ -439,4 +439,4 @@ def send_commands(thread_id) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn.stop() diff --git a/dimos/simulation/mujoco/direct_cmd_vel_explorer.py b/dimos/simulation/mujoco/direct_cmd_vel_explorer.py index 58dc91f6b1..f81c644d0a 100644 --- a/dimos/simulation/mujoco/direct_cmd_vel_explorer.py +++ b/dimos/simulation/mujoco/direct_cmd_vel_explorer.py @@ -99,7 +99,7 @@ def _drive_to(self, target_x: float, target_y: float) -> None: None, Twist(linear=Vector3(linear, 0, 0), angular=Vector3(0, 0, angular)), ) - self._stop() + self.stop() def follow_points(self, waypoints: list[tuple[float, float]]) -> None: self._wait_for_pose() diff --git a/dimos/utils/test_thread_utils.py b/dimos/utils/test_thread_utils.py new file mode 100644 index 0000000000..ac12936fe0 --- /dev/null +++ b/dimos/utils/test_thread_utils.py @@ -0,0 +1,892 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Exhaustive tests for dimos/utils/thread_utils.py + +Covers: ThreadSafeVal, ModuleThread, AsyncModuleThread, ModuleProcess, safe_thread_map. +Focuses on deadlocks, race conditions, idempotency, and edge cases under load. +""" + +from __future__ import annotations + +import asyncio +import os +import pickle +import sys +import threading +import time +from unittest import mock + +import pytest +from reactivex.disposable import CompositeDisposable + +from dimos.utils.thread_utils import ( + AsyncModuleThread, + ModuleProcess, + ModuleThread, + ThreadSafeVal, + safe_thread_map, +) + +# Helpers: fake ModuleBase for testing ModuleThread / AsyncModuleThread / ModuleProcess + + +class FakeModule: + """Minimal stand-in for ModuleBase — just needs _disposables.""" + + def __init__(self) -> None: + self._disposables = CompositeDisposable() + + def dispose(self) -> None: + self._disposables.dispose() + + +# ThreadSafeVal Tests + + +class TestThreadSafeVal: + def test_basic_get_set(self) -> None: + v = ThreadSafeVal(42) + assert v.get() == 42 + v.set(99) + assert v.get() == 99 + + def test_context_manager_returns_value(self) -> None: + v = ThreadSafeVal("hello") + with v as val: + assert val == "hello" + + def test_set_inside_context_manager_no_deadlock(self) -> None: + """The critical test: set() inside a with block must NOT deadlock. + + This was a confirmed bug when using threading.Lock (non-reentrant). + Fixed by using threading.RLock. + """ + v = ThreadSafeVal(0) + result = threading.Event() + + def do_it() -> None: + with v as val: + v.set(val + 1) + result.set() + + t = threading.Thread(target=do_it) + t.start() + t.join(timeout=2) + assert result.is_set(), "Deadlocked! set() inside with block hung" + assert v.get() == 1 + + def test_get_inside_context_manager_no_deadlock(self) -> None: + v = ThreadSafeVal(10) + result = threading.Event() + + def do_it() -> None: + with v: + _ = v.get() + result.set() + + t = threading.Thread(target=do_it) + t.start() + t.join(timeout=2) + assert result.is_set(), "Deadlocked! get() inside with block hung" + + def test_context_manager_blocks_other_threads(self) -> None: + """While one thread holds the lock via `with`, others should block on set().""" + v = ThreadSafeVal(0) + gate = threading.Event() + other_started = threading.Event() + other_finished = threading.Event() + + def holder() -> None: + with v: + gate.wait(timeout=5) # hold the lock until signaled + + def setter() -> None: + other_started.set() + v.set(42) # should block until holder releases + other_finished.set() + + t1 = threading.Thread(target=holder) + t2 = threading.Thread(target=setter) + t1.start() + time.sleep(0.05) # let holder acquire lock + t2.start() + other_started.wait(timeout=2) + time.sleep(0.1) + # setter should be blocked + assert not other_finished.is_set(), "set() did not block while lock was held" + gate.set() # release holder + t1.join(timeout=2) + t2.join(timeout=2) + assert other_finished.is_set() + assert v.get() == 42 + + def test_concurrent_increments(self) -> None: + """Many threads doing atomic read-modify-write should not lose updates.""" + v = ThreadSafeVal(0) + n_threads = 50 + n_increments = 100 + + def incrementer() -> None: + for _ in range(n_increments): + with v as val: + v.set(val + 1) + + threads = [threading.Thread(target=incrementer) for _ in range(n_threads)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=10) + assert v.get() == n_threads * n_increments + + def test_concurrent_increments_stress(self) -> None: + """Run the concurrent increment test multiple times to catch races.""" + for _ in range(10): + self.test_concurrent_increments() + + def test_pickle_roundtrip(self) -> None: + v = ThreadSafeVal({"key": [1, 2, 3]}) + data = pickle.dumps(v) + v2 = pickle.loads(data) + assert v2.get() == {"key": [1, 2, 3]} + # Verify the new instance has a working lock + with v2 as val: + v2.set({**val, "new": True}) + assert v2.get()["new"] is True + + def test_repr(self) -> None: + v = ThreadSafeVal("test") + assert repr(v) == "ThreadSafeVal('test')" + + def test_dict_type(self) -> None: + v = ThreadSafeVal({"running": False, "count": 0}) + with v as s: + v.set({**s, "running": True}) + assert v.get() == {"running": True, "count": 0} + + def test_string_literal_type(self) -> None: + """Simulates the ModState pattern from module.py.""" + v = ThreadSafeVal("init") + with v as state: + if state == "init": + v.set("started") + assert v.get() == "started" + + with v as state: + if state == "stopped": + pass # no-op + else: + v.set("stopped") + assert v.get() == "stopped" + + def test_nested_with_no_deadlock(self) -> None: + """RLock should allow the same thread to nest with blocks.""" + v = ThreadSafeVal(0) + result = threading.Event() + + def do_it() -> None: + with v: + with v as val2: + v.set(val2 + 1) + result.set() + + t = threading.Thread(target=do_it) + t.start() + t.join(timeout=2) + assert result.is_set(), "Nested with blocks deadlocked!" + + +# ModuleThread Tests + + +class TestModuleThread: + def test_basic_lifecycle(self) -> None: + mod = FakeModule() + ran = threading.Event() + + def target() -> None: + ran.set() + + mt = ModuleThread(module=mod, target=target, name="test-basic") + mt.start() + ran.wait(timeout=2) + assert ran.is_set() + mt.stop() + assert not mt.is_alive + + def test_auto_start(self) -> None: + mod = FakeModule() + started = threading.Event() + mt = ModuleThread(module=mod, target=started.set, name="test-autostart") + mt.start() + started.wait(timeout=2) + assert started.is_set() + mt.stop() + + def test_deferred_start(self) -> None: + mod = FakeModule() + started = threading.Event() + mt = ModuleThread(module=mod, target=started.set, name="test-deferred") + time.sleep(0.1) + assert not started.is_set() + mt.start() + started.wait(timeout=2) + assert started.is_set() + mt.stop() + + def test_stopping_property(self) -> None: + mod = FakeModule() + saw_stopping = threading.Event() + holder: list[ModuleThread] = [] + + def target() -> None: + while holder[0].status.get() == "running": + time.sleep(0.01) + saw_stopping.set() + + mt = ModuleThread(module=mod, target=target, name="test-stopping") + holder.append(mt) + mt.start() + time.sleep(0.05) + mt.stop() + saw_stopping.wait(timeout=2) + assert saw_stopping.is_set() + + def test_stop_idempotent(self) -> None: + mod = FakeModule() + mt = ModuleThread(module=mod, target=lambda: time.sleep(0.01), name="test-idem") + mt.start() + time.sleep(0.05) + mt.stop() + mt.stop() # second call should not raise + mt.stop() # third call should not raise + + def test_stop_from_managed_thread_no_deadlock(self) -> None: + """The thread calling stop() on itself should not deadlock.""" + mod = FakeModule() + result = threading.Event() + holder: list[ModuleThread] = [] + + def target() -> None: + holder[0].stop() # stop ourselves — should not deadlock + result.set() + + mt = ModuleThread(module=mod, target=target, name="test-self-stop") + holder.append(mt) + mt.start() + result.wait(timeout=3) + assert result.is_set(), "Deadlocked when thread called stop() on itself" + + def test_dispose_stops_thread(self) -> None: + """Module dispose should stop the thread via the registered Disposable.""" + mod = FakeModule() + running = threading.Event() + holder: list[ModuleThread] = [] + + def target() -> None: + running.set() + while holder[0].status.get() == "running": + time.sleep(0.01) + + mt = ModuleThread(module=mod, target=target, name="test-dispose") + holder.append(mt) + mt.start() + running.wait(timeout=2) + mod.dispose() + time.sleep(0.1) + assert not mt.is_alive + + def test_concurrent_stop_calls(self) -> None: + """Multiple threads calling stop() concurrently should not crash.""" + mod = FakeModule() + holder: list[ModuleThread] = [] + + def target() -> None: + while holder[0].status.get() == "running": + time.sleep(0.01) + + mt = ModuleThread(module=mod, target=target, name="test-concurrent-stop") + holder.append(mt) + mt.start() + time.sleep(0.05) + + errors = [] + + def stop_it() -> None: + try: + mt.stop() + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=stop_it) for _ in range(20)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=5) + assert not errors, f"Concurrent stop() raised: {errors}" + + def test_close_timeout_respected(self) -> None: + """If the thread ignores the stop signal, stop() should return after close_timeout.""" + mod = FakeModule() + cancel = threading.Event() + + def stubborn_target() -> None: + cancel.wait(10) # blocks but can be released for cleanup + + mt = ModuleThread( + module=mod, target=stubborn_target, name="test-timeout", close_timeout=0.2 + ) + mt.start() + start = time.monotonic() + mt.stop() + elapsed = time.monotonic() - start + assert elapsed < 1.0, f"stop() took {elapsed}s, expected ~0.2s" + # Release the thread so it doesn't leak + cancel.set() + mt._thread.join(timeout=1.0) + + def test_stop_concurrent_with_dispose(self) -> None: + """Calling stop() and dispose() concurrently should not crash.""" + for _ in range(20): + mod = FakeModule() + holder: list[ModuleThread] = [] + + def target(h: list[ModuleThread] = holder) -> None: + while h[0].status.get() == "running": + time.sleep(0.001) + + mt = ModuleThread(module=mod, target=target, name="test-stop-dispose") + holder.append(mt) + mt.start() + time.sleep(0.02) + # Race: stop and dispose from different threads + t1 = threading.Thread(target=mt.stop) + t2 = threading.Thread(target=mod.dispose) + t1.start() + t2.start() + t1.join(timeout=3) + t2.join(timeout=3) + + +# AsyncModuleThread Tests + + +class TestAsyncModuleThread: + def test_creates_loop_and_thread(self) -> None: + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + amt.start() + assert amt.loop is not None + assert amt.loop.is_running() + assert amt.is_alive + amt.stop() + assert not amt.is_alive + + def test_stop_idempotent(self) -> None: + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + amt.start() + amt.stop() + amt.stop() # should not raise + amt.stop() + + def test_dispose_stops_loop(self) -> None: + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + amt.start() + assert amt.is_alive + mod.dispose() + time.sleep(0.1) + assert not amt.is_alive + + def test_can_schedule_coroutine(self) -> None: + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + amt.start() + result = [] + + async def coro() -> None: + result.append(42) + + future = asyncio.run_coroutine_threadsafe(coro(), amt.loop) + future.result(timeout=2) + assert result == [42] + amt.stop() + + def test_stop_with_pending_work(self) -> None: + """Stop should succeed even with long-running tasks on the loop.""" + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + amt.start() + started = threading.Event() + + async def slow_coro() -> None: + started.set() + await asyncio.sleep(10) + + asyncio.run_coroutine_threadsafe(slow_coro(), amt.loop) + started.wait(timeout=2) + # stop() should not hang waiting for the coroutine + start = time.monotonic() + amt.stop() + elapsed = time.monotonic() - start + assert elapsed < 5.0, f"stop() hung for {elapsed}s with pending coroutine" + + def test_concurrent_stop(self) -> None: + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + amt.start() + errors = [] + + def stop_it() -> None: + try: + amt.stop() + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=stop_it) for _ in range(20)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=5) + assert not errors + + +# ModuleProcess Tests + + +# Helper: path to a python that sleeps or echoes +PYTHON = sys.executable + + +class TestModuleProcess: + def test_basic_lifecycle(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + shutdown_timeout=2.0, + ) + mp.start() + assert mp.is_alive + assert mp.pid is not None + mp.stop() + assert not mp.is_alive + assert mp.pid is None + + def test_stop_idempotent(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + shutdown_timeout=1.0, + ) + mp.start() + mp.stop() + mp.stop() # should not raise + mp.stop() + + def test_dispose_stops_process(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + shutdown_timeout=2.0, + ) + mp.start() + mod.dispose() + time.sleep(0.5) + assert not mp.is_alive + + def test_on_exit_fires_on_natural_exit(self) -> None: + """on_exit should fire when the process exits on its own.""" + mod = FakeModule() + exit_called = threading.Event() + + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "print('done')"], + on_exit=exit_called.set, + ) + mp.start() + exit_called.wait(timeout=5) + assert exit_called.is_set(), "on_exit was not called after natural process exit" + + def test_on_exit_fires_on_crash(self) -> None: + mod = FakeModule() + exit_called = threading.Event() + + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import sys; sys.exit(1)"], + on_exit=exit_called.set, + ) + mp.start() + exit_called.wait(timeout=5) + assert exit_called.is_set(), "on_exit was not called after process crash" + + def test_on_exit_not_fired_on_stop(self) -> None: + """on_exit should NOT fire when stop() kills the process.""" + mod = FakeModule() + exit_called = threading.Event() + + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + on_exit=exit_called.set, + shutdown_timeout=2.0, + ) + mp.start() + time.sleep(0.2) # let watchdog start + mp.stop() + time.sleep(1.0) # give watchdog time to potentially fire + assert not exit_called.is_set(), "on_exit fired after intentional stop()" + + def test_stdout_logged(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "print('hello from subprocess')"], + ) + mp.start() + time.sleep(1.0) # let output be read + mp.stop() + + def test_stderr_logged(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import sys; sys.stderr.write('error msg\\n')"], + ) + mp.start() + time.sleep(1.0) + mp.stop() + + def test_log_json_mode(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[ + PYTHON, + "-c", + """import json; print(json.dumps({"event": "test", "key": "val"}))""", + ], + log_json=True, + ) + mp.start() + time.sleep(1.0) + mp.stop() + + def test_log_json_malformed(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "print('not json')"], + log_json=True, + ) + mp.start() + time.sleep(1.0) + mp.stop() + + def test_stop_process_that_ignores_sigterm(self) -> None: + """Process that ignores SIGTERM should be killed with SIGKILL.""" + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[ + PYTHON, + "-c", + "import signal, time; signal.signal(signal.SIGTERM, signal.SIG_IGN); time.sleep(60)", + ], + shutdown_timeout=0.5, + kill_timeout=2.0, + ) + mp.start() + time.sleep(0.2) + start = time.monotonic() + mp.stop() + elapsed = time.monotonic() - start + assert not mp.is_alive + # Should take roughly shutdown_timeout (0.5) + a bit for SIGKILL + assert elapsed < 5.0 + + def test_stop_already_dead_process(self) -> None: + """stop() on a process that already exited should not raise.""" + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "pass"], # exits immediately + ) + mp.start() + time.sleep(1.0) # let it die + mp.stop() # should not raise + + def test_concurrent_stop(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + shutdown_timeout=2.0, + ) + mp.start() + errors = [] + + def stop_it() -> None: + try: + mp.stop() + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=stop_it) for _ in range(20)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=10) + assert not errors, f"Concurrent stop() raised: {errors}" + + def test_on_exit_calls_module_stop_no_deadlock(self) -> None: + """Simulate the real pattern: on_exit=module.stop, which disposes the + ModuleProcess, which tries to stop its watchdog from inside the watchdog. + Must not deadlock. + """ + mod = FakeModule() + stop_called = threading.Event() + + def fake_module_stop() -> None: + """Simulates module.stop() -> _stop() -> dispose()""" + mod.dispose() + stop_called.set() + + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "pass"], # exits immediately + on_exit=fake_module_stop, + ) + mp.start() + stop_called.wait(timeout=5) + assert stop_called.is_set(), "Deadlocked! on_exit -> dispose -> stop chain hung" + + def test_on_exit_calls_module_stop_no_deadlock_stress(self) -> None: + """Run the deadlock test multiple times under load.""" + for _i in range(10): + self.test_on_exit_calls_module_stop_no_deadlock() + + def test_deferred_start(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + ) + assert not mp.is_alive + mp.start() + assert mp.is_alive + mp.stop() + + def test_env_passed(self) -> None: + mod = FakeModule() + exit_called = threading.Event() + + mp = ModuleProcess( + module=mod, + args=[ + PYTHON, + "-c", + "import os, sys; sys.exit(0 if os.environ.get('MY_VAR') == '42' else 1)", + ], + env={**os.environ, "MY_VAR": "42"}, + on_exit=exit_called.set, + ) + mp.start() + exit_called.wait(timeout=5) + # Process should have exited with 0 (our on_exit fires for all unmanaged exits) + assert exit_called.is_set() + + def test_cwd_passed(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import os; print(os.getcwd())"], + cwd="/tmp", + ) + mp.start() + time.sleep(1.0) + mp.stop() + + +# safe_thread_map Tests + + +class TestSafeThreadMap: + def test_empty_input(self) -> None: + assert safe_thread_map([], lambda x: x) == [] + + def test_all_succeed(self) -> None: + result = safe_thread_map([1, 2, 3], lambda x: x * 2) + assert result == [2, 4, 6] + + def test_preserves_order(self) -> None: + def slow(x: int) -> int: + time.sleep(0.01 * (10 - x)) + return x + + result = safe_thread_map(list(range(10)), slow) + assert result == list(range(10)) + + def test_all_fail_raises_exception_group(self) -> None: + def fail(x: int) -> int: + raise ValueError(f"fail-{x}") + + with pytest.raises(ExceptionGroup) as exc_info: + safe_thread_map([1, 2, 3], fail) + assert len(exc_info.value.exceptions) == 3 + + def test_partial_failure(self) -> None: + def maybe_fail(x: int) -> int: + if x == 2: + raise ValueError("fail") + return x + + with pytest.raises(ExceptionGroup) as exc_info: + safe_thread_map([1, 2, 3], maybe_fail) + assert len(exc_info.value.exceptions) == 1 + + def test_on_errors_callback(self) -> None: + def fail(x: int) -> int: + if x == 2: + raise ValueError("boom") + return x * 10 + + cleanup_called = False + + def on_errors(outcomes, successes, errors): + nonlocal cleanup_called + cleanup_called = True + assert len(errors) == 1 + assert len(successes) == 2 + return successes # return successful results + + result = safe_thread_map([1, 2, 3], fail, on_errors) + assert cleanup_called + assert sorted(result) == [10, 30] + + def test_on_errors_can_raise(self) -> None: + def fail(x: int) -> int: + raise ValueError("boom") + + def on_errors(outcomes, successes, errors): + raise RuntimeError("custom error") + + with pytest.raises(RuntimeError, match="custom error"): + safe_thread_map([1], fail, on_errors) + + def test_waits_for_all_before_raising(self) -> None: + """Even if one fails fast, all others should complete.""" + completed = [] + + def work(x: int) -> int: + if x == 0: + raise ValueError("fast fail") + time.sleep(0.2) + completed.append(x) + return x + + with pytest.raises(ExceptionGroup): + safe_thread_map([0, 1, 2, 3], work) + # All non-failing items should have completed + assert sorted(completed) == [1, 2, 3] + + +# Integration: ModuleProcess on_exit -> dispose chain (the CI bug scenario) + + +class TestModuleProcessDisposeChain: + """Tests the exact pattern that caused the CI bug: + process exits -> watchdog fires on_exit -> module.stop() -> dispose -> + ModuleProcess.stop() -> tries to stop watchdog from inside watchdog thread. + """ + + @staticmethod + def _make_fake_stop(mod: FakeModule, done: threading.Event) -> Callable: + def fake_stop() -> None: + mod.dispose() + done.set() + + return fake_stop + + def test_chain_no_deadlock_fast_exit(self) -> None: + """Process exits immediately.""" + for _ in range(20): + mod = FakeModule() + done = threading.Event() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "pass"], + on_exit=self._make_fake_stop(mod, done), + ) + mp.start() + assert done.wait(timeout=5), "Deadlock in dispose chain (fast exit)" + + def test_chain_no_deadlock_slow_exit(self) -> None: + """Process runs briefly then exits.""" + for _ in range(10): + mod = FakeModule() + done = threading.Event() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(0.1)"], + on_exit=self._make_fake_stop(mod, done), + ) + mp.start() + assert done.wait(timeout=5), "Deadlock in dispose chain (slow exit)" + + def test_chain_concurrent_with_external_stop(self) -> None: + """Process exits naturally while external code calls stop().""" + for _ in range(20): + mod = FakeModule() + done = threading.Event() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(0.05)"], + on_exit=self._make_fake_stop(mod, done), + shutdown_timeout=1.0, + ) + mp.start() + # Race: the process might exit naturally or we might stop it + time.sleep(0.03) + mp.stop() + # Either way, should not deadlock + time.sleep(1.0) + + def test_dispose_with_artificial_delay(self) -> None: + """Add artificial delay near cleanup to simulate heavy CPU load.""" + original_stop = ModuleThread.stop + + def slow_stop(self_mt: ModuleThread) -> None: + time.sleep(0.05) # simulate load + original_stop(self_mt) + + for _ in range(10): + mod = FakeModule() + done = threading.Event() + with mock.patch.object(ModuleThread, "stop", slow_stop): + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "pass"], + on_exit=self._make_fake_stop(mod, done), + ) + mp.start() + assert done.wait(timeout=10), "Deadlock with slow ModuleThread.stop()" + + +from dimos.utils.typing_utils import ExceptionGroup diff --git a/dimos/utils/thread_utils.py b/dimos/utils/thread_utils.py new file mode 100644 index 0000000000..3e3fe6ba11 --- /dev/null +++ b/dimos/utils/thread_utils.py @@ -0,0 +1,559 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Thread utilities: safe values, managed threads, safe parallel map.""" + +from __future__ import annotations + +import asyncio +import collections +from concurrent.futures import Future, ThreadPoolExecutor, as_completed +import json +import signal +import subprocess +import threading +from typing import IO, TYPE_CHECKING, Any, Generic, Literal + +from reactivex.disposable import Disposable + +from dimos.utils.logging_config import setup_logger +from dimos.utils.typing_utils import ExceptionGroup, TypeVar + +logger = setup_logger() + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from dimos.core.module import ModuleBase + +T = TypeVar("T") +R = TypeVar("R") + + +class ThreadSafeVal(Generic[T]): + """A thread-safe value wrapper. + + Forces lock usage in order to get access to a value (reduces unsafe value access) + Three ways to use: + 1. `.set` + 2. `.get` + 3. via a context manager:: + with counter as value: + # Lock is held for the entire block. + # Other threads block on get/set/with until this exits. + if value < 100: + counter.set(value + 1) + + # Works with any type: + status = ThreadSafeVal({"running": False, "count": 0}) + with status as s: + status.set({**s, "running": True}) + + # Bool check (for flag-like usage): + stopping = ThreadSafeVal(False) + stopping.set(True) + if stopping.get(): + print("stopping!") + """ + + def __init__(self, initial: T) -> None: + self._lock = threading.RLock() + self._value = initial + + def get(self) -> T: + """Return the current value (acquires the lock briefly).""" + with self._lock: + return self._value + + def set(self, value: T) -> None: + """Replace the value (acquires the lock briefly).""" + with self._lock: + self._value = value + + def __enter__(self) -> T: + self._lock.acquire() + return self._value + + def __exit__(self, *exc: object) -> None: + self._lock.release() + + def __getstate__(self) -> dict[str, Any]: + return {"_value": self._value} + + def __setstate__(self, state: dict[str, Any]) -> None: + self._lock = threading.RLock() + self._value = state["_value"] + + def __repr__(self) -> str: + return f"ThreadSafeVal({self._value!r})" + + +class ModuleThread: + """A thread that registers cleanup with a module's disposables. + + Passes most kwargs through to ``threading.Thread``. On construction, + registers a disposable with the module so that when the module stops, + the thread is automatically joined. Cleanup is idempotent — safe to + call ``stop()`` manually even if the module also disposes it. + + Example:: + + class MyModule(Module): + @rpc + def start(self) -> None: + self._worker = ModuleThread( + module=self, + target=self._run_loop, + name="my-worker", + ) + self._worker.start() + + def _run_loop(self) -> None: + while self._worker.status.get() == "running": + do_work() + """ + + def __init__( + self, + module: ModuleBase[Any], + *, + close_timeout: float = 2.0, + **thread_kwargs: Any, + ) -> None: + thread_kwargs.setdefault("daemon", True) + thread_kwargs.setdefault("name", f"{type(module).__name__}-thread") + self._thread = threading.Thread(**thread_kwargs) + self._close_timeout = close_timeout + self.status: ThreadSafeVal[Literal["not_started", "running", "stopping", "stopped"]] = ( + ThreadSafeVal("not_started") + ) + module._disposables.add(Disposable(self.stop)) + + def start(self) -> None: + self.status.set("running") + self._thread.start() + + def stop(self) -> None: + """Signal the thread to stop and join it. + + Safe to call multiple times, from any thread (including the + managed thread itself — it will skip the join in that case). + """ + with self.status as s: + if s in ("stopping", "stopped"): + return + self.status.set("stopping") + if self._thread.is_alive() and self._thread is not threading.current_thread(): + self._thread.join(timeout=self._close_timeout) + self.status.set("stopped") + + def join(self, timeout: float | None = None) -> None: + """Join the underlying thread.""" + self._thread.join(timeout=timeout) + + @property + def is_alive(self) -> bool: + return self._thread.is_alive() + + +class AsyncModuleThread: + """A thread running an asyncio event loop, registered with a module's disposables. + + If a loop is already running in the current context, reuses it (no thread + created). Otherwise creates a new loop and drives it in a daemon thread. + + On stop (or module dispose), the loop is shut down gracefully and the + thread is joined. Idempotent — safe to call ``stop()`` multiple times. + + Example:: + + class MyModule(Module): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._async = AsyncModuleThread(module=self) + + @rpc + def start(self) -> None: + future = asyncio.run_coroutine_threadsafe( + self._do_work(), self._async.loop + ) + + async def _do_work(self) -> None: + ... + """ + + def __init__( + self, + module: ModuleBase[Any], + *, + close_timeout: float = 2.0, + ) -> None: + self._close_timeout = close_timeout + self._stopped = ThreadSafeVal(False) + self._owns_loop = False + self._thread: threading.Thread | None = None + self._loop: asyncio.AbstractEventLoop | None = None + self._module_name = type(module).__name__ + + module._disposables.add(Disposable(self.stop)) + + def start(self) -> None: + """Create (or reuse) the event loop and start the driver thread. + + Safe to call multiple times — subsequent calls are no-ops. + """ + if self._loop is not None: + return + + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + self._owns_loop = True + self._thread = threading.Thread( + target=self._loop.run_forever, + daemon=True, + name=f"{self._module_name}-event-loop", + ) + self._thread.start() + + @property + def loop(self) -> asyncio.AbstractEventLoop: + """The managed event loop.""" + return self._loop + + @property + def is_alive(self) -> bool: + return self._thread is not None and self._thread.is_alive() + + def stop(self) -> None: + """Stop the event loop and join the thread. + + No-op if the loop was not created by this instance (reused an + existing running loop). Safe to call multiple times. + """ + with self._stopped as stopped: + if stopped: + return + self._stopped.set(True) + + if self._owns_loop and self._loop is not None and self._loop.is_running(): + self._loop.call_soon_threadsafe(self._loop.stop) + + if self._thread is not None and self._thread.is_alive(): + self._thread.join(timeout=self._close_timeout) + + +class ModuleProcess: + """A managed subprocess that pipes stdout/stderr through the logger. + + Registers with a module's disposables so the process is automatically + stopped on module teardown. A watchdog thread monitors the process and + calls ``on_exit`` if the process exits on its own (i.e. not via + ``ModuleProcess.stop()``). + + Most constructor kwargs mirror ``subprocess.Popen``. ``stdout`` and + ``stderr`` are always captured (set to ``PIPE`` internally). + + Example:: + + class MyModule(Module): + @rpc + def start(self) -> None: + self._proc = ModuleProcess( + module=self, + args=["./my_binary", "--flag"], + cwd="/opt/bin", + on_exit=self.stop, # stops the whole module if process exits on its own + ) + self._proc.start() + + @rpc + def stop(self) -> None: + super().stop() + """ + + def __init__( + self, + module: ModuleBase[Any], + args: list[str] | str, + *, + env: dict[str, str] | None = None, + cwd: str | None = None, + shell: bool = False, + on_exit: Callable[[], Any] | None = None, + shutdown_timeout: float = 10.0, + kill_timeout: float = 5.0, + log_json: bool = False, + log_tail_lines: int = 50, + **popen_kwargs: Any, + ) -> None: + self._args = args + self._env = env + self._cwd = cwd + self._shell = shell + self._on_exit = on_exit + self._shutdown_timeout = shutdown_timeout + self._kill_timeout = kill_timeout + self._log_json = log_json + self._log_tail_lines = log_tail_lines + self._popen_kwargs = popen_kwargs + self._process: subprocess.Popen[bytes] | None = None + self._watchdog: ModuleThread | None = None + self._module = module + self._stopped = ThreadSafeVal(False) + self.last_stdout: collections.deque[str] = collections.deque(maxlen=log_tail_lines) + self.last_stderr: collections.deque[str] = collections.deque(maxlen=log_tail_lines) + + module._disposables.add(Disposable(self.stop)) + + @property + def pid(self) -> int | None: + return self._process.pid if self._process is not None else None + + @property + def returncode(self) -> int | None: + if self._process is None: + return None + return self._process.poll() + + @property + def is_alive(self) -> bool: + return self._process is not None and self._process.poll() is None + + def start(self) -> None: + """Launch the subprocess and start the watchdog.""" + if self._process is not None and self._process.poll() is None: + logger.warning("Process already running", pid=self._process.pid) + return + + self._stopped.set(False) + + self.last_stdout = collections.deque(maxlen=self._log_tail_lines) + self.last_stderr = collections.deque(maxlen=self._log_tail_lines) + + logger.info( + "Starting process", + cmd=self._args if isinstance(self._args, str) else " ".join(self._args), + cwd=self._cwd, + ) + self._process = subprocess.Popen( + self._args, + env=self._env, + cwd=self._cwd, + shell=self._shell, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + **self._popen_kwargs, + ) + logger.info("Process started", pid=self._process.pid) + + self._watchdog = ModuleThread( + module=self._module, + target=self._watch, + name=f"proc-{self._process.pid}-watchdog", + ) + self._watchdog.start() + + def stop(self) -> None: + """Send SIGTERM, wait, escalate to SIGKILL if needed. Idempotent.""" + with self._stopped as stopped: + if stopped: + return + self._stopped.set(True) + + if self._process is not None and self._process.poll() is None: + logger.info("Stopping process", pid=self._process.pid) + try: + self._process.send_signal(signal.SIGTERM) + except OSError: + pass # process already dead (PID recycled or exited between poll and signal) + else: + try: + self._process.wait(timeout=self._shutdown_timeout) + except subprocess.TimeoutExpired: + logger.warning( + "Process did not exit, sending SIGKILL", + pid=self._process.pid, + ) + self._process.kill() + try: + self._process.wait(timeout=self._kill_timeout) + except subprocess.TimeoutExpired: + logger.error( + "Process did not exit after SIGKILL", + pid=self._process.pid, + ) + self._process = None + + def _watch(self) -> None: + """Watchdog: pipe logs, detect crashes.""" + proc = self._process + if proc is None: + return + + stdout_t = self._start_reader(proc.stdout, "info") + stderr_t = self._start_reader(proc.stderr, "warning") + rc = proc.wait() + stdout_t.join(timeout=2) + stderr_t.join(timeout=2) + + if self._stopped.get(): + return + + last_stdout = "\n".join(self.last_stdout) or None + last_stderr = "\n".join(self.last_stderr) or None + logger.error( + "Process died unexpectedly", + pid=proc.pid, + returncode=rc, + last_stdout=last_stdout, + last_stderr=last_stderr, + ) + if self._on_exit is not None: + self._on_exit() + + def _start_reader(self, stream: IO[bytes] | None, level: str) -> threading.Thread: + t = threading.Thread(target=self._read_stream, args=(stream, level), daemon=True) + t.start() + return t + + def _read_stream(self, stream: IO[bytes] | None, level: str) -> None: + if stream is None: + return + log_fn = getattr(logger, level) + is_stderr = level == "warning" + buf = self.last_stderr if is_stderr else self.last_stdout + for raw in stream: + line = raw.decode("utf-8", errors="replace").rstrip() + if not line: + continue + buf.append(line) + if self._log_json: + try: + data = json.loads(line) + event = data.pop("event", line) + log_fn(event, **data) + continue + except (json.JSONDecodeError, TypeError): + logger.warning("malformed JSON from process", raw=line) + proc = self._process + log_fn(line, pid=proc.pid if proc else None) + stream.close() + + +def safe_thread_map( + items: Sequence[T], + fn: Callable[[T], R], + on_errors: Callable[[list[tuple[T, R | Exception]], list[R], list[Exception]], Any] + | None = None, +) -> list[R]: + """Thread-pool map that waits for all items to finish before raising and a cleanup handler + + - Empty *items* → returns ``[]`` immediately. + - All succeed → returns results in input order. + - Any fail → calls ``on_errors(outcomes, successes, errors)`` where + *outcomes* is a list of ``(input, result_or_exception)`` pairs in input + order, *successes* is the list of successful results, and *errors* is + the list of exceptions. If *on_errors* raises, that exception propagates. + If *on_errors* returns normally, its return value is returned from + ``safe_thread_map``. If *on_errors* is ``None``, raises an + ``ExceptionGroup``. + + Example:: + + def start_service(name: str) -> Connection: + return connect(name) + + def cleanup( + outcomes: list[tuple[str, Connection | Exception]], + successes: list[Connection], + errors: list[Exception], + ) -> None: + for conn in successes: + conn.close() + raise ExceptionGroup("failed to start services", errors) + + connections = safe_thread_map( + ["db", "cache", "queue"], + start_service, + cleanup, # called only if any start_service() raises + ) + """ + if not items: + return [] + + outcomes: dict[int, R | Exception] = {} + + with ThreadPoolExecutor(max_workers=len(items)) as pool: + futures: dict[Future[R], int] = {pool.submit(fn, item): i for i, item in enumerate(items)} + for fut in as_completed(futures): + idx = futures[fut] + try: + outcomes[idx] = fut.result() + except Exception as e: + outcomes[idx] = e + + successes: list[R] = [] + errors: list[Exception] = [] + for v in outcomes.values(): + if isinstance(v, Exception): + errors.append(v) + else: + successes.append(v) + + if errors: + if on_errors is not None: + zipped = [(items[i], outcomes[i]) for i in range(len(items))] + return on_errors(zipped, successes, errors) # type: ignore[return-value, no-any-return] + raise ExceptionGroup("safe_thread_map failed", errors) + + return [outcomes[i] for i in range(len(items))] # type: ignore[misc] + + +def thread_start( + module: ModuleBase[Any], + *, + close_timeout: float = 2.0, + **thread_kwargs: Any, +) -> ModuleThread: + """Create a :class:`ModuleThread`, start it immediately, and return it. + + Convenience wrapper equivalent to:: + + t = ModuleThread(module, close_timeout=close_timeout, **thread_kwargs) + t.start() + return t + + Accepts the same arguments as :class:`ModuleThread`. + + Example:: + + class MyModule(Module): + @rpc + def start(self) -> None: + self._worker = thread_start( + self, + target=self._run_loop, + name="my-worker", + ) + + def _run_loop(self) -> None: + while self._worker.status.get() == "running": + do_work() + """ + t = ModuleThread(module, close_timeout=close_timeout, **thread_kwargs) + t.start() + return t diff --git a/dimos/utils/typing_utils.py b/dimos/utils/typing_utils.py new file mode 100644 index 0000000000..3592d5fdbb --- /dev/null +++ b/dimos/utils/typing_utils.py @@ -0,0 +1,45 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unify typing compatibility across multiple Python versions.""" + +from __future__ import annotations + +from collections.abc import Sequence +import sys + +if sys.version_info < (3, 13): + from typing_extensions import TypeVar +else: + from typing import TypeVar + +if sys.version_info < (3, 11): + + class ExceptionGroup(Exception): # type: ignore[no-redef] # noqa: N818 + """Minimal ExceptionGroup polyfill for Python 3.10.""" + + exceptions: tuple[BaseException, ...] + + def __init__(self, message: str, exceptions: Sequence[BaseException]) -> None: + super().__init__(message) + self.exceptions = tuple(exceptions) +else: + import builtins + + ExceptionGroup = builtins.ExceptionGroup # type: ignore[misc] + +__all__ = [ + "ExceptionGroup", + "TypeVar", +]