diff --git a/src/agents/agent.py b/src/agents/agent.py index a061926b1..c479cc697 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -13,7 +13,6 @@ from .agent_output import AgentOutputSchemaBase from .guardrail import InputGuardrail, OutputGuardrail from .handoffs import Handoff -from .items import ItemHelpers from .logger import logger from .mcp import MCPUtil from .model_settings import ModelSettings @@ -417,7 +416,7 @@ def as_tool( description_override=tool_description or "", is_enabled=is_enabled, ) - async def run_agent(context: RunContextWrapper, input: str) -> str: + async def run_agent(context: RunContextWrapper, input: str) -> Any: from .run import DEFAULT_MAX_TURNS, Runner resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS @@ -436,7 +435,7 @@ async def run_agent(context: RunContextWrapper, input: str) -> str: if custom_output_extractor: return await custom_output_extractor(output) - return ItemHelpers.text_message_outputs(output.new_items) + return output.final_output return run_agent diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index 0f5736b03..51d8edf20 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -225,30 +225,15 @@ async def custom_extractor(result): @pytest.mark.asyncio -async def test_agent_as_tool_returns_concatenated_text(monkeypatch: pytest.MonkeyPatch) -> None: - """Agent tool should use default text aggregation when no custom extractor is provided.""" +async def test_agent_as_tool_returns_final_output(monkeypatch: pytest.MonkeyPatch) -> None: + """Agent tool should return final_output when no custom extractor is provided.""" agent = Agent(name="storyteller") - message = ResponseOutputMessage( - id="msg_1", - role="assistant", - status="completed", - type="message", - content=[ - ResponseOutputText( - annotations=[], - text="Hello world", - type="output_text", - logprobs=[], - ) - ], - ) - result = type( "DummyResult", (), - {"new_items": [MessageOutputItem(agent=agent, raw_item=message)]}, + {"final_output": "Hello world"}, )() async def fake_run(