Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ Use discovery when you want one agent module per file. OpenRTC will import each
module, find a local `Agent` subclass, and optionally read overrides from the
`@agent_config(...)` decorator.

Discovered agents are safe to run under `livekit dev`, including spawn-based
worker runtimes such as macOS. For direct `add()` registration, define agent
classes at module scope so worker processes can reload them.

```python
from pathlib import Path

Expand Down
2 changes: 2 additions & 0 deletions docs/api/pool.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ Registers a named LiveKit `Agent` subclass.
- `name` must be a non-empty string after trimming whitespace
- names must be unique
- `agent_cls` must be a subclass of `livekit.agents.Agent`
- `agent_cls` must be defined at module scope for spawn-based worker runtimes

### Session options

Expand Down Expand Up @@ -134,6 +135,7 @@ Discovery behavior:
- uses `@agent_config(...)` metadata when present
- otherwise uses the filename stem as the agent name
- falls back to pool defaults for omitted provider and greeting fields
- preserves file-backed agent loading so discovered agents work with `livekit dev`

### Raises

Expand Down
6 changes: 4 additions & 2 deletions examples/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

from pathlib import Path

from dotenv import load_dotenv

from openrtc import AgentPool

load_dotenv()


def main() -> None:
pool = AgentPool(
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ dev = [
"pre-commit>=4.5.1",
"pytest>=9.0.2",
"pytest-asyncio>=1.2.0",
"python-dotenv>=1.2.2",
"ruff>=0.15.6",
]

Expand Down
144 changes: 134 additions & 10 deletions src/openrtc/pool.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

import importlib
import importlib.util
import inspect
Comment thread
mahimairaja marked this conversation as resolved.
import json
import logging
import sys
from collections.abc import Callable, Mapping
from dataclasses import dataclass, field
from functools import partial
from hashlib import sha1
from pathlib import Path
from types import ModuleType
from typing import Any, TypeVar
Expand All @@ -27,6 +30,15 @@ class _PoolRuntimeState:
agents: dict[str, AgentConfig]


@dataclass(frozen=True, slots=True)
class _AgentClassRef:
"""Serializable reference to an agent class."""

module_name: str
qualname: str
module_path: str | None = None


def _prewarm_worker(
runtime_state: _PoolRuntimeState,
proc: JobProcess,
Expand Down Expand Up @@ -85,6 +97,31 @@ class AgentConfig:
tts: Any = None
greeting: str | None = None
session_kwargs: dict[str, Any] = field(default_factory=dict)
_agent_ref: _AgentClassRef = field(init=False, repr=False, compare=False)

def __post_init__(self) -> None:
self._agent_ref = _build_agent_class_ref(self.agent_cls)

def __getstate__(self) -> dict[str, Any]:
return {
"name": self.name,
"stt": self.stt,
"llm": self.llm,
"tts": self.tts,
"greeting": self.greeting,
"session_kwargs": dict(self.session_kwargs),
"agent_ref": self._agent_ref,
}

def __setstate__(self, state: Mapping[str, Any]) -> None:
self.name = state["name"]
self.stt = state["stt"]
self.llm = state["llm"]
self.tts = state["tts"]
self.greeting = state["greeting"]
self.session_kwargs = dict(state["session_kwargs"])
self._agent_ref = state["agent_ref"]
self.agent_cls = _resolve_agent_class(self._agent_ref)


@dataclass(slots=True)
Expand Down Expand Up @@ -400,21 +437,13 @@ def _resolve_discovery_metadata(
return AgentDiscoveryConfig()

def _load_agent_module(self, module_path: Path) -> ModuleType:
module_name = f"openrtc_discovered_{module_path.stem}"
spec = importlib.util.spec_from_file_location(module_name, module_path)
if spec is None or spec.loader is None:
raise RuntimeError(f"Could not create import spec for {module_path}.")

module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
module_name = _discovered_module_name(module_path)
try:
spec.loader.exec_module(module)
return _load_module_from_path(module_name, module_path)
except Exception as exc:
sys.modules.pop(module_name, None)
raise RuntimeError(
f"Failed to import agent module '{module_path.name}': {exc}"
) from exc
return module

def _find_local_agent_subclass(self, module: ModuleType) -> type[Agent]:
for value in vars(module).values():
Expand Down Expand Up @@ -520,6 +549,101 @@ def _get_registered_agent(
return config


def _build_agent_class_ref(agent_cls: type[Agent]) -> _AgentClassRef:
module_name = agent_cls.__module__
qualname = agent_cls.__qualname__
if "<locals>" in qualname:
raise ValueError(
"agent_cls must be defined at module scope so spawned workers can "
"reload it safely."
)

module_path = _try_get_module_path(agent_cls)
if module_name == "__main__" and module_path is None:
raise ValueError(
"agent_cls defined in __main__ must come from a real Python file so "
"spawned workers can reload it."
)

return _AgentClassRef(
module_name=module_name,
qualname=qualname,
module_path=None if module_path is None else str(module_path),
)


def _resolve_agent_class(agent_ref: _AgentClassRef) -> type[Agent]:
module: ModuleType | None = None
module_path = (
None if agent_ref.module_path is None else Path(agent_ref.module_path).resolve()
)

if module_path is not None and agent_ref.module_name.startswith(
"openrtc_discovered_"
):
module = _load_module_from_path(agent_ref.module_name, module_path)
else:
try:
module = importlib.import_module(agent_ref.module_name)
except ModuleNotFoundError:
if module_path is None:
raise
module = _load_module_from_path(agent_ref.module_name, module_path)

agent_cls = _resolve_qualname(module, agent_ref.qualname)
if not isinstance(agent_cls, type) or not issubclass(agent_cls, Agent):
raise TypeError(
f"{agent_ref.qualname!r} in module {module.__name__!r} is not a "
"livekit.agents.Agent subclass."
)
return agent_cls


def _resolve_qualname(module: ModuleType, qualname: str) -> Any:
value: Any = module
for part in qualname.split("."):
value = getattr(value, part)
return value


def _try_get_module_path(agent_cls: type[Agent]) -> Path | None:
try:
source_path = inspect.getsourcefile(agent_cls)
except (OSError, TypeError):
source_path = None
if source_path is None:
return None
return Path(source_path).resolve()


def _discovered_module_name(module_path: Path) -> str:
resolved_path = module_path.resolve()
digest = sha1(str(resolved_path).encode("utf-8")).hexdigest()[:12]
return f"openrtc_discovered_{resolved_path.stem}_{digest}"


def _load_module_from_path(module_name: str, module_path: Path) -> ModuleType:
resolved_path = module_path.resolve()
existing_module = sys.modules.get(module_name)
if existing_module is not None:
existing_file = getattr(existing_module, "__file__", None)
if existing_file is not None and Path(existing_file).resolve() == resolved_path:
return existing_module

spec = importlib.util.spec_from_file_location(module_name, resolved_path)
if spec is None or spec.loader is None:
raise RuntimeError(f"Could not create import spec for {resolved_path}.")

module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
try:
spec.loader.exec_module(module)
except Exception:
sys.modules.pop(module_name, None)
raise
return module


def _load_shared_runtime_dependencies() -> tuple[Any, type[Any]]:
"""Load the optional LiveKit runtime dependencies used during prewarm."""
try:
Expand Down
26 changes: 26 additions & 0 deletions tests/test_discovery.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import pickle
import sys
from pathlib import Path

import pytest
Expand Down Expand Up @@ -159,3 +161,27 @@ def test_discover_ignores_imported_agent_subclasses(tmp_path: Path) -> None:

assert [config.name for config in discovered] == ["local"]
assert discovered[0].agent_cls.__name__ == "LocalAgent"


def test_discovered_agent_config_is_pickleable_across_module_reload(
tmp_path: Path,
) -> None:
_write_agent_module(
tmp_path,
"dental.py",
class_name="DentalAgent",
decorator='@agent_config(name="dental")\n',
)

pool = AgentPool()
discovered = pool.discover(tmp_path)

config = discovered[0]
module_name = config.agent_cls.__module__
sys.modules.pop(module_name, None)

restored = pickle.loads(pickle.dumps(config))

assert restored.name == "dental"
assert restored.agent_cls.__name__ == "DentalAgent"
assert restored.agent_cls.__module__ == module_name
Loading
Loading