From 1e265a634a88e1c34a6f0d61622da89dc7e2666c Mon Sep 17 00:00:00 2001 From: Steven C Date: Thu, 30 Oct 2025 15:43:17 -0400 Subject: [PATCH 1/2] Correctly pasing api key to moderation --- src/guardrails/checks/text/moderation.py | 61 +++++++++++---------- tests/conftest.py | 5 ++ tests/unit/checks/test_moderation.py | 70 ++++++++++++++++++++++++ 3 files changed, 107 insertions(+), 29 deletions(-) diff --git a/src/guardrails/checks/text/moderation.py b/src/guardrails/checks/text/moderation.py index 66f8bb1..fc8115f 100644 --- a/src/guardrails/checks/text/moderation.py +++ b/src/guardrails/checks/text/moderation.py @@ -32,7 +32,7 @@ from functools import cache from typing import Any -from openai import AsyncOpenAI +from openai import AsyncOpenAI, NotFoundError from pydantic import BaseModel, ConfigDict, Field from guardrails.registry import default_spec_registry @@ -152,35 +152,38 @@ async def moderation( GuardrailResult: Indicates if tripwire was triggered, and details of flagged categories. """ - # Prefer reusing an existing OpenAI client from context ONLY if it targets the - # official OpenAI API. If it's any other provider (e.g., Ollama via base_url), - # fall back to the default OpenAI moderation client. - def _maybe_reuse_openai_client_from_ctx(context: Any) -> AsyncOpenAI | None: + client = None + if ctx is not None: + candidate = getattr(ctx, "guardrail_llm", None) + if isinstance(candidate, AsyncOpenAI): + client = candidate + + # Try the context client first, fall back if moderation endpoint doesn't exist + if client is not None: try: - candidate = getattr(context, "guardrail_llm", None) - if not isinstance(candidate, AsyncOpenAI): - return None - - # Attempt to discover the effective base URL in a best-effort way - base_url = getattr(candidate, "base_url", None) - if base_url is None: - inner = getattr(candidate, "_client", None) - base_url = getattr(inner, "base_url", None) or getattr(inner, "_base_url", None) - - # Reuse only when clearly the official OpenAI endpoint - if base_url is None: - return candidate - if isinstance(base_url, str) and "api.openai.com" in base_url: - return candidate - return None - except Exception: - return None - - client = _maybe_reuse_openai_client_from_ctx(ctx) or _get_moderation_client() - resp = await client.moderations.create( - model="omni-moderation-latest", - input=data, - ) + resp = await client.moderations.create( + model="omni-moderation-latest", + input=data, + ) + except NotFoundError as e: + # Moderation endpoint doesn't exist on this provider (e.g., third-party) + # Fall back to the OpenAI client + logger.debug( + "Moderation endpoint not available on context client, falling back to OpenAI: %s", + e, + ) + client = _get_moderation_client() + resp = await client.moderations.create( + model="omni-moderation-latest", + input=data, + ) + else: + # No context client, use fallback + client = _get_moderation_client() + resp = await client.moderations.create( + model="omni-moderation-latest", + input=data, + ) results = resp.results or [] if not results: return GuardrailResult( diff --git a/tests/conftest.py b/tests/conftest.py index 7cf4555..b8c5830 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -67,7 +67,12 @@ class APITimeoutError(Exception): """Stub API timeout error.""" +class NotFoundError(Exception): + """Stub 404 not found error.""" + + _STUB_OPENAI_MODULE.APITimeoutError = APITimeoutError +_STUB_OPENAI_MODULE.NotFoundError = NotFoundError _OPENAI_TYPES_MODULE = types.ModuleType("openai.types") _OPENAI_TYPES_MODULE.Completion = _DummyResponse diff --git a/tests/unit/checks/test_moderation.py b/tests/unit/checks/test_moderation.py index 389dd09..ae6bcac 100644 --- a/tests/unit/checks/test_moderation.py +++ b/tests/unit/checks/test_moderation.py @@ -56,3 +56,73 @@ async def create_empty(**_: Any) -> Any: assert result.tripwire_triggered is False # noqa: S101 assert result.info["error"] == "No moderation results returned" # noqa: S101 + + +@pytest.mark.asyncio +async def test_moderation_uses_context_client() -> None: + """Moderation should use the client from context when available.""" + from openai import AsyncOpenAI + + # Track whether context client was used + context_client_used = False + + async def track_create(**_: Any) -> Any: + nonlocal context_client_used + context_client_used = True + + class _Result: + def model_dump(self) -> dict[str, Any]: + return {"categories": {"hate": False, "violence": False}} + + return SimpleNamespace(results=[_Result()]) + + # Create a context with a guardrail_llm client + context_client = AsyncOpenAI(api_key="test-context-key", base_url="https://api.openai.com/v1") + context_client.moderations = SimpleNamespace(create=track_create) + + ctx = SimpleNamespace(guardrail_llm=context_client) + + cfg = ModerationCfg(categories=[Category.HATE]) + result = await moderation(ctx, "test text", cfg) + + # Verify the context client was used + assert context_client_used is True # noqa: S101 + assert result.tripwire_triggered is False # noqa: S101 + + +@pytest.mark.asyncio +async def test_moderation_falls_back_for_third_party_provider(monkeypatch: pytest.MonkeyPatch) -> None: + """Moderation should fall back to environment client for third-party providers.""" + from openai import AsyncOpenAI, NotFoundError + + # Create fallback client that tracks usage + fallback_used = False + + async def track_fallback_create(**_: Any) -> Any: + nonlocal fallback_used + fallback_used = True + + class _Result: + def model_dump(self) -> dict[str, Any]: + return {"categories": {"hate": False}} + + return SimpleNamespace(results=[_Result()]) + + fallback_client = SimpleNamespace(moderations=SimpleNamespace(create=track_fallback_create)) + monkeypatch.setattr("guardrails.checks.text.moderation._get_moderation_client", lambda: fallback_client) + + # Create a context client that simulates a third-party provider + # When moderation is called, it should raise NotFoundError + async def raise_not_found(**_: Any) -> Any: + raise NotFoundError("404 page not found") + + third_party_client = AsyncOpenAI(api_key="third-party-key", base_url="https://localhost:8080/v1") + third_party_client.moderations = SimpleNamespace(create=raise_not_found) + ctx = SimpleNamespace(guardrail_llm=third_party_client) + + cfg = ModerationCfg(categories=[Category.HATE]) + result = await moderation(ctx, "test text", cfg) + + # Verify the fallback client was used (not the third-party one) + assert fallback_used is True # noqa: S101 + assert result.tripwire_triggered is False # noqa: S101 From 707f842f4b313d791e3d3425d33da3439159654d Mon Sep 17 00:00:00 2001 From: Steven C Date: Thu, 30 Oct 2025 15:58:41 -0400 Subject: [PATCH 2/2] extract call_moderation helper --- src/guardrails/checks/text/moderation.py | 32 ++++++++++++++---------- tests/conftest.py | 6 +++++ tests/unit/checks/test_moderation.py | 14 ++++++++--- 3 files changed, 36 insertions(+), 16 deletions(-) diff --git a/src/guardrails/checks/text/moderation.py b/src/guardrails/checks/text/moderation.py index fc8115f..e883377 100644 --- a/src/guardrails/checks/text/moderation.py +++ b/src/guardrails/checks/text/moderation.py @@ -132,6 +132,22 @@ def _get_moderation_client() -> AsyncOpenAI: return AsyncOpenAI(**prepare_openai_kwargs({})) +async def _call_moderation_api(client: AsyncOpenAI, data: str) -> Any: + """Call the OpenAI moderation API. + + Args: + client: The OpenAI client to use. + data: The text to analyze. + + Returns: + The moderation API response. + """ + return await client.moderations.create( + model="omni-moderation-latest", + input=data, + ) + + async def moderation( ctx: Any, data: str, @@ -151,7 +167,6 @@ async def moderation( Returns: GuardrailResult: Indicates if tripwire was triggered, and details of flagged categories. """ - client = None if ctx is not None: candidate = getattr(ctx, "guardrail_llm", None) @@ -161,10 +176,7 @@ async def moderation( # Try the context client first, fall back if moderation endpoint doesn't exist if client is not None: try: - resp = await client.moderations.create( - model="omni-moderation-latest", - input=data, - ) + resp = await _call_moderation_api(client, data) except NotFoundError as e: # Moderation endpoint doesn't exist on this provider (e.g., third-party) # Fall back to the OpenAI client @@ -173,17 +185,11 @@ async def moderation( e, ) client = _get_moderation_client() - resp = await client.moderations.create( - model="omni-moderation-latest", - input=data, - ) + resp = await _call_moderation_api(client, data) else: # No context client, use fallback client = _get_moderation_client() - resp = await client.moderations.create( - model="omni-moderation-latest", - input=data, - ) + resp = await _call_moderation_api(client, data) results = resp.results or [] if not results: return GuardrailResult( diff --git a/tests/conftest.py b/tests/conftest.py index b8c5830..2c226f5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -70,6 +70,12 @@ class APITimeoutError(Exception): class NotFoundError(Exception): """Stub 404 not found error.""" + def __init__(self, message: str, *, response: Any = None, body: Any = None) -> None: + """Initialize NotFoundError with OpenAI-compatible signature.""" + super().__init__(message) + self.response = response + self.body = body + _STUB_OPENAI_MODULE.APITimeoutError = APITimeoutError _STUB_OPENAI_MODULE.NotFoundError = NotFoundError diff --git a/tests/unit/checks/test_moderation.py b/tests/unit/checks/test_moderation.py index ae6bcac..11351ea 100644 --- a/tests/unit/checks/test_moderation.py +++ b/tests/unit/checks/test_moderation.py @@ -78,7 +78,7 @@ def model_dump(self) -> dict[str, Any]: # Create a context with a guardrail_llm client context_client = AsyncOpenAI(api_key="test-context-key", base_url="https://api.openai.com/v1") - context_client.moderations = SimpleNamespace(create=track_create) + context_client.moderations = SimpleNamespace(create=track_create) # type: ignore[assignment] ctx = SimpleNamespace(guardrail_llm=context_client) @@ -111,13 +111,21 @@ def model_dump(self) -> dict[str, Any]: fallback_client = SimpleNamespace(moderations=SimpleNamespace(create=track_fallback_create)) monkeypatch.setattr("guardrails.checks.text.moderation._get_moderation_client", lambda: fallback_client) + # Create a mock httpx.Response for NotFoundError + mock_response = SimpleNamespace( + status_code=404, + headers={}, + text="404 page not found", + json=lambda: {"error": {"message": "Not found", "type": "invalid_request_error"}}, + ) + # Create a context client that simulates a third-party provider # When moderation is called, it should raise NotFoundError async def raise_not_found(**_: Any) -> Any: - raise NotFoundError("404 page not found") + raise NotFoundError("404 page not found", response=mock_response, body=None) # type: ignore[arg-type] third_party_client = AsyncOpenAI(api_key="third-party-key", base_url="https://localhost:8080/v1") - third_party_client.moderations = SimpleNamespace(create=raise_not_found) + third_party_client.moderations = SimpleNamespace(create=raise_not_found) # type: ignore[assignment] ctx = SimpleNamespace(guardrail_llm=third_party_client) cfg = ModerationCfg(categories=[Category.HATE])