Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,16 @@ openrtc list \

For backward compatibility, discovery still supports legacy module-level
`AGENT_*` variables, but the decorator is the preferred pattern.

## Session routing and greetings

`AgentPool` resolves sessions in this order:

1. `ctx.job.metadata["agent"]` or `ctx.job.metadata["demo"]`
2. `ctx.room.metadata["agent"]` or `ctx.room.metadata["demo"]`
3. room name prefix matching such as `restaurant-call-123`
4. the first registered agent

Registered greetings are generated after `ctx.connect()`, and advanced
`AgentSession` options can be passed per agent with `session_kwargs`.

97 changes: 84 additions & 13 deletions src/openrtc/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import logging
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from types import ModuleType
from typing import Any, TypeVar
Expand All @@ -15,6 +15,7 @@

_AgentType = TypeVar("_AgentType", bound=type[Agent])
_AGENT_METADATA_ATTR = "__openrtc_agent_config__"
_METADATA_AGENT_KEYS = ("agent", "demo")


@dataclass(slots=True)
Expand All @@ -27,7 +28,8 @@ class AgentConfig:
stt: Speech-to-text provider string or provider instance.
llm: Large language model provider string or provider instance.
tts: Text-to-speech provider string or provider instance.
greeting: Optional initial greeting reserved for future use.
greeting: Optional initial greeting played after the session connects.
session_kwargs: Additional keyword arguments forwarded to ``AgentSession``.
"""

name: str
Expand All @@ -36,6 +38,7 @@ class AgentConfig:
llm: Any = None
tts: Any = None
greeting: str | None = None
session_kwargs: dict[str, Any] = field(default_factory=dict)


@dataclass(slots=True)
Expand Down Expand Up @@ -139,6 +142,7 @@ def add(
llm: Any = None,
tts: Any = None,
greeting: str | None = None,
session_kwargs: Mapping[str, Any] | None = None,
) -> AgentConfig:
"""Register an agent in the pool.

