Skip to content

Commit 94d2239

Browse files
committed
improve the concurrency of event handling
1 parent 7f1672a commit 94d2239

File tree

2 files changed

+117
-6
lines changed

2 files changed

+117
-6
lines changed

src/agents/agent.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -457,12 +457,12 @@ async def run_agent(context: RunContextWrapper, input: str) -> Any:
457457
conversation_id=conversation_id,
458458
session=session,
459459
)
460-
async for event in run_result.stream_events():
461-
payload: AgentToolStreamEvent = {
462-
"event": event,
463-
"agent": self,
464-
"tool_call": getattr(context, "tool_call", None),
465-
}
460+
# Dispatch callbacks in the background so slow handlers do not block
461+
# event consumption.
462+
event_queue: asyncio.Queue[AgentToolStreamEvent | None] = asyncio.Queue()
463+
464+
async def _run_handler(payload: AgentToolStreamEvent) -> None:
465+
"""Execute the user callback while capturing exceptions."""
466466
try:
467467
maybe_result = on_stream(payload)
468468
if inspect.isawaitable(maybe_result):
@@ -472,6 +472,34 @@ async def run_agent(context: RunContextWrapper, input: str) -> Any:
472472
"Error while handling on_stream event for agent tool %s.",
473473
self.name,
474474
)
475+
476+
async def dispatch_stream_events() -> None:
477+
while True:
478+
payload = await event_queue.get()
479+
is_sentinel = payload is None # None marks the end of the stream.
480+
try:
481+
if payload is not None:
482+
await _run_handler(payload)
483+
finally:
484+
event_queue.task_done()
485+
486+
if is_sentinel:
487+
break
488+
489+
dispatch_task = asyncio.create_task(dispatch_stream_events())
490+
491+
try:
492+
async for event in run_result.stream_events():
493+
payload: AgentToolStreamEvent = {
494+
"event": event,
495+
"agent": self,
496+
"tool_call": getattr(context, "tool_call", None),
497+
}
498+
await event_queue.put(payload)
499+
finally:
500+
await event_queue.put(None)
501+
await event_queue.join()
502+
await dispatch_task
475503
else:
476504
run_result = await Runner.run(
477505
starting_agent=self,

tests/test_agent_as_tool.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
from typing import Any, cast
45

56
import pytest
@@ -612,6 +613,88 @@ def sync_handler(event: AgentToolStreamEvent) -> None:
612613
assert calls == ["raw_response_event"]
613614

614615

616+
@pytest.mark.asyncio
617+
async def test_agent_as_tool_streaming_dispatches_without_blocking(
618+
monkeypatch: pytest.MonkeyPatch,
619+
) -> None:
620+
"""on_stream handlers should not block streaming iteration."""
621+
agent = Agent(name="nonblocking_agent")
622+
623+
first_handler_started = asyncio.Event()
624+
allow_handler_to_continue = asyncio.Event()
625+
second_event_yielded = asyncio.Event()
626+
second_event_handled = asyncio.Event()
627+
628+
first_event = RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))
629+
second_event = RawResponsesStreamEvent(
630+
data=cast(Any, {"type": "output_text_delta", "delta": "hi"})
631+
)
632+
633+
class DummyStreamingResult:
634+
def __init__(self) -> None:
635+
self.final_output = "ok"
636+
637+
async def stream_events(self):
638+
yield first_event
639+
second_event_yielded.set()
640+
yield second_event
641+
642+
dummy_result = DummyStreamingResult()
643+
644+
monkeypatch.setattr(Runner, "run_streamed", classmethod(lambda *args, **kwargs: dummy_result))
645+
monkeypatch.setattr(
646+
Runner,
647+
"run",
648+
classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))),
649+
)
650+
651+
async def on_stream(payload: AgentToolStreamEvent) -> None:
652+
if payload["event"] is first_event:
653+
first_handler_started.set()
654+
await allow_handler_to_continue.wait()
655+
else:
656+
second_event_handled.set()
657+
658+
tool_call = ResponseFunctionToolCall(
659+
id="call_nonblocking",
660+
arguments='{"input": "go"}',
661+
call_id="call-nonblocking",
662+
name="nonblocking_tool",
663+
type="function_call",
664+
)
665+
666+
tool = cast(
667+
FunctionTool,
668+
agent.as_tool(
669+
tool_name="nonblocking_tool",
670+
tool_description="Uses non-blocking streaming handler",
671+
on_stream=on_stream,
672+
),
673+
)
674+
tool_context = ToolContext(
675+
context=None,
676+
tool_name="nonblocking_tool",
677+
tool_call_id=tool_call.call_id,
678+
tool_arguments=tool_call.arguments,
679+
tool_call=tool_call,
680+
)
681+
682+
async def _invoke_tool() -> Any:
683+
return await tool.on_invoke_tool(tool_context, '{"input": "go"}')
684+
685+
invoke_task: asyncio.Task[Any] = asyncio.create_task(_invoke_tool())
686+
687+
await asyncio.wait_for(first_handler_started.wait(), timeout=1.0)
688+
await asyncio.wait_for(second_event_yielded.wait(), timeout=1.0)
689+
assert invoke_task.done() is False
690+
691+
allow_handler_to_continue.set()
692+
await asyncio.wait_for(second_event_handled.wait(), timeout=1.0)
693+
output = await asyncio.wait_for(invoke_task, timeout=1.0)
694+
695+
assert output == "ok"
696+
697+
615698
@pytest.mark.asyncio
616699
async def test_agent_as_tool_streaming_handler_exception_does_not_fail_call(
617700
monkeypatch: pytest.MonkeyPatch,

0 commit comments

Comments
 (0)