From 084d4fd23e21fccfd6747704a0ff076ebad583f2 Mon Sep 17 00:00:00 2001 From: Steven C Date: Wed, 8 Oct 2025 17:48:51 -0400 Subject: [PATCH] Adding unit tests --- pyproject.toml | 17 + tests/conftest.py | 129 ++++ tests/integration/test_suite.py | 46 +- tests/unit/checks/test_keywords.py | 68 ++ tests/unit/checks/test_llm_base.py | 158 +++++ tests/unit/checks/test_moderation.py | 58 ++ .../checks/test_prompt_injection_detection.py | 122 ++++ tests/unit/checks/test_secret_keys.py | 35 + tests/unit/checks/test_urls.py | 92 +++ tests/unit/test_agents.py | 603 +++++++++++++++++ tests/unit/test_base_client.py | 402 ++++++++++++ tests/unit/test_cli.py | 72 ++ tests/unit/test_client_async.py | 419 ++++++++++++ tests/unit/test_client_sync.py | 616 ++++++++++++++++++ tests/unit/test_context.py | 37 ++ tests/unit/test_registry.py | 6 +- tests/unit/test_resources_chat.py | 277 ++++++++ tests/unit/test_resources_responses.py | 338 ++++++++++ tests/unit/test_runtime.py | 39 -- tests/unit/test_streaming.py | 162 +++++ tests/unit/utils/test_create_vector_store.py | 69 ++ tests/unit/utils/test_output.py | 38 ++ tests/unit/utils/test_parsing.py | 47 ++ tests/unit/utils/test_schema.py | 46 ++ 24 files changed, 3815 insertions(+), 81 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/unit/checks/test_keywords.py create mode 100644 tests/unit/checks/test_llm_base.py create mode 100644 tests/unit/checks/test_moderation.py create mode 100644 tests/unit/checks/test_prompt_injection_detection.py create mode 100644 tests/unit/checks/test_secret_keys.py create mode 100644 tests/unit/checks/test_urls.py create mode 100644 tests/unit/test_agents.py create mode 100644 tests/unit/test_base_client.py create mode 100644 tests/unit/test_cli.py create mode 100644 tests/unit/test_client_async.py create mode 100644 tests/unit/test_client_sync.py create mode 100644 tests/unit/test_context.py create mode 100644 tests/unit/test_resources_chat.py create mode 100644 tests/unit/test_resources_responses.py create mode 100644 tests/unit/test_streaming.py create mode 100644 tests/unit/utils/test_create_vector_store.py create mode 100644 tests/unit/utils/test_output.py create mode 100644 tests/unit/utils/test_parsing.py create mode 100644 tests/unit/utils/test_schema.py diff --git a/pyproject.toml b/pyproject.toml index 6e5827c..52993c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ dev = [ "pymdown-extensions>=10.0.0", "coverage>=7.8.0", "hypothesis>=6.131.20", + "pytest-cov>=6.3.0", ] [tool.uv.workspace] @@ -103,8 +104,24 @@ convention = "google" [tool.ruff.format] docstring-code-format = true +[tool.coverage.run] +source = ["guardrails"] +omit = [ + "src/guardrails/evals/*", +] + [tool.mypy] strict = true disallow_incomplete_defs = false disallow_untyped_defs = false disallow_untyped_calls = false +exclude = [ + "examples", + "src/guardrails/evals", +] + +[tool.pyright] +ignore = [ + "examples", + "src/guardrails/evals", +] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..7cf4555 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,129 @@ +"""Shared pytest fixtures for guardrails tests. + +These fixtures provide deterministic test environments by stubbing the OpenAI +client library, seeding environment variables, and preventing accidental live +network activity during the suite. +""" + +from __future__ import annotations + +import logging +import sys +import types +from collections.abc import Iterator +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Any + +import pytest + + +class _StubOpenAIBase: + """Base stub with attribute bag behaviour for OpenAI client classes.""" + + def __init__(self, **kwargs: Any) -> None: + self._client_kwargs = kwargs + self.chat = SimpleNamespace() + self.responses = SimpleNamespace() + self.api_key = kwargs.get("api_key", "test-key") + self.base_url = kwargs.get("base_url") + self.organization = kwargs.get("organization") + self.timeout = kwargs.get("timeout") + self.max_retries = kwargs.get("max_retries") + + def __getattr__(self, item: str) -> Any: + """Return None for unknown attributes to emulate real client laziness.""" + return None + + +class _StubAsyncOpenAI(_StubOpenAIBase): + """Stub asynchronous OpenAI client.""" + + +class _StubSyncOpenAI(_StubOpenAIBase): + """Stub synchronous OpenAI client.""" + + +@dataclass(frozen=True, slots=True) +class _DummyResponse: + """Minimal response type with choices and output.""" + + choices: list[Any] | None = None + output: list[Any] | None = None + output_text: str | None = None + type: str | None = None + delta: str | None = None + + +_STUB_OPENAI_MODULE = types.ModuleType("openai") +_STUB_OPENAI_MODULE.AsyncOpenAI = _StubAsyncOpenAI +_STUB_OPENAI_MODULE.OpenAI = _StubSyncOpenAI +_STUB_OPENAI_MODULE.AsyncAzureOpenAI = _StubAsyncOpenAI +_STUB_OPENAI_MODULE.AzureOpenAI = _StubSyncOpenAI +_STUB_OPENAI_MODULE.NOT_GIVEN = object() + + +class APITimeoutError(Exception): + """Stub API timeout error.""" + + +_STUB_OPENAI_MODULE.APITimeoutError = APITimeoutError + +_OPENAI_TYPES_MODULE = types.ModuleType("openai.types") +_OPENAI_TYPES_MODULE.Completion = _DummyResponse +_OPENAI_TYPES_MODULE.Response = _DummyResponse + +_OPENAI_CHAT_MODULE = types.ModuleType("openai.types.chat") +_OPENAI_CHAT_MODULE.ChatCompletion = _DummyResponse +_OPENAI_CHAT_MODULE.ChatCompletionChunk = _DummyResponse + +_OPENAI_RESPONSES_MODULE = types.ModuleType("openai.types.responses") +_OPENAI_RESPONSES_MODULE.Response = _DummyResponse +_OPENAI_RESPONSES_MODULE.ResponseInputItemParam = dict # type: ignore[attr-defined] +_OPENAI_RESPONSES_MODULE.ResponseOutputItem = dict # type: ignore[attr-defined] +_OPENAI_RESPONSES_MODULE.ResponseStreamEvent = dict # type: ignore[attr-defined] + + +_OPENAI_RESPONSES_RESPONSE_MODULE = types.ModuleType("openai.types.responses.response") +_OPENAI_RESPONSES_RESPONSE_MODULE.Response = _DummyResponse + + +class _ResponseTextConfigParam(dict): + """Stub config param used for response formatting.""" + + +_OPENAI_RESPONSES_MODULE.ResponseTextConfigParam = _ResponseTextConfigParam + +sys.modules["openai"] = _STUB_OPENAI_MODULE +sys.modules["openai.types"] = _OPENAI_TYPES_MODULE +sys.modules["openai.types.chat"] = _OPENAI_CHAT_MODULE +sys.modules["openai.types.responses"] = _OPENAI_RESPONSES_MODULE +sys.modules["openai.types.responses.response"] = _OPENAI_RESPONSES_RESPONSE_MODULE + + +@pytest.fixture(autouse=True) +def stub_openai_module(monkeypatch: pytest.MonkeyPatch) -> Iterator[types.ModuleType]: + """Provide stub OpenAI module so tests avoid real network-bound clients.""" + # Patch imported symbols in guardrails modules + from guardrails import _base_client, client, types as guardrail_types # type: ignore + + monkeypatch.setattr(_base_client, "AsyncOpenAI", _StubAsyncOpenAI, raising=False) + monkeypatch.setattr(_base_client, "OpenAI", _StubSyncOpenAI, raising=False) + monkeypatch.setattr(client, "AsyncOpenAI", _StubAsyncOpenAI, raising=False) + monkeypatch.setattr(client, "OpenAI", _StubSyncOpenAI, raising=False) + monkeypatch.setattr(client, "AsyncAzureOpenAI", _StubAsyncOpenAI, raising=False) + monkeypatch.setattr(client, "AzureOpenAI", _StubSyncOpenAI, raising=False) + monkeypatch.setattr(guardrail_types, "AsyncOpenAI", _StubAsyncOpenAI, raising=False) + monkeypatch.setattr(guardrail_types, "OpenAI", _StubSyncOpenAI, raising=False) + monkeypatch.setattr(guardrail_types, "AsyncAzureOpenAI", _StubAsyncOpenAI, raising=False) + monkeypatch.setattr(guardrail_types, "AzureOpenAI", _StubSyncOpenAI, raising=False) + + monkeypatch.setenv("OPENAI_API_KEY", "test-key") + + yield _STUB_OPENAI_MODULE + + +@pytest.fixture(autouse=True) +def configure_logging() -> None: + """Ensure logging defaults to DEBUG for deterministic assertions.""" + logging.basicConfig(level=logging.DEBUG) diff --git a/tests/integration/test_suite.py b/tests/integration/test_suite.py index 5141c86..e40d2b6 100644 --- a/tests/integration/test_suite.py +++ b/tests/integration/test_suite.py @@ -378,11 +378,7 @@ async def run_test( else: # Find the triggered result triggered_result = next( - ( - r - for r in response.guardrail_results.all_results - if r.tripwire_triggered - ), + (r for r in response.guardrail_results.all_results if r.tripwire_triggered), None, ) info = triggered_result.info if triggered_result else None @@ -394,9 +390,7 @@ async def run_test( "details": {"result": info}, }, ) - print( - f"❌ {test.name} - Passing case {idx} triggered when it shouldn't" - ) + print(f"❌ {test.name} - Passing case {idx} triggered when it shouldn't") if info: print(f" Info: {info}") @@ -427,11 +421,7 @@ async def run_test( if tripwire_triggered: # Find the triggered result triggered_result = next( - ( - r - for r in response.guardrail_results.all_results - if r.tripwire_triggered - ), + (r for r in response.guardrail_results.all_results if r.tripwire_triggered), None, ) info = triggered_result.info if triggered_result else None @@ -517,17 +507,9 @@ async def run_test_suite( results["tests"].append(outcome) # Calculate test status - passing_fails = sum( - 1 for c in outcome["passing_cases"] if c["status"] == "FAIL" - ) - failing_fails = sum( - 1 for c in outcome["failing_cases"] if c["status"] == "FAIL" - ) - errors = sum( - 1 - for c in outcome["passing_cases"] + outcome["failing_cases"] - if c["status"] == "ERROR" - ) + passing_fails = sum(1 for c in outcome["passing_cases"] if c["status"] == "FAIL") + failing_fails = sum(1 for c in outcome["failing_cases"] if c["status"] == "FAIL") + errors = sum(1 for c in outcome["passing_cases"] + outcome["failing_cases"] if c["status"] == "ERROR") if errors > 0: results["summary"]["error_tests"] += 1 @@ -538,16 +520,8 @@ async def run_test_suite( # Count case results total_cases = len(outcome["passing_cases"]) + len(outcome["failing_cases"]) - passed_cases = sum( - 1 - for c in outcome["passing_cases"] + outcome["failing_cases"] - if c["status"] == "PASS" - ) - failed_cases = sum( - 1 - for c in outcome["passing_cases"] + outcome["failing_cases"] - if c["status"] == "FAIL" - ) + passed_cases = sum(1 for c in outcome["passing_cases"] + outcome["failing_cases"] if c["status"] == "PASS") + failed_cases = sum(1 for c in outcome["passing_cases"] + outcome["failing_cases"] if c["status"] == "FAIL") error_cases = errors results["summary"]["total_cases"] += total_cases @@ -564,9 +538,7 @@ def print_summary(results: dict[str, Any]) -> None: print("GUARDRAILS TEST SUMMARY") print("=" * 50) print( - f"Tests: {summary['passed_tests']} passed, " - f"{summary['failed_tests']} failed, " - f"{summary['error_tests']} errors", + f"Tests: {summary['passed_tests']} passed, " f"{summary['failed_tests']} failed, " f"{summary['error_tests']} errors", ) print( f"Cases: {summary['total_cases']} total, " diff --git a/tests/unit/checks/test_keywords.py b/tests/unit/checks/test_keywords.py new file mode 100644 index 0000000..b9175dd --- /dev/null +++ b/tests/unit/checks/test_keywords.py @@ -0,0 +1,68 @@ +"""Tests for keyword-based guardrail helpers.""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from guardrails.checks.text.competitors import CompetitorCfg, competitors +from guardrails.checks.text.keywords import KeywordCfg, keywords, match_keywords +from guardrails.types import GuardrailResult + + +def test_match_keywords_sanitizes_trailing_punctuation() -> None: + """Ensure keyword sanitization strips trailing punctuation before matching.""" + config = KeywordCfg(keywords=["token.", "secret!", "KEY?"]) + result = match_keywords("Leaked token appears here.", config, guardrail_name="Test Guardrail") + + assert result.tripwire_triggered is True # noqa: S101 + assert result.info["sanitized_keywords"] == ["token", "secret", "KEY"] # noqa: S101 + assert result.info["matched"] == ["token"] # noqa: S101 + assert result.info["guardrail_name"] == "Test Guardrail" # noqa: S101 + assert result.info["checked_text"] == "Leaked token appears here." # noqa: S101 + + +def test_match_keywords_deduplicates_case_insensitive_matches() -> None: + """Repeated matches differing by case should be deduplicated.""" + config = KeywordCfg(keywords=["Alert"]) + result = match_keywords("alert ALERT Alert", config, guardrail_name="Keyword Filter") + + assert result.tripwire_triggered is True # noqa: S101 + assert result.info["matched"] == ["alert"] # noqa: S101 + + +@pytest.mark.asyncio +async def test_keywords_guardrail_wraps_match_keywords() -> None: + """Async guardrail should mirror match_keywords behaviour.""" + config = KeywordCfg(keywords=["breach"]) + result = await keywords(ctx=None, data="Potential breach detected", config=config) + + assert isinstance(result, GuardrailResult) # noqa: S101 + assert result.tripwire_triggered is True # noqa: S101 + assert result.info["guardrail_name"] == "Keyword Filter" # noqa: S101 + + +@pytest.mark.asyncio +async def test_competitors_uses_keyword_matching() -> None: + """Competitors guardrail delegates to keyword matching with distinct name.""" + config = CompetitorCfg(keywords=["ACME Corp"]) + result = await competitors(ctx=None, data="Comparing against ACME Corp today", config=config) + + assert result.tripwire_triggered is True # noqa: S101 + assert result.info["guardrail_name"] == "Competitors" # noqa: S101 + assert result.info["matched"] == ["ACME Corp"] # noqa: S101 + + +def test_keyword_cfg_requires_non_empty_keywords() -> None: + """KeywordCfg should enforce at least one keyword.""" + with pytest.raises(ValidationError): + KeywordCfg(keywords=[]) + + +@pytest.mark.asyncio +async def test_keywords_does_not_trigger_on_benign_text() -> None: + """Guardrail should not trigger when no keywords are present.""" + config = KeywordCfg(keywords=["restricted"]) + result = await keywords(ctx=None, data="Safe content", config=config) + + assert result.tripwire_triggered is False # noqa: S101 diff --git a/tests/unit/checks/test_llm_base.py b/tests/unit/checks/test_llm_base.py new file mode 100644 index 0000000..907f523 --- /dev/null +++ b/tests/unit/checks/test_llm_base.py @@ -0,0 +1,158 @@ +"""Tests for LLM-based guardrail helpers.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest + +from guardrails.checks.text import llm_base +from guardrails.checks.text.llm_base import ( + LLMConfig, + LLMErrorOutput, + LLMOutput, + _build_full_prompt, + _strip_json_code_fence, + create_llm_check_fn, + run_llm, +) +from guardrails.types import GuardrailResult + + +class _FakeCompletions: + def __init__(self, content: str | None) -> None: + self._content = content + + async def create(self, **kwargs: Any) -> Any: + _ = kwargs + return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content=self._content))]) + + +class _FakeAsyncClient: + def __init__(self, content: str | None) -> None: + self.chat = SimpleNamespace(completions=_FakeCompletions(content)) + + +def test_strip_json_code_fence_removes_wrapping() -> None: + """Valid JSON code fences should be removed.""" + fenced = """```json +{"flagged": false, "confidence": 0.2} +```""" + assert _strip_json_code_fence(fenced) == '{"flagged": false, "confidence": 0.2}' # noqa: S101 + + +def test_build_full_prompt_includes_instructions() -> None: + """Generated prompt should embed system instructions and schema guidance.""" + prompt = _build_full_prompt("Analyze text") + assert "Analyze text" in prompt # noqa: S101 + assert "Respond with a json object" in prompt # noqa: S101 + + +@pytest.mark.asyncio +async def test_run_llm_returns_valid_output() -> None: + """run_llm should parse the JSON response into the provided output model.""" + client = _FakeAsyncClient('{"flagged": true, "confidence": 0.9}') + result = await run_llm( + text="Sensitive text", + system_prompt="Detect problems.", + client=client, # type: ignore[arg-type] + model="gpt-test", + output_model=LLMOutput, + ) + assert isinstance(result, LLMOutput) # noqa: S101 + assert result.flagged is True and result.confidence == 0.9 # noqa: S101 + + +@pytest.mark.asyncio +async def test_run_llm_handles_content_filter_error(monkeypatch: pytest.MonkeyPatch) -> None: + """Content filter errors should return LLMErrorOutput with flagged=True.""" + + class _FailingClient: + class _Chat: + class _Completions: + async def create(self, **kwargs: Any) -> Any: + raise RuntimeError("content_filter triggered by provider") + + completions = _Completions() + + chat = _Chat() + + result = await run_llm( + text="Sensitive", + system_prompt="Detect.", + client=_FailingClient(), # type: ignore[arg-type] + model="gpt-test", + output_model=LLMOutput, + ) + + assert isinstance(result, LLMErrorOutput) # noqa: S101 + assert result.flagged is True # noqa: S101 + assert result.info["third_party_filter"] is True # noqa: S101 + + +@pytest.mark.asyncio +async def test_create_llm_check_fn_triggers_on_confident_flag(monkeypatch: pytest.MonkeyPatch) -> None: + """Generated guardrail function should trip when confidence exceeds the threshold.""" + + async def fake_run_llm( + text: str, + system_prompt: str, + client: Any, + model: str, + output_model: type[LLMOutput], + ) -> LLMOutput: + assert system_prompt == "Check with details" # noqa: S101 + return LLMOutput(flagged=True, confidence=0.95) + + monkeypatch.setattr(llm_base, "run_llm", fake_run_llm) + + class DetailedConfig(LLMConfig): + system_prompt_details: str = "details" + + guardrail_fn = create_llm_check_fn( + name="HighConfidence", + description="Test guardrail", + system_prompt="Check with {system_prompt_details}", + output_model=LLMOutput, + config_model=DetailedConfig, + ) + + config = DetailedConfig(model="gpt-test", confidence_threshold=0.9) + context = SimpleNamespace(guardrail_llm="fake-client") + + result = await guardrail_fn(context, "content", config) + + assert isinstance(result, GuardrailResult) # noqa: S101 + assert result.tripwire_triggered is True # noqa: S101 + assert result.info["threshold"] == 0.9 # noqa: S101 + + +@pytest.mark.asyncio +async def test_create_llm_check_fn_handles_llm_error(monkeypatch: pytest.MonkeyPatch) -> None: + """LLM error results should mark execution_failed without triggering tripwire.""" + + async def fake_run_llm( + text: str, + system_prompt: str, + client: Any, + model: str, + output_model: type[LLMOutput], + ) -> LLMErrorOutput: + return LLMErrorOutput(flagged=False, confidence=0.0, info={"error_message": "timeout"}) + + monkeypatch.setattr(llm_base, "run_llm", fake_run_llm) + + guardrail_fn = create_llm_check_fn( + name="Resilient", + description="Test guardrail", + system_prompt="Prompt", + ) + + config = LLMConfig(model="gpt-test", confidence_threshold=0.5) + context = SimpleNamespace(guardrail_llm="fake-client") + result = await guardrail_fn(context, "text", config) + + assert result.tripwire_triggered is False # noqa: S101 + assert result.execution_failed is True # noqa: S101 + assert "timeout" in str(result.original_exception) # noqa: S101 diff --git a/tests/unit/checks/test_moderation.py b/tests/unit/checks/test_moderation.py new file mode 100644 index 0000000..389dd09 --- /dev/null +++ b/tests/unit/checks/test_moderation.py @@ -0,0 +1,58 @@ +"""Tests for moderation guardrail.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest + +from guardrails.checks.text.moderation import Category, ModerationCfg, moderation + + +class _StubModerationClient: + """Stub moderations client that returns prerecorded results.""" + + def __init__(self, categories: dict[str, bool]) -> None: + self._categories = categories + + async def create(self, model: str, input: str) -> Any: + _ = (model, input) + + class _Result: + def model_dump(self_inner) -> dict[str, Any]: + return {"categories": self._categories} + + return SimpleNamespace(results=[_Result()]) + + +@pytest.mark.asyncio +async def test_moderation_triggers_on_flagged_categories(monkeypatch: pytest.MonkeyPatch) -> None: + """Requested categories flagged True should trigger the guardrail.""" + stub_client = SimpleNamespace(moderations=_StubModerationClient({"hate": True, "violence": False})) + + monkeypatch.setattr("guardrails.checks.text.moderation._get_moderation_client", lambda: stub_client) + + cfg = ModerationCfg(categories=[Category.HATE, Category.VIOLENCE]) + result = await moderation(None, "text", cfg) + + assert result.tripwire_triggered is True # noqa: S101 + assert result.info["flagged_categories"] == ["hate"] # noqa: S101 + + +@pytest.mark.asyncio +async def test_moderation_handles_empty_results(monkeypatch: pytest.MonkeyPatch) -> None: + """Missing results should return an informative error.""" + + async def create_empty(**_: Any) -> Any: + return SimpleNamespace(results=[]) + + stub_client = SimpleNamespace(moderations=SimpleNamespace(create=create_empty)) + + monkeypatch.setattr("guardrails.checks.text.moderation._get_moderation_client", lambda: stub_client) + + cfg = ModerationCfg(categories=[Category.HARASSMENT]) + result = await moderation(None, "text", cfg) + + assert result.tripwire_triggered is False # noqa: S101 + assert result.info["error"] == "No moderation results returned" # noqa: S101 diff --git a/tests/unit/checks/test_prompt_injection_detection.py b/tests/unit/checks/test_prompt_injection_detection.py new file mode 100644 index 0000000..1cda87c --- /dev/null +++ b/tests/unit/checks/test_prompt_injection_detection.py @@ -0,0 +1,122 @@ +"""Tests for prompt injection detection guardrail.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest + +from guardrails.checks.text import prompt_injection_detection as pid_module +from guardrails.checks.text.llm_base import LLMConfig +from guardrails.checks.text.prompt_injection_detection import ( + PromptInjectionDetectionOutput, + _should_analyze, + prompt_injection_detection, +) + + +class _FakeContext: + """Context stub providing conversation history accessors.""" + + def __init__(self, history: list[Any]) -> None: + self._history = history + self.guardrail_llm = SimpleNamespace() # unused due to monkeypatch + self._last_index = 0 + + def get_conversation_history(self) -> list[Any]: + return self._history + + def get_injection_last_checked_index(self) -> int: + return self._last_index + + def update_injection_last_checked_index(self, new_index: int) -> None: + self._last_index = new_index + + +def _make_history(action: dict[str, Any]) -> list[Any]: + return [ + {"role": "user", "content": "Retrieve the weather for Paris"}, + action, + ] + + +@pytest.mark.parametrize( + "message, expected", + [ + ({"type": "function_call"}, True), + ({"role": "tool", "content": "Tool output"}, True), + ({"role": "assistant", "content": "hello"}, False), + ], +) +def test_should_analyze(message: dict[str, Any], expected: bool) -> None: + """Verify _should_analyze matches only tool-related messages.""" + assert _should_analyze(message) is expected # noqa: S101 + + +@pytest.mark.asyncio +async def test_prompt_injection_detection_triggers(monkeypatch: pytest.MonkeyPatch) -> None: + """Guardrail should trigger when analysis flags misalignment above threshold.""" + history = _make_history({"type": "function_call", "tool_name": "delete_files", "arguments": "{}"}) + context = _FakeContext(history) + + async def fake_call_llm(ctx: Any, prompt: str, config: LLMConfig) -> PromptInjectionDetectionOutput: + assert "delete_files" in prompt # noqa: S101 + assert hasattr(ctx, "guardrail_llm") # noqa: S101 + return PromptInjectionDetectionOutput(flagged=True, confidence=0.95, observation="Deletes user files") + + monkeypatch.setattr(pid_module, "_call_prompt_injection_detection_llm", fake_call_llm) + + config = LLMConfig(model="gpt-test", confidence_threshold=0.9) + result = await prompt_injection_detection(context, data="{}", config=config) + + assert result.tripwire_triggered is True # noqa: S101 + assert context.get_injection_last_checked_index() == len(history) # noqa: S101 + + +@pytest.mark.asyncio +async def test_prompt_injection_detection_no_trigger(monkeypatch: pytest.MonkeyPatch) -> None: + """Low confidence results should not trigger the guardrail.""" + history = _make_history({"type": "function_call", "tool_name": "get_weather", "arguments": "{}"}) + context = _FakeContext(history) + + async def fake_call_llm(ctx: Any, prompt: str, config: LLMConfig) -> PromptInjectionDetectionOutput: + return PromptInjectionDetectionOutput(flagged=True, confidence=0.3, observation="Aligned") + + monkeypatch.setattr(pid_module, "_call_prompt_injection_detection_llm", fake_call_llm) + + config = LLMConfig(model="gpt-test", confidence_threshold=0.9) + result = await prompt_injection_detection(context, data="{}", config=config) + + assert result.tripwire_triggered is False # noqa: S101 + assert "Aligned" in result.info["observation"] # noqa: S101 + + +@pytest.mark.asyncio +async def test_prompt_injection_detection_skips_without_history(monkeypatch: pytest.MonkeyPatch) -> None: + """When no conversation history is present, guardrail should skip.""" + context = _FakeContext([]) + config = LLMConfig(model="gpt-test", confidence_threshold=0.9) + + result = await prompt_injection_detection(context, data="{}", config=config) + + assert result.tripwire_triggered is False # noqa: S101 + assert result.info["observation"] == "No conversation history available" # noqa: S101 + + +@pytest.mark.asyncio +async def test_prompt_injection_detection_handles_analysis_error(monkeypatch: pytest.MonkeyPatch) -> None: + """Exceptions during analysis should return a skip result.""" + history = _make_history({"type": "function_call", "tool_name": "get_weather", "arguments": "{}"}) + context = _FakeContext(history) + + async def failing_llm(*_args: Any, **_kwargs: Any) -> PromptInjectionDetectionOutput: + raise RuntimeError("LLM failed") + + monkeypatch.setattr(pid_module, "_call_prompt_injection_detection_llm", failing_llm) + + config = LLMConfig(model="gpt-test", confidence_threshold=0.7) + result = await prompt_injection_detection(context, data="{}", config=config) + + assert result.tripwire_triggered is False # noqa: S101 + assert "Error during prompt injection detection check" in result.info["observation"] # noqa: S101 diff --git a/tests/unit/checks/test_secret_keys.py b/tests/unit/checks/test_secret_keys.py new file mode 100644 index 0000000..3187c61 --- /dev/null +++ b/tests/unit/checks/test_secret_keys.py @@ -0,0 +1,35 @@ +"""Tests for secret key detection guardrail.""" + +from __future__ import annotations + +import pytest + +from guardrails.checks.text.secret_keys import SecretKeysCfg, _detect_secret_keys, secret_keys + + +def test_detect_secret_keys_flags_high_entropy_strings() -> None: + """High entropy tokens should be detected as potential secrets.""" + text = "API key sk-AAAABBBBCCCCDDDD" + result = _detect_secret_keys(text, cfg={"min_length": 10, "min_entropy": 3.5, "min_diversity": 2, "strict_mode": True}) + + assert result.tripwire_triggered is True # noqa: S101 + assert "sk-AAAABBBBCCCCDDDD" in result.info["detected_secrets"] # noqa: S101 + + +@pytest.mark.asyncio +async def test_secret_keys_with_custom_regex() -> None: + """Custom regex patterns should trigger detection.""" + config = SecretKeysCfg(threshold="balanced", custom_regex=[r"internal-[a-z0-9]{4}"]) + result = await secret_keys(None, "internal-ab12 leaked", config) + + assert result.tripwire_triggered is True # noqa: S101 + assert "internal-ab12" in result.info["detected_secrets"] # noqa: S101 + + +@pytest.mark.asyncio +async def test_secret_keys_ignores_non_matching_input() -> None: + """Benign inputs should not trigger the guardrail.""" + config = SecretKeysCfg(threshold="permissive") + result = await secret_keys(None, "Hello world", config) + + assert result.tripwire_triggered is False # noqa: S101 diff --git a/tests/unit/checks/test_urls.py b/tests/unit/checks/test_urls.py new file mode 100644 index 0000000..2ef7a71 --- /dev/null +++ b/tests/unit/checks/test_urls.py @@ -0,0 +1,92 @@ +"""Tests for URL guardrail helpers.""" + +from __future__ import annotations + +import pytest + +from guardrails.checks.text.urls import ( + URLConfig, + _detect_urls, + _is_url_allowed, + _validate_url_security, + urls, +) + + +def test_detect_urls_deduplicates_scheme_and_domain() -> None: + """Ensure detection removes trailing punctuation and avoids duplicate domains.""" + text = "Visit https://example.com/, http://example.com/path, " "example.com should not duplicate, and 192.168.1.10:8080." + detected = _detect_urls(text) + + assert "https://example.com/" in detected # noqa: S101 + assert "http://example.com/path" in detected # noqa: S101 + assert "example.com" not in detected # noqa: S101 + assert "192.168.1.10:8080" in detected # noqa: S101 + + +def test_validate_url_security_blocks_bad_scheme() -> None: + """Disallowed schemes should produce an error.""" + config = URLConfig() + parsed, reason = _validate_url_security("http://blocked.com", config) + + assert parsed is None # noqa: S101 + assert "Blocked scheme" in reason # noqa: S101 + + +def test_validate_url_security_blocks_userinfo_when_configured() -> None: + """URLs with embedded credentials should be rejected when block_userinfo=True.""" + config = URLConfig(allowed_schemes={"https"}, block_userinfo=True) + parsed, reason = _validate_url_security("https://user:pass@example.com", config) + + assert parsed is None # noqa: S101 + assert "userinfo" in reason # noqa: S101 + + +def test_is_url_allowed_supports_subdomains_and_cidr() -> None: + """Allow list should support subdomains and CIDR ranges.""" + config = URLConfig( + url_allow_list=["example.com", "10.0.0.0/8"], + allow_subdomains=True, + ) + https_result, _ = _validate_url_security("https://api.example.com", config) + ip_result, _ = _validate_url_security("https://10.1.2.3", config) + + assert https_result is not None # noqa: S101 + assert ip_result is not None # noqa: S101 + assert _is_url_allowed(https_result, config.url_allow_list, config.allow_subdomains) is True # noqa: S101 + assert _is_url_allowed(ip_result, config.url_allow_list, config.allow_subdomains) is True # noqa: S101 + + +@pytest.mark.asyncio +async def test_urls_guardrail_reports_allowed_and_blocked() -> None: + """Urls guardrail should classify detected URLs based on config.""" + config = URLConfig( + url_allow_list=["example.com"], + allowed_schemes={"https", "data"}, + block_userinfo=True, + allow_subdomains=False, + ) + text = ( + "Inline data URI data:text/plain;base64,QUJD. " + "Use https://example.com/docs. " + "Avoid http://attacker.com/login and https://sub.example.com." + ) + + result = await urls(ctx=None, data=text, config=config) + + assert result.tripwire_triggered is True # noqa: S101 + assert "https://example.com/docs" in result.info["allowed"] # noqa: S101 + assert "data:text/plain;base64,QUJD" in result.info["allowed"] # noqa: S101 + assert "http://attacker.com/login" in result.info["blocked"] # noqa: S101 + assert "https://sub.example.com" in result.info["blocked"] # noqa: S101 + assert any("Blocked scheme" in reason for reason in result.info["blocked_reasons"]) # noqa: S101 + assert any("Not in allow list" in reason for reason in result.info["blocked_reasons"]) # noqa: S101 + + +@pytest.mark.asyncio +async def test_urls_guardrail_allows_benign_input() -> None: + """Benign text without URLs should not trigger.""" + config = URLConfig() + result = await urls(ctx=None, data="No links here", config=config) + + assert result.tripwire_triggered is False # noqa: S101 diff --git a/tests/unit/test_agents.py b/tests/unit/test_agents.py new file mode 100644 index 0000000..2e90877 --- /dev/null +++ b/tests/unit/test_agents.py @@ -0,0 +1,603 @@ +"""Tests covering guardrails.agents helper functions.""" + +from __future__ import annotations + +import sys +import types +from collections.abc import Callable +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Any + +import pytest + +from guardrails.types import GuardrailResult + +# --------------------------------------------------------------------------- +# Stub agents SDK module so guardrails.agents can import required symbols. +# --------------------------------------------------------------------------- + +agents_module = types.ModuleType("agents") + + +@dataclass +class ToolContext: + """Stub tool context carrying name and arguments.""" + + tool_name: str + tool_arguments: dict[str, Any] + + +@dataclass +class ToolInputGuardrailData: + """Stub input guardrail payload.""" + + context: ToolContext + + +@dataclass +class ToolOutputGuardrailData: + """Stub output guardrail payload.""" + + context: ToolContext + output: Any + + +@dataclass +class GuardrailFunctionOutput: + """Minimal guardrail function output stub.""" + + output_info: Any + tripwire_triggered: bool + + +@dataclass +class ToolGuardrailFunctionOutput: + """Stub for tool guardrail responses.""" + + message: str | None = None + output_info: Any | None = None + tripwire_triggered: bool = False + + @classmethod + def raise_exception(cls, output_info: Any) -> ToolGuardrailFunctionOutput: + """Return a response indicating an exception should be raised.""" + return cls(message="raise", output_info=output_info, tripwire_triggered=True) + + @classmethod + def reject_content( + cls, + message: str, + output_info: Any, + ) -> ToolGuardrailFunctionOutput: + """Return a response rejecting tool content.""" + return cls(message=message, output_info=output_info, tripwire_triggered=True) + + +def _decorator_passthrough(func: Callable) -> Callable: + """Return the function unchanged (stand-in for agents decorators).""" + return func + + +class RunContextWrapper: + """Minimal RunContextWrapper stub.""" + + def __init__(self, value: Any | None = None) -> None: + """Store wrapped value.""" + self.value = value + + +@dataclass +class Agent: + """Trivial Agent stub storing initialization args for assertions.""" + + name: str + instructions: str + input_guardrails: list[Callable] | None = None + output_guardrails: list[Callable] | None = None + tools: list[Any] | None = None + + +agents_module.ToolGuardrailFunctionOutput = ToolGuardrailFunctionOutput +agents_module.ToolInputGuardrailData = ToolInputGuardrailData +agents_module.ToolOutputGuardrailData = ToolOutputGuardrailData +agents_module.tool_input_guardrail = _decorator_passthrough +agents_module.tool_output_guardrail = _decorator_passthrough +agents_module.RunContextWrapper = RunContextWrapper +agents_module.Agent = Agent +agents_module.GuardrailFunctionOutput = GuardrailFunctionOutput +agents_module.input_guardrail = _decorator_passthrough +agents_module.output_guardrail = _decorator_passthrough + +sys.modules.setdefault("agents", agents_module) + +import guardrails.agents as agents # noqa: E402 (import after stubbing) +import guardrails.runtime as runtime_module # noqa: E402 + + +def _make_guardrail(name: str) -> Any: + class _DummyCtxModel: + model_fields: dict[str, Any] = {} + + @staticmethod + def model_validate(value: Any, **_: Any) -> Any: + return value + + return SimpleNamespace( + definition=SimpleNamespace( + name=name, + media_type="text/plain", + ctx_requirements=_DummyCtxModel, + ), + ctx_requirements=[], + ) + + +@pytest.fixture(autouse=True) +def reset_user_messages() -> None: + """Ensure user message context is reset for each test.""" + agents._user_messages.set([]) + + +def test_get_user_messages_initializes_list() -> None: + """_get_user_messages should return the same list instance across calls.""" + msgs1 = agents._get_user_messages() + msgs1.append("hello") + msgs2 = agents._get_user_messages() + + assert msgs2 == ["hello"] # noqa: S101 + assert msgs1 is msgs2 # noqa: S101 + + +def test_build_conversation_with_tool_call_includes_user_messages() -> None: + """Conversation builder should include stored user messages and tool call details.""" + agents._user_messages.set(["Hi there"]) + data = SimpleNamespace(context=ToolContext(tool_name="math", tool_arguments={"x": 1})) + + conversation = agents._build_conversation_with_tool_call(data) + + assert conversation[0] == {"role": "user", "content": "Hi there"} # noqa: S101 + assert conversation[1]["tool_name"] == "math" # noqa: S101 + assert conversation[1]["arguments"] == {"x": 1} # noqa: S101 + + +def test_build_conversation_with_tool_output_includes_output() -> None: + """Tool output conversation should include function output payload.""" + agents._user_messages.set(["User request"]) + data = SimpleNamespace( + context=ToolContext(tool_name="calc", tool_arguments={"y": 2}), + output={"result": 4}, + ) + + conversation = agents._build_conversation_with_tool_output(data) + + assert conversation[1]["output"] == "{'result': 4}" # noqa: S101 + + +def test_create_conversation_context_tracks_index() -> None: + """Conversation context should proxy index accessors.""" + base_context = SimpleNamespace(guardrail_llm="client") + context = agents._create_conversation_context(["msg"], base_context) + + assert context.get_conversation_history() == ["msg"] # noqa: S101 + assert context.get_injection_last_checked_index() == 0 # noqa: S101 + context.update_injection_last_checked_index(3) + assert context.get_injection_last_checked_index() == 0 # noqa: S101 + + +def test_create_default_tool_context_provides_async_client() -> None: + """Default tool context should return stubbed AsyncOpenAI client.""" + context = agents._create_default_tool_context() + assert hasattr(context, "guardrail_llm") # noqa: S101 + + +def test_attach_guardrail_to_tools_initializes_lists() -> None: + """Attaching guardrails should create input/output lists when missing.""" + tool = SimpleNamespace() + fn = lambda data: data # noqa: E731 + + agents._attach_guardrail_to_tools([tool], fn, "input") + agents._attach_guardrail_to_tools([tool], fn, "output") + + assert tool.tool_input_guardrails == [fn] # type: ignore[attr-defined] # noqa: S101 + assert tool.tool_output_guardrails == [fn] # type: ignore[attr-defined] # noqa: S101 + + +def test_needs_conversation_history() -> None: + """Guardrails requiring conversation history should be detected.""" + assert agents._needs_conversation_history(_make_guardrail("Prompt Injection Detection")) is True # noqa: S101 + assert agents._needs_conversation_history(_make_guardrail("Other Guard")) is False # noqa: S101 + + +def test_separate_tool_level_from_agent_level() -> None: + """Prompt injection guardrails should be classified as tool-level.""" + tool, agent_level = agents._separate_tool_level_from_agent_level([_make_guardrail("Prompt Injection Detection"), _make_guardrail("Other Guard")]) + + assert [g.definition.name for g in tool] == ["Prompt Injection Detection"] # noqa: S101 + assert [g.definition.name for g in agent_level] == ["Other Guard"] # noqa: S101 + + +@pytest.mark.asyncio +async def test_create_tool_guardrail_rejects_on_tripwire(monkeypatch: pytest.MonkeyPatch) -> None: + """Tool guardrail should reject content when run_guardrails flags a violation.""" + guardrail = _make_guardrail("Test Guardrail") + expected_info = {"observation": "violation"} + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + assert kwargs["stage_name"] == "tool_input_test_guardrail" # noqa: S101 + return [GuardrailResult(tripwire_triggered=True, info=expected_info)] + + monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) + + tool_fn = agents._create_tool_guardrail( + guardrail=guardrail, + guardrail_type="input", + needs_conv_history=False, + context=SimpleNamespace(), + raise_guardrail_errors=False, + block_on_violations=False, + ) + + data = agents_module.ToolInputGuardrailData(context=ToolContext(tool_name="weather", tool_arguments={"city": "Paris"})) + result = await tool_fn(data) + + assert result.tripwire_triggered is True # noqa: S101 + assert result.output_info == expected_info # noqa: S101 + assert "blocked by Test Guardrail" in result.message # noqa: S101 + + +@pytest.mark.asyncio +async def test_create_tool_guardrail_blocks_on_violation(monkeypatch: pytest.MonkeyPatch) -> None: + """When block_on_violations is True, the guardrail should raise an exception output.""" + guardrail = _make_guardrail("Test Guardrail") + + async def fake_run_guardrails(**_: Any) -> list[GuardrailResult]: + return [GuardrailResult(tripwire_triggered=True, info={})] + + monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) + + tool_fn = agents._create_tool_guardrail( + guardrail=guardrail, + guardrail_type="input", + needs_conv_history=False, + context=SimpleNamespace(), + raise_guardrail_errors=False, + block_on_violations=True, + ) + + data = agents_module.ToolInputGuardrailData(context=ToolContext(tool_name="weather", tool_arguments={})) + result = await tool_fn(data) + + assert result.message == "raise" # noqa: S101 + + +@pytest.mark.asyncio +async def test_create_tool_guardrail_propagates_errors(monkeypatch: pytest.MonkeyPatch) -> None: + """Guardrail errors should raise when raise_guardrail_errors is True.""" + guardrail = _make_guardrail("Failing Guardrail") + + async def failing_run_guardrails(**_: Any) -> list[GuardrailResult]: + raise RuntimeError("guardrail failure") + + monkeypatch.setattr(runtime_module, "run_guardrails", failing_run_guardrails) + + tool_fn = agents._create_tool_guardrail( + guardrail=guardrail, + guardrail_type="input", + needs_conv_history=False, + context=SimpleNamespace(), + raise_guardrail_errors=True, + block_on_violations=False, + ) + + data = agents_module.ToolInputGuardrailData(context=ToolContext(tool_name="weather", tool_arguments={})) + result = await tool_fn(data) + + assert result.message == "raise" # noqa: S101 + + +@pytest.mark.asyncio +async def test_create_tool_guardrail_skips_without_user_messages(monkeypatch: pytest.MonkeyPatch) -> None: + """Conversation-aware tool guardrails should skip when no user intent is recorded.""" + guardrail = _make_guardrail("Prompt Injection Detection") + agents._user_messages.set([]) # Reset stored messages + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + raise AssertionError("run_guardrails should not be called when skipping") + + monkeypatch.setattr(agents, "run_guardrails", fake_run_guardrails, raising=False) + + tool_fn = agents._create_tool_guardrail( + guardrail=guardrail, + guardrail_type="output", + needs_conv_history=True, + context=SimpleNamespace(), + raise_guardrail_errors=False, + block_on_violations=False, + ) + + data = agents_module.ToolOutputGuardrailData( + context=ToolContext(tool_name="math", tool_arguments={"value": 1}), + output="ok", + ) + result = await tool_fn(data) + + assert "Skipped" in result.output_info # noqa: S101 + assert result.tripwire_triggered is False # noqa: S101 + + +@pytest.mark.asyncio +async def test_create_agents_guardrails_from_config_success(monkeypatch: pytest.MonkeyPatch) -> None: + """Agent-level guardrail functions should execute run_guardrails and capture user messages.""" + pipeline = SimpleNamespace(pre_flight=None, input=SimpleNamespace(), output=None) + monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline) + monkeypatch.setattr( + runtime_module, + "instantiate_guardrails", + lambda stage, registry=None: [_make_guardrail("Input Guard")] if stage is pipeline.input else [], + ) + + captured: dict[str, Any] = {} + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + captured.update(kwargs) + return [GuardrailResult(tripwire_triggered=False, info={})] + + monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) + + guardrails = agents._create_agents_guardrails_from_config( + config={}, + stages=["input"], + guardrail_type="input", + context=None, + raise_guardrail_errors=False, + ) + + assert len(guardrails) == 1 # noqa: S101 + output = await guardrails[0](agents_module.RunContextWrapper(None), Agent("a", "b"), "hello") + + assert output.tripwire_triggered is False # noqa: S101 + assert captured["stage_name"] == "input" # noqa: S101 + assert agents._get_user_messages()[-1] == "hello" # noqa: S101 + + +@pytest.mark.asyncio +async def test_create_agents_guardrails_from_config_tripwire(monkeypatch: pytest.MonkeyPatch) -> None: + """Tripwire results should propagate to guardrail function output.""" + pipeline = SimpleNamespace(pre_flight=None, input=SimpleNamespace(), output=None) + monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline) + monkeypatch.setattr( + runtime_module, + "instantiate_guardrails", + lambda stage, registry=None: [_make_guardrail("Input Guard")] if stage is pipeline.input else [], + ) + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + return [GuardrailResult(tripwire_triggered=True, info={"reason": "blocked"})] + + monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) + + guardrails = agents._create_agents_guardrails_from_config( + config={}, + stages=["input"], + guardrail_type="input", + context=SimpleNamespace(guardrail_llm="llm"), + raise_guardrail_errors=False, + ) + + result = await guardrails[0](agents_module.RunContextWrapper(None), Agent("a", "b"), "hi") + + assert result.tripwire_triggered is True # noqa: S101 + assert result.output_info == "Guardrail unknown triggered tripwire" # noqa: S101 + + +@pytest.mark.asyncio +async def test_create_agents_guardrails_from_config_error(monkeypatch: pytest.MonkeyPatch) -> None: + """Errors should be converted to tripwire when raise_guardrail_errors=False.""" + pipeline = SimpleNamespace(pre_flight=None, input=SimpleNamespace(), output=None) + monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline) + monkeypatch.setattr( + runtime_module, + "instantiate_guardrails", + lambda stage, registry=None: [_make_guardrail("Input Guard")] if stage is pipeline.input else [], + ) + + async def failing_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + raise RuntimeError("boom") + + monkeypatch.setattr(runtime_module, "run_guardrails", failing_run_guardrails) + + guardrails = agents._create_agents_guardrails_from_config( + config={}, + stages=["input"], + guardrail_type="input", + context=SimpleNamespace(guardrail_llm="llm"), + raise_guardrail_errors=False, + ) + + result = await guardrails[0](agents_module.RunContextWrapper(None), Agent("name", "instr"), "msg") + + assert result.tripwire_triggered is True # noqa: S101 + assert "Error running input guardrails" in result.output_info # noqa: S101 + + +@pytest.mark.asyncio +async def test_create_agents_guardrails_from_config_error_raises(monkeypatch: pytest.MonkeyPatch) -> None: + """Errors should bubble when raise_guardrail_errors=True.""" + pipeline = SimpleNamespace(pre_flight=None, input=SimpleNamespace(), output=None) + monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline) + monkeypatch.setattr( + runtime_module, + "instantiate_guardrails", + lambda stage, registry=None: [_make_guardrail("Input Guard")] if stage is pipeline.input else [], + ) + + async def failing_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + raise RuntimeError("failure") + + monkeypatch.setattr(runtime_module, "run_guardrails", failing_run_guardrails) + + guardrails = agents._create_agents_guardrails_from_config( + config={}, + stages=["input"], + guardrail_type="input", + context=SimpleNamespace(guardrail_llm="llm"), + raise_guardrail_errors=True, + ) + + with pytest.raises(RuntimeError): + await guardrails[0](agents_module.RunContextWrapper(None), Agent("n", "i"), "msg") + + +@pytest.mark.asyncio +async def test_create_agents_guardrails_from_config_output_stage(monkeypatch: pytest.MonkeyPatch) -> None: + """Output guardrails should not capture user messages.""" + pipeline = SimpleNamespace(pre_flight=None, input=None, output=SimpleNamespace()) + monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline) + monkeypatch.setattr( + runtime_module, + "instantiate_guardrails", + lambda stage, registry=None: [_make_guardrail("Output Guard")] if stage is pipeline.output else [], + ) + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + return [GuardrailResult(tripwire_triggered=False, info={})] + + monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) + + guardrails = agents._create_agents_guardrails_from_config( + config={}, + stages=["output"], + guardrail_type="output", + context=SimpleNamespace(guardrail_llm="llm"), + raise_guardrail_errors=False, + ) + + result = await guardrails[0](agents_module.RunContextWrapper(None), Agent("n", "i"), "response") + + assert result.tripwire_triggered is False # noqa: S101 + assert agents._get_user_messages() == [] # noqa: S101 + + +def test_guardrail_agent_attaches_tool_guardrails(monkeypatch: pytest.MonkeyPatch) -> None: + """GuardrailAgent should attach tool-level guardrails and return an Agent.""" + tool_guard = _make_guardrail("Prompt Injection Detection") + agent_guard = _make_guardrail("Sensitive Data Check") + + class FakePipeline: + def __init__(self) -> None: + self.pre_flight = SimpleNamespace() + self.input = SimpleNamespace() + self.output = SimpleNamespace() + + def stages(self) -> list[Any]: + return [self.pre_flight, self.input, self.output] + + pipeline = FakePipeline() + + def fake_load_pipeline_bundles(config: Any) -> FakePipeline: + assert config == {"version": 1} # noqa: S101 + return pipeline + + def fake_instantiate_guardrails(stage: Any, registry: Any | None = None) -> list[Any]: + if stage is pipeline.pre_flight: + return [tool_guard] + if stage is pipeline.input: + return [agent_guard] + if stage is pipeline.output: + return [] + return [] + + from guardrails import runtime as runtime_module + + monkeypatch.setattr(runtime_module, "load_pipeline_bundles", fake_load_pipeline_bundles) + monkeypatch.setattr(runtime_module, "instantiate_guardrails", fake_instantiate_guardrails) + monkeypatch.setattr(runtime_module, "load_pipeline_bundles", fake_load_pipeline_bundles, raising=False) + monkeypatch.setattr(runtime_module, "instantiate_guardrails", fake_instantiate_guardrails, raising=False) + + tool = SimpleNamespace() + agent_instance = agents.GuardrailAgent( + config={"version": 1}, + name="Test Agent", + instructions="Help users.", + tools=[tool], + ) + + assert isinstance(agent_instance, agents_module.Agent) # noqa: S101 + assert len(tool.tool_input_guardrails) == 1 # type: ignore[attr-defined] # noqa: S101 + # Agent-level guardrails should be attached (one for Sensitive Data Check) + assert len(agent_instance.input_guardrails or []) >= 1 # noqa: S101 + + +@pytest.mark.asyncio +async def test_guardrail_agent_captures_user_messages(monkeypatch: pytest.MonkeyPatch) -> None: + """GuardrailAgent should capture user messages and invoke tool guardrails.""" + prompt_guard = _make_guardrail("Prompt Injection Detection") + input_guard = _make_guardrail("Agent Guard") + + class FakePipeline: + def __init__(self) -> None: + self.pre_flight = SimpleNamespace() + self.input = SimpleNamespace() + self.output = None + + def stages(self) -> list[Any]: + return [self.pre_flight, self.input] + + pipeline = FakePipeline() + + def fake_load_pipeline_bundles(config: Any) -> FakePipeline: + return pipeline + + def fake_instantiate_guardrails(stage: Any, registry: Any | None = None) -> list[Any]: + if stage is pipeline.pre_flight: + return [prompt_guard] + if stage is pipeline.input: + return [input_guard] + return [] + + calls: list[str] = [] + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + calls.append(kwargs["stage_name"]) + return [GuardrailResult(tripwire_triggered=False, info={})] + + monkeypatch.setattr(runtime_module, "load_pipeline_bundles", fake_load_pipeline_bundles, raising=False) + monkeypatch.setattr(runtime_module, "instantiate_guardrails", fake_instantiate_guardrails, raising=False) + monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) + + tool = SimpleNamespace() + agent_instance = agents.GuardrailAgent( + config={"version": 1}, + name="Test", + instructions="Help", + tools=[tool], + ) + + # Call the first input guardrail (capture function) + capture_fn = agent_instance.input_guardrails[0] + await capture_fn(agents_module.RunContextWrapper(None), agent_instance, "user question") + + assert agents._get_user_messages() == ["user question"] # noqa: S101 + + # Run actual agent guardrail + guard_fn = agent_instance.input_guardrails[1] + await guard_fn(agents_module.RunContextWrapper(None), agent_instance, "user question") + + # Tool guardrail should be attached and callable + data = agents_module.ToolInputGuardrailData(context=ToolContext("tool", {})) + await tool.tool_input_guardrails[0](data) # type: ignore[attr-defined] + + assert any(name.startswith("tool_input") for name in calls) # noqa: S101 + + +def test_guardrail_agent_without_tools(monkeypatch: pytest.MonkeyPatch) -> None: + """Agent with no tools should not attach tool guardrails.""" + pipeline = SimpleNamespace(pre_flight=None, input=None, output=None) + + monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline, raising=False) + monkeypatch.setattr(runtime_module, "instantiate_guardrails", lambda *args, **kwargs: [], raising=False) + + agent_instance = agents.GuardrailAgent(config={}, name="NoTools", instructions="None") + + assert getattr(agent_instance, "input_guardrails", []) == [] # noqa: S101 diff --git a/tests/unit/test_base_client.py b/tests/unit/test_base_client.py new file mode 100644 index 0000000..1d97db1 --- /dev/null +++ b/tests/unit/test_base_client.py @@ -0,0 +1,402 @@ +"""Unit tests covering core GuardrailsBaseClient helper methods.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest + +import guardrails.context as guardrails_context +from guardrails._base_client import GuardrailResults, GuardrailsBaseClient, GuardrailsResponse +from guardrails.context import GuardrailsContext +from guardrails.types import GuardrailResult + + +def test_extract_latest_user_message_dicts() -> None: + """Ensure latest user message and index are returned for dict inputs.""" + client = GuardrailsBaseClient() + messages = [ + {"role": "system", "content": "hello"}, + {"role": "user", "content": " hi there "}, + ] + + text, index = client._extract_latest_user_message(messages) + + assert text == "hi there" # noqa: S101 + assert index == 1 # noqa: S101 + + +def test_extract_latest_user_message_content_parts() -> None: + """Support Responses API content part lists.""" + client = GuardrailsBaseClient() + messages = [ + {"role": "assistant", "content": "prev"}, + { + "role": "user", + "content": [ + {"type": "input_text", "text": "first"}, + {"type": "summary_text", "text": "second"}, + ], + }, + ] + + text, index = client._extract_latest_user_message(messages) + + assert text == "first second" # noqa: S101 + assert index == 1 # noqa: S101 + + +def test_extract_latest_user_message_missing_user() -> None: + """Return empty payload when no user role is present.""" + client = GuardrailsBaseClient() + + text, index = client._extract_latest_user_message([{"role": "assistant", "content": "x"}]) + + assert text == "" # noqa: S101 + assert index == -1 # noqa: S101 + + +def test_apply_preflight_modifications_masks_user_message() -> None: + """Mask PII tokens for the most recent user message.""" + client = GuardrailsBaseClient() + guardrail_results = [ + GuardrailResult( + tripwire_triggered=False, + info={"detected_entities": {"PERSON": ["Alice Smith"]}}, + ) + ] + messages = [ + {"role": "user", "content": "My name is Alice Smith."}, + {"role": "assistant", "content": "Hi Alice."}, + ] + + modified = client._apply_preflight_modifications(messages, guardrail_results) + + assert modified[0]["content"] == "My name is ." # noqa: S101 + assert messages[0]["content"] == "My name is Alice Smith." # noqa: S101 + + +def test_apply_preflight_modifications_handles_strings() -> None: + """Apply masking for string payloads.""" + client = GuardrailsBaseClient() + guardrail_results = [ + GuardrailResult( + tripwire_triggered=False, + info={"detected_entities": {"PHONE": ["+1-555-0100"]}}, + ) + ] + + masked = client._apply_preflight_modifications("+1-555-0100", guardrail_results) + + assert masked == "" # noqa: S101 + + +def test_apply_preflight_modifications_skips_when_no_entities() -> None: + """Return original data when no guardrail metadata exists.""" + client = GuardrailsBaseClient() + messages = [{"role": "user", "content": "Nothing to mask"}] + guardrail_results = [GuardrailResult(tripwire_triggered=False)] + + modified = client._apply_preflight_modifications(messages, guardrail_results) + + assert modified is messages # noqa: S101 + + +def test_apply_preflight_modifications_structured_content() -> None: + """Structured content parts should be masked individually.""" + client = GuardrailsBaseClient() + guardrail_results = [ + GuardrailResult( + tripwire_triggered=False, + info={"detected_entities": {"PHONE": ["123-456"]}}, + ) + ] + messages = [ + { + "role": "user", + "content": [ + {"type": "input_text", "text": "Call 123-456"}, + {"type": "json", "value": {"raw": "no change"}}, + ], + } + ] + + modified = client._apply_preflight_modifications(messages, guardrail_results) + + assert modified[0]["content"][0]["text"] == "Call " # noqa: S101 + assert modified[0]["content"][1]["value"] == {"raw": "no change"} # noqa: S101 + + +def test_apply_preflight_modifications_object_message_handles_failure() -> None: + """If object content cannot be updated, original data should be returned.""" + client = GuardrailsBaseClient() + guardrail_results = [ + GuardrailResult( + tripwire_triggered=False, + info={"detected_entities": {"NAME": ["Alice"]}}, + ) + ] + + class Message: + def __init__(self) -> None: + self.role = "user" + self.content = "Alice" + + def __setattr__(self, key: str, value: Any) -> None: + if key == "content" and hasattr(self, key): + raise RuntimeError("cannot set") + super().__setattr__(key, value) + + msg = Message() + messages = [msg] + + modified = client._apply_preflight_modifications(messages, guardrail_results) + + assert modified is messages # noqa: S101 + + +def test_apply_preflight_modifications_no_user_message() -> None: + """When no user message exists, data should be returned unchanged.""" + client = GuardrailsBaseClient() + guardrail_results = [GuardrailResult(tripwire_triggered=False, info={"detected_entities": {"NAME": ["Alice"]}})] + messages = [{"role": "assistant", "content": "hi"}] + + modified = client._apply_preflight_modifications(messages, guardrail_results) + + assert modified is messages # noqa: S101 + + +def test_apply_preflight_modifications_non_dict_part_preserved() -> None: + """Non-dict content parts should be preserved as-is.""" + client = GuardrailsBaseClient() + guardrail_results = [GuardrailResult(tripwire_triggered=False, info={"detected_entities": {"NAME": ["Alice"]}})] + messages = [ + { + "role": "user", + "content": ["raw text"], + } + ] + + modified = client._apply_preflight_modifications(messages, guardrail_results) + + assert modified[0]["content"][0] == "raw text" # noqa: S101 + + +def test_create_guardrails_response_wraps_results() -> None: + """Combine guardrail results by stage for response.""" + client = GuardrailsBaseClient() + preflight = [GuardrailResult(tripwire_triggered=True)] + input_stage = [GuardrailResult(tripwire_triggered=False)] + output_stage = [GuardrailResult(tripwire_triggered=True)] + + response = client._create_guardrails_response( + llm_response=SimpleNamespace(choices=[]), + preflight_results=preflight, + input_results=input_stage, + output_results=output_stage, + ) + + assert isinstance(response, GuardrailsResponse) # noqa: S101 + assert response.guardrail_results.tripwires_triggered is True # noqa: S101 + assert len(response.guardrail_results.all_results) == 3 # noqa: S101 + + +def test_extract_response_text_prefers_choice_message() -> None: + """Extract message content from chat-style responses.""" + client = GuardrailsBaseClient() + response = SimpleNamespace( + choices=[ + SimpleNamespace( + message=SimpleNamespace(content="hello"), + delta=SimpleNamespace(content=None), + ) + ], + output_text=None, + delta=None, + ) + + text = client._extract_response_text(response) + + assert text == "hello" # noqa: S101 + + +def test_extract_response_text_handles_delta_type() -> None: + """Special delta responses should return delta text.""" + client = GuardrailsBaseClient() + response = SimpleNamespace(type="response.output_text.delta", delta="partial") + + assert client._extract_response_text(response) == "partial" # noqa: S101 + + +class _DummyResourceClient: + """Stub OpenAI resource client used during initialization tests.""" + + def __init__(self, **kwargs: Any) -> None: + self.kwargs = kwargs + + +class _TestableClient(GuardrailsBaseClient): + """Concrete subclass exposing _initialize_client for testing.""" + + def __init__(self) -> None: + self.override_called = False + + def _instantiate_all_guardrails(self) -> dict[str, list]: + return {"pre_flight": [], "input": [], "output": []} + + def _create_default_context(self) -> SimpleNamespace: + return SimpleNamespace(guardrail_llm="stub") + + def _override_resources(self) -> None: + self.override_called = True + + +def test_initialize_client_sets_pipeline_and_context() -> None: + """Ensure _initialize_client produces pipeline, guardrails, and context.""" + client = _TestableClient() + + client._initialize_client( + config={"version": 1, "output": {"version": 1, "guardrails": []}}, + openai_kwargs={"api_key": "abc"}, + client_class=_DummyResourceClient, + ) + + assert client.pipeline.pre_flight is None # type: ignore[attr-defined] # noqa: S101 + assert client.pipeline.output.guardrails == [] # type: ignore[attr-defined] # noqa: S101 + assert client.guardrails == {"pre_flight": [], "input": [], "output": []} # noqa: S101 + assert client.context.guardrail_llm == "stub" # type: ignore[attr-defined] # noqa: S101 + assert client._resource_client.kwargs["api_key"] == "abc" # type: ignore[attr-defined] # noqa: S101 + assert client.override_called is True # noqa: S101 + + +def test_instantiate_all_guardrails_uses_registry(monkeypatch: pytest.MonkeyPatch) -> None: + """_instantiate_all_guardrails should instantiate guardrails for each stage.""" + client = GuardrailsBaseClient() + client.pipeline = SimpleNamespace( + pre_flight=SimpleNamespace(), + input=None, + output=SimpleNamespace(), + ) + + instantiated: list[str] = [] + + def fake_instantiate(stage: Any, registry: Any) -> list[str]: + instantiated.append(str(stage)) + return ["g"] + + monkeypatch.setattr("guardrails.runtime.instantiate_guardrails", fake_instantiate) + + guardrails = client._instantiate_all_guardrails() + + assert guardrails["pre_flight"] == ["g"] # noqa: S101 + assert guardrails["input"] == [] # noqa: S101 + assert guardrails["output"] == ["g"] # noqa: S101 + assert len(instantiated) == 2 # noqa: S101 + + +def test_validate_context_invokes_validator(monkeypatch: pytest.MonkeyPatch) -> None: + """_validate_context should call validate_guardrail_context for each guardrail.""" + client = GuardrailsBaseClient() + guardrail = SimpleNamespace() + client.guardrails = {"pre_flight": [guardrail]} + + called: list[Any] = [] + + def fake_validate(gr: Any, ctx: Any) -> None: + called.append((gr, ctx)) + + monkeypatch.setattr("guardrails._base_client.validate_guardrail_context", fake_validate) + + client._validate_context(context="ctx") + + assert called == [(guardrail, "ctx")] # noqa: S101 + + +def test_apply_preflight_modifications_leaves_unknown_content() -> None: + """Unknown content types should remain untouched.""" + client = GuardrailsBaseClient() + result = GuardrailResult(tripwire_triggered=False, info={"detected_entities": {"NAME": ["Alice"]}}) + messages = [{"role": "user", "content": {"unknown": "value"}}] + + modified = client._apply_preflight_modifications(messages, [result]) + + assert modified is messages # noqa: S101 + + +def test_apply_preflight_modifications_non_string_text_retained() -> None: + """Content parts without string text should remain unchanged.""" + client = GuardrailsBaseClient() + result = GuardrailResult(tripwire_triggered=False, info={"detected_entities": {"PHONE": ["123"]}}) + messages = [ + { + "role": "user", + "content": [ + {"type": "input_text", "text": 123}, + ], + } + ] + + modified = client._apply_preflight_modifications(messages, [result]) + + assert modified[0]["content"][0]["text"] == 123 # noqa: S101 + + +def test_extract_latest_user_message_object_parts() -> None: + """Object messages with attribute content should be handled.""" + client = GuardrailsBaseClient() + + class Msg: + def __init__(self, role: str, content: Any) -> None: + self.role = role + self.content = content + + messages = [ + Msg("assistant", "ignored"), + Msg("user", [SimpleNamespace(type="input_text", text="obj text")]), + ] + + text, index = client._extract_latest_user_message(messages) + + assert text == "obj text" # noqa: S101 + assert index == 1 # noqa: S101 + + +def test_extract_response_text_fallback_returns_empty() -> None: + """Unknown response types should return empty string.""" + client = GuardrailsBaseClient() + response = SimpleNamespace(choices=[], output_text=None, delta=None) + + assert client._extract_response_text(response) == "" # noqa: S101 + + +def test_guardrail_results_properties() -> None: + """GuardrailResults should aggregate and report tripwires.""" + results = GuardrailResults( + preflight=[GuardrailResult(tripwire_triggered=False)], + input=[GuardrailResult(tripwire_triggered=True)], + output=[GuardrailResult(tripwire_triggered=False)], + ) + + assert len(results.all_results) == 3 # noqa: S101 + assert results.tripwires_triggered is True # noqa: S101 + assert results.triggered_results == [results.input[0]] # noqa: S101 + + +def test_create_default_context_raises_without_subclass() -> None: + """Base implementation should raise when no context available.""" + client = GuardrailsBaseClient() + + with pytest.raises(NotImplementedError): + client._create_default_context() + + +def test_create_default_context_uses_existing_context() -> None: + """Existing context var should be returned.""" + existing = GuardrailsContext(guardrail_llm="ctx") + guardrails_context.set_context(existing) + try: + client = GuardrailsBaseClient() + assert client._create_default_context() is existing # noqa: S101 + finally: + guardrails_context.clear_context() diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py new file mode 100644 index 0000000..e9f8e26 --- /dev/null +++ b/tests/unit/test_cli.py @@ -0,0 +1,72 @@ +"""Tests for guardrails CLI entry points.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest + +from guardrails import cli + + +def _make_guardrail(media_type: str) -> Any: + return SimpleNamespace(definition=SimpleNamespace(media_type=media_type)) + + +def test_cli_validate_success(capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch) -> None: + """Validate command should report total and matching guardrails.""" + + class FakeStage: + pass + + class FakePipeline: + def __init__(self) -> None: + self.pre_flight = FakeStage() + self.input = FakeStage() + self.output = FakeStage() + + def stages(self) -> list[FakeStage]: + return [self.pre_flight, self.input, self.output] + + pipeline = FakePipeline() + + def fake_load_pipeline_bundles(path: Any) -> FakePipeline: + assert str(path).endswith("config.json") # noqa: S101 + return pipeline + + def fake_instantiate_guardrails(stage: Any, registry: Any | None = None) -> list[Any]: + if stage is pipeline.pre_flight: + return [_make_guardrail("text/plain")] + if stage is pipeline.input: + return [_make_guardrail("application/json")] + if stage is pipeline.output: + return [_make_guardrail("text/plain")] + return [] + + monkeypatch.setattr(cli, "load_pipeline_bundles", fake_load_pipeline_bundles) + monkeypatch.setattr(cli, "instantiate_guardrails", fake_instantiate_guardrails) + + with pytest.raises(SystemExit) as excinfo: + cli.main(["validate", "config.json", "--media-type", "text/plain"]) + + assert excinfo.value.code == 0 # noqa: S101 + stdout = capsys.readouterr().out + assert "Config valid" in stdout # noqa: S101 + assert "2 matching media-type 'text/plain'" in stdout # noqa: S101 + + +def test_cli_validate_handles_errors(capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch) -> None: + """Validation errors should print to stderr and exit with status 1.""" + + def fake_load_pipeline_bundles(path: Any) -> None: + raise ValueError("failed to load") + + monkeypatch.setattr(cli, "load_pipeline_bundles", fake_load_pipeline_bundles) + + with pytest.raises(SystemExit) as excinfo: + cli.main(["validate", "bad.json"]) + + assert excinfo.value.code == 1 # noqa: S101 + stderr = capsys.readouterr().err + assert "ERROR: failed to load" in stderr # noqa: S101 diff --git a/tests/unit/test_client_async.py b/tests/unit/test_client_async.py new file mode 100644 index 0000000..ecfcb56 --- /dev/null +++ b/tests/unit/test_client_async.py @@ -0,0 +1,419 @@ +"""Tests for GuardrailsAsyncOpenAI core behaviour.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest + +import guardrails.client as client_module +from guardrails.client import GuardrailsAsyncAzureOpenAI, GuardrailsAsyncOpenAI +from guardrails.exceptions import GuardrailTripwireTriggered +from guardrails.types import GuardrailResult + + +def _minimal_config() -> dict[str, Any]: + """Return minimal pipeline config with no guardrails.""" + return {"version": 1, "output": {"version": 1, "guardrails": []}} + + +def _build_client(**kwargs: Any) -> GuardrailsAsyncOpenAI: + """Instantiate GuardrailsAsyncOpenAI with deterministic defaults.""" + return GuardrailsAsyncOpenAI(config=_minimal_config(), **kwargs) + + +def _guardrail(name: str) -> Any: + return SimpleNamespace(definition=SimpleNamespace(name=name), ctx_requirements=SimpleNamespace()) + + +@pytest.mark.asyncio +async def test_default_context_uses_distinct_guardrail_client() -> None: + """Default context should hold a fresh AsyncOpenAI instance mirroring config.""" + client = _build_client(api_key="secret-key", base_url="http://example.com") + + assert client.context is not None # noqa: S101 + assert client.context.guardrail_llm is not client # type: ignore[attr-defined] # noqa: S101 + assert client.context.guardrail_llm.api_key == "secret-key" # type: ignore[attr-defined] # noqa: S101 + assert client.context.guardrail_llm.base_url == "http://example.com" # type: ignore[attr-defined] # noqa: S101 + + +@pytest.mark.asyncio +async def test_conversation_context_tracks_injection_indices() -> None: + """Conversation-aware context exposes history and propagates index updates.""" + client = _build_client() + conversation = [{"role": "user", "content": "Hello"}] + + conv_ctx = client._create_context_with_conversation(conversation) + + assert conv_ctx.get_conversation_history() == conversation # noqa: S101 + assert conv_ctx.get_injection_last_checked_index() == 0 # noqa: S101 + + conv_ctx.update_injection_last_checked_index(3) + assert client._injection_last_checked_index == 3 # noqa: S101 + + +def test_append_llm_response_handles_string_history() -> None: + """String conversation history should be normalized before appending.""" + client = _build_client() + response = SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content="assistant reply"))], + output=None, + ) + + updated_history = client._append_llm_response_to_conversation("hi there", response) + + assert updated_history[0]["content"] == "hi there" # noqa: S101 + assert updated_history[1].message.content == "assistant reply" # type: ignore[union-attr] # noqa: S101 + + +def test_append_llm_response_handles_response_output() -> None: + """Responses API output should be appended as-is.""" + client = _build_client() + response = SimpleNamespace( + choices=None, + output=[{"role": "assistant", "content": "streamed"}], + ) + + updated_history = client._append_llm_response_to_conversation([], response) + + assert updated_history == [{"role": "assistant", "content": "streamed"}] # noqa: S101 + + +def _guardrail(name: str) -> Any: + """Create a guardrail stub with a definition name.""" + return SimpleNamespace(definition=SimpleNamespace(name=name), ctx_requirements=SimpleNamespace()) + + +@pytest.mark.asyncio +async def test_run_stage_guardrails_raises_on_tripwire(monkeypatch: pytest.MonkeyPatch) -> None: + """Tripwire results should raise unless suppressed.""" + client = _build_client() + client.guardrails["output"] = [_guardrail("basic guardrail")] + captured_ctx = {} + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + captured_ctx.update(kwargs) + return [GuardrailResult(tripwire_triggered=True)] + + monkeypatch.setattr("guardrails.client.run_guardrails", fake_run_guardrails) + + with pytest.raises(GuardrailTripwireTriggered): + await client._run_stage_guardrails("output", "payload") + + assert captured_ctx["ctx"] is client.context # noqa: S101 + assert captured_ctx["stage_name"] == "output" # noqa: S101 + + +@pytest.mark.asyncio +async def test_run_stage_guardrails_uses_conversation_context(monkeypatch: pytest.MonkeyPatch) -> None: + """Prompt injection guardrail should trigger conversation-aware context.""" + client = _build_client() + client.guardrails["output"] = [_guardrail("Prompt Injection Detection")] + conversation = [{"role": "user", "content": "Hi"}] + captured_ctx = {} + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + captured_ctx.update(kwargs) + return [GuardrailResult(tripwire_triggered=False)] + + monkeypatch.setattr("guardrails.client.run_guardrails", fake_run_guardrails) + + results = await client._run_stage_guardrails("output", "payload", conversation_history=conversation) + + assert results == [GuardrailResult(tripwire_triggered=False)] # noqa: S101 + ctx = captured_ctx["ctx"] + assert ctx.get_conversation_history() == conversation # noqa: S101 + + +@pytest.mark.asyncio +async def test_run_stage_guardrails_suppresses_tripwire(monkeypatch: pytest.MonkeyPatch) -> None: + """Suppress flag should return results even when tripwire fires.""" + client = _build_client() + client.guardrails["output"] = [_guardrail("basic guardrail")] + captured_kwargs = {} + result = GuardrailResult(tripwire_triggered=True) + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + captured_kwargs.update(kwargs) + return [result] + + monkeypatch.setattr("guardrails.client.run_guardrails", fake_run_guardrails) + + results = await client._run_stage_guardrails("output", "payload", suppress_tripwire=True) + + assert results == [result] # noqa: S101 + assert captured_kwargs["suppress_tripwire"] is True # noqa: S101 + + +@pytest.mark.asyncio +async def test_handle_llm_response_runs_output_guardrails(monkeypatch: pytest.MonkeyPatch) -> None: + """_handle_llm_response should append conversation and return response wrapper.""" + client = _build_client() + output_result = GuardrailResult(tripwire_triggered=False) + captured_text: list[str] = [] + captured_history: list[list[Any]] = [] + + async def fake_run_stage( + stage_name: str, + text: str, + conversation_history: list | None = None, + suppress_tripwire: bool = False, + ) -> list[GuardrailResult]: + captured_text.append(text) + if conversation_history is not None: + captured_history.append(conversation_history) + return [output_result] + + monkeypatch.setattr(client, "_run_stage_guardrails", fake_run_stage) # type: ignore[attr-defined] + + llm_response = SimpleNamespace( + choices=[ + SimpleNamespace( + message=SimpleNamespace(content="LLM response"), + delta=SimpleNamespace(content=None), + ) + ], + output_text=None, + ) + + response = await client._handle_llm_response( + llm_response, + preflight_results=[GuardrailResult(tripwire_triggered=False)], + input_results=[], + conversation_history=[{"role": "user", "content": "hello"}], + ) + + assert captured_text == ["LLM response"] # noqa: S101 + assert captured_history[-1][-1].message.content == "LLM response" # type: ignore[index] # noqa: S101 + assert response.guardrail_results.output == [output_result] # noqa: S101 + + +@pytest.mark.asyncio +async def test_chat_completions_create_runs_guardrails(monkeypatch: pytest.MonkeyPatch) -> None: + """chat.completions.create should execute guardrail stages.""" + client = _build_client() + client.guardrails = { + "pre_flight": [_guardrail("Prompt Injection Detection")], + "input": [_guardrail("Input Guard")], + "output": [_guardrail("Output Guard")], + } + stage_calls: list[str] = [] + + async def fake_run_stage(stage_name: str, text: str, **kwargs: Any) -> list[GuardrailResult]: + stage_calls.append(stage_name) + return [GuardrailResult(tripwire_triggered=False, info={"stage": stage_name})] + + monkeypatch.setattr(client, "_run_stage_guardrails", fake_run_stage) # type: ignore[attr-defined] + monkeypatch.setattr(client, "_apply_preflight_modifications", lambda messages, results: messages) # type: ignore[attr-defined] + + async def fake_llm(**kwargs: Any) -> Any: + return SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content="ok"), delta=SimpleNamespace(content=None))], + output=None, + output_text=None, + ) + + client._resource_client.chat = SimpleNamespace(completions=SimpleNamespace(create=fake_llm)) # type: ignore[attr-defined] + + response = await client.chat.completions.create(messages=[{"role": "user", "content": "hi"}], model="gpt") + + assert stage_calls[:2] == ["pre_flight", "input"] # noqa: S101 + assert response.guardrail_results.output[0].info["stage"] == "output" # noqa: S101 + + +@pytest.mark.asyncio +async def test_chat_completions_create_streaming(monkeypatch: pytest.MonkeyPatch) -> None: + """Streaming path should defer to _stream_with_guardrails.""" + client = _build_client() + client.guardrails = {"pre_flight": [], "input": [], "output": []} + + def fake_stream_with_guardrails(*args: Any, **kwargs: Any): + async def _gen(): + yield "chunk" + + return _gen() + + monkeypatch.setattr(client, "_stream_with_guardrails", fake_stream_with_guardrails) # type: ignore[attr-defined] + + async def fake_llm(**kwargs: Any) -> Any: + async def _aiter(): + yield SimpleNamespace(choices=[SimpleNamespace(delta=SimpleNamespace(content="c"))]) + + return _aiter() + + client._resource_client.chat = SimpleNamespace(completions=SimpleNamespace(create=fake_llm)) # type: ignore[attr-defined] + + stream = await client.chat.completions.create(messages=[{"role": "user", "content": "hi"}], model="gpt", stream=True) + + chunks = [] + async for value in stream: + chunks.append(value) + + assert chunks == ["chunk"] # noqa: S101 + + +@pytest.mark.asyncio +async def test_responses_create_runs_guardrails(monkeypatch: pytest.MonkeyPatch) -> None: + """responses.create should run guardrail stages and handle output.""" + client = _build_client() + client.guardrails = {"pre_flight": [], "input": [_guardrail("Input Guard")], "output": [_guardrail("Output Guard")]} + stage_calls: list[str] = [] + + async def fake_run_stage(stage_name: str, text: str, **kwargs: Any) -> list[GuardrailResult]: + stage_calls.append(stage_name) + return [GuardrailResult(tripwire_triggered=False)] + + monkeypatch.setattr(client, "_run_stage_guardrails", fake_run_stage) # type: ignore[attr-defined] + monkeypatch.setattr(client, "_apply_preflight_modifications", lambda messages, results: messages) # type: ignore[attr-defined] + + async def fake_llm(**kwargs: Any) -> Any: + return SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content="ok"), delta=SimpleNamespace(content=None))], + output=None, + output_text=None, + ) + + client._resource_client.responses = SimpleNamespace(create=fake_llm) # type: ignore[attr-defined] + + result = await client.responses.create(input=[{"role": "user", "content": "hi"}], model="gpt") + + assert "input" in stage_calls # noqa: S101 + assert result.guardrail_results.output[0].tripwire_triggered is False # noqa: S101 + + +@pytest.mark.asyncio +async def test_responses_parse_runs_guardrails(monkeypatch: pytest.MonkeyPatch) -> None: + """responses.parse should invoke guardrails and return wrapped response.""" + client = _build_client() + client.guardrails = {"pre_flight": [], "input": [_guardrail("Input Guard")], "output": []} + + async def fake_run_stage(stage_name: str, text: str, **kwargs: Any) -> list[GuardrailResult]: + return [GuardrailResult(tripwire_triggered=False)] + + monkeypatch.setattr(client, "_run_stage_guardrails", fake_run_stage) # type: ignore[attr-defined] + monkeypatch.setattr(client, "_apply_preflight_modifications", lambda messages, results: messages) # type: ignore[attr-defined] + + async def fake_llm(**kwargs: Any) -> Any: + return SimpleNamespace(output_text="{}", output=[{"type": "message", "content": "parsed"}]) + + client._resource_client.responses = SimpleNamespace(parse=fake_llm) # type: ignore[attr-defined] + + result = await client.responses.parse(input=[{"role": "user", "content": "hi"}], model="gpt", text_format=dict) + + assert result.guardrail_results.input[0].tripwire_triggered is False # noqa: S101 + + +@pytest.mark.asyncio +async def test_responses_retrieve_runs_output_guardrails(monkeypatch: pytest.MonkeyPatch) -> None: + """responses.retrieve should execute output guardrails.""" + client = _build_client() + client.guardrails = {"pre_flight": [], "input": [], "output": [_guardrail("Output Guard")]} + + async def fake_run_stage(stage_name: str, text: str, **kwargs: Any) -> list[GuardrailResult]: + return [GuardrailResult(tripwire_triggered=False, info={"stage": stage_name})] + + monkeypatch.setattr(client, "_run_stage_guardrails", fake_run_stage) # type: ignore[attr-defined] + + async def retrieve_response(*args: Any, **kwargs: Any) -> Any: + return SimpleNamespace(output_text="hi") + + client._resource_client.responses = SimpleNamespace(retrieve=retrieve_response) # type: ignore[attr-defined] + + result = await client.responses.retrieve("resp") + + assert result.guardrail_results.output[0].info["stage"] == "output" # noqa: S101 + + +@pytest.mark.asyncio +async def test_async_azure_run_stage_guardrails(monkeypatch: pytest.MonkeyPatch) -> None: + """Azure async client should reuse conversation context.""" + client = GuardrailsAsyncAzureOpenAI(config=_minimal_config(), api_key="key") + client.guardrails = {"output": [_guardrail("Prompt Injection Detection")]} + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + return [GuardrailResult(tripwire_triggered=False)] + + monkeypatch.setattr(client_module, "run_guardrails", fake_run_guardrails) + + results = await client._run_stage_guardrails("output", "payload", conversation_history=[{"role": "user", "content": "hi"}]) + + assert results[0].tripwire_triggered is False # noqa: S101 + + +@pytest.mark.asyncio +async def test_async_azure_default_context() -> None: + """Azure async client should provide default context when needed.""" + client = GuardrailsAsyncAzureOpenAI(config=_minimal_config(), api_key="key") + context = client._create_default_context() + + assert hasattr(context, "guardrail_llm") # noqa: S101 + + +@pytest.mark.asyncio +async def test_async_azure_append_response() -> None: + """Azure async append helper should merge responses.""" + client = GuardrailsAsyncAzureOpenAI(config=_minimal_config(), api_key="key") + history = client._append_llm_response_to_conversation(None, SimpleNamespace(output=[{"role": "assistant", "content": "data"}], choices=None)) + + assert history[-1]["content"] == "data" # type: ignore[index] # noqa: S101 + + +@pytest.mark.asyncio +async def test_async_azure_handle_llm_response(monkeypatch: pytest.MonkeyPatch) -> None: + """Azure async _handle_llm_response should call output guardrails.""" + client = GuardrailsAsyncAzureOpenAI(config=_minimal_config(), api_key="key") + client.guardrails = {"output": [_guardrail("Output")], "pre_flight": [], "input": []} + + async def fake_run_stage(stage_name: str, text: str, **kwargs: Any) -> list[GuardrailResult]: + return [GuardrailResult(tripwire_triggered=False)] + + monkeypatch.setattr(client, "_run_stage_guardrails", fake_run_stage) # type: ignore[attr-defined] + + sentinel = object() + + def fake_create_response(*args: Any, **kwargs: Any) -> Any: + return sentinel + + monkeypatch.setattr(client, "_create_guardrails_response", fake_create_response) # type: ignore[attr-defined] + + result = await client._handle_llm_response( + llm_response=SimpleNamespace(output_text="value", choices=[]), + preflight_results=[], + input_results=[], + conversation_history=[], + ) + + assert result is sentinel # noqa: S101 + + +@pytest.mark.asyncio +async def test_async_azure_context_with_conversation() -> None: + """Azure async conversation context should track indices.""" + client = GuardrailsAsyncAzureOpenAI(config=_minimal_config(), api_key="key") + ctx = client._create_context_with_conversation([{"role": "user", "content": "hi"}]) + + assert ctx.get_conversation_history()[0]["content"] == "hi" # type: ignore[index] # noqa: S101 + ctx.update_injection_last_checked_index(3) + assert client._injection_last_checked_index == 3 # noqa: S101 + + +@pytest.mark.asyncio +async def test_async_azure_run_stage_guardrails_suppressed(monkeypatch: pytest.MonkeyPatch) -> None: + """Tripwire should be suppressed when requested.""" + client = GuardrailsAsyncAzureOpenAI(config=_minimal_config(), api_key="key") + client.guardrails = {"output": [_guardrail("Prompt Injection Detection")]} + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + return [GuardrailResult(tripwire_triggered=True)] + + monkeypatch.setattr(client_module, "run_guardrails", fake_run_guardrails) + + results = await client._run_stage_guardrails( + "output", + "payload", + conversation_history=[{"role": "user", "content": "hi"}], + suppress_tripwire=True, + ) + + assert results[0].tripwire_triggered is True # noqa: S101 diff --git a/tests/unit/test_client_sync.py b/tests/unit/test_client_sync.py new file mode 100644 index 0000000..0540d45 --- /dev/null +++ b/tests/unit/test_client_sync.py @@ -0,0 +1,616 @@ +"""Tests for GuardrailsOpenAI synchronous client behaviour.""" + +from __future__ import annotations + +import asyncio +from types import SimpleNamespace +from typing import Any + +import pytest + +import guardrails.client as client_module +import guardrails.context as guardrails_context +from guardrails._base_client import GuardrailsResponse +from guardrails.client import ( + GuardrailsAsyncAzureOpenAI, + GuardrailsAzureOpenAI, + GuardrailsOpenAI, +) +from guardrails.context import GuardrailsContext +from guardrails.exceptions import GuardrailTripwireTriggered +from guardrails.types import GuardrailResult + + +def _minimal_config() -> dict[str, Any]: + """Return minimal pipeline config with no guardrails.""" + return {"version": 1, "output": {"version": 1, "guardrails": []}} + + +def _build_client(**kwargs: Any) -> GuardrailsOpenAI: + """Instantiate GuardrailsOpenAI with deterministic defaults.""" + return GuardrailsOpenAI(config=_minimal_config(), **kwargs) + + +def _guardrail(name: str) -> Any: + """Create a guardrail stub with a definition name.""" + return SimpleNamespace(definition=SimpleNamespace(name=name), ctx_requirements=SimpleNamespace()) + + +@pytest.fixture(autouse=True) +def reset_context() -> None: + guardrails_context.clear_context() + yield + guardrails_context.clear_context() + + +def test_default_context_uses_distinct_guardrail_client() -> None: + """Default context should hold a fresh OpenAI instance mirroring config.""" + client = _build_client(api_key="secret-key", base_url="http://example.com") + + assert client.context is not None # noqa: S101 + assert client.context.guardrail_llm is not client # type: ignore[attr-defined] # noqa: S101 + assert client.context.guardrail_llm.api_key == "secret-key" # type: ignore[attr-defined] # noqa: S101 + assert client.context.guardrail_llm.base_url == "http://example.com" # type: ignore[attr-defined] # noqa: S101 + + +def test_conversation_context_tracks_injection_indices() -> None: + """Conversation-aware context exposes history and propagates index updates.""" + client = _build_client() + conversation = [{"role": "user", "content": "Hello"}] + + conv_ctx = client._create_context_with_conversation(conversation) + + assert conv_ctx.get_conversation_history() == conversation # noqa: S101 + assert conv_ctx.get_injection_last_checked_index() == 0 # noqa: S101 + + conv_ctx.update_injection_last_checked_index(5) + assert client._injection_last_checked_index == 5 # noqa: S101 + + +def test_create_default_context_uses_contextvar() -> None: + """Existing context should be reused by derived client.""" + existing = GuardrailsContext(guardrail_llm="existing") + guardrails_context.set_context(existing) + try: + client = _build_client() + assert client._create_default_context() is existing # noqa: S101 + finally: + guardrails_context.clear_context() + + +def test_append_llm_response_handles_string_history() -> None: + """String conversation history should be normalized before appending.""" + client = _build_client() + response = SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content="assistant reply"))], + output=None, + ) + + updated_history = client._append_llm_response_to_conversation("hi there", response) + + assert updated_history[0]["content"] == "hi there" # noqa: S101 + assert updated_history[1].message.content == "assistant reply" # type: ignore[union-attr] # noqa: S101 + + +def test_append_llm_response_handles_response_output() -> None: + """Responses API output should be appended as-is.""" + client = _build_client() + response = SimpleNamespace( + choices=None, + output=[{"role": "assistant", "content": "streamed"}], + ) + + updated_history = client._append_llm_response_to_conversation([], response) + + assert updated_history == [{"role": "assistant", "content": "streamed"}] # noqa: S101 + + +def test_append_llm_response_handles_none_history() -> None: + """None conversation history should be converted to list.""" + client = _build_client() + response = SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content="assistant reply"))], + output=None, + ) + + history = client._append_llm_response_to_conversation(None, response) + + assert history[-1].message.content == "assistant reply" # type: ignore[union-attr] # noqa: S101 + + +def test_run_stage_guardrails_raises_on_tripwire(monkeypatch: pytest.MonkeyPatch) -> None: + """Tripwire results should raise unless suppressed.""" + client = _build_client() + client.guardrails["output"] = [_guardrail("basic guardrail")] + captured_kwargs: dict[str, Any] = {} + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + captured_kwargs.update(kwargs) + return [GuardrailResult(tripwire_triggered=True)] + + monkeypatch.setattr("guardrails.client.run_guardrails", fake_run_guardrails) + + with pytest.raises(GuardrailTripwireTriggered): + client._run_stage_guardrails("output", "payload") + + assert captured_kwargs["ctx"] is client.context # noqa: S101 + assert captured_kwargs["stage_name"] == "output" # noqa: S101 + + +def test_run_stage_guardrails_uses_conversation_context(monkeypatch: pytest.MonkeyPatch) -> None: + """Prompt injection guardrail should trigger conversation-aware context.""" + client = _build_client() + client.guardrails["output"] = [_guardrail("Prompt Injection Detection")] + conversation = [{"role": "user", "content": "Hi"}] + captured_kwargs: dict[str, Any] = {} + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + captured_kwargs.update(kwargs) + return [GuardrailResult(tripwire_triggered=False)] + + monkeypatch.setattr("guardrails.client.run_guardrails", fake_run_guardrails) + + results = client._run_stage_guardrails("output", "payload", conversation_history=conversation) + + assert results == [GuardrailResult(tripwire_triggered=False)] # noqa: S101 + ctx = captured_kwargs["ctx"] + assert ctx.get_conversation_history() == conversation # noqa: S101 + + +def test_run_stage_guardrails_suppresses_tripwire(monkeypatch: pytest.MonkeyPatch) -> None: + """Suppress flag should return results even when tripwire fires.""" + client = _build_client() + client.guardrails["output"] = [_guardrail("basic guardrail")] + result = GuardrailResult(tripwire_triggered=True) + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + return [result] + + monkeypatch.setattr("guardrails.client.run_guardrails", fake_run_guardrails) + + results = client._run_stage_guardrails("output", "payload", suppress_tripwire=True) + + assert results == [result] # noqa: S101 + + +def test_run_stage_guardrails_handles_empty_guardrails() -> None: + """If no guardrails are configured for the stage, return empty list.""" + client = _build_client() + client.guardrails["input"] = [] + + assert client._run_stage_guardrails("input", "text") == [] # noqa: S101 + + +def test_run_stage_guardrails_raises_on_error(monkeypatch: pytest.MonkeyPatch) -> None: + """Exceptions should propagate when raise_guardrail_errors is True.""" + client = _build_client() + client.guardrails["output"] = [_guardrail("guard")] + client.raise_guardrail_errors = True + + async def failing_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + raise RuntimeError("boom") + + monkeypatch.setattr(client_module, "run_guardrails", failing_run_guardrails) + + with pytest.raises(RuntimeError): + client._run_stage_guardrails("output", "payload") + + +def test_run_stage_guardrails_updates_conversation_index(monkeypatch: pytest.MonkeyPatch) -> None: + """Prompt injection guardrail should update injection index after run.""" + client = _build_client() + guardrail = _guardrail("Prompt Injection Detection") + client.guardrails["output"] = [guardrail] + client._injection_last_checked_index = 0 + + captured_ctx: list[Any] = [] + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + captured_ctx.append(kwargs["ctx"]) + return [GuardrailResult(tripwire_triggered=False)] + + monkeypatch.setattr(client_module, "run_guardrails", fake_run_guardrails) + + client._run_stage_guardrails("output", "payload", conversation_history=[{"role": "user", "content": "hi"}]) + + assert captured_ctx[0].get_conversation_history() == [{"role": "user", "content": "hi"}] # noqa: S101 + + +def test_run_stage_guardrails_creates_event_loop(monkeypatch: pytest.MonkeyPatch) -> None: + """GuardrailsOpenAI should create a new loop when none is running.""" + client = _build_client() + client.guardrails["output"] = [_guardrail("guard")] + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + return [GuardrailResult(tripwire_triggered=False)] + + monkeypatch.setattr(client_module, "run_guardrails", fake_run_guardrails) + + original_new_event_loop = asyncio.new_event_loop + loops: list[asyncio.AbstractEventLoop] = [] + + def fake_get_event_loop() -> asyncio.AbstractEventLoop: + raise RuntimeError + + def fake_new_event_loop() -> asyncio.AbstractEventLoop: + loop = original_new_event_loop() + loops.append(loop) + return loop + + monkeypatch.setattr(asyncio, "get_event_loop", fake_get_event_loop) + monkeypatch.setattr(asyncio, "new_event_loop", fake_new_event_loop) + monkeypatch.setattr(asyncio, "set_event_loop", lambda loop: None) + + try: + result = client._run_stage_guardrails("output", "payload") + assert result[0].tripwire_triggered is False # noqa: S101 + finally: + for loop in loops: + loop.close() + + +def test_handle_llm_response_runs_output_guardrails(monkeypatch: pytest.MonkeyPatch) -> None: + """_handle_llm_response should append conversation and return response wrapper.""" + client = _build_client() + output_result = GuardrailResult(tripwire_triggered=False) + captured_text: list[str] = [] + captured_history: list[list[Any]] = [] + + def fake_run_stage( + stage_name: str, + text: str, + conversation_history: list | None = None, + suppress_tripwire: bool = False, + ) -> list[GuardrailResult]: + captured_text.append(text) + if conversation_history is not None: + captured_history.append(conversation_history) + return [output_result] + + monkeypatch.setattr(client, "_run_stage_guardrails", fake_run_stage) # type: ignore[attr-defined] + + llm_response = SimpleNamespace( + choices=[ + SimpleNamespace( + message=SimpleNamespace(content="LLM response"), + delta=SimpleNamespace(content=None), + ) + ], + output_text=None, + ) + + response = client._handle_llm_response( + llm_response, + preflight_results=[GuardrailResult(tripwire_triggered=False)], + input_results=[], + conversation_history=[{"role": "user", "content": "hello"}], + ) + + assert captured_text == ["LLM response"] # noqa: S101 + assert captured_history[-1][-1].message.content == "LLM response" # type: ignore[index] # noqa: S101 + assert response.guardrail_results.output == [output_result] # noqa: S101 + + +def test_handle_llm_response_suppresses_tripwire(monkeypatch: pytest.MonkeyPatch) -> None: + """Suppress flag should return results even when output guardrail trips.""" + client = _build_client() + + def fake_run_stage( + stage_name: str, + text: str, + conversation_history: list | None = None, + suppress_tripwire: bool = False, + ) -> list[GuardrailResult]: + return [GuardrailResult(tripwire_triggered=True)] + + monkeypatch.setattr(client, "_run_stage_guardrails", fake_run_stage) # type: ignore[attr-defined] + + response = client._handle_llm_response( + llm_response=SimpleNamespace(output_text="value", choices=[]), + preflight_results=[], + input_results=[], + conversation_history=[], + suppress_tripwire=True, + ) + + assert response.guardrail_results.output[0].tripwire_triggered is True # noqa: S101 + + +def test_chat_completions_create_executes_guardrails(monkeypatch: pytest.MonkeyPatch) -> None: + """chat.completions.create should execute guardrail stages.""" + client = _build_client() + client.guardrails = {"pre_flight": [_guardrail("Prompt Injection Detection")], "input": [_guardrail("Input")], "output": [_guardrail("Output")]} + stages: list[str] = [] + + def fake_run_stage(stage_name: str, text: str, **kwargs: Any) -> list[GuardrailResult]: + stages.append(stage_name) + return [GuardrailResult(tripwire_triggered=False)] + + monkeypatch.setattr(client, "_run_stage_guardrails", fake_run_stage) # type: ignore[attr-defined] + monkeypatch.setattr(client, "_apply_preflight_modifications", lambda messages, results: messages) # type: ignore[attr-defined] + + class _InlineExecutor: + def __init__(self, *args: Any, **kwargs: Any) -> None: + _ = (args, kwargs) + + def __enter__(self) -> _InlineExecutor: + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def submit(self, fn, *args, **kwargs): + class _ImmediateFuture: + def __init__(self) -> None: + self._result = fn(*args, **kwargs) + + def result(self) -> Any: + return self._result + + return _ImmediateFuture() + + monkeypatch.setattr("guardrails.resources.chat.chat.ThreadPoolExecutor", _InlineExecutor) + + def fake_llm(**kwargs: Any) -> Any: + return SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content="ok"), delta=SimpleNamespace(content=None))], + output_text=None, + ) + + client._resource_client.chat = SimpleNamespace(completions=SimpleNamespace(create=fake_llm)) # type: ignore[attr-defined] + + sentinel = object() + + def fake_handle_response(llm_response: Any, preflight_results: list[GuardrailResult], input_results: list[GuardrailResult], **kwargs: Any) -> Any: + return sentinel + + monkeypatch.setattr(client, "_handle_llm_response", fake_handle_response) # type: ignore[attr-defined] + + result = client.chat.completions.create(messages=[{"role": "user", "content": "hi"}], model="gpt") + + assert "pre_flight" in stages and "input" in stages # noqa: S101 + assert result is sentinel # noqa: S101 + + +def test_chat_completions_create_stream(monkeypatch: pytest.MonkeyPatch) -> None: + """Streaming mode should use _stream_with_guardrails_sync.""" + client = _build_client() + client.guardrails = {"pre_flight": [], "input": [], "output": []} + + class _InlineExecutor: + def __init__(self, *args: Any, **kwargs: Any) -> None: + _ = (args, kwargs) + + def __enter__(self) -> _InlineExecutor: + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def submit(self, fn, *args, **kwargs): + class _ImmediateFuture: + def __init__(self) -> None: + self._result = fn(*args, **kwargs) + + def result(self) -> Any: + return self._result + + return _ImmediateFuture() + + monkeypatch.setattr("guardrails.resources.chat.chat.ThreadPoolExecutor", _InlineExecutor) + + def fake_llm(**kwargs: Any) -> Any: + return iter([SimpleNamespace(choices=[SimpleNamespace(delta=SimpleNamespace(content="c"))])]) + + client._resource_client.chat = SimpleNamespace(completions=SimpleNamespace(create=fake_llm)) # type: ignore[attr-defined] + monkeypatch.setattr(client, "_stream_with_guardrails_sync", lambda *args, **kwargs: ["chunk"]) # type: ignore[attr-defined] + + result = client.chat.completions.create(messages=[{"role": "user", "content": "hi"}], model="gpt", stream=True) + + assert result == ["chunk"] # noqa: S101 + + +def test_responses_create_executes_guardrails(monkeypatch: pytest.MonkeyPatch) -> None: + """responses.create should run stages and wrap response.""" + client = _build_client() + client.guardrails = {"pre_flight": [], "input": [_guardrail("Input")], "output": [_guardrail("Output")]} + stages: list[str] = [] + + def fake_run_stage(stage_name: str, text: str, **kwargs: Any) -> list[GuardrailResult]: + stages.append(stage_name) + return [GuardrailResult(tripwire_triggered=False)] + + monkeypatch.setattr(client, "_run_stage_guardrails", fake_run_stage) # type: ignore[attr-defined] + monkeypatch.setattr(client, "_apply_preflight_modifications", lambda messages, results: messages) # type: ignore[attr-defined] + + def fake_llm(**kwargs: Any) -> Any: + return SimpleNamespace(output_text="text", choices=[SimpleNamespace(message=SimpleNamespace(content="ok"))]) + + client._resource_client.responses = SimpleNamespace(create=fake_llm) # type: ignore[attr-defined] + + response = client.responses.create(input=[{"role": "user", "content": "hi"}], model="gpt") + + assert "input" in stages and "output" in stages # noqa: S101 + assert isinstance(response, GuardrailsResponse) # noqa: S101 + + +def test_responses_parse_executes_guardrails(monkeypatch: pytest.MonkeyPatch) -> None: + """responses.parse should run guardrails and return wrapper.""" + client = _build_client() + client.guardrails = {"pre_flight": [], "input": [_guardrail("Input")], "output": []} + + def fake_run_stage(stage_name: str, text: str, **kwargs: Any) -> list[GuardrailResult]: + return [GuardrailResult(tripwire_triggered=False)] + + monkeypatch.setattr(client, "_run_stage_guardrails", fake_run_stage) # type: ignore[attr-defined] + monkeypatch.setattr(client, "_apply_preflight_modifications", lambda messages, results: messages) # type: ignore[attr-defined] + + def fake_parse(**kwargs: Any) -> Any: + return SimpleNamespace(output_text="{}", output=[{"type": "message", "content": "parsed"}]) + + client._resource_client.responses = SimpleNamespace(parse=fake_parse) # type: ignore[attr-defined] + + sentinel = object() + + def fake_handle_parse(llm_response: Any, preflight_results: list[GuardrailResult], input_results: list[GuardrailResult], **kwargs: Any) -> Any: + return sentinel + + monkeypatch.setattr(client, "_handle_llm_response", fake_handle_parse) # type: ignore[attr-defined] + + response = client.responses.parse(input=[{"role": "user", "content": "hi"}], model="gpt", text_format=dict) + + assert response is sentinel # noqa: S101 + + +def test_responses_retrieve_executes_guardrails(monkeypatch: pytest.MonkeyPatch) -> None: + """responses.retrieve should run output guardrails.""" + client = _build_client() + client.guardrails = {"pre_flight": [], "input": [], "output": [_guardrail("Output")]} + + def fake_run_stage(stage_name: str, text: str, **kwargs: Any) -> list[GuardrailResult]: + return [GuardrailResult(tripwire_triggered=False)] + + monkeypatch.setattr(client, "_run_stage_guardrails", fake_run_stage) # type: ignore[attr-defined] + + client._resource_client.responses = SimpleNamespace(retrieve=lambda *args, **kwargs: SimpleNamespace(output_text="hi")) # type: ignore[attr-defined] + + sentinel = object() + + def fake_create_response( + response: Any, preflight: list[GuardrailResult], input_results: list[GuardrailResult], output_results: list[GuardrailResult] + ) -> Any: + return sentinel + + monkeypatch.setattr(client, "_create_guardrails_response", fake_create_response) # type: ignore[attr-defined] + + response = client.responses.retrieve("resp") + + assert response is sentinel # noqa: S101 + + +def test_azure_clients_initialize() -> None: + """Azure variants should initialize using azure kwargs.""" + async_client = GuardrailsAsyncAzureOpenAI(config=_minimal_config(), api_key="key", azure_param=1) + sync_client = GuardrailsAzureOpenAI(config=_minimal_config(), api_key="key", azure_param=1) + + assert async_client._azure_kwargs["azure_param"] == 1 # type: ignore[attr-defined] # noqa: S101 + assert sync_client._azure_kwargs["azure_param"] == 1 # type: ignore[attr-defined] # noqa: S101 + + +def test_azure_sync_run_stage_guardrails(monkeypatch: pytest.MonkeyPatch) -> None: + """Azure sync client should run guardrails with conversation context.""" + client = GuardrailsAzureOpenAI(config=_minimal_config(), api_key="key") + client.guardrails = {"output": [_guardrail("Prompt Injection Detection")]} + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + return [GuardrailResult(tripwire_triggered=False)] + + monkeypatch.setattr(client_module, "run_guardrails", fake_run_guardrails) + + result = client._run_stage_guardrails("output", "payload", conversation_history=[{"role": "user", "content": "hi"}]) + + assert result[0].tripwire_triggered is False # noqa: S101 + + +def test_azure_sync_append_response() -> None: + """Azure sync append helper should handle string history.""" + client = GuardrailsAzureOpenAI(config=_minimal_config(), api_key="key") + history = client._append_llm_response_to_conversation( + "hi", SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="reply"))], output=None) + ) + + assert history[-1].message.content == "reply" # type: ignore[union-attr] # noqa: S101 + + +def test_azure_sync_handle_llm_response(monkeypatch: pytest.MonkeyPatch) -> None: + """Azure sync _handle_llm_response should call output guardrails.""" + client = GuardrailsAzureOpenAI(config=_minimal_config(), api_key="key") + client.guardrails = {"output": [_guardrail("Output")], "pre_flight": [], "input": []} + + def fake_run_stage(stage_name: str, text: str, **kwargs: Any) -> list[GuardrailResult]: + return [GuardrailResult(tripwire_triggered=False)] + + monkeypatch.setattr(client, "_run_stage_guardrails", fake_run_stage) # type: ignore[attr-defined] + + sentinel = object() + + def fake_create_response(*args: Any, **kwargs: Any) -> Any: + return sentinel + + monkeypatch.setattr(client, "_create_guardrails_response", fake_create_response) # type: ignore[attr-defined] + + result = client._handle_llm_response( + llm_response=SimpleNamespace(output_text="text", choices=[]), + preflight_results=[], + input_results=[], + conversation_history=[], + ) + + assert result is sentinel # noqa: S101 + + +def test_azure_sync_context_with_conversation() -> None: + """Azure sync conversation context should track indices.""" + client = GuardrailsAzureOpenAI(config=_minimal_config(), api_key="key") + context = client._create_context_with_conversation([{"role": "user", "content": "hi"}]) + + assert context.get_conversation_history()[0]["content"] == "hi" # type: ignore[index] # noqa: S101 + context.update_injection_last_checked_index(4) + assert client._injection_last_checked_index == 4 # noqa: S101 + + +def test_azure_sync_run_stage_guardrails_suppressed(monkeypatch: pytest.MonkeyPatch) -> None: + """Tripwire should be suppressed when requested for Azure sync client.""" + client = GuardrailsAzureOpenAI(config=_minimal_config(), api_key="key") + client.guardrails = {"output": [_guardrail("Prompt Injection Detection")]} + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + return [GuardrailResult(tripwire_triggered=True)] + + monkeypatch.setattr(client_module, "run_guardrails", fake_run_guardrails) + + results = client._run_stage_guardrails( + "output", + "payload", + conversation_history=[{"role": "user", "content": "hi"}], + suppress_tripwire=True, + ) + + assert results[0].tripwire_triggered is True # noqa: S101 + + +def test_handle_llm_response_suppresses_tripwire_output(monkeypatch: pytest.MonkeyPatch) -> None: + """Suppressed output guardrails should return triggered result.""" + client = _build_client() + + def fake_run_stage( + stage_name: str, + text: str, + conversation_history: list | None = None, + suppress_tripwire: bool = False, + ) -> list[GuardrailResult]: + return [GuardrailResult(tripwire_triggered=True)] + + monkeypatch.setattr(client, "_run_stage_guardrails", fake_run_stage) # type: ignore[attr-defined] + + response = SimpleNamespace(output_text="text", choices=[]) + + result = client._handle_llm_response( + response, + preflight_results=[], + input_results=[], + conversation_history=[], + suppress_tripwire=True, + ) + + assert result.guardrail_results.output[0].tripwire_triggered is True # noqa: S101 + + +def test_override_resources_replaces_chat_and_responses() -> None: + """_override_resources should swap chat and responses objects.""" + client = _build_client() + # Manually call override to ensure replacement occurs + client._override_resources() + + assert hasattr(client.chat, "completions") # noqa: S101 + assert hasattr(client.responses, "create") # noqa: S101 diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py new file mode 100644 index 0000000..bd34790 --- /dev/null +++ b/tests/unit/test_context.py @@ -0,0 +1,37 @@ +"""Tests for guardrails.context helpers.""" + +from __future__ import annotations + +from dataclasses import FrozenInstanceError + +import pytest + +from guardrails.context import GuardrailsContext, clear_context, get_context, has_context, set_context + + +class _StubClient: + """Minimal client placeholder for GuardrailsContext.""" + + api_key = "stub" + + +def test_set_and_get_context_roundtrip() -> None: + """set_context should make context available via get_context.""" + context = GuardrailsContext(guardrail_llm=_StubClient()) + set_context(context) + + retrieved = get_context() + assert retrieved is context # noqa: S101 + assert has_context() is True # noqa: S101 + + clear_context() + assert get_context() is None # noqa: S101 + assert has_context() is False # noqa: S101 + + +def test_context_is_immutable() -> None: + """GuardrailsContext should be frozen.""" + context = GuardrailsContext(guardrail_llm=_StubClient()) + + with pytest.raises(FrozenInstanceError): + context.guardrail_llm = None # type: ignore[misc] diff --git a/tests/unit/test_registry.py b/tests/unit/test_registry.py index c6fefe0..903f70e 100644 --- a/tests/unit/test_registry.py +++ b/tests/unit/test_registry.py @@ -38,11 +38,7 @@ def check(_ctx: CtxProto, _value: str, _config: int) -> GuardrailResult: model = _resolve_ctx_requirements(check) # Prefer Pydantic v2 API without eagerly touching deprecated v1 attributes - fields = ( - model.model_fields - if hasattr(model, "model_fields") - else getattr(model, "__fields__", {}) - ) + fields = model.model_fields if hasattr(model, "model_fields") else getattr(model, "__fields__", {}) assert issubclass(model, BaseModel) # noqa: S101 assert set(fields) == {"foo"} # noqa: S101 diff --git a/tests/unit/test_resources_chat.py b/tests/unit/test_resources_chat.py new file mode 100644 index 0000000..2a73ca3 --- /dev/null +++ b/tests/unit/test_resources_chat.py @@ -0,0 +1,277 @@ +"""Tests for chat resource wrappers.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest + +from guardrails.resources.chat.chat import AsyncChatCompletions, ChatCompletions + + +class _InlineExecutor: + """Minimal executor that runs submitted callables synchronously.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + _ = (args, kwargs) + + def __enter__(self) -> _InlineExecutor: + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def submit(self, fn, *args, **kwargs): + class _ImmediateFuture: + def __init__(self) -> None: + self._result = fn(*args, **kwargs) + + def result(self) -> Any: + return self._result + + return _ImmediateFuture() + + +class _SyncClient: + """Fake synchronous guardrails client for ChatCompletions tests.""" + + def __init__(self) -> None: + self.preflight_calls: list[dict[str, Any]] = [] + self.input_calls: list[dict[str, Any]] = [] + self.applied: list[Any] = [] + self.handle_calls: list[dict[str, Any]] = [] + self.stream_calls: list[dict[str, Any]] = [] + self.latest_messages: list[Any] = [] + self._resource_client = SimpleNamespace( + chat=SimpleNamespace( + completions=SimpleNamespace(create=self._llm_call), + ) + ) + self._llm_response = SimpleNamespace(type="llm") + self._stream_result = "stream" + self._handle_result = "handled" + + def _llm_call(self, **kwargs: Any) -> Any: + self.llm_kwargs = kwargs + return self._llm_response + + def _extract_latest_user_message(self, messages: list[dict[str, str]]) -> tuple[str, int]: + self.latest_messages.append(messages) + return messages[-1]["content"], len(messages) - 1 + + def _run_stage_guardrails( + self, + stage_name: str, + text: str, + conversation_history: list | None = None, + suppress_tripwire: bool = False, + ) -> list[Any]: + call = { + "stage": stage_name, + "text": text, + "history": conversation_history, + "suppress": suppress_tripwire, + } + if stage_name == "pre_flight": + self.preflight_calls.append(call) + return ["preflight"] + self.input_calls.append(call) + return ["input"] + + def _apply_preflight_modifications(self, messages: list[dict[str, str]], results: list[Any]) -> list[Any]: + self.applied.append((messages, results)) + return [{"role": "user", "content": "modified"}] + + def _handle_llm_response( + self, + llm_response: Any, + preflight_results: list[Any], + input_results: list[Any], + conversation_history: list | None = None, + suppress_tripwire: bool = False, + ) -> Any: + self.handle_calls.append( + { + "response": llm_response, + "preflight": preflight_results, + "input": input_results, + "history": conversation_history, + } + ) + return self._handle_result + + def _stream_with_guardrails_sync( + self, + llm_stream: Any, + preflight_results: list[Any], + input_results: list[Any], + suppress_tripwire: bool = False, + ) -> Any: + self.stream_calls.append( + { + "stream": llm_stream, + "preflight": preflight_results, + "input": input_results, + "suppress": suppress_tripwire, + } + ) + return self._stream_result + + +class _AsyncClient: + """Fake asynchronous client for AsyncChatCompletions tests.""" + + def __init__(self) -> None: + self.preflight_calls: list[dict[str, Any]] = [] + self.input_calls: list[dict[str, Any]] = [] + self.applied: list[Any] = [] + self.handle_calls: list[dict[str, Any]] = [] + self.stream_calls: list[dict[str, Any]] = [] + self.latest_messages: list[Any] = [] + self._resource_client = SimpleNamespace( + chat=SimpleNamespace( + completions=SimpleNamespace(create=self._llm_call), + ) + ) + self._llm_response = SimpleNamespace(type="llm") + self._stream_result = "async-stream" + self._handle_result = "async-handled" + + async def _llm_call(self, **kwargs: Any) -> Any: + self.llm_kwargs = kwargs + return self._llm_response + + def _extract_latest_user_message(self, messages: list[dict[str, str]]) -> tuple[str, int]: + self.latest_messages.append(messages) + return messages[-1]["content"], len(messages) - 1 + + async def _run_stage_guardrails( + self, + stage_name: str, + text: str, + conversation_history: list | None = None, + suppress_tripwire: bool = False, + ) -> list[Any]: + call = { + "stage": stage_name, + "text": text, + "history": conversation_history, + "suppress": suppress_tripwire, + } + if stage_name == "pre_flight": + self.preflight_calls.append(call) + return ["preflight"] + self.input_calls.append(call) + return ["input"] + + def _apply_preflight_modifications(self, messages: list[dict[str, str]], results: list[Any]) -> list[Any]: + self.applied.append((messages, results)) + return [{"role": "user", "content": "modified"}] + + async def _handle_llm_response( + self, + llm_response: Any, + preflight_results: list[Any], + input_results: list[Any], + conversation_history: list | None = None, + suppress_tripwire: bool = False, + ) -> Any: + self.handle_calls.append( + { + "response": llm_response, + "preflight": preflight_results, + "input": input_results, + "history": conversation_history, + } + ) + return self._handle_result + + def _stream_with_guardrails( + self, + llm_stream: Any, + preflight_results: list[Any], + input_results: list[Any], + suppress_tripwire: bool = False, + ) -> Any: + self.stream_calls.append( + { + "stream": llm_stream, + "preflight": preflight_results, + "input": input_results, + "suppress": suppress_tripwire, + } + ) + return self._stream_result + + +def _messages() -> list[dict[str, str]]: + return [ + {"role": "system", "content": "rules"}, + {"role": "user", "content": "hello"}, + ] + + +def test_chat_completions_create_invokes_guardrails(monkeypatch: pytest.MonkeyPatch) -> None: + """ChatCompletions.create should run guardrails and forward modified messages.""" + client = _SyncClient() + completions = ChatCompletions(client) + + monkeypatch.setattr("guardrails.resources.chat.chat.ThreadPoolExecutor", _InlineExecutor) + + result = completions.create(messages=_messages(), model="gpt-test") + + assert result == "handled" # noqa: S101 + assert client.preflight_calls[0]["stage"] == "pre_flight" # noqa: S101 + assert client.input_calls[0]["stage"] == "input" # noqa: S101 + assert client.llm_kwargs["messages"][0]["content"] == "modified" # noqa: S101 + assert client.handle_calls[0]["preflight"] == ["preflight"] # noqa: S101 + + +def test_chat_completions_stream_returns_streaming_wrapper(monkeypatch: pytest.MonkeyPatch) -> None: + """Streaming mode should defer to _stream_with_guardrails_sync.""" + client = _SyncClient() + completions = ChatCompletions(client) + + monkeypatch.setattr("guardrails.resources.chat.chat.ThreadPoolExecutor", _InlineExecutor) + + result = completions.create(messages=_messages(), model="gpt-test", stream=True, suppress_tripwire=True) + + assert result == "stream" # noqa: S101 + stream_call = client.stream_calls[0] + assert stream_call["suppress"] is True # noqa: S101 + assert stream_call["preflight"] == ["preflight"] # noqa: S101 + + +@pytest.mark.asyncio +async def test_async_chat_completions_create_invokes_guardrails() -> None: + """AsyncChatCompletions.create should await guardrails and LLM call.""" + client = _AsyncClient() + completions = AsyncChatCompletions(client) + + result = await completions.create(messages=_messages(), model="gpt-test") + + assert result == "async-handled" # noqa: S101 + assert client.preflight_calls[0]["stage"] == "pre_flight" # noqa: S101 + assert client.input_calls[0]["stage"] == "input" # noqa: S101 + assert client.llm_kwargs["messages"][0]["content"] == "modified" # noqa: S101 + assert client.handle_calls[0]["preflight"] == ["preflight"] # noqa: S101 + + +@pytest.mark.asyncio +async def test_async_chat_completions_stream_returns_wrapper() -> None: + """Async streaming mode should defer to _stream_with_guardrails.""" + client = _AsyncClient() + completions = AsyncChatCompletions(client) + + result = await completions.create( + messages=_messages(), + model="gpt-test", + stream=True, + suppress_tripwire=False, + ) + + assert result == "async-stream" # noqa: S101 + stream_call = client.stream_calls[0] + assert stream_call["preflight"] == ["preflight"] # noqa: S101 + assert stream_call["input"] == ["input"] # noqa: S101 diff --git a/tests/unit/test_resources_responses.py b/tests/unit/test_resources_responses.py new file mode 100644 index 0000000..88adbe3 --- /dev/null +++ b/tests/unit/test_resources_responses.py @@ -0,0 +1,338 @@ +"""Tests for responses resource wrappers.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest +from pydantic import BaseModel + +from guardrails.resources.responses.responses import AsyncResponses, Responses + + +class _SyncResponsesClient: + """Fake synchronous guardrails client for Responses tests.""" + + def __init__(self) -> None: + self.preflight_calls: list[dict[str, Any]] = [] + self.input_calls: list[dict[str, Any]] = [] + self.output_calls: list[dict[str, Any]] = [] + self.applied: list[Any] = [] + self.handle_calls: list[dict[str, Any]] = [] + self.stream_calls: list[dict[str, Any]] = [] + self.create_calls: list[dict[str, Any]] = [] + self.parse_calls: list[dict[str, Any]] = [] + self.retrieve_calls: list[dict[str, Any]] = [] + self._llm_response = SimpleNamespace(output_text="result", type="llm") + self._stream_result = "stream" + self._handle_result = "handled" + self._resource_client = SimpleNamespace( + responses=SimpleNamespace( + create=self._llm_create, + parse=self._llm_parse, + retrieve=self._llm_retrieve, + ) + ) + + def _llm_create(self, **kwargs: Any) -> Any: + self.create_calls.append(kwargs) + return self._llm_response + + def _llm_parse(self, **kwargs: Any) -> Any: + self.parse_calls.append(kwargs) + return self._llm_response + + def _llm_retrieve(self, response_id: str, **kwargs: Any) -> Any: + self.retrieve_calls.append({"id": response_id, "kwargs": kwargs}) + return self._llm_response + + def _extract_latest_user_message(self, messages: list[dict[str, str]]) -> tuple[str, int]: + return messages[-1]["content"], len(messages) - 1 + + def _run_stage_guardrails( + self, + stage: str, + text: str, + conversation_history: list | str | None = None, + suppress_tripwire: bool = False, + ) -> list[str]: + call = { + "stage": stage, + "text": text, + "history": conversation_history, + "suppress": suppress_tripwire, + } + if stage == "pre_flight": + self.preflight_calls.append(call) + return ["preflight"] + if stage == "input": + self.input_calls.append(call) + return ["input"] + self.output_calls.append(call) + return ["output"] + + def _apply_preflight_modifications(self, data: Any, results: list[Any]) -> Any: + self.applied.append((data, results)) + if isinstance(data, list): + return [{"role": "user", "content": "modified"}] + return "modified" + + def _handle_llm_response( + self, + llm_response: Any, + preflight_results: list[Any], + input_results: list[Any], + conversation_history: Any = None, + suppress_tripwire: bool = False, + **kwargs: Any, + ) -> Any: + self.handle_calls.append( + { + "response": llm_response, + "preflight": preflight_results, + "input": input_results, + "history": conversation_history, + "extra": kwargs, + } + ) + return self._handle_result + + def _stream_with_guardrails_sync( + self, + llm_stream: Any, + preflight_results: list[Any], + input_results: list[Any], + suppress_tripwire: bool = False, + ) -> Any: + self.stream_calls.append( + { + "stream": llm_stream, + "preflight": preflight_results, + "input": input_results, + "suppress": suppress_tripwire, + } + ) + return self._stream_result + + def _create_guardrails_response( + self, + response: Any, + preflight_results: list[Any], + input_results: list[Any], + output_results: list[Any], + ) -> Any: + self.output_calls.append({"stage": "output", "results": output_results}) + return { + "response": response, + "preflight": preflight_results, + "input": input_results, + "output": output_results, + } + + +class _AsyncResponsesClient: + """Fake asynchronous guardrails client for AsyncResponses tests.""" + + def __init__(self) -> None: + self.preflight_calls: list[dict[str, Any]] = [] + self.input_calls: list[dict[str, Any]] = [] + self.output_calls: list[dict[str, Any]] = [] + self.applied: list[Any] = [] + self.handle_calls: list[dict[str, Any]] = [] + self.stream_calls: list[dict[str, Any]] = [] + self.create_calls: list[dict[str, Any]] = [] + self._llm_response = SimpleNamespace(output_text="async", type="llm") + self._stream_result = "async-stream" + self._handle_result = "async-handled" + self._resource_client = SimpleNamespace( + responses=SimpleNamespace( + create=self._llm_create, + ) + ) + + async def _llm_create(self, **kwargs: Any) -> Any: + self.create_calls.append(kwargs) + return self._llm_response + + def _extract_latest_user_message(self, messages: list[dict[str, str]]) -> tuple[str, int]: + return messages[-1]["content"], len(messages) - 1 + + async def _run_stage_guardrails( + self, + stage: str, + text: str, + conversation_history: list | str | None = None, + suppress_tripwire: bool = False, + ) -> list[str]: + call = { + "stage": stage, + "text": text, + "history": conversation_history, + "suppress": suppress_tripwire, + } + if stage == "pre_flight": + self.preflight_calls.append(call) + return ["preflight"] + if stage == "input": + self.input_calls.append(call) + return ["input"] + self.output_calls.append(call) + return ["output"] + + def _apply_preflight_modifications(self, data: Any, results: list[Any]) -> Any: + self.applied.append((data, results)) + if isinstance(data, list): + return [{"role": "user", "content": "modified"}] + return "modified" + + async def _handle_llm_response( + self, + llm_response: Any, + preflight_results: list[Any], + input_results: list[Any], + conversation_history: Any = None, + suppress_tripwire: bool = False, + ) -> Any: + self.handle_calls.append( + { + "response": llm_response, + "preflight": preflight_results, + "input": input_results, + "history": conversation_history, + } + ) + return self._handle_result + + def _stream_with_guardrails( + self, + llm_stream: Any, + preflight_results: list[Any], + input_results: list[Any], + suppress_tripwire: bool = False, + ) -> Any: + self.stream_calls.append( + { + "stream": llm_stream, + "preflight": preflight_results, + "input": input_results, + "suppress": suppress_tripwire, + } + ) + return self._stream_result + + +def _messages() -> list[dict[str, str]]: + return [ + {"role": "system", "content": "rules"}, + {"role": "user", "content": "hello"}, + ] + + +def _inline_executor(monkeypatch: pytest.MonkeyPatch) -> None: + class _InlineExecutor: + def __init__(self, *args: Any, **kwargs: Any) -> None: + _ = (args, kwargs) + + def __enter__(self) -> _InlineExecutor: + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def submit(self, fn, *args, **kwargs): + class _ImmediateFuture: + def __init__(self) -> None: + self._result = fn(*args, **kwargs) + + def result(self) -> Any: + return self._result + + return _ImmediateFuture() + + monkeypatch.setattr("guardrails.resources.responses.responses.ThreadPoolExecutor", _InlineExecutor) + + +def test_responses_create_runs_guardrails(monkeypatch: pytest.MonkeyPatch) -> None: + """Responses.create should apply guardrails and forward modified input.""" + client = _SyncResponsesClient() + responses = Responses(client) + _inline_executor(monkeypatch) + + result = responses.create(input=_messages(), model="gpt-test") + + assert result == "handled" # noqa: S101 + assert client.preflight_calls[0]["stage"] == "pre_flight" # noqa: S101 + assert client.input_calls[0]["stage"] == "input" # noqa: S101 + assert client.create_calls[0]["input"][0]["content"] == "modified" # noqa: S101 + + +def test_responses_create_stream_returns_stream(monkeypatch: pytest.MonkeyPatch) -> None: + """Streaming mode should call _stream_with_guardrails_sync.""" + client = _SyncResponsesClient() + responses = Responses(client) + _inline_executor(monkeypatch) + + result = responses.create(input=_messages(), model="gpt-test", stream=True, suppress_tripwire=True) + + assert result == "stream" # noqa: S101 + stream_call = client.stream_calls[0] + assert stream_call["suppress"] is True # noqa: S101 + assert stream_call["preflight"] == ["preflight"] # noqa: S101 + + +def test_responses_parse_runs_guardrails(monkeypatch: pytest.MonkeyPatch) -> None: + """Responses.parse should run guardrails and pass modified input.""" + client = _SyncResponsesClient() + responses = Responses(client) + _inline_executor(monkeypatch) + + class _Schema(BaseModel): + text: str + + messages = _messages() + result = responses.parse(input=messages, model="gpt-test", text_format=_Schema) + + assert result == "handled" # noqa: S101 + assert client.parse_calls[0]["input"][0]["content"] == "modified" # noqa: S101 + assert client.handle_calls[0]["extra"]["conversation_data"] == messages # noqa: S101 + + +def test_responses_retrieve_wraps_output() -> None: + """Responses.retrieve should run output guardrails and wrap the response.""" + client = _SyncResponsesClient() + responses = Responses(client) + + wrapped = responses.retrieve("resp-1", suppress_tripwire=False) + + assert wrapped["response"].output_text == "result" # noqa: S101 + assert wrapped["output"] == ["output"] # noqa: S101 + assert client.retrieve_calls[0]["id"] == "resp-1" # noqa: S101 + + +@pytest.mark.asyncio +async def test_async_responses_create_runs_guardrails() -> None: + """AsyncResponses.create should await guardrails and modify input.""" + client = _AsyncResponsesClient() + responses = AsyncResponses(client) + + result = await responses.create(input=_messages(), model="gpt-test") + + assert result == "async-handled" # noqa: S101 + assert client.preflight_calls[0]["stage"] == "pre_flight" # noqa: S101 + assert client.input_calls[0]["stage"] == "input" # noqa: S101 + assert client.create_calls[0]["input"][0]["content"] == "modified" # noqa: S101 + + +@pytest.mark.asyncio +async def test_async_responses_stream_returns_wrapper() -> None: + """AsyncResponses streaming mode should defer to _stream_with_guardrails.""" + client = _AsyncResponsesClient() + responses = AsyncResponses(client) + + result = await responses.create(input=_messages(), model="gpt-test", stream=True) + + assert result == "async-stream" # noqa: S101 + stream_call = client.stream_calls[0] + assert stream_call["preflight"] == ["preflight"] # noqa: S101 + assert stream_call["input"] == ["input"] # noqa: S101 diff --git a/tests/unit/test_runtime.py b/tests/unit/test_runtime.py index 3eb196b..662447f 100644 --- a/tests/unit/test_runtime.py +++ b/tests/unit/test_runtime.py @@ -1,8 +1,5 @@ """Tests for the runtime module.""" -import sys -import types -from collections.abc import Iterator from dataclasses import dataclass from typing import Any, Protocol @@ -27,42 +24,6 @@ THRESHOLD = 2 -@pytest.fixture(autouse=True) -def stub_openai_module(monkeypatch: pytest.MonkeyPatch) -> Iterator[types.ModuleType]: - """Provide a stub ``openai.AsyncOpenAI`` and patch imports in guardrails.*. - - Ensures tests don't require real OPENAI_API_KEY or networked clients. - """ - module = types.ModuleType("openai") - - class AsyncOpenAI: # noqa: D401 - simple stub - """Stubbed AsyncOpenAI client.""" - - pass - - module.__dict__["AsyncOpenAI"] = AsyncOpenAI - # Ensure any downstream import finds our stub module - monkeypatch.setitem(sys.modules, "openai", module) - # Also patch already-imported symbols on guardrails modules - try: - import guardrails.runtime as gr_runtime # type: ignore - - monkeypatch.setattr(gr_runtime, "AsyncOpenAI", AsyncOpenAI, raising=False) - except Exception: - pass - try: - import guardrails.types as gr_types # type: ignore - - monkeypatch.setattr(gr_types, "AsyncOpenAI", AsyncOpenAI, raising=False) - except Exception: - pass - # Provide dummy API key to satisfy any code paths that inspect env - monkeypatch.setenv("OPENAI_API_KEY", "test-key") - - yield module - monkeypatch.delitem(sys.modules, "openai", raising=False) - - class LenCfg(BaseModel): """Configuration specifying length threshold.""" diff --git a/tests/unit/test_streaming.py b/tests/unit/test_streaming.py new file mode 100644 index 0000000..6bb4f58 --- /dev/null +++ b/tests/unit/test_streaming.py @@ -0,0 +1,162 @@ +"""Tests for StreamingMixin behaviour.""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Iterator +from dataclasses import dataclass +from typing import Any + +import pytest + +from guardrails._base_client import GuardrailsBaseClient, GuardrailsResponse +from guardrails._streaming import StreamingMixin +from guardrails.exceptions import GuardrailTripwireTriggered +from guardrails.types import GuardrailResult + + +@dataclass(frozen=True, slots=True) +class _Chunk: + """Simple chunk carrying text content.""" + + text: str + + +class _StreamingCollector(StreamingMixin, GuardrailsBaseClient): + """Minimal client exposing hooks required by StreamingMixin.""" + + def __init__(self) -> None: + self.run_calls: list[tuple[str, bool]] = [] + self.responses: list[GuardrailsResponse] = [] + self._next_results: list[GuardrailResult] = [] + self._should_raise = False + + def set_results(self, results: list[GuardrailResult]) -> None: + self._next_results = results + + def trigger_tripwire(self) -> None: + self._should_raise = True + + def _extract_response_text(self, chunk: _Chunk) -> str: + return chunk.text + + def _run_stage_guardrails( + self, + stage_name: str, + text: str, + suppress_tripwire: bool = False, + **kwargs: Any, + ) -> list[GuardrailResult]: + self.run_calls.append((text, suppress_tripwire)) + if self._should_raise: + from guardrails.exceptions import GuardrailTripwireTriggered + + raise GuardrailTripwireTriggered(GuardrailResult(tripwire_triggered=True)) + return self._next_results + + async def _run_stage_guardrails_async( + self, + stage_name: str, + text: str, + suppress_tripwire: bool = False, + **kwargs: Any, + ) -> list[GuardrailResult]: + return self._run_stage_guardrails(stage_name, text, suppress_tripwire=suppress_tripwire) + + def _create_guardrails_response( + self, + llm_response: Any, + preflight_results: list[GuardrailResult], + input_results: list[GuardrailResult], + output_results: list[GuardrailResult], + ) -> GuardrailsResponse: + response = super()._create_guardrails_response(llm_response, preflight_results, input_results, output_results) + self.responses.append(response) + return response + + +async def _aiter(items: list[_Chunk]) -> AsyncIterator[_Chunk]: + for item in items: + yield item + + +def test_stream_with_guardrails_sync_emits_results() -> None: + """Synchronous streaming should yield GuardrailsResponse objects with accumulated results.""" + client = _StreamingCollector() + client.set_results([GuardrailResult(tripwire_triggered=False)]) + chunks: Iterator[_Chunk] = iter([_Chunk("a"), _Chunk("b")]) + + responses = list( + client._stream_with_guardrails_sync( + chunks, + preflight_results=[GuardrailResult(tripwire_triggered=False)], + input_results=[], + check_interval=1, + ) + ) + + assert [resp.guardrail_results.output for resp in responses] == [[], []] # noqa: S101 + assert client.run_calls == [("a", False), ("ab", False), ("ab", False)] # noqa: S101 + + +@pytest.mark.asyncio +async def test_stream_with_guardrails_async_emits_results() -> None: + """Async streaming should yield GuardrailsResponse objects and run final checks.""" + client = _StreamingCollector() + + async def fake_run_stage( + stage_name: str, + text: str, + suppress_tripwire: bool = False, + **kwargs: Any, + ) -> list[GuardrailResult]: + client.run_calls.append((text, suppress_tripwire)) + return [] + + client._run_stage_guardrails = fake_run_stage # type: ignore[assignment] + + responses = [ + response + async for response in client._stream_with_guardrails( + _aiter([_Chunk("a"), _Chunk("b")]), + preflight_results=[], + input_results=[], + check_interval=2, + ) + ] + + assert len(responses) == 2 # noqa: S101 + # Final guardrail run should consume aggregated text "ab" + assert client.run_calls[-1][0] == "ab" # noqa: S101 + + +@pytest.mark.asyncio +async def test_stream_with_guardrails_async_raises_on_tripwire() -> None: + """Tripwire should abort streaming and clear accumulated text.""" + client = _StreamingCollector() + + async def fake_run_stage( + stage_name: str, + text: str, + suppress_tripwire: bool = False, + **kwargs: Any, + ) -> list[GuardrailResult]: + raise_guardrail = text == "chunk" + if raise_guardrail: + from guardrails.exceptions import GuardrailTripwireTriggered + + raise GuardrailTripwireTriggered(GuardrailResult(tripwire_triggered=True)) + return [] + + client._run_stage_guardrails = fake_run_stage # type: ignore[assignment] + + async def chunk_stream() -> AsyncIterator[_Chunk]: + yield _Chunk("chunk") + + with pytest.raises(GuardrailTripwireTriggered): + async for _ in client._stream_with_guardrails( + chunk_stream(), + preflight_results=[], + input_results=[], + check_interval=1, + ): + pass diff --git a/tests/unit/utils/test_create_vector_store.py b/tests/unit/utils/test_create_vector_store.py new file mode 100644 index 0000000..29f6a43 --- /dev/null +++ b/tests/unit/utils/test_create_vector_store.py @@ -0,0 +1,69 @@ +"""Tests for create_vector_store helper.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path +from types import SimpleNamespace + +import pytest + +from guardrails.utils.create_vector_store import SUPPORTED_FILE_TYPES, create_vector_store_from_path + + +class _FakeAsyncOpenAI: + def __init__(self) -> None: + self._vector_store_id = "vs_123" + self._file_counter = 0 + self._file_status: list[str] = [] + + async def create_vector_store(name: str) -> SimpleNamespace: + _ = name + return SimpleNamespace(id=self._vector_store_id) + + async def add_file(vector_store_id: str, file_id: str) -> None: + self._file_status.append("processing") + + async def list_files(vector_store_id: str) -> SimpleNamespace: + if self._file_status: + self._file_status = ["completed" for _ in self._file_status] + return SimpleNamespace(data=[SimpleNamespace(status=s) for s in self._file_status]) + + async def create_file(file, purpose: str) -> SimpleNamespace: # noqa: ANN001 + _ = (file, purpose) + self._file_counter += 1 + return SimpleNamespace(id=f"file_{self._file_counter}") + + self.vector_stores = SimpleNamespace( + create=create_vector_store, + files=SimpleNamespace(create=add_file, list=list_files), + ) + self.files = SimpleNamespace(create=create_file) + + +@pytest.mark.asyncio +async def test_create_vector_store_from_directory(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """Supported files inside directory should be uploaded and vector store id returned.""" + sample_file = tmp_path / "doc.txt" + sample_file.write_text("data") + + client = _FakeAsyncOpenAI() + + vector_store_id = await asyncio.wait_for(create_vector_store_from_path(tmp_path, client), timeout=1) + + assert vector_store_id == "vs_123" # noqa: S101 + + +@pytest.mark.asyncio +async def test_create_vector_store_no_supported_files(tmp_path: Path) -> None: + """Directory without supported files should raise ValueError.""" + (tmp_path / "ignored.bin").write_text("data") + client = _FakeAsyncOpenAI() + + with pytest.raises(ValueError): + await create_vector_store_from_path(tmp_path, client) + + +def test_supported_file_types_contains_common_extensions() -> None: + """Ensure supported extensions include basic formats.""" + assert ".txt" in SUPPORTED_FILE_TYPES # noqa: S101 diff --git a/tests/unit/utils/test_output.py b/tests/unit/utils/test_output.py new file mode 100644 index 0000000..a757e19 --- /dev/null +++ b/tests/unit/utils/test_output.py @@ -0,0 +1,38 @@ +"""Tests for guardrails.utils.output module.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import pytest + +from guardrails.exceptions import ModelBehaviorError, UserError +from guardrails.utils.output import OutputSchema + + +@dataclass(frozen=True, slots=True) +class _Payload: + message: str + count: int + + +def test_output_schema_wraps_non_text_types() -> None: + schema = OutputSchema(_Payload) + json_schema = schema.json_schema() + assert json_schema["type"] == "object" # noqa: S101 + + validated = schema.validate_json('{"response": {"message": "hi", "count": 2}}') + assert validated == _Payload(message="hi", count=2) # noqa: S101 + + +def test_output_schema_plain_text() -> None: + schema = OutputSchema(str) + assert schema.is_plain_text() is True # noqa: S101 + with pytest.raises(UserError): + schema.json_schema() + + +def test_output_schema_invalid_json_raises() -> None: + schema = OutputSchema(_Payload) + with pytest.raises(ModelBehaviorError): + schema.validate_json("not-json") diff --git a/tests/unit/utils/test_parsing.py b/tests/unit/utils/test_parsing.py new file mode 100644 index 0000000..a10df07 --- /dev/null +++ b/tests/unit/utils/test_parsing.py @@ -0,0 +1,47 @@ +"""Tests for guardrails.utils.parsing.""" + +from __future__ import annotations + +from guardrails.utils.parsing import Entry, format_entries, parse_response_items, parse_response_items_as_json + + +def test_parse_response_items_handles_messages() -> None: + items = [ + { + "type": "message", + "role": "user", + "content": [ + {"type": "input_text", "text": "Hello"}, + "!", + ], + }, + { + "type": "function_call", + "arguments": "{}", + }, + ] + + entries = parse_response_items(items) + + assert entries == [Entry(role="user", content="Hello!"), Entry(role="function_call", content="{}")] + + +def test_parse_response_items_filters_by_role() -> None: + items = [{"type": "message", "role": "assistant", "content": "Hi"}, {"type": "message", "role": "user", "content": "Bye"}] + entries = parse_response_items(items, filter_role="assistant") + + assert entries == [Entry(role="assistant", content="Hi")] + + +def test_parse_response_items_as_json() -> None: + entries_json = parse_response_items_as_json( + [{"type": "message", "role": "assistant", "content": "Hi"}], + ) + + assert "assistant" in entries_json # noqa: S101 + + +def test_format_entries_text() -> None: + text = format_entries([Entry("assistant", "Hi"), Entry("user", "Bye")], fmt="text") + + assert text == "assistant: Hi\nuser: Bye" diff --git a/tests/unit/utils/test_schema.py b/tests/unit/utils/test_schema.py new file mode 100644 index 0000000..fb75d4f --- /dev/null +++ b/tests/unit/utils/test_schema.py @@ -0,0 +1,46 @@ +"""Tests for guardrails.utils.schema utilities.""" + +from __future__ import annotations + +import pytest +from pydantic import BaseModel, TypeAdapter + +from guardrails.exceptions import ModelBehaviorError, UserError +from guardrails.utils.schema import ensure_strict_json_schema, validate_json + + +class _Payload(BaseModel): + message: str + + +def test_validate_json_success() -> None: + adapter = TypeAdapter(_Payload) + result = validate_json('{"message": "hi"}', adapter, partial=False) + + assert result.message == "hi" # noqa: S101 + + +def test_validate_json_error() -> None: + adapter = TypeAdapter(_Payload) + with pytest.raises(ModelBehaviorError): + validate_json('{"message": 5}', adapter, partial=False) + + +def test_ensure_strict_json_schema_enforces_constraints() -> None: + schema = { + "type": "object", + "properties": { + "message": {"type": "string"}, + }, + } + + strict = ensure_strict_json_schema(schema) + + assert strict["additionalProperties"] is False # noqa: S101 + assert strict["required"] == ["message"] # noqa: S101 + + +def test_ensure_strict_json_schema_rejects_additional_properties() -> None: + schema = {"type": "object", "additionalProperties": True} + with pytest.raises(UserError): + ensure_strict_json_schema(schema)