Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@
output_guardrail,
)
from .handoffs import Handoff, HandoffInputData, HandoffInputFilter, handoff
from .tool_guardrails import (
ToolGuardrailFunctionOutput,
ToolInputGuardrail,
ToolInputGuardrailData,
ToolOutputGuardrail,
ToolOutputGuardrailData,
tool_input_guardrail,
tool_output_guardrail,
)
from .items import (
HandoffCallItem,
HandoffOutputItem,
Expand Down Expand Up @@ -204,6 +213,13 @@ def enable_verbose_stdout_logging():
"GuardrailFunctionOutput",
"input_guardrail",
"output_guardrail",
"ToolInputGuardrail",
"ToolOutputGuardrail",
"ToolGuardrailFunctionOutput",
"ToolInputGuardrailData",
"ToolOutputGuardrailData",
"tool_input_guardrail",
"tool_output_guardrail",
"handoff",
"Handoff",
"HandoffInputData",
Expand Down
78 changes: 61 additions & 17 deletions src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@
Tool,
)
from .tool_context import ToolContext
from .tool_guardrails import (
ToolInputGuardrailData,
ToolOutputGuardrailData,
)
from .tracing import (
SpanError,
Trace,
Expand Down Expand Up @@ -556,24 +560,64 @@ async def run_single_tool(
if config.trace_include_sensitive_data:
span_fn.span_data.input = tool_call.arguments
try:
_, _, result = await asyncio.gather(
hooks.on_tool_start(tool_context, agent, func_tool),
(
agent.hooks.on_tool_start(tool_context, agent, func_tool)
if agent.hooks
else _coro.noop_coroutine()
),
func_tool.on_invoke_tool(tool_context, tool_call.arguments),
)
# 1) Run input tool guardrails, if any
final_result: Any | None = None
if func_tool.tool_input_guardrails:
for guardrail in func_tool.tool_input_guardrails:
gr_out = await guardrail.run(
ToolInputGuardrailData(
context=tool_context,
agent=agent,
tool_call=tool_call,
)
)
if gr_out.tripwire_triggered:
# Use the provided model message as the tool output
final_result = str(gr_out.model_message or "")
break

if final_result is None:
# 2) Actually run the tool
await asyncio.gather(
hooks.on_tool_start(tool_context, agent, func_tool),
(
agent.hooks.on_tool_start(tool_context, agent, func_tool)
if agent.hooks
else _coro.noop_coroutine()
),
)
real_result = await func_tool.on_invoke_tool(
tool_context, tool_call.arguments
)

await asyncio.gather(
hooks.on_tool_end(tool_context, agent, func_tool, result),
(
agent.hooks.on_tool_end(tool_context, agent, func_tool, result)
if agent.hooks
else _coro.noop_coroutine()
),
)
# 3) Run output tool guardrails, if any
final_result = real_result
if func_tool.tool_output_guardrails:
for guardrail in func_tool.tool_output_guardrails:
gr_out = await guardrail.run(
ToolOutputGuardrailData(
context=tool_context,
agent=agent,
tool_call=tool_call,
output=real_result,
)
)
if gr_out.tripwire_triggered:
final_result = str(gr_out.model_message or "")
break

# 4) Tool end hooks (with final result, which may have been overridden)
await asyncio.gather(
hooks.on_tool_end(tool_context, agent, func_tool, final_result),
(
agent.hooks.on_tool_end(
tool_context, agent, func_tool, final_result
)
if agent.hooks
else _coro.noop_coroutine()
),
)
result = final_result
except Exception as e:
_error_tracing.attach_error_to_current_span(
SpanError(
Expand Down
7 changes: 7 additions & 0 deletions src/agents/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ class FunctionTool:
and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool
based on your context/state."""

# Tool-specific guardrails
tool_input_guardrails: list["ToolInputGuardrail[Any]"] | None = None
"""Optional list of input guardrails to run before invoking this tool."""

tool_output_guardrails: list["ToolOutputGuardrail[Any]"] | None = None
"""Optional list of output guardrails to run after invoking this tool."""

def __post_init__(self):
if self.strict_json_schema:
self.params_json_schema = ensure_strict_json_schema(self.params_json_schema)
Expand Down
163 changes: 163 additions & 0 deletions src/agents/tool_guardrails.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from __future__ import annotations

import inspect
from collections.abc import Awaitable
from dataclasses import dataclass
from typing import Any, Callable, Generic, Optional, overload

from typing_extensions import TypeVar

from .agent import Agent
from .tool_context import ToolContext
from .util._types import MaybeAwaitable
from openai.types.responses import ResponseFunctionToolCall


@dataclass
class ToolGuardrailFunctionOutput:
"""The output of a tool guardrail function.

- `output_info`: Optional data about checks performed.
- `tripwire_triggered`: Whether the guardrail was tripped.
- `model_message`: Message to send back to the model as the tool output if tripped.
"""

output_info: Any
tripwire_triggered: bool
model_message: Optional[str] = None


@dataclass
class ToolInputGuardrailData:
"""Input data passed to a tool input guardrail function."""

context: ToolContext[Any]
agent: Agent[Any]
tool_call: ResponseFunctionToolCall


@dataclass
class ToolOutputGuardrailData(ToolInputGuardrailData):
"""Input data passed to a tool output guardrail function.

Extends input data with the tool's output.
"""

output: Any


TContext_co = TypeVar("TContext_co", bound=Any, covariant=True)


@dataclass
class ToolInputGuardrail(Generic[TContext_co]):
"""A guardrail that runs before a function tool is invoked."""

guardrail_function: Callable[[ToolInputGuardrailData], MaybeAwaitable[ToolGuardrailFunctionOutput]]
name: str | None = None

def get_name(self) -> str:
return self.name or self.guardrail_function.__name__

async def run(
self, data: ToolInputGuardrailData
) -> ToolGuardrailFunctionOutput:
result = self.guardrail_function(data)
if inspect.isawaitable(result):
return await result # type: ignore[return-value]
return result # type: ignore[return-value]


@dataclass
class ToolOutputGuardrail(Generic[TContext_co]):
"""A guardrail that runs after a function tool is invoked."""

guardrail_function: Callable[[ToolOutputGuardrailData], MaybeAwaitable[ToolGuardrailFunctionOutput]]
name: str | None = None

def get_name(self) -> str:
return self.name or self.guardrail_function.__name__

async def run(
self, data: ToolOutputGuardrailData
) -> ToolGuardrailFunctionOutput:
result = self.guardrail_function(data)
if inspect.isawaitable(result):
return await result # type: ignore[return-value]
return result # type: ignore[return-value]


# Decorators
_ToolInputFuncSync = Callable[[ToolInputGuardrailData], ToolGuardrailFunctionOutput]
_ToolInputFuncAsync = Callable[[ToolInputGuardrailData], Awaitable[ToolGuardrailFunctionOutput]]


@overload
def tool_input_guardrail(func: _ToolInputFuncSync): # type: ignore[overload-overlap]
...


@overload
def tool_input_guardrail(func: _ToolInputFuncAsync): # type: ignore[overload-overlap]
...


@overload
def tool_input_guardrail(*, name: str | None = None) -> Callable[[
_ToolInputFuncSync | _ToolInputFuncAsync
], ToolInputGuardrail[Any]]: ...


def tool_input_guardrail(
func: _ToolInputFuncSync | _ToolInputFuncAsync | None = None,
*,
name: str | None = None,
) -> ToolInputGuardrail[Any] | Callable[[
_ToolInputFuncSync | _ToolInputFuncAsync
], ToolInputGuardrail[Any]]:
"""Decorator to create a ToolInputGuardrail from a function."""

def decorator(f: _ToolInputFuncSync | _ToolInputFuncAsync) -> ToolInputGuardrail[Any]:
return ToolInputGuardrail(guardrail_function=f, name=name or f.__name__)

if func is not None:
return decorator(func)
return decorator


_ToolOutputFuncSync = Callable[[ToolOutputGuardrailData], ToolGuardrailFunctionOutput]
_ToolOutputFuncAsync = Callable[[ToolOutputGuardrailData], Awaitable[ToolGuardrailFunctionOutput]]


@overload
def tool_output_guardrail(func: _ToolOutputFuncSync): # type: ignore[overload-overlap]
...


@overload
def tool_output_guardrail(func: _ToolOutputFuncAsync): # type: ignore[overload-overlap]
...


@overload
def tool_output_guardrail(*, name: str | None = None) -> Callable[[
_ToolOutputFuncSync | _ToolOutputFuncAsync
], ToolOutputGuardrail[Any]]: ...


def tool_output_guardrail(
func: _ToolOutputFuncSync | _ToolOutputFuncAsync | None = None,
*,
name: str | None = None,
) -> ToolOutputGuardrail[Any] | Callable[[
_ToolOutputFuncSync | _ToolOutputFuncAsync
], ToolOutputGuardrail[Any]]:
"""Decorator to create a ToolOutputGuardrail from a function."""

def decorator(f: _ToolOutputFuncSync | _ToolOutputFuncAsync) -> ToolOutputGuardrail[Any]:
return ToolOutputGuardrail(guardrail_function=f, name=name or f.__name__)

if func is not None:
return decorator(func)
return decorator

Loading