From e45cf93217da51d7044691047990fa349a3bd842 Mon Sep 17 00:00:00 2001 From: Karan Acharya Date: Thu, 14 Mar 2024 12:43:09 -0400 Subject: [PATCH 1/8] Modify use() with kwargs "on" --- guardrails/guard.py | 85 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 71 insertions(+), 14 deletions(-) diff --git a/guardrails/guard.py b/guardrails/guard.py index 22120a51f..f80170ce6 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -1117,32 +1117,89 @@ def with_msg_history_validation( return self @overload - def use(self, validator: Validator) -> "Guard": + def use(self, validator: Validator, on="output") -> "Guard": ... @overload - def use(self, validator: Type[Validator], *args, **kwargs) -> "Guard": + def use(self, validator: Type[Validator], *args, on="output", **kwargs) -> "Guard": ... def use( - self, validator: Union[Validator, Type[Validator]], *args, **kwargs + self, validator: Union[Validator, Type[Validator]], *args, on="output", **kwargs ) -> "Guard": - """Use a validator to validate results of an LLM request. + """Use a validator to validate either of the following: + - The output of an LLM request + - The prompt + - The instructions + - The message history + + *Note*: For on="output", `use` is only available for string output types. - *Note*: `use` is only available for string output types. + Args: + validator: The validator to use. Either the class or an instance. + on: The part of the LLM request to validate. Defaults to "output". """ + if on == "prompt": + if validator: + hydrated_validator = get_validator(validator, *args, **kwargs) + # If the prompt schema exists, add the validator to it + if self.rail.prompt_schema: + self.rail.prompt_schema.root_datatype.validators.append( + hydrated_validator + ) + else: + # Otherwise, create a new schema with the validator + schema = StringSchema.from_string( + validators=[hydrated_validator], + ) + self.rail.prompt_schema = schema + elif on == "instructions": + if validator: + hydrated_validator = get_validator(validator, *args, **kwargs) + # If the instructions schema exists, add the validator to it + if self.rail.instructions_schema: + self.rail.instructions_schema.root_datatype.validators.append( + hydrated_validator + ) + else: + # Otherwise, create a new schema with the validator + schema = StringSchema.from_string( + validators=[hydrated_validator], + ) + self.rail.instructions_schema = schema + elif on == "msg_history": + if validator: + hydrated_validator = get_validator(validator, *args, **kwargs) + # If the msg_history schema exists, add the validator to it + if self.rail.msg_history_schema: + self.rail.msg_history_schema.root_datatype.validators.append( + hydrated_validator + ) + else: + # Otherwise, create a new schema with the validator + schema = StringSchema.from_string( + validators=[hydrated_validator], + ) + self.rail.msg_history_schema = schema + elif on == "output": + # Only available for string output types + if self.rail.output_type != "str": + raise RuntimeError( + "The `use` method is only available for string output types." + ) - if self.rail.output_type != "str": - raise RuntimeError( - "The `use` method is only available for string output types." + if validator: + hydrated_validator = get_validator(validator, *args, **kwargs) + self._validators.append(hydrated_validator) + self.rail.output_schema.root_datatype.validators.append( + hydrated_validator + ) + else: + raise ValueError( + """Invalid value for `on`. Must be one of + 'output', 'prompt', 'instructions', 'msg_history'.""" ) - if validator: - hydrated_validator = get_validator(validator, *args, **kwargs) - self._validators.append(hydrated_validator) - - self.rail.output_schema.root_datatype.validators.append(hydrated_validator) - return self @overload From 441f7f51cce5c8178525eef2f3598345bc4e78b4 Mon Sep 17 00:00:00 2001 From: Karan Acharya Date: Thu, 14 Mar 2024 14:45:59 -0400 Subject: [PATCH 2/8] Refactor code; update use() and use_many() --- guardrails/guard.py | 131 +++++++++++++++++++++----------------------- 1 file changed, 62 insertions(+), 69 deletions(-) diff --git a/guardrails/guard.py b/guardrails/guard.py index f80170ce6..677891a73 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -1116,16 +1116,68 @@ def with_msg_history_validation( self.rail.msg_history_schema = schema return self + def __add_validator(self, validator: Validator, on: str = "output"): + # Only available for string output types + if self.rail.output_type != "str": + raise RuntimeError( + "The `use` method is only available for string output types." + ) + + if on == "prompt": + # If the prompt schema exists, add the validator to it + if self.rail.prompt_schema: + self.rail.prompt_schema.root_datatype.validators.append(validator) + else: + # Otherwise, create a new schema with the validator + schema = StringSchema.from_string( + validators=[validator], + ) + self.rail.prompt_schema = schema + elif on == "instructions": + # If the instructions schema exists, add the validator to it + if self.rail.instructions_schema: + self.rail.instructions_schema.root_datatype.validators.append(validator) + else: + # Otherwise, create a new schema with the validator + schema = StringSchema.from_string( + validators=[validator], + ) + self.rail.instructions_schema = schema + elif on == "msg_history": + # If the msg_history schema exists, add the validator to it + if self.rail.msg_history_schema: + self.rail.msg_history_schema.root_datatype.validators.append(validator) + else: + # Otherwise, create a new schema with the validator + schema = StringSchema.from_string( + validators=[validator], + ) + self.rail.msg_history_schema = schema + elif on == "output": + self._validators.append(validator) + self.rail.output_schema.root_datatype.validators.append(validator) + else: + raise ValueError( + """Invalid value for `on`. Must be one of the following: + 'output', 'prompt', 'instructions', 'msg_history'.""" + ) + @overload - def use(self, validator: Validator, on="output") -> "Guard": + def use(self, validator: Validator, on: str = "output") -> "Guard": ... @overload - def use(self, validator: Type[Validator], *args, on="output", **kwargs) -> "Guard": + def use( + self, validator: Type[Validator], *args, on: str = "output", **kwargs + ) -> "Guard": ... def use( - self, validator: Union[Validator, Type[Validator]], *args, on="output", **kwargs + self, + validator: Union[Validator, Type[Validator]], + *args, + on: str = "output", + **kwargs, ) -> "Guard": """Use a validator to validate either of the following: - The output of an LLM request @@ -1139,71 +1191,12 @@ def use( validator: The validator to use. Either the class or an instance. on: The part of the LLM request to validate. Defaults to "output". """ - if on == "prompt": - if validator: - hydrated_validator = get_validator(validator, *args, **kwargs) - # If the prompt schema exists, add the validator to it - if self.rail.prompt_schema: - self.rail.prompt_schema.root_datatype.validators.append( - hydrated_validator - ) - else: - # Otherwise, create a new schema with the validator - schema = StringSchema.from_string( - validators=[hydrated_validator], - ) - self.rail.prompt_schema = schema - elif on == "instructions": - if validator: - hydrated_validator = get_validator(validator, *args, **kwargs) - # If the instructions schema exists, add the validator to it - if self.rail.instructions_schema: - self.rail.instructions_schema.root_datatype.validators.append( - hydrated_validator - ) - else: - # Otherwise, create a new schema with the validator - schema = StringSchema.from_string( - validators=[hydrated_validator], - ) - self.rail.instructions_schema = schema - elif on == "msg_history": - if validator: - hydrated_validator = get_validator(validator, *args, **kwargs) - # If the msg_history schema exists, add the validator to it - if self.rail.msg_history_schema: - self.rail.msg_history_schema.root_datatype.validators.append( - hydrated_validator - ) - else: - # Otherwise, create a new schema with the validator - schema = StringSchema.from_string( - validators=[hydrated_validator], - ) - self.rail.msg_history_schema = schema - elif on == "output": - # Only available for string output types - if self.rail.output_type != "str": - raise RuntimeError( - "The `use` method is only available for string output types." - ) - - if validator: - hydrated_validator = get_validator(validator, *args, **kwargs) - self._validators.append(hydrated_validator) - self.rail.output_schema.root_datatype.validators.append( - hydrated_validator - ) - else: - raise ValueError( - """Invalid value for `on`. Must be one of - 'output', 'prompt', 'instructions', 'msg_history'.""" - ) - + hydrated_validator = get_validator(validator, *args, **kwargs) + self.__add_validator(hydrated_validator, on=on) return self @overload - def use_many(self, *validators: Validator) -> "Guard": + def use_many(self, *validators: Validator, on: str = "output") -> "Guard": ... @overload @@ -1214,6 +1207,7 @@ def use_many( Optional[Union[List[Any], Dict[str, Any]]], Optional[Dict[str, Any]], ], + on: str = "output", ) -> "Guard": ... @@ -1227,22 +1221,21 @@ def use_many( Optional[Dict[str, Any]], ], ], + on: str = "output", ) -> "Guard": """Use a validator to validate results of an LLM request. *Note*: `use_many` is only available for string output types. """ - if self.rail.output_type != "str": raise RuntimeError( "The `use_many` method is only available for string output types." ) + # Loop through the validators for v in validators: hydrated_validator = get_validator(v) - self._validators.append(hydrated_validator) - self.rail.output_schema.root_datatype.validators.append(hydrated_validator) - + self.__add_validator(hydrated_validator, on=on) return self def validate(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]: From 549605e9e0390d02f5c5aaae074bd8131c5c5910 Mon Sep 17 00:00:00 2001 From: Karan Acharya Date: Thu, 14 Mar 2024 18:04:45 -0400 Subject: [PATCH 3/8] Add unit tests and integration tests --- tests/integration_tests/test_guard.py | 2 +- tests/integration_tests/test_litellm.py | 169 +++++++++++++++++++++--- tests/unit_tests/test_guard.py | 107 ++++++++++++++- 3 files changed, 260 insertions(+), 18 deletions(-) diff --git a/tests/integration_tests/test_guard.py b/tests/integration_tests/test_guard.py index 78a85a354..7ba5c074e 100644 --- a/tests/integration_tests/test_guard.py +++ b/tests/integration_tests/test_guard.py @@ -841,7 +841,7 @@ def invoke( model = MockModel() guard = ( Guard() - .use(RegexMatch("Ice cream", match_type="search")) + .use(RegexMatch("Ice cream", match_type="search"), on="output") .use(ReadingTime(0.05)) # 3 seconds ) output_parser = StrOutputParser() diff --git a/tests/integration_tests/test_litellm.py b/tests/integration_tests/test_litellm.py index 52c63909b..fe0abae7a 100644 --- a/tests/integration_tests/test_litellm.py +++ b/tests/integration_tests/test_litellm.py @@ -5,8 +5,13 @@ import pytest import guardrails as gd -from guardrails.validators import LowerCase - +from guardrails.validators import ( + LowerCase, + UpperCase, + OneLine, + EndsWith, + ValidLength +) # Mock the litellm.completion function and # the classes it returns @@ -14,47 +19,179 @@ class Message: content: str - @dataclass class Choice: message: Message - @dataclass class Usage: prompt_tokens: int completion_tokens: int - @dataclass class MockResponse: choices: List[Choice] usage: Usage - class MockCompletion: @staticmethod - def create() -> MockResponse: + def create(output) -> MockResponse: return MockResponse( - choices=[Choice(message=Message(content="GUARDRAILS AI"))], + choices=[Choice(message=Message(content=output))], usage=Usage(prompt_tokens=10, completion_tokens=20), ) +@pytest.mark.skipif( + not importlib.util.find_spec("litellm"), + reason="`litellm` is not installed", +) +@pytest.mark.parametrize( + "input_text, expected", + [ + ("Suggestions for a name for an AI company. The name should be short and catchy.", "GUARDRAILS AI"), + ("What is the capital of France?", "PARIS"), + ] +) +def test_litellm_completion(mocker, input_text, expected): + """Test that Guardrails can use litellm for completions.""" + import litellm -TEST_PROMPT = "Suggest a name for an AI company. The name should be short and catchy." -guard = gd.Guard.from_string(validators=[LowerCase(on_fail="fix")], prompt=TEST_PROMPT) + mocker.patch("litellm.completion", return_value=MockCompletion.create(expected)) + guard = gd.Guard.from_string(validators=[LowerCase(on_fail="fix")], prompt=input_text) + + raw, validated, *rest = guard(litellm.completion) + assert raw == expected + assert validated == expected.lower() +# Test Guard().use() with just output validators @pytest.mark.skipif( not importlib.util.find_spec("litellm"), reason="`litellm` is not installed", ) -def test_litellm_completion(mocker): - """Test that Guardrails can use litellm for completions.""" +@pytest.mark.parametrize( + "input_text, raw_response, pass_output", + [ + ("Name one Oscar-nominated film", "may december", True), + ("Name one Oscar-nominated film", "PAST LIVES", False), + ] +) +def test_guard_use_output_validators(mocker, input_text, raw_response, pass_output): + """Test Guard().use() with just output validators.""" import litellm + mocker.patch("litellm.completion", return_value=MockCompletion.create(raw_response)) - mocker.patch("litellm.completion", return_value=MockCompletion.create()) + guard = gd.Guard().use(LowerCase, on="output", on_fail="fix").use(OneLine, on="output", on_fail="noop") + raw, validated, *rest = guard(litellm.completion, prompt=input_text) - raw, validated, *rest = guard(litellm.completion) - assert raw == "GUARDRAILS AI" - assert validated == "guardrails ai" + assert raw == raw_response + if pass_output: + assert validated == raw_response + else: + assert validated == raw_response.lower() + +# Test Guard().use() with a combination of prompt and output validators +@pytest.mark.skipif( + not importlib.util.find_spec("litellm"), + reason="`litellm` is not installed", +) +@pytest.mark.parametrize( + "input_text, pass_input, raw_response, pass_output", + [ + ("name one oscar-nominated film", True, "MAY DECEMBER", True), + ("Name one Oscar-nominated film", False, "PAST LIVES", True), + ("Name one Oscar-nominated film", False, "past lives", False), + ("name one oscar-nominated film", True, "past lives", False), + ] +) +def test_guard_use_combination_validators(mocker, input_text, pass_input, raw_response, pass_output): + """Test Guard().use() with a combination of prompt and output validators.""" + import litellm + mocker.patch("litellm.completion", return_value=MockCompletion.create(raw_response)) + + guard = gd.Guard().use(LowerCase, on="prompt", on_fail="exception").use(UpperCase, on="output", on_fail="fix") + + if pass_input: + raw, validated, *rest = guard(litellm.completion, prompt=input_text) + + assert raw == raw_response + if pass_output: + assert validated == raw_response + else: + assert validated == raw_response.upper() + else: + with pytest.raises(Exception): + raw, validated, *rest = guard(litellm.completion, prompt=input_text) + + + +# Test Guard().use_many() with just output validators +@pytest.mark.skipif( + not importlib.util.find_spec("litellm"), + reason="`litellm` is not installed", +) +@pytest.mark.parametrize( + "input_text, raw_response, pass_output", + [ + ("Name one Oscar-nominated film", "may december", True), + ("Name one Oscar-nominated film", "PAST LIVES", False), + ] +) +def test_guard_use_many_output_validators(mocker, input_text, raw_response, pass_output): + """Test Guard().use_many() with just output validators.""" + import litellm + mocker.patch("litellm.completion", return_value=MockCompletion.create(raw_response)) + + guard = gd.Guard().use_many( + LowerCase(on_fail="fix"), + OneLine(on_fail="noop"), + on="output" + ) + raw, validated, *rest = guard(litellm.completion, prompt=input_text) + + assert raw == raw_response + if pass_output: + assert validated == raw_response + else: + assert validated == raw_response.lower() + + +# Test Guard().use_many() with a combination of prompt and output validators +@pytest.mark.skipif( + not importlib.util.find_spec("litellm"), + reason="`litellm` is not installed", +) +@pytest.mark.parametrize( + "input_text, pass_input, raw_response, pass_output", + [ + ("name one oscar-nominated film", True, "MAY DECEMBER", True), + ("Name one Oscar-nominated film", False, "PAST LIVES", True), + ("Name one Oscar-nominated film", False, "past lives", False), + ("name one oscar-nominated film", True, "past lives", False), + ] +) +def test_guard_use_many_combination_validators(mocker, input_text, pass_input, raw_response, pass_output): + """Test Guard().use() with a combination of prompt and output validators.""" + import litellm + mocker.patch("litellm.completion", return_value=MockCompletion.create(raw_response)) + + # guard = gd.Guard().use(LowerCase, on="prompt", on_fail="exception").use(UpperCase, on="output", on_fail="fix") + guard = gd.Guard().use_many( + LowerCase(on_fail="exception"), + on="prompt" + ).use_many( + UpperCase(on_fail="fix"), + on="output" + ) + + if pass_input: + raw, validated, *rest = guard(litellm.completion, prompt=input_text) + + assert raw == raw_response + if pass_output: + assert validated == raw_response + else: + assert validated == raw_response.upper() + else: + with pytest.raises(Exception): + raw, validated, *rest = guard(litellm.completion, prompt=input_text) diff --git a/tests/unit_tests/test_guard.py b/tests/unit_tests/test_guard.py index 729593e1a..0f355d6c9 100644 --- a/tests/unit_tests/test_guard.py +++ b/tests/unit_tests/test_guard.py @@ -12,6 +12,7 @@ OneLine, PassResult, TwoWords, + UpperCase, ValidLength, register_validator, ) @@ -223,6 +224,34 @@ class TestClass(BaseModel): py_guard = Guard.from_pydantic(output_class=TestClass) py_guard.use(EndsWith("a"), OneLine(), LowerCase(), TwoWords(on_fail="reask")) + # Use a combination of prompt, instructions, msg_history and output validators + # Should only have the output validators in the guard, everything else is in the schema + guard: Guard = ( + Guard() + .use(LowerCase, on="prompt") + .use(OneLine, on="prompt") + .use(UpperCase, on="instructions") + .use(EndsWith("a"), on="msg_history") + .use(EndsWith("a"), on="output") # default on="output", still explicitly set + .use(TwoWords, on_fail="reask") # default on="output", implicitly set + ) + + assert len(guard._validators) == 2 # only 2 output validators, hence 2 + + assert isinstance(guard._validators[0], EndsWith) + assert guard._validators[0]._kwargs["end"] == "a" + assert guard._validators[0].on_fail_descriptor == "fix" # bc this is the default + + assert isinstance(guard._validators[1], TwoWords) + assert guard._validators[1].on_fail_descriptor == "reask" # bc we set it + + # Test with an invalid "on" parameter, should raise a ValueError + with pytest.raises(ValueError): + guard: Guard = ( + Guard() + .use(EndsWith("a"), on="response") # invalid on parameter + .use(OneLine, on="prompt") # valid on parameter + ) def test_use_many_instances(): guard: Guard = Guard().use_many( @@ -257,6 +286,33 @@ class TestClass(BaseModel): [EndsWith("a"), OneLine(), LowerCase(), TwoWords(on_fail="reask")] ) + # Test with explicitly setting the "on" parameter + guard: Guard = Guard().use_many( + EndsWith("a"), OneLine(), LowerCase(), TwoWords(on_fail="reask"), on="output" + ) + + assert len(guard._validators) == 4 # still 4 output validators, hence 4 + + assert isinstance(guard._validators[0], EndsWith) + assert guard._validators[0]._end == "a" + assert guard._validators[0]._kwargs["end"] == "a" + assert guard._validators[0].on_fail_descriptor == "fix" # bc this is the default + + assert isinstance(guard._validators[1], OneLine) + assert guard._validators[1].on_fail_descriptor == "noop" # bc this is the default + + assert isinstance(guard._validators[2], LowerCase) + assert guard._validators[2].on_fail_descriptor == "noop" # bc this is the default + + assert isinstance(guard._validators[3], TwoWords) + assert guard._validators[3].on_fail_descriptor == "reask" # bc we set it + + # Test with an invalid "on" parameter, should raise a ValueError + with pytest.raises(ValueError): + guard: Guard = Guard().use_many( + EndsWith("a", on_fail="exception"), OneLine(), on="response" + ) + def test_use_many_tuple(): guard: Guard = Guard().use_many( @@ -294,12 +350,36 @@ def test_use_many_tuple(): assert guard._validators[4]._kwargs["max"] == 12 assert guard._validators[4].on_fail_descriptor == "refrain" # bc we set it + # Test with explicitly setting the "on" parameter + guard: Guard = Guard().use_many( + (EndsWith, ["a"], {"on_fail": "exception"}), + OneLine, + on="output", + ) + + assert len(guard._validators) == 2 # only 2 output validators, hence 2 + + assert isinstance(guard._validators[0], EndsWith) + assert guard._validators[0]._end == "a" + assert guard._validators[0]._kwargs["end"] == "a" + assert guard._validators[0].on_fail_descriptor == "exception" # bc we set it + + assert isinstance(guard._validators[1], OneLine) + assert guard._validators[1].on_fail_descriptor == "noop" # bc this is the default + + # Test with an invalid "on" parameter, should raise a ValueError + with pytest.raises(ValueError): + guard: Guard = Guard().use_many( + (EndsWith, ["a"], {"on_fail": "exception"}), + OneLine, + on="response", + ) def test_validate(): guard: Guard = ( Guard() .use(OneLine) - .use(LowerCase(on_fail="fix")) + .use(LowerCase(on_fail="fix"), on="output") # default on="output", still explicitly set .use(TwoWords) .use(ValidLength, 0, 12, on_fail="refrain") ) @@ -318,6 +398,31 @@ def test_validate(): assert response_2.validation_passed is False assert response_2.validated_output is None + # Test with a combination of prompt, output, instructions and msg_history validators + # Should still only use the output validators to validate the output + guard: Guard = ( + Guard() + .use(OneLine, on="prompt") + .use(LowerCase, on="instructions") + .use(UpperCase, on="msg_history") + .use(LowerCase, on="output", on_fail="fix") + .use(TwoWords, on="output") + .use(ValidLength, 0, 12, on="output") + ) + + llm_output: str = "Oh Canada" # bc it meets our criteria + + response = guard.validate(llm_output) + + assert response.validation_passed is True + assert response.validated_output == llm_output.lower() + + llm_output_2 = "Star Spangled Banner" # to stick with the theme + + response_2 = guard.validate(llm_output_2) + + assert response_2.validation_passed is False + assert response_2.validated_output is None # def test_call(): # five_seconds = 5 / 60 From 6c9ca9c3e77a32c7a4b5277a3c1f1a3f7cc84784 Mon Sep 17 00:00:00 2001 From: Karan Acharya Date: Thu, 14 Mar 2024 18:07:34 -0400 Subject: [PATCH 4/8] Lint --- tests/integration_tests/test_litellm.py | 104 +++++++++++++++--------- tests/unit_tests/test_guard.py | 10 ++- 2 files changed, 72 insertions(+), 42 deletions(-) diff --git a/tests/integration_tests/test_litellm.py b/tests/integration_tests/test_litellm.py index fe0abae7a..ecffec234 100644 --- a/tests/integration_tests/test_litellm.py +++ b/tests/integration_tests/test_litellm.py @@ -5,13 +5,8 @@ import pytest import guardrails as gd -from guardrails.validators import ( - LowerCase, - UpperCase, - OneLine, - EndsWith, - ValidLength -) +from guardrails.validators import LowerCase, OneLine, UpperCase + # Mock the litellm.completion function and # the classes it returns @@ -19,20 +14,24 @@ class Message: content: str + @dataclass class Choice: message: Message + @dataclass class Usage: prompt_tokens: int completion_tokens: int + @dataclass class MockResponse: choices: List[Choice] usage: Usage + class MockCompletion: @staticmethod def create(output) -> MockResponse: @@ -41,16 +40,23 @@ def create(output) -> MockResponse: usage=Usage(prompt_tokens=10, completion_tokens=20), ) + @pytest.mark.skipif( not importlib.util.find_spec("litellm"), reason="`litellm` is not installed", ) @pytest.mark.parametrize( - "input_text, expected", + "input_text, expected", [ - ("Suggestions for a name for an AI company. The name should be short and catchy.", "GUARDRAILS AI"), + ( + """ + Suggestions for a name for an AI company. + The name should be short and catchy. + """, + "GUARDRAILS AI", + ), ("What is the capital of France?", "PARIS"), - ] + ], ) def test_litellm_completion(mocker, input_text, expected): """Test that Guardrails can use litellm for completions.""" @@ -58,30 +64,38 @@ def test_litellm_completion(mocker, input_text, expected): mocker.patch("litellm.completion", return_value=MockCompletion.create(expected)) - guard = gd.Guard.from_string(validators=[LowerCase(on_fail="fix")], prompt=input_text) - + guard = gd.Guard.from_string( + validators=[LowerCase(on_fail="fix")], prompt=input_text + ) + raw, validated, *rest = guard(litellm.completion) assert raw == expected assert validated == expected.lower() + # Test Guard().use() with just output validators @pytest.mark.skipif( not importlib.util.find_spec("litellm"), reason="`litellm` is not installed", ) @pytest.mark.parametrize( - "input_text, raw_response, pass_output", + "input_text, raw_response, pass_output", [ ("Name one Oscar-nominated film", "may december", True), ("Name one Oscar-nominated film", "PAST LIVES", False), - ] + ], ) def test_guard_use_output_validators(mocker, input_text, raw_response, pass_output): """Test Guard().use() with just output validators.""" import litellm + mocker.patch("litellm.completion", return_value=MockCompletion.create(raw_response)) - guard = gd.Guard().use(LowerCase, on="output", on_fail="fix").use(OneLine, on="output", on_fail="noop") + guard = ( + gd.Guard() + .use(LowerCase, on="output", on_fail="fix") + .use(OneLine, on="output", on_fail="noop") + ) raw, validated, *rest = guard(litellm.completion, prompt=input_text) assert raw == raw_response @@ -90,27 +104,36 @@ def test_guard_use_output_validators(mocker, input_text, raw_response, pass_outp else: assert validated == raw_response.lower() + # Test Guard().use() with a combination of prompt and output validators @pytest.mark.skipif( not importlib.util.find_spec("litellm"), reason="`litellm` is not installed", ) @pytest.mark.parametrize( - "input_text, pass_input, raw_response, pass_output", + "input_text, pass_input, raw_response, pass_output", [ ("name one oscar-nominated film", True, "MAY DECEMBER", True), ("Name one Oscar-nominated film", False, "PAST LIVES", True), ("Name one Oscar-nominated film", False, "past lives", False), ("name one oscar-nominated film", True, "past lives", False), - ] + ], ) -def test_guard_use_combination_validators(mocker, input_text, pass_input, raw_response, pass_output): - """Test Guard().use() with a combination of prompt and output validators.""" +def test_guard_use_combination_validators( + mocker, input_text, pass_input, raw_response, pass_output +): + """Test Guard().use() with a combination of prompt and output + validators.""" import litellm + mocker.patch("litellm.completion", return_value=MockCompletion.create(raw_response)) - guard = gd.Guard().use(LowerCase, on="prompt", on_fail="exception").use(UpperCase, on="output", on_fail="fix") - + guard = ( + gd.Guard() + .use(LowerCase, on="prompt", on_fail="exception") + .use(UpperCase, on="output", on_fail="fix") + ) + if pass_input: raw, validated, *rest = guard(litellm.completion, prompt=input_text) @@ -124,28 +147,28 @@ def test_guard_use_combination_validators(mocker, input_text, pass_input, raw_re raw, validated, *rest = guard(litellm.completion, prompt=input_text) - # Test Guard().use_many() with just output validators @pytest.mark.skipif( not importlib.util.find_spec("litellm"), reason="`litellm` is not installed", ) @pytest.mark.parametrize( - "input_text, raw_response, pass_output", + "input_text, raw_response, pass_output", [ ("Name one Oscar-nominated film", "may december", True), ("Name one Oscar-nominated film", "PAST LIVES", False), - ] + ], ) -def test_guard_use_many_output_validators(mocker, input_text, raw_response, pass_output): +def test_guard_use_many_output_validators( + mocker, input_text, raw_response, pass_output +): """Test Guard().use_many() with just output validators.""" import litellm + mocker.patch("litellm.completion", return_value=MockCompletion.create(raw_response)) guard = gd.Guard().use_many( - LowerCase(on_fail="fix"), - OneLine(on_fail="noop"), - on="output" + LowerCase(on_fail="fix"), OneLine(on_fail="noop"), on="output" ) raw, validated, *rest = guard(litellm.completion, prompt=input_text) @@ -162,28 +185,29 @@ def test_guard_use_many_output_validators(mocker, input_text, raw_response, pass reason="`litellm` is not installed", ) @pytest.mark.parametrize( - "input_text, pass_input, raw_response, pass_output", + "input_text, pass_input, raw_response, pass_output", [ ("name one oscar-nominated film", True, "MAY DECEMBER", True), ("Name one Oscar-nominated film", False, "PAST LIVES", True), ("Name one Oscar-nominated film", False, "past lives", False), ("name one oscar-nominated film", True, "past lives", False), - ] + ], ) -def test_guard_use_many_combination_validators(mocker, input_text, pass_input, raw_response, pass_output): - """Test Guard().use() with a combination of prompt and output validators.""" +def test_guard_use_many_combination_validators( + mocker, input_text, pass_input, raw_response, pass_output +): + """Test Guard().use() with a combination of prompt and output + validators.""" import litellm + mocker.patch("litellm.completion", return_value=MockCompletion.create(raw_response)) - # guard = gd.Guard().use(LowerCase, on="prompt", on_fail="exception").use(UpperCase, on="output", on_fail="fix") - guard = gd.Guard().use_many( - LowerCase(on_fail="exception"), - on="prompt" - ).use_many( - UpperCase(on_fail="fix"), - on="output" + guard = ( + gd.Guard() + .use_many(LowerCase(on_fail="exception"), on="prompt") + .use_many(UpperCase(on_fail="fix"), on="output") ) - + if pass_input: raw, validated, *rest = guard(litellm.completion, prompt=input_text) diff --git a/tests/unit_tests/test_guard.py b/tests/unit_tests/test_guard.py index 0f355d6c9..109ca388f 100644 --- a/tests/unit_tests/test_guard.py +++ b/tests/unit_tests/test_guard.py @@ -225,7 +225,8 @@ class TestClass(BaseModel): py_guard.use(EndsWith("a"), OneLine(), LowerCase(), TwoWords(on_fail="reask")) # Use a combination of prompt, instructions, msg_history and output validators - # Should only have the output validators in the guard, everything else is in the schema + # Should only have the output validators in the guard, + # everything else is in the schema guard: Guard = ( Guard() .use(LowerCase, on="prompt") @@ -253,6 +254,7 @@ class TestClass(BaseModel): .use(OneLine, on="prompt") # valid on parameter ) + def test_use_many_instances(): guard: Guard = Guard().use_many( EndsWith("a"), OneLine(), LowerCase(), TwoWords(on_fail="reask") @@ -375,11 +377,14 @@ def test_use_many_tuple(): on="response", ) + def test_validate(): guard: Guard = ( Guard() .use(OneLine) - .use(LowerCase(on_fail="fix"), on="output") # default on="output", still explicitly set + .use( + LowerCase(on_fail="fix"), on="output" + ) # default on="output", still explicitly set .use(TwoWords) .use(ValidLength, 0, 12, on_fail="refrain") ) @@ -424,6 +429,7 @@ def test_validate(): assert response_2.validation_passed is False assert response_2.validated_output is None + # def test_call(): # five_seconds = 5 / 60 # response = Guard().use_many( From 321b7c843c814d4143c4f3f2c5c58e6b1a09dd5a Mon Sep 17 00:00:00 2001 From: Karan Acharya Date: Thu, 14 Mar 2024 18:14:57 -0400 Subject: [PATCH 5/8] Add FutureWarnings to old methods --- guardrails/guard.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/guardrails/guard.py b/guardrails/guard.py index 677891a73..a048637e6 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -1074,6 +1074,12 @@ def with_prompt_validation( Args: validators: The validators to add to the prompt. """ + warnings.warn( + """The `with_prompt_validation` method is deprecated, + and will be removed in 0.5.x. Instead, please use + `Guard().use(YourValidator, on='prompt')`.""", + FutureWarning, + ) if self.rail.prompt_schema: warnings.warn("Overriding existing prompt validators.") schema = StringSchema.from_string( @@ -1091,6 +1097,12 @@ def with_instructions_validation( Args: validators: The validators to add to the instructions. """ + warnings.warn( + """The `with_instructions_validation` method is deprecated, + and will be removed in 0.5.x. Instead, please use + `Guard().use(YourValidator, on='instructions')`.""", + FutureWarning, + ) if self.rail.instructions_schema: warnings.warn("Overriding existing instructions validators.") schema = StringSchema.from_string( @@ -1108,6 +1120,12 @@ def with_msg_history_validation( Args: validators: The validators to add to the msg_history. """ + warnings.warn( + """The `with_msg_history_validation` method is deprecated, + and will be removed in 0.5.x. Instead, please use + `Guard().use(YourValidator, on='msg_history')`.""", + FutureWarning, + ) if self.rail.msg_history_schema: warnings.warn("Overriding existing msg_history validators.") schema = StringSchema.from_string( From 48c9c4c980ae819fb04d23a3256f5db497cc6e22 Mon Sep 17 00:00:00 2001 From: Karan Acharya Date: Fri, 15 Mar 2024 10:49:24 -0400 Subject: [PATCH 6/8] Refactor warnings --- guardrails/guard.py | 40 ++++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/guardrails/guard.py b/guardrails/guard.py index a048637e6..5d764beb3 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -34,6 +34,7 @@ from langchain_core.messages import BaseMessage from langchain_core.runnables import Runnable, RunnableConfig from pydantic import BaseModel +from typing_extensions import deprecated from guardrails.api_client import GuardrailsApiClient from guardrails.classes import OT, InputType, ValidationOutcome @@ -1065,6 +1066,13 @@ async def _async_parse( return ValidationOutcome[OT].from_guard_history(call) + @deprecated( + """The `with_prompt_validation` method is deprecated, + and will be removed in 0.5.x. Instead, please use + `Guard().use(YourValidator, on='prompt')`.""", + FutureWarning, + stacklevel=2, + ) def with_prompt_validation( self, validators: Sequence[Validator], @@ -1074,12 +1082,6 @@ def with_prompt_validation( Args: validators: The validators to add to the prompt. """ - warnings.warn( - """The `with_prompt_validation` method is deprecated, - and will be removed in 0.5.x. Instead, please use - `Guard().use(YourValidator, on='prompt')`.""", - FutureWarning, - ) if self.rail.prompt_schema: warnings.warn("Overriding existing prompt validators.") schema = StringSchema.from_string( @@ -1088,6 +1090,13 @@ def with_prompt_validation( self.rail.prompt_schema = schema return self + @deprecated( + """The `with_instructions_validation` method is deprecated, + and will be removed in 0.5.x. Instead, please use + `Guard().use(YourValidator, on='instructions')`.""", + FutureWarning, + stacklevel=2, + ) def with_instructions_validation( self, validators: Sequence[Validator], @@ -1097,12 +1106,6 @@ def with_instructions_validation( Args: validators: The validators to add to the instructions. """ - warnings.warn( - """The `with_instructions_validation` method is deprecated, - and will be removed in 0.5.x. Instead, please use - `Guard().use(YourValidator, on='instructions')`.""", - FutureWarning, - ) if self.rail.instructions_schema: warnings.warn("Overriding existing instructions validators.") schema = StringSchema.from_string( @@ -1111,6 +1114,13 @@ def with_instructions_validation( self.rail.instructions_schema = schema return self + @deprecated( + """The `with_msg_history_validation` method is deprecated, + and will be removed in 0.5.x. Instead, please use + `Guard().use(YourValidator, on='msg_history')`.""", + FutureWarning, + stacklevel=2, + ) def with_msg_history_validation( self, validators: Sequence[Validator], @@ -1120,12 +1130,6 @@ def with_msg_history_validation( Args: validators: The validators to add to the msg_history. """ - warnings.warn( - """The `with_msg_history_validation` method is deprecated, - and will be removed in 0.5.x. Instead, please use - `Guard().use(YourValidator, on='msg_history')`.""", - FutureWarning, - ) if self.rail.msg_history_schema: warnings.warn("Overriding existing msg_history validators.") schema = StringSchema.from_string( From 95db2b2ae6e6d94dfa34ad4c7b54fcfe3d93429a Mon Sep 17 00:00:00 2001 From: Karan Acharya Date: Fri, 15 Mar 2024 11:49:42 -0400 Subject: [PATCH 7/8] Refactor: Add kwarg and make on mandatory kwarg --- guardrails/guard.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/guardrails/guard.py b/guardrails/guard.py index 5d764beb3..22e9748c3 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -1070,7 +1070,7 @@ async def _async_parse( """The `with_prompt_validation` method is deprecated, and will be removed in 0.5.x. Instead, please use `Guard().use(YourValidator, on='prompt')`.""", - FutureWarning, + category=FutureWarning, stacklevel=2, ) def with_prompt_validation( @@ -1094,7 +1094,7 @@ def with_prompt_validation( """The `with_instructions_validation` method is deprecated, and will be removed in 0.5.x. Instead, please use `Guard().use(YourValidator, on='instructions')`.""", - FutureWarning, + category=FutureWarning, stacklevel=2, ) def with_instructions_validation( @@ -1118,7 +1118,7 @@ def with_instructions_validation( """The `with_msg_history_validation` method is deprecated, and will be removed in 0.5.x. Instead, please use `Guard().use(YourValidator, on='msg_history')`.""", - FutureWarning, + category=FutureWarning, stacklevel=2, ) def with_msg_history_validation( @@ -1185,7 +1185,7 @@ def __add_validator(self, validator: Validator, on: str = "output"): ) @overload - def use(self, validator: Validator, on: str = "output") -> "Guard": + def use(self, validator: Validator, *, on: str = "output") -> "Guard": ... @overload From b0f2cd8903c035baedfa3d9431af3b7ac293c478 Mon Sep 17 00:00:00 2001 From: Karan Acharya Date: Fri, 15 Mar 2024 11:50:29 -0400 Subject: [PATCH 8/8] Add and update unit tests to cover more edge cases --- tests/unit_tests/test_guard.py | 114 ++++++++++++++++++++++++++++++++- 1 file changed, 111 insertions(+), 3 deletions(-) diff --git a/tests/unit_tests/test_guard.py b/tests/unit_tests/test_guard.py index 109ca388f..3f64f2293 100644 --- a/tests/unit_tests/test_guard.py +++ b/tests/unit_tests/test_guard.py @@ -232,11 +232,28 @@ class TestClass(BaseModel): .use(LowerCase, on="prompt") .use(OneLine, on="prompt") .use(UpperCase, on="instructions") - .use(EndsWith("a"), on="msg_history") - .use(EndsWith("a"), on="output") # default on="output", still explicitly set + .use(LowerCase, on="msg_history") + .use( + EndsWith, end="a", on="output" + ) # default on="output", still explicitly set .use(TwoWords, on_fail="reask") # default on="output", implicitly set ) + # Check schemas for prompt, instructions and msg_history validators + prompt_validators = guard.rail.prompt_schema.root_datatype.validators + assert len(prompt_validators) == 2 + assert prompt_validators[0].__class__.__name__ == "LowerCase" + assert prompt_validators[1].__class__.__name__ == "OneLine" + + instructions_validators = guard.rail.instructions_schema.root_datatype.validators + assert len(instructions_validators) == 1 + assert instructions_validators[0].__class__.__name__ == "UpperCase" + + msg_history_validators = guard.rail.msg_history_schema.root_datatype.validators + assert len(msg_history_validators) == 1 + assert msg_history_validators[0].__class__.__name__ == "LowerCase" + + # Check guard for output validators assert len(guard._validators) == 2 # only 2 output validators, hence 2 assert isinstance(guard._validators[0], EndsWith) @@ -288,7 +305,7 @@ class TestClass(BaseModel): [EndsWith("a"), OneLine(), LowerCase(), TwoWords(on_fail="reask")] ) - # Test with explicitly setting the "on" parameter + # Test with explicitly setting the "on" parameter = "output" guard: Guard = Guard().use_many( EndsWith("a"), OneLine(), LowerCase(), TwoWords(on_fail="reask"), on="output" ) @@ -309,6 +326,42 @@ class TestClass(BaseModel): assert isinstance(guard._validators[3], TwoWords) assert guard._validators[3].on_fail_descriptor == "reask" # bc we set it + # Test with explicitly setting the "on" parameter = "prompt" + guard: Guard = Guard().use_many( + OneLine(), LowerCase(), TwoWords(on_fail="reask"), on="prompt" + ) + + prompt_validators = guard.rail.prompt_schema.root_datatype.validators + assert len(prompt_validators) == 3 + assert prompt_validators[0].__class__.__name__ == "OneLine" + assert prompt_validators[1].__class__.__name__ == "LowerCase" + assert prompt_validators[2].__class__.__name__ == "TwoWords" + assert len(guard._validators) == 0 # no output validators, hence 0 + + # Test with explicitly setting the "on" parameter = "instructions" + guard: Guard = Guard().use_many( + OneLine(), LowerCase(), TwoWords(on_fail="reask"), on="instructions" + ) + + instructions_validators = guard.rail.instructions_schema.root_datatype.validators + assert len(instructions_validators) == 3 + assert instructions_validators[0].__class__.__name__ == "OneLine" + assert instructions_validators[1].__class__.__name__ == "LowerCase" + assert instructions_validators[2].__class__.__name__ == "TwoWords" + assert len(guard._validators) == 0 # no output validators, hence 0 + + # Test with explicitly setting the "on" parameter = "msg_history" + guard: Guard = Guard().use_many( + OneLine(), LowerCase(), TwoWords(on_fail="reask"), on="msg_history" + ) + + msg_history_validators = guard.rail.msg_history_schema.root_datatype.validators + assert len(msg_history_validators) == 3 + assert msg_history_validators[0].__class__.__name__ == "OneLine" + assert msg_history_validators[1].__class__.__name__ == "LowerCase" + assert msg_history_validators[2].__class__.__name__ == "TwoWords" + assert len(guard._validators) == 0 # no output validators, hence 0 + # Test with an invalid "on" parameter, should raise a ValueError with pytest.raises(ValueError): guard: Guard = Guard().use_many( @@ -430,6 +483,61 @@ def test_validate(): assert response_2.validated_output is None +def test_use_and_use_many(): + guard: Guard = ( + Guard() + .use_many(OneLine(), LowerCase(), on="prompt") + .use(UpperCase, on="instructions") + .use(LowerCase, on="msg_history") + .use_many( + TwoWords(on_fail="reask"), + ValidLength(0, 12, on_fail="refrain"), + on="output", + ) + ) + + # Check schemas for prompt, instructions and msg_history validators + prompt_validators = guard.rail.prompt_schema.root_datatype.validators + assert len(prompt_validators) == 2 + assert prompt_validators[0].__class__.__name__ == "OneLine" + assert prompt_validators[1].__class__.__name__ == "LowerCase" + + instructions_validators = guard.rail.instructions_schema.root_datatype.validators + assert len(instructions_validators) == 1 + assert instructions_validators[0].__class__.__name__ == "UpperCase" + + msg_history_validators = guard.rail.msg_history_schema.root_datatype.validators + assert len(msg_history_validators) == 1 + assert msg_history_validators[0].__class__.__name__ == "LowerCase" + + # Check guard for output validators + assert len(guard._validators) == 2 # only 2 output validators, hence 2 + + assert isinstance(guard._validators[0], TwoWords) + assert guard._validators[0].on_fail_descriptor == "reask" # bc we set it + + assert isinstance(guard._validators[1], ValidLength) + assert guard._validators[1]._min == 0 + assert guard._validators[1]._kwargs["min"] == 0 + assert guard._validators[1]._max == 12 + assert guard._validators[1]._kwargs["max"] == 12 + assert guard._validators[1].on_fail_descriptor == "refrain" # bc we set it + + # Test with an invalid "on" parameter, should raise a ValueError + with pytest.raises(ValueError): + guard: Guard = ( + Guard() + .use_many(OneLine(), LowerCase(), on="prompt") + .use(UpperCase, on="instructions") + .use(LowerCase, on="msg_history") + .use_many( + TwoWords(on_fail="reask"), + ValidLength(0, 12, on_fail="refrain"), + on="response", # invalid "on" parameter + ) + ) + + # def test_call(): # five_seconds = 5 / 60 # response = Guard().use_many(