diff --git a/guardrails/llm_providers.py b/guardrails/llm_providers.py index c8fce31d5..63940ac8d 100644 --- a/guardrails/llm_providers.py +++ b/guardrails/llm_providers.py @@ -32,6 +32,8 @@ class PromptCallableBase: failed, and how to fix it. """ + supports_base_model = False + def __init__(self, *args, **kwargs): self.init_args = args self.init_kwargs = kwargs @@ -119,6 +121,8 @@ def _invoke_llm( class OpenAIChatCallable(PromptCallableBase): + supports_base_model = True + def _invoke_llm( self, text: Optional[str] = None, @@ -588,6 +592,8 @@ async def invoke_llm( class AsyncOpenAIChatCallable(AsyncPromptCallableBase): + supports_base_model = True + async def invoke_llm( self, text: Optional[str] = None, diff --git a/guardrails/run.py b/guardrails/run.py index 7bf0b7445..4af15f76b 100644 --- a/guardrails/run.py +++ b/guardrails/run.py @@ -1,5 +1,6 @@ import copy import json +from functools import partial from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union from eliot import add_destinations, start_action @@ -549,36 +550,24 @@ def call( 3. Log the output """ + # If the API supports a base model, pass it in. + api_fn = api + if api is not None: + supports_base_model = getattr(api, "supports_base_model", False) + if supports_base_model: + api_fn = partial(api, base_model=self.base_model) + with start_action(action_type="call", index=index, prompt=prompt) as action: if output is not None: - llm_response = LLMResponse( - output=output, - ) - elif api is None: + llm_response = LLMResponse(output=output) + elif api_fn is None: raise ValueError("API or output must be provided.") elif msg_history: - try: - llm_response = api( - msg_history=msg_history_source(msg_history), - base_model=self.base_model, - ) - except Exception: - # If the API call fails, try calling again without the base model. - llm_response = api(msg_history=msg_history_source(msg_history)) + llm_response = api_fn(msg_history=msg_history_source(msg_history)) elif prompt and instructions: - try: - llm_response = api( - prompt.source, - instructions=instructions.source, - base_model=self.base_model, - ) - except Exception: - llm_response = api(prompt.source, instructions=instructions.source) + llm_response = api_fn(prompt.source, instructions=instructions.source) elif prompt: - try: - llm_response = api(prompt.source, base_model=self.base_model) - except Exception: - llm_response = api(prompt.source) + llm_response = api_fn(prompt.source) else: raise ValueError("'prompt' or 'msg_history' must be provided.") diff --git a/tests/unit_tests/validators/test_competitor_check.py b/tests/unit_tests/validators/test_competitor_check.py index 8c00c6448..bedfeb944 100644 --- a/tests/unit_tests/validators/test_competitor_check.py +++ b/tests/unit_tests/validators/test_competitor_check.py @@ -4,15 +4,7 @@ from guardrails.validators import CompetitorCheck, FailResult -class TestCompetitorCheck: # (unittest.TestCase): - def setUp(self): - # You can initialize any required resources or configurations here. - pass - - def tearDown(self): - # You can clean up any resources used during the tests here. - pass - +class TestCompetitorCheck: def test_perform_ner(self, mocker): # Create a mock NLP object mock_util_is_package = mocker.patch("spacy.util.is_package")