diff --git a/examples/agent_patterns/README.md b/examples/agent_patterns/README.md index 96b48920c..2bdadce0d 100644 --- a/examples/agent_patterns/README.md +++ b/examples/agent_patterns/README.md @@ -28,6 +28,7 @@ The mental model for handoffs is that the new agent "takes over". It sees the pr For example, you could model the translation task above as tool calls instead: rather than handing over to the language-specific agent, you could call the agent as a tool, and then use the result in the next step. This enables things like translating multiple languages at once. See the [`agents_as_tools.py`](./agents_as_tools.py) file for an example of this. +See the [`agents_as_tools_streaming.py`](./agents_as_tools_streaming.py) file for a streaming variant that taps into nested agent events via `on_stream`. ## LLM-as-a-judge diff --git a/examples/agent_patterns/agents_as_tools_streaming.py b/examples/agent_patterns/agents_as_tools_streaming.py new file mode 100644 index 000000000..030e6f1ba --- /dev/null +++ b/examples/agent_patterns/agents_as_tools_streaming.py @@ -0,0 +1,81 @@ +import asyncio + +from agents import Agent, ModelSettings, Runner, function_tool, trace +from agents.stream_events import StreamEvent +from agents.tool_context import ToolContext + + +@function_tool( + name_override="billing_status_checker", + description_override="Answer questions about customer billing status.", +) +def billing_status_checker(customer_id: str | None = None, question: str = "") -> str: + """Return a canned billing answer or a fallback when the question is unrelated.""" + normalized = question.lower() + if "bill" in normalized or "billing" in normalized: + return f"This customer (ID: {customer_id})'s bill is $100" + return "I can only answer questions about billing." + + +def handle_billing_agent_stream( + event: StreamEvent, caller_tool_context: ToolContext | None = None +) -> None: + """ + Stream handler that works in both scenarios: + 1. Direct streaming: Runner.run_streamed() - caller_tool_context is None + 2. Agent-as-tool streaming: as_tool(on_stream=...) - caller_tool_context contains caller info + """ + if caller_tool_context: + print( + f"[stream from caller agent={caller_tool_context.caller_agent.name} " + f"tool_name={caller_tool_context.tool_name} " + f"tool_id={caller_tool_context.tool_call_id}] {event.type}" + ) + else: + print(f"[stream] {event.type}") + + +async def main() -> None: + with trace("Agents as tools streaming example"): + billing_agent = Agent( + name="Billing Agent", + instructions="You are a billing agent that answers billing questions.", + model_settings=ModelSettings(tool_choice="required"), + tools=[billing_status_checker], + ) + + # Scenario 1: Run the billing agent directly with streaming + result1 = Runner.run_streamed( + billing_agent, + "Hello, my customer ID is ABC123. How much is my bill for this month?", + ) + + async for event in result1.stream_events(): + handle_billing_agent_stream(event) + + # Scenario 2: Use billing agent as a tool with streaming via on_stream callback + billing_agent_tool = billing_agent.as_tool( + tool_name="billing_agent", + tool_description="You are a billing agent that answers billing questions.", + on_stream=handle_billing_agent_stream, + ) + + main_agent = Agent( + name="Customer Support Agent", + instructions=( + "You are a customer support agent. Always call the billing agent to answer billing " + "questions and return the billing agent response to the user." + ), + tools=[billing_agent_tool], + ) + + result2 = await Runner.run( + main_agent, + "Hello, my customer ID is ABC123. How much is my bill for this month?", + ) + + print(f"\response:\n{result2.final_output}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/financial_research_agent/manager.py b/examples/financial_research_agent/manager.py index 58ec11bf2..6dfc631aa 100644 --- a/examples/financial_research_agent/manager.py +++ b/examples/financial_research_agent/manager.py @@ -6,7 +6,7 @@ from rich.console import Console -from agents import Runner, RunResult, custom_span, gen_trace_id, trace +from agents import Runner, RunResult, RunResultStreaming, custom_span, gen_trace_id, trace from .agents.financials_agent import financials_agent from .agents.planner_agent import FinancialSearchItem, FinancialSearchPlan, planner_agent @@ -17,7 +17,7 @@ from .printer import Printer -async def _summary_extractor(run_result: RunResult) -> str: +async def _summary_extractor(run_result: RunResult | RunResultStreaming) -> str: """Custom output extractor for sub‑agents that return an AnalysisSummary.""" # The financial/risk analyst agents emit an AnalysisSummary with a `summary` field. # We want the tool call to return just that summary text so the writer can drop it inline. diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 48e8eebdf..4601a781b 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -926,6 +926,7 @@ async def run_single_tool( context_wrapper, tool_call.call_id, tool_call=tool_call, + caller_agent=agent, ) if config.trace_include_sensitive_data: span_fn.span_data.input = tool_call.arguments diff --git a/src/agents/agent.py b/src/agents/agent.py index c479cc697..640a43793 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -32,8 +32,10 @@ from .lifecycle import AgentHooks, RunHooks from .mcp import MCPServer from .memory.session import Session - from .result import RunResult + from .result import RunResult, RunResultStreaming from .run import RunConfig + from .stream_events import StreamEvent + from .tool_context import ToolContext @dataclass @@ -382,9 +384,12 @@ def as_tool( self, tool_name: str | None, tool_description: str | None, - custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None, + custom_output_extractor: ( + Callable[[RunResult | RunResultStreaming], Awaitable[str]] | None + ) = None, is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True, + on_stream: Callable[[StreamEvent, ToolContext], MaybeAwaitable[None]] | None = None, run_config: RunConfig | None = None, max_turns: int | None = None, hooks: RunHooks[TContext] | None = None, @@ -409,6 +414,10 @@ def as_tool( is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run context and agent and returns whether the tool is enabled. Disabled tools are hidden from the LLM at runtime. + on_stream: Optional callback (sync or async) to receive streaming events from the nested + agent run. When provided, the nested agent is executed in streaming mode. + The callback receives (event, caller_tool_context) where caller_tool_context + provides access to the calling agent via caller_agent field. """ @function_tool( @@ -420,22 +429,46 @@ async def run_agent(context: RunContextWrapper, input: str) -> Any: from .run import DEFAULT_MAX_TURNS, Runner resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS - - output = await Runner.run( - starting_agent=self, - input=input, - context=context.context, - run_config=run_config, - max_turns=resolved_max_turns, - hooks=hooks, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - session=session, - ) + run_result: RunResult | RunResultStreaming + + if on_stream is not None: + run_result = Runner.run_streamed( + starting_agent=self, + input=input, + context=context.context, + run_config=run_config, + max_turns=resolved_max_turns, + hooks=hooks, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + session=session, + ) + async for event in run_result.stream_events(): + try: + maybe_result = on_stream(event, cast("ToolContext", context)) + if inspect.isawaitable(maybe_result): + await maybe_result + except Exception: + logger.exception( + "Error while handling on_stream event for agent tool %s.", + self.name, + ) + else: + run_result = await Runner.run( + starting_agent=self, + input=input, + context=context.context, + run_config=run_config, + max_turns=resolved_max_turns, + hooks=hooks, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + session=session, + ) if custom_output_extractor: - return await custom_output_extractor(output) + return await custom_output_extractor(run_result) - return output.final_output + return run_result.final_output return run_agent diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index a3cd1d3ea..c555f409e 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -422,6 +422,7 @@ async def _handle_tool_call( tool_name=event.name, tool_call_id=event.call_id, tool_arguments=event.arguments, + caller_agent=agent, ) result = await func_tool.on_invoke_tool(tool_context, event.arguments) @@ -448,6 +449,7 @@ async def _handle_tool_call( tool_name=event.name, tool_call_id=event.call_id, tool_arguments=event.arguments, + caller_agent=agent, ) # Execute the handoff to get the new agent diff --git a/src/agents/tool_context.py b/src/agents/tool_context.py index 5b81239f6..ce0b3014d 100644 --- a/src/agents/tool_context.py +++ b/src/agents/tool_context.py @@ -1,10 +1,13 @@ from dataclasses import dataclass, field, fields -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from openai.types.responses import ResponseFunctionToolCall from .run_context import RunContextWrapper, TContext +if TYPE_CHECKING: + from .agent import AgentBase + def _assert_must_pass_tool_call_id() -> str: raise ValueError("tool_call_id must be passed to ToolContext") @@ -18,6 +21,10 @@ def _assert_must_pass_tool_arguments() -> str: raise ValueError("tool_arguments must be passed to ToolContext") +def _assert_must_pass_caller_agent() -> "AgentBase": + raise ValueError("caller_agent must be passed to ToolContext") + + @dataclass class ToolContext(RunContextWrapper[TContext]): """The context of a tool call.""" @@ -31,11 +38,15 @@ class ToolContext(RunContextWrapper[TContext]): tool_arguments: str = field(default_factory=_assert_must_pass_tool_arguments) """The raw arguments string of the tool call.""" + caller_agent: "AgentBase" = field(default_factory=_assert_must_pass_caller_agent) + """The agent that called this tool.""" + @classmethod def from_agent_context( cls, context: RunContextWrapper[TContext], tool_call_id: str, + caller_agent: "AgentBase", tool_call: Optional[ResponseFunctionToolCall] = None, ) -> "ToolContext": """ @@ -51,5 +62,9 @@ def from_agent_context( ) return cls( - tool_name=tool_name, tool_call_id=tool_call_id, tool_arguments=tool_args, **base_values + tool_name=tool_name, + tool_call_id=tool_call_id, + tool_arguments=tool_args, + caller_agent=caller_agent, + **base_values, ) diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index 51d8edf20..7cb4c67ea 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, cast import pytest from openai.types.responses import ResponseOutputMessage, ResponseOutputText @@ -18,8 +18,11 @@ Session, TResponseInputItem, ) +from agents.stream_events import RawResponsesStreamEvent, StreamEvent from agents.tool_context import ToolContext +_caller_agent = Agent(name="test_caller_agent") + class BoolCtx(BaseModel): enable_tools: bool @@ -267,6 +270,7 @@ async def fake_run( tool_name="story_tool", tool_call_id="call_1", tool_arguments='{"input": "hello"}', + caller_agent=_caller_agent, ) output = await tool.on_invoke_tool(tool_context, '{"input": "hello"}') @@ -369,7 +373,379 @@ async def extractor(result) -> str: tool_name="summary_tool", tool_call_id="call_2", tool_arguments='{"input": "summarize this"}', + caller_agent=_caller_agent, ) output = await tool.on_invoke_tool(tool_context, '{"input": "summarize this"}') assert output == "custom output" + + +@pytest.mark.asyncio +async def test_agent_as_tool_streams_events_with_on_stream( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="streamer") + stream_events = [ + RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})), + RawResponsesStreamEvent(data=cast(Any, {"type": "output_text_delta", "delta": "hi"})), + ] + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "streamed output" + + async def stream_events(self): + for ev in stream_events: + yield ev + + run_calls: list[dict[str, Any]] = [] + + def fake_run_streamed( + cls, + starting_agent, + input, + *, + context, + max_turns, + hooks, + run_config, + previous_response_id, + auto_previous_response_id=False, + conversation_id, + session, + ): + run_calls.append( + { + "starting_agent": starting_agent, + "input": input, + "context": context, + "max_turns": max_turns, + "hooks": hooks, + "run_config": run_config, + "previous_response_id": previous_response_id, + "conversation_id": conversation_id, + "session": session, + } + ) + return DummyStreamingResult() + + async def unexpected_run(*args: Any, **kwargs: Any) -> None: + raise AssertionError("Runner.run should not be called when on_stream is provided.") + + monkeypatch.setattr(Runner, "run_streamed", classmethod(fake_run_streamed)) + monkeypatch.setattr(Runner, "run", classmethod(unexpected_run)) + + received_events: list[tuple[StreamEvent, ToolContext | None]] = [] + + async def on_stream(event: StreamEvent, caller_tool_context: ToolContext | None) -> None: + received_events.append((event, caller_tool_context)) + + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="stream_tool", + tool_description="Streams events", + on_stream=on_stream, + ), + ) + + caller_agent = Agent(name="caller_agent") + tool_context = ToolContext( + context=None, + tool_name="stream_tool", + tool_call_id="call-123", + tool_arguments='{"input": "run streaming"}', + caller_agent=caller_agent, + ) + output = await tool.on_invoke_tool(tool_context, '{"input": "run streaming"}') + + assert output == "streamed output" + assert len(received_events) == len(stream_events) + assert received_events[0][1] is not None + assert received_events[0][1].tool_call_id == "call-123" + assert received_events[0][1].tool_name == "stream_tool" + assert received_events[0][1].caller_agent is not None + assert received_events[0][1].caller_agent.name == "caller_agent" + assert received_events[0][0] == stream_events[0] + assert run_calls[0]["input"] == "run streaming" + + +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_works_with_custom_extractor( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="streamer") + stream_events = [RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))] + stream_events = [RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))] + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "raw output" + + async def stream_events(self): + for ev in stream_events: + yield ev + + streamed_instance = DummyStreamingResult() + + def fake_run_streamed( + cls, + starting_agent, + input, + *, + context, + max_turns, + hooks, + run_config, + previous_response_id, + auto_previous_response_id=False, + conversation_id, + session, + ): + return streamed_instance + + async def unexpected_run(*args: Any, **kwargs: Any) -> None: + raise AssertionError("Runner.run should not be called when on_stream is provided.") + + monkeypatch.setattr(Runner, "run_streamed", classmethod(fake_run_streamed)) + monkeypatch.setattr(Runner, "run", classmethod(unexpected_run)) + + received: list[Any] = [] + + async def extractor(result) -> str: + received.append(result) + return "custom value" + + callbacks: list[Any] = [] + + async def on_stream(event: StreamEvent, caller_tool_context: ToolContext | None) -> None: + callbacks.append(event) + + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="stream_tool", + tool_description="Streams events", + custom_output_extractor=extractor, + on_stream=on_stream, + ), + ) + + tool_context = ToolContext( + context=None, + tool_name="stream_tool", + tool_call_id="call-abc", + tool_arguments='{"input": "stream please"}', + caller_agent=_caller_agent, + ) + output = await tool.on_invoke_tool(tool_context, '{"input": "stream please"}') + + assert output == "custom value" + assert received == [streamed_instance] + assert callbacks == stream_events + + +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_accepts_sync_handler( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="sync_handler_agent") + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "ok" + + async def stream_events(self): + yield RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})) + + monkeypatch.setattr( + Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult()) + ) + monkeypatch.setattr( + Runner, + "run", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), + ) + + calls: list[str] = [] + + def sync_handler(event: StreamEvent, caller_tool_context: ToolContext | None) -> None: + calls.append(event.type) + + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="sync_tool", + tool_description="Uses sync handler", + on_stream=sync_handler, + ), + ) + tool_context = ToolContext( + context=None, + tool_name="sync_tool", + tool_call_id="call-sync", + tool_arguments='{"input": "go"}', + caller_agent=_caller_agent, + ) + + output = await tool.on_invoke_tool(tool_context, '{"input": "go"}') + + assert output == "ok" + assert calls == ["raw_response_event"] + + +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_handler_exception_does_not_fail_call( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="handler_error_agent") + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "ok" + + async def stream_events(self): + yield RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})) + + monkeypatch.setattr( + Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult()) + ) + monkeypatch.setattr( + Runner, + "run", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), + ) + + def bad_handler(event: StreamEvent, caller_tool_context: ToolContext | None) -> None: + raise RuntimeError("boom") + + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="error_tool", + tool_description="Handler throws", + on_stream=bad_handler, + ), + ) + tool_context = ToolContext( + context=None, + tool_name="error_tool", + tool_call_id="call-bad", + tool_arguments='{"input": "go"}', + caller_agent=_caller_agent, + ) + + output = await tool.on_invoke_tool(tool_context, '{"input": "go"}') + + assert output == "ok" + + +@pytest.mark.asyncio +async def test_agent_as_tool_without_stream_uses_run( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="nostream_agent") + + class DummyResult: + def __init__(self) -> None: + self.final_output = "plain" + + run_calls: list[dict[str, Any]] = [] + + async def fake_run( + cls, + starting_agent, + input, + *, + context, + max_turns, + hooks, + run_config, + previous_response_id, + auto_previous_response_id=False, + conversation_id, + session, + ): + run_calls.append({"input": input}) + return DummyResult() + + monkeypatch.setattr(Runner, "run", classmethod(fake_run)) + monkeypatch.setattr( + Runner, + "run_streamed", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no stream"))), + ) + + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="nostream_tool", + tool_description="No streaming path", + ), + ) + tool_context = ToolContext( + context=None, + tool_name="nostream_tool", + tool_call_id="call-no", + tool_arguments='{"input": "plain"}', + caller_agent=_caller_agent, + ) + + output = await tool.on_invoke_tool(tool_context, '{"input": "plain"}') + + assert output == "plain" + assert run_calls == [{"input": "plain"}] + + +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_sets_tool_call_id_from_context( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="direct_invocation_agent") + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "ok" + + async def stream_events(self): + yield RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})) + + monkeypatch.setattr( + Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult()) + ) + monkeypatch.setattr( + Runner, + "run", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), + ) + + captured: list[tuple[StreamEvent, ToolContext | None]] = [] + + async def on_stream(event: StreamEvent, caller_tool_context: ToolContext | None) -> None: + captured.append((event, caller_tool_context)) + + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="direct_stream_tool", + tool_description="Direct invocation", + on_stream=on_stream, + ), + ) + caller_agent = Agent(name="caller_agent") + tool_context = ToolContext( + context=None, + tool_name="direct_stream_tool", + tool_call_id="direct-call-id", + tool_arguments='{"input": "hi"}', + caller_agent=caller_agent, + ) + + output = await tool.on_invoke_tool(tool_context, '{"input": "hi"}') + + assert output == "ok" + assert captured[0][1] is not None + assert captured[0][1].tool_call_id == "direct-call-id" + assert captured[0][1].tool_name == "direct_stream_tool" + assert captured[0][1].caller_agent is not None + assert captured[0][1].caller_agent.name == "caller_agent" diff --git a/tests/test_function_tool.py b/tests/test_function_tool.py index 18107773d..d0f9420a4 100644 --- a/tests/test_function_tool.py +++ b/tests/test_function_tool.py @@ -16,6 +16,8 @@ from agents.tool import default_tool_error_function from agents.tool_context import ToolContext +_test_agent = Agent(name="test_agent") + def argless_function() -> str: return "ok" @@ -27,7 +29,14 @@ async def test_argless_function(): assert tool.name == "argless_function" result = await tool.on_invoke_tool( - ToolContext(context=None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), "" + ToolContext( + context=None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments="", + caller_agent=_test_agent, + ), + "", ) assert result == "ok" @@ -42,13 +51,22 @@ async def test_argless_with_context(): assert tool.name == "argless_with_context" result = await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), "" + ToolContext( + None, tool_name=tool.name, tool_call_id="1", tool_arguments="", caller_agent=_test_agent + ), + "", ) assert result == "ok" # Extra JSON should not raise an error result = await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1}'), + ToolContext( + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments='{"a": 1}', + caller_agent=_test_agent, + ), '{"a": 1}', ) assert result == "ok" @@ -64,13 +82,25 @@ async def test_simple_function(): assert tool.name == "simple_function" result = await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1}'), + ToolContext( + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments='{"a": 1}', + caller_agent=_test_agent, + ), '{"a": 1}', ) assert result == 6 result = await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1, "b": 2}'), + ToolContext( + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments='{"a": 1, "b": 2}', + caller_agent=_test_agent, + ), '{"a": 1, "b": 2}', ) assert result == 3 @@ -78,7 +108,14 @@ async def test_simple_function(): # Missing required argument should raise an error with pytest.raises(ModelBehaviorError): await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), "" + ToolContext( + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments="", + caller_agent=_test_agent, + ), + "", ) @@ -108,7 +145,13 @@ async def test_complex_args_function(): } ) result = await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=valid_json), + ToolContext( + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments=valid_json, + caller_agent=_test_agent, + ), valid_json, ) assert result == "6 hello10 hello" @@ -120,7 +163,13 @@ async def test_complex_args_function(): } ) result = await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=valid_json), + ToolContext( + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments=valid_json, + caller_agent=_test_agent, + ), valid_json, ) assert result == "3 hello10 hello" @@ -133,7 +182,13 @@ async def test_complex_args_function(): } ) result = await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=valid_json), + ToolContext( + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments=valid_json, + caller_agent=_test_agent, + ), valid_json, ) assert result == "3 hello10 world" @@ -142,7 +197,11 @@ async def test_complex_args_function(): with pytest.raises(ModelBehaviorError): await tool.on_invoke_tool( ToolContext( - None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"foo": {"a": 1}}' + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments='{"foo": {"a": 1}}', + caller_agent=_test_agent, ), '{"foo": {"a": 1}}', ) @@ -207,7 +266,11 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str: result = await tool.on_invoke_tool( ToolContext( - None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"data": "hello"}' + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments='{"data": "hello"}', + caller_agent=_test_agent, ), '{"data": "hello"}', ) @@ -230,6 +293,7 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str: tool_name=tool_not_strict.name, tool_call_id="1", tool_arguments='{"data": "hello", "bar": "baz"}', + caller_agent=_test_agent, ), '{"data": "hello", "bar": "baz"}', ) @@ -242,7 +306,13 @@ def my_func(a: int, b: int = 5): raise ValueError("test") tool = function_tool(my_func) - ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments="") + ctx = ToolContext( + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments="", + caller_agent=_test_agent, + ) result = await tool.on_invoke_tool(ctx, "") assert "Invalid JSON" in str(result) @@ -266,7 +336,13 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) -> return f"error_{error.__class__.__name__}" tool = function_tool(my_func, failure_error_function=custom_sync_error_function) - ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments="") + ctx = ToolContext( + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments="", + caller_agent=_test_agent, + ) result = await tool.on_invoke_tool(ctx, "") assert result == "error_ModelBehaviorError" @@ -290,7 +366,13 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) -> return f"error_{error.__class__.__name__}" tool = function_tool(my_func, failure_error_function=custom_sync_error_function) - ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments="") + ctx = ToolContext( + None, + tool_name=tool.name, + tool_call_id="1", + tool_arguments="", + caller_agent=_test_agent, + ) result = await tool.on_invoke_tool(ctx, "") assert result == "error_ModelBehaviorError" @@ -356,6 +438,12 @@ def boom() -> None: """Always raises to trigger the failure handler.""" raise RuntimeError("kapow") - ctx = ToolContext(None, tool_name=boom.name, tool_call_id="boom", tool_arguments="{}") + ctx = ToolContext( + None, + tool_name=boom.name, + tool_call_id="boom", + tool_arguments="{}", + caller_agent=_test_agent, + ) result = await boom.on_invoke_tool(ctx, "{}") assert result.startswith("handled:") diff --git a/tests/test_function_tool_decorator.py b/tests/test_function_tool_decorator.py index 2f5a38223..31f7177b3 100644 --- a/tests/test_function_tool_decorator.py +++ b/tests/test_function_tool_decorator.py @@ -5,10 +5,12 @@ import pytest from inline_snapshot import snapshot -from agents import function_tool +from agents import Agent, function_tool from agents.run_context import RunContextWrapper from agents.tool_context import ToolContext +_test_agent = Agent(name="test_agent") + class DummyContext: def __init__(self): @@ -17,7 +19,11 @@ def __init__(self): def ctx_wrapper() -> ToolContext[DummyContext]: return ToolContext( - context=DummyContext(), tool_name="dummy", tool_call_id="1", tool_arguments="" + context=DummyContext(), + tool_name="dummy", + tool_call_id="1", + tool_arguments="", + caller_agent=_test_agent, ) diff --git a/tests/test_tool_guardrails.py b/tests/test_tool_guardrails.py index 8ccaec0ad..81268adc5 100644 --- a/tests/test_tool_guardrails.py +++ b/tests/test_tool_guardrails.py @@ -19,6 +19,8 @@ from agents.tool_context import ToolContext from agents.tool_guardrails import tool_input_guardrail, tool_output_guardrail +_test_agent = Agent(name="test_agent") + def get_mock_tool_context(tool_arguments: str = '{"param": "value"}') -> ToolContext: """Helper to create a mock tool context for testing.""" @@ -27,6 +29,7 @@ def get_mock_tool_context(tool_arguments: str = '{"param": "value"}') -> ToolCon tool_name="test_tool", tool_call_id="call_123", tool_arguments=tool_arguments, + caller_agent=_test_agent, ) diff --git a/tests/test_tool_metadata.py b/tests/test_tool_metadata.py index ad6395e9b..db8f6848f 100644 --- a/tests/test_tool_metadata.py +++ b/tests/test_tool_metadata.py @@ -4,6 +4,7 @@ from openai.types.responses.tool_param import CodeInterpreter, ImageGeneration, Mcp +from agents import Agent from agents.computer import Computer from agents.run_context import RunContextWrapper from agents.tool import ( @@ -57,6 +58,7 @@ def test_shell_command_output_status_property() -> None: def test_tool_context_from_agent_context() -> None: ctx = RunContextWrapper(context={"foo": "bar"}) + test_agent = Agent(name="test_agent") tool_call = ToolContext.from_agent_context( ctx, tool_call_id="123", @@ -68,5 +70,6 @@ def test_tool_context_from_agent_context() -> None: "arguments": "{}", }, )(), + caller_agent=test_agent, ) assert tool_call.tool_name == "demo"