Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/hub/api_reference_markdown/validators.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ Validates whether the generated code snippet contains any secrets.
```py

guard = Guard.from_string(validators=[
DetectSecrets(on_fail="fix")
DetectSecrets(on_fail=OnFailAction.FIX)
])
guard.parse(
llm_output=code_snippet,
Expand Down
8 changes: 4 additions & 4 deletions docs/llm_api_wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ guard = Guard.from_string(
validators=[
ValidLength(
min=48,
on_fail="fix"
on_fail=OnFailAction.FIX
),
ToxicLanguage(
on_fail="fix"
on_fail=OnFailAction.FIX
)
],
prompt=prompt
Expand Down Expand Up @@ -179,10 +179,10 @@ guard = Guard.from_string(
validators=[
ValidLength(
min=48,
on_fail="fix"
on_fail=OnFailAction.FIX
),
ToxicLanguage(
on_fail="fix"
on_fail=OnFailAction.FIX
)
],
prompt=prompt
Expand Down
6 changes: 4 additions & 2 deletions guardrails/utils/pydantic_utils/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from guardrails.datatypes import String as StringDataType
from guardrails.datatypes import Time as TimeDataType
from guardrails.utils.safe_get import safe_get
from guardrails.validator_base import Validator
from guardrails.validator_base import OnFailAction, Validator
from guardrails.validatorsattr import ValidatorsAttr


Expand Down Expand Up @@ -218,7 +218,9 @@ def process_validators(vals, fld):
)
if "validators" not in fld.field_info.extra:
fld.field_info.extra["validators"] = []
fld.field_info.extra["validators"].append((gd_validator, "reask"))
fld.field_info.extra["validators"].append(
(gd_validator, OnFailAction.REASK)
)

model_fields = {}
for field_name, field in model.__fields__.items():
Expand Down
6 changes: 4 additions & 2 deletions guardrails/utils/pydantic_utils/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from guardrails.datatypes import String as StringDataType
from guardrails.datatypes import Time as TimeDataType
from guardrails.utils.safe_get import safe_get
from guardrails.validator_base import Validator
from guardrails.validator_base import OnFailAction, Validator
from guardrails.validatorsattr import ValidatorsAttr

DataTypeT = TypeVar("DataTypeT", bound=DataType)
Expand Down Expand Up @@ -248,7 +248,9 @@ def process_validators(vals, fld):
)
if "validators" not in fld.field_info.json_schema_extra:
fld.json_schema_extra["validators"] = []
fld.json_schema_extra["validators"].append((gd_validator, "reask"))
fld.json_schema_extra["validators"].append(
(gd_validator, OnFailAction.REASK)
)

model_fields = {}
for field_name, field in model.model_fields.items():
Expand Down
19 changes: 16 additions & 3 deletions guardrails/validator_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
from collections import defaultdict
from copy import deepcopy
from enum import Enum
from string import Template
from typing import (
Any,
Expand Down Expand Up @@ -373,6 +374,16 @@ class FailResult(ValidationResult):
fix_value: Optional[Any] = None


class OnFailAction(str, Enum):
REASK = "reask"
FIX = "fix"
FILTER = "filter"
REFRAIN = "refrain"
NOOP = "noop"
EXCEPTION = "exception"
FIX_REASK = "fix_reask"


@dataclass # type: ignore
class Validator(Runnable):
"""Base class for validators."""
Expand All @@ -384,7 +395,9 @@ class Validator(Runnable):
required_metadata_keys = []
_metadata = {}

def __init__(self, on_fail: Optional[Union[Callable, str]] = None, **kwargs):
def __init__(
self, on_fail: Optional[Union[Callable, OnFailAction]] = None, **kwargs
):
# Raise a warning for deprecated validators

# Get class name and rail_alias
Expand All @@ -411,8 +424,8 @@ def __init__(self, on_fail: Optional[Union[Callable, str]] = None, **kwargs):
)

if on_fail is None:
on_fail = "noop"
if isinstance(on_fail, str):
on_fail = OnFailAction.NOOP
if isinstance(on_fail, OnFailAction):
self.on_fail_descriptor = on_fail
self.on_fail_method = None
else:
Expand Down
25 changes: 15 additions & 10 deletions guardrails/validator_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from concurrent.futures import ProcessPoolExecutor
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

from guardrails.classes.history import Iteration
from guardrails.datatypes import FieldValidation
Expand All @@ -17,6 +17,7 @@
from guardrails.validator_base import (
FailResult,
Filter,
OnFailAction,
PassResult,
Refrain,
ValidationResult,
Expand Down Expand Up @@ -59,13 +60,13 @@ def perform_correction(
results: List[FailResult],
value: Any,
validator: Validator,
on_fail_descriptor: str,
on_fail_descriptor: Union[OnFailAction, str],
):
if on_fail_descriptor == "fix":
if on_fail_descriptor == OnFailAction.FIX:
# FIXME: Should we still return fix_value if it is None?
# I think we should warn and return the original value.
return results[0].fix_value
elif on_fail_descriptor == "fix_reask":
elif on_fail_descriptor == OnFailAction.FIX_REASK:
# FIXME: Same thing here
fixed_value = results[0].fix_value
result = self.execute_validator(
Expand All @@ -83,21 +84,21 @@ def perform_correction(
if validator.on_fail_method is None:
raise ValueError("on_fail is 'custom' but on_fail_method is None")
return validator.on_fail_method(value, results)
if on_fail_descriptor == "reask":
if on_fail_descriptor == OnFailAction.REASK:
return FieldReAsk(
incorrect_value=value,
fail_results=results,
)
if on_fail_descriptor == "exception":
if on_fail_descriptor == OnFailAction.EXCEPTION:
raise ValidationError(
"Validation failed for field with errors: "
+ ", ".join([result.error_message for result in results])
)
if on_fail_descriptor == "filter":
if on_fail_descriptor == OnFailAction.FILTER:
return Filter()
if on_fail_descriptor == "refrain":
if on_fail_descriptor == OnFailAction.REFRAIN:
return Refrain()
if on_fail_descriptor == "noop":
if on_fail_descriptor == OnFailAction.NOOP:
return value
else:
raise ValueError(
Expand Down Expand Up @@ -251,7 +252,11 @@ def group_validators(self, validators):
validators, key=lambda v: (v.on_fail_descriptor, v.override_value_on_pass)
)
for (on_fail_descriptor, override_on_pass), group in groups:
if override_on_pass or on_fail_descriptor in ["fix", "fix_reask", "custom"]:
if override_on_pass or on_fail_descriptor in [
OnFailAction.FIX,
OnFailAction.FIX_REASK,
"custom",
]:
for validator in group:
yield on_fail_descriptor, [validator]
else:
Expand Down
2 changes: 1 addition & 1 deletion guardrails/validators/detect_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class DetectSecrets(Validator):
```py

