Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions protovalidate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from protovalidate import config, validator
from protovalidate import validator

Config = config.Config
Validator = validator.Validator
CompilationError = validator.CompilationError
ValidationError = validator.ValidationError
Expand All @@ -24,4 +23,4 @@
validate = _default_validator.validate
collect_violations = _default_validator.collect_violations

__all__ = ["CompilationError", "Config", "ValidationError", "Validator", "Violations", "collect_violations", "validate"]
__all__ = ["CompilationError", "ValidationError", "Validator", "Violations", "collect_violations", "validate"]
26 changes: 0 additions & 26 deletions protovalidate/config.py

This file was deleted.

14 changes: 4 additions & 10 deletions protovalidate/internal/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from google.protobuf import any_pb2, descriptor, duration_pb2, message, message_factory, timestamp_pb2

from buf.validate import validate_pb2
from protovalidate.config import Config
from protovalidate.internal.cel_field_presence import InterpretedRunner, in_has


Expand Down Expand Up @@ -266,19 +265,14 @@ def __init__(self, *, field_value: typing.Any = None, rule_value: typing.Any = N
class RuleContext:
"""The state associated with a single rule evaluation."""

_cfg: Config
_violations: list[Violation]

def __init__(self, *, config: Config, violations: typing.Optional[list[Violation]] = None):
self._cfg = config
def __init__(self, *, fail_fast: bool = False, violations: typing.Optional[list[Violation]] = None):
self._fail_fast = fail_fast
if violations is None:
violations = []
self._violations = violations

@property
def fail_fast(self) -> bool:
return self._cfg.fail_fast

@property
def violations(self) -> list[Violation]:
return self._violations
Expand All @@ -299,13 +293,13 @@ def add_rule_path_elements(self, elements: typing.Iterable[validate_pb2.FieldPat

@property
def done(self) -> bool:
return self.fail_fast and self.has_errors()
return self._fail_fast and self.has_errors()

def has_errors(self) -> bool:
return len(self._violations) > 0

def sub_context(self) -> "RuleContext":
return RuleContext(config=self._cfg)
return RuleContext(fail_fast=self._fail_fast)


class Rules:
Expand Down
17 changes: 7 additions & 10 deletions protovalidate/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from google.protobuf import message

from buf.validate import validate_pb2
from protovalidate.config import Config
from protovalidate.internal import extra_func
from protovalidate.internal import rules as _rules

Expand All @@ -36,29 +35,25 @@ class Validator:
"""

_factory: _rules.RuleFactory
_cfg: Config

def __init__(self, config: typing.Optional[Config] = None):
self._cfg = config if config is not None else Config()
def __init__(self):
funcs = extra_func.make_extra_funcs()
self._factory = _rules.RuleFactory(funcs)

def validate(
self,
message: message.Message,
):
def validate(self, message: message.Message, *, fail_fast: bool = False):
"""
Validates the given message against the static rules defined in
the message's descriptor.

Parameters:
message: The message to validate.
fail_fast: If true, validation will stop after the first iteration.
Raises:
CompilationError: If the static rules could not be compiled.
ValidationError: If the message is invalid. The violations raised as part of this error should
always be equal to the list of violations returned by `collect_violations`.
"""
violations = self.collect_violations(message)
violations = self.collect_violations(message, fail_fast=fail_fast)
if len(violations) > 0:
msg = f"invalid {message.DESCRIPTOR.name}"
raise ValidationError(msg, violations)
Expand All @@ -67,6 +62,7 @@ def collect_violations(
self,
message: message.Message,
*,
fail_fast: bool = False,
into: typing.Optional[list[Violation]] = None,
) -> list[Violation]:
"""
Expand All @@ -80,12 +76,13 @@ def collect_violations(

Parameters:
message: The message to validate.
fail_fast: If true, validation will stop after the first iteration.
into: If provided, any violations will be appended to the
Violations object and the same object will be returned.
Raises:
CompilationError: If the static rules could not be compiled.
"""
ctx = _rules.RuleContext(config=self._cfg, violations=into)
ctx = _rules.RuleContext(fail_fast=fail_fast, violations=into)
for rule in self._factory.get(message.DESCRIPTOR):
rule.validate(ctx, message)
if ctx.done:
Expand Down
23 changes: 0 additions & 23 deletions test/test_config.py

This file was deleted.

23 changes: 7 additions & 16 deletions test/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import protovalidate
from gen.tests.example.v1 import validations_pb2
from protovalidate.config import Config
from protovalidate.internal import rules


Expand All @@ -27,13 +26,11 @@ def get_default_validator():

This allows testing for validators created via:
- module-level singleton
- instantiated class with no config
- instantiated class with config
- instantiated class
"""
return [
("module singleton", protovalidate),
("no config", protovalidate.Validator()),
("with default config", protovalidate.Validator(Config())),
("instantiated class", protovalidate.Validator()),
]


Expand All @@ -42,8 +39,7 @@ class TestCollectViolations(unittest.TestCase):

A validator can be created via various ways:
- a module-level singleton, which returns a default validator
- instantiating the Validator class with no config, which returns a default validator
- instantiating the Validator class with a config
- instantiating the Validator class

In addition, the API for validating a message allows for two approaches:
- via a call to `validate`, which will raise a ValidationError if validation fails
Expand Down Expand Up @@ -188,11 +184,7 @@ def test_concatenated_values(self):
self._run_valid_tests(msg)

def test_fail_fast(self):
"""Test that fail fast correctly fails on first violation

Note this does not use a default validator, but instead uses one with a custom config
so that fail_fast can be set to True.
"""
"""Test that fail fast correctly fails on first violation"""
msg = validations_pb2.MultipleValidations()
msg.title = "bar"
msg.name = "blah"
Expand All @@ -203,18 +195,17 @@ def test_fail_fast(self):
expected_violation.field_value = msg.title
expected_violation.rule_value = "foo"

cfg = Config(fail_fast=True)
validator = protovalidate.Validator(config=cfg)
validator = protovalidate.Validator()

# Test validate
with self.assertRaises(protovalidate.ValidationError) as cm:
validator.validate(msg)
validator.validate(msg, fail_fast=True)
e = cm.exception
self.assertEqual(str(e), f"invalid {msg.DESCRIPTOR.name}")
self._compare_violations(e.violations, [expected_violation])

# Test collect_violations
violations = validator.collect_violations(msg)
violations = validator.collect_violations(msg, fail_fast=True)
self._compare_violations(violations, [expected_violation])

def _run_valid_tests(self, msg: message.Message):
Expand Down