Skip to content

Add tool call parameters for on_tool_start hook #253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions .vscode/launch.json

This file was deleted.

7 changes: 0 additions & 7 deletions .vscode/settings.json

This file was deleted.

6 changes: 3 additions & 3 deletions examples/basic/agent_lifecycle_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import BaseModel

from agents import Agent, AgentHooks, RunContextWrapper, Runner, Tool, function_tool
from agents import Action, Agent, AgentHooks, RunContextWrapper, Runner, Tool, function_tool


class CustomAgentHooks(AgentHooks):
Expand All @@ -28,10 +28,10 @@ async def on_handoff(self, context: RunContextWrapper, agent: Agent, source: Age
f"### ({self.display_name}) {self.event_counter}: Agent {source.name} handed off to {agent.name}"
)

async def on_tool_start(self, context: RunContextWrapper, agent: Agent, tool: Tool) -> None:
async def on_tool_start(self, context: RunContextWrapper, agent: Agent, action: Action) -> None:
self.event_counter += 1
print(
f"### ({self.display_name}) {self.event_counter}: Agent {agent.name} started tool {tool.name}"
f"### ({self.display_name}) {self.event_counter}: Agent {agent.name} started tool {action.function_tool.name} with arguments {action.tool_call.arguments}"
)

async def on_tool_end(
Expand Down
6 changes: 3 additions & 3 deletions examples/basic/lifecycle_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import BaseModel

from agents import Agent, RunContextWrapper, RunHooks, Runner, Tool, Usage, function_tool
from agents import Action, Agent, RunContextWrapper, RunHooks, Runner, Tool, Usage, function_tool


class ExampleHooks(RunHooks):
Expand All @@ -26,10 +26,10 @@ async def on_agent_end(self, context: RunContextWrapper, agent: Agent, output: A
f"### {self.event_counter}: Agent {agent.name} ended with output {output}. Usage: {self._usage_to_str(context.usage)}"
)

async def on_tool_start(self, context: RunContextWrapper, agent: Agent, tool: Tool) -> None:
async def on_tool_start(self, context: RunContextWrapper, agent: Agent, action: Action) -> None:
self.event_counter += 1
print(
f"### {self.event_counter}: Tool {tool.name} started. Usage: {self._usage_to_str(context.usage)}"
f"### {self.event_counter}: Tool {action.function_tool.tool.name} started. Usage: {self._usage_to_str(context.usage)}"
)

async def on_tool_end(
Expand Down
2 changes: 2 additions & 0 deletions src/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
StreamEvent,
)
from .tool import (
Action,
CodeInterpreterTool,
ComputerTool,
FileSearchTool,
Expand Down Expand Up @@ -239,6 +240,7 @@ def enable_verbose_stdout_logging():
"MCPToolApprovalRequest",
"MCPToolApprovalFunctionResult",
"function_tool",
"Action",
"Usage",
"add_trace_processor",
"agent_span",
Expand Down
32 changes: 10 additions & 22 deletions src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@
LocalShellTool,
MCPToolApprovalRequest,
Tool,
ToolRunComputerAction,
ToolRunFunction,
)
from .tool_context import ToolContext
from .tracing import (
Expand Down Expand Up @@ -126,19 +128,6 @@ class ToolRunHandoff:
handoff: Handoff
tool_call: ResponseFunctionToolCall


@dataclass
class ToolRunFunction:
tool_call: ResponseFunctionToolCall
function_tool: FunctionTool


@dataclass
class ToolRunComputerAction:
tool_call: ResponseComputerToolCall
computer_tool: ComputerTool


@dataclass
class ToolRunMCPApprovalRequest:
request_item: McpApprovalRequest
Expand Down Expand Up @@ -544,9 +533,9 @@ async def execute_function_tool_calls(
context_wrapper: RunContextWrapper[TContext],
config: RunConfig,
) -> list[FunctionToolResult]:
async def run_single_tool(
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
) -> Any:
async def run_single_tool(action: ToolRunFunction) -> Any:
func_tool = action.function_tool
tool_call = action.tool_call
with function_span(func_tool.name) as span_fn:
tool_context = ToolContext.from_agent_context(
context_wrapper,
Expand All @@ -557,9 +546,9 @@ async def run_single_tool(
span_fn.span_data.input = tool_call.arguments
try:
_, _, result = await asyncio.gather(
hooks.on_tool_start(tool_context, agent, func_tool),
hooks.on_tool_start(tool_context, agent, action),
(
agent.hooks.on_tool_start(tool_context, agent, func_tool)
agent.hooks.on_tool_start(tool_context, agent, action)
if agent.hooks
else _coro.noop_coroutine()
),
Expand Down Expand Up @@ -591,8 +580,7 @@ async def run_single_tool(

tasks = []
for tool_run in tool_runs:
function_tool = tool_run.function_tool
tasks.append(run_single_tool(function_tool, tool_run.tool_call))
tasks.append(run_single_tool(tool_run))

results = await asyncio.gather(*tasks)

Expand Down Expand Up @@ -1039,9 +1027,9 @@ async def execute(
)

_, _, output = await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, action.computer_tool),
hooks.on_tool_start(context_wrapper, agent, action),
(
agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool)
agent.hooks.on_tool_start(context_wrapper, agent, action)
if agent.hooks
else _coro.noop_coroutine()
),
Expand Down
10 changes: 5 additions & 5 deletions src/agents/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from .agent import Agent, AgentBase
from .run_context import RunContextWrapper, TContext
from .tool import Tool
from .tool import Action, Tool

TAgent = TypeVar("TAgent", bound=AgentBase, default=AgentBase)

Expand Down Expand Up @@ -39,8 +39,8 @@ async def on_handoff(
async def on_tool_start(
self,
context: RunContextWrapper[TContext],
agent: TAgent,
tool: Tool,
agent: Agent[TContext],
action: Action,
) -> None:
"""Called before a tool is invoked."""
pass
Expand Down Expand Up @@ -90,8 +90,8 @@ async def on_handoff(
async def on_tool_start(
self,
context: RunContextWrapper[TContext],
agent: TAgent,
tool: Tool,
agent: Agent[TContext],
action: Action,
) -> None:
"""Called before a tool is invoked."""
pass
Expand Down
14 changes: 14 additions & 0 deletions src/agents/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Literal, Union, overload

from openai.types.responses import ResponseFunctionToolCall
from openai.types.responses.file_search_tool_param import Filters, RankingOptions
from openai.types.responses.response_computer_tool_call import (
PendingSafetyCheck,
Expand Down Expand Up @@ -281,6 +282,19 @@ def name(self):
]
"""A tool that can be used in an agent."""

@dataclass
class ToolRunFunction:
tool_call: ResponseFunctionToolCall
function_tool: FunctionTool


@dataclass
class ToolRunComputerAction:
tool_call: ResponseComputerToolCall
computer_tool: ComputerTool

Action = Union[ToolRunFunction, ToolRunComputerAction]
"""An action that can be performed by an agent. It contains the tool call and the tool"""

def default_tool_error_function(ctx: RunContextWrapper[Any], error: Exception) -> str:
"""The default tool error function, which just returns a generic error message."""
Expand Down
4 changes: 2 additions & 2 deletions tests/test_agent_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
from typing_extensions import TypedDict

from agents.agent import Agent
from agents.agent import Action, Agent
from agents.lifecycle import AgentHooks
from agents.run import Runner
from agents.run_context import RunContextWrapper, TContext
Expand Down Expand Up @@ -53,7 +53,7 @@ async def on_tool_start(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
tool: Tool,
action: Action,
) -> None:
self.events["on_tool_start"] += 1

Expand Down
13 changes: 7 additions & 6 deletions tests/test_computer_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)

from agents import (
Action,
Agent,
AgentHooks,
AsyncComputer,
Expand All @@ -32,9 +33,9 @@
RunContextWrapper,
RunHooks,
)
from agents._run_impl import ComputerAction, RunImpl, ToolRunComputerAction
from agents._run_impl import ComputerAction, RunImpl
from agents.items import ToolCallOutputItem
from agents.tool import ComputerToolSafetyCheckData
from agents.tool import ComputerToolSafetyCheckData, ToolRunComputerAction


class LoggingComputer(Computer):
Expand Down Expand Up @@ -224,9 +225,9 @@ def __init__(self) -> None:
self.ended: list[tuple[Agent[Any], Any, str]] = []

async def on_tool_start(
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any
self, context: RunContextWrapper[Any], agent: Agent[Any], action: Action,
) -> None:
self.started.append((agent, tool))
self.started.append((agent, action.computer_tool))

async def on_tool_end(
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, result: str
Expand All @@ -243,9 +244,9 @@ def __init__(self) -> None:
self.ended: list[tuple[Agent[Any], Any, str]] = []

async def on_tool_start(
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any
self, context: RunContextWrapper[Any], agent: Agent[Any], action: Action,
) -> None:
self.started.append((agent, tool))
self.started.append((agent, action.computer_tool))

async def on_tool_end(
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, result: str
Expand Down