Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@
from .models.interface import Model
from .prompts import DynamicPromptFunction, Prompt, PromptUtil
from .run_context import RunContextWrapper, TContext
from .tool import FunctionTool, FunctionToolResult, Tool, function_tool
from .tool import (
FunctionTool,
FunctionToolResult,
Tool,
ToolErrorFunction,
default_tool_error_function,
function_tool,
)
from .tool_context import ToolContext
from .util import _transforms
from .util._types import MaybeAwaitable
Expand Down Expand Up @@ -411,6 +418,7 @@ def as_tool(
previous_response_id: str | None = None,
conversation_id: str | None = None,
session: Session | None = None,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
) -> Tool:
"""Transform this agent into a tool, callable by other agents.

Expand All @@ -433,12 +441,15 @@ def as_tool(
agent run. The callback receives an `AgentToolStreamEvent` containing the nested
agent, the originating tool call (when available), and each stream event. When
provided, the nested agent is executed in streaming mode.
failure_error_function: If provided, generate an error message when the tool (agent) run
fails. The message is sent to the LLM. If None, the exception is raised instead.
"""

@function_tool(
name_override=tool_name or _transforms.transform_string_function_style(self.name),
description_override=tool_description or "",
is_enabled=is_enabled,
failure_error_function=failure_error_function,
)
async def run_agent(context: ToolContext, input: str) -> Any:
from .run import DEFAULT_MAX_TURNS, Runner
Expand Down
95 changes: 95 additions & 0 deletions tests/test_agent_as_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,3 +947,98 @@ async def on_stream(event: AgentToolStreamEvent) -> None:

assert output == "ok"
assert captured[0]["tool_call"] is tool_call


@pytest.mark.asyncio
async def test_agent_as_tool_failure_error_function_none_reraises(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""If failure_error_function=None, exceptions should propagate to the caller."""
agent = Agent(name="failing_agent")

async def fake_run(
cls,
starting_agent,
input,
*,
context,
max_turns,
hooks,
run_config,
previous_response_id,
conversation_id,
session,
):
assert starting_agent is agent
assert input == "hello"
raise RuntimeError("test failure")

monkeypatch.setattr(Runner, "run", classmethod(fake_run))

tool = agent.as_tool(
tool_name="failing_agent_tool",
tool_description="Agent tool that raises",
is_enabled=True,
failure_error_function=None,
)

assert isinstance(tool, FunctionTool)

tool_context = ToolContext(
context=None,
tool_name="failing_agent_tool",
tool_call_id="call_1",
tool_arguments='{"input": "hello"}',
)

with pytest.raises(RuntimeError, match="test failure"):
await tool.on_invoke_tool(tool_context, '{"input": "hello"}')


@pytest.mark.asyncio
async def test_agent_as_tool_failure_error_function_custom_handler(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Custom failure_error_function should be used to convert exceptions into tool output."""
agent = Agent(name="failing_agent")

async def fake_run(
cls,
starting_agent,
input,
*,
context,
max_turns,
hooks,
run_config,
previous_response_id,
conversation_id,
session,
):
assert starting_agent is agent
assert input == "hello"
raise ValueError("test failure")

monkeypatch.setattr(Runner, "run", classmethod(fake_run))

def custom_failure_handler(ctx: RunContextWrapper[Any], error: Exception) -> str:
return f"handled:{type(error).__name__}:{error}"

tool = agent.as_tool(
tool_name="failing_agent_tool",
tool_description="Agent tool that raises",
is_enabled=True,
failure_error_function=custom_failure_handler,
)

assert isinstance(tool, FunctionTool)

tool_context = ToolContext(
context=None,
tool_name="failing_agent_tool",
tool_call_id="call_1",
tool_arguments='{"input": "hello"}',
)

result = await tool.on_invoke_tool(tool_context, '{"input": "hello"}')
assert result == "handled:ValueError:test failure"