diff --git a/src/guardrails/agents.py b/src/guardrails/agents.py index 0645081..a0c9168 100644 --- a/src/guardrails/agents.py +++ b/src/guardrails/agents.py @@ -451,6 +451,33 @@ async def stage_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data return guardrail_functions +def _resolve_agent_instructions(instructions: str | None, prompt: Any | None) -> str | None: + """Derive instructions from explicit input or prompt. + + Args: + instructions: Explicit instructions provided by the caller. + prompt: Optional prompt object or string supplied to the agent. + + Returns: + A string containing the agent instructions when available, otherwise ``None``. + """ + if instructions is not None: + return instructions + + if prompt is None: + return None + + if isinstance(prompt, str): + return prompt + + for attr_name in ("instructions", "text", "content"): + candidate = getattr(prompt, attr_name, None) + if isinstance(candidate, str): + return candidate + + return None + + class GuardrailAgent: """Drop-in replacement for Agents SDK Agent with automatic guardrails integration. @@ -492,7 +519,7 @@ def __new__( cls, config: str | Path | dict[str, Any], name: str, - instructions: str, + instructions: str | None = None, raise_guardrail_errors: bool = False, block_on_tool_violations: bool = False, **agent_kwargs: Any, @@ -511,7 +538,7 @@ def __new__( Args: config: Pipeline configuration (file path, dict, or JSON string) name: Agent name - instructions: Agent instructions + instructions: Agent instructions. When omitted, a ``prompt`` argument must be provided. raise_guardrail_errors: If True, raise exceptions when guardrails fail to execute. If False (default), treat guardrail errors as safe and continue execution. block_on_tool_violations: If True, tool guardrail violations raise exceptions (halt execution). @@ -614,5 +641,19 @@ def __new__( ) _attach_guardrail_to_tools(tools, tool_output_gr, "output") + prompt_arg: Any | None = agent_kwargs.get("prompt") + resolved_instructions = _resolve_agent_instructions(instructions, prompt_arg) + + if resolved_instructions is None and prompt_arg is None: + raise ValueError( + "GuardrailAgent requires either 'instructions' or 'prompt' to initialize the underlying Agent." + ) + # Create and return a regular Agent instance with guardrails configured - return Agent(name=name, instructions=instructions, input_guardrails=input_guardrails, output_guardrails=output_guardrails, **agent_kwargs) + return Agent( + name=name, + instructions=resolved_instructions, + input_guardrails=input_guardrails, + output_guardrails=output_guardrails, + **agent_kwargs, + ) diff --git a/tests/unit/test_agents.py b/tests/unit/test_agents.py index 0aa1c4d..5725e19 100644 --- a/tests/unit/test_agents.py +++ b/tests/unit/test_agents.py @@ -6,13 +6,18 @@ import types from collections.abc import Callable from dataclasses import dataclass +from pathlib import Path from types import SimpleNamespace from typing import Any import pytest -from guardrails._openai_utils import SAFETY_IDENTIFIER_HEADER, SAFETY_IDENTIFIER_VALUE -from guardrails.types import GuardrailResult +guardrails_pkg = types.ModuleType("guardrails") +guardrails_pkg.__path__ = [str(Path(__file__).resolve().parents[2] / "src" / "guardrails")] +sys.modules.setdefault("guardrails", guardrails_pkg) + +from guardrails._openai_utils import SAFETY_IDENTIFIER_HEADER, SAFETY_IDENTIFIER_VALUE # noqa: E402 +from guardrails.types import GuardrailResult # noqa: E402 # --------------------------------------------------------------------------- # Stub agents SDK module so guardrails.agents can import required symbols. @@ -94,7 +99,8 @@ class Agent: """Trivial Agent stub storing initialization args for assertions.""" name: str - instructions: str + instructions: str | None = None + prompt: Any | None = None input_guardrails: list[Callable] | None = None output_guardrails: list[Callable] | None = None tools: list[Any] | None = None @@ -597,3 +603,42 @@ def test_guardrail_agent_without_tools(monkeypatch: pytest.MonkeyPatch) -> None: agent_instance = agents.GuardrailAgent(config={}, name="NoTools", instructions="None") assert getattr(agent_instance, "input_guardrails", []) == [] # noqa: S101 + + +def test_guardrail_agent_allows_prompt_without_instructions(monkeypatch: pytest.MonkeyPatch) -> None: + """Prompt attribute text becomes derived instructions.""" + pipeline = SimpleNamespace(pre_flight=None, input=None, output=None) + + monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline, raising=False) + monkeypatch.setattr(runtime_module, "instantiate_guardrails", lambda *args, **kwargs: [], raising=False) + + prompt = SimpleNamespace(text="Serve customers helpfully.") + + agent_instance = agents.GuardrailAgent(config={}, name="PromptOnly", prompt=prompt) + + assert agent_instance.prompt is prompt # noqa: S101 + assert agent_instance.instructions == "Serve customers helpfully." # noqa: S101 + + +def test_guardrail_agent_accepts_string_prompt(monkeypatch: pytest.MonkeyPatch) -> None: + """String prompts populate missing instructions automatically.""" + pipeline = SimpleNamespace(pre_flight=None, input=None, output=None) + + monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline, raising=False) + monkeypatch.setattr(runtime_module, "instantiate_guardrails", lambda *args, **kwargs: [], raising=False) + + agent_instance = agents.GuardrailAgent(config={}, name="PromptStr", prompt="Be concise.") + + assert agent_instance.prompt == "Be concise." # noqa: S101 + assert agent_instance.instructions == "Be concise." # noqa: S101 + + +def test_guardrail_agent_requires_instructions_or_prompt(monkeypatch: pytest.MonkeyPatch) -> None: + """GuardrailAgent requires instructions or prompt for construction.""" + pipeline = SimpleNamespace(pre_flight=None, input=None, output=None) + + monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline, raising=False) + monkeypatch.setattr(runtime_module, "instantiate_guardrails", lambda *args, **kwargs: [], raising=False) + + with pytest.raises(ValueError): + agents.GuardrailAgent(config={}, name="Missing")