diff --git a/src/guardrails/checks/text/hallucination_detection.py b/src/guardrails/checks/text/hallucination_detection.py index e775625..a175076 100644 --- a/src/guardrails/checks/text/hallucination_detection.py +++ b/src/guardrails/checks/text/hallucination_detection.py @@ -52,10 +52,7 @@ from guardrails.spec import GuardrailSpecMetadata from guardrails.types import GuardrailLLMContextProto, GuardrailResult -from .llm_base import ( - LLMConfig, - LLMOutput, -) +from .llm_base import LLMConfig, LLMOutput, _invoke_openai_callable logger = logging.getLogger(__name__) @@ -210,9 +207,10 @@ async def hallucination_detection( validation_query = f"{VALIDATION_PROMPT}\n\nText to validate:\n{candidate}" # Use the Responses API with file search and structured output - response = await ctx.guardrail_llm.responses.parse( - model=config.model, + response = await _invoke_openai_callable( + ctx.guardrail_llm.responses.parse, input=validation_query, + model=config.model, text_format=HallucinationDetectionOutput, tools=[{"type": "file_search", "vector_store_ids": [config.knowledge_source]}], ) diff --git a/src/guardrails/checks/text/llm_base.py b/src/guardrails/checks/text/llm_base.py index c125cea..9ab2077 100644 --- a/src/guardrails/checks/text/llm_base.py +++ b/src/guardrails/checks/text/llm_base.py @@ -31,12 +31,16 @@ class MyLLMOutput(LLMOutput): from __future__ import annotations +import asyncio +import functools +import inspect import json import logging import textwrap -from typing import TYPE_CHECKING, TypeVar +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, TypeVar -from openai import AsyncOpenAI +from openai import AsyncOpenAI, OpenAI from pydantic import BaseModel, ConfigDict, Field from guardrails.registry import default_spec_registry @@ -45,7 +49,13 @@ class MyLLMOutput(LLMOutput): from guardrails.utils.output import OutputSchema if TYPE_CHECKING: - from openai import AsyncOpenAI + from openai import AsyncAzureOpenAI, AzureOpenAI # type: ignore[unused-import] +else: + try: + from openai import AsyncAzureOpenAI, AzureOpenAI # type: ignore + except Exception: # pragma: no cover - optional dependency + AsyncAzureOpenAI = object # type: ignore[assignment] + AzureOpenAI = object # type: ignore[assignment] logger = logging.getLogger(__name__) @@ -165,10 +175,46 @@ def _strip_json_code_fence(text: str) -> str: return candidate +async def _invoke_openai_callable( + method: Callable[..., Any], + /, + *args: Any, + **kwargs: Any, +) -> Any: + """Invoke OpenAI SDK methods that may be sync or async.""" + if inspect.iscoroutinefunction(method): + return await method(*args, **kwargs) + + loop = asyncio.get_running_loop() + result = await loop.run_in_executor( + None, + functools.partial(method, *args, **kwargs), + ) + if inspect.isawaitable(result): + return await result + return result + + +async def _request_chat_completion( + client: AsyncOpenAI | OpenAI | AsyncAzureOpenAI | AzureOpenAI, + *, + messages: list[dict[str, str]], + model: str, + response_format: dict[str, Any], +) -> Any: + """Invoke chat.completions.create on sync or async OpenAI clients.""" + return await _invoke_openai_callable( + client.chat.completions.create, + messages=messages, + model=model, + response_format=response_format, + ) + + async def run_llm( text: str, system_prompt: str, - client: AsyncOpenAI, + client: AsyncOpenAI | OpenAI | AsyncAzureOpenAI | AzureOpenAI, model: str, output_model: type[LLMOutput], ) -> LLMOutput: @@ -180,7 +226,7 @@ async def run_llm( Args: text (str): Text to analyze. system_prompt (str): Prompt instructions for the LLM. - client (AsyncOpenAI): OpenAI client for LLM inference. + client (AsyncOpenAI | OpenAI | AsyncAzureOpenAI | AzureOpenAI): OpenAI client used for guardrails. model (str): Identifier for which LLM model to use. output_model (type[LLMOutput]): Model for parsing and validating the LLM's response. @@ -190,7 +236,8 @@ async def run_llm( full_prompt = _build_full_prompt(system_prompt) try: - response = await client.chat.completions.create( + response = await _request_chat_completion( + client=client, messages=[ {"role": "system", "content": full_prompt}, {"role": "user", "content": f"# Text\n\n{text}"}, diff --git a/src/guardrails/checks/text/prompt_injection_detection.py b/src/guardrails/checks/text/prompt_injection_detection.py index fcef929..3e2b905 100644 --- a/src/guardrails/checks/text/prompt_injection_detection.py +++ b/src/guardrails/checks/text/prompt_injection_detection.py @@ -36,7 +36,7 @@ from guardrails.spec import GuardrailSpecMetadata from guardrails.types import GuardrailLLMContextProto, GuardrailResult -from .llm_base import LLMConfig, LLMOutput +from .llm_base import LLMConfig, LLMOutput, _invoke_openai_callable __all__ = ["prompt_injection_detection", "PromptInjectionDetectionOutput"] @@ -341,9 +341,10 @@ def _create_skip_result( async def _call_prompt_injection_detection_llm(ctx: GuardrailLLMContextProto, prompt: str, config: LLMConfig) -> PromptInjectionDetectionOutput: """Call LLM for prompt injection detection analysis.""" - parsed_response = await ctx.guardrail_llm.responses.parse( - model=config.model, + parsed_response = await _invoke_openai_callable( + ctx.guardrail_llm.responses.parse, input=prompt, + model=config.model, text_format=PromptInjectionDetectionOutput, ) diff --git a/tests/unit/checks/test_llm_base.py b/tests/unit/checks/test_llm_base.py index 907f523..1d77a5f 100644 --- a/tests/unit/checks/test_llm_base.py +++ b/tests/unit/checks/test_llm_base.py @@ -34,6 +34,20 @@ def __init__(self, content: str | None) -> None: self.chat = SimpleNamespace(completions=_FakeCompletions(content)) +class _FakeSyncCompletions: + def __init__(self, content: str | None) -> None: + self._content = content + + def create(self, **kwargs: Any) -> Any: + _ = kwargs + return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content=self._content))]) + + +class _FakeSyncClient: + def __init__(self, content: str | None) -> None: + self.chat = SimpleNamespace(completions=_FakeSyncCompletions(content)) + + def test_strip_json_code_fence_removes_wrapping() -> None: """Valid JSON code fences should be removed.""" fenced = """```json @@ -64,6 +78,23 @@ async def test_run_llm_returns_valid_output() -> None: assert result.flagged is True and result.confidence == 0.9 # noqa: S101 +@pytest.mark.asyncio +async def test_run_llm_supports_sync_clients() -> None: + """run_llm should invoke synchronous clients without awaiting them.""" + client = _FakeSyncClient('{"flagged": false, "confidence": 0.25}') + + result = await run_llm( + text="General text", + system_prompt="Assess text.", + client=client, # type: ignore[arg-type] + model="gpt-test", + output_model=LLMOutput, + ) + + assert isinstance(result, LLMOutput) # noqa: S101 + assert result.flagged is False and result.confidence == 0.25 # noqa: S101 + + @pytest.mark.asyncio async def test_run_llm_handles_content_filter_error(monkeypatch: pytest.MonkeyPatch) -> None: """Content filter errors should return LLMErrorOutput with flagged=True.""" diff --git a/tests/unit/checks/test_prompt_injection_detection.py b/tests/unit/checks/test_prompt_injection_detection.py index 1cda87c..e0d90f5 100644 --- a/tests/unit/checks/test_prompt_injection_detection.py +++ b/tests/unit/checks/test_prompt_injection_detection.py @@ -120,3 +120,21 @@ async def failing_llm(*_args: Any, **_kwargs: Any) -> PromptInjectionDetectionOu assert result.tripwire_triggered is False # noqa: S101 assert "Error during prompt injection detection check" in result.info["observation"] # noqa: S101 + + +@pytest.mark.asyncio +async def test_prompt_injection_detection_llm_supports_sync_responses() -> None: + """Underlying responses.parse may be synchronous for some clients.""" + analysis = PromptInjectionDetectionOutput(flagged=True, confidence=0.4, observation="Action summary") + + class _SyncResponses: + def parse(self, **kwargs: Any) -> Any: + _ = kwargs + return SimpleNamespace(output_parsed=analysis) + + context = SimpleNamespace(guardrail_llm=SimpleNamespace(responses=_SyncResponses())) + config = LLMConfig(model="gpt-test", confidence_threshold=0.5) + + parsed = await pid_module._call_prompt_injection_detection_llm(context, "prompt", config) + + assert parsed is analysis # noqa: S101