Expand All @@ -148,7 +152,11 @@ def add(
stt: STT provider string or instance.
llm: LLM provider string or instance.
tts: TTS provider string or instance.
greeting: Optional greeting reserved for later milestones.
greeting: Optional greeting played after the room connection completes.
session_kwargs: Extra keyword arguments forwarded to ``AgentSession``.
Common examples include ``preemptive_generation``,
``allow_interruptions``, ``min_endpointing_delay``,
``max_endpointing_delay``, and ``max_tool_steps``.

Returns:
The created agent configuration.
Expand All @@ -172,6 +180,7 @@ def add(
llm=self._resolve_provider(llm, self._default_llm),
tts=self._resolve_provider(tts, self._default_tts),
greeting=self._resolve_greeting(greeting),
session_kwargs=self._copy_session_kwargs(session_kwargs),
)
self._agents[normalized_name] = config
logger.debug("Registered agent '%s'.", normalized_name)
Expand Down Expand Up @@ -230,6 +239,42 @@ def list_agents(self) -> list[str]:
"""Return registered agent names in registration order."""
return list(self._agents)

def get(self, name: str) -> AgentConfig:
"""Return a registered agent configuration by name.

Args:
name: The registered agent name.

Returns:
The registered configuration.

Raises:
KeyError: If the agent name is unknown.
"""
try:
return self._agents[name]
except KeyError as exc:
raise KeyError(f"Unknown agent '{name}'.") from exc

def remove(self, name: str) -> AgentConfig:
"""Remove and return a registered agent configuration.

Args:
name: The registered agent name.

Returns:
The removed configuration.

Raises:
KeyError: If the agent name is unknown.
"""
try:
removed = self._agents.pop(name)
except KeyError as exc:
raise KeyError(f"Unknown agent '{name}'.") from exc
logger.debug("Removed agent '%s'.", name)
return removed

def run(self) -> None:
"""Run the LiveKit worker for the registered agents.

Expand Down Expand Up @@ -274,10 +319,19 @@ def _resolve_agent(self, ctx: JobContext) -> AgentConfig:
if selected_name is not None:
return self._get_registered_agent(selected_name, source="room metadata")

room_name = getattr(ctx.room, "name", None)
if isinstance(room_name, str):
for agent_name, config in self._agents.items():
if room_name.startswith(f"{agent_name}-"):
logger.info(
"Resolved agent '%s' via room name prefix from room '%s'.",
agent_name,
room_name,
)
return config

default_agent = next(iter(self._agents.values()))
logger.debug(
"No routing metadata found; defaulting to agent '%s'.", default_agent.name
)
logger.info("Resolved agent '%s' via default fallback.", default_agent.name)
return default_agent

async def _handle_session(self, ctx: JobContext) -> None:
Expand All @@ -289,17 +343,21 @@ async def _handle_session(self, ctx: JobContext) -> None:
tts=config.tts,
vad=ctx.proc.userdata["vad"],
turn_detection=ctx.proc.userdata["turn_detection"],
**config.session_kwargs,
)

await session.start(agent=config.agent_cls(), room=ctx.room)
await ctx.connect()

if config.greeting is not None:
logger.debug("Generating greeting for agent '%s'.", config.name)
await session.generate_reply(instructions=config.greeting)

def _agent_name_from_metadata(self, metadata: Any) -> str | None:
if metadata is None:
return None
if isinstance(metadata, Mapping):
value = metadata.get("agent")
return value.strip() if isinstance(value, str) and value.strip() else None
return self._agent_name_from_mapping(metadata)
if isinstance(metadata, str):
stripped = metadata.strip()
if not stripped:
Expand All @@ -310,18 +368,24 @@ def _agent_name_from_metadata(self, metadata: Any) -> str | None:
logger.debug("Ignoring non-JSON metadata: %s", stripped)
return None
if isinstance(decoded, Mapping):
value = decoded.get("agent")
return (
value.strip() if isinstance(value, str) and value.strip() else None
)
return self._agent_name_from_mapping(decoded)
return None

def _agent_name_from_mapping(self, metadata: Mapping[str, Any]) -> str | None:
for key in _METADATA_AGENT_KEYS:
value = metadata.get(key)
if isinstance(value, str):
normalized_value = value.strip()
if normalized_value:
return normalized_value
return None

def _get_registered_agent(self, name: str, *, source: str) -> AgentConfig:
try:
config = self._agents[name]
except KeyError as exc:
raise ValueError(f"Unknown agent '{name}' requested via {source}.") from exc
logger.debug("Resolved agent '%s' via %s.", name, source)
logger.info("Resolved agent '%s' via %s.", name, source)
return config

def _resolve_provider(self, value: Any, default_value: Any) -> Any:
Expand All @@ -330,6 +394,13 @@ def _resolve_provider(self, value: Any, default_value: Any) -> Any:
def _resolve_greeting(self, greeting: str | None) -> str | None:
return self._default_greeting if greeting is None else greeting

def _copy_session_kwargs(
self, session_kwargs: Mapping[str, Any] | None
) -> dict[str, Any]:
if session_kwargs is None:
return {}
return dict(session_kwargs)

def _resolve_discovery_metadata(
self,
module: ModuleType,
Expand Down
47 changes: 47 additions & 0 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,22 @@ def test_add_uses_pool_defaults_when_agent_values_are_omitted() -> None:
assert config.greeting == "Hello from OpenRTC."


def test_add_stores_session_kwargs_copy() -> None:
pool = AgentPool()
session_kwargs = {
"preemptive_generation": True,
"min_endpointing_delay": 0.5,
}

config = pool.add("test", DemoAgent, session_kwargs=session_kwargs)
session_kwargs["preemptive_generation"] = False

assert config.session_kwargs == {
"preemptive_generation": True,
"min_endpointing_delay": 0.5,
}


def test_add_duplicate_name_raises() -> None:
pool = AgentPool()
pool.add("test", DemoAgent)
Expand All @@ -66,6 +82,37 @@ def test_list_agents_returns_registration_order() -> None:
assert pool.list_agents() == ["restaurant", "dental"]


def test_get_returns_registered_agent() -> None:
pool = AgentPool()
config = pool.add("restaurant", DemoAgent)

assert pool.get("restaurant") is config


def test_get_unknown_agent_raises_key_error() -> None:
pool = AgentPool()

with pytest.raises(KeyError, match="Unknown agent 'missing'"):
pool.get("missing")


def test_remove_returns_removed_agent() -> None:
pool = AgentPool()
config = pool.add("restaurant", DemoAgent)

removed = pool.remove("restaurant")

assert removed is config
assert pool.list_agents() == []


def test_remove_unknown_agent_raises_key_error() -> None:
pool = AgentPool()

with pytest.raises(KeyError, match="Unknown agent 'missing'"):
pool.remove("missing")


def test_run_without_agents_raises() -> None:
pool = AgentPool()

Expand Down
Loading
Loading