Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions guardrails/llm_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class PromptCallableBase:
failed, and how to fix it.
"""

supports_base_model = False

def __init__(self, *args, **kwargs):
self.init_args = args
self.init_kwargs = kwargs
Expand Down Expand Up @@ -119,6 +121,8 @@ def _invoke_llm(


class OpenAIChatCallable(PromptCallableBase):
supports_base_model = True

def _invoke_llm(
self,
text: Optional[str] = None,
Expand Down Expand Up @@ -588,6 +592,8 @@ async def invoke_llm(


class AsyncOpenAIChatCallable(AsyncPromptCallableBase):
supports_base_model = True

async def invoke_llm(
self,
text: Optional[str] = None,
Expand Down
37 changes: 13 additions & 24 deletions guardrails/run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import json
from functools import partial
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union

from eliot import add_destinations, start_action
Expand Down Expand Up @@ -549,36 +550,24 @@ def call(
3. Log the output
"""

# If the API supports a base model, pass it in.
api_fn = api
if api is not None:
supports_base_model = getattr(api, "supports_base_model", False)
if supports_base_model:
api_fn = partial(api, base_model=self.base_model)

with start_action(action_type="call", index=index, prompt=prompt) as action:
if output is not None:
llm_response = LLMResponse(
output=output,
)
elif api is None:
llm_response = LLMResponse(output=output)
elif api_fn is None:
raise ValueError("API or output must be provided.")
elif msg_history:
try:
llm_response = api(
msg_history=msg_history_source(msg_history),
base_model=self.base_model,
)
except Exception:
# If the API call fails, try calling again without the base model.
llm_response = api(msg_history=msg_history_source(msg_history))
llm_response = api_fn(msg_history=msg_history_source(msg_history))
elif prompt and instructions:
try:
llm_response = api(
prompt.source,
instructions=instructions.source,
base_model=self.base_model,
)
except Exception:
llm_response = api(prompt.source, instructions=instructions.source)
llm_response = api_fn(prompt.source, instructions=instructions.source)
elif prompt:
try:
llm_response = api(prompt.source, base_model=self.base_model)
except Exception:
llm_response = api(prompt.source)
llm_response = api_fn(prompt.source)
else:
raise ValueError("'prompt' or 'msg_history' must be provided.")

Expand Down
10 changes: 1 addition & 9 deletions tests/unit_tests/validators/test_competitor_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,7 @@
from guardrails.validators import CompetitorCheck, FailResult


class TestCompetitorCheck: # (unittest.TestCase):
def setUp(self):
# You can initialize any required resources or configurations here.
pass

def tearDown(self):
# You can clean up any resources used during the tests here.
pass

class TestCompetitorCheck:
def test_perform_ner(self, mocker):
# Create a mock NLP object
mock_util_is_package = mocker.patch("spacy.util.is_package")
Expand Down