Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions dimos/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2026 Dimensional Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def __getattr__(name: str) -> object:
if name == "Dimos":
from dimos.porcelain.dimos import Dimos

return Dimos
raise AttributeError(f"module 'dimos' has no attribute {name!r}")
2 changes: 1 addition & 1 deletion dimos/agents/mcp/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _start_server(self, port: int | None = None) -> None:
from dimos.core.global_config import global_config

_port = port if port is not None else global_config.mcp_port
_host = global_config.mcp_host
_host = global_config.listen_host
config = uvicorn.Config(app, host=_host, port=_port, log_level="info")
server = uvicorn.Server(config)
self._uvicorn_server = server
Expand Down
80 changes: 69 additions & 11 deletions dimos/core/coordination/module_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import threading
from typing import TYPE_CHECKING, Any, cast

from dimos.core.coordination.rpyc_server import RpycServer
from dimos.core.coordination.worker_manager import WorkerManager
from dimos.core.coordination.worker_manager_docker import WorkerManagerDocker
from dimos.core.coordination.worker_manager_python import WorkerManagerPython
Expand Down Expand Up @@ -52,16 +53,16 @@ def __init__(
) -> None:
self._global_config = g
manager_types: list[type[WorkerManager]] = [WorkerManagerDocker, WorkerManagerPython]
self._managers: dict[str, WorkerManager] = {
cls.deployment_identifier: cls(g=g) for cls in manager_types
}
self._managers = {cls.deployment_identifier: cls(g=g) for cls in manager_types}
self._deployed_modules = {}
self._deployed_atoms: dict[type[ModuleBase], BlueprintAtom] = {}
self._resolved_module_refs: dict[tuple[type[ModuleBase], str], type[ModuleBase]] = {}
self._transport_registry: dict[tuple[str, type], PubSubTransport[Any]] = {}
self._class_aliases: dict[type[ModuleBase], type[ModuleBase]] = {}
self._module_transports: dict[type[ModuleBase], dict[str, PubSubTransport[Any]]] = {}
self._started = False
self._modules_lock = threading.RLock()
self._rpyc = RpycServer(self)
Comment thread
paul-nechifor marked this conversation as resolved.

def start(self) -> None:
from dimos.core.o3dpickle import register_picklers
Expand All @@ -72,6 +73,8 @@ def start(self) -> None:
self._started = True

def stop(self) -> None:
self._rpyc.stop()

for module_class, module in reversed(self._deployed_modules.items()):
logger.info("Stopping module...", module=module_class.__name__)
try:
Expand All @@ -88,6 +91,26 @@ def _stop_manager(m: WorkerManager) -> None:

safe_thread_map(tuple(self._managers.values()), _stop_manager)

def start_rpyc_service(self) -> int:
return self._rpyc.start()

def list_module_names(self) -> list[str]:
with self._modules_lock:
return [cls.__name__ for cls in self._deployed_modules]

def get_module_endpoint(self, class_name: str) -> tuple[str, int, int]:
"""Return (host, worker_rpyc_port, module_id) for the given class name.

Lazily starts the worker-side RPyC server on first use.
"""
with self._modules_lock:
for cls, proxy in self._deployed_modules.items():
if cls.__name__ == class_name:
actor = cast("ModuleProxy", proxy).actor_instance
port = actor.start_rpyc()
return ("localhost", int(port), int(actor._module_id))
raise KeyError(class_name)

def health_check(self) -> bool:
return all(m.health_check() for m in self._managers.values())

