From e4b002d7f4cb6e118cc527016f15f25250ddbd56 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 8 Nov 2018 20:16:33 -0800 Subject: [PATCH 01/16] Some stuff --- setup.py | 4 +-- sqlalchemy-plugin/sqlmypy.py | 0 sqlalchemy-stubs/sql/schema.pyi | 8 +++--- sqlmypy.py | 26 +++++++++++++++++++ .../sqltyping.py => sqltyping.py | 0 test/sqlalchemy.ini | 3 +++ test/testsql.py | 2 ++ 7 files changed, 37 insertions(+), 6 deletions(-) delete mode 100644 sqlalchemy-plugin/sqlmypy.py create mode 100644 sqlmypy.py rename sqlalchemy-typing/sqltyping.py => sqltyping.py (100%) create mode 100644 test/sqlalchemy.ini diff --git a/setup.py b/setup.py index 81fc445..2dd7e74 100644 --- a/setup.py +++ b/setup.py @@ -22,10 +22,10 @@ def find_stub_files(): author='Ivan Levkivskyi', author_email='levkivskyi@gmail.com', license='MIT License', - py_modules=[], + py_modules=['sqlmypy'], install_requires=[ 'typing-extensions>=3.6.5' ], packages=['sqlalchemy-stubs'], - package_data={'sqlalchemy-stubs': find_stub_files()} + package_data={'sqlalchemy-stubs': find_stub_files()}, ) diff --git a/sqlalchemy-plugin/sqlmypy.py b/sqlalchemy-plugin/sqlmypy.py deleted file mode 100644 index e69de29..0000000 diff --git a/sqlalchemy-stubs/sql/schema.pyi b/sqlalchemy-stubs/sql/schema.pyi index 8eaa91b..1205bcf 100644 --- a/sqlalchemy-stubs/sql/schema.pyi +++ b/sqlalchemy-stubs/sql/schema.pyi @@ -67,13 +67,13 @@ class Column(SchemaItem, ColumnClause[_T]): foreign_keys: Any = ... info: Any = ... @overload - def __init__(self, name: str, type_: Type[TypeEngine[_T]], *args, **kwargs) -> None: ... + def __init__(self, name: str, type_: Type[TypeEngine[_T]], *args, primary_key: Any = ..., nullable: bool = ..., **kwargs) -> None: ... @overload - def __init__(self, type_: Type[TypeEngine[_T]], *args, **kwargs) -> None: ... + def __init__(self, type_: Type[TypeEngine[_T]], *args, primary_key: Any = ..., nullable: bool = ..., **kwargs) -> None: ... @overload - def __init__(self, name: str, type_: TypeEngine[_T], *args, **kwargs) -> None: ... + def __init__(self, name: str, type_: TypeEngine[_T], *args, primary_key: Any = ..., nullable: bool = ..., **kwargs) -> None: ... @overload - def __init__(self, type_: TypeEngine[_T], *args, **kwargs) -> None: ... + def __init__(self, type_: TypeEngine[_T], *args, primary_key: Any = ..., nullable: bool = ..., **kwargs) -> None: ... @overload def __init__(self, *args, **kwargs) -> None: ... def references(self, column): ... diff --git a/sqlmypy.py b/sqlmypy.py new file mode 100644 index 0000000..afd0c8e --- /dev/null +++ b/sqlmypy.py @@ -0,0 +1,26 @@ +from mypy.plugin import Plugin +from mypy.nodes import NameExpr +from mypy.types import UnionType, NoneTyp, Instance + + +class BasicSQLAlchemyPlugin(Plugin): + def get_function_hook(self, fullname): + if fullname == 'sqlalchemy.sql.schema.Column': + return column_hook + return None + + +def column_hook(ctx): + assert isinstance(ctx.default_return_type, Instance) + last_arg_exprs = ctx.args[-1] + if nullable: + return ctx.default_return_type + assert len(ctx.default_return_type.args) == 1 + arg_type = ctx.default_return_type.args[0] + return Instance(ctx.default_return_type.type, [UnionType([arg_type, NoneTyp()])], + line=ctx.default_return_type.line, + column=ctx.default_return_type.column) + + +def plugin(version): + return BasicSQLAlchemyPlugin diff --git a/sqlalchemy-typing/sqltyping.py b/sqltyping.py similarity index 100% rename from sqlalchemy-typing/sqltyping.py rename to sqltyping.py diff --git a/test/sqlalchemy.ini b/test/sqlalchemy.ini new file mode 100644 index 0000000..9f1b8e1 --- /dev/null +++ b/test/sqlalchemy.ini @@ -0,0 +1,3 @@ +[mypy] +plugins = sqlmypy + diff --git a/test/testsql.py b/test/testsql.py index a162fa9..4e60cc6 100644 --- a/test/testsql.py +++ b/test/testsql.py @@ -14,6 +14,7 @@ this_file_dir = os.path.dirname(os.path.realpath(__file__)) prefix = os.path.dirname(this_file_dir) +inipath = os.path.abspath(os.path.join(prefix, 'test')) # Locations of test data files such as test case descriptions (.test). test_data_prefix = os.path.join(prefix, 'test', 'test-data') @@ -31,6 +32,7 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: mypy_cmdline = [ '--show-traceback', '--no-silence-site-packages', + '--config-file={}/sqlalchemy.ini'.format(inipath), ] py2 = testcase.name.lower().endswith('python2') if py2: From 3f5b33f91b1cc7b698476b8af13b2286d64c2fb4 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 14 Nov 2018 18:52:56 -0800 Subject: [PATCH 02/16] Add basic Column plugin; tests TBD --- setup.py | 2 +- sqlmypy.py | 44 ++++++++++++++++++++++++++++++++++++++------ 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index 2dd7e74..ae95c7f 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ def find_stub_files(): author='Ivan Levkivskyi', author_email='levkivskyi@gmail.com', license='MIT License', - py_modules=['sqlmypy'], + py_modules=['sqlmypy', 'sqltyping'], install_requires=[ 'typing-extensions>=3.6.5' ], diff --git a/sqlmypy.py b/sqlmypy.py index afd0c8e..95bbe8c 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -1,19 +1,41 @@ from mypy.plugin import Plugin -from mypy.nodes import NameExpr -from mypy.types import UnionType, NoneTyp, Instance +from mypy.nodes import NameExpr, Expression +from mypy.types import UnionType, NoneTyp, Instance, Type + +from typing import Optional, TYPE_CHECKING, Callable + +if TYPE_CHECKING: + from mypy.plugin import FunctionContext class BasicSQLAlchemyPlugin(Plugin): - def get_function_hook(self, fullname): + def get_function_hook(self, fullname: str) -> Optional[Callable[['FunctionContext'], Type]]: if fullname == 'sqlalchemy.sql.schema.Column': return column_hook return None -def column_hook(ctx): +def column_hook(ctx: 'FunctionContext') -> Type: assert isinstance(ctx.default_return_type, Instance) - last_arg_exprs = ctx.args[-1] - if nullable: + # This is very fragile, need to update the plugin API. + if len(ctx.args) in (5, 6): # overloads with and without the name + nullable_index = len(ctx.args) - 2 + primary_index = len(ctx.args) - 3 + else: + # Something new, give up. + return ctx.default_return_type + + nullable_args = ctx.args[nullable_index] + primary_args = ctx.args[primary_index] + if nullable_args: + nullable = parse_bool(nullable_args[0]) + else: + if primary_args: + nullable = not parse_bool(primary_args[0]) + else: + nullable = True + + if not nullable: return ctx.default_return_type assert len(ctx.default_return_type.args) == 1 arg_type = ctx.default_return_type.args[0] @@ -22,5 +44,15 @@ def column_hook(ctx): column=ctx.default_return_type.column) +# We really need to add this to TypeChecker API +def parse_bool(expr: Expression) -> Optional[bool]: + if isinstance(expr, NameExpr): + if expr.fullname == 'builtins.True': + return True + if expr.fullname == 'builtins.False': + return False + return None + + def plugin(version): return BasicSQLAlchemyPlugin From 7e993f152f2f50843f216cfc70ec75dbbbc14973 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 14 Nov 2018 19:57:49 -0800 Subject: [PATCH 03/16] Add basic plugin for relationship; tests TBD --- sqlalchemy-stubs/orm/relationships.pyi | 21 +++++++++++++++-- sqlmypy.py | 32 ++++++++++++++++++++++++-- test/test-data/sqlalchemy-basics.test | 28 ++++++++++++++-------- 3 files changed, 67 insertions(+), 14 deletions(-) diff --git a/sqlalchemy-stubs/orm/relationships.pyi b/sqlalchemy-stubs/orm/relationships.pyi index c656294..ef814a2 100644 --- a/sqlalchemy-stubs/orm/relationships.pyi +++ b/sqlalchemy-stubs/orm/relationships.pyi @@ -1,4 +1,4 @@ -from typing import Any, Optional, Generic, TypeVar, Union, overload +from typing import Any, Optional, Generic, TypeVar, Union, overload, Type from .interfaces import ( MANYTOMANY as MANYTOMANY, MANYTOONE as MANYTOONE, @@ -51,7 +51,24 @@ class RelationshipProperty(StrategizedProperty, Generic[_T]): order_by: Any = ... back_populates: Any = ... backref: Any = ... - def __init__(self, argument, secondary: Optional[Any] = ..., + @overload + def __init__(self, argument: Type[_T], secondary: Optional[Any] = ..., + primaryjoin: Optional[Any] = ..., secondaryjoin: Optional[Any] = ..., + foreign_keys: Optional[Any] = ..., uselist: Optional[Any] = ..., + order_by: Any = ..., backref: Optional[Any] = ..., + back_populates: Optional[Any] = ..., post_update: bool = ..., cascade: Union[str, bool] = ..., + extension: Optional[Any] = ..., viewonly: bool = ..., + lazy: Optional[Union[str, bool]] = ..., collection_class: Optional[Any] = ..., + passive_deletes: bool = ..., passive_updates: bool = ..., + remote_side: Optional[Any] = ..., enable_typechecks: bool = ..., + join_depth: Optional[Any] = ..., comparator_factory: Optional[Any] = ..., + single_parent: bool = ..., innerjoin: bool = ..., distinct_target_key: Optional[Any] = ..., + doc: Optional[Any] = ..., active_history: bool = ..., cascade_backrefs: bool = ..., + load_on_pending: bool = ..., bake_queries: bool = ..., + _local_remote_pairs: Optional[Any] = ..., query_class: Optional[Any] = ..., + info: Optional[Any] = ...) -> None: ... + @overload + def __init__(self, argument: Any, secondary: Optional[Any] = ..., primaryjoin: Optional[Any] = ..., secondaryjoin: Optional[Any] = ..., foreign_keys: Optional[Any] = ..., uselist: Optional[Any] = ..., order_by: Any = ..., backref: Optional[Any] = ..., diff --git a/sqlmypy.py b/sqlmypy.py index 95bbe8c..9d07283 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -1,6 +1,7 @@ from mypy.plugin import Plugin -from mypy.nodes import NameExpr, Expression -from mypy.types import UnionType, NoneTyp, Instance, Type +from mypy.nodes import NameExpr, Expression, StrExpr, TypeInfo +from mypy.types import UnionType, NoneTyp, Instance, Type, CallableType, AnyType, TypeOfAny +from mypy.erasetype import erase_typevars from typing import Optional, TYPE_CHECKING, Callable @@ -12,9 +13,35 @@ class BasicSQLAlchemyPlugin(Plugin): def get_function_hook(self, fullname: str) -> Optional[Callable[['FunctionContext'], Type]]: if fullname == 'sqlalchemy.sql.schema.Column': return column_hook + if fullname == 'sqlalchemy.orm.relationships.RelationshipProperty': + return relationship_hook return None +def relationship_hook(ctx: 'FunctionContext') -> Type: + assert isinstance(ctx.default_return_type, Instance) + arg_type = ctx.arg_types[0][0] + arg = ctx.args[0][0] + if isinstance(arg_type, CallableType) and arg_type.is_type_obj(): + return Instance(ctx.default_return_type.type, [erase_typevars(arg_type.ret_type)], + line=ctx.default_return_type.line, + column=ctx.default_return_type.column) + elif isinstance(arg, StrExpr): + name = arg.value + # Private API, but probably needs to be public. + try: + sym = ctx.api.lookup_qualified(name) + except (KeyError, AssertionError): + return ctx.default_return_type + if sym and isinstance(sym.node, TypeInfo): + any = AnyType(TypeOfAny.special_form) + new_arg = Instance(sym.node, [any] * len(sym.node.defn.type_vars)) + return Instance(ctx.default_return_type.type, [new_arg], + line=ctx.default_return_type.line, + column=ctx.default_return_type.column) + return ctx.default_return_type + + def column_hook(ctx: 'FunctionContext') -> Type: assert isinstance(ctx.default_return_type, Instance) # This is very fragile, need to update the plugin API. @@ -34,6 +61,7 @@ def column_hook(ctx: 'FunctionContext') -> Type: nullable = not parse_bool(primary_args[0]) else: nullable = True + # TODO: Add support for literal types when they will be available. if not nullable: return ctx.default_return_type diff --git a/test/test-data/sqlalchemy-basics.test b/test/test-data/sqlalchemy-basics.test index 6df3e9a..0e25bd0 100644 --- a/test/test-data/sqlalchemy-basics.test +++ b/test/test-data/sqlalchemy-basics.test @@ -35,30 +35,38 @@ reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[buil [out] [case testColumnFieldsRelationship] -from typing import Any +from typing import Any, TYPE_CHECKING -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import Column, Integer, String from sqlalchemy.orm import relationship -from sqlalchemy.orm.properties import RelationshipProperty - -Base: Any = declarative_base() +from base import Base +if TYPE_CHECKING: + from other import Other class User(Base): __tablename__ = 'users' id = Column(Integer(), primary_key=True) name = Column(String()) - other: RelationshipProperty[Other] = relationship('Other') + other = relationship('Other') + +user: User +reveal_type(user.other) # E: Revealed type is 'main.Other*' +reveal_type(User.other) # E: Revealed type is 'sqlalchemy.orm.relationships.RelationshipProperty[main.Other*]' +reveal_type(user.other.name) # E: Revealed type is 'builtins.str*' + +[file other.py] +from sqlalchemy import Column, Integer, String +from base import Base class Other(Base): __tablename__ = 'other' id = Column(Integer(), primary_key=True) name = Column(String()) -user: User -reveal_type(user.other) # E: Revealed type is 'main.Other*' -reveal_type(User.other) # E: Revealed type is 'sqlalchemy.orm.relationships.RelationshipProperty[main.Other*]' -reveal_type(user.other.name) # E: Revealed type is 'builtins.str*' +[file base.py] +from typing import Any +from sqlalchemy.ext.declarative import declarative_base +Base: Any = declarative_base() [out] [case testTableColumns] From 1a8f2e640ef31374b4ace8433cb89053b7769c45 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 14 Nov 2018 20:07:36 -0800 Subject: [PATCH 04/16] Add basic declarative base hook; tests TBD --- sqlmypy.py | 22 +++++++++++++++++++++- test/test-data/sqlalchemy-basics.test | 2 +- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/sqlmypy.py b/sqlmypy.py index 9d07283..1425c3e 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -1,5 +1,5 @@ from mypy.plugin import Plugin -from mypy.nodes import NameExpr, Expression, StrExpr, TypeInfo +from mypy.nodes import NameExpr, Expression, StrExpr, TypeInfo, ClassDef, Block, SymbolTable, SymbolTableNode, GDEF from mypy.types import UnionType, NoneTyp, Instance, Type, CallableType, AnyType, TypeOfAny from mypy.erasetype import erase_typevars @@ -8,6 +8,8 @@ if TYPE_CHECKING: from mypy.plugin import FunctionContext +DECL_BASES = set() + class BasicSQLAlchemyPlugin(Plugin): def get_function_hook(self, fullname: str) -> Optional[Callable[['FunctionContext'], Type]]: @@ -17,6 +19,24 @@ def get_function_hook(self, fullname: str) -> Optional[Callable[['FunctionContex return relationship_hook return None + def get_dynamic_class_hook(self, fullname): + if fullname == 'sqlalchemy.ext.declarative.api.declarative_base': + return decl_info_hook + return None + + +def decl_info_hook(ctx): + class_def = ClassDef(ctx.name, Block([])) + class_def.fullname = ctx.api.qualified_name(ctx.name) + + info = TypeInfo(SymbolTable(), class_def, ctx.api.cur_mod_id) + class_def.info = info + obj = ctx.api.builtin_type('builtins.object') + info.mro = [info, obj.type] + info.bases = [obj] + ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info)) + DECL_BASES.add(class_def.fullname) + def relationship_hook(ctx: 'FunctionContext') -> Type: assert isinstance(ctx.default_return_type, Instance) diff --git a/test/test-data/sqlalchemy-basics.test b/test/test-data/sqlalchemy-basics.test index 0e25bd0..0ee8119 100644 --- a/test/test-data/sqlalchemy-basics.test +++ b/test/test-data/sqlalchemy-basics.test @@ -66,7 +66,7 @@ class Other(Base): [file base.py] from typing import Any from sqlalchemy.ext.declarative import declarative_base -Base: Any = declarative_base() +Base = declarative_base() [out] [case testTableColumns] From 9664932b2aac84941725020a238f5facccbb62ae Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 14 Nov 2018 23:04:33 -0800 Subject: [PATCH 05/16] A sketchy implementation of basic __init__ plugin --- sqlmypy.py | 107 ++++++++++++++++++++++++-- test/test-data/sqlalchemy-basics.test | 3 + 2 files changed, 105 insertions(+), 5 deletions(-) diff --git a/sqlmypy.py b/sqlmypy.py index 1425c3e..63c4d76 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -1,14 +1,21 @@ from mypy.plugin import Plugin -from mypy.nodes import NameExpr, Expression, StrExpr, TypeInfo, ClassDef, Block, SymbolTable, SymbolTableNode, GDEF -from mypy.types import UnionType, NoneTyp, Instance, Type, CallableType, AnyType, TypeOfAny +from mypy.nodes import( + NameExpr, Expression, StrExpr, TypeInfo, ClassDef, Block, SymbolTable, SymbolTableNode, GDEF, + AssignmentStmt, CallExpr, RefExpr, Argument, Var, ARG_OPT +) +from mypy.types import ( + UnionType, NoneTyp, Instance, Type, CallableType, AnyType, TypeOfAny, Overloaded +) from mypy.erasetype import erase_typevars +from mypy.maptype import map_instance_to_supertype -from typing import Optional, TYPE_CHECKING, Callable +from typing import Optional, TYPE_CHECKING, Callable, Set, List if TYPE_CHECKING: - from mypy.plugin import FunctionContext + from mypy.plugin import FunctionContext, ClassDefContext -DECL_BASES = set() +DECL_BASES = set() # type: Set[str] +DECL_BASES.add('base.Base') class BasicSQLAlchemyPlugin(Plugin): @@ -24,6 +31,96 @@ def get_dynamic_class_hook(self, fullname): return decl_info_hook return None + def get_class_decorator_hook(self, fullname: str) -> Optional[Callable[['ClassDefContext'], None]]: + if fullname == 'sqlalchemy.ext.declarative.api.as_declarative': + return decl_deco_hook + return None + + def get_base_class_hook(self, fullname: str) -> Optional[Callable[['ClassDefContext'], None]]: + if fullname in DECL_BASES: + return add_init_hook + return None + + +def _get_column_argument(call: CallExpr, name: str) -> Optional[Expression]: + """Return the expression for the specific argument.""" + # This is super sketchy. + callee_node = call.callee.node + callee_node_type = callee_node.names['__init__'].type + assert isinstance(callee_node_type, Overloaded) + if isinstance(call.args[0], StrExpr): + overload_index = 0 + else: + overload_index = 1 + callee_type = callee_node_type.items()[overload_index] + if not callee_type: + return None + + argument = callee_type.argument_by_name(name) + if not argument: + return None + assert argument.name + + for i, (attr_name, attr_value) in enumerate(zip(call.arg_names, call.args)): + if argument.pos is not None and not attr_name and i == argument.pos - 1: + return attr_value + if attr_name == argument.name: + return attr_value + return None + + +def add_init_hook(ctx: 'ClassDefContext') -> None: + from mypy.plugins.common import _add_method + if '__init__' in ctx.cls.info.names: + # Don't override existing definition. + return + col_types = [] # type: List[Type] + col_names = [] # type: List[str] + engine_info = ctx.api.named_type_or_none('sqlalchemy.sql.type_api.TypeEngine').type + for stmt in ctx.cls.defs.body: + if (isinstance(stmt, AssignmentStmt) and isinstance(stmt.lvalues[0], NameExpr) and + isinstance(stmt.rvalue, CallExpr) and + isinstance(stmt.rvalue.callee, RefExpr) and + stmt.rvalue.callee.fullname == 'sqlalchemy.sql.schema.Column'): + # OK, this is what we a looking for. + col_names.append(stmt.lvalues[0].name) + # First try the easy way... + if isinstance(stmt.type, Instance): + col_types.append(stmt.type.args[0]) + continue + # ...otherwise, the hard way (hard because types are not inferred yet, + # we are in semantic analysis pass) + typ_arg = _get_column_argument(stmt.rvalue, 'type_') + if isinstance(typ_arg, RefExpr): + typ_name = typ_arg.fullname + elif isinstance(typ_arg, CallExpr) and isinstance(typ_arg.callee, RefExpr): + typ_name = typ_arg.callee.fullname + else: + col_types.append(AnyType(TypeOfAny.special_form)) + continue + typ = ctx.api.named_type_or_none(typ_name) + if typ and typ.type.has_base('sqlalchemy.sql.type_api.TypeEngine'): + # Using maptype at this stage is dangerous, since if there is an import cycle, + # the result is unpredictable. + engine = map_instance_to_supertype(typ, engine_info) + if engine.args and isinstance(engine.args[0], Instance): + # OK, the column type already analyzed, we are good to go + col_types.append(engine.args[0]) + continue + # Can't figure out type, fall back to Any + col_types.append(AnyType(TypeOfAny.special_form)) + init_args = [] # type: List[Argument] + for typ, name in zip(col_types, col_names): + typ = UnionType([typ, NoneTyp()]) + var = Var(name, typ) + i_arg = Argument(variable=var, type_annotation=typ, initializer=None, kind=ARG_OPT) + init_args.append(i_arg) + _add_method(ctx, '__init__', init_args, NoneTyp()) + + +def decl_deco_hook(ctx: 'ClassDefContext') -> None: + DECL_BASES.add(ctx.cls.fullname) + def decl_info_hook(ctx): class_def = ClassDef(ctx.name, Block([])) diff --git a/test/test-data/sqlalchemy-basics.test b/test/test-data/sqlalchemy-basics.test index 0ee8119..88cc997 100644 --- a/test/test-data/sqlalchemy-basics.test +++ b/test/test-data/sqlalchemy-basics.test @@ -49,6 +49,8 @@ class User(Base): name = Column(String()) other = relationship('Other') +reveal_type(User.__init__) + user: User reveal_type(user.other) # E: Revealed type is 'main.Other*' reveal_type(User.other) # E: Revealed type is 'sqlalchemy.orm.relationships.RelationshipProperty[main.Other*]' @@ -67,6 +69,7 @@ class Other(Base): from typing import Any from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() +reveal_type(Base) [out] [case testTableColumns] From f94f37719475e6c9fb9643bbbeab34b9b345de23 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 18 Nov 2018 23:50:58 +0000 Subject: [PATCH 06/16] Some more cleanup --- sqlmypy.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/sqlmypy.py b/sqlmypy.py index 63c4d76..ac09fc4 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -1,4 +1,5 @@ -from mypy.plugin import Plugin +from mypy.plugin import Plugin, FunctionContext, ClassDefContext +from mypy.plugins.common import add_method from mypy.nodes import( NameExpr, Expression, StrExpr, TypeInfo, ClassDef, Block, SymbolTable, SymbolTableNode, GDEF, AssignmentStmt, CallExpr, RefExpr, Argument, Var, ARG_OPT @@ -9,17 +10,13 @@ from mypy.erasetype import erase_typevars from mypy.maptype import map_instance_to_supertype -from typing import Optional, TYPE_CHECKING, Callable, Set, List - -if TYPE_CHECKING: - from mypy.plugin import FunctionContext, ClassDefContext +from typing import Optional, Callable, Set, List DECL_BASES = set() # type: Set[str] -DECL_BASES.add('base.Base') class BasicSQLAlchemyPlugin(Plugin): - def get_function_hook(self, fullname: str) -> Optional[Callable[['FunctionContext'], Type]]: + def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], Type]]: if fullname == 'sqlalchemy.sql.schema.Column': return column_hook if fullname == 'sqlalchemy.orm.relationships.RelationshipProperty': @@ -31,12 +28,12 @@ def get_dynamic_class_hook(self, fullname): return decl_info_hook return None - def get_class_decorator_hook(self, fullname: str) -> Optional[Callable[['ClassDefContext'], None]]: + def get_class_decorator_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]: if fullname == 'sqlalchemy.ext.declarative.api.as_declarative': return decl_deco_hook return None - def get_base_class_hook(self, fullname: str) -> Optional[Callable[['ClassDefContext'], None]]: + def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]: if fullname in DECL_BASES: return add_init_hook return None @@ -69,8 +66,7 @@ def _get_column_argument(call: CallExpr, name: str) -> Optional[Expression]: return None -def add_init_hook(ctx: 'ClassDefContext') -> None: - from mypy.plugins.common import _add_method +def add_init_hook(ctx: ClassDefContext) -> None: if '__init__' in ctx.cls.info.names: # Don't override existing definition. return @@ -115,10 +111,10 @@ def add_init_hook(ctx: 'ClassDefContext') -> None: var = Var(name, typ) i_arg = Argument(variable=var, type_annotation=typ, initializer=None, kind=ARG_OPT) init_args.append(i_arg) - _add_method(ctx, '__init__', init_args, NoneTyp()) + add_method(ctx, '__init__', init_args, NoneTyp()) -def decl_deco_hook(ctx: 'ClassDefContext') -> None: +def decl_deco_hook(ctx: ClassDefContext) -> None: DECL_BASES.add(ctx.cls.fullname) @@ -135,7 +131,7 @@ def decl_info_hook(ctx): DECL_BASES.add(class_def.fullname) -def relationship_hook(ctx: 'FunctionContext') -> Type: +def relationship_hook(ctx: FunctionContext) -> Type: assert isinstance(ctx.default_return_type, Instance) arg_type = ctx.arg_types[0][0] arg = ctx.args[0][0] @@ -159,7 +155,7 @@ def relationship_hook(ctx: 'FunctionContext') -> Type: return ctx.default_return_type -def column_hook(ctx: 'FunctionContext') -> Type: +def column_hook(ctx: FunctionContext) -> Type: assert isinstance(ctx.default_return_type, Instance) # This is very fragile, need to update the plugin API. if len(ctx.args) in (5, 6): # overloads with and without the name From fb7b80147e4f4a8a228073a6eed964692d308e4f Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 15 Jan 2019 14:22:53 +0000 Subject: [PATCH 07/16] Sync mypy once more --- external/mypy | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/mypy b/external/mypy index 01c2686..f82319f 160000 --- a/external/mypy +++ b/external/mypy @@ -1 +1 @@ -Subproject commit 01c268644d1d22506442df4e21b39c04710b7e8b +Subproject commit f82319f85043ac2afa9708cc90a7fec32a3aa767 From 51a70a85d58d7bb4c81eb58d648788bd1c74b75c Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 15 Jan 2019 15:22:08 +0000 Subject: [PATCH 08/16] Fix tests --- sqlalchemy-stubs/orm/relationships.pyi | 5 ++--- test/test-data/sqlalchemy-basics.test | 7 ++----- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/sqlalchemy-stubs/orm/relationships.pyi b/sqlalchemy-stubs/orm/relationships.pyi index a25cb45..e55c7b8 100644 --- a/sqlalchemy-stubs/orm/relationships.pyi +++ b/sqlalchemy-stubs/orm/relationships.pyi @@ -11,12 +11,11 @@ def remote(expr): ... def foreign(expr): ... -_T = TypeVar('_T') _T_co = TypeVar('_T_co', covariant=True) # Note: typical use case is where argument is a string, so this will require -# a plugin to infer '_T', otherwise a user will need to write an explicit annotation. +# a plugin to infer '_T_co', otherwise a user will need to write an explicit annotation. # It is not clear whether RelationshipProperty is covariant at this stage since # many types are still missing. class RelationshipProperty(StrategizedProperty, Generic[_T_co]): @@ -56,7 +55,7 @@ class RelationshipProperty(StrategizedProperty, Generic[_T_co]): back_populates: Any = ... backref: Any = ... @overload - def __init__(self, argument: Type[_T], secondary: Optional[Any] = ..., + def __init__(self, argument: Type[_T_co], secondary: Optional[Any] = ..., primaryjoin: Optional[Any] = ..., secondaryjoin: Optional[Any] = ..., foreign_keys: Optional[Any] = ..., uselist: Optional[Any] = ..., order_by: Any = ..., backref: Optional[Any] = ..., diff --git a/test/test-data/sqlalchemy-basics.test b/test/test-data/sqlalchemy-basics.test index 0d62ed9..e705adf 100644 --- a/test/test-data/sqlalchemy-basics.test +++ b/test/test-data/sqlalchemy-basics.test @@ -61,11 +61,9 @@ class User(Base): name = Column(String()) other = relationship('Other') -reveal_type(User.__init__) - user: User -reveal_type(user.other) # E: Revealed type is 'main.Other*' -reveal_type(User.other) # E: Revealed type is 'sqlalchemy.orm.relationships.RelationshipProperty[main.Other*]' +reveal_type(user.other) # E: Revealed type is 'other.Other*' +reveal_type(User.other) # E: Revealed type is 'sqlalchemy.orm.relationships.RelationshipProperty[other.Other*]' reveal_type(user.other.name) # E: Revealed type is 'builtins.str*' [file other.py] @@ -81,7 +79,6 @@ class Other(Base): from typing import Any from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() -reveal_type(Base) [out] [case testTableColumns] From 9ebe55a2cbe2fcdcb17363c5603c9f089d0f63d1 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 15 Jan 2019 16:53:39 +0000 Subject: [PATCH 09/16] Some cleanups --- sqlmypy.py | 60 +++++++++++++++++-- test/test-data/sqlalchemy-basics.test | 3 +- .../test-data/sqlalchemy-plugin-features.test | 0 test/testsql.py | 3 +- 4 files changed, 57 insertions(+), 9 deletions(-) create mode 100644 test/test-data/sqlalchemy-plugin-features.test diff --git a/sqlmypy.py b/sqlmypy.py index ac09fc4..480ea2b 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -10,17 +10,43 @@ from mypy.erasetype import erase_typevars from mypy.maptype import map_instance_to_supertype -from typing import Optional, Callable, Set, List +from typing import Optional, Callable, List -DECL_BASES = set() # type: Set[str] + +def is_declarative(info: TypeInfo) -> bool: + """Check if this is a subclass of a declarative base.""" + if info.mro: + for base in info.mro: + metadata = info.metadata.get('sqlalchemy') + if metadata and metadata.get('declarative_base'): + return True + return False + + +def set_declarative(info: TypeInfo) -> None: + """Record given class as a declarative base.""" + info.metadata.setdefault('sqlalchemy', {})['declarative_base'] = True class BasicSQLAlchemyPlugin(Plugin): + """Basic plugin to support simple operations with models. + + Currently supported functionality: + * Recognize dynamically defined declarative bases. + * Add an __init__() method to models. + * Provide better types for 'Column's and 'RelationshipProperty's + using flags 'primary_key', 'nullable', 'uselist', etc. + """ def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], Type]]: if fullname == 'sqlalchemy.sql.schema.Column': return column_hook if fullname == 'sqlalchemy.orm.relationships.RelationshipProperty': return relationship_hook + sym = self.lookup_fully_qualified(fullname) + if sym and isinstance(sym.node, TypeInfo): + # May be a model instantiation + if is_declarative(sym.node): + return model_hook return None def get_dynamic_class_hook(self, fullname): @@ -34,8 +60,10 @@ def get_class_decorator_hook(self, fullname: str) -> Optional[Callable[[ClassDef return None def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]: - if fullname in DECL_BASES: - return add_init_hook + sym = self.lookup_fully_qualified(fullname) + if sym and isinstance(sym.node, TypeInfo): + if is_declarative(sym.node): + return add_init_hook return None @@ -115,10 +143,26 @@ def add_init_hook(ctx: ClassDefContext) -> None: def decl_deco_hook(ctx: ClassDefContext) -> None: - DECL_BASES.add(ctx.cls.fullname) + """Support declaring base class as declarative with a decorator. + + For example: + from from sqlalchemy.ext.declarative import as_declarative + + @as_declarative + class Base: + ... + """ + set_declarative(ctx.cls.info) def decl_info_hook(ctx): + """Support dynamically defining declarative bases. + + For example: + from sqlalchemy.ext.declarative import declarative_base + + Base = declarative_base() + """ class_def = ClassDef(ctx.name, Block([])) class_def.fullname = ctx.api.qualified_name(ctx.name) @@ -128,7 +172,11 @@ def decl_info_hook(ctx): info.mro = [info, obj.type] info.bases = [obj] ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info)) - DECL_BASES.add(class_def.fullname) + set_declarative(info) + + +def model_hook(ctx: FunctionContext) -> Type: + return ctx.default_return_type def relationship_hook(ctx: FunctionContext) -> Type: diff --git a/test/test-data/sqlalchemy-basics.test b/test/test-data/sqlalchemy-basics.test index e705adf..c03e868 100644 --- a/test/test-data/sqlalchemy-basics.test +++ b/test/test-data/sqlalchemy-basics.test @@ -47,7 +47,7 @@ reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[buil [out] [case testColumnFieldsRelationship] -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING from sqlalchemy import Column, Integer, String from sqlalchemy.orm import relationship @@ -76,7 +76,6 @@ class Other(Base): name = Column(String()) [file base.py] -from typing import Any from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() [out] diff --git a/test/test-data/sqlalchemy-plugin-features.test b/test/test-data/sqlalchemy-plugin-features.test new file mode 100644 index 0000000..e69de29 diff --git a/test/testsql.py b/test/testsql.py index f074a40..d04dd6c 100644 --- a/test/testsql.py +++ b/test/testsql.py @@ -25,7 +25,8 @@ class SQLDataSuite(DataSuite): 'sqlalchemy-sql-elements.test', 'sqlalchemy-sql-sqltypes.test', 'sqlalchemy-sql-selectable.test', - 'sqlalchemy-sql-schema.test'] + 'sqlalchemy-sql-schema.test', + 'sqlalchemy-plugin-features.test'] data_prefix = test_data_prefix def run_case(self, testcase: DataDrivenTestCase) -> None: From c056851f8ea799cfb7290afa5b6850e03860c6ff Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 15 Jan 2019 18:00:37 +0000 Subject: [PATCH 10/16] Update model __init__ hooks --- sqlmypy.py | 104 +++++++++--------- .../test-data/sqlalchemy-plugin-features.test | 24 ++++ 2 files changed, 79 insertions(+), 49 deletions(-) diff --git a/sqlmypy.py b/sqlmypy.py index 480ea2b..ea17873 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -2,22 +2,26 @@ from mypy.plugins.common import add_method from mypy.nodes import( NameExpr, Expression, StrExpr, TypeInfo, ClassDef, Block, SymbolTable, SymbolTableNode, GDEF, - AssignmentStmt, CallExpr, RefExpr, Argument, Var, ARG_OPT + CallExpr, Argument, Var, ARG_STAR2 ) from mypy.types import ( UnionType, NoneTyp, Instance, Type, CallableType, AnyType, TypeOfAny, Overloaded ) from mypy.erasetype import erase_typevars -from mypy.maptype import map_instance_to_supertype -from typing import Optional, Callable, List +from typing import Optional, Callable, Dict, TYPE_CHECKING +if TYPE_CHECKING: + from typing_extensions import Final + +COLUMN_NAME = 'sqlalchemy.sql.schema.Column' # type: Final +RELATIONSHIP_NAME = 'sqlalchemy.orm.relationships.RelationshipProperty' # type: Final def is_declarative(info: TypeInfo) -> bool: """Check if this is a subclass of a declarative base.""" if info.mro: for base in info.mro: - metadata = info.metadata.get('sqlalchemy') + metadata = base.metadata.get('sqlalchemy') if metadata and metadata.get('declarative_base'): return True return False @@ -38,9 +42,9 @@ class BasicSQLAlchemyPlugin(Plugin): using flags 'primary_key', 'nullable', 'uselist', etc. """ def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], Type]]: - if fullname == 'sqlalchemy.sql.schema.Column': + if fullname == COLUMN_NAME: return column_hook - if fullname == 'sqlalchemy.orm.relationships.RelationshipProperty': + if fullname == RELATIONSHIP_NAME: return relationship_hook sym = self.lookup_fully_qualified(fullname) if sym and isinstance(sym.node, TypeInfo): @@ -95,51 +99,19 @@ def _get_column_argument(call: CallExpr, name: str) -> Optional[Expression]: def add_init_hook(ctx: ClassDefContext) -> None: + """Add a dummy __init__() to a model and record it is generated. + + Instantiation will be checked more precisely when we inferred types + (using get_function_hook and model_hook). + """ if '__init__' in ctx.cls.info.names: # Don't override existing definition. return - col_types = [] # type: List[Type] - col_names = [] # type: List[str] - engine_info = ctx.api.named_type_or_none('sqlalchemy.sql.type_api.TypeEngine').type - for stmt in ctx.cls.defs.body: - if (isinstance(stmt, AssignmentStmt) and isinstance(stmt.lvalues[0], NameExpr) and - isinstance(stmt.rvalue, CallExpr) and - isinstance(stmt.rvalue.callee, RefExpr) and - stmt.rvalue.callee.fullname == 'sqlalchemy.sql.schema.Column'): - # OK, this is what we a looking for. - col_names.append(stmt.lvalues[0].name) - # First try the easy way... - if isinstance(stmt.type, Instance): - col_types.append(stmt.type.args[0]) - continue - # ...otherwise, the hard way (hard because types are not inferred yet, - # we are in semantic analysis pass) - typ_arg = _get_column_argument(stmt.rvalue, 'type_') - if isinstance(typ_arg, RefExpr): - typ_name = typ_arg.fullname - elif isinstance(typ_arg, CallExpr) and isinstance(typ_arg.callee, RefExpr): - typ_name = typ_arg.callee.fullname - else: - col_types.append(AnyType(TypeOfAny.special_form)) - continue - typ = ctx.api.named_type_or_none(typ_name) - if typ and typ.type.has_base('sqlalchemy.sql.type_api.TypeEngine'): - # Using maptype at this stage is dangerous, since if there is an import cycle, - # the result is unpredictable. - engine = map_instance_to_supertype(typ, engine_info) - if engine.args and isinstance(engine.args[0], Instance): - # OK, the column type already analyzed, we are good to go - col_types.append(engine.args[0]) - continue - # Can't figure out type, fall back to Any - col_types.append(AnyType(TypeOfAny.special_form)) - init_args = [] # type: List[Argument] - for typ, name in zip(col_types, col_names): - typ = UnionType([typ, NoneTyp()]) - var = Var(name, typ) - i_arg = Argument(variable=var, type_annotation=typ, initializer=None, kind=ARG_OPT) - init_args.append(i_arg) - add_method(ctx, '__init__', init_args, NoneTyp()) + typ = AnyType(TypeOfAny.special_form) + var = Var('kwargs', typ) + kw_arg = Argument(variable=var, type_annotation=typ, initializer=None, kind=ARG_STAR2) + add_method(ctx, '__init__', [kw_arg], NoneTyp()) + ctx.cls.info.metadata.setdefault('sqlalchemy', {})['generated_init'] = True def decl_deco_hook(ctx: ClassDefContext) -> None: @@ -176,6 +148,40 @@ def decl_info_hook(ctx): def model_hook(ctx: FunctionContext) -> Type: + """More precise model instantiation check. + + Note: sub-models are not supported. + Note: this is still not perfect, since the context for inference of + argument types is 'Any'. + """ + assert isinstance(ctx.default_return_type, Instance) + model = ctx.default_return_type.type + metadata = model.metadata.get('sqlalchemy') + if not metadata or not metadata.get('generated_init'): + return ctx.default_return_type + + # Collect column names and types defined in the model + # TODO: cache this? + expected_types = {} # type: Dict[str, Type] + for name, sym in model.names.items(): + if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance): + tp = sym.node.type + if tp.type.fullname() in (COLUMN_NAME, RELATIONSHIP_NAME): + assert len(tp.args) == 1 + expected_types[name] = tp.args[0] + + assert len(ctx.arg_names) == 1 # only **kwargs in generated __init__ + assert len(ctx.arg_types) == 1 + for actual_name, actual_type in zip(ctx.arg_names[0], ctx.arg_types[0]): + if actual_name not in expected_types: + ctx.api.fail('Unexpected column "{}" for model "{}"'.format(actual_name, model.name()), + ctx.context) + continue + # Using private API to simplify life. + ctx.api.check_subtype(actual_type, expected_types[actual_name], + ctx.context, + 'Incompatible type for "{}" of "{}"'.format(actual_name, model.name()), + 'got', 'expected') return ctx.default_return_type @@ -222,7 +228,7 @@ def column_hook(ctx: FunctionContext) -> Type: nullable = not parse_bool(primary_args[0]) else: nullable = True - # TODO: Add support for literal types when they will be available. + # TODO: Add support for literal types. if not nullable: return ctx.default_return_type diff --git a/test/test-data/sqlalchemy-plugin-features.test b/test/test-data/sqlalchemy-plugin-features.test index e69de29..be0ff54 100644 --- a/test/test-data/sqlalchemy-plugin-features.test +++ b/test/test-data/sqlalchemy-plugin-features.test @@ -0,0 +1,24 @@ +[case testModelInitColumnDeclared] +# flags: --strict-optional +from sqlalchemy import Column, Integer, String +from base import Base +from typing import Optional + +class User(Base): + __tablename__ = 'users' + id = Column(Integer(), primary_key=True) + name = Column(String()) + +oi: Optional[int] +os: Optional[str] + +User() +User(1, 2) # E: Too many arguments for "User" +User(id=int(), name=str()) +User(id=oi) # E: Incompatible type for "id" of "User" (got "Optional[int]", expected "int") +User(name=os) + +[file base.py] +from sqlalchemy.ext.declarative import declarative_base +Base = declarative_base() +[out] From 87f2d8c77898a8997f537f22a2296375efb8564d Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 15 Jan 2019 18:19:25 +0000 Subject: [PATCH 11/16] Update column hook --- sqlmypy.py | 114 +++++++++++++++++++++++++---------------------------- 1 file changed, 54 insertions(+), 60 deletions(-) diff --git a/sqlmypy.py b/sqlmypy.py index ea17873..baf170b 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -71,33 +71,6 @@ def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefConte return None -def _get_column_argument(call: CallExpr, name: str) -> Optional[Expression]: - """Return the expression for the specific argument.""" - # This is super sketchy. - callee_node = call.callee.node - callee_node_type = callee_node.names['__init__'].type - assert isinstance(callee_node_type, Overloaded) - if isinstance(call.args[0], StrExpr): - overload_index = 0 - else: - overload_index = 1 - callee_type = callee_node_type.items()[overload_index] - if not callee_type: - return None - - argument = callee_type.argument_by_name(name) - if not argument: - return None - assert argument.name - - for i, (attr_name, attr_value) in enumerate(zip(call.arg_names, call.args)): - if argument.pos is not None and not attr_name and i == argument.pos - 1: - return attr_value - if attr_name == argument.name: - return attr_value - return None - - def add_init_hook(ctx: ClassDefContext) -> None: """Add a dummy __init__() to a model and record it is generated. @@ -179,12 +152,63 @@ def model_hook(ctx: FunctionContext) -> Type: continue # Using private API to simplify life. ctx.api.check_subtype(actual_type, expected_types[actual_name], - ctx.context, - 'Incompatible type for "{}" of "{}"'.format(actual_name, model.name()), - 'got', 'expected') + ctx.context, + 'Incompatible type for "{}" of "{}"'.format(actual_name, model.name()), + 'got', 'expected') return ctx.default_return_type +def get_argument_by_name(ctx: FunctionContext, name: str) -> Optional[Expression]: + """Return the expression for the specific argument. + + This helper should only be used with non-star arguments. + """ + if name not in ctx.callee_arg_names: + return None + idx = ctx.callee_arg_names.index(name) + args = ctx.args[idx] + if len(args) != 1: + # Either an error or no value passed. + return None + return args[0] + + +def column_hook(ctx: FunctionContext) -> Type: + """Infer better types for Column calls. + + Examples: + Column(String) -> Column[Optional[str]] + Column(String, primary_key=True) -> Column[str] + Column(String, nullable=False) -> Column[str] + Column(String, default=...) -> Column[str] + Column(String, default=..., nullable=True) -> Column[Optional[str]] + + TODO: check the type of 'default'. + """ + assert isinstance(ctx.default_return_type, Instance) + + nullable_arg = get_argument_by_name(ctx, 'nullable') + primary_arg = get_argument_by_name(ctx, 'primary_key') + default_arg = get_argument_by_name(ctx, 'default') + + if nullable_arg: + nullable = parse_bool(nullable_arg) + else: + if primary_arg: + nullable = not parse_bool(primary_arg) + else: + nullable = default_arg is None + # TODO: Add support for literal types. + + if not nullable: + return ctx.default_return_type + assert len(ctx.default_return_type.args) == 1 + arg_type = ctx.default_return_type.args[0] + return Instance(ctx.default_return_type.type, [UnionType([arg_type, NoneTyp()])], + line=ctx.default_return_type.line, + column=ctx.default_return_type.column) + + def relationship_hook(ctx: FunctionContext) -> Type: assert isinstance(ctx.default_return_type, Instance) arg_type = ctx.arg_types[0][0] @@ -209,36 +233,6 @@ def relationship_hook(ctx: FunctionContext) -> Type: return ctx.default_return_type -def column_hook(ctx: FunctionContext) -> Type: - assert isinstance(ctx.default_return_type, Instance) - # This is very fragile, need to update the plugin API. - if len(ctx.args) in (5, 6): # overloads with and without the name - nullable_index = len(ctx.args) - 2 - primary_index = len(ctx.args) - 3 - else: - # Something new, give up. - return ctx.default_return_type - - nullable_args = ctx.args[nullable_index] - primary_args = ctx.args[primary_index] - if nullable_args: - nullable = parse_bool(nullable_args[0]) - else: - if primary_args: - nullable = not parse_bool(primary_args[0]) - else: - nullable = True - # TODO: Add support for literal types. - - if not nullable: - return ctx.default_return_type - assert len(ctx.default_return_type.args) == 1 - arg_type = ctx.default_return_type.args[0] - return Instance(ctx.default_return_type.type, [UnionType([arg_type, NoneTyp()])], - line=ctx.default_return_type.line, - column=ctx.default_return_type.column) - - # We really need to add this to TypeChecker API def parse_bool(expr: Expression) -> Optional[bool]: if isinstance(expr, NameExpr): From e2a7b217f1570a2c90ef9bd69144fb8e2a007094 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 15 Jan 2019 19:18:18 +0000 Subject: [PATCH 12/16] Update the relationship hook --- sqlalchemy-stubs/orm/__init__.pyi | 2 +- sqlmypy.py | 64 ++++++++++++++++++++------- test/test-data/sqlalchemy-basics.test | 5 ++- 3 files changed, 52 insertions(+), 19 deletions(-) diff --git a/sqlalchemy-stubs/orm/__init__.pyi b/sqlalchemy-stubs/orm/__init__.pyi index 57665a4..ff5c7ad 100644 --- a/sqlalchemy-stubs/orm/__init__.pyi +++ b/sqlalchemy-stubs/orm/__init__.pyi @@ -43,7 +43,7 @@ from .strategy_options import Load as Load def create_session(bind: Optional[Any] = ..., **kwargs): ... -relationship = RelationshipProperty[Any] +relationship = RelationshipProperty def relation(*arg, **kw): ... def dynamic_loader(argument, **kw): ... diff --git a/sqlmypy.py b/sqlmypy.py index baf170b..62ba196 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -5,9 +5,9 @@ CallExpr, Argument, Var, ARG_STAR2 ) from mypy.types import ( - UnionType, NoneTyp, Instance, Type, CallableType, AnyType, TypeOfAny, Overloaded + UnionType, NoneTyp, Instance, Type, CallableType, AnyType, TypeOfAny, Overloaded, + UninhabitedType ) -from mypy.erasetype import erase_typevars from typing import Optional, Callable, Dict, TYPE_CHECKING if TYPE_CHECKING: @@ -210,27 +210,59 @@ def column_hook(ctx: FunctionContext) -> Type: def relationship_hook(ctx: FunctionContext) -> Type: + """Support basic use cases for relationships. + + Examples: + from sqlalchemy.orm import relationship + + from one import OneModel + if TYPE_CHECKING: + from other import OtherModel + + class User(Base): + __tablename__ = 'users' + id = Column(Integer(), primary_key=True) + one = relationship(OneModel) + other = relationship("OtherModel") + + This also tries to infer the type argument for 'RelationshipProperty' + using the 'uselist' flag. + """ assert isinstance(ctx.default_return_type, Instance) - arg_type = ctx.arg_types[0][0] - arg = ctx.args[0][0] - if isinstance(arg_type, CallableType) and arg_type.is_type_obj(): - return Instance(ctx.default_return_type.type, [erase_typevars(arg_type.ret_type)], - line=ctx.default_return_type.line, - column=ctx.default_return_type.column) - elif isinstance(arg, StrExpr): + original_type_arg = ctx.default_return_type.args[0] + arg = get_argument_by_name(ctx, 'argument') + uselist_arg = get_argument_by_name(ctx, 'uselist') + if not isinstance(original_type_arg, UninhabitedType): + # The type was inferred using the stub signature. + return ctx.default_return_type + if isinstance(arg, StrExpr): name = arg.value - # Private API, but probably needs to be public. + # Private API for local lookup, but probably needs to be public. try: - sym = ctx.api.lookup_qualified(name) + sym = ctx.api.lookup_qualified(name) # type: Optional[SymbolTableNode] except (KeyError, AssertionError): - return ctx.default_return_type + sym = None if sym and isinstance(sym.node, TypeInfo): + if not is_declarative(sym.node): + ctx.api.fail('First argument to relationship must be a model', ctx.context) any = AnyType(TypeOfAny.special_form) new_arg = Instance(sym.node, [any] * len(sym.node.defn.type_vars)) - return Instance(ctx.default_return_type.type, [new_arg], - line=ctx.default_return_type.line, - column=ctx.default_return_type.column) - return ctx.default_return_type + else: + ctx.api.fail('Cannot find model"{}"'.format(name), ctx.context) + ctx.api.note('Only imported models can be found', ctx.context) + ctx.api.note('Use "if TYPE_CHECKING: ..." to avoid import cycles', ctx.context) + new_arg = AnyType(TypeOfAny.from_error) + else: + new_arg = original_type_arg + if uselist_arg: + if parse_bool(uselist_arg): + new_arg = ctx.api.named_generic_type('typing.Iterable', [new_arg]) + else: + ctx.api.fail('Cannot figure out kind of relationship', ctx.context) + ctx.api.note('Either add an annotation or use an explicit "uselist" flag', ctx.context) + return Instance(ctx.default_return_type.type, [new_arg], + line=ctx.default_return_type.line, + column=ctx.default_return_type.column) # We really need to add this to TypeChecker API diff --git a/test/test-data/sqlalchemy-basics.test b/test/test-data/sqlalchemy-basics.test index c03e868..817983c 100644 --- a/test/test-data/sqlalchemy-basics.test +++ b/test/test-data/sqlalchemy-basics.test @@ -50,7 +50,7 @@ reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[buil from typing import TYPE_CHECKING from sqlalchemy import Column, Integer, String -from sqlalchemy.orm import relationship +from sqlalchemy.orm import relationship, RelationshipProperty from base import Base if TYPE_CHECKING: from other import Other @@ -59,7 +59,8 @@ class User(Base): __tablename__ = 'users' id = Column(Integer(), primary_key=True) name = Column(String()) - other = relationship('Other') + other = relationship('Other', uselist=False) + another: RelationshipProperty[Other] = relationship('Other') user: User reveal_type(user.other) # E: Revealed type is 'other.Other*' From 463b582d5284d54fb93f57dafb95d3d4fd49fb62 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 15 Jan 2019 19:40:54 +0000 Subject: [PATCH 13/16] Remove unused imports --- sqlmypy.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sqlmypy.py b/sqlmypy.py index 62ba196..6189ee4 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -2,11 +2,10 @@ from mypy.plugins.common import add_method from mypy.nodes import( NameExpr, Expression, StrExpr, TypeInfo, ClassDef, Block, SymbolTable, SymbolTableNode, GDEF, - CallExpr, Argument, Var, ARG_STAR2 + Argument, Var, ARG_STAR2 ) from mypy.types import ( - UnionType, NoneTyp, Instance, Type, CallableType, AnyType, TypeOfAny, Overloaded, - UninhabitedType + UnionType, NoneTyp, Instance, Type, AnyType, TypeOfAny, UninhabitedType ) from typing import Optional, Callable, Dict, TYPE_CHECKING @@ -248,7 +247,7 @@ class User(Base): any = AnyType(TypeOfAny.special_form) new_arg = Instance(sym.node, [any] * len(sym.node.defn.type_vars)) else: - ctx.api.fail('Cannot find model"{}"'.format(name), ctx.context) + ctx.api.fail('Cannot find model "{}"'.format(name), ctx.context) ctx.api.note('Only imported models can be found', ctx.context) ctx.api.note('Use "if TYPE_CHECKING: ..." to avoid import cycles', ctx.context) new_arg = AnyType(TypeOfAny.from_error) From 7f2e2ce855ddb27a24d2e10e5e25309b457e3517 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 16 Jan 2019 17:12:25 +0000 Subject: [PATCH 14/16] Fix relationahip; add tests --- sqlalchemy-stubs/orm/relationships.pyi | 17 -- sqlmypy.py | 42 +++-- test/test-data/sqlalchemy-basics.test | 26 ++- .../test-data/sqlalchemy-plugin-features.test | 153 ++++++++++++++++++ test/test-data/sqlalchemy-sql-elements.test | 2 +- 5 files changed, 194 insertions(+), 46 deletions(-) diff --git a/sqlalchemy-stubs/orm/relationships.pyi b/sqlalchemy-stubs/orm/relationships.pyi index e55c7b8..ecb6b0b 100644 --- a/sqlalchemy-stubs/orm/relationships.pyi +++ b/sqlalchemy-stubs/orm/relationships.pyi @@ -54,23 +54,6 @@ class RelationshipProperty(StrategizedProperty, Generic[_T_co]): order_by: Any = ... back_populates: Any = ... backref: Any = ... - @overload - def __init__(self, argument: Type[_T_co], secondary: Optional[Any] = ..., - primaryjoin: Optional[Any] = ..., secondaryjoin: Optional[Any] = ..., - foreign_keys: Optional[Any] = ..., uselist: Optional[Any] = ..., - order_by: Any = ..., backref: Optional[Any] = ..., - back_populates: Optional[Any] = ..., post_update: bool = ..., cascade: Union[str, bool] = ..., - extension: Optional[Any] = ..., viewonly: bool = ..., - lazy: Optional[Union[str, bool]] = ..., collection_class: Optional[Any] = ..., - passive_deletes: bool = ..., passive_updates: bool = ..., - remote_side: Optional[Any] = ..., enable_typechecks: bool = ..., - join_depth: Optional[Any] = ..., comparator_factory: Optional[Any] = ..., - single_parent: bool = ..., innerjoin: bool = ..., distinct_target_key: Optional[Any] = ..., - doc: Optional[Any] = ..., active_history: bool = ..., cascade_backrefs: bool = ..., - load_on_pending: bool = ..., bake_queries: bool = ..., - _local_remote_pairs: Optional[Any] = ..., query_class: Optional[Any] = ..., - info: Optional[Any] = ...) -> None: ... - @overload def __init__(self, argument: Any, secondary: Optional[Any] = ..., primaryjoin: Optional[Any] = ..., secondaryjoin: Optional[Any] = ..., foreign_keys: Optional[Any] = ..., uselist: Optional[Any] = ..., diff --git a/sqlmypy.py b/sqlmypy.py index 6189ee4..3938cf2 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -5,8 +5,9 @@ Argument, Var, ARG_STAR2 ) from mypy.types import ( - UnionType, NoneTyp, Instance, Type, AnyType, TypeOfAny, UninhabitedType + UnionType, NoneTyp, Instance, Type, AnyType, TypeOfAny, UninhabitedType, CallableType ) +from mypy.typevars import fill_typevars_with_any from typing import Optional, Callable, Dict, TYPE_CHECKING if TYPE_CHECKING: @@ -172,6 +173,18 @@ def get_argument_by_name(ctx: FunctionContext, name: str) -> Optional[Expression return args[0] +def get_argtype_by_name(ctx: FunctionContext, name: str) -> Optional[Type]: + """Same as above but for argument type.""" + if name not in ctx.callee_arg_names: + return None + idx = ctx.callee_arg_names.index(name) + arg_types = ctx.arg_types[idx] + if len(arg_types) != 1: + # Either an error or no value passed. + return None + return arg_types[0] + + def column_hook(ctx: FunctionContext) -> Type: """Infer better types for Column calls. @@ -229,11 +242,13 @@ class User(Base): """ assert isinstance(ctx.default_return_type, Instance) original_type_arg = ctx.default_return_type.args[0] + has_no_annotation = isinstance(original_type_arg, UninhabitedType) + arg = get_argument_by_name(ctx, 'argument') + arg_type = get_argtype_by_name(ctx, 'argument') + uselist_arg = get_argument_by_name(ctx, 'uselist') - if not isinstance(original_type_arg, UninhabitedType): - # The type was inferred using the stub signature. - return ctx.default_return_type + if isinstance(arg, StrExpr): name = arg.value # Private API for local lookup, but probably needs to be public. @@ -242,23 +257,28 @@ class User(Base): except (KeyError, AssertionError): sym = None if sym and isinstance(sym.node, TypeInfo): - if not is_declarative(sym.node): - ctx.api.fail('First argument to relationship must be a model', ctx.context) - any = AnyType(TypeOfAny.special_form) - new_arg = Instance(sym.node, [any] * len(sym.node.defn.type_vars)) + new_arg = fill_typevars_with_any(sym.node) else: ctx.api.fail('Cannot find model "{}"'.format(name), ctx.context) ctx.api.note('Only imported models can be found', ctx.context) ctx.api.note('Use "if TYPE_CHECKING: ..." to avoid import cycles', ctx.context) new_arg = AnyType(TypeOfAny.from_error) else: - new_arg = original_type_arg + if isinstance(arg_type, CallableType) and arg_type.is_type_obj(): + new_arg = fill_typevars_with_any(arg_type.type_object()) + else: + # Something complex, stay silent for now. + new_arg = AnyType(TypeOfAny.special_form) + + # We figured out, the model type. Now check if we need to wrap it in Iterable if uselist_arg: if parse_bool(uselist_arg): new_arg = ctx.api.named_generic_type('typing.Iterable', [new_arg]) else: - ctx.api.fail('Cannot figure out kind of relationship', ctx.context) - ctx.api.note('Either add an annotation or use an explicit "uselist" flag', ctx.context) + if has_no_annotation: + ctx.api.fail('Cannot figure out kind of relationship', ctx.context) + ctx.api.note('Suggestion: use an explicit "uselist" flag', ctx.context) + return Instance(ctx.default_return_type.type, [new_arg], line=ctx.default_return_type.line, column=ctx.default_return_type.column) diff --git a/test/test-data/sqlalchemy-basics.test b/test/test-data/sqlalchemy-basics.test index 817983c..9fecf29 100644 --- a/test/test-data/sqlalchemy-basics.test +++ b/test/test-data/sqlalchemy-basics.test @@ -1,10 +1,8 @@ [case testColumnFieldsInferred] -from typing import Any - from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import Column, Integer, String -Base: Any = declarative_base() +Base = declarative_base() class User(Base): __tablename__ = 'users' @@ -13,7 +11,7 @@ class User(Base): user: User reveal_type(user.id) # E: Revealed type is 'builtins.int*' -reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[builtins.str*]' +reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[Union[builtins.str*, None]]' [out] [case testTypeEngineCovariance] @@ -29,17 +27,15 @@ func(String()) # E: Value of type variable "T" of "func" cannot be "str" [out] [case testColumnFieldsInferredInstance] -from typing import Any - from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import Column, Integer, String -Base: Any = declarative_base() +Base = declarative_base() class User(Base): __tablename__ = 'users' id = Column(Integer(), primary_key=True) - name = Column(String()) + name = Column(String(), nullable=False) user: User reveal_type(user.id) # E: Revealed type is 'builtins.int*' @@ -74,7 +70,7 @@ from base import Base class Other(Base): __tablename__ = 'other' id = Column(Integer(), primary_key=True) - name = Column(String()) + name = Column(String(), nullable=False) [file base.py] from sqlalchemy.ext.declarative import declarative_base @@ -94,12 +90,10 @@ reveal_type(users.c.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column*[ [out] [case testColumnFieldsInferred_python2] -from typing import Any - from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import Column, Integer, String -Base = declarative_base() # type: Any +Base = declarative_base() class User(Base): __tablename__ = 'users' @@ -108,21 +102,19 @@ class User(Base): user = User() reveal_type(user.id) # E: Revealed type is 'builtins.int*' -reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[builtins.unicode*]' +reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[Union[builtins.unicode*, None]]' [out] [case testColumnFieldsInferredInstance_python2] -from typing import Any - from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import Column, Integer, String -Base = declarative_base() # type: Any +Base = declarative_base() class User(Base): __tablename__ = 'users' id = Column(Integer(), primary_key=True) - name = Column(String()) + name = Column(String(), default='John Doe') user = User() reveal_type(user.id) # E: Revealed type is 'builtins.int*' diff --git a/test/test-data/sqlalchemy-plugin-features.test b/test/test-data/sqlalchemy-plugin-features.test index be0ff54..30ca7e7 100644 --- a/test/test-data/sqlalchemy-plugin-features.test +++ b/test/test-data/sqlalchemy-plugin-features.test @@ -22,3 +22,156 @@ User(name=os) from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() [out] + +[case testModelInitColumnDecorated] +# flags: --strict-optional +from sqlalchemy import Column, Integer, String +from base import Base +from typing import Optional + +class User(Base): + __tablename__ = 'users' + id = Column(Integer(), primary_key=True) + name = Column(String()) + +oi: Optional[int] +os: Optional[str] + +User() +User(1, 2) # E: Too many arguments for "User" +User(id=int(), name=str()) +User(id=oi) # E: Incompatible type for "id" of "User" (got "Optional[int]", expected "int") +User(name=os) + +[file base.py] +from sqlalchemy.ext.declarative import as_declarative +@as_declarative() +class Base: + ... +[out] + +[case testModelInitRelationship] +from typing import TYPE_CHECKING, List + +from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import relationship, RelationshipProperty +from base import Base +if TYPE_CHECKING: + from other import Other + +class User(Base): + __tablename__ = 'users' + id = Column(Integer(), primary_key=True) + other = relationship('Other', uselist=False) + many_others = relationship(Other, uselist=True) + +o: Other +mo: List[Other] +User() +User(other=o, many_others=mo) +User(other=mo) # E: Incompatible type for "other" of "User" (got "List[Other]", expected "Other") +User(unknown=42) # E: Unexpected column "unknown" for model "User" + +[file other.py] +from sqlalchemy import Column, Integer, String +from base import Base + +class Other(Base): + __tablename__ = 'other' + id = Column(Integer(), primary_key=True) + name = Column(String(), nullable=False) + +[file base.py] +from sqlalchemy.ext.declarative import declarative_base +Base = declarative_base() +[out] + +[case testRelationshipType] +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import relationship, RelationshipProperty +from typing import Iterable + +Base = declarative_base() + +class User(Base): + __tablename__ = 'users' + id = Column(Integer(), primary_key=True) + first_other = relationship(Other, uselist=False) + second_other = relationship(Other, uselist=True) + first_bad_other = relationship(Other) # E: Cannot figure out kind of relationship \ + # N: Suggestion: use an explicit "uselist" flag + second_bad_other: RelationshipProperty[int] = relationship(Other, uselist=False) # E: Incompatible types in assignment (expression has type "RelationshipProperty[Other]", variable has type "RelationshipProperty[int]") + +user = User() +reveal_type(user.first_other) # E: Revealed type is 'main.Other*' +reveal_type(user.second_other) # E: Revealed type is 'typing.Iterable*[main.Other]' + +class Other(Base): + __tablename__ = 'other' + id = Column(Integer(), primary_key=True) +[out] + +[case testRelationshipString] +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import relationship, RelationshipProperty +from typing import Iterable + +Base = declarative_base() + +class User(Base): + __tablename__ = 'users' + id = Column(Integer(), primary_key=True) + first_other = relationship('Other', uselist=False) + second_other = relationship('Other', uselist=True) + first_bad_other = relationship('Other') # E: Cannot figure out kind of relationship \ + # N: Suggestion: use an explicit "uselist" flag + second_bad_other = relationship('What', uselist=False) # E: Cannot find model "What" \ + # N: Only imported models can be found \ + # N: Use "if TYPE_CHECKING: ..." to avoid import cycles + +user = User() +reveal_type(user.first_other) # E: Revealed type is 'main.Other*' +reveal_type(user.second_other) # E: Revealed type is 'typing.Iterable*[main.Other]' + +class Other(Base): + __tablename__ = 'other' + id = Column(Integer(), primary_key=True) +[out] + +[case testRelationshipAnnotated] +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import relationship, RelationshipProperty +from typing import Iterable + +Base = declarative_base() + +class User(Base): + __tablename__ = 'users' + id = Column(Integer(), primary_key=True) + first_other: RelationshipProperty[Other] = relationship('Other') + second_other: RelationshipProperty[Iterable[Other]] = relationship(Other, uselist=True) + third_other: RelationshipProperty[Other] = relationship(Other, uselist=False) + bad_other: RelationshipProperty[Other] = relationship('Other', uselist=True) # E: Incompatible types in assignment (expression has type "RelationshipProperty[Iterable[Other]]", variable has type "RelationshipProperty[Other]") + +class Other(Base): + __tablename__ = 'other' + id = Column(Integer(), primary_key=True) +[out] + +[case testColumnCombo] +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import Column, Integer, String + +Base = declarative_base() + +class User(Base): + __tablename__ = 'users' + id = Column(Integer(), primary_key=True) + name = Column(String(), default='John Doe', nullable=True) + +user: User +reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[Union[builtins.str*, None]]' +[out] diff --git a/test/test-data/sqlalchemy-sql-elements.test b/test/test-data/sqlalchemy-sql-elements.test index ebc0d60..f6a0d77 100644 --- a/test/test-data/sqlalchemy-sql-elements.test +++ b/test/test-data/sqlalchemy-sql-elements.test @@ -85,7 +85,7 @@ Base: Any = declarative_base() class Model(Base): __tablename__ = 'users' - tags = Column(ARRAY(String(16))) + tags = Column(ARRAY(String(16)), nullable=False) tags: Set[str] = set() reveal_type(cast(tags, Model.tags.type)) # E: Revealed type is 'sqlalchemy.sql.elements.Cast[builtins.list*[builtins.str*]]' From 40dce9b5609357d9cdee4b0a364bad58a0ff82b0 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 16 Jan 2019 18:05:18 +0000 Subject: [PATCH 15/16] Simplify error messages --- sqlmypy.py | 5 ++--- test/test-data/sqlalchemy-plugin-features.test | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/sqlmypy.py b/sqlmypy.py index 3938cf2..28a993e 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -260,8 +260,7 @@ class User(Base): new_arg = fill_typevars_with_any(sym.node) else: ctx.api.fail('Cannot find model "{}"'.format(name), ctx.context) - ctx.api.note('Only imported models can be found', ctx.context) - ctx.api.note('Use "if TYPE_CHECKING: ..." to avoid import cycles', ctx.context) + ctx.api.note('Only imported models can be found; use "if TYPE_CHECKING: ..." to avoid import cycles', ctx.context) new_arg = AnyType(TypeOfAny.from_error) else: if isinstance(arg_type, CallableType) and arg_type.is_type_obj(): @@ -275,7 +274,7 @@ class User(Base): if parse_bool(uselist_arg): new_arg = ctx.api.named_generic_type('typing.Iterable', [new_arg]) else: - if has_no_annotation: + if has_no_annotation and not isinstance(new_arg, AnyType): ctx.api.fail('Cannot figure out kind of relationship', ctx.context) ctx.api.note('Suggestion: use an explicit "uselist" flag', ctx.context) diff --git a/test/test-data/sqlalchemy-plugin-features.test b/test/test-data/sqlalchemy-plugin-features.test index 30ca7e7..7bfa6f0 100644 --- a/test/test-data/sqlalchemy-plugin-features.test +++ b/test/test-data/sqlalchemy-plugin-features.test @@ -127,9 +127,8 @@ class User(Base): second_other = relationship('Other', uselist=True) first_bad_other = relationship('Other') # E: Cannot figure out kind of relationship \ # N: Suggestion: use an explicit "uselist" flag - second_bad_other = relationship('What', uselist=False) # E: Cannot find model "What" \ - # N: Only imported models can be found \ - # N: Use "if TYPE_CHECKING: ..." to avoid import cycles + second_bad_other = relationship('What') # E: Cannot find model "What" \ + # N: Only imported models can be found; use "if TYPE_CHECKING: ..." to avoid import cycles user = User() reveal_type(user.first_other) # E: Revealed type is 'main.Other*' From 4632d6a5817c7a82de116fbe1088fd32b4edf350 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 16 Jan 2019 18:50:43 +0000 Subject: [PATCH 16/16] Don't complain about uselist, it is not that helpful --- sqlmypy.py | 9 +++++---- test/test-data/sqlalchemy-plugin-features.test | 8 ++------ 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/sqlmypy.py b/sqlmypy.py index 28a993e..e9d2819 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -242,7 +242,7 @@ class User(Base): """ assert isinstance(ctx.default_return_type, Instance) original_type_arg = ctx.default_return_type.args[0] - has_no_annotation = isinstance(original_type_arg, UninhabitedType) + has_annotation = not isinstance(original_type_arg, UninhabitedType) arg = get_argument_by_name(ctx, 'argument') arg_type = get_argtype_by_name(ctx, 'argument') @@ -274,9 +274,10 @@ class User(Base): if parse_bool(uselist_arg): new_arg = ctx.api.named_generic_type('typing.Iterable', [new_arg]) else: - if has_no_annotation and not isinstance(new_arg, AnyType): - ctx.api.fail('Cannot figure out kind of relationship', ctx.context) - ctx.api.note('Suggestion: use an explicit "uselist" flag', ctx.context) + if has_annotation: + # If there is an annotation we use it as a source of truth. + # This will cause false negatives, but it is better than lots of false positives. + new_arg = original_type_arg return Instance(ctx.default_return_type.type, [new_arg], line=ctx.default_return_type.line, diff --git a/test/test-data/sqlalchemy-plugin-features.test b/test/test-data/sqlalchemy-plugin-features.test index 7bfa6f0..92cf4ac 100644 --- a/test/test-data/sqlalchemy-plugin-features.test +++ b/test/test-data/sqlalchemy-plugin-features.test @@ -97,11 +97,9 @@ Base = declarative_base() class User(Base): __tablename__ = 'users' id = Column(Integer(), primary_key=True) - first_other = relationship(Other, uselist=False) + first_other = relationship(Other) second_other = relationship(Other, uselist=True) - first_bad_other = relationship(Other) # E: Cannot figure out kind of relationship \ - # N: Suggestion: use an explicit "uselist" flag - second_bad_other: RelationshipProperty[int] = relationship(Other, uselist=False) # E: Incompatible types in assignment (expression has type "RelationshipProperty[Other]", variable has type "RelationshipProperty[int]") + bad_other: RelationshipProperty[int] = relationship(Other, uselist=False) # E: Incompatible types in assignment (expression has type "RelationshipProperty[Other]", variable has type "RelationshipProperty[int]") user = User() reveal_type(user.first_other) # E: Revealed type is 'main.Other*' @@ -125,8 +123,6 @@ class User(Base): id = Column(Integer(), primary_key=True) first_other = relationship('Other', uselist=False) second_other = relationship('Other', uselist=True) - first_bad_other = relationship('Other') # E: Cannot figure out kind of relationship \ - # N: Suggestion: use an explicit "uselist" flag second_bad_other = relationship('What') # E: Cannot find model "What" \ # N: Only imported models can be found; use "if TYPE_CHECKING: ..." to avoid import cycles