-
Notifications
You must be signed in to change notification settings - Fork 3k
feat: Add on_stream to agents as tools #2169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| import asyncio | ||
|
|
||
| from agents import Agent, AgentToolStreamEvent, ModelSettings, Runner, function_tool, trace | ||
|
|
||
|
|
||
| @function_tool( | ||
| name_override="billing_status_checker", | ||
| description_override="Answer questions about customer billing status.", | ||
| ) | ||
| def billing_status_checker(customer_id: str | None = None, question: str = "") -> str: | ||
| """Return a canned billing answer or a fallback when the question is unrelated.""" | ||
| normalized = question.lower() | ||
| if "bill" in normalized or "billing" in normalized: | ||
| return f"This customer (ID: {customer_id})'s bill is $100" | ||
| return "I can only answer questions about billing." | ||
|
|
||
|
|
||
| def handle_stream(event: AgentToolStreamEvent) -> None: | ||
| """Print streaming events emitted by the nested billing agent.""" | ||
| stream = event["event"] | ||
| print(f"[stream] agent={event['agent_name']} type={stream.type} {stream}") | ||
|
|
||
|
|
||
| async def main() -> None: | ||
| with trace("Agents as tools streaming example"): | ||
| billing_agent = Agent( | ||
| name="Billing Agent", | ||
| instructions="You are a billing agent that answers billing questions.", | ||
| model_settings=ModelSettings(tool_choice="required"), | ||
| tools=[billing_status_checker], | ||
| ) | ||
|
|
||
| billing_agent_tool = billing_agent.as_tool( | ||
| tool_name="billing_agent", | ||
| tool_description="You are a billing agent that answers billing questions.", | ||
| on_stream=handle_stream, | ||
| ) | ||
|
|
||
| main_agent = Agent( | ||
| name="Customer Support Agent", | ||
| instructions=( | ||
| "You are a customer support agent. Always call the billing agent to answer billing " | ||
| "questions and return the billing agent response to the user." | ||
| ), | ||
| tools=[billing_agent_tool], | ||
| ) | ||
|
|
||
| result = await Runner.run( | ||
| main_agent, | ||
| "Hello, my customer ID is ABC123. How much is my bill for this month?", | ||
| ) | ||
|
|
||
| print(f"\nFinal response:\n{result.final_output}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| asyncio.run(main()) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,8 +32,9 @@ | |
| from .lifecycle import AgentHooks, RunHooks | ||
| from .mcp import MCPServer | ||
| from .memory.session import Session | ||
| from .result import RunResult | ||
| from .result import RunResult, RunResultStreaming | ||
| from .run import RunConfig | ||
| from .stream_events import StreamEvent | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -58,6 +59,19 @@ class ToolsToFinalOutputResult: | |
| """ | ||
|
|
||
|
|
||
| class AgentToolStreamEvent(TypedDict): | ||
| """Streaming event emitted when an agent is invoked as a tool.""" | ||
|
|
||
| event: StreamEvent | ||
| """The streaming event from the nested agent run.""" | ||
|
|
||
| agent_name: str | ||
| """The name of the nested agent emitting the event.""" | ||
|
|
||
| tool_call_id: str | None | ||
| """The originating tool call ID, if available.""" | ||
|
|
||
|
|
||
| class StopAtTools(TypedDict): | ||
| stop_at_tool_names: list[str] | ||
| """A list of tool names, any of which will stop the agent from running further.""" | ||
|
|
@@ -382,9 +396,12 @@ def as_tool( | |
| self, | ||
| tool_name: str | None, | ||
| tool_description: str | None, | ||
| custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None, | ||
| custom_output_extractor: ( | ||
| Callable[[RunResult | RunResultStreaming], Awaitable[str]] | None | ||
| ) = None, | ||
| is_enabled: bool | ||
| | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True, | ||
| on_stream: Callable[[AgentToolStreamEvent], MaybeAwaitable[None]] | None = None, | ||
| run_config: RunConfig | None = None, | ||
| max_turns: int | None = None, | ||
| hooks: RunHooks[TContext] | None = None, | ||
|
|
@@ -409,6 +426,8 @@ def as_tool( | |
| is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run | ||
| context and agent and returns whether the tool is enabled. Disabled tools are hidden | ||
| from the LLM at runtime. | ||
| on_stream: Optional callback (sync or async) to receive streaming events from the nested | ||
| agent run. When provided, the nested agent is executed in streaming mode. | ||
| """ | ||
|
|
||
| @function_tool( | ||
|
|
@@ -420,22 +439,51 @@ async def run_agent(context: RunContextWrapper, input: str) -> Any: | |
| from .run import DEFAULT_MAX_TURNS, Runner | ||
|
|
||
| resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS | ||
|
|
||
| output = await Runner.run( | ||
| starting_agent=self, | ||
| input=input, | ||
| context=context.context, | ||
| run_config=run_config, | ||
| max_turns=resolved_max_turns, | ||
| hooks=hooks, | ||
| previous_response_id=previous_response_id, | ||
| conversation_id=conversation_id, | ||
| session=session, | ||
| ) | ||
| run_result: RunResult | RunResultStreaming | ||
|
|
||
| if on_stream is not None: | ||
| run_result = Runner.run_streamed( | ||
| starting_agent=self, | ||
| input=input, | ||
| context=context.context, | ||
| run_config=run_config, | ||
| max_turns=resolved_max_turns, | ||
| hooks=hooks, | ||
| previous_response_id=previous_response_id, | ||
| conversation_id=conversation_id, | ||
| session=session, | ||
| ) | ||
| async for event in run_result.stream_events(): | ||
| payload: AgentToolStreamEvent = { | ||
| "event": event, | ||
| "agent_name": self.name, | ||
| "tool_call_id": getattr(context, "tool_call_id", None), | ||
| } | ||
| try: | ||
| maybe_result = on_stream(payload) | ||
| if inspect.isawaitable(maybe_result): | ||
| await maybe_result | ||
|
Comment on lines
+463
to
+465
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this can be kinda bad since
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good call; actually i made these executions async in the TS SDK. so will make this more efficient and consistent. |
||
| except Exception: | ||
| logger.exception( | ||
| "Error while handling on_stream event for agent tool %s.", | ||
| self.name, | ||
| ) | ||
| else: | ||
| run_result = await Runner.run( | ||
| starting_agent=self, | ||
| input=input, | ||
| context=context.context, | ||
| run_config=run_config, | ||
| max_turns=resolved_max_turns, | ||
| hooks=hooks, | ||
| previous_response_id=previous_response_id, | ||
| conversation_id=conversation_id, | ||
| session=session, | ||
| ) | ||
| if custom_output_extractor: | ||
| return await custom_output_extractor(output) | ||
| return await custom_output_extractor(run_result) | ||
|
|
||
| return output.final_output | ||
| return run_result.final_output | ||
|
|
||
| return run_agent | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of
agent_name, why not pass theAgentobject?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah this is true. i will revisit the properties in this object