Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions guardrails/llm_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,15 @@ def get_llm_ask(
) -> Optional[PromptCallableBase]:
if "temperature" not in kwargs:
kwargs.update({"temperature": 0})

try:
from litellm import completion

if llm_api == completion or (llm_api is None and kwargs.get("model")):
return LiteLLMCallable(*args, **kwargs)
except ImportError:
pass

if llm_api == get_static_openai_create_func():
return OpenAICallable(*args, **kwargs)
if llm_api == get_static_openai_chat_create_func():
Expand Down Expand Up @@ -668,14 +677,6 @@ def get_llm_ask(
except ImportError:
pass

try:
from litellm import completion # noqa: F401 # type: ignore

if llm_api == completion or (llm_api is None and kwargs.get("model")):
return LiteLLMCallable(*args, **kwargs)
except ImportError:
pass

# Let the user pass in an arbitrary callable.
if llm_api is not None:
return ArbitraryCallable(*args, llm_api=llm_api, **kwargs)
Expand Down
24 changes: 24 additions & 0 deletions tests/integration_tests/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@

from typing import List
from pydantic import BaseModel
from guardrails.llm_providers import (
get_llm_ask,
LiteLLMCallable,
get_async_llm_ask,
AsyncLiteLLMCallable,
)


@pytest.mark.skipif(
Expand Down Expand Up @@ -145,3 +151,21 @@ def test_litellm_openai_async_messages():
assert res.validated_output
assert res.validated_output == res.raw_llm_output
assert len(res.validated_output.split("\n")) == 10


@pytest.mark.skipif(
not importlib.util.find_spec("litellm"),
reason="`litellm` is not installed",
)
def test_get_llm_ask_returns_litellm_callable_without_llm_api():
result = get_llm_ask(llm_api=None, model="azure/gpt-4")
assert isinstance(result, LiteLLMCallable)


@pytest.mark.skipif(
not importlib.util.find_spec("litellm"),
reason="`litellm` is not installed",
)
def test_get_async_llm_ask_returns_async_litellm_callable_without_llm_api():
result = get_async_llm_ask(llm_api=None, model="azure/gpt-4")
assert isinstance(result, AsyncLiteLLMCallable)