Expand All @@ -111,7 +134,8 @@ def deploy(
deployed_module = self._managers[module_class.deployment].deploy(
module_class, global_config, kwargs
)
self._deployed_modules[module_class] = deployed_module
with self._modules_lock:
self._deployed_modules[module_class] = deployed_module
return deployed_module # type: ignore[return-value]

def deploy_parallel(
Expand Down Expand Up @@ -142,13 +166,14 @@ def _deploy_group(dep: str) -> None:
self.stop()
raise

self._deployed_modules.update(
{
cls: mod
for (cls, _, _), mod in zip(module_specs, results, strict=True)
if mod is not None
}
)
with self._modules_lock:
self._deployed_modules.update(
{
cls: mod
for (cls, _, _), mod in zip(module_specs, results, strict=True)
if mod is not None
}
)
return results

def build_all_modules(self) -> None:
Expand Down Expand Up @@ -264,6 +289,14 @@ def load_blueprint(
if not self._started:
raise RuntimeError("ModuleCoordinator not started; call start() first")

with self._modules_lock:
self._load_blueprint(blueprint, blueprint_args)

def _load_blueprint(
self,
blueprint: Blueprint,
blueprint_args: MutableMapping[str, Mapping[str, Any]] | None = None,
) -> None:
# Apply config overrides.
self._global_config.update(**dict(blueprint.global_config_overrides))
blueprint_args = blueprint_args or {}
Expand Down Expand Up @@ -320,6 +353,10 @@ def unload_module(self, module_class: type[ModuleBase]) -> None:
callers that expect the module to come back (e.g. ``restart_module``)
are responsible for rewiring.
"""
with self._modules_lock:
self._unload_module(module_class)

def _unload_module(self, module_class: type[ModuleBase]) -> None:
module_class = self._resolve_class(module_class)
if module_class not in self._deployed_modules:
raise ValueError(f"{module_class.__name__} is not deployed")
Expand Down Expand Up @@ -361,6 +398,18 @@ def unload_module(self, module_class: type[ModuleBase]) -> None:
if key[0] is not module_class and target is not module_class
}

def restart_module_by_class_name(
self,
class_name: str,
*,
reload_source: bool = True,
) -> ModuleProxyProtocol:
with self._modules_lock:
for cls in self._deployed_modules:
if cls.__name__ == class_name:
return self._restart_module(cls, reload_source=reload_source)
raise ValueError(f"No deployed module with class name {class_name!r}")

def restart_module(
self,
module_class: type[ModuleBase],
Expand All @@ -375,6 +424,15 @@ def restart_module(
transports, and re-injects the new proxy into every other module that
held a reference to it.
"""
with self._modules_lock:
return self._restart_module(module_class, reload_source=reload_source)

def _restart_module(
self,
module_class: type[ModuleBase],
*,
reload_source: bool = True,
) -> ModuleProxyProtocol:
module_class = self._resolve_class(module_class)
if module_class not in self._deployed_modules:
raise ValueError(f"{module_class.__name__} is not deployed")
Expand Down
123 changes: 82 additions & 41 deletions dimos/core/coordination/python_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import annotations

from dataclasses import dataclass
import logging
import multiprocessing
from multiprocessing.connection import Connection
Expand All @@ -22,12 +23,16 @@
import traceback
from typing import TYPE_CHECKING, Any

from rpyc.utils.server import ThreadedServer

from dimos.core.coordination.rpyc_services import WorkerRpycService
from dimos.core.coordination.worker_messages import (
CallMethodRequest,
DeployModuleRequest,
GetAttrRequest,
SetRefRequest,
ShutdownRequest,
StartRpycRequest,
SuppressConsoleRequest,
UndeployModuleRequest,
WorkerRequest,
Expand Down Expand Up @@ -128,6 +133,10 @@ def set_ref(self, ref: Any) -> ActorFuture:
result = self._send_request_to_worker(SetRefRequest(module_id=self._module_id, ref=ref))
return ActorFuture(result)

def start_rpyc(self) -> int:
port: int = self._send_request_to_worker(StartRpycRequest())
return port

def __getattr__(self, name: str) -> Any:
"""Proxy attribute access to the worker process."""
if name.startswith("_"):
Expand Down Expand Up @@ -317,18 +326,27 @@ def _suppress_console_output() -> None:
]


@dataclass
class _WorkerState:
instances: dict[int, Any]
worker_id: int
rpyc_server: ThreadedServer | None = None
rpyc_thread: threading.Thread | None = None
should_stop: bool = False


def _worker_entrypoint(conn: Connection, worker_id: int) -> None:
apply_library_config()
instances: dict[int, Any] = {}
state = _WorkerState(instances={}, worker_id=worker_id)

try:
_worker_loop(conn, instances, worker_id)
_worker_loop(conn, state)
except KeyboardInterrupt:
logger.info("Worker got KeyboardInterrupt.", worker_id=worker_id)
except Exception as e:
logger.error(f"Worker process error: {e}", exc_info=True)
finally:
for module_id, instance in reversed(list(instances.items())):
for module_id, instance in reversed(list(state.instances.items())):
try:
logger.info(
"Worker stopping module...",
Expand All @@ -353,7 +371,63 @@ def _worker_entrypoint(conn: Connection, worker_id: int) -> None:
logger.error("Error during worker shutdown", exc_info=True)


def _worker_loop(conn: Connection, instances: dict[int, Any], worker_id: int) -> None:
def _handle_request(request: Any, state: _WorkerState) -> WorkerResponse:
match request:
case DeployModuleRequest(module_id=module_id, module_class=module_class, kwargs=kwargs):
state.instances[module_id] = module_class(**kwargs)
return WorkerResponse(result=module_id)

case SetRefRequest(module_id=module_id, ref=ref):
state.instances[module_id].ref = ref
return WorkerResponse(result=state.worker_id)

case GetAttrRequest(module_id=module_id, name=name):
return WorkerResponse(result=getattr(state.instances[module_id], name))

case CallMethodRequest(module_id=module_id, name=name, args=args, kwargs=kwargs):
method = getattr(state.instances[module_id], name)
return WorkerResponse(result=method(*args, **kwargs))

case UndeployModuleRequest(module_id=module_id):
instance = state.instances.pop(module_id, None)
if instance is not None:
instance.stop()
return WorkerResponse(result=True)

case SuppressConsoleRequest():
_suppress_console_output()
return WorkerResponse(result=True)

case StartRpycRequest():
if state.rpyc_server is not None:
return WorkerResponse(result=state.rpyc_server.port)
WorkerRpycService._instances = state.instances
state.rpyc_server = ThreadedServer(
Comment thread
paul-nechifor marked this conversation as resolved.
WorkerRpycService,
port=0,
hostname=global_config.listen_host,
protocol_config={
"allow_all_attrs": True,
"allow_public_attrs": True,
},
)
state.rpyc_thread = threading.Thread(target=state.rpyc_server.start, daemon=True)
state.rpyc_thread.start()
return WorkerResponse(result=state.rpyc_server.port)

case ShutdownRequest():
if state.rpyc_server is not None:
state.rpyc_server.close()
if state.rpyc_thread is not None:
state.rpyc_thread.join(timeout=5)
state.should_stop = True
return WorkerResponse(result=True)

case _:
return WorkerResponse(error=f"Unknown request type: {type(request)}")


def _worker_loop(conn: Connection, state: _WorkerState) -> None:
while True:
try:
if not conn.poll(timeout=0.1):
Expand All @@ -362,44 +436,8 @@ def _worker_loop(conn: Connection, instances: dict[int, Any], worker_id: int) ->
except (EOFError, KeyboardInterrupt):
break

response: WorkerResponse
try:
match request:
case DeployModuleRequest(
module_id=module_id, module_class=module_class, kwargs=kwargs
):
instance = module_class(**kwargs)
instances[module_id] = instance
response = WorkerResponse(result=module_id)

case SetRefRequest(module_id=module_id, ref=ref):
instances[module_id].ref = ref
response = WorkerResponse(result=worker_id)

case GetAttrRequest(module_id=module_id, name=name):
response = WorkerResponse(result=getattr(instances[module_id], name))

case CallMethodRequest(module_id=module_id, name=name, args=args, kwargs=kwargs):
method = getattr(instances[module_id], name)
response = WorkerResponse(result=method(*args, **kwargs))

case UndeployModuleRequest(module_id=module_id):
instance = instances.pop(module_id, None)
if instance is not None:
instance.stop()
response = WorkerResponse(result=True)

case SuppressConsoleRequest():
_suppress_console_output()
response = WorkerResponse(result=True)

case ShutdownRequest():
conn.send(WorkerResponse(result=True))
break

case _:
response = WorkerResponse(error=f"Unknown request type: {type(request)}")

response = _handle_request(request, state)
except Exception as e:
response = WorkerResponse(
error=f"{e.__class__.__name__}: {e}\n{traceback.format_exc()}"
Expand All @@ -409,3 +447,6 @@ def _worker_loop(conn: Connection, instances: dict[int, Any], worker_id: int) ->
conn.send(response)
except (BrokenPipeError, EOFError):
break

if state.should_stop:
break
Loading
Loading