From 6566c592051d601425b458ce0c68164ce0bc25d4 Mon Sep 17 00:00:00 2001 From: Shreya Rajpal Date: Thu, 22 Feb 2024 21:18:47 -0800 Subject: [PATCH 1/6] update Runner call to check for supports_base_model attribute instead try except --- guardrails/llm_providers.py | 8 +++++ guardrails/run.py | 33 ++++++------------- .../validators/test_competitor_check.py | 10 +----- 3 files changed, 19 insertions(+), 32 deletions(-) diff --git a/guardrails/llm_providers.py b/guardrails/llm_providers.py index c8fce31d5..4f17d9263 100644 --- a/guardrails/llm_providers.py +++ b/guardrails/llm_providers.py @@ -32,6 +32,8 @@ class PromptCallableBase: failed, and how to fix it. """ + supports_base_model = False + def __init__(self, *args, **kwargs): self.init_args = args self.init_kwargs = kwargs @@ -119,6 +121,9 @@ def _invoke_llm( class OpenAIChatCallable(PromptCallableBase): + + supports_base_model = True + def _invoke_llm( self, text: Optional[str] = None, @@ -588,6 +593,9 @@ async def invoke_llm( class AsyncOpenAIChatCallable(AsyncPromptCallableBase): + + supports_base_model = True + async def invoke_llm( self, text: Optional[str] = None, diff --git a/guardrails/run.py b/guardrails/run.py index 7bf0b7445..13525d6d4 100644 --- a/guardrails/run.py +++ b/guardrails/run.py @@ -1,4 +1,5 @@ import copy +from functools import partial import json from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union @@ -549,36 +550,22 @@ def call( 3. Log the output """ + # If the API supports a base model, pass it in. + supports_base_model = getattr(api, "supports_base_model", False) + if supports_base_model: + api = partial(api, base_model=self.base_model) + with start_action(action_type="call", index=index, prompt=prompt) as action: if output is not None: - llm_response = LLMResponse( - output=output, - ) + llm_response = LLMResponse(output=output) elif api is None: raise ValueError("API or output must be provided.") elif msg_history: - try: - llm_response = api( - msg_history=msg_history_source(msg_history), - base_model=self.base_model, - ) - except Exception: - # If the API call fails, try calling again without the base model. - llm_response = api(msg_history=msg_history_source(msg_history)) + llm_response = api(msg_history=msg_history_source(msg_history)) elif prompt and instructions: - try: - llm_response = api( - prompt.source, - instructions=instructions.source, - base_model=self.base_model, - ) - except Exception: - llm_response = api(prompt.source, instructions=instructions.source) + llm_response = api(prompt.source, instructions=instructions.source) elif prompt: - try: - llm_response = api(prompt.source, base_model=self.base_model) - except Exception: - llm_response = api(prompt.source) + llm_response = api(prompt.source) else: raise ValueError("'prompt' or 'msg_history' must be provided.") diff --git a/tests/unit_tests/validators/test_competitor_check.py b/tests/unit_tests/validators/test_competitor_check.py index 8c00c6448..bedfeb944 100644 --- a/tests/unit_tests/validators/test_competitor_check.py +++ b/tests/unit_tests/validators/test_competitor_check.py @@ -4,15 +4,7 @@ from guardrails.validators import CompetitorCheck, FailResult -class TestCompetitorCheck: # (unittest.TestCase): - def setUp(self): - # You can initialize any required resources or configurations here. - pass - - def tearDown(self): - # You can clean up any resources used during the tests here. - pass - +class TestCompetitorCheck: def test_perform_ner(self, mocker): # Create a mock NLP object mock_util_is_package = mocker.patch("spacy.util.is_package") From d7fe2a9532b088a03000ebe5f52d555958e0ad02 Mon Sep 17 00:00:00 2001 From: Shreya Rajpal Date: Thu, 22 Feb 2024 21:23:22 -0800 Subject: [PATCH 2/6] lint --- guardrails/llm_providers.py | 2 -- guardrails/run.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/guardrails/llm_providers.py b/guardrails/llm_providers.py index 4f17d9263..63940ac8d 100644 --- a/guardrails/llm_providers.py +++ b/guardrails/llm_providers.py @@ -121,7 +121,6 @@ def _invoke_llm( class OpenAIChatCallable(PromptCallableBase): - supports_base_model = True def _invoke_llm( @@ -593,7 +592,6 @@ async def invoke_llm( class AsyncOpenAIChatCallable(AsyncPromptCallableBase): - supports_base_model = True async def invoke_llm( diff --git a/guardrails/run.py b/guardrails/run.py index 13525d6d4..d0de3e276 100644 --- a/guardrails/run.py +++ b/guardrails/run.py @@ -1,6 +1,6 @@ import copy -from functools import partial import json +from functools import partial from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union from eliot import add_destinations, start_action From 24e7d69a3b8277b1d628e69e08cdf489c983986e Mon Sep 17 00:00:00 2001 From: Shreya Rajpal Date: Thu, 22 Feb 2024 21:35:24 -0800 Subject: [PATCH 3/6] fix type hints --- guardrails/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guardrails/run.py b/guardrails/run.py index d0de3e276..de975ad60 100644 --- a/guardrails/run.py +++ b/guardrails/run.py @@ -540,7 +540,7 @@ def call( instructions: Optional[Instructions], prompt: Optional[Prompt], msg_history: Optional[List[Dict[str, str]]], - api: Optional[PromptCallableBase], + api: Optional[PromptCallableBase | partial[LLMResponse]], output: Optional[str] = None, ) -> LLMResponse: """Run a step. From 0797be71969365910f11ad4c1722998c5c94ba88 Mon Sep 17 00:00:00 2001 From: Shreya Rajpal Date: Thu, 22 Feb 2024 23:31:26 -0800 Subject: [PATCH 4/6] fix type hints --- guardrails/run.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/guardrails/run.py b/guardrails/run.py index de975ad60..05eea8b09 100644 --- a/guardrails/run.py +++ b/guardrails/run.py @@ -551,9 +551,10 @@ def call( """ # If the API supports a base model, pass it in. - supports_base_model = getattr(api, "supports_base_model", False) - if supports_base_model: - api = partial(api, base_model=self.base_model) + if api is not None: + supports_base_model = getattr(api, "supports_base_model", False) + if supports_base_model: + api = partial(api, base_model=self.base_model) with start_action(action_type="call", index=index, prompt=prompt) as action: if output is not None: From 524634141ac37b6d74806f59cb43cfed410f357d Mon Sep 17 00:00:00 2001 From: Shreya Rajpal Date: Thu, 22 Feb 2024 23:36:09 -0800 Subject: [PATCH 5/6] fix type hints --- guardrails/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guardrails/run.py b/guardrails/run.py index 05eea8b09..dc5ae4471 100644 --- a/guardrails/run.py +++ b/guardrails/run.py @@ -540,7 +540,7 @@ def call( instructions: Optional[Instructions], prompt: Optional[Prompt], msg_history: Optional[List[Dict[str, str]]], - api: Optional[PromptCallableBase | partial[LLMResponse]], + api: Optional[Union[PromptCallableBase, partial[LLMResponse]]], output: Optional[str] = None, ) -> LLMResponse: """Run a step. From 866c9b613f9c03c6abd0f42fe7664f6fb4e4f6ca Mon Sep 17 00:00:00 2001 From: Shreya Rajpal Date: Fri, 23 Feb 2024 00:15:48 -0800 Subject: [PATCH 6/6] fix type hints --- guardrails/run.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/guardrails/run.py b/guardrails/run.py index dc5ae4471..4af15f76b 100644 --- a/guardrails/run.py +++ b/guardrails/run.py @@ -540,7 +540,7 @@ def call( instructions: Optional[Instructions], prompt: Optional[Prompt], msg_history: Optional[List[Dict[str, str]]], - api: Optional[Union[PromptCallableBase, partial[LLMResponse]]], + api: Optional[PromptCallableBase], output: Optional[str] = None, ) -> LLMResponse: """Run a step. @@ -551,22 +551,23 @@ def call( """ # If the API supports a base model, pass it in. + api_fn = api if api is not None: supports_base_model = getattr(api, "supports_base_model", False) if supports_base_model: - api = partial(api, base_model=self.base_model) + api_fn = partial(api, base_model=self.base_model) with start_action(action_type="call", index=index, prompt=prompt) as action: if output is not None: llm_response = LLMResponse(output=output) - elif api is None: + elif api_fn is None: raise ValueError("API or output must be provided.") elif msg_history: - llm_response = api(msg_history=msg_history_source(msg_history)) + llm_response = api_fn(msg_history=msg_history_source(msg_history)) elif prompt and instructions: - llm_response = api(prompt.source, instructions=instructions.source) + llm_response = api_fn(prompt.source, instructions=instructions.source) elif prompt: - llm_response = api(prompt.source) + llm_response = api_fn(prompt.source) else: raise ValueError("'prompt' or 'msg_history' must be provided.")