Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions examples/realtime/app/agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

from agents import function_tool
from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX
from agents.realtime import RealtimeAgent, realtime_handoff
Expand All @@ -13,20 +15,26 @@
name_override="faq_lookup_tool", description_override="Lookup frequently asked questions."
)
async def faq_lookup_tool(question: str) -> str:
if "bag" in question or "baggage" in question:
print("faq_lookup_tool called with question:", question)

# Simulate a slow API call
await asyncio.sleep(3)

q = question.lower()
if "wifi" in q or "wi-fi" in q:
return "We have free wifi on the plane, join Airline-Wifi"
elif "bag" in q or "baggage" in q:
return (
"You are allowed to bring one bag on the plane. "
"It must be under 50 pounds and 22 inches x 14 inches x 9 inches."
)
elif "seats" in question or "plane" in question:
elif "seats" in q or "plane" in q:
return (
"There are 120 seats on the plane. "
"There are 22 business class seats and 98 economy seats. "
"Exit rows are rows 4 and 16. "
"Rows 5-8 are Economy Plus, with extra legroom. "
)
elif "wifi" in question:
return "We have free wifi on the plane, join Airline-Wifi"
return "I'm sorry, I don't know the answer to that question."


Expand Down
3 changes: 3 additions & 0 deletions examples/realtime/app/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ async def connect(self, websocket: WebSocket, session_id: str):

agent = get_starting_agent()
runner = RealtimeRunner(agent)
# If you want to customize the runner behavior, you can pass options:
# runner_config = RealtimeRunConfig(async_tool_calls=False)
# runner = RealtimeRunner(agent, config=runner_config)
model_config: RealtimeModelConfig = {
"initial_model_settings": {
"turn_detection": {
Expand Down
3 changes: 3 additions & 0 deletions src/agents/realtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ class RealtimeRunConfig(TypedDict):
tracing_disabled: NotRequired[bool]
"""Whether tracing is disabled for this run."""

async_tool_calls: NotRequired[bool]
"""Whether function tool calls should run asynchronously. Defaults to True."""

# TODO (rm) Add history audio storage config


Expand Down
67 changes: 59 additions & 8 deletions src/agents/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
}
self._event_queue: asyncio.Queue[RealtimeSessionEvent] = asyncio.Queue()
self._closed = False
self._stored_exception: Exception | None = None
self._stored_exception: BaseException | None = None

# Guardrails state tracking
self._interrupted_response_ids: set[str] = set()
Expand All @@ -123,6 +123,8 @@ def __init__(
)

self._guardrail_tasks: set[asyncio.Task[Any]] = set()
self._tool_call_tasks: set[asyncio.Task[Any]] = set()
self._async_tool_calls: bool = bool(self._run_config.get("async_tool_calls", True))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned in the PR description, the default value is true. If we'd like to make this opt-in, we can change this part for it, and switch the default behavior to true in future minor/major versions. But I personally think setting true should be beneficial for everyone.


@property
def model(self) -> RealtimeModel:
Expand Down Expand Up @@ -216,7 +218,11 @@ async def on_event(self, event: RealtimeModelEvent) -> None:
if event.type == "error":
await self._put_event(RealtimeError(info=self._event_info, error=event.error))
elif event.type == "function_call":
await self._handle_tool_call(event)
agent_snapshot = self._current_agent
if self._async_tool_calls:
self._enqueue_tool_call_task(event, agent_snapshot)
else:
await self._handle_tool_call(event, agent_snapshot=agent_snapshot)
elif event.type == "audio":
await self._put_event(
RealtimeAudio(
Expand Down Expand Up @@ -384,11 +390,17 @@ async def _put_event(self, event: RealtimeSessionEvent) -> None:
"""Put an event into the queue."""
await self._event_queue.put(event)

async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
async def _handle_tool_call(
self,
event: RealtimeModelToolCallEvent,
*,
agent_snapshot: RealtimeAgent | None = None,
) -> None:
"""Handle a tool call event."""
agent = agent_snapshot or self._current_agent
tools, handoffs = await asyncio.gather(
self._current_agent.get_all_tools(self._context_wrapper),
self._get_handoffs(self._current_agent, self._context_wrapper),
agent.get_all_tools(self._context_wrapper),
self._get_handoffs(agent, self._context_wrapper),
)
function_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)}
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
Expand All @@ -398,7 +410,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
RealtimeToolStart(
info=self._event_info,
tool=function_map[event.name],
agent=self._current_agent,
agent=agent,
)
)

Expand All @@ -423,7 +435,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
info=self._event_info,
tool=func_tool,
output=result,
agent=self._current_agent,
agent=agent,
)
)
elif event.name in handoff_map:
Expand All @@ -444,7 +456,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
)

