Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 39 additions & 30 deletions src/guardrails/checks/text/moderation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from functools import cache
from typing import Any

from openai import AsyncOpenAI
from openai import AsyncOpenAI, NotFoundError
from pydantic import BaseModel, ConfigDict, Field

from guardrails.registry import default_spec_registry
Expand Down Expand Up @@ -132,6 +132,22 @@ def _get_moderation_client() -> AsyncOpenAI:
return AsyncOpenAI(**prepare_openai_kwargs({}))


async def _call_moderation_api(client: AsyncOpenAI, data: str) -> Any:
"""Call the OpenAI moderation API.

Args:
client: The OpenAI client to use.
data: The text to analyze.

Returns:
The moderation API response.
"""
return await client.moderations.create(
model="omni-moderation-latest",
input=data,
)


async def moderation(
ctx: Any,
data: str,
Expand All @@ -151,36 +167,29 @@ async def moderation(
Returns:
GuardrailResult: Indicates if tripwire was triggered, and details of flagged categories.
"""

# Prefer reusing an existing OpenAI client from context ONLY if it targets the
# official OpenAI API. If it's any other provider (e.g., Ollama via base_url),
# fall back to the default OpenAI moderation client.
def _maybe_reuse_openai_client_from_ctx(context: Any) -> AsyncOpenAI | None:
client = None
if ctx is not None:
candidate = getattr(ctx, "guardrail_llm", None)
if isinstance(candidate, AsyncOpenAI):
client = candidate

# Try the context client first, fall back if moderation endpoint doesn't exist
if client is not None:
try:
candidate = getattr(context, "guardrail_llm", None)
if not isinstance(candidate, AsyncOpenAI):
return None

# Attempt to discover the effective base URL in a best-effort way
base_url = getattr(candidate, "base_url", None)
if base_url is None:
inner = getattr(candidate, "_client", None)
base_url = getattr(inner, "base_url", None) or getattr(inner, "_base_url", None)

# Reuse only when clearly the official OpenAI endpoint
if base_url is None:
return candidate
if isinstance(base_url, str) and "api.openai.com" in base_url:
return candidate
return None
except Exception:
return None

client = _maybe_reuse_openai_client_from_ctx(ctx) or _get_moderation_client()
resp = await client.moderations.create(
model="omni-moderation-latest",
input=data,
)
resp = await _call_moderation_api(client, data)
except NotFoundError as e:
# Moderation endpoint doesn't exist on this provider (e.g., third-party)
# Fall back to the OpenAI client
logger.debug(
"Moderation endpoint not available on context client, falling back to OpenAI: %s",
e,
)
client = _get_moderation_client()
resp = await _call_moderation_api(client, data)
else:
# No context client, use fallback
client = _get_moderation_client()
resp = await _call_moderation_api(client, data)
Comment on lines +176 to +192
Copy link

Copilot AI Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code duplication: Lines 187-188 and 191-192 duplicate the same logic (getting fallback client and calling moderation API). Consider refactoring to avoid this duplication. For example, you could set client = _get_moderation_client() once and call resp = await _call_moderation_api(client, data) after the if-else block.

Copilot uses AI. Check for mistakes.
results = resp.results or []
if not results:
return GuardrailResult(
Expand Down
11 changes: 11 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,18 @@ class APITimeoutError(Exception):
"""Stub API timeout error."""


class NotFoundError(Exception):
"""Stub 404 not found error."""

def __init__(self, message: str, *, response: Any = None, body: Any = None) -> None:
"""Initialize NotFoundError with OpenAI-compatible signature."""
super().__init__(message)
self.response = response
self.body = body


_STUB_OPENAI_MODULE.APITimeoutError = APITimeoutError
_STUB_OPENAI_MODULE.NotFoundError = NotFoundError

_OPENAI_TYPES_MODULE = types.ModuleType("openai.types")
_OPENAI_TYPES_MODULE.Completion = _DummyResponse
Expand Down
78 changes: 78 additions & 0 deletions tests/unit/checks/test_moderation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,81 @@ async def create_empty(**_: Any) -> Any:

assert result.tripwire_triggered is False # noqa: S101
assert result.info["error"] == "No moderation results returned" # noqa: S101


@pytest.mark.asyncio
async def test_moderation_uses_context_client() -> None:
"""Moderation should use the client from context when available."""
from openai import AsyncOpenAI

# Track whether context client was used
context_client_used = False

async def track_create(**_: Any) -> Any:
nonlocal context_client_used
context_client_used = True

class _Result:
def model_dump(self) -> dict[str, Any]:
return {"categories": {"hate": False, "violence": False}}

return SimpleNamespace(results=[_Result()])

# Create a context with a guardrail_llm client
context_client = AsyncOpenAI(api_key="test-context-key", base_url="https://api.openai.com/v1")
context_client.moderations = SimpleNamespace(create=track_create) # type: ignore[assignment]

ctx = SimpleNamespace(guardrail_llm=context_client)

cfg = ModerationCfg(categories=[Category.HATE])
result = await moderation(ctx, "test text", cfg)

# Verify the context client was used
assert context_client_used is True # noqa: S101
assert result.tripwire_triggered is False # noqa: S101


@pytest.mark.asyncio
async def test_moderation_falls_back_for_third_party_provider(monkeypatch: pytest.MonkeyPatch) -> None:
"""Moderation should fall back to environment client for third-party providers."""
from openai import AsyncOpenAI, NotFoundError

# Create fallback client that tracks usage
fallback_used = False

async def track_fallback_create(**_: Any) -> Any:
nonlocal fallback_used
fallback_used = True

class _Result:
def model_dump(self) -> dict[str, Any]:
return {"categories": {"hate": False}}

return SimpleNamespace(results=[_Result()])

fallback_client = SimpleNamespace(moderations=SimpleNamespace(create=track_fallback_create))
monkeypatch.setattr("guardrails.checks.text.moderation._get_moderation_client", lambda: fallback_client)

# Create a mock httpx.Response for NotFoundError
mock_response = SimpleNamespace(
status_code=404,
headers={},
text="404 page not found",
json=lambda: {"error": {"message": "Not found", "type": "invalid_request_error"}},
)

# Create a context client that simulates a third-party provider
# When moderation is called, it should raise NotFoundError
async def raise_not_found(**_: Any) -> Any:
raise NotFoundError("404 page not found", response=mock_response, body=None) # type: ignore[arg-type]

third_party_client = AsyncOpenAI(api_key="third-party-key", base_url="https://localhost:8080/v1")
third_party_client.moderations = SimpleNamespace(create=raise_not_found) # type: ignore[assignment]
ctx = SimpleNamespace(guardrail_llm=third_party_client)

cfg = ModerationCfg(categories=[Category.HATE])
result = await moderation(ctx, "test text", cfg)

# Verify the fallback client was used (not the third-party one)
assert fallback_used is True # noqa: S101
assert result.tripwire_triggered is False # noqa: S101
Loading