From 4abf66c67323e2b1f32b72b67eb1dbe3b012c11c Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 11 Dec 2025 11:36:10 +0900 Subject: [PATCH 1/3] feat: Add on_stream to agents as tools --- examples/agent_patterns/README.md | 1 + .../agents_as_tools_streaming.py | 57 +++ src/agents/__init__.py | 2 + src/agents/agent.py | 77 +++- tests/test_agent_as_tool.py | 338 ++++++++++++++++++ 5 files changed, 460 insertions(+), 15 deletions(-) create mode 100644 examples/agent_patterns/agents_as_tools_streaming.py diff --git a/examples/agent_patterns/README.md b/examples/agent_patterns/README.md index 96b48920c..2bdadce0d 100644 --- a/examples/agent_patterns/README.md +++ b/examples/agent_patterns/README.md @@ -28,6 +28,7 @@ The mental model for handoffs is that the new agent "takes over". It sees the pr For example, you could model the translation task above as tool calls instead: rather than handing over to the language-specific agent, you could call the agent as a tool, and then use the result in the next step. This enables things like translating multiple languages at once. See the [`agents_as_tools.py`](./agents_as_tools.py) file for an example of this. +See the [`agents_as_tools_streaming.py`](./agents_as_tools_streaming.py) file for a streaming variant that taps into nested agent events via `on_stream`. ## LLM-as-a-judge diff --git a/examples/agent_patterns/agents_as_tools_streaming.py b/examples/agent_patterns/agents_as_tools_streaming.py new file mode 100644 index 000000000..846593c81 --- /dev/null +++ b/examples/agent_patterns/agents_as_tools_streaming.py @@ -0,0 +1,57 @@ +import asyncio + +from agents import Agent, AgentToolStreamEvent, ModelSettings, Runner, function_tool, trace + + +@function_tool( + name_override="billing_status_checker", + description_override="Answer questions about customer billing status.", +) +def billing_status_checker(customer_id: str | None = None, question: str = "") -> str: + """Return a canned billing answer or a fallback when the question is unrelated.""" + normalized = question.lower() + if "bill" in normalized or "billing" in normalized: + return f"This customer (ID: {customer_id})'s bill is $100" + return "I can only answer questions about billing." + + +def handle_stream(event: AgentToolStreamEvent) -> None: + """Print streaming events emitted by the nested billing agent.""" + stream = event["event"] + print(f"[stream] agent={event['agent_name']} type={stream.type} {stream}") + + +async def main() -> None: + with trace("Agents as tools streaming example"): + billing_agent = Agent( + name="Billing Agent", + instructions="You are a billing agent that answers billing questions.", + model_settings=ModelSettings(tool_choice="required"), + tools=[billing_status_checker], + ) + + billing_agent_tool = billing_agent.as_tool( + tool_name="billing_agent", + tool_description="You are a billing agent that answers billing questions.", + on_stream=handle_stream, + ) + + main_agent = Agent( + name="Customer Support Agent", + instructions=( + "You are a customer support agent. Always call the billing agent to answer billing " + "questions and return the billing agent response to the user." + ), + tools=[billing_agent_tool], + ) + + result = await Runner.run( + main_agent, + "Hello, my customer ID is ABC123. How much is my bill for this month?", + ) + + print(f"\nFinal response:\n{result.final_output}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 6f4d0815d..00a5ca21e 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -8,6 +8,7 @@ from .agent import ( Agent, AgentBase, + AgentToolStreamEvent, StopAtTools, ToolsToFinalOutputFunction, ToolsToFinalOutputResult, @@ -214,6 +215,7 @@ def enable_verbose_stdout_logging(): __all__ = [ "Agent", "AgentBase", + "AgentToolStreamEvent", "StopAtTools", "ToolsToFinalOutputFunction", "ToolsToFinalOutputResult", diff --git a/src/agents/agent.py b/src/agents/agent.py index c479cc697..6d8206ab6 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -32,8 +32,9 @@ from .lifecycle import AgentHooks, RunHooks from .mcp import MCPServer from .memory.session import Session - from .result import RunResult + from .result import RunResult, RunResultStreaming from .run import RunConfig + from .stream_events import StreamEvent @dataclass @@ -58,6 +59,19 @@ class ToolsToFinalOutputResult: """ +class AgentToolStreamEvent(TypedDict): + """Streaming event emitted when an agent is invoked as a tool.""" + + event: StreamEvent + """The streaming event from the nested agent run.""" + + agent_name: str + """The name of the nested agent emitting the event.""" + + tool_call_id: str | None + """The originating tool call ID, if available.""" + + class StopAtTools(TypedDict): stop_at_tool_names: list[str] """A list of tool names, any of which will stop the agent from running further.""" @@ -382,9 +396,12 @@ def as_tool( self, tool_name: str | None, tool_description: str | None, - custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None, + custom_output_extractor: ( + Callable[[RunResult | RunResultStreaming], Awaitable[str]] | None + ) = None, is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True, + on_stream: Callable[[AgentToolStreamEvent], MaybeAwaitable[None]] | None = None, run_config: RunConfig | None = None, max_turns: int | None = None, hooks: RunHooks[TContext] | None = None, @@ -409,6 +426,8 @@ def as_tool( is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run context and agent and returns whether the tool is enabled. Disabled tools are hidden from the LLM at runtime. + on_stream: Optional callback (sync or async) to receive streaming events from the nested + agent run. When provided, the nested agent is executed in streaming mode. """ @function_tool( @@ -421,21 +440,49 @@ async def run_agent(context: RunContextWrapper, input: str) -> Any: resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS - output = await Runner.run( - starting_agent=self, - input=input, - context=context.context, - run_config=run_config, - max_turns=resolved_max_turns, - hooks=hooks, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - session=session, - ) + if on_stream is not None: + run_result = Runner.run_streamed( + starting_agent=self, + input=input, + context=context.context, + run_config=run_config, + max_turns=resolved_max_turns, + hooks=hooks, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + session=session, + ) + async for event in run_result.stream_events(): + payload: AgentToolStreamEvent = { + "event": event, + "agent_name": self.name, + "tool_call_id": getattr(context, "tool_call_id", None), + } + try: + maybe_result = on_stream(payload) + if inspect.isawaitable(maybe_result): + await maybe_result + except Exception: + logger.exception( + "Error while handling on_stream event for agent tool %s.", + self.name, + ) + else: + run_result = await Runner.run( + starting_agent=self, + input=input, + context=context.context, + run_config=run_config, + max_turns=resolved_max_turns, + hooks=hooks, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + session=session, + ) if custom_output_extractor: - return await custom_output_extractor(output) + return await custom_output_extractor(run_result) - return output.final_output + return run_result.final_output return run_agent diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index 51d8edf20..2f6b38a4d 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -18,6 +18,7 @@ Session, TResponseInputItem, ) +from agents.stream_events import RawResponsesStreamEvent from agents.tool_context import ToolContext @@ -373,3 +374,340 @@ async def extractor(result) -> str: output = await tool.on_invoke_tool(tool_context, '{"input": "summarize this"}') assert output == "custom output" + + +@pytest.mark.asyncio +async def test_agent_as_tool_streams_events_with_on_stream( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="streamer") + stream_events = [ + RawResponsesStreamEvent(data={"type": "response_started"}), + RawResponsesStreamEvent(data={"type": "output_text_delta", "delta": "hi"}), + ] + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "streamed output" + + async def stream_events(self): + for ev in stream_events: + yield ev + + run_calls: list[dict[str, Any]] = [] + + def fake_run_streamed( + cls, + starting_agent, + input, + *, + context, + max_turns, + hooks, + run_config, + previous_response_id, + auto_previous_response_id=False, + conversation_id, + session, + ): + run_calls.append( + { + "starting_agent": starting_agent, + "input": input, + "context": context, + "max_turns": max_turns, + "hooks": hooks, + "run_config": run_config, + "previous_response_id": previous_response_id, + "conversation_id": conversation_id, + "session": session, + } + ) + return DummyStreamingResult() + + async def unexpected_run(*args: Any, **kwargs: Any) -> None: + raise AssertionError("Runner.run should not be called when on_stream is provided.") + + monkeypatch.setattr(Runner, "run_streamed", classmethod(fake_run_streamed)) + monkeypatch.setattr(Runner, "run", classmethod(unexpected_run)) + + received_events: list[dict[str, Any]] = [] + + async def on_stream(payload: dict[str, Any]) -> None: + received_events.append(payload) + + tool = agent.as_tool( + tool_name="stream_tool", + tool_description="Streams events", + on_stream=on_stream, + ) + + tool_context = ToolContext( + context=None, + tool_name="stream_tool", + tool_call_id="call-123", + tool_arguments='{"input": "run streaming"}', + ) + output = await tool.on_invoke_tool(tool_context, '{"input": "run streaming"}') + + assert output == "streamed output" + assert len(received_events) == len(stream_events) + assert received_events[0]["agent_name"] == "streamer" + assert received_events[0]["tool_call_id"] == "call-123" + assert received_events[0]["event"] == stream_events[0] + assert run_calls[0]["input"] == "run streaming" + + +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_works_with_custom_extractor( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="streamer") + stream_events = [RawResponsesStreamEvent(data={"type": "response_started"})] + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "raw output" + + async def stream_events(self): + for ev in stream_events: + yield ev + + streamed_instance = DummyStreamingResult() + + def fake_run_streamed( + cls, + starting_agent, + input, + *, + context, + max_turns, + hooks, + run_config, + previous_response_id, + auto_previous_response_id=False, + conversation_id, + session, + ): + return streamed_instance + + async def unexpected_run(*args: Any, **kwargs: Any) -> None: + raise AssertionError("Runner.run should not be called when on_stream is provided.") + + monkeypatch.setattr(Runner, "run_streamed", classmethod(fake_run_streamed)) + monkeypatch.setattr(Runner, "run", classmethod(unexpected_run)) + + received: list[Any] = [] + + async def extractor(result) -> str: + received.append(result) + return "custom value" + + callbacks: list[Any] = [] + + async def on_stream(payload: dict[str, Any]) -> None: + callbacks.append(payload["event"]) + + tool = agent.as_tool( + tool_name="stream_tool", + tool_description="Streams events", + custom_output_extractor=extractor, + on_stream=on_stream, + ) + + tool_context = ToolContext( + context=None, + tool_name="stream_tool", + tool_call_id="call-abc", + tool_arguments='{"input": "stream please"}', + ) + output = await tool.on_invoke_tool(tool_context, '{"input": "stream please"}') + + assert output == "custom value" + assert received == [streamed_instance] + assert callbacks == stream_events + + +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_accepts_sync_handler( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="sync_handler_agent") + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "ok" + + async def stream_events(self): + yield RawResponsesStreamEvent(data={"type": "response_started"}) + + monkeypatch.setattr( + Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult()) + ) + monkeypatch.setattr( + Runner, + "run", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), + ) + + calls: list[str] = [] + + def sync_handler(event: dict[str, Any]) -> None: + calls.append(event["event"].type) + + tool = agent.as_tool( + tool_name="sync_tool", + tool_description="Uses sync handler", + on_stream=sync_handler, + ) + tool_context = ToolContext( + context=None, + tool_name="sync_tool", + tool_call_id="call-sync", + tool_arguments='{"input": "go"}', + ) + + output = await tool.on_invoke_tool(tool_context, '{"input": "go"}') + + assert output == "ok" + assert calls == ["raw_response_event"] + + +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_handler_exception_does_not_fail_call( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="handler_error_agent") + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "ok" + + async def stream_events(self): + yield RawResponsesStreamEvent(data={"type": "response_started"}) + + monkeypatch.setattr( + Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult()) + ) + monkeypatch.setattr( + Runner, + "run", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), + ) + + def bad_handler(event: dict[str, Any]) -> None: + raise RuntimeError("boom") + + tool = agent.as_tool( + tool_name="error_tool", + tool_description="Handler throws", + on_stream=bad_handler, + ) + tool_context = ToolContext( + context=None, + tool_name="error_tool", + tool_call_id="call-bad", + tool_arguments='{"input": "go"}', + ) + + output = await tool.on_invoke_tool(tool_context, '{"input": "go"}') + + assert output == "ok" + + +@pytest.mark.asyncio +async def test_agent_as_tool_without_stream_uses_run( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="nostream_agent") + + class DummyResult: + def __init__(self) -> None: + self.final_output = "plain" + + run_calls: list[dict[str, Any]] = [] + + async def fake_run( + cls, + starting_agent, + input, + *, + context, + max_turns, + hooks, + run_config, + previous_response_id, + auto_previous_response_id=False, + conversation_id, + session, + ): + run_calls.append({"input": input}) + return DummyResult() + + monkeypatch.setattr(Runner, "run", classmethod(fake_run)) + monkeypatch.setattr( + Runner, + "run_streamed", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no stream"))), + ) + + tool = agent.as_tool( + tool_name="nostream_tool", + tool_description="No streaming path", + ) + tool_context = ToolContext( + context=None, + tool_name="nostream_tool", + tool_call_id="call-no", + tool_arguments='{"input": "plain"}', + ) + + output = await tool.on_invoke_tool(tool_context, '{"input": "plain"}') + + assert output == "plain" + assert run_calls == [{"input": "plain"}] + + +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_sets_tool_call_id_none_for_direct_invocation( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="direct_invocation_agent") + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "ok" + + async def stream_events(self): + yield RawResponsesStreamEvent(data={"type": "response_started"}) + + monkeypatch.setattr( + Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult()) + ) + monkeypatch.setattr( + Runner, + "run", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), + ) + + captured: list[dict[str, Any]] = [] + + async def on_stream(event: dict[str, Any]) -> None: + captured.append(event) + + tool = agent.as_tool( + tool_name="direct_stream_tool", + tool_description="Direct invocation", + on_stream=on_stream, + ) + tool_context = ToolContext( + context=None, + tool_name="direct_stream_tool", + tool_call_id=None, # Direct invoke path does not have a tool call ID. + tool_arguments='{"input": "hi"}', + ) + + output = await tool.on_invoke_tool(tool_context, '{"input": "hi"}') + + assert output == "ok" + assert captured[0]["tool_call_id"] is None From a9133665bfed2ceeef0fe62fe78bedad8c307284 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 11 Dec 2025 11:46:01 +0900 Subject: [PATCH 2/3] fix mypy errors --- examples/financial_research_agent/manager.py | 4 +- src/agents/agent.py | 1 + tests/test_agent_as_tool.py | 102 +++++++++++-------- 3 files changed, 64 insertions(+), 43 deletions(-) diff --git a/examples/financial_research_agent/manager.py b/examples/financial_research_agent/manager.py index 58ec11bf2..6dfc631aa 100644 --- a/examples/financial_research_agent/manager.py +++ b/examples/financial_research_agent/manager.py @@ -6,7 +6,7 @@ from rich.console import Console -from agents import Runner, RunResult, custom_span, gen_trace_id, trace +from agents import Runner, RunResult, RunResultStreaming, custom_span, gen_trace_id, trace from .agents.financials_agent import financials_agent from .agents.planner_agent import FinancialSearchItem, FinancialSearchPlan, planner_agent @@ -17,7 +17,7 @@ from .printer import Printer -async def _summary_extractor(run_result: RunResult) -> str: +async def _summary_extractor(run_result: RunResult | RunResultStreaming) -> str: """Custom output extractor for sub‑agents that return an AnalysisSummary.""" # The financial/risk analyst agents emit an AnalysisSummary with a `summary` field. # We want the tool call to return just that summary text so the writer can drop it inline. diff --git a/src/agents/agent.py b/src/agents/agent.py index 6d8206ab6..d449fa3ae 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -439,6 +439,7 @@ async def run_agent(context: RunContextWrapper, input: str) -> Any: from .run import DEFAULT_MAX_TURNS, Runner resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS + run_result: RunResult | RunResultStreaming if on_stream is not None: run_result = Runner.run_streamed( diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index 2f6b38a4d..ab5f57660 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, cast import pytest from openai.types.responses import ResponseOutputMessage, ResponseOutputText @@ -9,6 +9,7 @@ from agents import ( Agent, AgentBase, + AgentToolStreamEvent, FunctionTool, MessageOutputItem, RunConfig, @@ -382,8 +383,8 @@ async def test_agent_as_tool_streams_events_with_on_stream( ) -> None: agent = Agent(name="streamer") stream_events = [ - RawResponsesStreamEvent(data={"type": "response_started"}), - RawResponsesStreamEvent(data={"type": "output_text_delta", "delta": "hi"}), + RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})), + RawResponsesStreamEvent(data=cast(Any, {"type": "output_text_delta", "delta": "hi"})), ] class DummyStreamingResult: @@ -431,15 +432,18 @@ async def unexpected_run(*args: Any, **kwargs: Any) -> None: monkeypatch.setattr(Runner, "run_streamed", classmethod(fake_run_streamed)) monkeypatch.setattr(Runner, "run", classmethod(unexpected_run)) - received_events: list[dict[str, Any]] = [] + received_events: list[AgentToolStreamEvent] = [] - async def on_stream(payload: dict[str, Any]) -> None: + async def on_stream(payload: AgentToolStreamEvent) -> None: received_events.append(payload) - tool = agent.as_tool( - tool_name="stream_tool", - tool_description="Streams events", - on_stream=on_stream, + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="stream_tool", + tool_description="Streams events", + on_stream=on_stream, + ), ) tool_context = ToolContext( @@ -463,7 +467,8 @@ async def test_agent_as_tool_streaming_works_with_custom_extractor( monkeypatch: pytest.MonkeyPatch, ) -> None: agent = Agent(name="streamer") - stream_events = [RawResponsesStreamEvent(data={"type": "response_started"})] + stream_events = [RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))] + stream_events = [RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))] class DummyStreamingResult: def __init__(self) -> None: @@ -505,14 +510,17 @@ async def extractor(result) -> str: callbacks: list[Any] = [] - async def on_stream(payload: dict[str, Any]) -> None: + async def on_stream(payload: AgentToolStreamEvent) -> None: callbacks.append(payload["event"]) - tool = agent.as_tool( - tool_name="stream_tool", - tool_description="Streams events", - custom_output_extractor=extractor, - on_stream=on_stream, + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="stream_tool", + tool_description="Streams events", + custom_output_extractor=extractor, + on_stream=on_stream, + ), ) tool_context = ToolContext( @@ -539,7 +547,7 @@ def __init__(self) -> None: self.final_output = "ok" async def stream_events(self): - yield RawResponsesStreamEvent(data={"type": "response_started"}) + yield RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})) monkeypatch.setattr( Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult()) @@ -552,13 +560,16 @@ async def stream_events(self): calls: list[str] = [] - def sync_handler(event: dict[str, Any]) -> None: + def sync_handler(event: AgentToolStreamEvent) -> None: calls.append(event["event"].type) - tool = agent.as_tool( - tool_name="sync_tool", - tool_description="Uses sync handler", - on_stream=sync_handler, + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="sync_tool", + tool_description="Uses sync handler", + on_stream=sync_handler, + ), ) tool_context = ToolContext( context=None, @@ -584,7 +595,7 @@ def __init__(self) -> None: self.final_output = "ok" async def stream_events(self): - yield RawResponsesStreamEvent(data={"type": "response_started"}) + yield RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})) monkeypatch.setattr( Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult()) @@ -595,13 +606,16 @@ async def stream_events(self): classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), ) - def bad_handler(event: dict[str, Any]) -> None: + def bad_handler(event: AgentToolStreamEvent) -> None: raise RuntimeError("boom") - tool = agent.as_tool( - tool_name="error_tool", - tool_description="Handler throws", - on_stream=bad_handler, + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="error_tool", + tool_description="Handler throws", + on_stream=bad_handler, + ), ) tool_context = ToolContext( context=None, @@ -651,9 +665,12 @@ async def fake_run( classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no stream"))), ) - tool = agent.as_tool( - tool_name="nostream_tool", - tool_description="No streaming path", + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="nostream_tool", + tool_description="No streaming path", + ), ) tool_context = ToolContext( context=None, @@ -669,7 +686,7 @@ async def fake_run( @pytest.mark.asyncio -async def test_agent_as_tool_streaming_sets_tool_call_id_none_for_direct_invocation( +async def test_agent_as_tool_streaming_sets_tool_call_id_from_context( monkeypatch: pytest.MonkeyPatch, ) -> None: agent = Agent(name="direct_invocation_agent") @@ -679,7 +696,7 @@ def __init__(self) -> None: self.final_output = "ok" async def stream_events(self): - yield RawResponsesStreamEvent(data={"type": "response_started"}) + yield RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})) monkeypatch.setattr( Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult()) @@ -690,24 +707,27 @@ async def stream_events(self): classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), ) - captured: list[dict[str, Any]] = [] + captured: list[AgentToolStreamEvent] = [] - async def on_stream(event: dict[str, Any]) -> None: + async def on_stream(event: AgentToolStreamEvent) -> None: captured.append(event) - tool = agent.as_tool( - tool_name="direct_stream_tool", - tool_description="Direct invocation", - on_stream=on_stream, + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="direct_stream_tool", + tool_description="Direct invocation", + on_stream=on_stream, + ), ) tool_context = ToolContext( context=None, tool_name="direct_stream_tool", - tool_call_id=None, # Direct invoke path does not have a tool call ID. + tool_call_id="direct-call-id", tool_arguments='{"input": "hi"}', ) output = await tool.on_invoke_tool(tool_context, '{"input": "hi"}') assert output == "ok" - assert captured[0]["tool_call_id"] is None + assert captured[0]["tool_call_id"] == "direct-call-id" From e384c289e61bc0ba8fb7bc863224e7b145801657 Mon Sep 17 00:00:00 2001 From: Wen-Tien Chang Date: Thu, 11 Dec 2025 22:23:34 +0800 Subject: [PATCH 3/3] Update agent-as-tool streaming API to pass StreamEvent and caller tool context --- .../agents_as_tools_streaming.py | 40 ++++-- src/agents/__init__.py | 2 - src/agents/_run_impl.py | 1 + src/agents/agent.py | 25 +--- src/agents/realtime/session.py | 2 + src/agents/tool_context.py | 19 ++- tests/test_agent_as_tool.py | 52 +++++--- tests/test_function_tool.py | 118 +++++++++++++++--- tests/test_function_tool_decorator.py | 10 +- tests/test_tool_guardrails.py | 3 + tests/test_tool_metadata.py | 3 + 11 files changed, 209 insertions(+), 66 deletions(-) diff --git a/examples/agent_patterns/agents_as_tools_streaming.py b/examples/agent_patterns/agents_as_tools_streaming.py index 846593c81..030e6f1ba 100644 --- a/examples/agent_patterns/agents_as_tools_streaming.py +++ b/examples/agent_patterns/agents_as_tools_streaming.py @@ -1,6 +1,8 @@ import asyncio -from agents import Agent, AgentToolStreamEvent, ModelSettings, Runner, function_tool, trace +from agents import Agent, ModelSettings, Runner, function_tool, trace +from agents.stream_events import StreamEvent +from agents.tool_context import ToolContext @function_tool( @@ -15,10 +17,22 @@ def billing_status_checker(customer_id: str | None = None, question: str = "") - return "I can only answer questions about billing." -def handle_stream(event: AgentToolStreamEvent) -> None: - """Print streaming events emitted by the nested billing agent.""" - stream = event["event"] - print(f"[stream] agent={event['agent_name']} type={stream.type} {stream}") +def handle_billing_agent_stream( + event: StreamEvent, caller_tool_context: ToolContext | None = None +) -> None: + """ + Stream handler that works in both scenarios: + 1. Direct streaming: Runner.run_streamed() - caller_tool_context is None + 2. Agent-as-tool streaming: as_tool(on_stream=...) - caller_tool_context contains caller info + """ + if caller_tool_context: + print( + f"[stream from caller agent={caller_tool_context.caller_agent.name} " + f"tool_name={caller_tool_context.tool_name} " + f"tool_id={caller_tool_context.tool_call_id}] {event.type}" + ) + else: + print(f"[stream] {event.type}") async def main() -> None: @@ -30,10 +44,20 @@ async def main() -> None: tools=[billing_status_checker], ) + # Scenario 1: Run the billing agent directly with streaming + result1 = Runner.run_streamed( + billing_agent, + "Hello, my customer ID is ABC123. How much is my bill for this month?", + ) + + async for event in result1.stream_events(): + handle_billing_agent_stream(event) + + # Scenario 2: Use billing agent as a tool with streaming via on_stream callback billing_agent_tool = billing_agent.as_tool( tool_name="billing_agent", tool_description="You are a billing agent that answers billing questions.", - on_stream=handle_stream, + on_stream=handle_billing_agent_stream, ) main_agent = Agent( @@ -45,12 +69,12 @@ async def main() -> None: tools=[billing_agent_tool], ) - result = await Runner.run( + result2 = await Runner.run( main_agent, "Hello, my customer ID is ABC123. How much is my bill for this month?", ) - print(f"\nFinal response:\n{result.final_output}") + print(f"\response:\n{result2.final_output}") if __name__ == "__main__": diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 00a5ca21e..6f4d0815d 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -8,7 +8,6 @@ from .agent import ( Agent, AgentBase, - AgentToolStreamEvent, StopAtTools, ToolsToFinalOutputFunction, ToolsToFinalOutputResult, @@ -215,7 +214,6 @@ def enable_verbose_stdout_logging(): __all__ = [ "Agent", "AgentBase", - "AgentToolStreamEvent", "StopAtTools", "ToolsToFinalOutputFunction", "ToolsToFinalOutputResult", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 48e8eebdf..4601a781b 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -926,6 +926,7 @@ async def run_single_tool( context_wrapper, tool_call.call_id, tool_call=tool_call, + caller_agent=agent, ) if config.trace_include_sensitive_data: span_fn.span_data.input = tool_call.arguments diff --git a/src/agents/agent.py b/src/agents/agent.py index d449fa3ae..640a43793 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -35,6 +35,7 @@ from .result import RunResult, RunResultStreaming from .run import RunConfig from .stream_events import StreamEvent + from .tool_context import ToolContext @dataclass @@ -59,19 +60,6 @@ class ToolsToFinalOutputResult: """ -class AgentToolStreamEvent(TypedDict): - """Streaming event emitted when an agent is invoked as a tool.""" - - event: StreamEvent - """The streaming event from the nested agent run.""" - - agent_name: str - """The name of the nested agent emitting the event.""" - - tool_call_id: str | None - """The originating tool call ID, if available.""" - - class StopAtTools(TypedDict): stop_at_tool_names: list[str] """A list of tool names, any of which will stop the agent from running further.""" @@ -401,7 +389,7 @@ def as_tool( ) = None, is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True, - on_stream: Callable[[AgentToolStreamEvent], MaybeAwaitable[None]] | None = None, + on_stream: Callable[[StreamEvent, ToolContext], MaybeAwaitable[None]] | None = None, run_config: RunConfig | None = None, max_turns: int | None = None, hooks: RunHooks[TContext] | None = None, @@ -428,6 +416,8 @@ def as_tool( from the LLM at runtime. on_stream: Optional callback (sync or async) to receive streaming events from the nested agent run. When provided, the nested agent is executed in streaming mode. + The callback receives (event, caller_tool_context) where caller_tool_context + provides access to the calling agent via caller_agent field. """ @function_tool( @@ -454,13 +444,8 @@ async def run_agent(context: RunContextWrapper, input: str) -> Any: session=session, ) async for event in run_result.stream_events(): - payload: AgentToolStreamEvent = { - "event": event, - "agent_name": self.name, - "tool_call_id": getattr(context, "tool_call_id", None), - } try: - maybe_result = on_stream(payload) + maybe_result = on_stream(event, cast("ToolContext", context)) if inspect.isawaitable(maybe_result): await maybe_result except Exception: diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index a3cd1d3ea..c555f409e 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -422,6 +422,7 @@ async def _handle_tool_call( tool_name=event.name, tool_call_id=event.call_id, tool_arguments=event.arguments, + caller_agent=agent, ) result = await func_tool.on_invoke_tool(tool_context, event.arguments) @@ -448,6 +449,7 @@ async def _handle_tool_call( tool_name=event.name, tool_call_id=event.call_id, tool_arguments=event.arguments, + caller_agent=agent, ) # Execute the handoff to get the new agent diff --git a/src/agents/tool_context.py b/src/agents/tool_context.py index 5b81239f6..ce0b3014d 100644 --- a/src/agents/tool_context.py +++ b/src/agents/tool_context.py @@ -1,10 +1,13 @@ from dataclasses import dataclass, field, fields -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from openai.types.responses import ResponseFunctionToolCall from .run_context import RunContextWrapper, TContext +if TYPE_CHECKING: + from .agent import AgentBase + def _assert_must_pass_tool_call_id() -> str: raise ValueError("tool_call_id must be passed to ToolContext") @@ -18,6 +21,10 @@ def _assert_must_pass_tool_arguments() -> str: raise ValueError("tool_arguments must be passed to ToolContext") +def _assert_must_pass_caller_agent() -> "AgentBase": + raise ValueError("caller_agent must be passed to ToolContext") + + @dataclass class ToolContext(RunContextWrapper[TContext]): """The context of a tool call.""" @@ -31,11 +38,15 @@ class ToolContext(RunContextWrapper[TContext]): tool_arguments: str = field(default_factory=_assert_must_pass_tool_arguments) """The raw arguments string of the tool call.""" + caller_agent: "AgentBase" = field(default_factory=_assert_must_pass_caller_agent) + """The agent that called this tool.""" + @classmethod def from_agent_context( cls, context: RunContextWrapper[TContext], tool_call_id: str, + caller_agent: "AgentBase", tool_call: Optional[ResponseFunctionToolCall] = None, ) -> "ToolContext": """ @@ -51,5 +62,9 @@ def from_agent_context( ) return cls( - tool_name=tool_name, tool_call_id=tool_call_id, tool_arguments=tool_args, **base_values + tool_name=tool_name, + tool_call_id=tool_call_id, + tool_arguments=tool_args, + caller_agent=caller_agent, + **base_values, ) diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index ab5f57660..7cb4c67ea 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -9,7 +9,6 @@ from agents import ( Agent, AgentBase, - AgentToolStreamEvent, FunctionTool, MessageOutputItem, RunConfig, @@ -19,9 +18,11 @@ Session, TResponseInputItem, ) -from agents.stream_events import RawResponsesStreamEvent +from agents.stream_events import RawResponsesStreamEvent, StreamEvent from agents.tool_context import ToolContext +_caller_agent = Agent(name="test_caller_agent") + class BoolCtx(BaseModel): enable_tools: bool @@ -269,6 +270,7 @@ async def fake_run( tool_name="story_tool", tool_call_id="call_1", tool_arguments='{"input": "hello"}', + caller_agent=_caller_agent, ) output = await tool.on_invoke_tool(tool_context, '{"input": "hello"}') @@ -371,6 +373,7 @@ async def extractor(result) -> str: tool_name="summary_tool", tool_call_id="call_2", tool_arguments='{"input": "summarize this"}', + caller_agent=_caller_agent, ) output = await tool.on_invoke_tool(tool_context, '{"input": "summarize this"}') @@ -432,10 +435,10 @@ async def unexpected_run(*args: Any, **kwargs: Any) -> None: monkeypatch.setattr(Runner, "run_streamed", classmethod(fake_run_streamed)) monkeypatch.setattr(Runner, "run", classmethod(unexpected_run)) - received_events: list[AgentToolStreamEvent] = [] + received_events: list[tuple[StreamEvent, ToolContext | None]] = [] - async def on_stream(payload: AgentToolStreamEvent) -> None: - received_events.append(payload) + async def on_stream(event: StreamEvent, caller_tool_context: ToolContext | None) -> None: + received_events.append((event, caller_tool_context)) tool = cast( FunctionTool, @@ -446,19 +449,24 @@ async def on_stream(payload: AgentToolStreamEvent) -> None: ), ) + caller_agent = Agent(name="caller_agent") tool_context = ToolContext( context=None, tool_name="stream_tool", tool_call_id="call-123", tool_arguments='{"input": "run streaming"}', + caller_agent=caller_agent, ) output = await tool.on_invoke_tool(tool_context, '{"input": "run streaming"}') assert output == "streamed output" assert len(received_events) == len(stream_events) - assert received_events[0]["agent_name"] == "streamer" - assert received_events[0]["tool_call_id"] == "call-123" - assert received_events[0]["event"] == stream_events[0] + assert received_events[0][1] is not None + assert received_events[0][1].tool_call_id == "call-123" + assert received_events[0][1].tool_name == "stream_tool" + assert received_events[0][1].caller_agent is not None + assert received_events[0][1].caller_agent.name == "caller_agent" + assert received_events[0][0] == stream_events[0] assert run_calls[0]["input"] == "run streaming" @@ -510,8 +518,8 @@ async def extractor(result) -> str: callbacks: list[Any] = [] - async def on_stream(payload: AgentToolStreamEvent) -> None: - callbacks.append(payload["event"]) + async def on_stream(event: StreamEvent, caller_tool_context: ToolContext | None) -> None: + callbacks.append(event) tool = cast( FunctionTool, @@ -528,6 +536,7 @@ async def on_stream(payload: AgentToolStreamEvent) -> None: tool_name="stream_tool", tool_call_id="call-abc", tool_arguments='{"input": "stream please"}', + caller_agent=_caller_agent, ) output = await tool.on_invoke_tool(tool_context, '{"input": "stream please"}') @@ -560,8 +569,8 @@ async def stream_events(self): calls: list[str] = [] - def sync_handler(event: AgentToolStreamEvent) -> None: - calls.append(event["event"].type) + def sync_handler(event: StreamEvent, caller_tool_context: ToolContext | None) -> None: + calls.append(event.type) tool = cast( FunctionTool, @@ -576,6 +585,7 @@ def sync_handler(event: AgentToolStreamEvent) -> None: tool_name="sync_tool", tool_call_id="call-sync", tool_arguments='{"input": "go"}', + caller_agent=_caller_agent, ) output = await tool.on_invoke_tool(tool_context, '{"input": "go"}') @@ -606,7 +616,7 @@ async def stream_events(self): classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), ) - def bad_handler(event: AgentToolStreamEvent) -> None: + def bad_handler(event: StreamEvent, caller_tool_context: ToolContext | None) -> None: raise RuntimeError("boom") tool = cast( @@ -622,6 +632,7 @@ def bad_handler(event: AgentToolStreamEvent) -> None: tool_name="error_tool", tool_call_id="call-bad", tool_arguments='{"input": "go"}', + caller_agent=_caller_agent, ) output = await tool.on_invoke_tool(tool_context, '{"input": "go"}') @@ -677,6 +688,7 @@ async def fake_run( tool_name="nostream_tool", tool_call_id="call-no", tool_arguments='{"input": "plain"}', + caller_agent=_caller_agent, ) output = await tool.on_invoke_tool(tool_context, '{"input": "plain"}') @@ -707,10 +719,10 @@ async def stream_events(self): classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), ) - captured: list[AgentToolStreamEvent] = [] + captured: list[tuple[StreamEvent, ToolContext | None]] = [] - async def on_stream(event: AgentToolStreamEvent) -> None: - captured.append(event) + async def on_stream(event: StreamEvent, caller_tool_context: ToolContext | None) -> None: + captured.append((event, caller_tool_context)) tool = cast( FunctionTool, @@ -720,14 +732,20 @@ async def on_stream(event: AgentToolStreamEvent) -> None: on_stream=on_stream, ), ) + caller_agent = Agent(name="caller_agent") tool_context = ToolContext( context=None, tool_name="direct_stream_tool", tool_call_id="direct-call-id", tool_arguments='{"input": "hi"}', + caller_agent=caller_agent, ) output = await tool.on_invoke_tool(tool_context, '{"input": "hi"}') assert output == "ok" - assert captured[0]["tool_call_id"] == "direct-call-id" + assert captured[0][1] is not None + assert captured[0][1].tool_call_id == "direct-call-id" + assert captured[0][1].tool_name == "direct_stream_tool" + assert captured[0][1].caller_agent is not None + assert captured[0][1].caller_agent.name == "caller_agent" diff --git a/tests/test_function_tool.py b/tests/test_function_tool.py index 18107773d..d0f9420a4 100644 --- a/tests/test_function_tool.py +++ b/tests/test_function_tool.py @@ -16,6 +16,8 @@ from agents.tool import default_tool_error_function from agents.tool_context import ToolContext +_test_agent = Agent(name="test_agent") + def argless_function() -> str: return "ok" @@ -27,7 +29,14 @@ async def test_argless_function(): assert tool.name == "argless_function" result = await tool.on_invoke_tool( - ToolContext(context=None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), "" + ToolContext( + context=None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments="", + caller_agent=_test_agent, + ), + "", ) assert result == "ok" @@ -42,13 +51,22 @@ async def test_argless_with_context(): assert tool.name == "argless_with_context" result = await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), "" + ToolContext( + None, tool_name=tool.name, tool_call_id="1", tool_arguments="", caller_agent=_test_agent + ), + "", ) assert result == "ok" # Extra JSON should not raise an error result = await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1}'), + ToolContext( + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments='{"a": 1}', + caller_agent=_test_agent, + ), '{"a": 1}', ) assert result == "ok" @@ -64,13 +82,25 @@ async def test_simple_function(): assert tool.name == "simple_function" result = await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1}'), + ToolContext( + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments='{"a": 1}', + caller_agent=_test_agent, + ), '{"a": 1}', ) assert result == 6 result = await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1, "b": 2}'), + ToolContext( + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments='{"a": 1, "b": 2}', + caller_agent=_test_agent, + ), '{"a": 1, "b": 2}', ) assert result == 3 @@ -78,7 +108,14 @@ async def test_simple_function(): # Missing required argument should raise an error with pytest.raises(ModelBehaviorError): await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), "" + ToolContext( + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments="", + caller_agent=_test_agent, + ), + "", ) @@ -108,7 +145,13 @@ async def test_complex_args_function(): } ) result = await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=valid_json), + ToolContext( + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments=valid_json, + caller_agent=_test_agent, + ), valid_json, ) assert result == "6 hello10 hello" @@ -120,7 +163,13 @@ async def test_complex_args_function(): } ) result = await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=valid_json), + ToolContext( + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments=valid_json, + caller_agent=_test_agent, + ), valid_json, ) assert result == "3 hello10 hello" @@ -133,7 +182,13 @@ async def test_complex_args_function(): } ) result = await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=valid_json), + ToolContext( + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments=valid_json, + caller_agent=_test_agent, + ), valid_json, ) assert result == "3 hello10 world" @@ -142,7 +197,11 @@ async def test_complex_args_function(): with pytest.raises(ModelBehaviorError): await tool.on_invoke_tool( ToolContext( - None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"foo": {"a": 1}}' + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments='{"foo": {"a": 1}}', + caller_agent=_test_agent, ), '{"foo": {"a": 1}}', ) @@ -207,7 +266,11 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str: result = await tool.on_invoke_tool( ToolContext( - None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"data": "hello"}' + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments='{"data": "hello"}', + caller_agent=_test_agent, ), '{"data": "hello"}', ) @@ -230,6 +293,7 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str: tool_name=tool_not_strict.name, tool_call_id="1", tool_arguments='{"data": "hello", "bar": "baz"}', + caller_agent=_test_agent, ), '{"data": "hello", "bar": "baz"}', ) @@ -242,7 +306,13 @@ def my_func(a: int, b: int = 5): raise ValueError("test") tool = function_tool(my_func) - ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments="") + ctx = ToolContext( + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments="", + caller_agent=_test_agent, + ) result = await tool.on_invoke_tool(ctx, "") assert "Invalid JSON" in str(result) @@ -266,7 +336,13 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) -> return f"error_{error.__class__.__name__}" tool = function_tool(my_func, failure_error_function=custom_sync_error_function) - ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments="") + ctx = ToolContext( + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments="", + caller_agent=_test_agent, + ) result = await tool.on_invoke_tool(ctx, "") assert result == "error_ModelBehaviorError" @@ -290,7 +366,13 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) -> return f"error_{error.__class__.__name__}" tool = function_tool(my_func, failure_error_function=custom_sync_error_function) - ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments="") + ctx = ToolContext( + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments="", + caller_agent=_test_agent, + ) result = await tool.on_invoke_tool(ctx, "") assert result == "error_ModelBehaviorError" @@ -356,6 +438,12 @@ def boom() -> None: """Always raises to trigger the failure handler.""" raise RuntimeError("kapow") - ctx = ToolContext(None, tool_name=boom.name, tool_call_id="boom", tool_arguments="{}") + ctx = ToolContext( + None, + tool_name=boom.name, + tool_call_id="boom", + tool_arguments="{}", + caller_agent=_test_agent, + ) result = await boom.on_invoke_tool(ctx, "{}") assert result.startswith("handled:") diff --git a/tests/test_function_tool_decorator.py b/tests/test_function_tool_decorator.py index 2f5a38223..31f7177b3 100644 --- a/tests/test_function_tool_decorator.py +++ b/tests/test_function_tool_decorator.py @@ -5,10 +5,12 @@ import pytest from inline_snapshot import snapshot -from agents import function_tool +from agents import Agent, function_tool from agents.run_context import RunContextWrapper from agents.tool_context import ToolContext +_test_agent = Agent(name="test_agent") + class DummyContext: def __init__(self): @@ -17,7 +19,11 @@ def __init__(self): def ctx_wrapper() -> ToolContext[DummyContext]: return ToolContext( - context=DummyContext(), tool_name="dummy", tool_call_id="1", tool_arguments="" + context=DummyContext(), + tool_name="dummy", + tool_call_id="1", + tool_arguments="", + caller_agent=_test_agent, ) diff --git a/tests/test_tool_guardrails.py b/tests/test_tool_guardrails.py index 8ccaec0ad..81268adc5 100644 --- a/tests/test_tool_guardrails.py +++ b/tests/test_tool_guardrails.py @@ -19,6 +19,8 @@ from agents.tool_context import ToolContext from agents.tool_guardrails import tool_input_guardrail, tool_output_guardrail +_test_agent = Agent(name="test_agent") + def get_mock_tool_context(tool_arguments: str = '{"param": "value"}') -> ToolContext: """Helper to create a mock tool context for testing.""" @@ -27,6 +29,7 @@ def get_mock_tool_context(tool_arguments: str = '{"param": "value"}') -> ToolCon tool_name="test_tool", tool_call_id="call_123", tool_arguments=tool_arguments, + caller_agent=_test_agent, ) diff --git a/tests/test_tool_metadata.py b/tests/test_tool_metadata.py index ad6395e9b..db8f6848f 100644 --- a/tests/test_tool_metadata.py +++ b/tests/test_tool_metadata.py @@ -4,6 +4,7 @@ from openai.types.responses.tool_param import CodeInterpreter, ImageGeneration, Mcp +from agents import Agent from agents.computer import Computer from agents.run_context import RunContextWrapper from agents.tool import ( @@ -57,6 +58,7 @@ def test_shell_command_output_status_property() -> None: def test_tool_context_from_agent_context() -> None: ctx = RunContextWrapper(context={"foo": "bar"}) + test_agent = Agent(name="test_agent") tool_call = ToolContext.from_agent_context( ctx, tool_call_id="123", @@ -68,5 +70,6 @@ def test_tool_context_from_agent_context() -> None: "arguments": "{}", }, )(), + caller_agent=test_agent, ) assert tool_call.tool_name == "demo"