diff --git a/examples/realtime/app/agent.py b/examples/realtime/app/agent.py index 81d8db7c1..ee906dbb8 100644 --- a/examples/realtime/app/agent.py +++ b/examples/realtime/app/agent.py @@ -1,3 +1,5 @@ +import asyncio + from agents import function_tool from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX from agents.realtime import RealtimeAgent, realtime_handoff @@ -13,20 +15,26 @@ name_override="faq_lookup_tool", description_override="Lookup frequently asked questions." ) async def faq_lookup_tool(question: str) -> str: - if "bag" in question or "baggage" in question: + print("faq_lookup_tool called with question:", question) + + # Simulate a slow API call + await asyncio.sleep(3) + + q = question.lower() + if "wifi" in q or "wi-fi" in q: + return "We have free wifi on the plane, join Airline-Wifi" + elif "bag" in q or "baggage" in q: return ( "You are allowed to bring one bag on the plane. " "It must be under 50 pounds and 22 inches x 14 inches x 9 inches." ) - elif "seats" in question or "plane" in question: + elif "seats" in q or "plane" in q: return ( "There are 120 seats on the plane. " "There are 22 business class seats and 98 economy seats. " "Exit rows are rows 4 and 16. " "Rows 5-8 are Economy Plus, with extra legroom. " ) - elif "wifi" in question: - return "We have free wifi on the plane, join Airline-Wifi" return "I'm sorry, I don't know the answer to that question." diff --git a/examples/realtime/app/server.py b/examples/realtime/app/server.py index 3a3d58674..6082fe8d2 100644 --- a/examples/realtime/app/server.py +++ b/examples/realtime/app/server.py @@ -47,6 +47,9 @@ async def connect(self, websocket: WebSocket, session_id: str): agent = get_starting_agent() runner = RealtimeRunner(agent) + # If you want to customize the runner behavior, you can pass options: + # runner_config = RealtimeRunConfig(async_tool_calls=False) + # runner = RealtimeRunner(agent, config=runner_config) model_config: RealtimeModelConfig = { "initial_model_settings": { "turn_detection": { diff --git a/src/agents/realtime/config.py b/src/agents/realtime/config.py index ddbf48bab..9b6712a28 100644 --- a/src/agents/realtime/config.py +++ b/src/agents/realtime/config.py @@ -184,6 +184,9 @@ class RealtimeRunConfig(TypedDict): tracing_disabled: NotRequired[bool] """Whether tracing is disabled for this run.""" + async_tool_calls: NotRequired[bool] + """Whether function tool calls should run asynchronously. Defaults to True.""" + # TODO (rm) Add history audio storage config diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index e10b48e53..6378382e1 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -112,7 +112,7 @@ def __init__( } self._event_queue: asyncio.Queue[RealtimeSessionEvent] = asyncio.Queue() self._closed = False - self._stored_exception: Exception | None = None + self._stored_exception: BaseException | None = None # Guardrails state tracking self._interrupted_response_ids: set[str] = set() @@ -123,6 +123,8 @@ def __init__( ) self._guardrail_tasks: set[asyncio.Task[Any]] = set() + self._tool_call_tasks: set[asyncio.Task[Any]] = set() + self._async_tool_calls: bool = bool(self._run_config.get("async_tool_calls", True)) @property def model(self) -> RealtimeModel: @@ -216,7 +218,11 @@ async def on_event(self, event: RealtimeModelEvent) -> None: if event.type == "error": await self._put_event(RealtimeError(info=self._event_info, error=event.error)) elif event.type == "function_call": - await self._handle_tool_call(event) + agent_snapshot = self._current_agent + if self._async_tool_calls: + self._enqueue_tool_call_task(event, agent_snapshot) + else: + await self._handle_tool_call(event, agent_snapshot=agent_snapshot) elif event.type == "audio": await self._put_event( RealtimeAudio( @@ -384,11 +390,17 @@ async def _put_event(self, event: RealtimeSessionEvent) -> None: """Put an event into the queue.""" await self._event_queue.put(event) - async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None: + async def _handle_tool_call( + self, + event: RealtimeModelToolCallEvent, + *, + agent_snapshot: RealtimeAgent | None = None, + ) -> None: """Handle a tool call event.""" + agent = agent_snapshot or self._current_agent tools, handoffs = await asyncio.gather( - self._current_agent.get_all_tools(self._context_wrapper), - self._get_handoffs(self._current_agent, self._context_wrapper), + agent.get_all_tools(self._context_wrapper), + self._get_handoffs(agent, self._context_wrapper), ) function_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)} handoff_map = {handoff.tool_name: handoff for handoff in handoffs} @@ -398,7 +410,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None: RealtimeToolStart( info=self._event_info, tool=function_map[event.name], - agent=self._current_agent, + agent=agent, ) ) @@ -423,7 +435,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None: info=self._event_info, tool=func_tool, output=result, - agent=self._current_agent, + agent=agent, ) ) elif event.name in handoff_map: @@ -444,7 +456,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None: ) # Store previous agent for event - previous_agent = self._current_agent + previous_agent = agent # Update current agent self._current_agent = result @@ -752,10 +764,49 @@ def _cleanup_guardrail_tasks(self) -> None: task.cancel() self._guardrail_tasks.clear() + def _enqueue_tool_call_task( + self, event: RealtimeModelToolCallEvent, agent_snapshot: RealtimeAgent + ) -> None: + """Run tool calls in the background to avoid blocking realtime transport.""" + task = asyncio.create_task(self._handle_tool_call(event, agent_snapshot=agent_snapshot)) + self._tool_call_tasks.add(task) + task.add_done_callback(self._on_tool_call_task_done) + + def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None: + self._tool_call_tasks.discard(task) + + if task.cancelled(): + return + + exception = task.exception() + if exception is None: + return + + logger.exception("Realtime tool call task failed", exc_info=exception) + + if self._stored_exception is None: + self._stored_exception = exception + + asyncio.create_task( + self._put_event( + RealtimeError( + info=self._event_info, + error={"message": f"Tool call task failed: {exception}"}, + ) + ) + ) + + def _cleanup_tool_call_tasks(self) -> None: + for task in self._tool_call_tasks: + if not task.done(): + task.cancel() + self._tool_call_tasks.clear() + async def _cleanup(self) -> None: """Clean up all resources and mark session as closed.""" # Cancel and cleanup guardrail tasks self._cleanup_guardrail_tasks() + self._cleanup_tool_call_tasks() # Remove ourselves as a listener self._model.remove_listener(self) diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index 7ffb6d981..8280de61c 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -561,8 +561,13 @@ async def test_ignored_events_only_generate_raw_events(self, mock_model, mock_ag @pytest.mark.asyncio async def test_function_call_event_triggers_tool_handling(self, mock_model, mock_agent): - """Test that function_call events trigger tool call handling""" - session = RealtimeSession(mock_model, mock_agent, None) + """Test that function_call events trigger tool call handling synchronously when disabled""" + session = RealtimeSession( + mock_model, + mock_agent, + None, + run_config={"async_tool_calls": False}, + ) # Create function call event function_call_event = RealtimeModelToolCallEvent( @@ -578,7 +583,9 @@ async def test_function_call_event_triggers_tool_handling(self, mock_model, mock await session.on_event(function_call_event) # Should have called the tool handler - handle_tool_call_mock.assert_called_once_with(function_call_event) + handle_tool_call_mock.assert_called_once_with( + function_call_event, agent_snapshot=mock_agent + ) # Should still have raw event assert session._event_queue.qsize() == 1 @@ -586,6 +593,36 @@ async def test_function_call_event_triggers_tool_handling(self, mock_model, mock assert isinstance(raw_event, RealtimeRawModelEvent) assert raw_event.data == function_call_event + @pytest.mark.asyncio + async def test_function_call_event_runs_async_by_default(self, mock_model, mock_agent): + """Function call handling should be scheduled asynchronously by default""" + session = RealtimeSession(mock_model, mock_agent, None) + + function_call_event = RealtimeModelToolCallEvent( + name="test_function", + call_id="call_async", + arguments='{"param": "value"}', + ) + + with pytest.MonkeyPatch().context() as m: + handle_tool_call_mock = AsyncMock() + m.setattr(session, "_handle_tool_call", handle_tool_call_mock) + + await session.on_event(function_call_event) + + # Let the background task run + await asyncio.sleep(0) + + handle_tool_call_mock.assert_awaited_once_with( + function_call_event, agent_snapshot=mock_agent + ) + + # Raw event still enqueued + assert session._event_queue.qsize() == 1 + raw_event = await session._event_queue.get() + assert isinstance(raw_event, RealtimeRawModelEvent) + assert raw_event.data == function_call_event + class TestHistoryManagement: """Test suite for history management and audio transcription in