diff --git a/sqlmypy.py b/sqlmypy.py index eeb4fa1..b4b7220 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -1,15 +1,16 @@ +from mypy.mro import calculate_mro, MroError from mypy.plugin import Plugin, FunctionContext, ClassDefContext from mypy.plugins.common import add_method from mypy.nodes import( NameExpr, Expression, StrExpr, TypeInfo, ClassDef, Block, SymbolTable, SymbolTableNode, GDEF, - Argument, Var, ARG_STAR2, MDEF + Argument, Var, ARG_STAR2, MDEF, TupleExpr ) from mypy.types import ( UnionType, NoneTyp, Instance, Type, AnyType, TypeOfAny, UninhabitedType, CallableType ) from mypy.typevars import fill_typevars_with_any -from typing import Optional, Callable, Dict, TYPE_CHECKING +from typing import Optional, Callable, Dict, TYPE_CHECKING, List if TYPE_CHECKING: from typing_extensions import Final @@ -141,14 +142,30 @@ def decl_info_hook(ctx): Base = declarative_base() """ + cls_instances = [] # type: List[Instance] + + if 'cls' in ctx.call.arg_names: + declarative_base_cls_arg = ctx.call.args[ctx.call.arg_names.index("cls")] + if isinstance(declarative_base_cls_arg, TupleExpr): + cls_instances = [Instance(item.node, []) for item in declarative_base_cls_arg.items] + else: + cls_instances = [Instance(declarative_base_cls_arg.node, [])] + class_def = ClassDef(ctx.name, Block([])) class_def.fullname = ctx.api.qualified_name(ctx.name) info = TypeInfo(SymbolTable(), class_def, ctx.api.cur_mod_id) class_def.info = info obj = ctx.api.builtin_type('builtins.object') - info.mro = [info, obj.type] - info.bases = [obj] + info.bases = cls_instances + [obj] + try: + calculate_mro(info) + except MroError: + ctx.api.errors.report(ctx.get_line(), ctx.get_column(), "Not able to calculate MRO for declarative base", + blocker=False) + info.bases = [obj] + info.fallback_to_any = True + ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info)) set_declarative(info)