From aef6b7b9721c0f01b53b95cc1f7e5e7b17d5a705 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Mon, 24 Nov 2025 08:46:44 -0600 Subject: [PATCH] sync async consistency --- guardrails/async_guard.py | 9 ++++++++- guardrails/formatters/base_formatter.py | 11 ++++++++++- guardrails/formatters/json_formatter.py | 3 +++ 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/guardrails/async_guard.py b/guardrails/async_guard.py index c7ad495ea..e28a42f87 100644 --- a/guardrails/async_guard.py +++ b/guardrails/async_guard.py @@ -325,7 +325,14 @@ async def _exec( Returns: The raw text output from the LLM and the validated output. """ - api = get_async_llm_ask(llm_api, *args, **kwargs) # type: ignore + api = None + + if llm_api is not None or kwargs.get("model") is not None: + api = get_async_llm_ask(llm_api, *args, **kwargs) # type: ignore + + if self._output_formatter is not None: + api = self._output_formatter.wrap_async_callable(api) # type: ignore + if kwargs.get("stream", False): runner = AsyncStreamRunner( output_type=self._output_type, diff --git a/guardrails/formatters/base_formatter.py b/guardrails/formatters/base_formatter.py index dc409e8c4..e23ce1153 100644 --- a/guardrails/formatters/base_formatter.py +++ b/guardrails/formatters/base_formatter.py @@ -2,6 +2,7 @@ from guardrails.llm_providers import ( ArbitraryCallable, + AsyncPromptCallableBase, PromptCallableBase, ) @@ -17,7 +18,15 @@ class BaseFormatter(ABC): @abstractmethod def wrap_callable(self, llm_callable: PromptCallableBase) -> ArbitraryCallable: ... + @abstractmethod + def wrap_async_callable( + self, llm_callable: PromptCallableBase + ) -> AsyncPromptCallableBase: ... + class PassthroughFormatter(BaseFormatter): - def wrap_callable(self, llm_callable: PromptCallableBase): + def wrap_callable(self, llm_callable: PromptCallableBase): # type: ignore + return llm_callable # Noop + + def wrap_async_callable(self, llm_callable: PromptCallableBase): # type: ignore return llm_callable # Noop diff --git a/guardrails/formatters/json_formatter.py b/guardrails/formatters/json_formatter.py index adb1a5207..5d8b84b72 100644 --- a/guardrails/formatters/json_formatter.py +++ b/guardrails/formatters/json_formatter.py @@ -149,3 +149,6 @@ def fn( raise ValueError( "JsonFormatter can only be used with HuggingFace*Callable." ) + + def wrap_async_callable(self, llm_callable): + raise NotImplementedError()