diff --git a/docs/05-dataclasses.md b/docs/05-dataclasses.md index 52a352e..a9d191f 100644 --- a/docs/05-dataclasses.md +++ b/docs/05-dataclasses.md @@ -793,12 +793,11 @@ the specified item validator). The `DataclassValidator` supports these context arguments and uses them in two ways: First, it passes them as they are to any field validator (which might pass them to other validators as well). Second, it also passes them to the -`__post_validate__()` method of the dataclass. +`__post_validate__()` method of the dataclass as long as the method accepts them. -However, for this to work, the method MUST accept arbitrary keyword arguments, i.e. it needs to be declared with a -`**kwargs` parameter (the parameter name doesn't matter). You can of course declare specific keyword arguments that you -want to use for post-validation (make sure to define them as optional!), but you still need to accept any other keyword -argument as well, otherwise the context arguments will not be passed to the method at all. +You can define the `__post_validate__()` method with specific keyword-only arguments and/or with a `**kwargs` parameter. +The `DataclassValidator` will make sure to only pass the arguments that the method accepts. Please make sure to use +**keyword-only** arguments instead of positional arguments. The latter will still work, but emit a warning. Example: @@ -814,8 +813,8 @@ class ContextSensitiveExampleClass: # This field is optional, unless the context says otherwise. some_value: Optional[int] = IntegerValidator(), Default(None) - # Note: Prefix the kwargs parameter with an underscore to avoid "unused parameter" notices. - def __post_validate__(self, *, require_some_value: bool = False, **_kwargs): + # Note: You can also specify **kwargs here to get all context arguments. + def __post_validate__(self, *, require_some_value: bool = False): # If require_some_value was set at validation time, ensure that some_value is set! if require_some_value and self.some_value is None: raise DataclassPostValidationError(field_errors={ diff --git a/src/validataclass/validators/dataclass_validator.py b/src/validataclass/validators/dataclass_validator.py index 4f32449..bd9d6fb 100644 --- a/src/validataclass/validators/dataclass_validator.py +++ b/src/validataclass/validators/dataclass_validator.py @@ -6,6 +6,7 @@ import dataclasses import inspect +import warnings from typing import Any, Dict, Generic, Optional, Type, TypeVar from validataclass.dataclasses import Default, NoDefault @@ -67,7 +68,7 @@ class ExampleDataclass: is part of regular dataclasses and thus also works without validataclass) or using a `__post_validate__()` method (which is called by the DataclassValidator after creating the object). The latter also supports *context-sensitive* validation, which means you can pass extra arguments to the `validate()` call that will be passed both to all field - validators and to the `__post_validate__()` method (as long as it is defined with a `**kwargs` argument). + validators and to the `__post_validate__()` method (as long as it is defined to accept the keyword arguments). In post-validation you can either raise regular `ValidationError` exceptions, which will be automatically wrapped inside a `DataclassPostValidationError` exception, or raise such an exception directly (in which case you can @@ -80,10 +81,7 @@ class ExampleDataclass: class ExampleDataclass: optional_field: str = StringValidator(), Default('') - # Note: The method MUST accept arbitrary keyword arguments (**kwargs), not just the parameter you defined, - # otherwise no context arguments will be passed to it at all. To avoid "unused parameter" notices, you can - # prepend the variable name with an underscore. - def __post_validate__(self, *, require_optional_field: bool = False, **_kwargs): + def __post_validate__(self, *, require_optional_field: bool = False): if require_optional_field and not self.optional_field: raise DataclassPostValidationError(field_errors={ 'value': RequiredValueError(reason='The optional field is required for some reason.'), @@ -216,10 +214,26 @@ def _post_validate(validated_object: T_Dataclass, **kwargs) -> T_Dataclass: """ # Post validation using the custom __post_validate__() method in the dataclass (if defined) if hasattr(validated_object, '__post_validate__'): - # Only pass context arguments if __post_validate__() accepts them - if inspect.getfullargspec(validated_object.__post_validate__).varkw is not None: - validated_object.__post_validate__(**kwargs) + post_validate_spec = inspect.getfullargspec(validated_object.__post_validate__) + + # Warn about __post_validate__() with positional arguments (ignoring "self") + if len(post_validate_spec.args) > 1 or post_validate_spec.varargs: + warnings.warn( + f'{validated_object.__class__.__name__}.__post_validate__() is defined with positional arguments. ' + 'This should still work, but it is recommended to use keyword-only arguments instead.' + ) + + # If __post_validate__() accepts arbitrary keyword arguments (**kwargs), we can just pass all keyword + # arguments to the function. Otherwise we need to filter out all keys that are not accepted as keyword + # arguments by the function. + if post_validate_spec.varkw is not None: + context_kwargs = kwargs else: - validated_object.__post_validate__() + context_kwargs = { + key: value for key, value in kwargs.items() + if key in post_validate_spec.kwonlyargs + post_validate_spec.args + } + + validated_object.__post_validate__(**context_kwargs) return validated_object diff --git a/tests/validators/dataclass_validator_test.py b/tests/validators/dataclass_validator_test.py index bf30d6e..fde5aed 100644 --- a/tests/validators/dataclass_validator_test.py +++ b/tests/validators/dataclass_validator_test.py @@ -94,13 +94,51 @@ class UnitTestContextSensitiveDataclass: name: str = UnitTestContextValidator() value: Optional[int] = (IntegerValidator(), Default(None)) - def __post_validate__(self, *, value_required: bool = False, **_kwargs): + def __post_validate__(self, *, value_required: bool = False): if value_required and self.value is None: raise DataclassPostValidationError(field_errors={ 'value': RequiredValueError(reason='Value is required in this context.'), }) +@validataclass +class UnitTestContextSensitiveDataclassWithPosArgs(UnitTestContextSensitiveDataclass): + """ + Dataclass with a __post_validate__() method that takes *positional* arguments. This should work, but emit a warning. + """ + + # Same as UnitTestContextSensitiveDataclass, but with positional arguments + def __post_validate__(self, value_required: bool = False): + super().__post_validate__(value_required=value_required) + + +# Regex-escaped warning text emitted when using __post_validate__ of the dataclass above +POST_VALIDATE_POS_ARGS_WARNING = \ + r'UnitTestContextSensitiveDataclassWithPosArgs\.__post_validate__\(\) is defined with positional arguments' + + +@validataclass +class UnitTestContextSensitiveDataclassWithVarKwargs: + """ + Dataclass with a __post_validate__() method that takes fixed *and* variable keyword arguments (`**kwargs`). + + This class only has one validated field "name". Additionally it takes two context parameters "ctx_a" and "ctx_b", as + well as arbitrary keyword arguments, which will be written into the attributes "ctx_a", "ctx_b" and "extra_kwargs" + respectively. + """ + name: str = UnitTestContextValidator() + + # These are no validated fields, just attributes that are populated by __post_validate__ + ctx_a = None + ctx_b = None + extra_kwargs = None + + def __post_validate__(self, *, ctx_a: str = '', ctx_b: str = '', **kwargs): + self.ctx_a = ctx_a + self.ctx_b = ctx_b + self.extra_kwargs = kwargs + + class DataclassValidatorTest: # Tests for DataclassValidator with a simple dataclass @@ -451,6 +489,71 @@ def test_dataclass_with_context_sensitive_post_validate_invalid(): }, } + @staticmethod + def test_dataclass_with_context_sensitive_post_validate_with_pos_args(): + """ Validate dataclass with a __post_validate__() method that accepts positional arguments. """ + validator = DataclassValidator(UnitTestContextSensitiveDataclassWithPosArgs) + + with pytest.warns(UserWarning, match=POST_VALIDATE_POS_ARGS_WARNING): + validated_data = validator.validate({'name': 'banana', 'value': 13}, value_required=True, foo=42) + + assert validated_data.name == "banana / {'value_required': True, 'foo': 42}" + assert validated_data.value == 13 + + @staticmethod + def test_dataclass_with_context_sensitive_post_validate_with_pos_args_invalid(): + """ Validate dataclass with a __post_validate__() method that accepts positional arguments, with invalid input. """ + validator = DataclassValidator(UnitTestContextSensitiveDataclassWithPosArgs) + + with pytest.raises(DataclassPostValidationError): + with pytest.warns(UserWarning, match=POST_VALIDATE_POS_ARGS_WARNING): + validator.validate({'name': 'banana'}, value_required=True) + + @staticmethod + @pytest.mark.parametrize( + 'validate_kwargs, expected_ctx_a, expected_ctx_b, expected_extra_kwargs', + [ + # No context arguments + ({}, '', '', {}), + + # Only context parameters defined as keyword arguments in __post_validate__ (ctx_a, ctx_b) + ({'ctx_a': 'foo'}, 'foo', '', {}), + ({'ctx_b': 'bar'}, '', 'bar', {}), + ({'ctx_b': 'bar', 'ctx_a': 'foo'}, 'foo', 'bar', {}), + + # Arbitrary context arguments not defined as keyword arguments in __post_validate__ + ( + {'some_value': 42}, + '', + '', + {'some_value': 42}, + ), + ( + {'ctx_a': 'foo', 'some_value': 42}, + 'foo', + '', + {'some_value': 42}, + ), + ( + {'any_value': 3, 'ctx_a': 'foo', 'some_value': 42, 'ctx_b': 'bar'}, + 'foo', + 'bar', + {'any_value': 3, 'some_value': 42}, + ), + ] + ) + def test_dataclass_with_context_sensitive_post_validate_with_var_kwargs( + validate_kwargs, expected_ctx_a, expected_ctx_b, expected_extra_kwargs, + ): + """ Validate dataclass with a context-sensitive __post_validate__() method that accepts arbitrary keyword arguments. """ + validator = DataclassValidator(UnitTestContextSensitiveDataclassWithVarKwargs) + validated_data = validator.validate({'name': 'unit-test'}, **validate_kwargs) + + assert validated_data.name == f"unit-test / {validate_kwargs}" + assert validated_data.ctx_a == expected_ctx_a + assert validated_data.ctx_b == expected_ctx_b + assert validated_data.extra_kwargs == expected_extra_kwargs + # Test invalid validator options @staticmethod