diff --git a/guardrails/guard.py b/guardrails/guard.py index 22120a51f..22e9748c3 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -34,6 +34,7 @@ from langchain_core.messages import BaseMessage from langchain_core.runnables import Runnable, RunnableConfig from pydantic import BaseModel +from typing_extensions import deprecated from guardrails.api_client import GuardrailsApiClient from guardrails.classes import OT, InputType, ValidationOutcome @@ -1065,6 +1066,13 @@ async def _async_parse( return ValidationOutcome[OT].from_guard_history(call) + @deprecated( + """The `with_prompt_validation` method is deprecated, + and will be removed in 0.5.x. Instead, please use + `Guard().use(YourValidator, on='prompt')`.""", + category=FutureWarning, + stacklevel=2, + ) def with_prompt_validation( self, validators: Sequence[Validator], @@ -1082,6 +1090,13 @@ def with_prompt_validation( self.rail.prompt_schema = schema return self + @deprecated( + """The `with_instructions_validation` method is deprecated, + and will be removed in 0.5.x. Instead, please use + `Guard().use(YourValidator, on='instructions')`.""", + category=FutureWarning, + stacklevel=2, + ) def with_instructions_validation( self, validators: Sequence[Validator], @@ -1099,6 +1114,13 @@ def with_instructions_validation( self.rail.instructions_schema = schema return self + @deprecated( + """The `with_msg_history_validation` method is deprecated, + and will be removed in 0.5.x. Instead, please use + `Guard().use(YourValidator, on='msg_history')`.""", + category=FutureWarning, + stacklevel=2, + ) def with_msg_history_validation( self, validators: Sequence[Validator], @@ -1116,37 +1138,87 @@ def with_msg_history_validation( self.rail.msg_history_schema = schema return self + def __add_validator(self, validator: Validator, on: str = "output"): + # Only available for string output types + if self.rail.output_type != "str": + raise RuntimeError( + "The `use` method is only available for string output types." + ) + + if on == "prompt": + # If the prompt schema exists, add the validator to it + if self.rail.prompt_schema: + self.rail.prompt_schema.root_datatype.validators.append(validator) + else: + # Otherwise, create a new schema with the validator + schema = StringSchema.from_string( + validators=[validator], + ) + self.rail.prompt_schema = schema + elif on == "instructions": + # If the instructions schema exists, add the validator to it + if self.rail.instructions_schema: + self.rail.instructions_schema.root_datatype.validators.append(validator) + else: + # Otherwise, create a new schema with the validator + schema = StringSchema.from_string( + validators=[validator], + ) + self.rail.instructions_schema = schema + elif on == "msg_history": + # If the msg_history schema exists, add the validator to it + if self.rail.msg_history_schema: + self.rail.msg_history_schema.root_datatype.validators.append(validator) + else: + # Otherwise, create a new schema with the validator + schema = StringSchema.from_string( + validators=[validator], + ) + self.rail.msg_history_schema = schema + elif on == "output": + self._validators.append(validator) + self.rail.output_schema.root_datatype.validators.append(validator) + else: + raise ValueError( + """Invalid value for `on`. Must be one of the following: + 'output', 'prompt', 'instructions', 'msg_history'.""" + ) + @overload - def use(self, validator: Validator) -> "Guard": + def use(self, validator: Validator, *, on: str = "output") -> "Guard": ... @overload - def use(self, validator: Type[Validator], *args, **kwargs) -> "Guard": + def use( + self, validator: Type[Validator], *args, on: str = "output", **kwargs + ) -> "Guard": ... def use( - self, validator: Union[Validator, Type[Validator]], *args, **kwargs + self, + validator: Union[Validator, Type[Validator]], + *args, + on: str = "output", + **kwargs, ) -> "Guard": - """Use a validator to validate results of an LLM request. - - *Note*: `use` is only available for string output types. - """ - - if self.rail.output_type != "str": - raise RuntimeError( - "The `use` method is only available for string output types." - ) + """Use a validator to validate either of the following: + - The output of an LLM request + - The prompt + - The instructions + - The message history - if validator: - hydrated_validator = get_validator(validator, *args, **kwargs) - self._validators.append(hydrated_validator) - - self.rail.output_schema.root_datatype.validators.append(hydrated_validator) + *Note*: For on="output", `use` is only available for string output types. + Args: + validator: The validator to use. Either the class or an instance. + on: The part of the LLM request to validate. Defaults to "output". + """ + hydrated_validator = get_validator(validator, *args, **kwargs) + self.__add_validator(hydrated_validator, on=on) return self @overload - def use_many(self, *validators: Validator) -> "Guard": + def use_many(self, *validators: Validator, on: str = "output") -> "Guard": ... @overload @@ -1157,6 +1229,7 @@ def use_many( Optional[Union[List[Any], Dict[str, Any]]], Optional[Dict[str, Any]], ], + on: str = "output", ) -> "Guard": ... @@ -1170,22 +1243,21 @@ def use_many( Optional[Dict[str, Any]], ], ], + on: str = "output", ) -> "Guard": """Use a validator to validate results of an LLM request. *Note*: `use_many` is only available for string output types. """ - if self.rail.output_type != "str": raise RuntimeError( "The `use_many` method is only available for string output types." ) + # Loop through the validators for v in validators: hydrated_validator = get_validator(v) - self._validators.append(hydrated_validator) - self.rail.output_schema.root_datatype.validators.append(hydrated_validator) - + self.__add_validator(hydrated_validator, on=on) return self def validate(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]: diff --git a/tests/integration_tests/test_guard.py b/tests/integration_tests/test_guard.py index 78a85a354..7ba5c074e 100644 --- a/tests/integration_tests/test_guard.py +++ b/tests/integration_tests/test_guard.py @@ -841,7 +841,7 @@ def invoke( model = MockModel() guard = ( Guard() - .use(RegexMatch("Ice cream", match_type="search")) + .use(RegexMatch("Ice cream", match_type="search"), on="output") .use(ReadingTime(0.05)) # 3 seconds ) output_parser = StrOutputParser() diff --git a/tests/integration_tests/test_litellm.py b/tests/integration_tests/test_litellm.py index 52c63909b..ecffec234 100644 --- a/tests/integration_tests/test_litellm.py +++ b/tests/integration_tests/test_litellm.py @@ -5,7 +5,7 @@ import pytest import guardrails as gd -from guardrails.validators import LowerCase +from guardrails.validators import LowerCase, OneLine, UpperCase # Mock the litellm.completion function and @@ -34,27 +34,188 @@ class MockResponse: class MockCompletion: @staticmethod - def create() -> MockResponse: + def create(output) -> MockResponse: return MockResponse( - choices=[Choice(message=Message(content="GUARDRAILS AI"))], + choices=[Choice(message=Message(content=output))], usage=Usage(prompt_tokens=10, completion_tokens=20), ) -TEST_PROMPT = "Suggest a name for an AI company. The name should be short and catchy." -guard = gd.Guard.from_string(validators=[LowerCase(on_fail="fix")], prompt=TEST_PROMPT) - - @pytest.mark.skipif( not importlib.util.find_spec("litellm"), reason="`litellm` is not installed", ) -def test_litellm_completion(mocker): +@pytest.mark.parametrize( + "input_text, expected", + [ + ( + """ + Suggestions for a name for an AI company. + The name should be short and catchy. + """, + "GUARDRAILS AI", + ), + ("What is the capital of France?", "PARIS"), + ], +) +def test_litellm_completion(mocker, input_text, expected): """Test that Guardrails can use litellm for completions.""" import litellm - mocker.patch("litellm.completion", return_value=MockCompletion.create()) + mocker.patch("litellm.completion", return_value=MockCompletion.create(expected)) + + guard = gd.Guard.from_string( + validators=[LowerCase(on_fail="fix")], prompt=input_text + ) raw, validated, *rest = guard(litellm.completion) - assert raw == "GUARDRAILS AI" - assert validated == "guardrails ai" + assert raw == expected + assert validated == expected.lower() + + +# Test Guard().use() with just output validators +@pytest.mark.skipif( + not importlib.util.find_spec("litellm"), + reason="`litellm` is not installed", +) +@pytest.mark.parametrize( + "input_text, raw_response, pass_output", + [ + ("Name one Oscar-nominated film", "may december", True), + ("Name one Oscar-nominated film", "PAST LIVES", False), + ], +) +def test_guard_use_output_validators(mocker, input_text, raw_response, pass_output): + """Test Guard().use() with just output validators.""" + import litellm + + mocker.patch("litellm.completion", return_value=MockCompletion.create(raw_response)) + + guard = ( + gd.Guard() + .use(LowerCase, on="output", on_fail="fix") + .use(OneLine, on="output", on_fail="noop") + ) + raw, validated, *rest = guard(litellm.completion, prompt=input_text) + + assert raw == raw_response + if pass_output: + assert validated == raw_response + else: + assert validated == raw_response.lower() + + +# Test Guard().use() with a combination of prompt and output validators +@pytest.mark.skipif( + not importlib.util.find_spec("litellm"), + reason="`litellm` is not installed", +) +@pytest.mark.parametrize( + "input_text, pass_input, raw_response, pass_output", + [ + ("name one oscar-nominated film", True, "MAY DECEMBER", True), + ("Name one Oscar-nominated film", False, "PAST LIVES", True), + ("Name one Oscar-nominated film", False, "past lives", False), + ("name one oscar-nominated film", True, "past lives", False), + ], +) +def test_guard_use_combination_validators( + mocker, input_text, pass_input, raw_response, pass_output +): + """Test Guard().use() with a combination of prompt and output + validators.""" + import litellm + + mocker.patch("litellm.completion", return_value=MockCompletion.create(raw_response)) + + guard = ( + gd.Guard() + .use(LowerCase, on="prompt", on_fail="exception") + .use(UpperCase, on="output", on_fail="fix") + ) + + if pass_input: + raw, validated, *rest = guard(litellm.completion, prompt=input_text) + + assert raw == raw_response + if pass_output: + assert validated == raw_response + else: + assert validated == raw_response.upper() + else: + with pytest.raises(Exception): + raw, validated, *rest = guard(litellm.completion, prompt=input_text) + + +# Test Guard().use_many() with just output validators +@pytest.mark.skipif( + not importlib.util.find_spec("litellm"), + reason="`litellm` is not installed", +) +@pytest.mark.parametrize( + "input_text, raw_response, pass_output", + [ + ("Name one Oscar-nominated film", "may december", True), + ("Name one Oscar-nominated film", "PAST LIVES", False), + ], +) +def test_guard_use_many_output_validators( + mocker, input_text, raw_response, pass_output +): + """Test Guard().use_many() with just output validators.""" + import litellm + + mocker.patch("litellm.completion", return_value=MockCompletion.create(raw_response)) + + guard = gd.Guard().use_many( + LowerCase(on_fail="fix"), OneLine(on_fail="noop"), on="output" + ) + raw, validated, *rest = guard(litellm.completion, prompt=input_text) + + assert raw == raw_response + if pass_output: + assert validated == raw_response + else: + assert validated == raw_response.lower() + + +# Test Guard().use_many() with a combination of prompt and output validators +@pytest.mark.skipif( + not importlib.util.find_spec("litellm"), + reason="`litellm` is not installed", +) +@pytest.mark.parametrize( + "input_text, pass_input, raw_response, pass_output", + [ + ("name one oscar-nominated film", True, "MAY DECEMBER", True), + ("Name one Oscar-nominated film", False, "PAST LIVES", True), + ("Name one Oscar-nominated film", False, "past lives", False), + ("name one oscar-nominated film", True, "past lives", False), + ], +) +def test_guard_use_many_combination_validators( + mocker, input_text, pass_input, raw_response, pass_output +): + """Test Guard().use() with a combination of prompt and output + validators.""" + import litellm + + mocker.patch("litellm.completion", return_value=MockCompletion.create(raw_response)) + + guard = ( + gd.Guard() + .use_many(LowerCase(on_fail="exception"), on="prompt") + .use_many(UpperCase(on_fail="fix"), on="output") + ) + + if pass_input: + raw, validated, *rest = guard(litellm.completion, prompt=input_text) + + assert raw == raw_response + if pass_output: + assert validated == raw_response + else: + assert validated == raw_response.upper() + else: + with pytest.raises(Exception): + raw, validated, *rest = guard(litellm.completion, prompt=input_text) diff --git a/tests/unit_tests/test_guard.py b/tests/unit_tests/test_guard.py index 729593e1a..3f64f2293 100644 --- a/tests/unit_tests/test_guard.py +++ b/tests/unit_tests/test_guard.py @@ -12,6 +12,7 @@ OneLine, PassResult, TwoWords, + UpperCase, ValidLength, register_validator, ) @@ -223,6 +224,53 @@ class TestClass(BaseModel): py_guard = Guard.from_pydantic(output_class=TestClass) py_guard.use(EndsWith("a"), OneLine(), LowerCase(), TwoWords(on_fail="reask")) + # Use a combination of prompt, instructions, msg_history and output validators + # Should only have the output validators in the guard, + # everything else is in the schema + guard: Guard = ( + Guard() + .use(LowerCase, on="prompt") + .use(OneLine, on="prompt") + .use(UpperCase, on="instructions") + .use(LowerCase, on="msg_history") + .use( + EndsWith, end="a", on="output" + ) # default on="output", still explicitly set + .use(TwoWords, on_fail="reask") # default on="output", implicitly set + ) + + # Check schemas for prompt, instructions and msg_history validators + prompt_validators = guard.rail.prompt_schema.root_datatype.validators + assert len(prompt_validators) == 2 + assert prompt_validators[0].__class__.__name__ == "LowerCase" + assert prompt_validators[1].__class__.__name__ == "OneLine" + + instructions_validators = guard.rail.instructions_schema.root_datatype.validators + assert len(instructions_validators) == 1 + assert instructions_validators[0].__class__.__name__ == "UpperCase" + + msg_history_validators = guard.rail.msg_history_schema.root_datatype.validators + assert len(msg_history_validators) == 1 + assert msg_history_validators[0].__class__.__name__ == "LowerCase" + + # Check guard for output validators + assert len(guard._validators) == 2 # only 2 output validators, hence 2 + + assert isinstance(guard._validators[0], EndsWith) + assert guard._validators[0]._kwargs["end"] == "a" + assert guard._validators[0].on_fail_descriptor == "fix" # bc this is the default + + assert isinstance(guard._validators[1], TwoWords) + assert guard._validators[1].on_fail_descriptor == "reask" # bc we set it + + # Test with an invalid "on" parameter, should raise a ValueError + with pytest.raises(ValueError): + guard: Guard = ( + Guard() + .use(EndsWith("a"), on="response") # invalid on parameter + .use(OneLine, on="prompt") # valid on parameter + ) + def test_use_many_instances(): guard: Guard = Guard().use_many( @@ -257,6 +305,69 @@ class TestClass(BaseModel): [EndsWith("a"), OneLine(), LowerCase(), TwoWords(on_fail="reask")] ) + # Test with explicitly setting the "on" parameter = "output" + guard: Guard = Guard().use_many( + EndsWith("a"), OneLine(), LowerCase(), TwoWords(on_fail="reask"), on="output" + ) + + assert len(guard._validators) == 4 # still 4 output validators, hence 4 + + assert isinstance(guard._validators[0], EndsWith) + assert guard._validators[0]._end == "a" + assert guard._validators[0]._kwargs["end"] == "a" + assert guard._validators[0].on_fail_descriptor == "fix" # bc this is the default + + assert isinstance(guard._validators[1], OneLine) + assert guard._validators[1].on_fail_descriptor == "noop" # bc this is the default + + assert isinstance(guard._validators[2], LowerCase) + assert guard._validators[2].on_fail_descriptor == "noop" # bc this is the default + + assert isinstance(guard._validators[3], TwoWords) + assert guard._validators[3].on_fail_descriptor == "reask" # bc we set it + + # Test with explicitly setting the "on" parameter = "prompt" + guard: Guard = Guard().use_many( + OneLine(), LowerCase(), TwoWords(on_fail="reask"), on="prompt" + ) + + prompt_validators = guard.rail.prompt_schema.root_datatype.validators + assert len(prompt_validators) == 3 + assert prompt_validators[0].__class__.__name__ == "OneLine" + assert prompt_validators[1].__class__.__name__ == "LowerCase" + assert prompt_validators[2].__class__.__name__ == "TwoWords" + assert len(guard._validators) == 0 # no output validators, hence 0 + + # Test with explicitly setting the "on" parameter = "instructions" + guard: Guard = Guard().use_many( + OneLine(), LowerCase(), TwoWords(on_fail="reask"), on="instructions" + ) + + instructions_validators = guard.rail.instructions_schema.root_datatype.validators + assert len(instructions_validators) == 3 + assert instructions_validators[0].__class__.__name__ == "OneLine" + assert instructions_validators[1].__class__.__name__ == "LowerCase" + assert instructions_validators[2].__class__.__name__ == "TwoWords" + assert len(guard._validators) == 0 # no output validators, hence 0 + + # Test with explicitly setting the "on" parameter = "msg_history" + guard: Guard = Guard().use_many( + OneLine(), LowerCase(), TwoWords(on_fail="reask"), on="msg_history" + ) + + msg_history_validators = guard.rail.msg_history_schema.root_datatype.validators + assert len(msg_history_validators) == 3 + assert msg_history_validators[0].__class__.__name__ == "OneLine" + assert msg_history_validators[1].__class__.__name__ == "LowerCase" + assert msg_history_validators[2].__class__.__name__ == "TwoWords" + assert len(guard._validators) == 0 # no output validators, hence 0 + + # Test with an invalid "on" parameter, should raise a ValueError + with pytest.raises(ValueError): + guard: Guard = Guard().use_many( + EndsWith("a", on_fail="exception"), OneLine(), on="response" + ) + def test_use_many_tuple(): guard: Guard = Guard().use_many( @@ -294,12 +405,39 @@ def test_use_many_tuple(): assert guard._validators[4]._kwargs["max"] == 12 assert guard._validators[4].on_fail_descriptor == "refrain" # bc we set it + # Test with explicitly setting the "on" parameter + guard: Guard = Guard().use_many( + (EndsWith, ["a"], {"on_fail": "exception"}), + OneLine, + on="output", + ) + + assert len(guard._validators) == 2 # only 2 output validators, hence 2 + + assert isinstance(guard._validators[0], EndsWith) + assert guard._validators[0]._end == "a" + assert guard._validators[0]._kwargs["end"] == "a" + assert guard._validators[0].on_fail_descriptor == "exception" # bc we set it + + assert isinstance(guard._validators[1], OneLine) + assert guard._validators[1].on_fail_descriptor == "noop" # bc this is the default + + # Test with an invalid "on" parameter, should raise a ValueError + with pytest.raises(ValueError): + guard: Guard = Guard().use_many( + (EndsWith, ["a"], {"on_fail": "exception"}), + OneLine, + on="response", + ) + def test_validate(): guard: Guard = ( Guard() .use(OneLine) - .use(LowerCase(on_fail="fix")) + .use( + LowerCase(on_fail="fix"), on="output" + ) # default on="output", still explicitly set .use(TwoWords) .use(ValidLength, 0, 12, on_fail="refrain") ) @@ -318,6 +456,87 @@ def test_validate(): assert response_2.validation_passed is False assert response_2.validated_output is None + # Test with a combination of prompt, output, instructions and msg_history validators + # Should still only use the output validators to validate the output + guard: Guard = ( + Guard() + .use(OneLine, on="prompt") + .use(LowerCase, on="instructions") + .use(UpperCase, on="msg_history") + .use(LowerCase, on="output", on_fail="fix") + .use(TwoWords, on="output") + .use(ValidLength, 0, 12, on="output") + ) + + llm_output: str = "Oh Canada" # bc it meets our criteria + + response = guard.validate(llm_output) + + assert response.validation_passed is True + assert response.validated_output == llm_output.lower() + + llm_output_2 = "Star Spangled Banner" # to stick with the theme + + response_2 = guard.validate(llm_output_2) + + assert response_2.validation_passed is False + assert response_2.validated_output is None + + +def test_use_and_use_many(): + guard: Guard = ( + Guard() + .use_many(OneLine(), LowerCase(), on="prompt") + .use(UpperCase, on="instructions") + .use(LowerCase, on="msg_history") + .use_many( + TwoWords(on_fail="reask"), + ValidLength(0, 12, on_fail="refrain"), + on="output", + ) + ) + + # Check schemas for prompt, instructions and msg_history validators + prompt_validators = guard.rail.prompt_schema.root_datatype.validators + assert len(prompt_validators) == 2 + assert prompt_validators[0].__class__.__name__ == "OneLine" + assert prompt_validators[1].__class__.__name__ == "LowerCase" + + instructions_validators = guard.rail.instructions_schema.root_datatype.validators + assert len(instructions_validators) == 1 + assert instructions_validators[0].__class__.__name__ == "UpperCase" + + msg_history_validators = guard.rail.msg_history_schema.root_datatype.validators + assert len(msg_history_validators) == 1 + assert msg_history_validators[0].__class__.__name__ == "LowerCase" + + # Check guard for output validators + assert len(guard._validators) == 2 # only 2 output validators, hence 2 + + assert isinstance(guard._validators[0], TwoWords) + assert guard._validators[0].on_fail_descriptor == "reask" # bc we set it + + assert isinstance(guard._validators[1], ValidLength) + assert guard._validators[1]._min == 0 + assert guard._validators[1]._kwargs["min"] == 0 + assert guard._validators[1]._max == 12 + assert guard._validators[1]._kwargs["max"] == 12 + assert guard._validators[1].on_fail_descriptor == "refrain" # bc we set it + + # Test with an invalid "on" parameter, should raise a ValueError + with pytest.raises(ValueError): + guard: Guard = ( + Guard() + .use_many(OneLine(), LowerCase(), on="prompt") + .use(UpperCase, on="instructions") + .use(LowerCase, on="msg_history") + .use_many( + TwoWords(on_fail="reask"), + ValidLength(0, 12, on_fail="refrain"), + on="response", # invalid "on" parameter + ) + ) + # def test_call(): # five_seconds = 5 / 60