Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 94 additions & 22 deletions guardrails/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from langchain_core.messages import BaseMessage
from langchain_core.runnables import Runnable, RunnableConfig
from pydantic import BaseModel
from typing_extensions import deprecated

from guardrails.api_client import GuardrailsApiClient
from guardrails.classes import OT, InputType, ValidationOutcome
Expand Down Expand Up @@ -1065,6 +1066,13 @@ async def _async_parse(

return ValidationOutcome[OT].from_guard_history(call)

@deprecated(
"""The `with_prompt_validation` method is deprecated,
and will be removed in 0.5.x. Instead, please use
`Guard().use(YourValidator, on='prompt')`.""",
category=FutureWarning,
stacklevel=2,
)
def with_prompt_validation(
self,
validators: Sequence[Validator],
Expand All @@ -1082,6 +1090,13 @@ def with_prompt_validation(
self.rail.prompt_schema = schema
return self

@deprecated(
"""The `with_instructions_validation` method is deprecated,
and will be removed in 0.5.x. Instead, please use
`Guard().use(YourValidator, on='instructions')`.""",
category=FutureWarning,
stacklevel=2,
)
def with_instructions_validation(
self,
validators: Sequence[Validator],
Expand All @@ -1099,6 +1114,13 @@ def with_instructions_validation(
self.rail.instructions_schema = schema
return self

@deprecated(
"""The `with_msg_history_validation` method is deprecated,
and will be removed in 0.5.x. Instead, please use
`Guard().use(YourValidator, on='msg_history')`.""",
category=FutureWarning,
stacklevel=2,
)
def with_msg_history_validation(
self,
validators: Sequence[Validator],
Expand All @@ -1116,37 +1138,87 @@ def with_msg_history_validation(
self.rail.msg_history_schema = schema
return self

def __add_validator(self, validator: Validator, on: str = "output"):
# Only available for string output types
if self.rail.output_type != "str":
raise RuntimeError(
"The `use` method is only available for string output types."
)

if on == "prompt":
# If the prompt schema exists, add the validator to it
if self.rail.prompt_schema:
self.rail.prompt_schema.root_datatype.validators.append(validator)
else:
# Otherwise, create a new schema with the validator
schema = StringSchema.from_string(
validators=[validator],
)
self.rail.prompt_schema = schema
elif on == "instructions":
# If the instructions schema exists, add the validator to it
if self.rail.instructions_schema:
self.rail.instructions_schema.root_datatype.validators.append(validator)
else:
# Otherwise, create a new schema with the validator
schema = StringSchema.from_string(
validators=[validator],
)
self.rail.instructions_schema = schema
elif on == "msg_history":
# If the msg_history schema exists, add the validator to it
if self.rail.msg_history_schema:
self.rail.msg_history_schema.root_datatype.validators.append(validator)
else:
# Otherwise, create a new schema with the validator
schema = StringSchema.from_string(
validators=[validator],
)
self.rail.msg_history_schema = schema
elif on == "output":
self._validators.append(validator)
self.rail.output_schema.root_datatype.validators.append(validator)
else:
raise ValueError(
"""Invalid value for `on`. Must be one of the following:
'output', 'prompt', 'instructions', 'msg_history'."""
)

@overload
def use(self, validator: Validator) -> "Guard":
def use(self, validator: Validator, *, on: str = "output") -> "Guard":
...

@overload
def use(self, validator: Type[Validator], *args, **kwargs) -> "Guard":
def use(
self, validator: Type[Validator], *args, on: str = "output", **kwargs
) -> "Guard":
...

def use(
self, validator: Union[Validator, Type[Validator]], *args, **kwargs
self,
validator: Union[Validator, Type[Validator]],
*args,
on: str = "output",
**kwargs,
) -> "Guard":
"""Use a validator to validate results of an LLM request.

*Note*: `use` is only available for string output types.
"""

if self.rail.output_type != "str":
raise RuntimeError(
"The `use` method is only available for string output types."
)
"""Use a validator to validate either of the following:
- The output of an LLM request
- The prompt
- The instructions
- The message history

if validator:
hydrated_validator = get_validator(validator, *args, **kwargs)
self._validators.append(hydrated_validator)

self.rail.output_schema.root_datatype.validators.append(hydrated_validator)
*Note*: For on="output", `use` is only available for string output types.

Args:
validator: The validator to use. Either the class or an instance.
on: The part of the LLM request to validate. Defaults to "output".
"""
hydrated_validator = get_validator(validator, *args, **kwargs)
self.__add_validator(hydrated_validator, on=on)
return self

@overload
def use_many(self, *validators: Validator) -> "Guard":
def use_many(self, *validators: Validator, on: str = "output") -> "Guard":
...

@overload
Expand All @@ -1157,6 +1229,7 @@ def use_many(
Optional[Union[List[Any], Dict[str, Any]]],
Optional[Dict[str, Any]],
],
on: str = "output",
) -> "Guard":
...

Expand All @@ -1170,22 +1243,21 @@ def use_many(
Optional[Dict[str, Any]],
],
],
on: str = "output",
) -> "Guard":
"""Use a validator to validate results of an LLM request.

*Note*: `use_many` is only available for string output types.
"""

if self.rail.output_type != "str":
raise RuntimeError(
"The `use_many` method is only available for string output types."
)

# Loop through the validators
for v in validators:
hydrated_validator = get_validator(v)
self._validators.append(hydrated_validator)
self.rail.output_schema.root_datatype.validators.append(hydrated_validator)

self.__add_validator(hydrated_validator, on=on)
return self

def validate(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]:
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/test_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ def invoke(
model = MockModel()
guard = (
Guard()
.use(RegexMatch("Ice cream", match_type="search"))
.use(RegexMatch("Ice cream", match_type="search"), on="output")
.use(ReadingTime(0.05)) # 3 seconds
)
output_parser = StrOutputParser()
Expand Down
Loading