Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions dimos/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2026 Dimensional Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def __getattr__(name: str) -> object:
if name == "Dimos":
from dimos.porcelain import Dimos

return Dimos
raise AttributeError(f"module 'dimos' has no attribute {name!r}")
135 changes: 98 additions & 37 deletions dimos/core/coordination/python_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@
import traceback
from typing import TYPE_CHECKING, Any

import rpyc
from rpyc.utils.server import ThreadedServer

from dimos.core.coordination.worker_messages import (
CallMethodRequest,
DeployModuleRequest,
GetAttrRequest,
SetRefRequest,
ShutdownRequest,
StartRpycRequest,
SuppressConsoleRequest,
UndeployModuleRequest,
WorkerRequest,
Expand Down Expand Up @@ -128,6 +132,11 @@ def set_ref(self, ref: Any) -> ActorFuture:
result = self._send_request_to_worker(SetRefRequest(module_id=self._module_id, ref=ref))
return ActorFuture(result)

def start_rpyc(self) -> int:
"""Start an RPyC server in the worker process and return its port."""
port: int = self._send_request_to_worker(StartRpycRequest())
return port

def __getattr__(self, name: str) -> Any:
"""Proxy attribute access to the worker process."""
if name.startswith("_"):
Expand Down Expand Up @@ -300,6 +309,36 @@ def shutdown(self) -> None:
self._process = None


class _WorkerRpycService(rpyc.Service): # type: ignore[misc]
_instances: dict[int, Any] = {}

def on_connect(self, conn: Any) -> None:
conn._config.update(
{
"allow_all_attrs": True,
"allow_public_attrs": True,
"allow_setattr": True,
}
)

def exposed_get_module(self, module_id: int) -> Any:
return self._instances[module_id]


def _start_rpyc_server(instances: dict[int, Any]) -> ThreadedServer:
_WorkerRpycService._instances = instances
server = ThreadedServer(
_WorkerRpycService,
port=0,
protocol_config={
"allow_all_attrs": True,
"allow_public_attrs": True,
},
)
threading.Thread(target=server.start, daemon=True).start()
return server


def _suppress_console_output() -> None:
"""Redirect stdout/stderr to /dev/null and strip console handlers."""
devnull = open(os.devnull, "w")
Expand Down Expand Up @@ -353,7 +392,58 @@ def _worker_entrypoint(conn: Connection, worker_id: int) -> None:
logger.error("Error during worker shutdown", exc_info=True)


def _handle_request(
request: Any,
instances: dict[int, Any],
worker_id: int,
rpyc_server: ThreadedServer | None,
) -> tuple[WorkerResponse, ThreadedServer | None, bool]:
match request:
case DeployModuleRequest(module_id=module_id, module_class=module_class, kwargs=kwargs):
instances[module_id] = module_class(**kwargs)
return WorkerResponse(result=module_id), rpyc_server, False

case SetRefRequest(module_id=module_id, ref=ref):
instances[module_id].ref = ref
return WorkerResponse(result=worker_id), rpyc_server, False

case GetAttrRequest(module_id=module_id, name=name):
return WorkerResponse(result=getattr(instances[module_id], name)), rpyc_server, False

case CallMethodRequest(module_id=module_id, name=name, args=args, kwargs=kwargs):
method = getattr(instances[module_id], name)
return WorkerResponse(result=method(*args, **kwargs)), rpyc_server, False

case UndeployModuleRequest(module_id=module_id):
instance = instances.pop(module_id, None)
if instance is not None:
instance.stop()
return WorkerResponse(result=True), rpyc_server, False

case SuppressConsoleRequest():
_suppress_console_output()
return WorkerResponse(result=True), rpyc_server, False

case StartRpycRequest():
server = rpyc_server or _start_rpyc_server(instances)
return WorkerResponse(result=server.port), server, False

case ShutdownRequest():
if rpyc_server is not None:
rpyc_server.close()
return WorkerResponse(result=True), rpyc_server, True

case _:
return (
WorkerResponse(error=f"Unknown request type: {type(request)}"),
rpyc_server,
False,
)


def _worker_loop(conn: Connection, instances: dict[int, Any], worker_id: int) -> None:
rpyc_server: ThreadedServer | None = None

while True:
try:
if not conn.poll(timeout=0.1):
Expand All @@ -362,44 +452,12 @@ def _worker_loop(conn: Connection, instances: dict[int, Any], worker_id: int) ->
except (EOFError, KeyboardInterrupt):
break

response: WorkerResponse
try:
match request:
case DeployModuleRequest(
module_id=module_id, module_class=module_class, kwargs=kwargs
):
instance = module_class(**kwargs)
instances[module_id] = instance
response = WorkerResponse(result=module_id)

case SetRefRequest(module_id=module_id, ref=ref):
instances[module_id].ref = ref
response = WorkerResponse(result=worker_id)

case GetAttrRequest(module_id=module_id, name=name):
response = WorkerResponse(result=getattr(instances[module_id], name))

case CallMethodRequest(module_id=module_id, name=name, args=args, kwargs=kwargs):
method = getattr(instances[module_id], name)
response = WorkerResponse(result=method(*args, **kwargs))

case UndeployModuleRequest(module_id=module_id):
instance = instances.pop(module_id, None)
if instance is not None:
instance.stop()
response = WorkerResponse(result=True)

case SuppressConsoleRequest():
_suppress_console_output()
response = WorkerResponse(result=True)

case ShutdownRequest():
conn.send(WorkerResponse(result=True))
break

case _:
response = WorkerResponse(error=f"Unknown request type: {type(request)}")
should_stop = False

try:
response, rpyc_server, should_stop = _handle_request(
request, instances, worker_id, rpyc_server
)
except Exception as e:
response = WorkerResponse(
error=f"{e.__class__.__name__}: {e}\n{traceback.format_exc()}"
Expand All @@ -409,3 +467,6 @@ def _worker_loop(conn: Connection, instances: dict[int, Any], worker_id: int) ->
conn.send(response)
except (BrokenPipeError, EOFError):
break

if should_stop:
break
6 changes: 6 additions & 0 deletions dimos/core/coordination/worker_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ class SuppressConsoleRequest:
pass


@dataclass(frozen=True)
class StartRpycRequest:
pass


@dataclass(frozen=True)
class ShutdownRequest:
pass
Expand All @@ -69,6 +74,7 @@ class ShutdownRequest:
| CallMethodRequest
| UndeployModuleRequest
| SuppressConsoleRequest
| StartRpycRequest
| ShutdownRequest
)

Expand Down
Loading
Loading