diff --git a/README.md b/README.md index d044e47..72aa63b 100644 --- a/README.md +++ b/README.md @@ -53,5 +53,18 @@ For backward compatibility, discovery still supports legacy module-level 4. the first registered agent Registered greetings are generated after `ctx.connect()`, and advanced -`AgentSession` options can be passed per agent with `session_kwargs`. +`AgentSession` options can be passed either with `session_kwargs` or directly as +keyword arguments to `add()`. Direct keyword arguments take precedence when the +same option is provided both ways. + +```python +pool.add( + "restaurant", + RestaurantAgent, + greeting="Welcome to reservations.", + session_kwargs={"allow_interruptions": False}, + max_tool_steps=4, + preemptive_generation=True, +) +``` diff --git a/src/openrtc/pool.py b/src/openrtc/pool.py index 23e1328..391df8c 100644 --- a/src/openrtc/pool.py +++ b/src/openrtc/pool.py @@ -143,6 +143,7 @@ def add( tts: Any = None, greeting: str | None = None, session_kwargs: Mapping[str, Any] | None = None, + **session_options: Any, ) -> AgentConfig: """Register an agent in the pool. @@ -157,6 +158,10 @@ def add( Common examples include ``preemptive_generation``, ``allow_interruptions``, ``min_endpointing_delay``, ``max_endpointing_delay``, and ``max_tool_steps``. + **session_options: Additional ``AgentSession`` options passed + directly to ``add()``. When the same option appears in both + ``session_kwargs`` and direct keyword arguments, the direct + keyword argument takes precedence. Returns: The created agent configuration. @@ -180,7 +185,10 @@ def add( llm=self._resolve_provider(llm, self._default_llm), tts=self._resolve_provider(tts, self._default_tts), greeting=self._resolve_greeting(greeting), - session_kwargs=self._copy_session_kwargs(session_kwargs), + session_kwargs=self._merge_session_kwargs( + session_kwargs=session_kwargs, + direct_session_kwargs=session_options, + ), ) self._agents[normalized_name] = config logger.debug("Registered agent '%s'.", normalized_name) @@ -394,12 +402,17 @@ def _resolve_provider(self, value: Any, default_value: Any) -> Any: def _resolve_greeting(self, greeting: str | None) -> str | None: return self._default_greeting if greeting is None else greeting - def _copy_session_kwargs( - self, session_kwargs: Mapping[str, Any] | None + def _merge_session_kwargs( + self, + session_kwargs: Mapping[str, Any] | None, + direct_session_kwargs: Mapping[str, Any] | None = None, ) -> dict[str, Any]: - if session_kwargs is None: - return {} - return dict(session_kwargs) + merged_kwargs: dict[str, Any] = {} + if session_kwargs is not None: + merged_kwargs.update(session_kwargs) + if direct_session_kwargs is not None: + merged_kwargs.update(direct_session_kwargs) + return merged_kwargs def _resolve_discovery_metadata( self, diff --git a/tests/test_pool.py b/tests/test_pool.py index 308a5fe..50f4b56 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -58,6 +58,29 @@ def test_add_stores_session_kwargs_copy() -> None: } +def test_add_merges_direct_session_kwargs_with_mapping() -> None: + pool = AgentPool() + session_kwargs = { + "preemptive_generation": False, + "allow_interruptions": False, + } + + config = pool.add( + "test", + DemoAgent, + session_kwargs=session_kwargs, + preemptive_generation=True, + max_tool_steps=3, + ) + session_kwargs["allow_interruptions"] = True + + assert config.session_kwargs == { + "preemptive_generation": True, + "allow_interruptions": False, + "max_tool_steps": 3, + } + + def test_add_duplicate_name_raises() -> None: pool = AgentPool() pool.add("test", DemoAgent) diff --git a/tests/test_routing.py b/tests/test_routing.py index 14dba6f..debd1d8 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -194,6 +194,27 @@ def test_handle_session_passes_session_kwargs_and_provider_objects( assert session.kwargs["turn_detection"] is ctx.proc.userdata["turn_detection"] +def test_handle_session_supports_direct_session_kwargs( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr("openrtc.pool.AgentSession", FakeSession) + pool = AgentPool() + pool.add( + "dental", + DentalAgent, + session_kwargs={"allow_interruptions": False}, + allow_interruptions=True, + max_tool_steps=6, + ) + ctx = FakeJobContext(job_metadata={"agent": "dental"}) + + asyncio.run(pool._handle_session(ctx)) + + session = FakeSession.instances[0] + assert session.kwargs["allow_interruptions"] is True + assert session.kwargs["max_tool_steps"] == 6 + + def test_handle_session_generates_greeting_after_connect( monkeypatch: pytest.MonkeyPatch, pool: AgentPool,