diff --git a/examples/agent_patterns/README.md b/examples/agent_patterns/README.md index 96b48920c..2bdadce0d 100644 --- a/examples/agent_patterns/README.md +++ b/examples/agent_patterns/README.md @@ -28,6 +28,7 @@ The mental model for handoffs is that the new agent "takes over". It sees the pr For example, you could model the translation task above as tool calls instead: rather than handing over to the language-specific agent, you could call the agent as a tool, and then use the result in the next step. This enables things like translating multiple languages at once. See the [`agents_as_tools.py`](./agents_as_tools.py) file for an example of this. +See the [`agents_as_tools_streaming.py`](./agents_as_tools_streaming.py) file for a streaming variant that taps into nested agent events via `on_stream`. ## LLM-as-a-judge diff --git a/examples/agent_patterns/agents_as_tools_streaming.py b/examples/agent_patterns/agents_as_tools_streaming.py new file mode 100644 index 000000000..846593c81 --- /dev/null +++ b/examples/agent_patterns/agents_as_tools_streaming.py @@ -0,0 +1,57 @@ +import asyncio + +from agents import Agent, AgentToolStreamEvent, ModelSettings, Runner, function_tool, trace + + +@function_tool( + name_override="billing_status_checker", + description_override="Answer questions about customer billing status.", +) +def billing_status_checker(customer_id: str | None = None, question: str = "") -> str: + """Return a canned billing answer or a fallback when the question is unrelated.""" + normalized = question.lower() + if "bill" in normalized or "billing" in normalized: + return f"This customer (ID: {customer_id})'s bill is $100" + return "I can only answer questions about billing." + + +def handle_stream(event: AgentToolStreamEvent) -> None: + """Print streaming events emitted by the nested billing agent.""" + stream = event["event"] + print(f"[stream] agent={event['agent_name']} type={stream.type} {stream}") + + +async def main() -> None: + with trace("Agents as tools streaming example"): + billing_agent = Agent( + name="Billing Agent", + instructions="You are a billing agent that answers billing questions.", + model_settings=ModelSettings(tool_choice="required"), + tools=[billing_status_checker], + ) + + billing_agent_tool = billing_agent.as_tool( + tool_name="billing_agent", + tool_description="You are a billing agent that answers billing questions.", + on_stream=handle_stream, + ) + + main_agent = Agent( + name="Customer Support Agent", + instructions=( + "You are a customer support agent. Always call the billing agent to answer billing " + "questions and return the billing agent response to the user." + ), + tools=[billing_agent_tool], + ) + + result = await Runner.run( + main_agent, + "Hello, my customer ID is ABC123. How much is my bill for this month?", + ) + + print(f"\nFinal response:\n{result.final_output}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/financial_research_agent/manager.py b/examples/financial_research_agent/manager.py index 58ec11bf2..6dfc631aa 100644 --- a/examples/financial_research_agent/manager.py +++ b/examples/financial_research_agent/manager.py @@ -6,7 +6,7 @@ from rich.console import Console -from agents import Runner, RunResult, custom_span, gen_trace_id, trace +from agents import Runner, RunResult, RunResultStreaming, custom_span, gen_trace_id, trace from .agents.financials_agent import financials_agent from .agents.planner_agent import FinancialSearchItem, FinancialSearchPlan, planner_agent @@ -17,7 +17,7 @@ from .printer import Printer -async def _summary_extractor(run_result: RunResult) -> str: +async def _summary_extractor(run_result: RunResult | RunResultStreaming) -> str: """Custom output extractor for sub‑agents that return an AnalysisSummary.""" # The financial/risk analyst agents emit an AnalysisSummary with a `summary` field. # We want the tool call to return just that summary text so the writer can drop it inline. diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 6f4d0815d..00a5ca21e 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -8,6 +8,7 @@ from .agent import ( Agent, AgentBase, + AgentToolStreamEvent, StopAtTools, ToolsToFinalOutputFunction, ToolsToFinalOutputResult, @@ -214,6 +215,7 @@ def enable_verbose_stdout_logging(): __all__ = [ "Agent", "AgentBase", + "AgentToolStreamEvent", "StopAtTools", "ToolsToFinalOutputFunction", "ToolsToFinalOutputResult", diff --git a/src/agents/agent.py b/src/agents/agent.py index c479cc697..d449fa3ae 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -32,8 +32,9 @@ from .lifecycle import AgentHooks, RunHooks from .mcp import MCPServer from .memory.session import Session - from .result import RunResult + from .result import RunResult, RunResultStreaming from .run import RunConfig + from .stream_events import StreamEvent @dataclass @@ -58,6 +59,19 @@ class ToolsToFinalOutputResult: """ +class AgentToolStreamEvent(TypedDict): + """Streaming event emitted when an agent is invoked as a tool.""" + + event: StreamEvent + """The streaming event from the nested agent run.""" + + agent_name: str + """The name of the nested agent emitting the event.""" + + tool_call_id: str | None + """The originating tool call ID, if available.""" + + class StopAtTools(TypedDict): stop_at_tool_names: list[str] """A list of tool names, any of which will stop the agent from running further.""" @@ -382,9 +396,12 @@ def as_tool( self, tool_name: str | None, tool_description: str | None, - custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None, + custom_output_extractor: ( + Callable[[RunResult | RunResultStreaming], Awaitable[str]] | None + ) = None, is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True, + on_stream: Callable[[AgentToolStreamEvent], MaybeAwaitable[None]] | None = None, run_config: RunConfig | None = None, max_turns: int | None = None, hooks: RunHooks[TContext] | None = None, @@ -409,6 +426,8 @@ def as_tool( is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run context and agent and returns whether the tool is enabled. Disabled tools are hidden from the LLM at runtime. + on_stream: Optional callback (sync or async) to receive streaming events from the nested + agent run. When provided, the nested agent is executed in streaming mode. """ @function_tool( @@ -420,22 +439,51 @@ async def run_agent(context: RunContextWrapper, input: str) -> Any: from .run import DEFAULT_MAX_TURNS, Runner resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS - - output = await Runner.run( - starting_agent=self, - input=input, - context=context.context, - run_config=run_config, - max_turns=resolved_max_turns, - hooks=hooks, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - session=session, - ) + run_result: RunResult | RunResultStreaming + + if on_stream is not None: + run_result = Runner.run_streamed( + starting_agent=self, + input=input, + context=context.context, + run_config=run_config, + max_turns=resolved_max_turns, + hooks=hooks, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + session=session, + ) + async for event in run_result.stream_events(): + payload: AgentToolStreamEvent = { + "event": event, + "agent_name": self.name, + "tool_call_id": getattr(context, "tool_call_id", None), + } + try: + maybe_result = on_stream(payload) + if inspect.isawaitable(maybe_result): + await maybe_result + except Exception: + logger.exception( + "Error while handling on_stream event for agent tool %s.", + self.name, + ) + else: + run_result = await Runner.run( + starting_agent=self, + input=input, + context=context.context, + run_config=run_config, + max_turns=resolved_max_turns, + hooks=hooks, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + session=session, + ) if custom_output_extractor: - return await custom_output_extractor(output) + return await custom_output_extractor(run_result) - return output.final_output + return run_result.final_output return run_agent diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index 51d8edf20..ab5f57660 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, cast import pytest from openai.types.responses import ResponseOutputMessage, ResponseOutputText @@ -9,6 +9,7 @@ from agents import ( Agent, AgentBase, + AgentToolStreamEvent, FunctionTool, MessageOutputItem, RunConfig, @@ -18,6 +19,7 @@ Session, TResponseInputItem, ) +from agents.stream_events import RawResponsesStreamEvent from agents.tool_context import ToolContext @@ -373,3 +375,359 @@ async def extractor(result) -> str: output = await tool.on_invoke_tool(tool_context, '{"input": "summarize this"}') assert output == "custom output" + + +@pytest.mark.asyncio +async def test_agent_as_tool_streams_events_with_on_stream( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="streamer") + stream_events = [ + RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})), + RawResponsesStreamEvent(data=cast(Any, {"type": "output_text_delta", "delta": "hi"})), + ] + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "streamed output" + + async def stream_events(self): + for ev in stream_events: + yield ev + + run_calls: list[dict[str, Any]] = [] + + def fake_run_streamed( + cls, + starting_agent, + input, + *, + context, + max_turns, + hooks, + run_config, + previous_response_id, + auto_previous_response_id=False, + conversation_id, + session, + ): + run_calls.append( + { + "starting_agent": starting_agent, + "input": input, + "context": context, + "max_turns": max_turns, + "hooks": hooks, + "run_config": run_config, + "previous_response_id": previous_response_id, + "conversation_id": conversation_id, + "session": session, + } + ) + return DummyStreamingResult() + + async def unexpected_run(*args: Any, **kwargs: Any) -> None: + raise AssertionError("Runner.run should not be called when on_stream is provided.") + + monkeypatch.setattr(Runner, "run_streamed", classmethod(fake_run_streamed)) + monkeypatch.setattr(Runner, "run", classmethod(unexpected_run)) + + received_events: list[AgentToolStreamEvent] = [] + + async def on_stream(payload: AgentToolStreamEvent) -> None: + received_events.append(payload) + + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="stream_tool", + tool_description="Streams events", + on_stream=on_stream, + ), + ) + + tool_context = ToolContext( + context=None, + tool_name="stream_tool", + tool_call_id="call-123", + tool_arguments='{"input": "run streaming"}', + ) + output = await tool.on_invoke_tool(tool_context, '{"input": "run streaming"}') + + assert output == "streamed output" + assert len(received_events) == len(stream_events) + assert received_events[0]["agent_name"] == "streamer" + assert received_events[0]["tool_call_id"] == "call-123" + assert received_events[0]["event"] == stream_events[0] + assert run_calls[0]["input"] == "run streaming" + + +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_works_with_custom_extractor( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="streamer") + stream_events = [RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))] + stream_events = [RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))] + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "raw output" + + async def stream_events(self): + for ev in stream_events: + yield ev + + streamed_instance = DummyStreamingResult() + + def fake_run_streamed( + cls, + starting_agent, + input, + *, + context, + max_turns, + hooks, + run_config, + previous_response_id, + auto_previous_response_id=False, + conversation_id, + session, + ): + return streamed_instance + + async def unexpected_run(*args: Any, **kwargs: Any) -> None: + raise AssertionError("Runner.run should not be called when on_stream is provided.") + + monkeypatch.setattr(Runner, "run_streamed", classmethod(fake_run_streamed)) + monkeypatch.setattr(Runner, "run", classmethod(unexpected_run)) + + received: list[Any] = [] + + async def extractor(result) -> str: + received.append(result) + return "custom value" + + callbacks: list[Any] = [] + + async def on_stream(payload: AgentToolStreamEvent) -> None: + callbacks.append(payload["event"]) + + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="stream_tool", + tool_description="Streams events", + custom_output_extractor=extractor, + on_stream=on_stream, + ), + ) + + tool_context = ToolContext( + context=None, + tool_name="stream_tool", + tool_call_id="call-abc", + tool_arguments='{"input": "stream please"}', + ) + output = await tool.on_invoke_tool(tool_context, '{"input": "stream please"}') + + assert output == "custom value" + assert received == [streamed_instance] + assert callbacks == stream_events + + +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_accepts_sync_handler( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="sync_handler_agent") + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "ok" + + async def stream_events(self): + yield RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})) + + monkeypatch.setattr( + Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult()) + ) + monkeypatch.setattr( + Runner, + "run", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), + ) + + calls: list[str] = [] + + def sync_handler(event: AgentToolStreamEvent) -> None: + calls.append(event["event"].type) + + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="sync_tool", + tool_description="Uses sync handler", + on_stream=sync_handler, + ), + ) + tool_context = ToolContext( + context=None, + tool_name="sync_tool", + tool_call_id="call-sync", + tool_arguments='{"input": "go"}', + ) + + output = await tool.on_invoke_tool(tool_context, '{"input": "go"}') + + assert output == "ok" + assert calls == ["raw_response_event"] + + +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_handler_exception_does_not_fail_call( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="handler_error_agent") + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "ok" + + async def stream_events(self): + yield RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})) + + monkeypatch.setattr( + Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult()) + ) + monkeypatch.setattr( + Runner, + "run", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), + ) + + def bad_handler(event: AgentToolStreamEvent) -> None: + raise RuntimeError("boom") + + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="error_tool", + tool_description="Handler throws", + on_stream=bad_handler, + ), + ) + tool_context = ToolContext( + context=None, + tool_name="error_tool", + tool_call_id="call-bad", + tool_arguments='{"input": "go"}', + ) + + output = await tool.on_invoke_tool(tool_context, '{"input": "go"}') + + assert output == "ok" + + +@pytest.mark.asyncio +async def test_agent_as_tool_without_stream_uses_run( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="nostream_agent") + + class DummyResult: + def __init__(self) -> None: + self.final_output = "plain" + + run_calls: list[dict[str, Any]] = [] + + async def fake_run( + cls, + starting_agent, + input, + *, + context, + max_turns, + hooks, + run_config, + previous_response_id, + auto_previous_response_id=False, + conversation_id, + session, + ): + run_calls.append({"input": input}) + return DummyResult() + + monkeypatch.setattr(Runner, "run", classmethod(fake_run)) + monkeypatch.setattr( + Runner, + "run_streamed", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no stream"))), + ) + + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="nostream_tool", + tool_description="No streaming path", + ), + ) + tool_context = ToolContext( + context=None, + tool_name="nostream_tool", + tool_call_id="call-no", + tool_arguments='{"input": "plain"}', + ) + + output = await tool.on_invoke_tool(tool_context, '{"input": "plain"}') + + assert output == "plain" + assert run_calls == [{"input": "plain"}] + + +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_sets_tool_call_id_from_context( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="direct_invocation_agent") + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "ok" + + async def stream_events(self): + yield RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})) + + monkeypatch.setattr( + Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult()) + ) + monkeypatch.setattr( + Runner, + "run", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), + ) + + captured: list[AgentToolStreamEvent] = [] + + async def on_stream(event: AgentToolStreamEvent) -> None: + captured.append(event) + + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="direct_stream_tool", + tool_description="Direct invocation", + on_stream=on_stream, + ), + ) + tool_context = ToolContext( + context=None, + tool_name="direct_stream_tool", + tool_call_id="direct-call-id", + tool_arguments='{"input": "hi"}', + ) + + output = await tool.on_invoke_tool(tool_context, '{"input": "hi"}') + + assert output == "ok" + assert captured[0]["tool_call_id"] == "direct-call-id"