guard = Guard.from_string(validators=[
DetectSecrets(on_fail="fix")
DetectSecrets(on_fail=OnFailAction.FIX)
])
guard.parse(
llm_output=code_snippet,
Expand Down
3 changes: 2 additions & 1 deletion guardrails/validators/ends_with.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from guardrails.logger import logger
from guardrails.validator_base import (
FailResult,
OnFailAction,
PassResult,
ValidationResult,
Validator,
Expand All @@ -26,7 +27,7 @@ class EndsWith(Validator):
end: The required last element.
"""

def __init__(self, end: str, on_fail: str = "fix"):
def __init__(self, end: str, on_fail: OnFailAction = OnFailAction.FIX):
super().__init__(
on_fail=on_fail,
end=end,
Expand Down
4 changes: 2 additions & 2 deletions guardrails/validators/reading_time.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, Optional

from guardrails.logger import logger
from guardrails.validator_base import (
Expand Down Expand Up @@ -28,7 +28,7 @@ class ReadingTime(Validator):
reading_time: The maximum reading time in minutes.
"""

def __init__(self, reading_time: int, on_fail: Optional[str] = None):
def __init__(self, reading_time: int, on_fail: Optional[Callable] = None):
super().__init__(
on_fail=on_fail,
reading_time=reading_time,
Expand Down
4 changes: 2 additions & 2 deletions guardrails/validatorsattr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from guardrails.constants import hub
from guardrails.utils.xml_utils import cast_xml_to_string
from guardrails.validator_base import Validator, ValidatorSpec
from guardrails.validator_base import OnFailAction, Validator, ValidatorSpec


class ValidatorsAttr(pydantic.BaseModel):
Expand Down Expand Up @@ -176,7 +176,7 @@ def from_xml(
key = cast_xml_to_string(key)
if key.startswith("on-fail-"):
on_fail_handler_name = key[len("on-fail-") :]
on_fail_handler = value
on_fail_handler = OnFailAction(value)
on_fail_handlers[on_fail_handler_name] = on_fail_handler

validators, unregistered_validators = cls.get_validators(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@

from pydantic import BaseModel, Field

from guardrails.validator_base import OnFailAction
from guardrails.validators import LowerCase, OneLine, TwoWords


class FeeDetailsFilter(BaseModel):
index: int = Field(validators=("1-indexed", "noop"))
index: int = Field(validators=("1-indexed", OnFailAction.NOOP))
name: str = Field(
validators=[LowerCase(on_fail="filter"), TwoWords(on_fail="filter")]
validators=[
LowerCase(on_fail=OnFailAction.FILTER),
TwoWords(on_fail=OnFailAction.FILTER),
]
)
explanation: str = Field(validators=OneLine(on_fail="filter"))
value: float = Field(validators=("percentage", "noop"))
explanation: str = Field(validators=OneLine(on_fail=OnFailAction.FILTER))
value: float = Field(validators=("percentage", OnFailAction.NOOP))


class ContractDetailsFilter(BaseModel):
Expand All @@ -25,10 +29,15 @@ class ContractDetailsFilter(BaseModel):


class FeeDetailsFix(BaseModel):
index: int = Field(validators=("1-indexed", "noop"))
name: str = Field(validators=[LowerCase(on_fail="fix"), TwoWords(on_fail="fix")])
explanation: str = Field(validators=OneLine(on_fail="fix"))
value: float = Field(validators=("percentage", "noop"))
index: int = Field(validators=("1-indexed", OnFailAction.NOOP))
name: str = Field(
validators=[
LowerCase(on_fail=OnFailAction.FIX),
TwoWords(on_fail=OnFailAction.FIX),
]
)
explanation: str = Field(validators=OneLine(on_fail=OnFailAction.FIX))
value: float = Field(validators=("percentage", OnFailAction.NOOP))


class ContractDetailsFix(BaseModel):
Expand All @@ -42,10 +51,15 @@ class ContractDetailsFix(BaseModel):


class FeeDetailsNoop(BaseModel):
index: int = Field(validators=("1-indexed", "noop"))
name: str = Field(validators=[LowerCase(on_fail="noop"), TwoWords(on_fail="noop")])
explanation: str = Field(validators=OneLine(on_fail="noop"))
value: float = Field(validators=("percentage", "noop"))
index: int = Field(validators=("1-indexed", OnFailAction.NOOP))
name: str = Field(
validators=[
LowerCase(on_fail=OnFailAction.NOOP),
TwoWords(on_fail=OnFailAction.NOOP),
]
)
explanation: str = Field(validators=OneLine(on_fail=OnFailAction.NOOP))
value: float = Field(validators=("percentage", OnFailAction.NOOP))


class ContractDetailsNoop(BaseModel):
Expand All @@ -59,10 +73,15 @@ class ContractDetailsNoop(BaseModel):


class FeeDetailsReask(BaseModel):
index: int = Field(validators=("1-indexed", "noop"))
name: str = Field(validators=[LowerCase(on_fail="noop"), TwoWords(on_fail="reask")])
explanation: str = Field(validators=OneLine(on_fail="noop"))
value: float = Field(validators=("percentage", "noop"))
index: int = Field(validators=("1-indexed", OnFailAction.NOOP))
name: str = Field(
validators=[
LowerCase(on_fail=OnFailAction.NOOP),
TwoWords(on_fail=OnFailAction.REASK),
]
)
explanation: str = Field(validators=OneLine(on_fail=OnFailAction.NOOP))
value: float = Field(validators=("percentage", OnFailAction.NOOP))


class ContractDetailsReask(BaseModel):
Expand All @@ -76,12 +95,15 @@ class ContractDetailsReask(BaseModel):


class FeeDetailsRefrain(BaseModel):
index: int = Field(validators=("1-indexed", "noop"))
index: int = Field(validators=("1-indexed", OnFailAction.NOOP))
name: str = Field(
validators=[LowerCase(on_fail="refrain"), TwoWords(on_fail="refrain")]
validators=[
LowerCase(on_fail=OnFailAction.REFRAIN),
TwoWords(on_fail=OnFailAction.REFRAIN),
]
)
explanation: str = Field(validators=OneLine(on_fail="refrain"))
value: float = Field(validators=("percentage", "noop"))
explanation: str = Field(validators=OneLine(on_fail=OnFailAction.REFRAIN))
value: float = Field(validators=("percentage", OnFailAction.NOOP))


class ContractDetailsRefrain(BaseModel):
Expand Down
Loading