From 20c0c6d9219e29a95f5e1cadeb05ec39d8c15c24 Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Mon, 4 Dec 2023 12:37:21 -0600 Subject: [PATCH] Fix `mypy` error on free before validator (classmethod) (#8285) --- pydantic/functional_validators.py | 38 ++++++++++++++++--- tests/mypy/modules/success.py | 13 +++++++ .../outputs/1.0.1/mypy-default_ini/success.py | 13 +++++++ .../1.0.1/pyproject-default_toml/success.py | 13 +++++++ 4 files changed, 72 insertions(+), 5 deletions(-) diff --git a/pydantic/functional_validators.py b/pydantic/functional_validators.py index 26ff76d4cb..f5654967e6 100644 --- a/pydantic/functional_validators.py +++ b/pydantic/functional_validators.py @@ -417,6 +417,21 @@ def __call__( # noqa: D102 ... +class FreeModelBeforeValidatorWithoutInfo(Protocol): + """A @model_validator decorated function signature. + This is used when `mode='before'` and the function does not have info argument. + """ + + def __call__( # noqa: D102 + self, + # this can be a dict, a model instance + # or anything else that gets passed to validate_python + # thus validators _must_ handle all cases + __value: Any, + ) -> Any: + ... + + class ModelBeforeValidatorWithoutInfo(Protocol): """A @model_validator decorated function signature. This is used when `mode='before'` and the function does not have info argument. @@ -433,6 +448,20 @@ def __call__( # noqa: D102 ... +class FreeModelBeforeValidator(Protocol): + """A `@model_validator` decorated function signature. This is used when `mode='before'`.""" + + def __call__( # noqa: D102 + self, + # this can be a dict, a model instance + # or anything else that gets passed to validate_python + # thus validators _must_ handle all cases + __value: Any, + __info: _core_schema.ValidationInfo, + ) -> Any: + ... + + class ModelBeforeValidator(Protocol): """A `@model_validator` decorated function signature. This is used when `mode='before'`.""" @@ -457,7 +486,9 @@ def __call__( # noqa: D102 """A `@model_validator` decorated function signature. This is used when `mode='after'`.""" _AnyModelWrapValidator = Union[ModelWrapValidator[_ModelType], ModelWrapValidatorWithoutInfo[_ModelType]] -_AnyModeBeforeValidator = Union[ModelBeforeValidator, ModelBeforeValidatorWithoutInfo] +_AnyModeBeforeValidator = Union[ + FreeModelBeforeValidator, ModelBeforeValidator, FreeModelBeforeValidatorWithoutInfo, ModelBeforeValidatorWithoutInfo +] _AnyModelAfterValidator = Union[ModelAfterValidator[_ModelType], ModelAfterValidatorWithoutInfo[_ModelType]] @@ -499,8 +530,6 @@ def model_validator( Example usage: ```py - from typing import Optional - from typing_extensions import Self from pydantic import BaseModel, ValidationError, model_validator @@ -525,8 +554,7 @@ def verify_square(self) -> Self: print(e) ''' 1 validation error for Square - __root__ - width and height do not match (type=value_error) + Value error, width and height do not match [type=value_error, input_value={'width': 1, 'height': 2}, input_type=dict] ''' ``` diff --git a/tests/mypy/modules/success.py b/tests/mypy/modules/success.py index d68f2f2256..4f0d6db99e 100644 --- a/tests/mypy/modules/success.py +++ b/tests/mypy/modules/success.py @@ -42,6 +42,7 @@ WrapValidator, create_model, field_validator, + model_validator, root_validator, validate_call, ) @@ -308,3 +309,15 @@ class Abstract(BaseModel): class Concrete(Abstract): class_id = 1 + + +def two_dim_shape_validator(v: Dict[str, Any]) -> Dict[str, Any]: + assert 'volume' not in v, 'shape is 2d, cannot have volume' + return v + + +class Square(BaseModel): + width: float + height: float + + free_validator = model_validator(mode='before')(two_dim_shape_validator) diff --git a/tests/mypy/outputs/1.0.1/mypy-default_ini/success.py b/tests/mypy/outputs/1.0.1/mypy-default_ini/success.py index ceda32fe49..9406c543df 100644 --- a/tests/mypy/outputs/1.0.1/mypy-default_ini/success.py +++ b/tests/mypy/outputs/1.0.1/mypy-default_ini/success.py @@ -42,6 +42,7 @@ WrapValidator, create_model, field_validator, + model_validator, root_validator, validate_call, ) @@ -314,3 +315,15 @@ class Abstract(BaseModel): class Concrete(Abstract): class_id = 1 + + +def two_dim_shape_validator(v: Dict[str, Any]) -> Dict[str, Any]: + assert 'volume' not in v, 'shape is 2d, cannot have volume' + return v + + +class Square(BaseModel): + width: float + height: float + + free_validator = model_validator(mode='before')(two_dim_shape_validator) diff --git a/tests/mypy/outputs/1.0.1/pyproject-default_toml/success.py b/tests/mypy/outputs/1.0.1/pyproject-default_toml/success.py index ceda32fe49..9406c543df 100644 --- a/tests/mypy/outputs/1.0.1/pyproject-default_toml/success.py +++ b/tests/mypy/outputs/1.0.1/pyproject-default_toml/success.py @@ -42,6 +42,7 @@ WrapValidator, create_model, field_validator, + model_validator, root_validator, validate_call, ) @@ -314,3 +315,15 @@ class Abstract(BaseModel): class Concrete(Abstract): class_id = 1 + + +def two_dim_shape_validator(v: Dict[str, Any]) -> Dict[str, Any]: + assert 'volume' not in v, 'shape is 2d, cannot have volume' + return v + + +class Square(BaseModel): + width: float + height: float + + free_validator = model_validator(mode='before')(two_dim_shape_validator)