Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,18 @@ For backward compatibility, discovery still supports legacy module-level
4. the first registered agent

Registered greetings are generated after `ctx.connect()`, and advanced
`AgentSession` options can be passed per agent with `session_kwargs`.
`AgentSession` options can be passed either with `session_kwargs` or directly as
keyword arguments to `add()`. Direct keyword arguments take precedence when the
same option is provided both ways.

```python
pool.add(
"restaurant",
RestaurantAgent,
greeting="Welcome to reservations.",
session_kwargs={"allow_interruptions": False},
max_tool_steps=4,
preemptive_generation=True,
)
```

25 changes: 19 additions & 6 deletions src/openrtc/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def add(
tts: Any = None,
greeting: str | None = None,
session_kwargs: Mapping[str, Any] | None = None,
**session_options: Any,
) -> AgentConfig:
"""Register an agent in the pool.

Expand All @@ -157,6 +158,10 @@ def add(
Common examples include ``preemptive_generation``,
``allow_interruptions``, ``min_endpointing_delay``,
``max_endpointing_delay``, and ``max_tool_steps``.
**session_options: Additional ``AgentSession`` options passed
directly to ``add()``. When the same option appears in both
``session_kwargs`` and direct keyword arguments, the direct
keyword argument takes precedence.

Returns:
The created agent configuration.
Expand All @@ -180,7 +185,10 @@ def add(
llm=self._resolve_provider(llm, self._default_llm),
tts=self._resolve_provider(tts, self._default_tts),
greeting=self._resolve_greeting(greeting),
session_kwargs=self._copy_session_kwargs(session_kwargs),
session_kwargs=self._merge_session_kwargs(
session_kwargs=session_kwargs,
direct_session_kwargs=session_options,
),
)
self._agents[normalized_name] = config
logger.debug("Registered agent '%s'.", normalized_name)
Expand Down Expand Up @@ -394,12 +402,17 @@ def _resolve_provider(self, value: Any, default_value: Any) -> Any:
def _resolve_greeting(self, greeting: str | None) -> str | None:
return self._default_greeting if greeting is None else greeting

def _copy_session_kwargs(
self, session_kwargs: Mapping[str, Any] | None
def _merge_session_kwargs(
self,
session_kwargs: Mapping[str, Any] | None,
direct_session_kwargs: Mapping[str, Any] | None = None,
) -> dict[str, Any]:
if session_kwargs is None:
return {}
return dict(session_kwargs)
merged_kwargs: dict[str, Any] = {}
if session_kwargs is not None:
merged_kwargs.update(session_kwargs)
if direct_session_kwargs is not None:
merged_kwargs.update(direct_session_kwargs)
return merged_kwargs

def _resolve_discovery_metadata(
self,
Expand Down
23 changes: 23 additions & 0 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,29 @@ def test_add_stores_session_kwargs_copy() -> None:
}


def test_add_merges_direct_session_kwargs_with_mapping() -> None:
pool = AgentPool()
session_kwargs = {
"preemptive_generation": False,
"allow_interruptions": False,
}

config = pool.add(
"test",
DemoAgent,
session_kwargs=session_kwargs,
preemptive_generation=True,
max_tool_steps=3,
)
session_kwargs["allow_interruptions"] = True

assert config.session_kwargs == {
"preemptive_generation": True,
"allow_interruptions": False,
"max_tool_steps": 3,
}


def test_add_duplicate_name_raises() -> None:
pool = AgentPool()
pool.add("test", DemoAgent)
Expand Down
21 changes: 21 additions & 0 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,27 @@ def test_handle_session_passes_session_kwargs_and_provider_objects(
assert session.kwargs["turn_detection"] is ctx.proc.userdata["turn_detection"]


def test_handle_session_supports_direct_session_kwargs(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr("openrtc.pool.AgentSession", FakeSession)
pool = AgentPool()
pool.add(
"dental",
DentalAgent,
session_kwargs={"allow_interruptions": False},
allow_interruptions=True,
max_tool_steps=6,
)
ctx = FakeJobContext(job_metadata={"agent": "dental"})

asyncio.run(pool._handle_session(ctx))

session = FakeSession.instances[0]
assert session.kwargs["allow_interruptions"] is True
assert session.kwargs["max_tool_steps"] == 6


def test_handle_session_generates_greeting_after_connect(
monkeypatch: pytest.MonkeyPatch,
pool: AgentPool,
Expand Down
Loading