From f4cea7f1f7020e5aca4b13e90df25e8ed07b95f8 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 22 Jul 2024 23:33:05 +0530 Subject: [PATCH 1/2] fix(litellm): call LiteLLMCallable instead of OpenAIChatCallable when llm_api is not passed --- guardrails/llm_providers.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/guardrails/llm_providers.py b/guardrails/llm_providers.py index b11905415..cc33559f9 100644 --- a/guardrails/llm_providers.py +++ b/guardrails/llm_providers.py @@ -596,6 +596,15 @@ def get_llm_ask( ) -> Optional[PromptCallableBase]: if "temperature" not in kwargs: kwargs.update({"temperature": 0}) + + try: + from litellm import completion + + if llm_api == completion or (llm_api is None and kwargs.get("model")): + return LiteLLMCallable(*args, **kwargs) + except ImportError: + pass + if llm_api == get_static_openai_create_func(): return OpenAICallable(*args, **kwargs) if llm_api == get_static_openai_chat_create_func(): @@ -668,14 +677,6 @@ def get_llm_ask( except ImportError: pass - try: - from litellm import completion # noqa: F401 # type: ignore - - if llm_api == completion or (llm_api is None and kwargs.get("model")): - return LiteLLMCallable(*args, **kwargs) - except ImportError: - pass - # Let the user pass in an arbitrary callable. if llm_api is not None: return ArbitraryCallable(*args, llm_api=llm_api, **kwargs) From 7c3b02018c260294331eac34956004cccf06d46e Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 23 Jul 2024 01:50:36 +0530 Subject: [PATCH 2/2] add tests --- tests/integration_tests/test_litellm.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/integration_tests/test_litellm.py b/tests/integration_tests/test_litellm.py index 9ca06aec5..47514cba8 100644 --- a/tests/integration_tests/test_litellm.py +++ b/tests/integration_tests/test_litellm.py @@ -12,6 +12,12 @@ from typing import List from pydantic import BaseModel +from guardrails.llm_providers import ( + get_llm_ask, + LiteLLMCallable, + get_async_llm_ask, + AsyncLiteLLMCallable, +) @pytest.mark.skipif( @@ -145,3 +151,21 @@ def test_litellm_openai_async_messages(): assert res.validated_output assert res.validated_output == res.raw_llm_output assert len(res.validated_output.split("\n")) == 10 + + +@pytest.mark.skipif( + not importlib.util.find_spec("litellm"), + reason="`litellm` is not installed", +) +def test_get_llm_ask_returns_litellm_callable_without_llm_api(): + result = get_llm_ask(llm_api=None, model="azure/gpt-4") + assert isinstance(result, LiteLLMCallable) + + +@pytest.mark.skipif( + not importlib.util.find_spec("litellm"), + reason="`litellm` is not installed", +) +def test_get_async_llm_ask_returns_async_litellm_callable_without_llm_api(): + result = get_async_llm_ask(llm_api=None, model="azure/gpt-4") + assert isinstance(result, AsyncLiteLLMCallable)