From 7d2e925116d88c8ca8ed637512de8c75e24dbf4d Mon Sep 17 00:00:00 2001 From: Mehdi Date: Sun, 27 Jan 2019 17:43:35 +0100 Subject: [PATCH 1/2] Plugin: handle `declarative_base` `cls` argument --- sqlmypy.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/sqlmypy.py b/sqlmypy.py index eeb4fa1..569ef1d 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -1,15 +1,16 @@ +from mypy.mro import calculate_mro from mypy.plugin import Plugin, FunctionContext, ClassDefContext from mypy.plugins.common import add_method from mypy.nodes import( NameExpr, Expression, StrExpr, TypeInfo, ClassDef, Block, SymbolTable, SymbolTableNode, GDEF, - Argument, Var, ARG_STAR2, MDEF + Argument, Var, ARG_STAR2, MDEF, TupleExpr ) from mypy.types import ( UnionType, NoneTyp, Instance, Type, AnyType, TypeOfAny, UninhabitedType, CallableType ) from mypy.typevars import fill_typevars_with_any -from typing import Optional, Callable, Dict, TYPE_CHECKING +from typing import Optional, Callable, Dict, TYPE_CHECKING, List if TYPE_CHECKING: from typing_extensions import Final @@ -141,14 +142,23 @@ def decl_info_hook(ctx): Base = declarative_base() """ + cls_instances = [] # type: List[Instance] + + if 'cls' in ctx.call.arg_names: + declarative_base_cls_arg = ctx.call.args[ctx.call.arg_names.index("cls")] + if isinstance(declarative_base_cls_arg, TupleExpr): + cls_instances = [Instance(item.node, []) for item in declarative_base_cls_arg.items] + else: + cls_instances = [Instance(declarative_base_cls_arg.node, [])] + class_def = ClassDef(ctx.name, Block([])) class_def.fullname = ctx.api.qualified_name(ctx.name) info = TypeInfo(SymbolTable(), class_def, ctx.api.cur_mod_id) class_def.info = info obj = ctx.api.builtin_type('builtins.object') - info.mro = [info, obj.type] - info.bases = [obj] + info.bases = cls_instances + [obj] + calculate_mro(info) ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info)) set_declarative(info) From c717f4de549f742409fb4236208e73373d1688f5 Mon Sep 17 00:00:00 2001 From: Mehdi Date: Wed, 30 Jan 2019 18:50:44 +0100 Subject: [PATCH 2/2] catch MRO exception --- sqlmypy.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sqlmypy.py b/sqlmypy.py index 569ef1d..b4b7220 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -1,4 +1,4 @@ -from mypy.mro import calculate_mro +from mypy.mro import calculate_mro, MroError from mypy.plugin import Plugin, FunctionContext, ClassDefContext from mypy.plugins.common import add_method from mypy.nodes import( @@ -158,7 +158,14 @@ def decl_info_hook(ctx): class_def.info = info obj = ctx.api.builtin_type('builtins.object') info.bases = cls_instances + [obj] - calculate_mro(info) + try: + calculate_mro(info) + except MroError: + ctx.api.errors.report(ctx.get_line(), ctx.get_column(), "Not able to calculate MRO for declarative base", + blocker=False) + info.bases = [obj] + info.fallback_to_any = True + ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info)) set_declarative(info)