From 548afbfb2691e5de0f08d48fe0db1dad1ff15d68 Mon Sep 17 00:00:00 2001 From: Stefan VanBuren Date: Tue, 26 Aug 2025 16:31:41 -0400 Subject: [PATCH] Remove Config class in favor of fail_fast kwargs Brings us full circle, effectively backing out #323. Resolves #362. [1]: https://docs.python.org/3/tutorial/controlflow.html#keyword-only-arguments --- protovalidate/__init__.py | 5 ++--- protovalidate/config.py | 26 -------------------------- protovalidate/internal/rules.py | 14 ++++---------- protovalidate/validator.py | 17 +++++++---------- test/test_config.py | 23 ----------------------- test/test_validate.py | 23 +++++++---------------- 6 files changed, 20 insertions(+), 88 deletions(-) delete mode 100644 protovalidate/config.py delete mode 100644 test/test_config.py diff --git a/protovalidate/__init__.py b/protovalidate/__init__.py index 1c078426..2ce8261b 100644 --- a/protovalidate/__init__.py +++ b/protovalidate/__init__.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from protovalidate import config, validator +from protovalidate import validator -Config = config.Config Validator = validator.Validator CompilationError = validator.CompilationError ValidationError = validator.ValidationError @@ -24,4 +23,4 @@ validate = _default_validator.validate collect_violations = _default_validator.collect_violations -__all__ = ["CompilationError", "Config", "ValidationError", "Validator", "Violations", "collect_violations", "validate"] +__all__ = ["CompilationError", "ValidationError", "Validator", "Violations", "collect_violations", "validate"] diff --git a/protovalidate/config.py b/protovalidate/config.py deleted file mode 100644 index 1e21683b..00000000 --- a/protovalidate/config.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2023-2025 Buf Technologies, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass - - -@dataclass -class Config: - """A class for holding configuration values for validation. - - Attributes: - fail_fast (bool): If true, validation will stop after the first violation. Defaults to False. - """ - - fail_fast: bool = False diff --git a/protovalidate/internal/rules.py b/protovalidate/internal/rules.py index ff2e319a..b4d995a3 100644 --- a/protovalidate/internal/rules.py +++ b/protovalidate/internal/rules.py @@ -22,7 +22,6 @@ from google.protobuf import any_pb2, descriptor, duration_pb2, message, message_factory, timestamp_pb2 from buf.validate import validate_pb2 -from protovalidate.config import Config from protovalidate.internal.cel_field_presence import InterpretedRunner, in_has @@ -266,19 +265,14 @@ def __init__(self, *, field_value: typing.Any = None, rule_value: typing.Any = N class RuleContext: """The state associated with a single rule evaluation.""" - _cfg: Config _violations: list[Violation] - def __init__(self, *, config: Config, violations: typing.Optional[list[Violation]] = None): - self._cfg = config + def __init__(self, *, fail_fast: bool = False, violations: typing.Optional[list[Violation]] = None): + self._fail_fast = fail_fast if violations is None: violations = [] self._violations = violations - @property - def fail_fast(self) -> bool: - return self._cfg.fail_fast - @property def violations(self) -> list[Violation]: return self._violations @@ -299,13 +293,13 @@ def add_rule_path_elements(self, elements: typing.Iterable[validate_pb2.FieldPat @property def done(self) -> bool: - return self.fail_fast and self.has_errors() + return self._fail_fast and self.has_errors() def has_errors(self) -> bool: return len(self._violations) > 0 def sub_context(self) -> "RuleContext": - return RuleContext(config=self._cfg) + return RuleContext(fail_fast=self._fail_fast) class Rules: diff --git a/protovalidate/validator.py b/protovalidate/validator.py index 2215997d..443e08ff 100644 --- a/protovalidate/validator.py +++ b/protovalidate/validator.py @@ -17,7 +17,6 @@ from google.protobuf import message from buf.validate import validate_pb2 -from protovalidate.config import Config from protovalidate.internal import extra_func from protovalidate.internal import rules as _rules @@ -36,29 +35,25 @@ class Validator: """ _factory: _rules.RuleFactory - _cfg: Config - def __init__(self, config: typing.Optional[Config] = None): - self._cfg = config if config is not None else Config() + def __init__(self): funcs = extra_func.make_extra_funcs() self._factory = _rules.RuleFactory(funcs) - def validate( - self, - message: message.Message, - ): + def validate(self, message: message.Message, *, fail_fast: bool = False): """ Validates the given message against the static rules defined in the message's descriptor. Parameters: message: The message to validate. + fail_fast: If true, validation will stop after the first iteration. Raises: CompilationError: If the static rules could not be compiled. ValidationError: If the message is invalid. The violations raised as part of this error should always be equal to the list of violations returned by `collect_violations`. """ - violations = self.collect_violations(message) + violations = self.collect_violations(message, fail_fast=fail_fast) if len(violations) > 0: msg = f"invalid {message.DESCRIPTOR.name}" raise ValidationError(msg, violations) @@ -67,6 +62,7 @@ def collect_violations( self, message: message.Message, *, + fail_fast: bool = False, into: typing.Optional[list[Violation]] = None, ) -> list[Violation]: """ @@ -80,12 +76,13 @@ def collect_violations( Parameters: message: The message to validate. + fail_fast: If true, validation will stop after the first iteration. into: If provided, any violations will be appended to the Violations object and the same object will be returned. Raises: CompilationError: If the static rules could not be compiled. """ - ctx = _rules.RuleContext(config=self._cfg, violations=into) + ctx = _rules.RuleContext(fail_fast=fail_fast, violations=into) for rule in self._factory.get(message.DESCRIPTOR): rule.validate(ctx, message) if ctx.done: diff --git a/test/test_config.py b/test/test_config.py deleted file mode 100644 index 71f33af7..00000000 --- a/test/test_config.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2023-2025 Buf Technologies, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -from protovalidate import Config - - -class TestConfig(unittest.TestCase): - def test_defaults(self): - cfg = Config() - self.assertFalse(cfg.fail_fast) diff --git a/test/test_validate.py b/test/test_validate.py index 2233ca6c..b7ea1036 100644 --- a/test/test_validate.py +++ b/test/test_validate.py @@ -18,7 +18,6 @@ import protovalidate from gen.tests.example.v1 import validations_pb2 -from protovalidate.config import Config from protovalidate.internal import rules @@ -27,13 +26,11 @@ def get_default_validator(): This allows testing for validators created via: - module-level singleton - - instantiated class with no config - - instantiated class with config + - instantiated class """ return [ ("module singleton", protovalidate), - ("no config", protovalidate.Validator()), - ("with default config", protovalidate.Validator(Config())), + ("instantiated class", protovalidate.Validator()), ] @@ -42,8 +39,7 @@ class TestCollectViolations(unittest.TestCase): A validator can be created via various ways: - a module-level singleton, which returns a default validator - - instantiating the Validator class with no config, which returns a default validator - - instantiating the Validator class with a config + - instantiating the Validator class In addition, the API for validating a message allows for two approaches: - via a call to `validate`, which will raise a ValidationError if validation fails @@ -188,11 +184,7 @@ def test_concatenated_values(self): self._run_valid_tests(msg) def test_fail_fast(self): - """Test that fail fast correctly fails on first violation - - Note this does not use a default validator, but instead uses one with a custom config - so that fail_fast can be set to True. - """ + """Test that fail fast correctly fails on first violation""" msg = validations_pb2.MultipleValidations() msg.title = "bar" msg.name = "blah" @@ -203,18 +195,17 @@ def test_fail_fast(self): expected_violation.field_value = msg.title expected_violation.rule_value = "foo" - cfg = Config(fail_fast=True) - validator = protovalidate.Validator(config=cfg) + validator = protovalidate.Validator() # Test validate with self.assertRaises(protovalidate.ValidationError) as cm: - validator.validate(msg) + validator.validate(msg, fail_fast=True) e = cm.exception self.assertEqual(str(e), f"invalid {msg.DESCRIPTOR.name}") self._compare_violations(e.violations, [expected_violation]) # Test collect_violations - violations = validator.collect_violations(msg) + violations = validator.collect_violations(msg, fail_fast=True) self._compare_violations(violations, [expected_violation]) def _run_valid_tests(self, msg: message.Message):