Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions src/guardrails/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,33 @@ async def stage_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data
return guardrail_functions


def _resolve_agent_instructions(instructions: str | None, prompt: Any | None) -> str | None:
"""Derive instructions from explicit input or prompt.

Args:
instructions: Explicit instructions provided by the caller.
prompt: Optional prompt object or string supplied to the agent.

Returns:
A string containing the agent instructions when available, otherwise ``None``.
"""
if instructions is not None:
return instructions

if prompt is None:
return None

if isinstance(prompt, str):
return prompt

for attr_name in ("instructions", "text", "content"):
candidate = getattr(prompt, attr_name, None)
if isinstance(candidate, str):
return candidate

return None


class GuardrailAgent:
"""Drop-in replacement for Agents SDK Agent with automatic guardrails integration.

Expand Down Expand Up @@ -492,7 +519,7 @@ def __new__(
cls,
config: str | Path | dict[str, Any],
name: str,
instructions: str,
instructions: str | None = None,
raise_guardrail_errors: bool = False,
block_on_tool_violations: bool = False,
**agent_kwargs: Any,
Expand All @@ -511,7 +538,7 @@ def __new__(
Args:
config: Pipeline configuration (file path, dict, or JSON string)
name: Agent name
instructions: Agent instructions
instructions: Agent instructions. When omitted, a ``prompt`` argument must be provided.
raise_guardrail_errors: If True, raise exceptions when guardrails fail to execute.
If False (default), treat guardrail errors as safe and continue execution.
block_on_tool_violations: If True, tool guardrail violations raise exceptions (halt execution).
Expand Down Expand Up @@ -614,5 +641,19 @@ def __new__(
)
_attach_guardrail_to_tools(tools, tool_output_gr, "output")

prompt_arg: Any | None = agent_kwargs.get("prompt")
resolved_instructions = _resolve_agent_instructions(instructions, prompt_arg)

if resolved_instructions is None and prompt_arg is None:
raise ValueError(
"GuardrailAgent requires either 'instructions' or 'prompt' to initialize the underlying Agent."
)

# Create and return a regular Agent instance with guardrails configured
return Agent(name=name, instructions=instructions, input_guardrails=input_guardrails, output_guardrails=output_guardrails, **agent_kwargs)
return Agent(
name=name,
instructions=resolved_instructions,
input_guardrails=input_guardrails,
output_guardrails=output_guardrails,
**agent_kwargs,
)
51 changes: 48 additions & 3 deletions tests/unit/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
import types
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from types import SimpleNamespace
from typing import Any

import pytest

from guardrails._openai_utils import SAFETY_IDENTIFIER_HEADER, SAFETY_IDENTIFIER_VALUE
from guardrails.types import GuardrailResult
guardrails_pkg = types.ModuleType("guardrails")
guardrails_pkg.__path__ = [str(Path(__file__).resolve().parents[2] / "src" / "guardrails")]
sys.modules.setdefault("guardrails", guardrails_pkg)

from guardrails._openai_utils import SAFETY_IDENTIFIER_HEADER, SAFETY_IDENTIFIER_VALUE # noqa: E402
from guardrails.types import GuardrailResult # noqa: E402

# ---------------------------------------------------------------------------
# Stub agents SDK module so guardrails.agents can import required symbols.
Expand Down Expand Up @@ -94,7 +99,8 @@ class Agent:
"""Trivial Agent stub storing initialization args for assertions."""

name: str
instructions: str
instructions: str | None = None
prompt: Any | None = None
input_guardrails: list[Callable] | None = None
output_guardrails: list[Callable] | None = None
tools: list[Any] | None = None
Expand Down Expand Up @@ -597,3 +603,42 @@ def test_guardrail_agent_without_tools(monkeypatch: pytest.MonkeyPatch) -> None:
agent_instance = agents.GuardrailAgent(config={}, name="NoTools", instructions="None")

assert getattr(agent_instance, "input_guardrails", []) == [] # noqa: S101


def test_guardrail_agent_allows_prompt_without_instructions(monkeypatch: pytest.MonkeyPatch) -> None:
"""Prompt attribute text becomes derived instructions."""
pipeline = SimpleNamespace(pre_flight=None, input=None, output=None)

monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline, raising=False)
monkeypatch.setattr(runtime_module, "instantiate_guardrails", lambda *args, **kwargs: [], raising=False)

prompt = SimpleNamespace(text="Serve customers helpfully.")

agent_instance = agents.GuardrailAgent(config={}, name="PromptOnly", prompt=prompt)

assert agent_instance.prompt is prompt # noqa: S101
assert agent_instance.instructions == "Serve customers helpfully." # noqa: S101


def test_guardrail_agent_accepts_string_prompt(monkeypatch: pytest.MonkeyPatch) -> None:
"""String prompts populate missing instructions automatically."""
pipeline = SimpleNamespace(pre_flight=None, input=None, output=None)

monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline, raising=False)
monkeypatch.setattr(runtime_module, "instantiate_guardrails", lambda *args, **kwargs: [], raising=False)

agent_instance = agents.GuardrailAgent(config={}, name="PromptStr", prompt="Be concise.")

assert agent_instance.prompt == "Be concise." # noqa: S101
assert agent_instance.instructions == "Be concise." # noqa: S101


def test_guardrail_agent_requires_instructions_or_prompt(monkeypatch: pytest.MonkeyPatch) -> None:
"""GuardrailAgent requires instructions or prompt for construction."""
pipeline = SimpleNamespace(pre_flight=None, input=None, output=None)

monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline, raising=False)
monkeypatch.setattr(runtime_module, "instantiate_guardrails", lambda *args, **kwargs: [], raising=False)

with pytest.raises(ValueError):
agents.GuardrailAgent(config={}, name="Missing")
Loading