diff --git a/pydantic_xml/mypy.py b/pydantic_xml/mypy.py index d62e56a..5dde656 100644 --- a/pydantic_xml/mypy.py +++ b/pydantic_xml/mypy.py @@ -1,7 +1,7 @@ from typing import Callable, Optional, Tuple, Union from mypy import nodes -from mypy.plugin import ClassDefContext, FunctionContext, Plugin, Type +from mypy.plugin import ClassDefContext, Plugin from pydantic.mypy import PydanticModelTransformer, PydanticPlugin MODEL_METACLASS_FULLNAME = 'pydantic_xml.model.XmlModelMeta' @@ -21,38 +21,6 @@ def get_metaclass_hook(self, fullname: str) -> Optional[Callable[[ClassDefContex return self._pydantic_model_metaclass_marker_callback return super().get_metaclass_hook(fullname) - def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], Type]]: - sym = self.lookup_fully_qualified(fullname) - if sym and sym.fullname == ATTR_FULLNAME: - return self._attribute_callback - elif sym and sym.fullname == ELEMENT_FULLNAME: - return self._element_callback - elif sym and sym.fullname == WRAPPED_FULLNAME: - return self._wrapped_callback - - return super().get_function_hook(fullname) - - def _attribute_callback(self, ctx: FunctionContext) -> Type: - return super()._pydantic_field_callback(self._pop_first_args(ctx, 2)) - - def _element_callback(self, ctx: FunctionContext) -> Type: - return super()._pydantic_field_callback(self._pop_first_args(ctx, 4)) - - def _wrapped_callback(self, ctx: FunctionContext) -> Type: - return super()._pydantic_field_callback(self._pop_first_args(ctx, 4)) - - def _pop_first_args(self, ctx: FunctionContext, num: int) -> FunctionContext: - return FunctionContext( - arg_types=ctx.arg_types[num:], - arg_kinds=ctx.arg_kinds[num:], - callee_arg_names=ctx.callee_arg_names[num:], - arg_names=ctx.arg_names[num:], - default_return_type=ctx.default_return_type, - args=ctx.args[num:], - context=ctx.context, - api=ctx.api, - ) - def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> bool: transformer = PydanticXmlModelTransformer(ctx.cls, ctx.reason, ctx.api, self.plugin_config) return transformer.transform() @@ -100,3 +68,43 @@ def get_alias_info(stmt: nodes.AssignmentStmt) -> Tuple[Union[str, None], bool]: return None, True return PydanticModelTransformer.get_alias_info(stmt) + + @staticmethod + def get_strict(stmt: nodes.AssignmentStmt) -> Optional[bool]: + expr = stmt.rvalue + if ( + isinstance(expr, nodes.CallExpr) and + isinstance(expr.callee, nodes.RefExpr) and + expr.callee.fullname in ENTITIES_FULLNAME + ): + for arg, name in zip(expr.args, expr.arg_names): + if name != 'strict': + continue + if isinstance(arg, nodes.NameExpr): + if arg.fullname == 'builtins.True': + return True + elif arg.fullname == 'builtins.False': + return False + return None + + return PydanticModelTransformer.get_strict(stmt) + + @staticmethod + def is_field_frozen(stmt: nodes.AssignmentStmt) -> bool: + expr = stmt.rvalue + if isinstance(expr, nodes.TempNode): + return False + + if not ( + isinstance(expr, nodes.CallExpr) and + isinstance(expr.callee, nodes.RefExpr) and + expr.callee.fullname in ENTITIES_FULLNAME + ): + return False + + for i, arg_name in enumerate(expr.arg_names): + if arg_name == 'frozen': + arg = expr.args[i] + return isinstance(arg, nodes.NameExpr) and arg.fullname == 'builtins.True' + + return PydanticModelTransformer.is_field_frozen(stmt)