diff --git a/guardrails/guard.py b/guardrails/guard.py index 77f044518..c707f8a21 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -987,8 +987,21 @@ def use(self, validator: Type[Validator], *args, **kwargs) -> "Guard": def use( self, validator: Union[Validator, Type[Validator]], *args, **kwargs ) -> "Guard": + """Use a validator to validate results of an LLM request. + + *Note*: `use` is only available for string output types. + """ + + if self.rail.output_type != "str": + raise RuntimeError( + "The `use` method is only available for string output types." + ) + if validator: - self._validators.append(get_validator(validator, *args, **kwargs)) + hydrated_validator = get_validator(validator, *args, **kwargs) + self._validators.append(hydrated_validator) + + self.rail.output_schema.root_datatype.validators.append(hydrated_validator) return self @@ -1018,8 +1031,20 @@ def use_many( ], ], ) -> "Guard": + """Use a validator to validate results of an LLM request. + + *Note*: `use_many` is only available for string output types. + """ + + if self.rail.output_type != "str": + raise RuntimeError( + "The `use_many` method is only available for string output types." + ) + for v in validators: - self._validators.append(get_validator(v)) + hydrated_validator = get_validator(v) + self._validators.append(hydrated_validator) + self.rail.output_schema.root_datatype.validators.append(hydrated_validator) return self diff --git a/tests/unit_tests/test_guard.py b/tests/unit_tests/test_guard.py index 40798cf45..729593e1a 100644 --- a/tests/unit_tests/test_guard.py +++ b/tests/unit_tests/test_guard.py @@ -214,6 +214,15 @@ def test_use(): assert guard._validators[4]._kwargs["max"] == 12 assert guard._validators[4].on_fail_descriptor == "refrain" # bc we set it + # Raises error when trying to `use` a validator on a non-string + with pytest.raises(RuntimeError): + + class TestClass(BaseModel): + another_field: str + + py_guard = Guard.from_pydantic(output_class=TestClass) + py_guard.use(EndsWith("a"), OneLine(), LowerCase(), TwoWords(on_fail="reask")) + def test_use_many_instances(): guard: Guard = Guard().use_many( @@ -237,6 +246,17 @@ def test_use_many_instances(): assert isinstance(guard._validators[3], TwoWords) assert guard._validators[3].on_fail_descriptor == "reask" # bc we set it + # Raises error when trying to `use_many` a validator on a non-string + with pytest.raises(RuntimeError): + + class TestClass(BaseModel): + another_field: str + + py_guard = Guard.from_pydantic(output_class=TestClass) + py_guard.use_many( + [EndsWith("a"), OneLine(), LowerCase(), TwoWords(on_fail="reask")] + ) + def test_use_many_tuple(): guard: Guard = Guard().use_many( @@ -278,14 +298,13 @@ def test_use_many_tuple(): def test_validate(): guard: Guard = ( Guard() - .use(EndsWith("a")) .use(OneLine) .use(LowerCase(on_fail="fix")) .use(TwoWords) .use(ValidLength, 0, 12, on_fail="refrain") ) - llm_output = "Oh Canada" # bc it meets our criteria + llm_output: str = "Oh Canada" # bc it meets our criteria response = guard.validate(llm_output)