From ff2becde188d41ba4b4c891aa1b415ae3b65917b Mon Sep 17 00:00:00 2001 From: Summer Yang Date: Wed, 11 Mar 2026 17:30:00 -0700 Subject: [PATCH 01/21] fix(test): resolve thread leak failures in CI - Add mod.stop() to test_process_crash_triggers_stop so watchdog, LCM, and event-loop threads are properly joined from the test thread - Filter third-party daemon threads with generic names (Thread-\d+) in conftest monitor_threads to ignore torch/HF background threads that have no cleanup API --- dimos/conftest.py | 11 +++++++++++ dimos/core/test_native_module.py | 4 ++++ 2 files changed, 15 insertions(+) diff --git a/dimos/conftest.py b/dimos/conftest.py index 4ab8a401f8..5f7f30e882 100644 --- a/dimos/conftest.py +++ b/dimos/conftest.py @@ -14,6 +14,7 @@ import asyncio import os +import re import threading from dotenv import load_dotenv @@ -160,6 +161,16 @@ def monitor_threads(request): if not any(t.name.startswith(prefix) for prefix in expected_persistent_thread_prefixes) ] + # Filter out third-party daemon threads with generic names (e.g. "Thread-109"). + # On Python 3.12+ our own threads include the target function name in parens + # (e.g. "Thread-166 (run_forever)"), so this only matches unnamed threads + # from libraries like torch/HuggingFace that have no cleanup API. + new_threads = [ + t + for t in new_threads + if not (t.daemon and re.fullmatch(r"Thread-\d+", t.name)) + ] + # Filter out threads we've already seen (from previous tests) truly_new = [t for t in new_threads if t.ident not in _seen_threads] diff --git a/dimos/core/test_native_module.py b/dimos/core/test_native_module.py index e77b8f9a53..c9556493f5 100644 --- a/dimos/core/test_native_module.py +++ b/dimos/core/test_native_module.py @@ -106,6 +106,10 @@ def test_process_crash_triggers_stop() -> None: break assert mod._process is None, f"Watchdog did not clean up after process {pid} died" + # Explicitly stop to join watchdog, LCM, and event-loop threads from the + # test thread. The watchdog's self.stop() can't join itself, so these + # threads would otherwise leak. stop() is idempotent. + mod.stop() @pytest.mark.slow From 6bd7ad60c3a1e73e96793f8e5355404045b34bd3 Mon Sep 17 00:00:00 2001 From: SUMMERxYANG <69720581+SUMMERxYANG@users.noreply.github.com> Date: Thu, 12 Mar 2026 00:35:29 +0000 Subject: [PATCH 02/21] CI code cleanup --- dimos/conftest.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/dimos/conftest.py b/dimos/conftest.py index 5f7f30e882..1a7a4f943b 100644 --- a/dimos/conftest.py +++ b/dimos/conftest.py @@ -166,9 +166,7 @@ def monitor_threads(request): # (e.g. "Thread-166 (run_forever)"), so this only matches unnamed threads # from libraries like torch/HuggingFace that have no cleanup API. new_threads = [ - t - for t in new_threads - if not (t.daemon and re.fullmatch(r"Thread-\d+", t.name)) + t for t in new_threads if not (t.daemon and re.fullmatch(r"Thread-\d+", t.name)) ] # Filter out threads we've already seen (from previous tests) From ee752c023dd6f20f8afbc6969a26de298210d9b8 Mon Sep 17 00:00:00 2001 From: Summer Yang Date: Thu, 12 Mar 2026 14:08:39 -0700 Subject: [PATCH 03/21] fix(test): use fixture for native module crash test cleanup Convert test_process_crash_triggers_stop to use a fixture that calls mod.stop() in teardown. The watchdog thread calls self.stop() but can't join itself, so an explicit stop() from the test thread is needed to properly clean up all threads. Drop the broad conftest regex filter for generic daemon thread names per review feedback. --- dimos/conftest.py | 8 -------- dimos/core/test_native_module.py | 33 ++++++++++++++++++++------------ 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/dimos/conftest.py b/dimos/conftest.py index 1a7a4f943b..29eaf05567 100644 --- a/dimos/conftest.py +++ b/dimos/conftest.py @@ -14,7 +14,6 @@ import asyncio import os -import re import threading from dotenv import load_dotenv @@ -161,13 +160,6 @@ def monitor_threads(request): if not any(t.name.startswith(prefix) for prefix in expected_persistent_thread_prefixes) ] - # Filter out third-party daemon threads with generic names (e.g. "Thread-109"). - # On Python 3.12+ our own threads include the target function name in parens - # (e.g. "Thread-166 (run_forever)"), so this only matches unnamed threads - # from libraries like torch/HuggingFace that have no cleanup API. - new_threads = [ - t for t in new_threads if not (t.daemon and re.fullmatch(r"Thread-\d+", t.name)) - ] # Filter out threads we've already seen (from previous tests) truly_new = [t for t in new_threads if t.ident not in _seen_threads] diff --git a/dimos/core/test_native_module.py b/dimos/core/test_native_module.py index c9556493f5..600795b031 100644 --- a/dimos/core/test_native_module.py +++ b/dimos/core/test_native_module.py @@ -18,8 +18,11 @@ The echo script writes received CLI args to a temp file for assertions. """ +from collections.abc import Generator +from dataclasses import dataclass import json from pathlib import Path +import threading import time import pytest @@ -90,26 +93,32 @@ def start(self) -> None: pass -def test_process_crash_triggers_stop() -> None: - """When the native process dies unexpectedly, the watchdog calls stop().""" +@pytest.fixture +def crash_module() -> Generator[StubNativeModule, None, None]: + """Create a StubNativeModule that dies after 0.2s, ensuring cleanup.""" mod = StubNativeModule(die_after=0.2) - mod.pointcloud.transport = LCMTransport("/pc", PointCloud2) - mod.start() + yield mod + # Join watchdog, LCM, and event-loop threads from the test thread. + # The watchdog's self.stop() can't join itself, so without this the + # threads leak. stop() is idempotent. + mod.stop() - assert mod._process is not None - pid = mod._process.pid + +def test_process_crash_triggers_stop(crash_module: StubNativeModule) -> None: + """When the native process dies unexpectedly, the watchdog calls stop().""" + crash_module.pointcloud.transport = LCMTransport("/pc", PointCloud2) + crash_module.start() + + assert crash_module._process is not None + pid = crash_module._process.pid # Wait for the process to die and the watchdog to call stop() for _ in range(30): time.sleep(0.1) - if mod._process is None: + if crash_module._process is None: break - assert mod._process is None, f"Watchdog did not clean up after process {pid} died" - # Explicitly stop to join watchdog, LCM, and event-loop threads from the - # test thread. The watchdog's self.stop() can't join itself, so these - # threads would otherwise leak. stop() is idempotent. - mod.stop() + assert crash_module._process is None, f"Watchdog did not clean up after process {pid} died" @pytest.mark.slow From e316626b09f2ef5b08a80973b2df2d1251251dec Mon Sep 17 00:00:00 2001 From: SUMMERxYANG <69720581+SUMMERxYANG@users.noreply.github.com> Date: Thu, 12 Mar 2026 21:09:29 +0000 Subject: [PATCH 04/21] CI code cleanup --- dimos/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dimos/conftest.py b/dimos/conftest.py index 29eaf05567..4ab8a401f8 100644 --- a/dimos/conftest.py +++ b/dimos/conftest.py @@ -160,7 +160,6 @@ def monitor_threads(request): if not any(t.name.startswith(prefix) for prefix in expected_persistent_thread_prefixes) ] - # Filter out threads we've already seen (from previous tests) truly_new = [t for t in new_threads if t.ident not in _seen_threads] From f13b2b325244a86ccc2a9a9d6568162723dcb826 Mon Sep 17 00:00:00 2001 From: Summer Yang Date: Thu, 12 Mar 2026 14:13:08 -0700 Subject: [PATCH 05/21] chore: retrigger CI From 3197ad3307a1a9e2f6c9d57589b14ad90d5e99cb Mon Sep 17 00:00:00 2001 From: Summer Yang Date: Thu, 12 Mar 2026 15:58:14 -0700 Subject: [PATCH 06/21] fix(test): join threads directly in crash_module fixture mod.stop() is a no-op when the watchdog already called it, so capture thread IDs before the test and join new ones in teardown. --- dimos/core/test_native_module.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/dimos/core/test_native_module.py b/dimos/core/test_native_module.py index 600795b031..d5609003ba 100644 --- a/dimos/core/test_native_module.py +++ b/dimos/core/test_native_module.py @@ -96,12 +96,15 @@ def start(self) -> None: @pytest.fixture def crash_module() -> Generator[StubNativeModule, None, None]: """Create a StubNativeModule that dies after 0.2s, ensuring cleanup.""" + before = {t.ident for t in threading.enumerate()} mod = StubNativeModule(die_after=0.2) yield mod - # Join watchdog, LCM, and event-loop threads from the test thread. - # The watchdog's self.stop() can't join itself, so without this the - # threads leak. stop() is idempotent. - mod.stop() + # The watchdog calls stop() from its own thread, which sets + # _module_closed=True. A second stop() from here is then a no-op, + # so we explicitly join any threads the test created. + for t in threading.enumerate(): + if t.ident not in before and t is not threading.current_thread(): + t.join(timeout=5) def test_process_crash_triggers_stop(crash_module: StubNativeModule) -> None: From 43d5434b0779e5e92e6a4fd16fe61a74e14f94d3 Mon Sep 17 00:00:00 2001 From: SUMMERxYANG <69720581+SUMMERxYANG@users.noreply.github.com> Date: Thu, 12 Mar 2026 23:09:08 +0000 Subject: [PATCH 07/21] CI code cleanup --- dimos/core/test_native_module.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dimos/core/test_native_module.py b/dimos/core/test_native_module.py index d5609003ba..5811da4b08 100644 --- a/dimos/core/test_native_module.py +++ b/dimos/core/test_native_module.py @@ -19,7 +19,6 @@ """ from collections.abc import Generator -from dataclasses import dataclass import json from pathlib import Path import threading From 1ff8769015e7e9cae1d14fcddab506488e80b227 Mon Sep 17 00:00:00 2001 From: Summer Yang Date: Mon, 23 Mar 2026 16:58:44 -0700 Subject: [PATCH 08/21] fix(native_module): preserve watchdog reference so second stop() can join it --- dimos/core/native_module.py | 12 ++++++++--- dimos/core/test_native_module.py | 35 +++++++++++--------------------- 2 files changed, 21 insertions(+), 26 deletions(-) diff --git a/dimos/core/native_module.py b/dimos/core/native_module.py index f4a674cb5d..00cc2d77d0 100644 --- a/dimos/core/native_module.py +++ b/dimos/core/native_module.py @@ -183,11 +183,17 @@ def stop(self) -> None: ) self._process.kill() self._process.wait(timeout=5) - if self._watchdog is not None and self._watchdog is not threading.current_thread(): - self._watchdog.join(timeout=2) - self._watchdog = None self._process = None super().stop() + # Join the watchdog AFTER super().stop() so all module threads are + # cleaned up first. When the watchdog itself is the caller (crash + # path), it skips joining itself — but the thread exits naturally + # right after this returns. A second stop() from external code + # (e.g. test teardown) will reach here and join the now-finished + # watchdog thread, preventing monitor_threads from seeing a leak. + if self._watchdog is not None and self._watchdog is not threading.current_thread(): + self._watchdog.join(timeout=2) + self._watchdog = None def _watch_process(self) -> None: """Block until the native process exits; trigger stop() if it crashed.""" diff --git a/dimos/core/test_native_module.py b/dimos/core/test_native_module.py index 5811da4b08..a7e6bd2b9a 100644 --- a/dimos/core/test_native_module.py +++ b/dimos/core/test_native_module.py @@ -18,10 +18,8 @@ The echo script writes received CLI args to a temp file for assertions. """ -from collections.abc import Generator import json from pathlib import Path -import threading import time import pytest @@ -92,35 +90,26 @@ def start(self) -> None: pass -@pytest.fixture -def crash_module() -> Generator[StubNativeModule, None, None]: - """Create a StubNativeModule that dies after 0.2s, ensuring cleanup.""" - before = {t.ident for t in threading.enumerate()} - mod = StubNativeModule(die_after=0.2) - yield mod - # The watchdog calls stop() from its own thread, which sets - # _module_closed=True. A second stop() from here is then a no-op, - # so we explicitly join any threads the test created. - for t in threading.enumerate(): - if t.ident not in before and t is not threading.current_thread(): - t.join(timeout=5) - - -def test_process_crash_triggers_stop(crash_module: StubNativeModule) -> None: +def test_process_crash_triggers_stop() -> None: """When the native process dies unexpectedly, the watchdog calls stop().""" - crash_module.pointcloud.transport = LCMTransport("/pc", PointCloud2) - crash_module.start() + mod = StubNativeModule(die_after=0.2) + mod.pointcloud.transport = LCMTransport("/pc", PointCloud2) + mod.start() - assert crash_module._process is not None - pid = crash_module._process.pid + assert mod._process is not None + pid = mod._process.pid # Wait for the process to die and the watchdog to call stop() for _ in range(30): time.sleep(0.1) - if crash_module._process is None: + if mod._process is None: break - assert crash_module._process is None, f"Watchdog did not clean up after process {pid} died" + assert mod._process is None, f"Watchdog did not clean up after process {pid} died" + + # Join the watchdog thread. stop() is idempotent but will now join the + # watchdog on the second call since the reference is preserved. + mod.stop() @pytest.mark.slow From c202c57a504fe4ac3a1c50ff3aa78b7e228403b6 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Tue, 24 Mar 2026 10:39:47 -0700 Subject: [PATCH 09/21] minimal fix --- dimos/core/native_module.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/dimos/core/native_module.py b/dimos/core/native_module.py index 00cc2d77d0..1a591772f4 100644 --- a/dimos/core/native_module.py +++ b/dimos/core/native_module.py @@ -169,8 +169,8 @@ def start(self) -> None: self._watchdog = threading.Thread(target=self._watch_process, daemon=True) self._watchdog.start() - @rpc - def stop(self) -> None: + def _clean_all_but_watchdog(self) -> None: + """A cleanup helper designed to be called inside of the watchdog and outside of the watchdong""" self._stopping = True if self._process is not None and self._process.poll() is None: logger.info("Stopping native process", pid=self._process.pid) @@ -182,21 +182,23 @@ def stop(self) -> None: "Native process did not exit, sending SIGKILL", pid=self._process.pid ) self._process.kill() - self._process.wait(timeout=5) + try: + self._process.wait(timeout=5) + except Exception as error: + print(f'''error = {error}''') self._process = None + + + @rpc + def stop(self) -> None: + self._clean_all_but_watchdog() super().stop() - # Join the watchdog AFTER super().stop() so all module threads are - # cleaned up first. When the watchdog itself is the caller (crash - # path), it skips joining itself — but the thread exits naturally - # right after this returns. A second stop() from external code - # (e.g. test teardown) will reach here and join the now-finished - # watchdog thread, preventing monitor_threads from seeing a leak. - if self._watchdog is not None and self._watchdog is not threading.current_thread(): + if self._watchdog is not None: self._watchdog.join(timeout=2) self._watchdog = None def _watch_process(self) -> None: - """Block until the native process exits; trigger stop() if it crashed.""" + """Block until the native process exits; trigger cleanup if it crashed.""" if self._process is None: return @@ -213,7 +215,7 @@ def _watch_process(self) -> None: pid=self._process.pid, returncode=rc, ) - self.stop() + self._clean_all_but_watchdog() def _start_reader(self, stream: IO[bytes] | None, level: str) -> threading.Thread: """Spawn a daemon thread that pipes a subprocess stream through the logger.""" From fe787bbdc6e574f78f228590f4f8ae2b7beea555 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Tue, 24 Mar 2026 22:43:50 -0700 Subject: [PATCH 10/21] fully ideal approach, untested --- dimos/agents/mcp/mcp_server.py | 5 +- dimos/core/module.py | 76 +- dimos/core/native_module.py | 125 +-- dimos/core/test_core.py | 2 +- dimos/perception/detection/conftest.py | 10 +- .../perception/detection/reid/test_module.py | 2 +- dimos/robot/unitree/b1/test_connection.py | 18 +- dimos/utils/test_thread_utils.py | 897 ++++++++++++++++++ dimos/utils/thread_utils.py | 542 +++++++++++ dimos/utils/typing_utils.py | 45 + 10 files changed, 1544 insertions(+), 178 deletions(-) create mode 100644 dimos/utils/test_thread_utils.py create mode 100644 dimos/utils/thread_utils.py create mode 100644 dimos/utils/typing_utils.py diff --git a/dimos/agents/mcp/mcp_server.py b/dimos/agents/mcp/mcp_server.py index e5697542fb..e1179abe55 100644 --- a/dimos/agents/mcp/mcp_server.py +++ b/dimos/agents/mcp/mcp_server.py @@ -195,7 +195,7 @@ def start(self) -> None: def stop(self) -> None: if self._uvicorn_server: self._uvicorn_server.should_exit = True - loop = self._loop + loop = self._async_thread.loop if loop is not None and self._serve_future is not None: self._serve_future.result(timeout=5.0) self._uvicorn_server = None @@ -269,6 +269,5 @@ def _start_server(self, port: int | None = None) -> None: config = uvicorn.Config(app, host=_host, port=_port, log_level="info") server = uvicorn.Server(config) self._uvicorn_server = server - loop = self._loop - assert loop is not None + loop = self._async_thread.loop self._serve_future = asyncio.run_coroutine_threadsafe(server.serve(), loop) diff --git a/dimos/core/module.py b/dimos/core/module.py index ab21ce17a9..40fd34f5be 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -22,6 +22,7 @@ from typing import ( TYPE_CHECKING, Any, + Literal, Protocol, get_args, get_origin, @@ -43,6 +44,9 @@ from dimos.protocol.tf import LCMTF, TFSpec from dimos.utils import colors from dimos.utils.generic import classproperty +from dimos.utils.thread_utils import AsyncModuleThread, ThreadSafeVal + +ModState = Literal["init", "started", "stopping", "stopped"] if TYPE_CHECKING: from dimos.core.blueprints import Blueprint @@ -62,19 +66,6 @@ class SkillInfo: args_schema: str -def get_loop() -> tuple[asyncio.AbstractEventLoop, threading.Thread | None]: - try: - running_loop = asyncio.get_running_loop() - return running_loop, None - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - thr = threading.Thread(target=loop.run_forever, daemon=True) - thr.start() - return loop, thr - - class ModuleConfig(BaseConfig): rpc_transport: type[RPCSpec] = LCMRPC tf_transport: type[TFSpec] = LCMTF # type: ignore[type-arg] @@ -96,20 +87,20 @@ class ModuleBase(Configurable[ModuleConfigT], Resource): _rpc: RPCSpec | None = None _tf: TFSpec[Any] | None = None - _loop: asyncio.AbstractEventLoop | None = None - _loop_thread: threading.Thread | None + _async_thread: AsyncModuleThread _disposables: CompositeDisposable _bound_rpc_calls: dict[str, RpcCall] = {} - _module_closed: bool = False - _module_closed_lock: threading.Lock + mod_state: ThreadSafeVal[ModState] rpc_calls: list[str] = [] def __init__(self, config_args: dict[str, Any]): super().__init__(**config_args) - self._module_closed_lock = threading.Lock() - self._loop, self._loop_thread = get_loop() self._disposables = CompositeDisposable() + self.mod_state = ThreadSafeVal[ModState]("init") + self._async_thread = AsyncModuleThread( # NEEDS to be created after self._disposables exists + module=self + ) try: self.rpc = self.config.rpc_transport() self.rpc.serve_module_rpc(self) @@ -126,38 +117,30 @@ def frame_id(self) -> str: @rpc def start(self) -> None: - pass + with self.mod_state as state: + if state == "stopped": + raise RuntimeError(f"{type(self).__name__} cannot be restarted after stop") + self.mod_state.set("started") @rpc def stop(self) -> None: - self._close_module() + self._stop() - def _close_module(self) -> None: - with self._module_closed_lock: - if self._module_closed: + def _stop(self) -> None: + with self.mod_state as state: + if state in ("stopping", "stopped"): return - self._module_closed = True - - self._close_rpc() - - # Save into local variables to avoid race when stopping concurrently - # (from RPC and worker shutdown) - loop_thread = getattr(self, "_loop_thread", None) - loop = getattr(self, "_loop", None) + self.mod_state.set("stopping") - if loop_thread: - if loop_thread.is_alive(): - if loop: - loop.call_soon_threadsafe(loop.stop) - loop_thread.join(timeout=2) - self._loop = None - self._loop_thread = None + if self.rpc: + self.rpc.stop() # type: ignore[attr-defined] + self.rpc = None # type: ignore[assignment] if hasattr(self, "_tf") and self._tf is not None: self._tf.stop() self._tf = None if hasattr(self, "_disposables"): - self._disposables.dispose() + self._disposables.dispose() # stops _async_thread via disposable # Break the In/Out -> owner -> self reference cycle so the instance # can be freed by refcount instead of waiting for GC. @@ -165,19 +148,12 @@ def _close_module(self) -> None: if isinstance(attr, (In, Out)): attr.owner = None - def _close_rpc(self) -> None: - if self.rpc: - self.rpc.stop() # type: ignore[attr-defined] - self.rpc = None # type: ignore[assignment] - def __getstate__(self): # type: ignore[no-untyped-def] """Exclude unpicklable runtime attributes when serializing.""" state = self.__dict__.copy() # Remove unpicklable attributes state.pop("_disposables", None) - state.pop("_module_closed_lock", None) - state.pop("_loop", None) - state.pop("_loop_thread", None) + state.pop("_async_thread", None) state.pop("_rpc", None) state.pop("_tf", None) return state @@ -187,9 +163,7 @@ def __setstate__(self, state) -> None: # type: ignore[no-untyped-def] self.__dict__.update(state) # Reinitialize runtime attributes self._disposables = CompositeDisposable() - self._module_closed_lock = threading.Lock() - self._loop = None - self._loop_thread = None + self._async_thread = None # type: ignore[assignment] self._rpc = None self._tf = None diff --git a/dimos/core/native_module.py b/dimos/core/native_module.py index 1a591772f4..bdc46e2cab 100644 --- a/dimos/core/native_module.py +++ b/dimos/core/native_module.py @@ -42,20 +42,18 @@ class MyCppModule(NativeModule): import enum import inspect -import json import os from pathlib import Path -import signal import subprocess import sys -import threading -from typing import IO, Any +from typing import Any from pydantic import Field from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.utils.logging_config import setup_logger +from dimos.utils.thread_utils import ModuleProcess if sys.version_info < (3, 13): from typing_extensions import TypeVar @@ -128,9 +126,7 @@ class NativeModule(Module[_NativeConfig]): """ default_config: type[_NativeConfig] = NativeModuleConfig # type: ignore[assignment] - _process: subprocess.Popen[bytes] | None = None - _watchdog: threading.Thread | None = None - _stopping: bool = False + _proc: ModuleProcess | None = None def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -138,109 +134,22 @@ def __init__(self, **kwargs: Any) -> None: @rpc def start(self) -> None: - if self._process is not None and self._process.poll() is None: - logger.warning("Native process already running", pid=self._process.pid) - return - + super().start() self._maybe_build() - - topics = self._collect_topics() - - cmd = [self.config.executable] - for name, topic_str in topics.items(): - cmd.extend([f"--{name}", topic_str]) - cmd.extend(self.config.to_cli_args()) - cmd.extend(self.config.extra_args) - - env = {**os.environ, **self.config.extra_env} - cwd = self.config.cwd or str(Path(self.config.executable).resolve().parent) - - logger.info("Starting native process", cmd=" ".join(cmd), cwd=cwd) - self._process = subprocess.Popen( - cmd, - env=env, - cwd=cwd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - logger.info("Native process started", pid=self._process.pid) - - self._stopping = False - self._watchdog = threading.Thread(target=self._watch_process, daemon=True) - self._watchdog.start() - - def _clean_all_but_watchdog(self) -> None: - """A cleanup helper designed to be called inside of the watchdog and outside of the watchdong""" - self._stopping = True - if self._process is not None and self._process.poll() is None: - logger.info("Stopping native process", pid=self._process.pid) - self._process.send_signal(signal.SIGTERM) - try: - self._process.wait(timeout=self.config.shutdown_timeout) - except subprocess.TimeoutExpired: - logger.warning( - "Native process did not exit, sending SIGKILL", pid=self._process.pid - ) - self._process.kill() - try: - self._process.wait(timeout=5) - except Exception as error: - print(f'''error = {error}''') - self._process = None - - - @rpc - def stop(self) -> None: - self._clean_all_but_watchdog() - super().stop() - if self._watchdog is not None: - self._watchdog.join(timeout=2) - self._watchdog = None - - def _watch_process(self) -> None: - """Block until the native process exits; trigger cleanup if it crashed.""" - if self._process is None: - return - - stdout_t = self._start_reader(self._process.stdout, "info") - stderr_t = self._start_reader(self._process.stderr, "warning") - rc = self._process.wait() - stdout_t.join(timeout=2) - stderr_t.join(timeout=2) - - if self._stopping: - return - logger.error( - "Native process died unexpectedly", - pid=self._process.pid, - returncode=rc, + self._proc = ModuleProcess( + module=self, + args=[ + self.config.executable, + *[arg for name, topic in self._collect_topics().items() for arg in (f"--{name}", topic)], + *self.config.to_cli_args(), + *self.config.extra_args, + ], + env={**os.environ, **self.config.extra_env}, + cwd=self.config.cwd or str(Path(self.config.executable).resolve().parent), + on_exit=self.stop, + shutdown_timeout=self.config.shutdown_timeout, + log_json=self.config.log_format == LogFormat.JSON, ) - self._clean_all_but_watchdog() - - def _start_reader(self, stream: IO[bytes] | None, level: str) -> threading.Thread: - """Spawn a daemon thread that pipes a subprocess stream through the logger.""" - t = threading.Thread(target=self._read_log_stream, args=(stream, level), daemon=True) - t.start() - return t - - def _read_log_stream(self, stream: IO[bytes] | None, level: str) -> None: - if stream is None: - return - log_fn = getattr(logger, level) - for raw in stream: - line = raw.decode("utf-8", errors="replace").rstrip() - if not line: - continue - if self.config.log_format == LogFormat.JSON: - try: - data = json.loads(line) - event = data.pop("event", line) - log_fn(event, **data) - continue - except (json.JSONDecodeError, TypeError): - logger.warning("malformed JSON from native module", raw=line) - log_fn(line, pid=self._process.pid if self._process else None) - stream.close() def _resolve_paths(self) -> None: """Resolve relative ``cwd`` and ``executable`` against the subclass's source file.""" diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index 3bd1383761..1400ae1ebb 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -89,7 +89,7 @@ def test_classmethods() -> None: ) assert hasattr(class_rpcs["start"], "__rpc__"), "start should have __rpc__ attribute" - nav._close_module() + nav._stop() @pytest.mark.slow diff --git a/dimos/perception/detection/conftest.py b/dimos/perception/detection/conftest.py index 8c1a65eb8b..81f6d1805b 100644 --- a/dimos/perception/detection/conftest.py +++ b/dimos/perception/detection/conftest.py @@ -221,7 +221,7 @@ def moment_provider(**kwargs) -> Moment2D: yield moment_provider moment_provider.cache_clear() - module._close_module() + module._stop() @pytest.fixture(scope="session") @@ -256,7 +256,7 @@ def moment_provider(**kwargs) -> Moment3D: yield moment_provider moment_provider.cache_clear() if module is not None: - module._close_module() + module._stop() @pytest.fixture(scope="session") @@ -290,9 +290,9 @@ def object_db_module(get_moment): yield moduleDB - module2d._close_module() - module3d._close_module() - moduleDB._close_module() + module2d._stop() + module3d._stop() + moduleDB._stop() @pytest.fixture(scope="session") diff --git a/dimos/perception/detection/reid/test_module.py b/dimos/perception/detection/reid/test_module.py index f5672c1f67..e2c603c307 100644 --- a/dimos/perception/detection/reid/test_module.py +++ b/dimos/perception/detection/reid/test_module.py @@ -40,5 +40,5 @@ def test_reid_ingress(imageDetections2d) -> None: print("Processing detections through ReidModule...") reid_module.annotations._transport = LCMTransport("/annotations", ImageAnnotations) reid_module.ingress(imageDetections2d) - reid_module._close_module() + reid_module._stop() print("✓ ReidModule ingress test completed successfully") diff --git a/dimos/robot/unitree/b1/test_connection.py b/dimos/robot/unitree/b1/test_connection.py index e43a3124dc..011853d172 100644 --- a/dimos/robot/unitree/b1/test_connection.py +++ b/dimos/robot/unitree/b1/test_connection.py @@ -73,7 +73,7 @@ def test_watchdog_actually_zeros_commands(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn._stop() def test_watchdog_resets_on_new_command(self) -> None: """Test that watchdog timeout resets when new command arrives.""" @@ -121,7 +121,7 @@ def test_watchdog_resets_on_new_command(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn._stop() def test_watchdog_thread_efficiency(self) -> None: """Test that watchdog uses only one thread regardless of command rate.""" @@ -155,7 +155,7 @@ def test_watchdog_thread_efficiency(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn._stop() def test_watchdog_with_send_loop_blocking(self) -> None: """Test that watchdog still works if send loop blocks.""" @@ -202,7 +202,7 @@ def blocking_send_loop() -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn._stop() def test_continuous_commands_prevent_timeout(self) -> None: """Test that continuous commands prevent watchdog timeout.""" @@ -237,7 +237,7 @@ def test_continuous_commands_prevent_timeout(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn._stop() def test_watchdog_timing_accuracy(self) -> None: """Test that watchdog zeros commands at approximately 200ms.""" @@ -278,7 +278,7 @@ def test_watchdog_timing_accuracy(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn._stop() def test_mode_changes_with_watchdog(self) -> None: """Test that mode changes work correctly with watchdog.""" @@ -321,7 +321,7 @@ def test_mode_changes_with_watchdog(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn._stop() def test_watchdog_stops_movement_when_commands_stop(self) -> None: """Verify watchdog zeros commands when packets stop being sent.""" @@ -379,7 +379,7 @@ def test_watchdog_stops_movement_when_commands_stop(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn._stop() def test_rapid_command_thread_safety(self) -> None: """Test thread safety with rapid commands from multiple threads.""" @@ -428,4 +428,4 @@ def send_commands(thread_id) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn._stop() diff --git a/dimos/utils/test_thread_utils.py b/dimos/utils/test_thread_utils.py new file mode 100644 index 0000000000..87c4a883e2 --- /dev/null +++ b/dimos/utils/test_thread_utils.py @@ -0,0 +1,897 @@ +"""Exhaustive tests for dimos/utils/thread_utils.py + +Covers: ThreadSafeVal, ModuleThread, AsyncModuleThread, ModuleProcess, safe_thread_map. +Focuses on deadlocks, race conditions, idempotency, and edge cases under load. +""" + +from __future__ import annotations + +import asyncio +import os +import pickle +import signal +import subprocess +import sys +import threading +import time +from unittest import mock + +import pytest +from reactivex.disposable import CompositeDisposable + +from dimos.utils.thread_utils import ( + AsyncModuleThread, + ModuleProcess, + ModuleThread, + ThreadSafeVal, + safe_thread_map, +) + + +# --------------------------------------------------------------------------- +# Helpers: fake ModuleBase for testing ModuleThread / AsyncModuleThread / ModuleProcess +# --------------------------------------------------------------------------- + + +class FakeModule: + """Minimal stand-in for ModuleBase — just needs _disposables.""" + + def __init__(self) -> None: + self._disposables = CompositeDisposable() + + def dispose(self) -> None: + self._disposables.dispose() + + +# =================================================================== +# ThreadSafeVal Tests +# =================================================================== + + +class TestThreadSafeVal: + def test_basic_get_set(self) -> None: + v = ThreadSafeVal(42) + assert v.get() == 42 + v.set(99) + assert v.get() == 99 + + def test_bool_truthy(self) -> None: + v = ThreadSafeVal(True) + assert bool(v) is True + v.set(False) + assert bool(v) is False + + def test_bool_zero(self) -> None: + v = ThreadSafeVal(0) + assert bool(v) is False + v.set(1) + assert bool(v) is True + + def test_context_manager_returns_value(self) -> None: + v = ThreadSafeVal("hello") + with v as val: + assert val == "hello" + + def test_set_inside_context_manager_no_deadlock(self) -> None: + """The critical test: set() inside a with block must NOT deadlock. + + This was a confirmed bug when using threading.Lock (non-reentrant). + Fixed by using threading.RLock. + """ + v = ThreadSafeVal(0) + result = threading.Event() + + def do_it() -> None: + with v as val: + v.set(val + 1) + result.set() + + t = threading.Thread(target=do_it) + t.start() + t.join(timeout=2) + assert result.is_set(), "Deadlocked! set() inside with block hung" + assert v.get() == 1 + + def test_get_inside_context_manager_no_deadlock(self) -> None: + v = ThreadSafeVal(10) + result = threading.Event() + + def do_it() -> None: + with v as val: + _ = v.get() + result.set() + + t = threading.Thread(target=do_it) + t.start() + t.join(timeout=2) + assert result.is_set(), "Deadlocked! get() inside with block hung" + + def test_bool_inside_context_manager_no_deadlock(self) -> None: + v = ThreadSafeVal(True) + result = threading.Event() + + def do_it() -> None: + with v as val: + _ = bool(v) + result.set() + + t = threading.Thread(target=do_it) + t.start() + t.join(timeout=2) + assert result.is_set(), "Deadlocked! bool() inside with block hung" + + def test_context_manager_blocks_other_threads(self) -> None: + """While one thread holds the lock via `with`, others should block on set().""" + v = ThreadSafeVal(0) + gate = threading.Event() + other_started = threading.Event() + other_finished = threading.Event() + + def holder() -> None: + with v as val: + gate.wait(timeout=5) # hold the lock until signaled + + def setter() -> None: + other_started.set() + v.set(42) # should block until holder releases + other_finished.set() + + t1 = threading.Thread(target=holder) + t2 = threading.Thread(target=setter) + t1.start() + time.sleep(0.05) # let holder acquire lock + t2.start() + other_started.wait(timeout=2) + time.sleep(0.1) + # setter should be blocked + assert not other_finished.is_set(), "set() did not block while lock was held" + gate.set() # release holder + t1.join(timeout=2) + t2.join(timeout=2) + assert other_finished.is_set() + assert v.get() == 42 + + def test_concurrent_increments(self) -> None: + """Many threads doing atomic read-modify-write should not lose updates.""" + v = ThreadSafeVal(0) + n_threads = 50 + n_increments = 100 + + def incrementer() -> None: + for _ in range(n_increments): + with v as val: + v.set(val + 1) + + threads = [threading.Thread(target=incrementer) for _ in range(n_threads)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=10) + assert v.get() == n_threads * n_increments + + def test_concurrent_increments_stress(self) -> None: + """Run the concurrent increment test multiple times to catch races.""" + for _ in range(10): + self.test_concurrent_increments() + + def test_pickle_roundtrip(self) -> None: + v = ThreadSafeVal({"key": [1, 2, 3]}) + data = pickle.dumps(v) + v2 = pickle.loads(data) + assert v2.get() == {"key": [1, 2, 3]} + # Verify the new instance has a working lock + with v2 as val: + v2.set({**val, "new": True}) + assert v2.get()["new"] is True + + def test_repr(self) -> None: + v = ThreadSafeVal("test") + assert repr(v) == "ThreadSafeVal('test')" + + def test_dict_type(self) -> None: + v = ThreadSafeVal({"running": False, "count": 0}) + with v as s: + v.set({**s, "running": True}) + assert v.get() == {"running": True, "count": 0} + + def test_string_literal_type(self) -> None: + """Simulates the ModState pattern from module.py.""" + v = ThreadSafeVal("init") + with v as state: + if state == "init": + v.set("started") + assert v.get() == "started" + + with v as state: + if state in ("stopping", "stopped"): + pass # no-op + else: + v.set("stopping") + assert v.get() == "stopping" + + def test_nested_with_no_deadlock(self) -> None: + """RLock should allow the same thread to nest with blocks.""" + v = ThreadSafeVal(0) + result = threading.Event() + + def do_it() -> None: + with v as val1: + with v as val2: + v.set(val2 + 1) + result.set() + + t = threading.Thread(target=do_it) + t.start() + t.join(timeout=2) + assert result.is_set(), "Nested with blocks deadlocked!" + + +# =================================================================== +# ModuleThread Tests +# =================================================================== + + +class TestModuleThread: + def test_basic_lifecycle(self) -> None: + mod = FakeModule() + ran = threading.Event() + + def target() -> None: + ran.set() + + mt = ModuleThread(module=mod, target=target, name="test-basic") + ran.wait(timeout=2) + assert ran.is_set() + mt.stop() + assert not mt.is_alive + + def test_auto_start(self) -> None: + mod = FakeModule() + started = threading.Event() + mt = ModuleThread(module=mod, target=started.set, name="test-autostart") + started.wait(timeout=2) + assert started.is_set() + mt.stop() + + def test_deferred_start(self) -> None: + mod = FakeModule() + started = threading.Event() + mt = ModuleThread(module=mod, target=started.set, name="test-deferred", start=False) + time.sleep(0.1) + assert not started.is_set() + mt.start() + started.wait(timeout=2) + assert started.is_set() + mt.stop() + + def test_stopping_property(self) -> None: + mod = FakeModule() + saw_stopping = threading.Event() + holder: list[ModuleThread] = [] + + def target() -> None: + while not holder[0].stopping: + time.sleep(0.01) + saw_stopping.set() + + mt = ModuleThread(module=mod, target=target, name="test-stopping", start=False) + holder.append(mt) + mt.start() + time.sleep(0.05) + mt.stop() + saw_stopping.wait(timeout=2) + assert saw_stopping.is_set() + + def test_stop_idempotent(self) -> None: + mod = FakeModule() + mt = ModuleThread(module=mod, target=lambda: time.sleep(0.01), name="test-idem") + time.sleep(0.05) + mt.stop() + mt.stop() # second call should not raise + mt.stop() # third call should not raise + + def test_stop_from_managed_thread_no_deadlock(self) -> None: + """The thread calling stop() on itself should not deadlock.""" + mod = FakeModule() + result = threading.Event() + holder: list[ModuleThread] = [] + + def target() -> None: + holder[0].stop() # stop ourselves — should not deadlock + result.set() + + mt = ModuleThread(module=mod, target=target, name="test-self-stop", start=False) + holder.append(mt) + mt.start() + result.wait(timeout=3) + assert result.is_set(), "Deadlocked when thread called stop() on itself" + + def test_dispose_stops_thread(self) -> None: + """Module dispose should stop the thread via the registered Disposable.""" + mod = FakeModule() + running = threading.Event() + holder: list[ModuleThread] = [] + + def target() -> None: + running.set() + while not holder[0].stopping: + time.sleep(0.01) + + mt = ModuleThread(module=mod, target=target, name="test-dispose", start=False) + holder.append(mt) + mt.start() + running.wait(timeout=2) + mod.dispose() + time.sleep(0.1) + assert not mt.is_alive + + def test_concurrent_stop_calls(self) -> None: + """Multiple threads calling stop() concurrently should not crash.""" + mod = FakeModule() + holder: list[ModuleThread] = [] + + def target() -> None: + while not holder[0].stopping: + time.sleep(0.01) + + mt = ModuleThread(module=mod, target=target, name="test-concurrent-stop", start=False) + holder.append(mt) + mt.start() + time.sleep(0.05) + + errors = [] + + def stop_it() -> None: + try: + mt.stop() + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=stop_it) for _ in range(20)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=5) + assert not errors, f"Concurrent stop() raised: {errors}" + + def test_close_timeout_respected(self) -> None: + """If the thread ignores the stop signal, stop() should return after close_timeout.""" + mod = FakeModule() + + def stubborn_target() -> None: + time.sleep(10) # ignores stopping signal + + mt = ModuleThread( + module=mod, target=stubborn_target, name="test-timeout", close_timeout=0.2 + ) + start = time.monotonic() + mt.stop() + elapsed = time.monotonic() - start + assert elapsed < 1.0, f"stop() took {elapsed}s, expected ~0.2s" + + def test_stop_concurrent_with_dispose(self) -> None: + """Calling stop() and dispose() concurrently should not crash.""" + for _ in range(20): + mod = FakeModule() + holder: list[ModuleThread] = [] + + def target() -> None: + while not holder[0].stopping: + time.sleep(0.001) + + mt = ModuleThread(module=mod, target=target, name="test-stop-dispose", start=False) + holder.append(mt) + mt.start() + time.sleep(0.02) + # Race: stop and dispose from different threads + t1 = threading.Thread(target=mt.stop) + t2 = threading.Thread(target=mod.dispose) + t1.start() + t2.start() + t1.join(timeout=3) + t2.join(timeout=3) + + +# =================================================================== +# AsyncModuleThread Tests +# =================================================================== + + +class TestAsyncModuleThread: + def test_creates_loop_and_thread(self) -> None: + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + assert amt.loop is not None + assert amt.loop.is_running() + assert amt.is_alive + amt.stop() + assert not amt.is_alive + + def test_stop_idempotent(self) -> None: + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + amt.stop() + amt.stop() # should not raise + amt.stop() + + def test_dispose_stops_loop(self) -> None: + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + assert amt.is_alive + mod.dispose() + time.sleep(0.1) + assert not amt.is_alive + + def test_can_schedule_coroutine(self) -> None: + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + result = [] + + async def coro() -> None: + result.append(42) + + future = asyncio.run_coroutine_threadsafe(coro(), amt.loop) + future.result(timeout=2) + assert result == [42] + amt.stop() + + def test_stop_with_pending_work(self) -> None: + """Stop should succeed even with long-running tasks on the loop.""" + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + started = threading.Event() + + async def slow_coro() -> None: + started.set() + await asyncio.sleep(10) + + asyncio.run_coroutine_threadsafe(slow_coro(), amt.loop) + started.wait(timeout=2) + # stop() should not hang waiting for the coroutine + start = time.monotonic() + amt.stop() + elapsed = time.monotonic() - start + assert elapsed < 5.0, f"stop() hung for {elapsed}s with pending coroutine" + + def test_concurrent_stop(self) -> None: + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + errors = [] + + def stop_it() -> None: + try: + amt.stop() + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=stop_it) for _ in range(20)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=5) + assert not errors + + +# =================================================================== +# ModuleProcess Tests +# =================================================================== + + +# Helper: path to a python that sleeps or echoes +PYTHON = sys.executable + + +class TestModuleProcess: + def test_basic_lifecycle(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + shutdown_timeout=2.0, + ) + assert mp.is_alive + assert mp.pid is not None + mp.stop() + assert not mp.is_alive + assert mp.pid is None + + def test_stop_idempotent(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + shutdown_timeout=1.0, + ) + mp.stop() + mp.stop() # should not raise + mp.stop() + + def test_dispose_stops_process(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + shutdown_timeout=2.0, + ) + pid = mp.pid + mod.dispose() + time.sleep(0.5) + assert not mp.is_alive + + def test_on_exit_fires_on_natural_exit(self) -> None: + """on_exit should fire when the process exits on its own.""" + mod = FakeModule() + exit_called = threading.Event() + + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "print('done')"], + on_exit=exit_called.set, + ) + exit_called.wait(timeout=5) + assert exit_called.is_set(), "on_exit was not called after natural process exit" + + def test_on_exit_fires_on_crash(self) -> None: + mod = FakeModule() + exit_called = threading.Event() + + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import sys; sys.exit(1)"], + on_exit=exit_called.set, + ) + exit_called.wait(timeout=5) + assert exit_called.is_set(), "on_exit was not called after process crash" + + def test_on_exit_not_fired_on_stop(self) -> None: + """on_exit should NOT fire when stop() kills the process.""" + mod = FakeModule() + exit_called = threading.Event() + + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + on_exit=exit_called.set, + shutdown_timeout=2.0, + ) + time.sleep(0.2) # let watchdog start + mp.stop() + time.sleep(1.0) # give watchdog time to potentially fire + assert not exit_called.is_set(), "on_exit fired after intentional stop()" + + def test_stdout_logged(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "print('hello from subprocess')"], + ) + time.sleep(1.0) # let output be read + mp.stop() + + def test_stderr_logged(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import sys; sys.stderr.write('error msg\\n')"], + ) + time.sleep(1.0) + mp.stop() + + def test_log_json_mode(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", """import json; print(json.dumps({"event": "test", "key": "val"}))"""], + log_json=True, + ) + time.sleep(1.0) + mp.stop() + + def test_log_json_malformed(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "print('not json')"], + log_json=True, + ) + time.sleep(1.0) + mp.stop() + + def test_stop_process_that_ignores_sigterm(self) -> None: + """Process that ignores SIGTERM should be killed with SIGKILL.""" + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[ + PYTHON, + "-c", + "import signal, time; signal.signal(signal.SIGTERM, signal.SIG_IGN); time.sleep(60)", + ], + shutdown_timeout=0.5, + kill_timeout=2.0, + ) + time.sleep(0.2) + start = time.monotonic() + mp.stop() + elapsed = time.monotonic() - start + assert not mp.is_alive + # Should take roughly shutdown_timeout (0.5) + a bit for SIGKILL + assert elapsed < 5.0 + + def test_stop_already_dead_process(self) -> None: + """stop() on a process that already exited should not raise.""" + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "pass"], # exits immediately + ) + time.sleep(1.0) # let it die + mp.stop() # should not raise + + def test_concurrent_stop(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + shutdown_timeout=2.0, + ) + errors = [] + + def stop_it() -> None: + try: + mp.stop() + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=stop_it) for _ in range(20)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=10) + assert not errors, f"Concurrent stop() raised: {errors}" + + def test_on_exit_calls_module_stop_no_deadlock(self) -> None: + """Simulate the real pattern: on_exit=module.stop, which disposes the + ModuleProcess, which tries to stop its watchdog from inside the watchdog. + Must not deadlock. + """ + mod = FakeModule() + stop_called = threading.Event() + + def fake_module_stop() -> None: + """Simulates module.stop() -> _stop() -> dispose()""" + mod.dispose() + stop_called.set() + + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "pass"], # exits immediately + on_exit=fake_module_stop, + ) + stop_called.wait(timeout=5) + assert stop_called.is_set(), "Deadlocked! on_exit -> dispose -> stop chain hung" + + def test_on_exit_calls_module_stop_no_deadlock_stress(self) -> None: + """Run the deadlock test multiple times under load.""" + for i in range(10): + self.test_on_exit_calls_module_stop_no_deadlock() + + def test_deferred_start(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + start=False, + ) + assert not mp.is_alive + mp.start() + assert mp.is_alive + mp.stop() + + def test_env_passed(self) -> None: + mod = FakeModule() + exit_called = threading.Event() + + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import os, sys; sys.exit(0 if os.environ.get('MY_VAR') == '42' else 1)"], + env={**os.environ, "MY_VAR": "42"}, + on_exit=exit_called.set, + ) + exit_called.wait(timeout=5) + # Process should have exited with 0 (our on_exit fires for all unmanaged exits) + assert exit_called.is_set() + + def test_cwd_passed(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import os; print(os.getcwd())"], + cwd="/tmp", + ) + time.sleep(1.0) + mp.stop() + + +# =================================================================== +# safe_thread_map Tests +# =================================================================== + + +class TestSafeThreadMap: + def test_empty_input(self) -> None: + assert safe_thread_map([], lambda x: x) == [] + + def test_all_succeed(self) -> None: + result = safe_thread_map([1, 2, 3], lambda x: x * 2) + assert result == [2, 4, 6] + + def test_preserves_order(self) -> None: + def slow(x: int) -> int: + time.sleep(0.01 * (10 - x)) + return x + + result = safe_thread_map(list(range(10)), slow) + assert result == list(range(10)) + + def test_all_fail_raises_exception_group(self) -> None: + def fail(x: int) -> int: + raise ValueError(f"fail-{x}") + + with pytest.raises(ExceptionGroup) as exc_info: + safe_thread_map([1, 2, 3], fail) + assert len(exc_info.value.exceptions) == 3 + + def test_partial_failure(self) -> None: + def maybe_fail(x: int) -> int: + if x == 2: + raise ValueError("fail") + return x + + with pytest.raises(ExceptionGroup) as exc_info: + safe_thread_map([1, 2, 3], maybe_fail) + assert len(exc_info.value.exceptions) == 1 + + def test_on_errors_callback(self) -> None: + def fail(x: int) -> int: + if x == 2: + raise ValueError("boom") + return x * 10 + + cleanup_called = False + + def on_errors(outcomes, successes, errors): + nonlocal cleanup_called + cleanup_called = True + assert len(errors) == 1 + assert len(successes) == 2 + return successes # return successful results + + result = safe_thread_map([1, 2, 3], fail, on_errors) + assert cleanup_called + assert sorted(result) == [10, 30] + + def test_on_errors_can_raise(self) -> None: + def fail(x: int) -> int: + raise ValueError("boom") + + def on_errors(outcomes, successes, errors): + raise RuntimeError("custom error") + + with pytest.raises(RuntimeError, match="custom error"): + safe_thread_map([1], fail, on_errors) + + def test_waits_for_all_before_raising(self) -> None: + """Even if one fails fast, all others should complete.""" + completed = [] + + def work(x: int) -> int: + if x == 0: + raise ValueError("fast fail") + time.sleep(0.2) + completed.append(x) + return x + + with pytest.raises(ExceptionGroup): + safe_thread_map([0, 1, 2, 3], work) + # All non-failing items should have completed + assert sorted(completed) == [1, 2, 3] + + +# =================================================================== +# Integration: ModuleProcess on_exit -> dispose chain (the CI bug scenario) +# =================================================================== + + +class TestModuleProcessDisposeChain: + """Tests the exact pattern that caused the CI bug: + process exits -> watchdog fires on_exit -> module.stop() -> dispose -> + ModuleProcess.stop() -> tries to stop watchdog from inside watchdog thread. + """ + + def test_chain_no_deadlock_fast_exit(self) -> None: + """Process exits immediately.""" + for _ in range(20): + mod = FakeModule() + done = threading.Event() + + def fake_stop() -> None: + mod.dispose() + done.set() + + ModuleProcess( + module=mod, + args=[PYTHON, "-c", "pass"], + on_exit=fake_stop, + ) + assert done.wait(timeout=5), "Deadlock in dispose chain (fast exit)" + + def test_chain_no_deadlock_slow_exit(self) -> None: + """Process runs briefly then exits.""" + for _ in range(10): + mod = FakeModule() + done = threading.Event() + + def fake_stop() -> None: + mod.dispose() + done.set() + + ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(0.1)"], + on_exit=fake_stop, + ) + assert done.wait(timeout=5), "Deadlock in dispose chain (slow exit)" + + def test_chain_concurrent_with_external_stop(self) -> None: + """Process exits naturally while external code calls stop().""" + for _ in range(20): + mod = FakeModule() + done = threading.Event() + + def fake_stop() -> None: + mod.dispose() + done.set() + + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(0.05)"], + on_exit=fake_stop, + shutdown_timeout=1.0, + ) + # Race: the process might exit naturally or we might stop it + time.sleep(0.03) + mp.stop() + # Either way, should not deadlock + time.sleep(1.0) + + def test_dispose_with_artificial_delay(self) -> None: + """Add artificial delay near cleanup to simulate heavy CPU load.""" + original_stop = ModuleThread.stop + + def slow_stop(self_mt: ModuleThread) -> None: + time.sleep(0.05) # simulate load + original_stop(self_mt) + + for _ in range(10): + mod = FakeModule() + done = threading.Event() + + def fake_stop() -> None: + mod.dispose() + done.set() + + with mock.patch.object(ModuleThread, "stop", slow_stop): + ModuleProcess( + module=mod, + args=[PYTHON, "-c", "pass"], + on_exit=fake_stop, + ) + assert done.wait(timeout=10), "Deadlock with slow ModuleThread.stop()" + + +# We need ExceptionGroup for safe_thread_map tests +try: + ExceptionGroup +except NameError: + from dimos.utils.typing_utils import ExceptionGroup diff --git a/dimos/utils/thread_utils.py b/dimos/utils/thread_utils.py new file mode 100644 index 0000000000..e83047ca5b --- /dev/null +++ b/dimos/utils/thread_utils.py @@ -0,0 +1,542 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Thread utilities: safe values, managed threads, safe parallel map.""" + +from __future__ import annotations + +import asyncio +import json +import signal +import subprocess +import threading +from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from typing import IO, TYPE_CHECKING, Any, Generic + +from reactivex.disposable import Disposable + +from dimos.utils.logging_config import setup_logger +from dimos.utils.typing_utils import ExceptionGroup, TypeVar + +logger = setup_logger() + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from dimos.core.module import ModuleBase + +T = TypeVar("T") +R = TypeVar("R") + + +# --------------------------------------------------------------------------- +# ThreadSafeVal: a lock-protected value with context-manager support +# --------------------------------------------------------------------------- + + +class ThreadSafeVal(Generic[T]): + """A thread-safe value wrapper. + + Wraps any value with a lock and provides atomic read-modify-write + via a context manager:: + + counter = ThreadSafeVal(0) + + # Simple get/set (each acquires the lock briefly): + counter.set(10) + print(counter.get()) # 10 + + # Atomic read-modify-write: + with counter as value: + # Lock is held for the entire block. + # Other threads block on get/set/with until this exits. + if value < 100: + counter.set(value + 1) + + # Works with any type: + status = ThreadSafeVal({"running": False, "count": 0}) + with status as s: + status.set({**s, "running": True}) + + # Bool check (for flag-like usage): + stopping = ThreadSafeVal(False) + stopping.set(True) + if stopping: + print("stopping!") + """ + + def __init__(self, initial: T) -> None: + self._lock = threading.RLock() + self._value = initial + + def get(self) -> T: + """Return the current value (acquires the lock briefly).""" + with self._lock: + return self._value + + def set(self, value: T) -> None: + """Replace the value (acquires the lock briefly).""" + with self._lock: + self._value = value + + def __bool__(self) -> bool: + with self._lock: + return bool(self._value) + + def __enter__(self) -> T: + self._lock.acquire() + return self._value + + def __exit__(self, *exc: object) -> None: + self._lock.release() + + def __getstate__(self) -> dict[str, Any]: + return {"_value": self._value} + + def __setstate__(self, state: dict[str, Any]) -> None: + self._lock = threading.RLock() + self._value = state["_value"] + + def __repr__(self) -> str: + return f"ThreadSafeVal({self._value!r})" + + +# --------------------------------------------------------------------------- +# ModuleThread: a thread that auto-registers with a module's disposables +# --------------------------------------------------------------------------- + + +class ModuleThread: + """A thread that registers cleanup with a module's disposables. + + Passes most kwargs through to ``threading.Thread``. On construction, + registers a disposable with the module so that when the module stops, + the thread is automatically joined. Cleanup is idempotent — safe to + call ``stop()`` manually even if the module also disposes it. + + Example:: + + class MyModule(Module): + @rpc + def start(self) -> None: + self._worker = ModuleThread( + module=self, + target=self._run_loop, + name="my-worker", + ) + + def _run_loop(self) -> None: + while not self._worker.stopping: + do_work() + """ + + def __init__( + self, + module: ModuleBase, + *, + start: bool = True, + close_timeout: float = 2.0, + **thread_kwargs: Any, + ) -> None: + thread_kwargs.setdefault("daemon", True) + self._thread = threading.Thread(**thread_kwargs) + self._stop_event = threading.Event() + self._close_timeout = close_timeout + self._stopped = False + self._stop_lock = threading.Lock() + module._disposables.add(Disposable(self.stop)) + if start: + self.start() + + @property + def stopping(self) -> bool: + """True after ``stop()`` has been called.""" + return self._stop_event.is_set() + + def start(self) -> None: + """Start the underlying thread.""" + self._stop_event.clear() + self._thread.start() + + def stop(self) -> None: + """Signal the thread to stop and join it. + + Safe to call multiple times, from any thread (including the + managed thread itself — it will skip the join in that case). + """ + with self._stop_lock: + if self._stopped: + return + self._stopped = True + + self._stop_event.set() + if self._thread.is_alive() and self._thread is not threading.current_thread(): + self._thread.join(timeout=self._close_timeout) + + def join(self, timeout: float | None = None) -> None: + """Join the underlying thread.""" + self._thread.join(timeout=timeout) + + @property + def is_alive(self) -> bool: + return self._thread.is_alive() + + +# --------------------------------------------------------------------------- +# AsyncModuleThread: a thread running an asyncio event loop, auto-registered +# --------------------------------------------------------------------------- + + +class AsyncModuleThread: + """A thread running an asyncio event loop, registered with a module's disposables. + + If a loop is already running in the current context, reuses it (no thread + created). Otherwise creates a new loop and drives it in a daemon thread. + + On stop (or module dispose), the loop is shut down gracefully and the + thread is joined. Idempotent — safe to call ``stop()`` multiple times. + + Example:: + + class MyModule(Module): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._async = AsyncModuleThread(module=self) + + @rpc + def start(self) -> None: + future = asyncio.run_coroutine_threadsafe( + self._do_work(), self._async.loop + ) + + async def _do_work(self) -> None: + ... + """ + + def __init__( + self, + module: ModuleBase, + *, + close_timeout: float = 2.0, + ) -> None: + self._close_timeout = close_timeout + self._stopped = False + self._stop_lock = threading.Lock() + self._owns_loop = False + self._thread: threading.Thread | None = None + + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + self._owns_loop = True + self._thread = threading.Thread( + target=self._loop.run_forever, + daemon=True, + name=f"{type(module).__name__}-event-loop", + ) + self._thread.start() + + module._disposables.add(Disposable(self.stop)) + + @property + def loop(self) -> asyncio.AbstractEventLoop: + """The managed event loop.""" + return self._loop + + @property + def is_alive(self) -> bool: + return self._thread is not None and self._thread.is_alive() + + def stop(self) -> None: + """Stop the event loop and join the thread. + + No-op if the loop was not created by this instance (reused an + existing running loop). Safe to call multiple times. + """ + with self._stop_lock: + if self._stopped: + return + self._stopped = True + + if self._owns_loop and self._loop.is_running(): + self._loop.call_soon_threadsafe(self._loop.stop) + + if self._thread is not None and self._thread.is_alive(): + self._thread.join(timeout=self._close_timeout) + + +# --------------------------------------------------------------------------- +# ModuleProcess: managed subprocess with log piping, auto-registered cleanup +# --------------------------------------------------------------------------- + + +class ModuleProcess: + """A managed subprocess that pipes stdout/stderr through the logger. + + Registers with a module's disposables so the process is automatically + stopped on module teardown. A watchdog thread monitors the process and + calls ``on_exit`` if the process exits on its own (i.e. not via + ``ModuleProcess.stop()``). + + Most constructor kwargs mirror ``subprocess.Popen``. ``stdout`` and + ``stderr`` are always captured (set to ``PIPE`` internally). + + Example:: + + class MyModule(Module): + @rpc + def start(self) -> None: + self._proc = ModuleProcess( + module=self, + args=["./my_binary", "--flag"], + cwd="/opt/bin", + on_exit=self.stop, # stops the whole module if process exits on its own + ) + + @rpc + def stop(self) -> None: + # ModuleProcess.stop() is also called automatically via disposables + super().stop() + """ + + def __init__( + self, + module: ModuleBase, + args: list[str] | str, + *, + env: dict[str, str] | None = None, + cwd: str | None = None, + shell: bool = False, + on_exit: Callable[[], Any] | None = None, + shutdown_timeout: float = 10.0, + kill_timeout: float = 5.0, + log_json: bool = False, + start: bool = True, + **popen_kwargs: Any, + ) -> None: + self._args = args + self._env = env + self._cwd = cwd + self._shell = shell + self._on_exit = on_exit + self._shutdown_timeout = shutdown_timeout + self._kill_timeout = kill_timeout + self._log_json = log_json + self._popen_kwargs = popen_kwargs + self._process: subprocess.Popen[bytes] | None = None + self._watchdog: ModuleThread | None = None + self._module = module + self._stopped = False + self._stop_lock = threading.Lock() + + module._disposables.add(Disposable(self.stop)) + if start: + self.start() + + @property + def pid(self) -> int | None: + return self._process.pid if self._process is not None else None + + @property + def returncode(self) -> int | None: + if self._process is None: + return None + return self._process.poll() + + @property + def is_alive(self) -> bool: + return self._process is not None and self._process.poll() is None + + def start(self) -> None: + """Launch the subprocess and start the watchdog.""" + if self._process is not None and self._process.poll() is None: + logger.warning("Process already running", pid=self._process.pid) + return + + with self._stop_lock: + self._stopped = False + + logger.info( + "Starting process", + cmd=self._args if isinstance(self._args, str) else " ".join(self._args), + cwd=self._cwd, + ) + self._process = subprocess.Popen( + self._args, + env=self._env, + cwd=self._cwd, + shell=self._shell, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + **self._popen_kwargs, + ) + logger.info("Process started", pid=self._process.pid) + + self._watchdog = ModuleThread( + module=self._module, + target=self._watch, + name=f"proc-{self._process.pid}-watchdog", + ) + + def stop(self) -> None: + """Send SIGTERM, wait, escalate to SIGKILL if needed. Idempotent.""" + with self._stop_lock: + if self._stopped: + return + self._stopped = True + + if self._process is not None and self._process.poll() is None: + logger.info("Stopping process", pid=self._process.pid) + try: + self._process.send_signal(signal.SIGTERM) + except OSError: + pass # process already dead (PID recycled or exited between poll and signal) + else: + try: + self._process.wait(timeout=self._shutdown_timeout) + except subprocess.TimeoutExpired: + logger.warning( + "Process did not exit, sending SIGKILL", + pid=self._process.pid, + ) + self._process.kill() + try: + self._process.wait(timeout=self._kill_timeout) + except subprocess.TimeoutExpired: + logger.error( + "Process did not exit after SIGKILL", + pid=self._process.pid, + ) + self._process = None + + def _watch(self) -> None: + """Watchdog: pipe logs, detect crashes.""" + proc = self._process + if proc is None: + return + + stdout_t = self._start_reader(proc.stdout, "info") + stderr_t = self._start_reader(proc.stderr, "warning") + rc = proc.wait() + stdout_t.join(timeout=2) + stderr_t.join(timeout=2) + + with self._stop_lock: + if self._stopped: + return + + logger.error("Process died unexpectedly", pid=proc.pid, returncode=rc) + if self._on_exit is not None: + self._on_exit() + + def _start_reader(self, stream: IO[bytes] | None, level: str) -> threading.Thread: + t = threading.Thread(target=self._read_stream, args=(stream, level), daemon=True) + t.start() + return t + + def _read_stream(self, stream: IO[bytes] | None, level: str) -> None: + if stream is None: + return + log_fn = getattr(logger, level) + for raw in stream: + line = raw.decode("utf-8", errors="replace").rstrip() + if not line: + continue + if self._log_json: + try: + data = json.loads(line) + event = data.pop("event", line) + log_fn(event, **data) + continue + except (json.JSONDecodeError, TypeError): + logger.warning("malformed JSON from process", raw=line) + proc = self._process + log_fn(line, pid=proc.pid if proc else None) + stream.close() + + +# --------------------------------------------------------------------------- +# safe_thread_map: parallel map that collects all results before raising +# --------------------------------------------------------------------------- + + +def safe_thread_map( + items: Sequence[T], + fn: Callable[[T], R], + on_errors: Callable[[list[tuple[T, R | Exception]], list[R], list[Exception]], Any] + | None = None, +) -> list[R]: + """Thread-pool map that waits for all items to finish before raising and a cleanup handler + + - Empty *items* → returns ``[]`` immediately. + - All succeed → returns results in input order. + - Any fail → calls ``on_errors(outcomes, successes, errors)`` where + *outcomes* is a list of ``(input, result_or_exception)`` pairs in input + order, *successes* is the list of successful results, and *errors* is + the list of exceptions. If *on_errors* raises, that exception propagates. + If *on_errors* returns normally, its return value is returned from + ``safe_thread_map``. If *on_errors* is ``None``, raises an + ``ExceptionGroup``. + + Example:: + + def start_service(name: str) -> Connection: + return connect(name) + + def cleanup( + outcomes: list[tuple[str, Connection | Exception]], + successes: list[Connection], + errors: list[Exception], + ) -> None: + for conn in successes: + conn.close() + raise ExceptionGroup("failed to start services", errors) + + connections = safe_thread_map( + ["db", "cache", "queue"], + start_service, + cleanup, # called only if any start_service() raises + ) + """ + if not items: + return [] + + outcomes: dict[int, R | Exception] = {} + + with ThreadPoolExecutor(max_workers=len(items)) as pool: + futures: dict[Future[R], int] = {pool.submit(fn, item): i for i, item in enumerate(items)} + for fut in as_completed(futures): + idx = futures[fut] + try: + outcomes[idx] = fut.result() + except Exception as e: + outcomes[idx] = e + + successes: list[R] = [] + errors: list[Exception] = [] + for v in outcomes.values(): + if isinstance(v, Exception): + errors.append(v) + else: + successes.append(v) + + if errors: + if on_errors is not None: + zipped = [(items[i], outcomes[i]) for i in range(len(items))] + return on_errors(zipped, successes, errors) # type: ignore[return-value, no-any-return] + raise ExceptionGroup("safe_thread_map failed", errors) + + return [outcomes[i] for i in range(len(items))] # type: ignore[misc] diff --git a/dimos/utils/typing_utils.py b/dimos/utils/typing_utils.py new file mode 100644 index 0000000000..aa32fff47f --- /dev/null +++ b/dimos/utils/typing_utils.py @@ -0,0 +1,45 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unify typing compatibility across multiple Python versions.""" + +from __future__ import annotations + +import sys +from collections.abc import Sequence + +if sys.version_info < (3, 13): + from typing_extensions import TypeVar +else: + from typing import TypeVar + +if sys.version_info < (3, 11): + + class ExceptionGroup(Exception): # type: ignore[no-redef] # noqa: N818 + """Minimal ExceptionGroup polyfill for Python 3.10.""" + + exceptions: tuple[BaseException, ...] + + def __init__(self, message: str, exceptions: Sequence[BaseException]) -> None: + super().__init__(message) + self.exceptions = tuple(exceptions) +else: + import builtins + + ExceptionGroup = builtins.ExceptionGroup # type: ignore[misc] + +__all__ = [ + "ExceptionGroup", + "TypeVar", +] From 8ae6282d677c91dfb97ec1adf5ded4ec79641afb Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Tue, 24 Mar 2026 23:03:09 -0700 Subject: [PATCH 11/21] ideal approach, not tested --- dimos/agents/mcp/mcp_server.py | 5 +- dimos/core/module.py | 76 +- dimos/core/native_module.py | 125 +-- dimos/core/test_core.py | 2 +- dimos/perception/detection/conftest.py | 10 +- .../perception/detection/reid/test_module.py | 2 +- dimos/robot/unitree/b1/test_connection.py | 18 +- dimos/utils/test_thread_utils.py | 897 ++++++++++++++++++ dimos/utils/thread_utils.py | 542 +++++++++++ dimos/utils/typing_utils.py | 45 + 10 files changed, 1544 insertions(+), 178 deletions(-) create mode 100644 dimos/utils/test_thread_utils.py create mode 100644 dimos/utils/thread_utils.py create mode 100644 dimos/utils/typing_utils.py diff --git a/dimos/agents/mcp/mcp_server.py b/dimos/agents/mcp/mcp_server.py index e5697542fb..e1179abe55 100644 --- a/dimos/agents/mcp/mcp_server.py +++ b/dimos/agents/mcp/mcp_server.py @@ -195,7 +195,7 @@ def start(self) -> None: def stop(self) -> None: if self._uvicorn_server: self._uvicorn_server.should_exit = True - loop = self._loop + loop = self._async_thread.loop if loop is not None and self._serve_future is not None: self._serve_future.result(timeout=5.0) self._uvicorn_server = None @@ -269,6 +269,5 @@ def _start_server(self, port: int | None = None) -> None: config = uvicorn.Config(app, host=_host, port=_port, log_level="info") server = uvicorn.Server(config) self._uvicorn_server = server - loop = self._loop - assert loop is not None + loop = self._async_thread.loop self._serve_future = asyncio.run_coroutine_threadsafe(server.serve(), loop) diff --git a/dimos/core/module.py b/dimos/core/module.py index ab21ce17a9..40fd34f5be 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -22,6 +22,7 @@ from typing import ( TYPE_CHECKING, Any, + Literal, Protocol, get_args, get_origin, @@ -43,6 +44,9 @@ from dimos.protocol.tf import LCMTF, TFSpec from dimos.utils import colors from dimos.utils.generic import classproperty +from dimos.utils.thread_utils import AsyncModuleThread, ThreadSafeVal + +ModState = Literal["init", "started", "stopping", "stopped"] if TYPE_CHECKING: from dimos.core.blueprints import Blueprint @@ -62,19 +66,6 @@ class SkillInfo: args_schema: str -def get_loop() -> tuple[asyncio.AbstractEventLoop, threading.Thread | None]: - try: - running_loop = asyncio.get_running_loop() - return running_loop, None - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - thr = threading.Thread(target=loop.run_forever, daemon=True) - thr.start() - return loop, thr - - class ModuleConfig(BaseConfig): rpc_transport: type[RPCSpec] = LCMRPC tf_transport: type[TFSpec] = LCMTF # type: ignore[type-arg] @@ -96,20 +87,20 @@ class ModuleBase(Configurable[ModuleConfigT], Resource): _rpc: RPCSpec | None = None _tf: TFSpec[Any] | None = None - _loop: asyncio.AbstractEventLoop | None = None - _loop_thread: threading.Thread | None + _async_thread: AsyncModuleThread _disposables: CompositeDisposable _bound_rpc_calls: dict[str, RpcCall] = {} - _module_closed: bool = False - _module_closed_lock: threading.Lock + mod_state: ThreadSafeVal[ModState] rpc_calls: list[str] = [] def __init__(self, config_args: dict[str, Any]): super().__init__(**config_args) - self._module_closed_lock = threading.Lock() - self._loop, self._loop_thread = get_loop() self._disposables = CompositeDisposable() + self.mod_state = ThreadSafeVal[ModState]("init") + self._async_thread = AsyncModuleThread( # NEEDS to be created after self._disposables exists + module=self + ) try: self.rpc = self.config.rpc_transport() self.rpc.serve_module_rpc(self) @@ -126,38 +117,30 @@ def frame_id(self) -> str: @rpc def start(self) -> None: - pass + with self.mod_state as state: + if state == "stopped": + raise RuntimeError(f"{type(self).__name__} cannot be restarted after stop") + self.mod_state.set("started") @rpc def stop(self) -> None: - self._close_module() + self._stop() - def _close_module(self) -> None: - with self._module_closed_lock: - if self._module_closed: + def _stop(self) -> None: + with self.mod_state as state: + if state in ("stopping", "stopped"): return - self._module_closed = True - - self._close_rpc() - - # Save into local variables to avoid race when stopping concurrently - # (from RPC and worker shutdown) - loop_thread = getattr(self, "_loop_thread", None) - loop = getattr(self, "_loop", None) + self.mod_state.set("stopping") - if loop_thread: - if loop_thread.is_alive(): - if loop: - loop.call_soon_threadsafe(loop.stop) - loop_thread.join(timeout=2) - self._loop = None - self._loop_thread = None + if self.rpc: + self.rpc.stop() # type: ignore[attr-defined] + self.rpc = None # type: ignore[assignment] if hasattr(self, "_tf") and self._tf is not None: self._tf.stop() self._tf = None if hasattr(self, "_disposables"): - self._disposables.dispose() + self._disposables.dispose() # stops _async_thread via disposable # Break the In/Out -> owner -> self reference cycle so the instance # can be freed by refcount instead of waiting for GC. @@ -165,19 +148,12 @@ def _close_module(self) -> None: if isinstance(attr, (In, Out)): attr.owner = None - def _close_rpc(self) -> None: - if self.rpc: - self.rpc.stop() # type: ignore[attr-defined] - self.rpc = None # type: ignore[assignment] - def __getstate__(self): # type: ignore[no-untyped-def] """Exclude unpicklable runtime attributes when serializing.""" state = self.__dict__.copy() # Remove unpicklable attributes state.pop("_disposables", None) - state.pop("_module_closed_lock", None) - state.pop("_loop", None) - state.pop("_loop_thread", None) + state.pop("_async_thread", None) state.pop("_rpc", None) state.pop("_tf", None) return state @@ -187,9 +163,7 @@ def __setstate__(self, state) -> None: # type: ignore[no-untyped-def] self.__dict__.update(state) # Reinitialize runtime attributes self._disposables = CompositeDisposable() - self._module_closed_lock = threading.Lock() - self._loop = None - self._loop_thread = None + self._async_thread = None # type: ignore[assignment] self._rpc = None self._tf = None diff --git a/dimos/core/native_module.py b/dimos/core/native_module.py index 1a591772f4..bdc46e2cab 100644 --- a/dimos/core/native_module.py +++ b/dimos/core/native_module.py @@ -42,20 +42,18 @@ class MyCppModule(NativeModule): import enum import inspect -import json import os from pathlib import Path -import signal import subprocess import sys -import threading -from typing import IO, Any +from typing import Any from pydantic import Field from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.utils.logging_config import setup_logger +from dimos.utils.thread_utils import ModuleProcess if sys.version_info < (3, 13): from typing_extensions import TypeVar @@ -128,9 +126,7 @@ class NativeModule(Module[_NativeConfig]): """ default_config: type[_NativeConfig] = NativeModuleConfig # type: ignore[assignment] - _process: subprocess.Popen[bytes] | None = None - _watchdog: threading.Thread | None = None - _stopping: bool = False + _proc: ModuleProcess | None = None def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -138,109 +134,22 @@ def __init__(self, **kwargs: Any) -> None: @rpc def start(self) -> None: - if self._process is not None and self._process.poll() is None: - logger.warning("Native process already running", pid=self._process.pid) - return - + super().start() self._maybe_build() - - topics = self._collect_topics() - - cmd = [self.config.executable] - for name, topic_str in topics.items(): - cmd.extend([f"--{name}", topic_str]) - cmd.extend(self.config.to_cli_args()) - cmd.extend(self.config.extra_args) - - env = {**os.environ, **self.config.extra_env} - cwd = self.config.cwd or str(Path(self.config.executable).resolve().parent) - - logger.info("Starting native process", cmd=" ".join(cmd), cwd=cwd) - self._process = subprocess.Popen( - cmd, - env=env, - cwd=cwd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - logger.info("Native process started", pid=self._process.pid) - - self._stopping = False - self._watchdog = threading.Thread(target=self._watch_process, daemon=True) - self._watchdog.start() - - def _clean_all_but_watchdog(self) -> None: - """A cleanup helper designed to be called inside of the watchdog and outside of the watchdong""" - self._stopping = True - if self._process is not None and self._process.poll() is None: - logger.info("Stopping native process", pid=self._process.pid) - self._process.send_signal(signal.SIGTERM) - try: - self._process.wait(timeout=self.config.shutdown_timeout) - except subprocess.TimeoutExpired: - logger.warning( - "Native process did not exit, sending SIGKILL", pid=self._process.pid - ) - self._process.kill() - try: - self._process.wait(timeout=5) - except Exception as error: - print(f'''error = {error}''') - self._process = None - - - @rpc - def stop(self) -> None: - self._clean_all_but_watchdog() - super().stop() - if self._watchdog is not None: - self._watchdog.join(timeout=2) - self._watchdog = None - - def _watch_process(self) -> None: - """Block until the native process exits; trigger cleanup if it crashed.""" - if self._process is None: - return - - stdout_t = self._start_reader(self._process.stdout, "info") - stderr_t = self._start_reader(self._process.stderr, "warning") - rc = self._process.wait() - stdout_t.join(timeout=2) - stderr_t.join(timeout=2) - - if self._stopping: - return - logger.error( - "Native process died unexpectedly", - pid=self._process.pid, - returncode=rc, + self._proc = ModuleProcess( + module=self, + args=[ + self.config.executable, + *[arg for name, topic in self._collect_topics().items() for arg in (f"--{name}", topic)], + *self.config.to_cli_args(), + *self.config.extra_args, + ], + env={**os.environ, **self.config.extra_env}, + cwd=self.config.cwd or str(Path(self.config.executable).resolve().parent), + on_exit=self.stop, + shutdown_timeout=self.config.shutdown_timeout, + log_json=self.config.log_format == LogFormat.JSON, ) - self._clean_all_but_watchdog() - - def _start_reader(self, stream: IO[bytes] | None, level: str) -> threading.Thread: - """Spawn a daemon thread that pipes a subprocess stream through the logger.""" - t = threading.Thread(target=self._read_log_stream, args=(stream, level), daemon=True) - t.start() - return t - - def _read_log_stream(self, stream: IO[bytes] | None, level: str) -> None: - if stream is None: - return - log_fn = getattr(logger, level) - for raw in stream: - line = raw.decode("utf-8", errors="replace").rstrip() - if not line: - continue - if self.config.log_format == LogFormat.JSON: - try: - data = json.loads(line) - event = data.pop("event", line) - log_fn(event, **data) - continue - except (json.JSONDecodeError, TypeError): - logger.warning("malformed JSON from native module", raw=line) - log_fn(line, pid=self._process.pid if self._process else None) - stream.close() def _resolve_paths(self) -> None: """Resolve relative ``cwd`` and ``executable`` against the subclass's source file.""" diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index 3bd1383761..1400ae1ebb 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -89,7 +89,7 @@ def test_classmethods() -> None: ) assert hasattr(class_rpcs["start"], "__rpc__"), "start should have __rpc__ attribute" - nav._close_module() + nav._stop() @pytest.mark.slow diff --git a/dimos/perception/detection/conftest.py b/dimos/perception/detection/conftest.py index 8c1a65eb8b..81f6d1805b 100644 --- a/dimos/perception/detection/conftest.py +++ b/dimos/perception/detection/conftest.py @@ -221,7 +221,7 @@ def moment_provider(**kwargs) -> Moment2D: yield moment_provider moment_provider.cache_clear() - module._close_module() + module._stop() @pytest.fixture(scope="session") @@ -256,7 +256,7 @@ def moment_provider(**kwargs) -> Moment3D: yield moment_provider moment_provider.cache_clear() if module is not None: - module._close_module() + module._stop() @pytest.fixture(scope="session") @@ -290,9 +290,9 @@ def object_db_module(get_moment): yield moduleDB - module2d._close_module() - module3d._close_module() - moduleDB._close_module() + module2d._stop() + module3d._stop() + moduleDB._stop() @pytest.fixture(scope="session") diff --git a/dimos/perception/detection/reid/test_module.py b/dimos/perception/detection/reid/test_module.py index f5672c1f67..e2c603c307 100644 --- a/dimos/perception/detection/reid/test_module.py +++ b/dimos/perception/detection/reid/test_module.py @@ -40,5 +40,5 @@ def test_reid_ingress(imageDetections2d) -> None: print("Processing detections through ReidModule...") reid_module.annotations._transport = LCMTransport("/annotations", ImageAnnotations) reid_module.ingress(imageDetections2d) - reid_module._close_module() + reid_module._stop() print("✓ ReidModule ingress test completed successfully") diff --git a/dimos/robot/unitree/b1/test_connection.py b/dimos/robot/unitree/b1/test_connection.py index e43a3124dc..011853d172 100644 --- a/dimos/robot/unitree/b1/test_connection.py +++ b/dimos/robot/unitree/b1/test_connection.py @@ -73,7 +73,7 @@ def test_watchdog_actually_zeros_commands(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn._stop() def test_watchdog_resets_on_new_command(self) -> None: """Test that watchdog timeout resets when new command arrives.""" @@ -121,7 +121,7 @@ def test_watchdog_resets_on_new_command(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn._stop() def test_watchdog_thread_efficiency(self) -> None: """Test that watchdog uses only one thread regardless of command rate.""" @@ -155,7 +155,7 @@ def test_watchdog_thread_efficiency(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn._stop() def test_watchdog_with_send_loop_blocking(self) -> None: """Test that watchdog still works if send loop blocks.""" @@ -202,7 +202,7 @@ def blocking_send_loop() -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn._stop() def test_continuous_commands_prevent_timeout(self) -> None: """Test that continuous commands prevent watchdog timeout.""" @@ -237,7 +237,7 @@ def test_continuous_commands_prevent_timeout(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn._stop() def test_watchdog_timing_accuracy(self) -> None: """Test that watchdog zeros commands at approximately 200ms.""" @@ -278,7 +278,7 @@ def test_watchdog_timing_accuracy(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn._stop() def test_mode_changes_with_watchdog(self) -> None: """Test that mode changes work correctly with watchdog.""" @@ -321,7 +321,7 @@ def test_mode_changes_with_watchdog(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn._stop() def test_watchdog_stops_movement_when_commands_stop(self) -> None: """Verify watchdog zeros commands when packets stop being sent.""" @@ -379,7 +379,7 @@ def test_watchdog_stops_movement_when_commands_stop(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn._stop() def test_rapid_command_thread_safety(self) -> None: """Test thread safety with rapid commands from multiple threads.""" @@ -428,4 +428,4 @@ def send_commands(thread_id) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn._stop() diff --git a/dimos/utils/test_thread_utils.py b/dimos/utils/test_thread_utils.py new file mode 100644 index 0000000000..87c4a883e2 --- /dev/null +++ b/dimos/utils/test_thread_utils.py @@ -0,0 +1,897 @@ +"""Exhaustive tests for dimos/utils/thread_utils.py + +Covers: ThreadSafeVal, ModuleThread, AsyncModuleThread, ModuleProcess, safe_thread_map. +Focuses on deadlocks, race conditions, idempotency, and edge cases under load. +""" + +from __future__ import annotations + +import asyncio +import os +import pickle +import signal +import subprocess +import sys +import threading +import time +from unittest import mock + +import pytest +from reactivex.disposable import CompositeDisposable + +from dimos.utils.thread_utils import ( + AsyncModuleThread, + ModuleProcess, + ModuleThread, + ThreadSafeVal, + safe_thread_map, +) + + +# --------------------------------------------------------------------------- +# Helpers: fake ModuleBase for testing ModuleThread / AsyncModuleThread / ModuleProcess +# --------------------------------------------------------------------------- + + +class FakeModule: + """Minimal stand-in for ModuleBase — just needs _disposables.""" + + def __init__(self) -> None: + self._disposables = CompositeDisposable() + + def dispose(self) -> None: + self._disposables.dispose() + + +# =================================================================== +# ThreadSafeVal Tests +# =================================================================== + + +class TestThreadSafeVal: + def test_basic_get_set(self) -> None: + v = ThreadSafeVal(42) + assert v.get() == 42 + v.set(99) + assert v.get() == 99 + + def test_bool_truthy(self) -> None: + v = ThreadSafeVal(True) + assert bool(v) is True + v.set(False) + assert bool(v) is False + + def test_bool_zero(self) -> None: + v = ThreadSafeVal(0) + assert bool(v) is False + v.set(1) + assert bool(v) is True + + def test_context_manager_returns_value(self) -> None: + v = ThreadSafeVal("hello") + with v as val: + assert val == "hello" + + def test_set_inside_context_manager_no_deadlock(self) -> None: + """The critical test: set() inside a with block must NOT deadlock. + + This was a confirmed bug when using threading.Lock (non-reentrant). + Fixed by using threading.RLock. + """ + v = ThreadSafeVal(0) + result = threading.Event() + + def do_it() -> None: + with v as val: + v.set(val + 1) + result.set() + + t = threading.Thread(target=do_it) + t.start() + t.join(timeout=2) + assert result.is_set(), "Deadlocked! set() inside with block hung" + assert v.get() == 1 + + def test_get_inside_context_manager_no_deadlock(self) -> None: + v = ThreadSafeVal(10) + result = threading.Event() + + def do_it() -> None: + with v as val: + _ = v.get() + result.set() + + t = threading.Thread(target=do_it) + t.start() + t.join(timeout=2) + assert result.is_set(), "Deadlocked! get() inside with block hung" + + def test_bool_inside_context_manager_no_deadlock(self) -> None: + v = ThreadSafeVal(True) + result = threading.Event() + + def do_it() -> None: + with v as val: + _ = bool(v) + result.set() + + t = threading.Thread(target=do_it) + t.start() + t.join(timeout=2) + assert result.is_set(), "Deadlocked! bool() inside with block hung" + + def test_context_manager_blocks_other_threads(self) -> None: + """While one thread holds the lock via `with`, others should block on set().""" + v = ThreadSafeVal(0) + gate = threading.Event() + other_started = threading.Event() + other_finished = threading.Event() + + def holder() -> None: + with v as val: + gate.wait(timeout=5) # hold the lock until signaled + + def setter() -> None: + other_started.set() + v.set(42) # should block until holder releases + other_finished.set() + + t1 = threading.Thread(target=holder) + t2 = threading.Thread(target=setter) + t1.start() + time.sleep(0.05) # let holder acquire lock + t2.start() + other_started.wait(timeout=2) + time.sleep(0.1) + # setter should be blocked + assert not other_finished.is_set(), "set() did not block while lock was held" + gate.set() # release holder + t1.join(timeout=2) + t2.join(timeout=2) + assert other_finished.is_set() + assert v.get() == 42 + + def test_concurrent_increments(self) -> None: + """Many threads doing atomic read-modify-write should not lose updates.""" + v = ThreadSafeVal(0) + n_threads = 50 + n_increments = 100 + + def incrementer() -> None: + for _ in range(n_increments): + with v as val: + v.set(val + 1) + + threads = [threading.Thread(target=incrementer) for _ in range(n_threads)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=10) + assert v.get() == n_threads * n_increments + + def test_concurrent_increments_stress(self) -> None: + """Run the concurrent increment test multiple times to catch races.""" + for _ in range(10): + self.test_concurrent_increments() + + def test_pickle_roundtrip(self) -> None: + v = ThreadSafeVal({"key": [1, 2, 3]}) + data = pickle.dumps(v) + v2 = pickle.loads(data) + assert v2.get() == {"key": [1, 2, 3]} + # Verify the new instance has a working lock + with v2 as val: + v2.set({**val, "new": True}) + assert v2.get()["new"] is True + + def test_repr(self) -> None: + v = ThreadSafeVal("test") + assert repr(v) == "ThreadSafeVal('test')" + + def test_dict_type(self) -> None: + v = ThreadSafeVal({"running": False, "count": 0}) + with v as s: + v.set({**s, "running": True}) + assert v.get() == {"running": True, "count": 0} + + def test_string_literal_type(self) -> None: + """Simulates the ModState pattern from module.py.""" + v = ThreadSafeVal("init") + with v as state: + if state == "init": + v.set("started") + assert v.get() == "started" + + with v as state: + if state in ("stopping", "stopped"): + pass # no-op + else: + v.set("stopping") + assert v.get() == "stopping" + + def test_nested_with_no_deadlock(self) -> None: + """RLock should allow the same thread to nest with blocks.""" + v = ThreadSafeVal(0) + result = threading.Event() + + def do_it() -> None: + with v as val1: + with v as val2: + v.set(val2 + 1) + result.set() + + t = threading.Thread(target=do_it) + t.start() + t.join(timeout=2) + assert result.is_set(), "Nested with blocks deadlocked!" + + +# =================================================================== +# ModuleThread Tests +# =================================================================== + + +class TestModuleThread: + def test_basic_lifecycle(self) -> None: + mod = FakeModule() + ran = threading.Event() + + def target() -> None: + ran.set() + + mt = ModuleThread(module=mod, target=target, name="test-basic") + ran.wait(timeout=2) + assert ran.is_set() + mt.stop() + assert not mt.is_alive + + def test_auto_start(self) -> None: + mod = FakeModule() + started = threading.Event() + mt = ModuleThread(module=mod, target=started.set, name="test-autostart") + started.wait(timeout=2) + assert started.is_set() + mt.stop() + + def test_deferred_start(self) -> None: + mod = FakeModule() + started = threading.Event() + mt = ModuleThread(module=mod, target=started.set, name="test-deferred", start=False) + time.sleep(0.1) + assert not started.is_set() + mt.start() + started.wait(timeout=2) + assert started.is_set() + mt.stop() + + def test_stopping_property(self) -> None: + mod = FakeModule() + saw_stopping = threading.Event() + holder: list[ModuleThread] = [] + + def target() -> None: + while not holder[0].stopping: + time.sleep(0.01) + saw_stopping.set() + + mt = ModuleThread(module=mod, target=target, name="test-stopping", start=False) + holder.append(mt) + mt.start() + time.sleep(0.05) + mt.stop() + saw_stopping.wait(timeout=2) + assert saw_stopping.is_set() + + def test_stop_idempotent(self) -> None: + mod = FakeModule() + mt = ModuleThread(module=mod, target=lambda: time.sleep(0.01), name="test-idem") + time.sleep(0.05) + mt.stop() + mt.stop() # second call should not raise + mt.stop() # third call should not raise + + def test_stop_from_managed_thread_no_deadlock(self) -> None: + """The thread calling stop() on itself should not deadlock.""" + mod = FakeModule() + result = threading.Event() + holder: list[ModuleThread] = [] + + def target() -> None: + holder[0].stop() # stop ourselves — should not deadlock + result.set() + + mt = ModuleThread(module=mod, target=target, name="test-self-stop", start=False) + holder.append(mt) + mt.start() + result.wait(timeout=3) + assert result.is_set(), "Deadlocked when thread called stop() on itself" + + def test_dispose_stops_thread(self) -> None: + """Module dispose should stop the thread via the registered Disposable.""" + mod = FakeModule() + running = threading.Event() + holder: list[ModuleThread] = [] + + def target() -> None: + running.set() + while not holder[0].stopping: + time.sleep(0.01) + + mt = ModuleThread(module=mod, target=target, name="test-dispose", start=False) + holder.append(mt) + mt.start() + running.wait(timeout=2) + mod.dispose() + time.sleep(0.1) + assert not mt.is_alive + + def test_concurrent_stop_calls(self) -> None: + """Multiple threads calling stop() concurrently should not crash.""" + mod = FakeModule() + holder: list[ModuleThread] = [] + + def target() -> None: + while not holder[0].stopping: + time.sleep(0.01) + + mt = ModuleThread(module=mod, target=target, name="test-concurrent-stop", start=False) + holder.append(mt) + mt.start() + time.sleep(0.05) + + errors = [] + + def stop_it() -> None: + try: + mt.stop() + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=stop_it) for _ in range(20)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=5) + assert not errors, f"Concurrent stop() raised: {errors}" + + def test_close_timeout_respected(self) -> None: + """If the thread ignores the stop signal, stop() should return after close_timeout.""" + mod = FakeModule() + + def stubborn_target() -> None: + time.sleep(10) # ignores stopping signal + + mt = ModuleThread( + module=mod, target=stubborn_target, name="test-timeout", close_timeout=0.2 + ) + start = time.monotonic() + mt.stop() + elapsed = time.monotonic() - start + assert elapsed < 1.0, f"stop() took {elapsed}s, expected ~0.2s" + + def test_stop_concurrent_with_dispose(self) -> None: + """Calling stop() and dispose() concurrently should not crash.""" + for _ in range(20): + mod = FakeModule() + holder: list[ModuleThread] = [] + + def target() -> None: + while not holder[0].stopping: + time.sleep(0.001) + + mt = ModuleThread(module=mod, target=target, name="test-stop-dispose", start=False) + holder.append(mt) + mt.start() + time.sleep(0.02) + # Race: stop and dispose from different threads + t1 = threading.Thread(target=mt.stop) + t2 = threading.Thread(target=mod.dispose) + t1.start() + t2.start() + t1.join(timeout=3) + t2.join(timeout=3) + + +# =================================================================== +# AsyncModuleThread Tests +# =================================================================== + + +class TestAsyncModuleThread: + def test_creates_loop_and_thread(self) -> None: + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + assert amt.loop is not None + assert amt.loop.is_running() + assert amt.is_alive + amt.stop() + assert not amt.is_alive + + def test_stop_idempotent(self) -> None: + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + amt.stop() + amt.stop() # should not raise + amt.stop() + + def test_dispose_stops_loop(self) -> None: + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + assert amt.is_alive + mod.dispose() + time.sleep(0.1) + assert not amt.is_alive + + def test_can_schedule_coroutine(self) -> None: + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + result = [] + + async def coro() -> None: + result.append(42) + + future = asyncio.run_coroutine_threadsafe(coro(), amt.loop) + future.result(timeout=2) + assert result == [42] + amt.stop() + + def test_stop_with_pending_work(self) -> None: + """Stop should succeed even with long-running tasks on the loop.""" + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + started = threading.Event() + + async def slow_coro() -> None: + started.set() + await asyncio.sleep(10) + + asyncio.run_coroutine_threadsafe(slow_coro(), amt.loop) + started.wait(timeout=2) + # stop() should not hang waiting for the coroutine + start = time.monotonic() + amt.stop() + elapsed = time.monotonic() - start + assert elapsed < 5.0, f"stop() hung for {elapsed}s with pending coroutine" + + def test_concurrent_stop(self) -> None: + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + errors = [] + + def stop_it() -> None: + try: + amt.stop() + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=stop_it) for _ in range(20)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=5) + assert not errors + + +# =================================================================== +# ModuleProcess Tests +# =================================================================== + + +# Helper: path to a python that sleeps or echoes +PYTHON = sys.executable + + +class TestModuleProcess: + def test_basic_lifecycle(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + shutdown_timeout=2.0, + ) + assert mp.is_alive + assert mp.pid is not None + mp.stop() + assert not mp.is_alive + assert mp.pid is None + + def test_stop_idempotent(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + shutdown_timeout=1.0, + ) + mp.stop() + mp.stop() # should not raise + mp.stop() + + def test_dispose_stops_process(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + shutdown_timeout=2.0, + ) + pid = mp.pid + mod.dispose() + time.sleep(0.5) + assert not mp.is_alive + + def test_on_exit_fires_on_natural_exit(self) -> None: + """on_exit should fire when the process exits on its own.""" + mod = FakeModule() + exit_called = threading.Event() + + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "print('done')"], + on_exit=exit_called.set, + ) + exit_called.wait(timeout=5) + assert exit_called.is_set(), "on_exit was not called after natural process exit" + + def test_on_exit_fires_on_crash(self) -> None: + mod = FakeModule() + exit_called = threading.Event() + + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import sys; sys.exit(1)"], + on_exit=exit_called.set, + ) + exit_called.wait(timeout=5) + assert exit_called.is_set(), "on_exit was not called after process crash" + + def test_on_exit_not_fired_on_stop(self) -> None: + """on_exit should NOT fire when stop() kills the process.""" + mod = FakeModule() + exit_called = threading.Event() + + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + on_exit=exit_called.set, + shutdown_timeout=2.0, + ) + time.sleep(0.2) # let watchdog start + mp.stop() + time.sleep(1.0) # give watchdog time to potentially fire + assert not exit_called.is_set(), "on_exit fired after intentional stop()" + + def test_stdout_logged(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "print('hello from subprocess')"], + ) + time.sleep(1.0) # let output be read + mp.stop() + + def test_stderr_logged(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import sys; sys.stderr.write('error msg\\n')"], + ) + time.sleep(1.0) + mp.stop() + + def test_log_json_mode(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", """import json; print(json.dumps({"event": "test", "key": "val"}))"""], + log_json=True, + ) + time.sleep(1.0) + mp.stop() + + def test_log_json_malformed(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "print('not json')"], + log_json=True, + ) + time.sleep(1.0) + mp.stop() + + def test_stop_process_that_ignores_sigterm(self) -> None: + """Process that ignores SIGTERM should be killed with SIGKILL.""" + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[ + PYTHON, + "-c", + "import signal, time; signal.signal(signal.SIGTERM, signal.SIG_IGN); time.sleep(60)", + ], + shutdown_timeout=0.5, + kill_timeout=2.0, + ) + time.sleep(0.2) + start = time.monotonic() + mp.stop() + elapsed = time.monotonic() - start + assert not mp.is_alive + # Should take roughly shutdown_timeout (0.5) + a bit for SIGKILL + assert elapsed < 5.0 + + def test_stop_already_dead_process(self) -> None: + """stop() on a process that already exited should not raise.""" + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "pass"], # exits immediately + ) + time.sleep(1.0) # let it die + mp.stop() # should not raise + + def test_concurrent_stop(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + shutdown_timeout=2.0, + ) + errors = [] + + def stop_it() -> None: + try: + mp.stop() + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=stop_it) for _ in range(20)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=10) + assert not errors, f"Concurrent stop() raised: {errors}" + + def test_on_exit_calls_module_stop_no_deadlock(self) -> None: + """Simulate the real pattern: on_exit=module.stop, which disposes the + ModuleProcess, which tries to stop its watchdog from inside the watchdog. + Must not deadlock. + """ + mod = FakeModule() + stop_called = threading.Event() + + def fake_module_stop() -> None: + """Simulates module.stop() -> _stop() -> dispose()""" + mod.dispose() + stop_called.set() + + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "pass"], # exits immediately + on_exit=fake_module_stop, + ) + stop_called.wait(timeout=5) + assert stop_called.is_set(), "Deadlocked! on_exit -> dispose -> stop chain hung" + + def test_on_exit_calls_module_stop_no_deadlock_stress(self) -> None: + """Run the deadlock test multiple times under load.""" + for i in range(10): + self.test_on_exit_calls_module_stop_no_deadlock() + + def test_deferred_start(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + start=False, + ) + assert not mp.is_alive + mp.start() + assert mp.is_alive + mp.stop() + + def test_env_passed(self) -> None: + mod = FakeModule() + exit_called = threading.Event() + + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import os, sys; sys.exit(0 if os.environ.get('MY_VAR') == '42' else 1)"], + env={**os.environ, "MY_VAR": "42"}, + on_exit=exit_called.set, + ) + exit_called.wait(timeout=5) + # Process should have exited with 0 (our on_exit fires for all unmanaged exits) + assert exit_called.is_set() + + def test_cwd_passed(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import os; print(os.getcwd())"], + cwd="/tmp", + ) + time.sleep(1.0) + mp.stop() + + +# =================================================================== +# safe_thread_map Tests +# =================================================================== + + +class TestSafeThreadMap: + def test_empty_input(self) -> None: + assert safe_thread_map([], lambda x: x) == [] + + def test_all_succeed(self) -> None: + result = safe_thread_map([1, 2, 3], lambda x: x * 2) + assert result == [2, 4, 6] + + def test_preserves_order(self) -> None: + def slow(x: int) -> int: + time.sleep(0.01 * (10 - x)) + return x + + result = safe_thread_map(list(range(10)), slow) + assert result == list(range(10)) + + def test_all_fail_raises_exception_group(self) -> None: + def fail(x: int) -> int: + raise ValueError(f"fail-{x}") + + with pytest.raises(ExceptionGroup) as exc_info: + safe_thread_map([1, 2, 3], fail) + assert len(exc_info.value.exceptions) == 3 + + def test_partial_failure(self) -> None: + def maybe_fail(x: int) -> int: + if x == 2: + raise ValueError("fail") + return x + + with pytest.raises(ExceptionGroup) as exc_info: + safe_thread_map([1, 2, 3], maybe_fail) + assert len(exc_info.value.exceptions) == 1 + + def test_on_errors_callback(self) -> None: + def fail(x: int) -> int: + if x == 2: + raise ValueError("boom") + return x * 10 + + cleanup_called = False + + def on_errors(outcomes, successes, errors): + nonlocal cleanup_called + cleanup_called = True + assert len(errors) == 1 + assert len(successes) == 2 + return successes # return successful results + + result = safe_thread_map([1, 2, 3], fail, on_errors) + assert cleanup_called + assert sorted(result) == [10, 30] + + def test_on_errors_can_raise(self) -> None: + def fail(x: int) -> int: + raise ValueError("boom") + + def on_errors(outcomes, successes, errors): + raise RuntimeError("custom error") + + with pytest.raises(RuntimeError, match="custom error"): + safe_thread_map([1], fail, on_errors) + + def test_waits_for_all_before_raising(self) -> None: + """Even if one fails fast, all others should complete.""" + completed = [] + + def work(x: int) -> int: + if x == 0: + raise ValueError("fast fail") + time.sleep(0.2) + completed.append(x) + return x + + with pytest.raises(ExceptionGroup): + safe_thread_map([0, 1, 2, 3], work) + # All non-failing items should have completed + assert sorted(completed) == [1, 2, 3] + + +# =================================================================== +# Integration: ModuleProcess on_exit -> dispose chain (the CI bug scenario) +# =================================================================== + + +class TestModuleProcessDisposeChain: + """Tests the exact pattern that caused the CI bug: + process exits -> watchdog fires on_exit -> module.stop() -> dispose -> + ModuleProcess.stop() -> tries to stop watchdog from inside watchdog thread. + """ + + def test_chain_no_deadlock_fast_exit(self) -> None: + """Process exits immediately.""" + for _ in range(20): + mod = FakeModule() + done = threading.Event() + + def fake_stop() -> None: + mod.dispose() + done.set() + + ModuleProcess( + module=mod, + args=[PYTHON, "-c", "pass"], + on_exit=fake_stop, + ) + assert done.wait(timeout=5), "Deadlock in dispose chain (fast exit)" + + def test_chain_no_deadlock_slow_exit(self) -> None: + """Process runs briefly then exits.""" + for _ in range(10): + mod = FakeModule() + done = threading.Event() + + def fake_stop() -> None: + mod.dispose() + done.set() + + ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(0.1)"], + on_exit=fake_stop, + ) + assert done.wait(timeout=5), "Deadlock in dispose chain (slow exit)" + + def test_chain_concurrent_with_external_stop(self) -> None: + """Process exits naturally while external code calls stop().""" + for _ in range(20): + mod = FakeModule() + done = threading.Event() + + def fake_stop() -> None: + mod.dispose() + done.set() + + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(0.05)"], + on_exit=fake_stop, + shutdown_timeout=1.0, + ) + # Race: the process might exit naturally or we might stop it + time.sleep(0.03) + mp.stop() + # Either way, should not deadlock + time.sleep(1.0) + + def test_dispose_with_artificial_delay(self) -> None: + """Add artificial delay near cleanup to simulate heavy CPU load.""" + original_stop = ModuleThread.stop + + def slow_stop(self_mt: ModuleThread) -> None: + time.sleep(0.05) # simulate load + original_stop(self_mt) + + for _ in range(10): + mod = FakeModule() + done = threading.Event() + + def fake_stop() -> None: + mod.dispose() + done.set() + + with mock.patch.object(ModuleThread, "stop", slow_stop): + ModuleProcess( + module=mod, + args=[PYTHON, "-c", "pass"], + on_exit=fake_stop, + ) + assert done.wait(timeout=10), "Deadlock with slow ModuleThread.stop()" + + +# We need ExceptionGroup for safe_thread_map tests +try: + ExceptionGroup +except NameError: + from dimos.utils.typing_utils import ExceptionGroup diff --git a/dimos/utils/thread_utils.py b/dimos/utils/thread_utils.py new file mode 100644 index 0000000000..e83047ca5b --- /dev/null +++ b/dimos/utils/thread_utils.py @@ -0,0 +1,542 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Thread utilities: safe values, managed threads, safe parallel map.""" + +from __future__ import annotations + +import asyncio +import json +import signal +import subprocess +import threading +from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from typing import IO, TYPE_CHECKING, Any, Generic + +from reactivex.disposable import Disposable + +from dimos.utils.logging_config import setup_logger +from dimos.utils.typing_utils import ExceptionGroup, TypeVar + +logger = setup_logger() + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from dimos.core.module import ModuleBase + +T = TypeVar("T") +R = TypeVar("R") + + +# --------------------------------------------------------------------------- +# ThreadSafeVal: a lock-protected value with context-manager support +# --------------------------------------------------------------------------- + + +class ThreadSafeVal(Generic[T]): + """A thread-safe value wrapper. + + Wraps any value with a lock and provides atomic read-modify-write + via a context manager:: + + counter = ThreadSafeVal(0) + + # Simple get/set (each acquires the lock briefly): + counter.set(10) + print(counter.get()) # 10 + + # Atomic read-modify-write: + with counter as value: + # Lock is held for the entire block. + # Other threads block on get/set/with until this exits. + if value < 100: + counter.set(value + 1) + + # Works with any type: + status = ThreadSafeVal({"running": False, "count": 0}) + with status as s: + status.set({**s, "running": True}) + + # Bool check (for flag-like usage): + stopping = ThreadSafeVal(False) + stopping.set(True) + if stopping: + print("stopping!") + """ + + def __init__(self, initial: T) -> None: + self._lock = threading.RLock() + self._value = initial + + def get(self) -> T: + """Return the current value (acquires the lock briefly).""" + with self._lock: + return self._value + + def set(self, value: T) -> None: + """Replace the value (acquires the lock briefly).""" + with self._lock: + self._value = value + + def __bool__(self) -> bool: + with self._lock: + return bool(self._value) + + def __enter__(self) -> T: + self._lock.acquire() + return self._value + + def __exit__(self, *exc: object) -> None: + self._lock.release() + + def __getstate__(self) -> dict[str, Any]: + return {"_value": self._value} + + def __setstate__(self, state: dict[str, Any]) -> None: + self._lock = threading.RLock() + self._value = state["_value"] + + def __repr__(self) -> str: + return f"ThreadSafeVal({self._value!r})" + + +# --------------------------------------------------------------------------- +# ModuleThread: a thread that auto-registers with a module's disposables +# --------------------------------------------------------------------------- + + +class ModuleThread: + """A thread that registers cleanup with a module's disposables. + + Passes most kwargs through to ``threading.Thread``. On construction, + registers a disposable with the module so that when the module stops, + the thread is automatically joined. Cleanup is idempotent — safe to + call ``stop()`` manually even if the module also disposes it. + + Example:: + + class MyModule(Module): + @rpc + def start(self) -> None: + self._worker = ModuleThread( + module=self, + target=self._run_loop, + name="my-worker", + ) + + def _run_loop(self) -> None: + while not self._worker.stopping: + do_work() + """ + + def __init__( + self, + module: ModuleBase, + *, + start: bool = True, + close_timeout: float = 2.0, + **thread_kwargs: Any, + ) -> None: + thread_kwargs.setdefault("daemon", True) + self._thread = threading.Thread(**thread_kwargs) + self._stop_event = threading.Event() + self._close_timeout = close_timeout + self._stopped = False + self._stop_lock = threading.Lock() + module._disposables.add(Disposable(self.stop)) + if start: + self.start() + + @property + def stopping(self) -> bool: + """True after ``stop()`` has been called.""" + return self._stop_event.is_set() + + def start(self) -> None: + """Start the underlying thread.""" + self._stop_event.clear() + self._thread.start() + + def stop(self) -> None: + """Signal the thread to stop and join it. + + Safe to call multiple times, from any thread (including the + managed thread itself — it will skip the join in that case). + """ + with self._stop_lock: + if self._stopped: + return + self._stopped = True + + self._stop_event.set() + if self._thread.is_alive() and self._thread is not threading.current_thread(): + self._thread.join(timeout=self._close_timeout) + + def join(self, timeout: float | None = None) -> None: + """Join the underlying thread.""" + self._thread.join(timeout=timeout) + + @property + def is_alive(self) -> bool: + return self._thread.is_alive() + + +# --------------------------------------------------------------------------- +# AsyncModuleThread: a thread running an asyncio event loop, auto-registered +# --------------------------------------------------------------------------- + + +class AsyncModuleThread: + """A thread running an asyncio event loop, registered with a module's disposables. + + If a loop is already running in the current context, reuses it (no thread + created). Otherwise creates a new loop and drives it in a daemon thread. + + On stop (or module dispose), the loop is shut down gracefully and the + thread is joined. Idempotent — safe to call ``stop()`` multiple times. + + Example:: + + class MyModule(Module): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._async = AsyncModuleThread(module=self) + + @rpc + def start(self) -> None: + future = asyncio.run_coroutine_threadsafe( + self._do_work(), self._async.loop + ) + + async def _do_work(self) -> None: + ... + """ + + def __init__( + self, + module: ModuleBase, + *, + close_timeout: float = 2.0, + ) -> None: + self._close_timeout = close_timeout + self._stopped = False + self._stop_lock = threading.Lock() + self._owns_loop = False + self._thread: threading.Thread | None = None + + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + self._owns_loop = True + self._thread = threading.Thread( + target=self._loop.run_forever, + daemon=True, + name=f"{type(module).__name__}-event-loop", + ) + self._thread.start() + + module._disposables.add(Disposable(self.stop)) + + @property + def loop(self) -> asyncio.AbstractEventLoop: + """The managed event loop.""" + return self._loop + + @property + def is_alive(self) -> bool: + return self._thread is not None and self._thread.is_alive() + + def stop(self) -> None: + """Stop the event loop and join the thread. + + No-op if the loop was not created by this instance (reused an + existing running loop). Safe to call multiple times. + """ + with self._stop_lock: + if self._stopped: + return + self._stopped = True + + if self._owns_loop and self._loop.is_running(): + self._loop.call_soon_threadsafe(self._loop.stop) + + if self._thread is not None and self._thread.is_alive(): + self._thread.join(timeout=self._close_timeout) + + +# --------------------------------------------------------------------------- +# ModuleProcess: managed subprocess with log piping, auto-registered cleanup +# --------------------------------------------------------------------------- + + +class ModuleProcess: + """A managed subprocess that pipes stdout/stderr through the logger. + + Registers with a module's disposables so the process is automatically + stopped on module teardown. A watchdog thread monitors the process and + calls ``on_exit`` if the process exits on its own (i.e. not via + ``ModuleProcess.stop()``). + + Most constructor kwargs mirror ``subprocess.Popen``. ``stdout`` and + ``stderr`` are always captured (set to ``PIPE`` internally). + + Example:: + + class MyModule(Module): + @rpc + def start(self) -> None: + self._proc = ModuleProcess( + module=self, + args=["./my_binary", "--flag"], + cwd="/opt/bin", + on_exit=self.stop, # stops the whole module if process exits on its own + ) + + @rpc + def stop(self) -> None: + # ModuleProcess.stop() is also called automatically via disposables + super().stop() + """ + + def __init__( + self, + module: ModuleBase, + args: list[str] | str, + *, + env: dict[str, str] | None = None, + cwd: str | None = None, + shell: bool = False, + on_exit: Callable[[], Any] | None = None, + shutdown_timeout: float = 10.0, + kill_timeout: float = 5.0, + log_json: bool = False, + start: bool = True, + **popen_kwargs: Any, + ) -> None: + self._args = args + self._env = env + self._cwd = cwd + self._shell = shell + self._on_exit = on_exit + self._shutdown_timeout = shutdown_timeout + self._kill_timeout = kill_timeout + self._log_json = log_json + self._popen_kwargs = popen_kwargs + self._process: subprocess.Popen[bytes] | None = None + self._watchdog: ModuleThread | None = None + self._module = module + self._stopped = False + self._stop_lock = threading.Lock() + + module._disposables.add(Disposable(self.stop)) + if start: + self.start() + + @property + def pid(self) -> int | None: + return self._process.pid if self._process is not None else None + + @property + def returncode(self) -> int | None: + if self._process is None: + return None + return self._process.poll() + + @property + def is_alive(self) -> bool: + return self._process is not None and self._process.poll() is None + + def start(self) -> None: + """Launch the subprocess and start the watchdog.""" + if self._process is not None and self._process.poll() is None: + logger.warning("Process already running", pid=self._process.pid) + return + + with self._stop_lock: + self._stopped = False + + logger.info( + "Starting process", + cmd=self._args if isinstance(self._args, str) else " ".join(self._args), + cwd=self._cwd, + ) + self._process = subprocess.Popen( + self._args, + env=self._env, + cwd=self._cwd, + shell=self._shell, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + **self._popen_kwargs, + ) + logger.info("Process started", pid=self._process.pid) + + self._watchdog = ModuleThread( + module=self._module, + target=self._watch, + name=f"proc-{self._process.pid}-watchdog", + ) + + def stop(self) -> None: + """Send SIGTERM, wait, escalate to SIGKILL if needed. Idempotent.""" + with self._stop_lock: + if self._stopped: + return + self._stopped = True + + if self._process is not None and self._process.poll() is None: + logger.info("Stopping process", pid=self._process.pid) + try: + self._process.send_signal(signal.SIGTERM) + except OSError: + pass # process already dead (PID recycled or exited between poll and signal) + else: + try: + self._process.wait(timeout=self._shutdown_timeout) + except subprocess.TimeoutExpired: + logger.warning( + "Process did not exit, sending SIGKILL", + pid=self._process.pid, + ) + self._process.kill() + try: + self._process.wait(timeout=self._kill_timeout) + except subprocess.TimeoutExpired: + logger.error( + "Process did not exit after SIGKILL", + pid=self._process.pid, + ) + self._process = None + + def _watch(self) -> None: + """Watchdog: pipe logs, detect crashes.""" + proc = self._process + if proc is None: + return + + stdout_t = self._start_reader(proc.stdout, "info") + stderr_t = self._start_reader(proc.stderr, "warning") + rc = proc.wait() + stdout_t.join(timeout=2) + stderr_t.join(timeout=2) + + with self._stop_lock: + if self._stopped: + return + + logger.error("Process died unexpectedly", pid=proc.pid, returncode=rc) + if self._on_exit is not None: + self._on_exit() + + def _start_reader(self, stream: IO[bytes] | None, level: str) -> threading.Thread: + t = threading.Thread(target=self._read_stream, args=(stream, level), daemon=True) + t.start() + return t + + def _read_stream(self, stream: IO[bytes] | None, level: str) -> None: + if stream is None: + return + log_fn = getattr(logger, level) + for raw in stream: + line = raw.decode("utf-8", errors="replace").rstrip() + if not line: + continue + if self._log_json: + try: + data = json.loads(line) + event = data.pop("event", line) + log_fn(event, **data) + continue + except (json.JSONDecodeError, TypeError): + logger.warning("malformed JSON from process", raw=line) + proc = self._process + log_fn(line, pid=proc.pid if proc else None) + stream.close() + + +# --------------------------------------------------------------------------- +# safe_thread_map: parallel map that collects all results before raising +# --------------------------------------------------------------------------- + + +def safe_thread_map( + items: Sequence[T], + fn: Callable[[T], R], + on_errors: Callable[[list[tuple[T, R | Exception]], list[R], list[Exception]], Any] + | None = None, +) -> list[R]: + """Thread-pool map that waits for all items to finish before raising and a cleanup handler + + - Empty *items* → returns ``[]`` immediately. + - All succeed → returns results in input order. + - Any fail → calls ``on_errors(outcomes, successes, errors)`` where + *outcomes* is a list of ``(input, result_or_exception)`` pairs in input + order, *successes* is the list of successful results, and *errors* is + the list of exceptions. If *on_errors* raises, that exception propagates. + If *on_errors* returns normally, its return value is returned from + ``safe_thread_map``. If *on_errors* is ``None``, raises an + ``ExceptionGroup``. + + Example:: + + def start_service(name: str) -> Connection: + return connect(name) + + def cleanup( + outcomes: list[tuple[str, Connection | Exception]], + successes: list[Connection], + errors: list[Exception], + ) -> None: + for conn in successes: + conn.close() + raise ExceptionGroup("failed to start services", errors) + + connections = safe_thread_map( + ["db", "cache", "queue"], + start_service, + cleanup, # called only if any start_service() raises + ) + """ + if not items: + return [] + + outcomes: dict[int, R | Exception] = {} + + with ThreadPoolExecutor(max_workers=len(items)) as pool: + futures: dict[Future[R], int] = {pool.submit(fn, item): i for i, item in enumerate(items)} + for fut in as_completed(futures): + idx = futures[fut] + try: + outcomes[idx] = fut.result() + except Exception as e: + outcomes[idx] = e + + successes: list[R] = [] + errors: list[Exception] = [] + for v in outcomes.values(): + if isinstance(v, Exception): + errors.append(v) + else: + successes.append(v) + + if errors: + if on_errors is not None: + zipped = [(items[i], outcomes[i]) for i in range(len(items))] + return on_errors(zipped, successes, errors) # type: ignore[return-value, no-any-return] + raise ExceptionGroup("safe_thread_map failed", errors) + + return [outcomes[i] for i in range(len(items))] # type: ignore[misc] diff --git a/dimos/utils/typing_utils.py b/dimos/utils/typing_utils.py new file mode 100644 index 0000000000..aa32fff47f --- /dev/null +++ b/dimos/utils/typing_utils.py @@ -0,0 +1,45 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unify typing compatibility across multiple Python versions.""" + +from __future__ import annotations + +import sys +from collections.abc import Sequence + +if sys.version_info < (3, 13): + from typing_extensions import TypeVar +else: + from typing import TypeVar + +if sys.version_info < (3, 11): + + class ExceptionGroup(Exception): # type: ignore[no-redef] # noqa: N818 + """Minimal ExceptionGroup polyfill for Python 3.10.""" + + exceptions: tuple[BaseException, ...] + + def __init__(self, message: str, exceptions: Sequence[BaseException]) -> None: + super().__init__(message) + self.exceptions = tuple(exceptions) +else: + import builtins + + ExceptionGroup = builtins.ExceptionGroup # type: ignore[misc] + +__all__ = [ + "ExceptionGroup", + "TypeVar", +] From b8fcd083724c9042b6ae64890b6822ff3ef44daa Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Tue, 24 Mar 2026 23:04:44 -0700 Subject: [PATCH 12/21] improve tests --- dimos/core/test_native_module.py | 16 +++++++++------- dimos/utils/test_thread_utils.py | 5 ++++- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/dimos/core/test_native_module.py b/dimos/core/test_native_module.py index a7e6bd2b9a..6a743ac378 100644 --- a/dimos/core/test_native_module.py +++ b/dimos/core/test_native_module.py @@ -96,19 +96,21 @@ def test_process_crash_triggers_stop() -> None: mod.pointcloud.transport = LCMTransport("/pc", PointCloud2) mod.start() - assert mod._process is not None - pid = mod._process.pid + assert mod._proc is not None + assert mod._proc.is_alive + pid = mod._proc.pid - # Wait for the process to die and the watchdog to call stop() + # Wait for the process to die and the on_exit callback to call stop() for _ in range(30): time.sleep(0.1) - if mod._process is None: + if mod._proc is None or not mod._proc.is_alive: break - assert mod._process is None, f"Watchdog did not clean up after process {pid} died" + assert mod._proc is None or not mod._proc.is_alive, ( + f"Watchdog did not clean up after process {pid} died" + ) - # Join the watchdog thread. stop() is idempotent but will now join the - # watchdog on the second call since the reference is preserved. + # stop() is idempotent mod.stop() diff --git a/dimos/utils/test_thread_utils.py b/dimos/utils/test_thread_utils.py index 87c4a883e2..28152d9242 100644 --- a/dimos/utils/test_thread_utils.py +++ b/dimos/utils/test_thread_utils.py @@ -357,9 +357,10 @@ def stop_it() -> None: def test_close_timeout_respected(self) -> None: """If the thread ignores the stop signal, stop() should return after close_timeout.""" mod = FakeModule() + bail = threading.Event() def stubborn_target() -> None: - time.sleep(10) # ignores stopping signal + bail.wait(timeout=10) # ignores stopping signal, but we can bail it out mt = ModuleThread( module=mod, target=stubborn_target, name="test-timeout", close_timeout=0.2 @@ -368,6 +369,8 @@ def stubborn_target() -> None: mt.stop() elapsed = time.monotonic() - start assert elapsed < 1.0, f"stop() took {elapsed}s, expected ~0.2s" + bail.set() # let the thread exit so conftest thread-leak detector is happy + mt.join(timeout=2) def test_stop_concurrent_with_dispose(self) -> None: """Calling stop() and dispose() concurrently should not crash.""" From 1d06db973cd8d3ba0709e533d470721e787b4c77 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Tue, 24 Mar 2026 23:28:16 -0700 Subject: [PATCH 13/21] formatting --- dimos/core/module.py | 8 +-- dimos/utils/test_thread_utils.py | 108 ++++++++++++++----------------- dimos/utils/thread_utils.py | 25 +++---- dimos/utils/typing_utils.py | 2 +- 4 files changed, 66 insertions(+), 77 deletions(-) diff --git a/dimos/core/module.py b/dimos/core/module.py index d9f8356f38..af733a19b7 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -11,14 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import asyncio from collections.abc import Callable from dataclasses import dataclass from functools import partial import inspect import json import sys -import threading from typing import ( TYPE_CHECKING, Any, @@ -100,8 +98,10 @@ def __init__(self, config_args: dict[str, Any]): super().__init__(**config_args) self._disposables = CompositeDisposable() self.mod_state = ThreadSafeVal[ModState]("init") - self._async_thread = AsyncModuleThread( # NEEDS to be created after self._disposables exists - module=self + self._async_thread = ( + AsyncModuleThread( # NEEDS to be created after self._disposables exists + module=self + ) ) try: self.rpc = self.config.rpc_transport() diff --git a/dimos/utils/test_thread_utils.py b/dimos/utils/test_thread_utils.py index 28152d9242..d8274c8b7f 100644 --- a/dimos/utils/test_thread_utils.py +++ b/dimos/utils/test_thread_utils.py @@ -1,3 +1,17 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Exhaustive tests for dimos/utils/thread_utils.py Covers: ThreadSafeVal, ModuleThread, AsyncModuleThread, ModuleProcess, safe_thread_map. @@ -9,8 +23,6 @@ import asyncio import os import pickle -import signal -import subprocess import sys import threading import time @@ -27,10 +39,7 @@ safe_thread_map, ) - -# --------------------------------------------------------------------------- # Helpers: fake ModuleBase for testing ModuleThread / AsyncModuleThread / ModuleProcess -# --------------------------------------------------------------------------- class FakeModule: @@ -43,9 +52,7 @@ def dispose(self) -> None: self._disposables.dispose() -# =================================================================== # ThreadSafeVal Tests -# =================================================================== class TestThreadSafeVal: @@ -97,7 +104,7 @@ def test_get_inside_context_manager_no_deadlock(self) -> None: result = threading.Event() def do_it() -> None: - with v as val: + with v: _ = v.get() result.set() @@ -111,7 +118,7 @@ def test_bool_inside_context_manager_no_deadlock(self) -> None: result = threading.Event() def do_it() -> None: - with v as val: + with v: _ = bool(v) result.set() @@ -128,7 +135,7 @@ def test_context_manager_blocks_other_threads(self) -> None: other_finished = threading.Event() def holder() -> None: - with v as val: + with v: gate.wait(timeout=5) # hold the lock until signaled def setter() -> None: @@ -215,7 +222,7 @@ def test_nested_with_no_deadlock(self) -> None: result = threading.Event() def do_it() -> None: - with v as val1: + with v: with v as val2: v.set(val2 + 1) result.set() @@ -226,9 +233,7 @@ def do_it() -> None: assert result.is_set(), "Nested with blocks deadlocked!" -# =================================================================== # ModuleThread Tests -# =================================================================== class TestModuleThread: @@ -378,8 +383,8 @@ def test_stop_concurrent_with_dispose(self) -> None: mod = FakeModule() holder: list[ModuleThread] = [] - def target() -> None: - while not holder[0].stopping: + def target(h: list[ModuleThread] = holder) -> None: + while not h[0].stopping: time.sleep(0.001) mt = ModuleThread(module=mod, target=target, name="test-stop-dispose", start=False) @@ -395,9 +400,7 @@ def target() -> None: t2.join(timeout=3) -# =================================================================== # AsyncModuleThread Tests -# =================================================================== class TestAsyncModuleThread: @@ -475,9 +478,7 @@ def stop_it() -> None: assert not errors -# =================================================================== # ModuleProcess Tests -# =================================================================== # Helper: path to a python that sleeps or echoes @@ -516,7 +517,6 @@ def test_dispose_stops_process(self) -> None: args=[PYTHON, "-c", "import time; time.sleep(30)"], shutdown_timeout=2.0, ) - pid = mp.pid mod.dispose() time.sleep(0.5) assert not mp.is_alive @@ -526,7 +526,7 @@ def test_on_exit_fires_on_natural_exit(self) -> None: mod = FakeModule() exit_called = threading.Event() - mp = ModuleProcess( + ModuleProcess( module=mod, args=[PYTHON, "-c", "print('done')"], on_exit=exit_called.set, @@ -538,7 +538,7 @@ def test_on_exit_fires_on_crash(self) -> None: mod = FakeModule() exit_called = threading.Event() - mp = ModuleProcess( + ModuleProcess( module=mod, args=[PYTHON, "-c", "import sys; sys.exit(1)"], on_exit=exit_called.set, @@ -584,7 +584,11 @@ def test_log_json_mode(self) -> None: mod = FakeModule() mp = ModuleProcess( module=mod, - args=[PYTHON, "-c", """import json; print(json.dumps({"event": "test", "key": "val"}))"""], + args=[ + PYTHON, + "-c", + """import json; print(json.dumps({"event": "test", "key": "val"}))""", + ], log_json=True, ) time.sleep(1.0) @@ -666,7 +670,7 @@ def fake_module_stop() -> None: mod.dispose() stop_called.set() - mp = ModuleProcess( + ModuleProcess( module=mod, args=[PYTHON, "-c", "pass"], # exits immediately on_exit=fake_module_stop, @@ -676,7 +680,7 @@ def fake_module_stop() -> None: def test_on_exit_calls_module_stop_no_deadlock_stress(self) -> None: """Run the deadlock test multiple times under load.""" - for i in range(10): + for _i in range(10): self.test_on_exit_calls_module_stop_no_deadlock() def test_deferred_start(self) -> None: @@ -695,9 +699,13 @@ def test_env_passed(self) -> None: mod = FakeModule() exit_called = threading.Event() - mp = ModuleProcess( + ModuleProcess( module=mod, - args=[PYTHON, "-c", "import os, sys; sys.exit(0 if os.environ.get('MY_VAR') == '42' else 1)"], + args=[ + PYTHON, + "-c", + "import os, sys; sys.exit(0 if os.environ.get('MY_VAR') == '42' else 1)", + ], env={**os.environ, "MY_VAR": "42"}, on_exit=exit_called.set, ) @@ -716,9 +724,7 @@ def test_cwd_passed(self) -> None: mp.stop() -# =================================================================== # safe_thread_map Tests -# =================================================================== class TestSafeThreadMap: @@ -801,9 +807,7 @@ def work(x: int) -> int: assert sorted(completed) == [1, 2, 3] -# =================================================================== # Integration: ModuleProcess on_exit -> dispose chain (the CI bug scenario) -# =================================================================== class TestModuleProcessDisposeChain: @@ -812,20 +816,23 @@ class TestModuleProcessDisposeChain: ModuleProcess.stop() -> tries to stop watchdog from inside watchdog thread. """ + @staticmethod + def _make_fake_stop(mod: FakeModule, done: threading.Event) -> Callable: + def fake_stop() -> None: + mod.dispose() + done.set() + + return fake_stop + def test_chain_no_deadlock_fast_exit(self) -> None: """Process exits immediately.""" for _ in range(20): mod = FakeModule() done = threading.Event() - - def fake_stop() -> None: - mod.dispose() - done.set() - ModuleProcess( module=mod, args=[PYTHON, "-c", "pass"], - on_exit=fake_stop, + on_exit=self._make_fake_stop(mod, done), ) assert done.wait(timeout=5), "Deadlock in dispose chain (fast exit)" @@ -834,15 +841,10 @@ def test_chain_no_deadlock_slow_exit(self) -> None: for _ in range(10): mod = FakeModule() done = threading.Event() - - def fake_stop() -> None: - mod.dispose() - done.set() - ModuleProcess( module=mod, args=[PYTHON, "-c", "import time; time.sleep(0.1)"], - on_exit=fake_stop, + on_exit=self._make_fake_stop(mod, done), ) assert done.wait(timeout=5), "Deadlock in dispose chain (slow exit)" @@ -851,15 +853,10 @@ def test_chain_concurrent_with_external_stop(self) -> None: for _ in range(20): mod = FakeModule() done = threading.Event() - - def fake_stop() -> None: - mod.dispose() - done.set() - mp = ModuleProcess( module=mod, args=[PYTHON, "-c", "import time; time.sleep(0.05)"], - on_exit=fake_stop, + on_exit=self._make_fake_stop(mod, done), shutdown_timeout=1.0, ) # Race: the process might exit naturally or we might stop it @@ -879,22 +876,13 @@ def slow_stop(self_mt: ModuleThread) -> None: for _ in range(10): mod = FakeModule() done = threading.Event() - - def fake_stop() -> None: - mod.dispose() - done.set() - with mock.patch.object(ModuleThread, "stop", slow_stop): ModuleProcess( module=mod, args=[PYTHON, "-c", "pass"], - on_exit=fake_stop, + on_exit=self._make_fake_stop(mod, done), ) assert done.wait(timeout=10), "Deadlock with slow ModuleThread.stop()" -# We need ExceptionGroup for safe_thread_map tests -try: - ExceptionGroup -except NameError: - from dimos.utils.typing_utils import ExceptionGroup +from dimos.utils.typing_utils import ExceptionGroup diff --git a/dimos/utils/thread_utils.py b/dimos/utils/thread_utils.py index e83047ca5b..3a53386c50 100644 --- a/dimos/utils/thread_utils.py +++ b/dimos/utils/thread_utils.py @@ -17,11 +17,12 @@ from __future__ import annotations import asyncio +import collections +from concurrent.futures import Future, ThreadPoolExecutor, as_completed import json import signal import subprocess import threading -from concurrent.futures import Future, ThreadPoolExecutor, as_completed from typing import IO, TYPE_CHECKING, Any, Generic from reactivex.disposable import Disposable @@ -40,9 +41,7 @@ R = TypeVar("R") -# --------------------------------------------------------------------------- # ThreadSafeVal: a lock-protected value with context-manager support -# --------------------------------------------------------------------------- class ThreadSafeVal(Generic[T]): @@ -112,9 +111,7 @@ def __repr__(self) -> str: return f"ThreadSafeVal({self._value!r})" -# --------------------------------------------------------------------------- # ModuleThread: a thread that auto-registers with a module's disposables -# --------------------------------------------------------------------------- class ModuleThread: @@ -193,9 +190,7 @@ def is_alive(self) -> bool: return self._thread.is_alive() -# --------------------------------------------------------------------------- # AsyncModuleThread: a thread running an asyncio event loop, auto-registered -# --------------------------------------------------------------------------- class AsyncModuleThread: @@ -278,9 +273,7 @@ def stop(self) -> None: self._thread.join(timeout=self._close_timeout) -# --------------------------------------------------------------------------- # ModuleProcess: managed subprocess with log piping, auto-registered cleanup -# --------------------------------------------------------------------------- class ModuleProcess: @@ -341,6 +334,7 @@ def __init__( self._module = module self._stopped = False self._stop_lock = threading.Lock() + self.last_stderr: collections.deque[str] = collections.deque(maxlen=50) module._disposables.add(Disposable(self.stop)) if start: @@ -438,7 +432,13 @@ def _watch(self) -> None: if self._stopped: return - logger.error("Process died unexpectedly", pid=proc.pid, returncode=rc) + last_stderr = "\n".join(self.last_stderr) + logger.error( + "Process died unexpectedly", + pid=proc.pid, + returncode=rc, + last_stderr=last_stderr[:500] if last_stderr else None, + ) if self._on_exit is not None: self._on_exit() @@ -451,10 +451,13 @@ def _read_stream(self, stream: IO[bytes] | None, level: str) -> None: if stream is None: return log_fn = getattr(logger, level) + is_stderr = level == "warning" for raw in stream: line = raw.decode("utf-8", errors="replace").rstrip() if not line: continue + if is_stderr: + self.last_stderr.append(line) if self._log_json: try: data = json.loads(line) @@ -468,9 +471,7 @@ def _read_stream(self, stream: IO[bytes] | None, level: str) -> None: stream.close() -# --------------------------------------------------------------------------- # safe_thread_map: parallel map that collects all results before raising -# --------------------------------------------------------------------------- def safe_thread_map( diff --git a/dimos/utils/typing_utils.py b/dimos/utils/typing_utils.py index aa32fff47f..3592d5fdbb 100644 --- a/dimos/utils/typing_utils.py +++ b/dimos/utils/typing_utils.py @@ -16,8 +16,8 @@ from __future__ import annotations -import sys from collections.abc import Sequence +import sys if sys.version_info < (3, 13): from typing_extensions import TypeVar From 75708ffe723856453f41f17f322654ff1fdd8c26 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Tue, 24 Mar 2026 23:32:54 -0700 Subject: [PATCH 14/21] misc improve --- dimos/utils/thread_utils.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/dimos/utils/thread_utils.py b/dimos/utils/thread_utils.py index 3a53386c50..e9844326b5 100644 --- a/dimos/utils/thread_utils.py +++ b/dimos/utils/thread_utils.py @@ -140,7 +140,7 @@ def _run_loop(self) -> None: def __init__( self, - module: ModuleBase, + module: ModuleBase[Any], *, start: bool = True, close_timeout: float = 2.0, @@ -221,7 +221,7 @@ async def _do_work(self) -> None: def __init__( self, - module: ModuleBase, + module: ModuleBase[Any], *, close_timeout: float = 2.0, ) -> None: @@ -307,7 +307,7 @@ def stop(self) -> None: def __init__( self, - module: ModuleBase, + module: ModuleBase[Any], args: list[str] | str, *, env: dict[str, str] | None = None, @@ -317,6 +317,7 @@ def __init__( shutdown_timeout: float = 10.0, kill_timeout: float = 5.0, log_json: bool = False, + log_tail_lines: int = 50, start: bool = True, **popen_kwargs: Any, ) -> None: @@ -328,13 +329,15 @@ def __init__( self._shutdown_timeout = shutdown_timeout self._kill_timeout = kill_timeout self._log_json = log_json + self._log_tail_lines = log_tail_lines self._popen_kwargs = popen_kwargs self._process: subprocess.Popen[bytes] | None = None self._watchdog: ModuleThread | None = None self._module = module self._stopped = False self._stop_lock = threading.Lock() - self.last_stderr: collections.deque[str] = collections.deque(maxlen=50) + self.last_stdout: collections.deque[str] = collections.deque(maxlen=log_tail_lines) + self.last_stderr: collections.deque[str] = collections.deque(maxlen=log_tail_lines) module._disposables.add(Disposable(self.stop)) if start: @@ -363,6 +366,9 @@ def start(self) -> None: with self._stop_lock: self._stopped = False + self.last_stdout = collections.deque(maxlen=self._log_tail_lines) + self.last_stderr = collections.deque(maxlen=self._log_tail_lines) + logger.info( "Starting process", cmd=self._args if isinstance(self._args, str) else " ".join(self._args), @@ -432,12 +438,14 @@ def _watch(self) -> None: if self._stopped: return - last_stderr = "\n".join(self.last_stderr) + last_stdout = "\n".join(self.last_stdout) or None + last_stderr = "\n".join(self.last_stderr) or None logger.error( "Process died unexpectedly", pid=proc.pid, returncode=rc, - last_stderr=last_stderr[:500] if last_stderr else None, + last_stdout=last_stdout, + last_stderr=last_stderr, ) if self._on_exit is not None: self._on_exit() @@ -452,12 +460,12 @@ def _read_stream(self, stream: IO[bytes] | None, level: str) -> None: return log_fn = getattr(logger, level) is_stderr = level == "warning" + buf = self.last_stderr if is_stderr else self.last_stdout for raw in stream: line = raw.decode("utf-8", errors="replace").rstrip() if not line: continue - if is_stderr: - self.last_stderr.append(line) + buf.append(line) if self._log_json: try: data = json.loads(line) From aff62bc452fdb4f58f3b3fffe5e59c05d7e03fd8 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Wed, 25 Mar 2026 00:09:21 -0700 Subject: [PATCH 15/21] Apply suggestions from code review Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- dimos/agents/mcp/mcp_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dimos/agents/mcp/mcp_server.py b/dimos/agents/mcp/mcp_server.py index d09202532c..f858571fac 100644 --- a/dimos/agents/mcp/mcp_server.py +++ b/dimos/agents/mcp/mcp_server.py @@ -181,7 +181,7 @@ def stop(self) -> None: if self._uvicorn_server: self._uvicorn_server.should_exit = True loop = self._async_thread.loop - if loop is not None and self._serve_future is not None: + if self._serve_future is not None: self._serve_future.result(timeout=5.0) self._uvicorn_server = None self._serve_future = None From d5ae028dd4ff231ab4173caa8c7690911aa1d19f Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Wed, 25 Mar 2026 00:10:43 -0700 Subject: [PATCH 16/21] cleanup --- dimos/agents/mcp/test_mcp_server.py | 59 ++++++++++++++++++++++++++++- dimos/core/module.py | 6 +-- dimos/utils/test_thread_utils.py | 6 +-- dimos/utils/thread_utils.py | 1 - 4 files changed, 64 insertions(+), 8 deletions(-) diff --git a/dimos/agents/mcp/test_mcp_server.py b/dimos/agents/mcp/test_mcp_server.py index 1cbca9e3e4..0424a2b16f 100644 --- a/dimos/agents/mcp/test_mcp_server.py +++ b/dimos/agents/mcp/test_mcp_server.py @@ -16,9 +16,13 @@ import asyncio import json +import socket +import time from unittest.mock import MagicMock -from dimos.agents.mcp.mcp_server import handle_request +import requests + +from dimos.agents.mcp.mcp_server import McpServer, handle_request from dimos.core.module import SkillInfo @@ -111,3 +115,56 @@ def test_mcp_module_initialize_and_unknown() -> None: response = asyncio.run(handle_request({"method": "unknown/method", "id": 2}, [], {})) assert response["error"]["code"] == -32601 + + +def _free_port() -> int: + with socket.socket() as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def test_mcp_server_lifecycle() -> None: + """Start a real McpServer, hit the HTTP endpoint, then stop it. + + This exercises the AsyncModuleThread event loop integration that the + unit tests above do not cover. + """ + port = _free_port() + + server = McpServer() + server._start_server(port=port) + url = f"http://127.0.0.1:{port}/mcp" + + # Wait for the server to be ready + for _ in range(40): + try: + resp = requests.post( + url, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + timeout=0.5, + ) + if resp.status_code == 200: + break + except requests.ConnectionError: + time.sleep(0.1) + else: + server.stop() + raise AssertionError("McpServer did not become ready") + + # Verify it responds + data = resp.json() + assert data["result"]["serverInfo"]["name"] == "dimensional" + + # Stop and verify it shuts down + server.stop() + time.sleep(0.3) + + with socket.socket() as s: + # Port should be released after stop + try: + s.connect(("127.0.0.1", port)) + s.close() + # If we could connect, the server is still up — that's a bug + raise AssertionError("McpServer still listening after stop()") + except ConnectionRefusedError: + pass # expected — server is down diff --git a/dimos/core/module.py b/dimos/core/module.py index af733a19b7..d96133ec7d 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -46,7 +46,7 @@ from dimos.utils.generic import classproperty from dimos.utils.thread_utils import AsyncModuleThread, ThreadSafeVal -ModState = Literal["init", "started", "stopping", "stopped"] +ModState = Literal["init", "started", "stopped"] if TYPE_CHECKING: from dimos.core.blueprints import Blueprint @@ -130,9 +130,9 @@ def stop(self) -> None: def _stop(self) -> None: with self.mod_state as state: - if state in ("stopping", "stopped"): + if state == "stopped": return - self.mod_state.set("stopping") + self.mod_state.set("stopped") if self.rpc: self.rpc.stop() # type: ignore[attr-defined] diff --git a/dimos/utils/test_thread_utils.py b/dimos/utils/test_thread_utils.py index d8274c8b7f..07047c6d92 100644 --- a/dimos/utils/test_thread_utils.py +++ b/dimos/utils/test_thread_utils.py @@ -210,11 +210,11 @@ def test_string_literal_type(self) -> None: assert v.get() == "started" with v as state: - if state in ("stopping", "stopped"): + if state == "stopped": pass # no-op else: - v.set("stopping") - assert v.get() == "stopping" + v.set("stopped") + assert v.get() == "stopped" def test_nested_with_no_deadlock(self) -> None: """RLock should allow the same thread to nest with blocks.""" diff --git a/dimos/utils/thread_utils.py b/dimos/utils/thread_utils.py index e9844326b5..6d9b7a9e7f 100644 --- a/dimos/utils/thread_utils.py +++ b/dimos/utils/thread_utils.py @@ -301,7 +301,6 @@ def start(self) -> None: @rpc def stop(self) -> None: - # ModuleProcess.stop() is also called automatically via disposables super().stop() """ From 4595390eda7bea07bd01e74e9fe99a9445665837 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Wed, 25 Mar 2026 00:13:01 -0700 Subject: [PATCH 17/21] - --- dimos/core/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dimos/core/module.py b/dimos/core/module.py index d96133ec7d..39800fe2d1 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -95,9 +95,9 @@ class ModuleBase(Configurable[ModuleConfigT], Resource): rpc_calls: list[str] = [] def __init__(self, config_args: dict[str, Any]): + self.mod_state = ThreadSafeVal[ModState]("init") super().__init__(**config_args) self._disposables = CompositeDisposable() - self.mod_state = ThreadSafeVal[ModState]("init") self._async_thread = ( AsyncModuleThread( # NEEDS to be created after self._disposables exists module=self From 3b5c4fd359a0672b3528b8b6180e3ae60393c3f8 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Wed, 25 Mar 2026 00:24:25 -0700 Subject: [PATCH 18/21] fix order of _disposables --- dimos/agents/mcp/mcp_server.py | 1 - dimos/core/module.py | 6 ++++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/dimos/agents/mcp/mcp_server.py b/dimos/agents/mcp/mcp_server.py index f858571fac..1e1d7d9942 100644 --- a/dimos/agents/mcp/mcp_server.py +++ b/dimos/agents/mcp/mcp_server.py @@ -180,7 +180,6 @@ def start(self) -> None: def stop(self) -> None: if self._uvicorn_server: self._uvicorn_server.should_exit = True - loop = self._async_thread.loop if self._serve_future is not None: self._serve_future.result(timeout=5.0) self._uvicorn_server = None diff --git a/dimos/core/module.py b/dimos/core/module.py index 39800fe2d1..4fbffe07b9 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -134,6 +134,10 @@ def _stop(self) -> None: return self.mod_state.set("stopped") + # dispose of things BEFORE making aspects like rpc and _tf invalid + if hasattr(self, "_disposables"): + self._disposables.dispose() # stops _async_thread via disposable + if self.rpc: self.rpc.stop() # type: ignore[attr-defined] self.rpc = None # type: ignore[assignment] @@ -141,8 +145,6 @@ def _stop(self) -> None: if hasattr(self, "_tf") and self._tf is not None: self._tf.stop() self._tf = None - if hasattr(self, "_disposables"): - self._disposables.dispose() # stops _async_thread via disposable # Break the In/Out -> owner -> self reference cycle so the instance # can be freed by refcount instead of waiting for GC. From f9b6d04fc0f1a1f7b739f6ba13c41b853ffbeb40 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Wed, 25 Mar 2026 23:19:13 -0700 Subject: [PATCH 19/21] pr feedback --- dimos/core/native_module.py | 1 + dimos/utils/test_thread_utils.py | 85 ++++++++++++++-------------- dimos/utils/thread_utils.py | 97 ++++++++++---------------------- 3 files changed, 71 insertions(+), 112 deletions(-) diff --git a/dimos/core/native_module.py b/dimos/core/native_module.py index 2c7e685150..3a6236045a 100644 --- a/dimos/core/native_module.py +++ b/dimos/core/native_module.py @@ -154,6 +154,7 @@ def start(self) -> None: shutdown_timeout=self.config.shutdown_timeout, log_json=self.config.log_format == LogFormat.JSON, ) + self._proc.start() def _resolve_paths(self) -> None: """Resolve relative ``cwd`` and ``executable`` against the subclass's source file.""" diff --git a/dimos/utils/test_thread_utils.py b/dimos/utils/test_thread_utils.py index 07047c6d92..973884ea66 100644 --- a/dimos/utils/test_thread_utils.py +++ b/dimos/utils/test_thread_utils.py @@ -62,18 +62,6 @@ def test_basic_get_set(self) -> None: v.set(99) assert v.get() == 99 - def test_bool_truthy(self) -> None: - v = ThreadSafeVal(True) - assert bool(v) is True - v.set(False) - assert bool(v) is False - - def test_bool_zero(self) -> None: - v = ThreadSafeVal(0) - assert bool(v) is False - v.set(1) - assert bool(v) is True - def test_context_manager_returns_value(self) -> None: v = ThreadSafeVal("hello") with v as val: @@ -113,20 +101,6 @@ def do_it() -> None: t.join(timeout=2) assert result.is_set(), "Deadlocked! get() inside with block hung" - def test_bool_inside_context_manager_no_deadlock(self) -> None: - v = ThreadSafeVal(True) - result = threading.Event() - - def do_it() -> None: - with v: - _ = bool(v) - result.set() - - t = threading.Thread(target=do_it) - t.start() - t.join(timeout=2) - assert result.is_set(), "Deadlocked! bool() inside with block hung" - def test_context_manager_blocks_other_threads(self) -> None: """While one thread holds the lock via `with`, others should block on set().""" v = ThreadSafeVal(0) @@ -245,6 +219,7 @@ def target() -> None: ran.set() mt = ModuleThread(module=mod, target=target, name="test-basic") + mt.start() ran.wait(timeout=2) assert ran.is_set() mt.stop() @@ -254,6 +229,7 @@ def test_auto_start(self) -> None: mod = FakeModule() started = threading.Event() mt = ModuleThread(module=mod, target=started.set, name="test-autostart") + mt.start() started.wait(timeout=2) assert started.is_set() mt.stop() @@ -261,7 +237,7 @@ def test_auto_start(self) -> None: def test_deferred_start(self) -> None: mod = FakeModule() started = threading.Event() - mt = ModuleThread(module=mod, target=started.set, name="test-deferred", start=False) + mt = ModuleThread(module=mod, target=started.set, name="test-deferred") time.sleep(0.1) assert not started.is_set() mt.start() @@ -275,11 +251,11 @@ def test_stopping_property(self) -> None: holder: list[ModuleThread] = [] def target() -> None: - while not holder[0].stopping: + while holder[0].status.get() == "running": time.sleep(0.01) saw_stopping.set() - mt = ModuleThread(module=mod, target=target, name="test-stopping", start=False) + mt = ModuleThread(module=mod, target=target, name="test-stopping") holder.append(mt) mt.start() time.sleep(0.05) @@ -290,6 +266,7 @@ def target() -> None: def test_stop_idempotent(self) -> None: mod = FakeModule() mt = ModuleThread(module=mod, target=lambda: time.sleep(0.01), name="test-idem") + mt.start() time.sleep(0.05) mt.stop() mt.stop() # second call should not raise @@ -305,7 +282,7 @@ def target() -> None: holder[0].stop() # stop ourselves — should not deadlock result.set() - mt = ModuleThread(module=mod, target=target, name="test-self-stop", start=False) + mt = ModuleThread(module=mod, target=target, name="test-self-stop") holder.append(mt) mt.start() result.wait(timeout=3) @@ -319,10 +296,10 @@ def test_dispose_stops_thread(self) -> None: def target() -> None: running.set() - while not holder[0].stopping: + while holder[0].status.get() == "running": time.sleep(0.01) - mt = ModuleThread(module=mod, target=target, name="test-dispose", start=False) + mt = ModuleThread(module=mod, target=target, name="test-dispose") holder.append(mt) mt.start() running.wait(timeout=2) @@ -336,10 +313,10 @@ def test_concurrent_stop_calls(self) -> None: holder: list[ModuleThread] = [] def target() -> None: - while not holder[0].stopping: + while holder[0].status.get() == "running": time.sleep(0.01) - mt = ModuleThread(module=mod, target=target, name="test-concurrent-stop", start=False) + mt = ModuleThread(module=mod, target=target, name="test-concurrent-stop") holder.append(mt) mt.start() time.sleep(0.05) @@ -370,6 +347,7 @@ def stubborn_target() -> None: mt = ModuleThread( module=mod, target=stubborn_target, name="test-timeout", close_timeout=0.2 ) + mt.start() start = time.monotonic() mt.stop() elapsed = time.monotonic() - start @@ -384,10 +362,10 @@ def test_stop_concurrent_with_dispose(self) -> None: holder: list[ModuleThread] = [] def target(h: list[ModuleThread] = holder) -> None: - while not h[0].stopping: + while h[0].status.get() == "running": time.sleep(0.001) - mt = ModuleThread(module=mod, target=target, name="test-stop-dispose", start=False) + mt = ModuleThread(module=mod, target=target, name="test-stop-dispose") holder.append(mt) mt.start() time.sleep(0.02) @@ -493,6 +471,7 @@ def test_basic_lifecycle(self) -> None: args=[PYTHON, "-c", "import time; time.sleep(30)"], shutdown_timeout=2.0, ) + mp.start() assert mp.is_alive assert mp.pid is not None mp.stop() @@ -506,6 +485,7 @@ def test_stop_idempotent(self) -> None: args=[PYTHON, "-c", "import time; time.sleep(30)"], shutdown_timeout=1.0, ) + mp.start() mp.stop() mp.stop() # should not raise mp.stop() @@ -517,6 +497,7 @@ def test_dispose_stops_process(self) -> None: args=[PYTHON, "-c", "import time; time.sleep(30)"], shutdown_timeout=2.0, ) + mp.start() mod.dispose() time.sleep(0.5) assert not mp.is_alive @@ -526,11 +507,12 @@ def test_on_exit_fires_on_natural_exit(self) -> None: mod = FakeModule() exit_called = threading.Event() - ModuleProcess( + mp = ModuleProcess( module=mod, args=[PYTHON, "-c", "print('done')"], on_exit=exit_called.set, ) + mp.start() exit_called.wait(timeout=5) assert exit_called.is_set(), "on_exit was not called after natural process exit" @@ -538,11 +520,12 @@ def test_on_exit_fires_on_crash(self) -> None: mod = FakeModule() exit_called = threading.Event() - ModuleProcess( + mp = ModuleProcess( module=mod, args=[PYTHON, "-c", "import sys; sys.exit(1)"], on_exit=exit_called.set, ) + mp.start() exit_called.wait(timeout=5) assert exit_called.is_set(), "on_exit was not called after process crash" @@ -557,6 +540,7 @@ def test_on_exit_not_fired_on_stop(self) -> None: on_exit=exit_called.set, shutdown_timeout=2.0, ) + mp.start() time.sleep(0.2) # let watchdog start mp.stop() time.sleep(1.0) # give watchdog time to potentially fire @@ -568,6 +552,7 @@ def test_stdout_logged(self) -> None: module=mod, args=[PYTHON, "-c", "print('hello from subprocess')"], ) + mp.start() time.sleep(1.0) # let output be read mp.stop() @@ -577,6 +562,7 @@ def test_stderr_logged(self) -> None: module=mod, args=[PYTHON, "-c", "import sys; sys.stderr.write('error msg\\n')"], ) + mp.start() time.sleep(1.0) mp.stop() @@ -591,6 +577,7 @@ def test_log_json_mode(self) -> None: ], log_json=True, ) + mp.start() time.sleep(1.0) mp.stop() @@ -601,6 +588,7 @@ def test_log_json_malformed(self) -> None: args=[PYTHON, "-c", "print('not json')"], log_json=True, ) + mp.start() time.sleep(1.0) mp.stop() @@ -617,6 +605,7 @@ def test_stop_process_that_ignores_sigterm(self) -> None: shutdown_timeout=0.5, kill_timeout=2.0, ) + mp.start() time.sleep(0.2) start = time.monotonic() mp.stop() @@ -632,6 +621,7 @@ def test_stop_already_dead_process(self) -> None: module=mod, args=[PYTHON, "-c", "pass"], # exits immediately ) + mp.start() time.sleep(1.0) # let it die mp.stop() # should not raise @@ -642,6 +632,7 @@ def test_concurrent_stop(self) -> None: args=[PYTHON, "-c", "import time; time.sleep(30)"], shutdown_timeout=2.0, ) + mp.start() errors = [] def stop_it() -> None: @@ -670,11 +661,12 @@ def fake_module_stop() -> None: mod.dispose() stop_called.set() - ModuleProcess( + mp = ModuleProcess( module=mod, args=[PYTHON, "-c", "pass"], # exits immediately on_exit=fake_module_stop, ) + mp.start() stop_called.wait(timeout=5) assert stop_called.is_set(), "Deadlocked! on_exit -> dispose -> stop chain hung" @@ -688,7 +680,6 @@ def test_deferred_start(self) -> None: mp = ModuleProcess( module=mod, args=[PYTHON, "-c", "import time; time.sleep(30)"], - start=False, ) assert not mp.is_alive mp.start() @@ -699,7 +690,7 @@ def test_env_passed(self) -> None: mod = FakeModule() exit_called = threading.Event() - ModuleProcess( + mp = ModuleProcess( module=mod, args=[ PYTHON, @@ -709,6 +700,7 @@ def test_env_passed(self) -> None: env={**os.environ, "MY_VAR": "42"}, on_exit=exit_called.set, ) + mp.start() exit_called.wait(timeout=5) # Process should have exited with 0 (our on_exit fires for all unmanaged exits) assert exit_called.is_set() @@ -720,6 +712,7 @@ def test_cwd_passed(self) -> None: args=[PYTHON, "-c", "import os; print(os.getcwd())"], cwd="/tmp", ) + mp.start() time.sleep(1.0) mp.stop() @@ -829,11 +822,12 @@ def test_chain_no_deadlock_fast_exit(self) -> None: for _ in range(20): mod = FakeModule() done = threading.Event() - ModuleProcess( + mp = ModuleProcess( module=mod, args=[PYTHON, "-c", "pass"], on_exit=self._make_fake_stop(mod, done), ) + mp.start() assert done.wait(timeout=5), "Deadlock in dispose chain (fast exit)" def test_chain_no_deadlock_slow_exit(self) -> None: @@ -841,11 +835,12 @@ def test_chain_no_deadlock_slow_exit(self) -> None: for _ in range(10): mod = FakeModule() done = threading.Event() - ModuleProcess( + mp = ModuleProcess( module=mod, args=[PYTHON, "-c", "import time; time.sleep(0.1)"], on_exit=self._make_fake_stop(mod, done), ) + mp.start() assert done.wait(timeout=5), "Deadlock in dispose chain (slow exit)" def test_chain_concurrent_with_external_stop(self) -> None: @@ -859,6 +854,7 @@ def test_chain_concurrent_with_external_stop(self) -> None: on_exit=self._make_fake_stop(mod, done), shutdown_timeout=1.0, ) + mp.start() # Race: the process might exit naturally or we might stop it time.sleep(0.03) mp.stop() @@ -877,11 +873,12 @@ def slow_stop(self_mt: ModuleThread) -> None: mod = FakeModule() done = threading.Event() with mock.patch.object(ModuleThread, "stop", slow_stop): - ModuleProcess( + mp = ModuleProcess( module=mod, args=[PYTHON, "-c", "pass"], on_exit=self._make_fake_stop(mod, done), ) + mp.start() assert done.wait(timeout=10), "Deadlock with slow ModuleThread.stop()" diff --git a/dimos/utils/thread_utils.py b/dimos/utils/thread_utils.py index 6d9b7a9e7f..adf6751333 100644 --- a/dimos/utils/thread_utils.py +++ b/dimos/utils/thread_utils.py @@ -23,7 +23,7 @@ import signal import subprocess import threading -from typing import IO, TYPE_CHECKING, Any, Generic +from typing import IO, TYPE_CHECKING, Any, Generic, Literal from reactivex.disposable import Disposable @@ -41,22 +41,14 @@ R = TypeVar("R") -# ThreadSafeVal: a lock-protected value with context-manager support - - class ThreadSafeVal(Generic[T]): """A thread-safe value wrapper. - Wraps any value with a lock and provides atomic read-modify-write - via a context manager:: - - counter = ThreadSafeVal(0) - - # Simple get/set (each acquires the lock briefly): - counter.set(10) - print(counter.get()) # 10 - - # Atomic read-modify-write: + Forces lock usage in order to get access to a value (reduces unsafe value access) + Three ways to use: + 1. `.set` + 2. `.get` + 3. via a context manager:: with counter as value: # Lock is held for the entire block. # Other threads block on get/set/with until this exits. @@ -71,7 +63,7 @@ class ThreadSafeVal(Generic[T]): # Bool check (for flag-like usage): stopping = ThreadSafeVal(False) stopping.set(True) - if stopping: + if stopping.get(): print("stopping!") """ @@ -89,10 +81,6 @@ def set(self, value: T) -> None: with self._lock: self._value = value - def __bool__(self) -> bool: - with self._lock: - return bool(self._value) - def __enter__(self) -> T: self._lock.acquire() return self._value @@ -111,9 +99,6 @@ def __repr__(self) -> str: return f"ThreadSafeVal({self._value!r})" -# ModuleThread: a thread that auto-registers with a module's disposables - - class ModuleThread: """A thread that registers cleanup with a module's disposables. @@ -132,9 +117,10 @@ def start(self) -> None: target=self._run_loop, name="my-worker", ) + self._worker.start() def _run_loop(self) -> None: - while not self._worker.stopping: + while self._worker.status.get() == "running": do_work() """ @@ -142,28 +128,18 @@ def __init__( self, module: ModuleBase[Any], *, - start: bool = True, close_timeout: float = 2.0, **thread_kwargs: Any, ) -> None: thread_kwargs.setdefault("daemon", True) + thread_kwargs.setdefault("name", f"{type(module).__name__}-thread") self._thread = threading.Thread(**thread_kwargs) - self._stop_event = threading.Event() self._close_timeout = close_timeout - self._stopped = False - self._stop_lock = threading.Lock() + self.status: ThreadSafeVal[Literal["not_started", "running", "stopping", "stopped"]] = ThreadSafeVal("not_started") module._disposables.add(Disposable(self.stop)) - if start: - self.start() - - @property - def stopping(self) -> bool: - """True after ``stop()`` has been called.""" - return self._stop_event.is_set() def start(self) -> None: - """Start the underlying thread.""" - self._stop_event.clear() + self.status.set("running") self._thread.start() def stop(self) -> None: @@ -172,14 +148,13 @@ def stop(self) -> None: Safe to call multiple times, from any thread (including the managed thread itself — it will skip the join in that case). """ - with self._stop_lock: - if self._stopped: + with self.status as s: + if s in ("stopping", "stopped"): return - self._stopped = True - - self._stop_event.set() + self.status.set("stopping") if self._thread.is_alive() and self._thread is not threading.current_thread(): self._thread.join(timeout=self._close_timeout) + self.status.set("stopped") def join(self, timeout: float | None = None) -> None: """Join the underlying thread.""" @@ -190,9 +165,6 @@ def is_alive(self) -> bool: return self._thread.is_alive() -# AsyncModuleThread: a thread running an asyncio event loop, auto-registered - - class AsyncModuleThread: """A thread running an asyncio event loop, registered with a module's disposables. @@ -226,8 +198,7 @@ def __init__( close_timeout: float = 2.0, ) -> None: self._close_timeout = close_timeout - self._stopped = False - self._stop_lock = threading.Lock() + self._stopped = ThreadSafeVal(False) self._owns_loop = False self._thread: threading.Thread | None = None @@ -261,10 +232,10 @@ def stop(self) -> None: No-op if the loop was not created by this instance (reused an existing running loop). Safe to call multiple times. """ - with self._stop_lock: - if self._stopped: + with self._stopped as stopped: + if stopped: return - self._stopped = True + self._stopped.set(True) if self._owns_loop and self._loop.is_running(): self._loop.call_soon_threadsafe(self._loop.stop) @@ -273,9 +244,6 @@ def stop(self) -> None: self._thread.join(timeout=self._close_timeout) -# ModuleProcess: managed subprocess with log piping, auto-registered cleanup - - class ModuleProcess: """A managed subprocess that pipes stdout/stderr through the logger. @@ -298,6 +266,7 @@ def start(self) -> None: cwd="/opt/bin", on_exit=self.stop, # stops the whole module if process exits on its own ) + self._proc.start() @rpc def stop(self) -> None: @@ -317,7 +286,6 @@ def __init__( kill_timeout: float = 5.0, log_json: bool = False, log_tail_lines: int = 50, - start: bool = True, **popen_kwargs: Any, ) -> None: self._args = args @@ -333,14 +301,11 @@ def __init__( self._process: subprocess.Popen[bytes] | None = None self._watchdog: ModuleThread | None = None self._module = module - self._stopped = False - self._stop_lock = threading.Lock() + self._stopped = ThreadSafeVal(False) self.last_stdout: collections.deque[str] = collections.deque(maxlen=log_tail_lines) self.last_stderr: collections.deque[str] = collections.deque(maxlen=log_tail_lines) module._disposables.add(Disposable(self.stop)) - if start: - self.start() @property def pid(self) -> int | None: @@ -362,8 +327,7 @@ def start(self) -> None: logger.warning("Process already running", pid=self._process.pid) return - with self._stop_lock: - self._stopped = False + self._stopped.set(False) self.last_stdout = collections.deque(maxlen=self._log_tail_lines) self.last_stderr = collections.deque(maxlen=self._log_tail_lines) @@ -389,13 +353,14 @@ def start(self) -> None: target=self._watch, name=f"proc-{self._process.pid}-watchdog", ) + self._watchdog.start() def stop(self) -> None: """Send SIGTERM, wait, escalate to SIGKILL if needed. Idempotent.""" - with self._stop_lock: - if self._stopped: + with self._stopped as stopped: + if stopped: return - self._stopped = True + self._stopped.set(True) if self._process is not None and self._process.poll() is None: logger.info("Stopping process", pid=self._process.pid) @@ -433,9 +398,8 @@ def _watch(self) -> None: stdout_t.join(timeout=2) stderr_t.join(timeout=2) - with self._stop_lock: - if self._stopped: - return + if self._stopped.get(): + return last_stdout = "\n".join(self.last_stdout) or None last_stderr = "\n".join(self.last_stderr) or None @@ -478,9 +442,6 @@ def _read_stream(self, stream: IO[bytes] | None, level: str) -> None: stream.close() -# safe_thread_map: parallel map that collects all results before raising - - def safe_thread_map( items: Sequence[T], fn: Callable[[T], R], From e5d739f166864d37400231f98a6b02005b4c94e2 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Thu, 26 Mar 2026 03:40:18 -0700 Subject: [PATCH 20/21] refactor: merge _stop into stop, add thread_start helper - Merge _stop() into stop() in ModuleBase (removes unnecessary indirection) - Update all callers of _stop() to use stop() directly - Add thread_start() convenience function that creates + starts a ModuleThread --- dimos/core/module.py | 3 -- dimos/core/test_core.py | 2 +- dimos/perception/detection/conftest.py | 10 +++--- .../perception/detection/reid/test_module.py | 2 +- dimos/robot/unitree/b1/test_connection.py | 18 +++++----- .../mujoco/direct_cmd_vel_explorer.py | 2 +- dimos/utils/thread_utils.py | 36 +++++++++++++++++++ 7 files changed, 53 insertions(+), 20 deletions(-) diff --git a/dimos/core/module.py b/dimos/core/module.py index 4fbffe07b9..1ad934b5ae 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -126,9 +126,6 @@ def start(self) -> None: @rpc def stop(self) -> None: - self._stop() - - def _stop(self) -> None: with self.mod_state as state: if state == "stopped": return diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index aae167d8c6..858cc81849 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -89,7 +89,7 @@ def test_classmethods() -> None: ) assert hasattr(class_rpcs["start"], "__rpc__"), "start should have __rpc__ attribute" - nav._stop() + nav.stop() @pytest.mark.slow diff --git a/dimos/perception/detection/conftest.py b/dimos/perception/detection/conftest.py index 2040d687be..cb20e3d06a 100644 --- a/dimos/perception/detection/conftest.py +++ b/dimos/perception/detection/conftest.py @@ -221,7 +221,7 @@ def moment_provider(**kwargs) -> Moment2D: yield moment_provider moment_provider.cache_clear() - module._stop() + module.stop() @pytest.fixture(scope="session") @@ -256,7 +256,7 @@ def moment_provider(**kwargs) -> Moment3D: yield moment_provider moment_provider.cache_clear() if module is not None: - module._stop() + module.stop() @pytest.fixture(scope="session") @@ -290,9 +290,9 @@ def object_db_module(get_moment): yield moduleDB - module2d._stop() - module3d._stop() - moduleDB._stop() + module2d.stop() + module3d.stop() + moduleDB.stop() @pytest.fixture(scope="session") diff --git a/dimos/perception/detection/reid/test_module.py b/dimos/perception/detection/reid/test_module.py index 5fa0eead8d..89d752aae3 100644 --- a/dimos/perception/detection/reid/test_module.py +++ b/dimos/perception/detection/reid/test_module.py @@ -40,5 +40,5 @@ def test_reid_ingress(imageDetections2d) -> None: print("Processing detections through ReidModule...") reid_module.annotations._transport = LCMTransport("/annotations", ImageAnnotations) reid_module.ingress(imageDetections2d) - reid_module._stop() + reid_module.stop() print("✓ ReidModule ingress test completed successfully") diff --git a/dimos/robot/unitree/b1/test_connection.py b/dimos/robot/unitree/b1/test_connection.py index 8443e9e161..977554eb59 100644 --- a/dimos/robot/unitree/b1/test_connection.py +++ b/dimos/robot/unitree/b1/test_connection.py @@ -82,7 +82,7 @@ def test_watchdog_actually_zeros_commands(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=1.0) conn.watchdog_thread.join(timeout=1.0) - conn._stop() + conn.stop() def test_watchdog_resets_on_new_command(self) -> None: """Test that watchdog timeout resets when new command arrives.""" @@ -114,7 +114,7 @@ def test_watchdog_resets_on_new_command(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=1.0) conn.watchdog_thread.join(timeout=1.0) - conn._stop() + conn.stop() def test_watchdog_thread_efficiency(self) -> None: """Test that watchdog uses only one thread regardless of command rate.""" @@ -148,7 +148,7 @@ def test_watchdog_thread_efficiency(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._stop() + conn.stop() def test_watchdog_with_send_loop_blocking(self) -> None: """Test that watchdog still works if send loop blocks.""" @@ -200,7 +200,7 @@ def blocking_send_loop() -> None: conn.watchdog_running = False conn.send_thread.join(timeout=1.0) conn.watchdog_thread.join(timeout=1.0) - conn._stop() + conn.stop() def test_continuous_commands_prevent_timeout(self) -> None: """Test that continuous commands prevent watchdog timeout.""" @@ -238,7 +238,7 @@ def test_continuous_commands_prevent_timeout(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=1.0) conn.watchdog_thread.join(timeout=1.0) - conn._stop() + conn.stop() def test_watchdog_timing_accuracy(self) -> None: """Test that watchdog zeros commands at approximately 200ms.""" @@ -280,7 +280,7 @@ def test_watchdog_timing_accuracy(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._stop() + conn.stop() def test_mode_changes_with_watchdog(self) -> None: """Test that mode changes work correctly with watchdog.""" @@ -323,7 +323,7 @@ def test_mode_changes_with_watchdog(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._stop() + conn.stop() def test_watchdog_stops_movement_when_commands_stop(self) -> None: """Verify watchdog zeros commands when packets stop being sent.""" @@ -390,7 +390,7 @@ def test_watchdog_stops_movement_when_commands_stop(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=1.0) conn.watchdog_thread.join(timeout=1.0) - conn._stop() + conn.stop() def test_rapid_command_thread_safety(self) -> None: """Test thread safety with rapid commands from multiple threads.""" @@ -439,4 +439,4 @@ def send_commands(thread_id) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._stop() + conn.stop() diff --git a/dimos/simulation/mujoco/direct_cmd_vel_explorer.py b/dimos/simulation/mujoco/direct_cmd_vel_explorer.py index 58dc91f6b1..f81c644d0a 100644 --- a/dimos/simulation/mujoco/direct_cmd_vel_explorer.py +++ b/dimos/simulation/mujoco/direct_cmd_vel_explorer.py @@ -99,7 +99,7 @@ def _drive_to(self, target_x: float, target_y: float) -> None: None, Twist(linear=Vector3(linear, 0, 0), angular=Vector3(0, 0, angular)), ) - self._stop() + self.stop() def follow_points(self, waypoints: list[tuple[float, float]]) -> None: self._wait_for_pose() diff --git a/dimos/utils/thread_utils.py b/dimos/utils/thread_utils.py index adf6751333..5b3f6648f2 100644 --- a/dimos/utils/thread_utils.py +++ b/dimos/utils/thread_utils.py @@ -509,3 +509,39 @@ def cleanup( raise ExceptionGroup("safe_thread_map failed", errors) return [outcomes[i] for i in range(len(items))] # type: ignore[misc] + + +def thread_start( + module: "ModuleBase[Any]", + *, + close_timeout: float = 2.0, + **thread_kwargs: Any, +) -> ModuleThread: + """Create a :class:`ModuleThread`, start it immediately, and return it. + + Convenience wrapper equivalent to:: + + t = ModuleThread(module, close_timeout=close_timeout, **thread_kwargs) + t.start() + return t + + Accepts the same arguments as :class:`ModuleThread`. + + Example:: + + class MyModule(Module): + @rpc + def start(self) -> None: + self._worker = thread_start( + self, + target=self._run_loop, + name="my-worker", + ) + + def _run_loop(self) -> None: + while self._worker.status.get() == "running": + do_work() + """ + t = ModuleThread(module, close_timeout=close_timeout, **thread_kwargs) + t.start() + return t From 530680f5acb160a6bf53b5f87854ba5396360c04 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Thu, 26 Mar 2026 03:48:05 -0700 Subject: [PATCH 21/21] refactor: defer AsyncModuleThread loop creation to start() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit AsyncModuleThread no longer spawns the event loop thread in __init__. The loop is created on the first call to start(), which ModuleBase.start() now calls. This means module construction no longer has side effects — no threads are spawned until the module is explicitly started. --- dimos/core/module.py | 1 + dimos/utils/test_thread_utils.py | 6 ++++++ dimos/utils/thread_utils.py | 18 ++++++++++++++---- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/dimos/core/module.py b/dimos/core/module.py index 1ad934b5ae..83dfeddcd4 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -123,6 +123,7 @@ def start(self) -> None: if state == "stopped": raise RuntimeError(f"{type(self).__name__} cannot be restarted after stop") self.mod_state.set("started") + self._async_thread.start() @rpc def stop(self) -> None: diff --git a/dimos/utils/test_thread_utils.py b/dimos/utils/test_thread_utils.py index 973884ea66..f6b03c3542 100644 --- a/dimos/utils/test_thread_utils.py +++ b/dimos/utils/test_thread_utils.py @@ -385,6 +385,7 @@ class TestAsyncModuleThread: def test_creates_loop_and_thread(self) -> None: mod = FakeModule() amt = AsyncModuleThread(module=mod) + amt.start() assert amt.loop is not None assert amt.loop.is_running() assert amt.is_alive @@ -394,6 +395,7 @@ def test_creates_loop_and_thread(self) -> None: def test_stop_idempotent(self) -> None: mod = FakeModule() amt = AsyncModuleThread(module=mod) + amt.start() amt.stop() amt.stop() # should not raise amt.stop() @@ -401,6 +403,7 @@ def test_stop_idempotent(self) -> None: def test_dispose_stops_loop(self) -> None: mod = FakeModule() amt = AsyncModuleThread(module=mod) + amt.start() assert amt.is_alive mod.dispose() time.sleep(0.1) @@ -409,6 +412,7 @@ def test_dispose_stops_loop(self) -> None: def test_can_schedule_coroutine(self) -> None: mod = FakeModule() amt = AsyncModuleThread(module=mod) + amt.start() result = [] async def coro() -> None: @@ -423,6 +427,7 @@ def test_stop_with_pending_work(self) -> None: """Stop should succeed even with long-running tasks on the loop.""" mod = FakeModule() amt = AsyncModuleThread(module=mod) + amt.start() started = threading.Event() async def slow_coro() -> None: @@ -440,6 +445,7 @@ async def slow_coro() -> None: def test_concurrent_stop(self) -> None: mod = FakeModule() amt = AsyncModuleThread(module=mod) + amt.start() errors = [] def stop_it() -> None: diff --git a/dimos/utils/thread_utils.py b/dimos/utils/thread_utils.py index 5b3f6648f2..e1767988c0 100644 --- a/dimos/utils/thread_utils.py +++ b/dimos/utils/thread_utils.py @@ -201,6 +201,18 @@ def __init__( self._stopped = ThreadSafeVal(False) self._owns_loop = False self._thread: threading.Thread | None = None + self._loop: asyncio.AbstractEventLoop | None = None + self._module_name = type(module).__name__ + + module._disposables.add(Disposable(self.stop)) + + def start(self) -> None: + """Create (or reuse) the event loop and start the driver thread. + + Safe to call multiple times — subsequent calls are no-ops. + """ + if self._loop is not None: + return try: self._loop = asyncio.get_running_loop() @@ -211,12 +223,10 @@ def __init__( self._thread = threading.Thread( target=self._loop.run_forever, daemon=True, - name=f"{type(module).__name__}-event-loop", + name=f"{self._module_name}-event-loop", ) self._thread.start() - module._disposables.add(Disposable(self.stop)) - @property def loop(self) -> asyncio.AbstractEventLoop: """The managed event loop.""" @@ -237,7 +247,7 @@ def stop(self) -> None: return self._stopped.set(True) - if self._owns_loop and self._loop.is_running(): + if self._owns_loop and self._loop is not None and self._loop.is_running(): self._loop.call_soon_threadsafe(self._loop.stop) if self._thread is not None and self._thread.is_alive():