# Store previous agent for event
previous_agent = self._current_agent
previous_agent = agent

# Update current agent
self._current_agent = result
Expand Down Expand Up @@ -752,10 +764,49 @@ def _cleanup_guardrail_tasks(self) -> None:
task.cancel()
self._guardrail_tasks.clear()

def _enqueue_tool_call_task(
self, event: RealtimeModelToolCallEvent, agent_snapshot: RealtimeAgent
) -> None:
"""Run tool calls in the background to avoid blocking realtime transport."""
task = asyncio.create_task(self._handle_tool_call(event, agent_snapshot=agent_snapshot))
self._tool_call_tasks.add(task)
task.add_done_callback(self._on_tool_call_task_done)

def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None:
self._tool_call_tasks.discard(task)

if task.cancelled():
return

exception = task.exception()
if exception is None:
return

logger.exception("Realtime tool call task failed", exc_info=exception)

if self._stored_exception is None:
self._stored_exception = exception

asyncio.create_task(
self._put_event(
RealtimeError(
info=self._event_info,
error={"message": f"Tool call task failed: {exception}"},
)
)
)

def _cleanup_tool_call_tasks(self) -> None:
for task in self._tool_call_tasks:
if not task.done():
task.cancel()
self._tool_call_tasks.clear()

async def _cleanup(self) -> None:
"""Clean up all resources and mark session as closed."""
# Cancel and cleanup guardrail tasks
self._cleanup_guardrail_tasks()
self._cleanup_tool_call_tasks()

# Remove ourselves as a listener
self._model.remove_listener(self)
Expand Down
43 changes: 40 additions & 3 deletions tests/realtime/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,8 +561,13 @@ async def test_ignored_events_only_generate_raw_events(self, mock_model, mock_ag

@pytest.mark.asyncio
async def test_function_call_event_triggers_tool_handling(self, mock_model, mock_agent):
"""Test that function_call events trigger tool call handling"""
session = RealtimeSession(mock_model, mock_agent, None)
"""Test that function_call events trigger tool call handling synchronously when disabled"""
session = RealtimeSession(
mock_model,
mock_agent,
None,
run_config={"async_tool_calls": False},
)

# Create function call event
function_call_event = RealtimeModelToolCallEvent(
Expand All @@ -578,14 +583,46 @@ async def test_function_call_event_triggers_tool_handling(self, mock_model, mock
await session.on_event(function_call_event)

# Should have called the tool handler
handle_tool_call_mock.assert_called_once_with(function_call_event)
handle_tool_call_mock.assert_called_once_with(
function_call_event, agent_snapshot=mock_agent
)

# Should still have raw event
assert session._event_queue.qsize() == 1
raw_event = await session._event_queue.get()
assert isinstance(raw_event, RealtimeRawModelEvent)
assert raw_event.data == function_call_event

@pytest.mark.asyncio
async def test_function_call_event_runs_async_by_default(self, mock_model, mock_agent):
"""Function call handling should be scheduled asynchronously by default"""
session = RealtimeSession(mock_model, mock_agent, None)

function_call_event = RealtimeModelToolCallEvent(
name="test_function",
call_id="call_async",
arguments='{"param": "value"}',
)

with pytest.MonkeyPatch().context() as m:
handle_tool_call_mock = AsyncMock()
m.setattr(session, "_handle_tool_call", handle_tool_call_mock)

await session.on_event(function_call_event)

# Let the background task run
await asyncio.sleep(0)

handle_tool_call_mock.assert_awaited_once_with(
function_call_event, agent_snapshot=mock_agent
)

# Raw event still enqueued
assert session._event_queue.qsize() == 1
raw_event = await session._event_queue.get()
assert isinstance(raw_event, RealtimeRawModelEvent)
assert raw_event.data == function_call_event


class TestHistoryManagement:
"""Test suite for history management and audio transcription in
Expand Down