From cc6c69649bbb5ee1092ad650490368b5c8d421f5 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 17 Jan 2019 15:05:40 +0000 Subject: [PATCH 1/7] Small fixes; mostly temporary --- sqlalchemy-stubs/engine/interfaces.pyi | 2 ++ sqlalchemy-stubs/sql/schema.pyi | 13 +++++++++++++ sqlalchemy-stubs/sql/sqltypes.pyi | 2 +- sqlalchemy-stubs/sql/type_api.pyi | 4 ++-- 4 files changed, 18 insertions(+), 3 deletions(-) diff --git a/sqlalchemy-stubs/engine/interfaces.pyi b/sqlalchemy-stubs/engine/interfaces.pyi index a6f2adf..24c01ed 100644 --- a/sqlalchemy-stubs/engine/interfaces.pyi +++ b/sqlalchemy-stubs/engine/interfaces.pyi @@ -4,6 +4,8 @@ from .result import ResultProxy from ..sql.compiler import Compiled as Compiled, TypeCompiler as TypeCompiler class Dialect(object): + @property + def name(self) -> str: ... def create_connect_args(self, url): ... @classmethod def type_descriptor(cls, typeobj): ... diff --git a/sqlalchemy-stubs/sql/schema.pyi b/sqlalchemy-stubs/sql/schema.pyi index 4d837d3..3564c35 100644 --- a/sqlalchemy-stubs/sql/schema.pyi +++ b/sqlalchemy-stubs/sql/schema.pyi @@ -110,6 +110,19 @@ class Column(SchemaItem, ColumnClause[_T]): nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ..., server_onupdate: FetchedValue = ..., quote: Optional[bool] = ..., unique: bool = ..., system: bool = ..., comment: str = ...) -> None: ... + # The two overloads below exist to make annotation more like a cast. This is a temporary measure. + @overload + def __init__(self, name: str, type_: Any, *args: Any, autoincrement: Union[bool, str] = ..., + default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ..., + nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ..., + server_onupdate: FetchedValue = ..., quote: Optional[bool] = ..., unique: bool = ..., + system: bool = ..., comment: str = ...) -> None: ... + @overload + def __init__(self, type_: Any, *args: Any, autoincrement: Union[bool, str] = ..., + default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ..., + nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ..., + server_onupdate: FetchedValue = ..., quote: Optional[bool] = ..., unique: bool = ..., + system: bool = ..., comment: str = ...) -> None: ... def references(self, column: Column[Any]) -> bool: ... def append_foreign_key(self, fk: ForeignKey) -> None: ... def copy(self: _C, **kw: Any) -> _C: ... diff --git a/sqlalchemy-stubs/sql/sqltypes.pyi b/sqlalchemy-stubs/sql/sqltypes.pyi index 2827bc2..6cfd8b7 100644 --- a/sqlalchemy-stubs/sql/sqltypes.pyi +++ b/sqlalchemy-stubs/sql/sqltypes.pyi @@ -22,7 +22,7 @@ class Indexable(object): # Docs say that String is unicode when DBAPI supports it # but it should be all major DBAPIs now. -class String(Concatenable, TypeEngine[typing_Text]): +class String(Concatenable, TypeEngine[str]): # XXX: should be typing_Text __visit_name__: str = ... length: Optional[int] = ... collation: Optional[str] = ... diff --git a/sqlalchemy-stubs/sql/type_api.pyi b/sqlalchemy-stubs/sql/type_api.pyi index a2c0ca6..0e8ce50 100644 --- a/sqlalchemy-stubs/sql/type_api.pyi +++ b/sqlalchemy-stubs/sql/type_api.pyi @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union, TypeVar, Generic, Type, Callable, ClassVar, Tuple, Mapping, overload +from typing import Any, Optional, Union, TypeVar, Generic, Type, Callable, ClassVar, Tuple, Mapping, overload, Text as typing_Text from .. import util from .visitors import Visitable as Visitable, VisitableType as VisitableType from .base import SchemaEventTarget as SchemaEventTarget @@ -91,7 +91,7 @@ class TypeDecorator(SchemaEventTarget, TypeEngine[_T]): def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: ... def __getattr__(self, key: str) -> Any: ... def process_literal_param(self, value: Optional[_T], dialect: Dialect) -> Optional[str]: ... - def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Optional[str]: ... + def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Optional[typing_Text]: ... def process_result_value(self, value: Optional[Any], dialect: Dialect) -> Optional[_T]: ... def literal_processor(self, dialect: Dialect) -> Callable[[Optional[_T]], Optional[str]]: ... def bind_processor(self, dialect: Dialect) -> Callable[[Optional[_T]], Optional[str]]: ... From e560a501bd59c3ef80e9de6999bd9f8d2a107b64 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 17 Jan 2019 20:29:10 +0000 Subject: [PATCH 2/7] A bunch of minor fixes --- sqlalchemy-stubs/sql/schema.pyi | 13 +++++++------ sqlmypy.py | 28 +++++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/sqlalchemy-stubs/sql/schema.pyi b/sqlalchemy-stubs/sql/schema.pyi index 3564c35..b2cc49a 100644 --- a/sqlalchemy-stubs/sql/schema.pyi +++ b/sqlalchemy-stubs/sql/schema.pyi @@ -11,6 +11,7 @@ from .. import util from ..engine import Engine, Connection, Connectable from ..engine.url import URL from .compiler import DDLCompiler +from .expression import FunctionElement import threading _T = TypeVar('_T') @@ -90,38 +91,38 @@ class Column(SchemaItem, ColumnClause[_T]): def __init__(self, name: str, type_: Type[TypeEngine[_T]], *args: Any, autoincrement: Union[bool, str] = ..., default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ..., nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ..., - server_onupdate: FetchedValue = ..., quote: Optional[bool] = ..., unique: bool = ..., + server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ..., system: bool = ..., comment: str = ...) -> None: ... @overload def __init__(self, type_: Type[TypeEngine[_T]], *args: Any, autoincrement: Union[bool, str] = ..., default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ..., nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ..., - server_onupdate: FetchedValue = ..., quote: Optional[bool] = ..., unique: bool = ..., + server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ..., system: bool = ..., comment: str = ...) -> None: ... @overload def __init__(self, name: str, type_: TypeEngine[_T], *args: Any, autoincrement: Union[bool, str] = ..., default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ..., nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ..., - server_onupdate: FetchedValue = ..., quote: Optional[bool] = ..., unique: bool = ..., + server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ..., system: bool = ..., comment: str = ...) -> None: ... @overload def __init__(self, type_: TypeEngine[_T], *args: Any, autoincrement: Union[bool, str] = ..., default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ..., nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ..., - server_onupdate: FetchedValue = ..., quote: Optional[bool] = ..., unique: bool = ..., + server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ..., system: bool = ..., comment: str = ...) -> None: ... # The two overloads below exist to make annotation more like a cast. This is a temporary measure. @overload def __init__(self, name: str, type_: Any, *args: Any, autoincrement: Union[bool, str] = ..., default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ..., nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ..., - server_onupdate: FetchedValue = ..., quote: Optional[bool] = ..., unique: bool = ..., + server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ..., system: bool = ..., comment: str = ...) -> None: ... @overload def __init__(self, type_: Any, *args: Any, autoincrement: Union[bool, str] = ..., default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ..., nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ..., - server_onupdate: FetchedValue = ..., quote: Optional[bool] = ..., unique: bool = ..., + server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ..., system: bool = ..., comment: str = ...) -> None: ... def references(self, column: Column[Any]) -> bool: ... def append_foreign_key(self, fk: ForeignKey) -> None: ... diff --git a/sqlmypy.py b/sqlmypy.py index e9d2819..b07cbc7 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -2,7 +2,7 @@ from mypy.plugins.common import add_method from mypy.nodes import( NameExpr, Expression, StrExpr, TypeInfo, ClassDef, Block, SymbolTable, SymbolTableNode, GDEF, - Argument, Var, ARG_STAR2 + Argument, Var, ARG_STAR2, MDEF ) from mypy.types import ( UnionType, NoneTyp, Instance, Type, AnyType, TypeOfAny, UninhabitedType, CallableType @@ -86,6 +86,19 @@ def add_init_hook(ctx: ClassDefContext) -> None: add_method(ctx, '__init__', [kw_arg], NoneTyp()) ctx.cls.info.metadata.setdefault('sqlalchemy', {})['generated_init'] = True + # Also add a selection of auto-generated attributes. + table = Var('__table__') + table.info = ctx.cls.info + table._fullname = ctx.cls.fullname + '.__table__' + sym = ctx.api.lookup_fully_qualified_or_none('sqlalchemy.sql.schema.Table') + if sym: + assert isinstance(sym.node, TypeInfo) + tp = Instance(sym.node, []) # type: Type + else: + tp = AnyType(TypeOfAny.special_form) + table.type = tp + ctx.cls.info.names['__table__'] = SymbolTableNode(MDEF, table) + def decl_deco_hook(ctx: ClassDefContext) -> None: """Support declaring base class as declarative with a decorator. @@ -119,6 +132,19 @@ def decl_info_hook(ctx): ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info)) set_declarative(info) + # Also add a selection of generated attributes. + meta = Var('metadata') + meta.info = info + meta._fullname = class_def.fullname + '.metadata' + sym = ctx.api.lookup_fully_qualified_or_none('sqlalchemy.sql.schema.MetaData') + if sym: + assert isinstance(sym.node, TypeInfo) + tp = Instance(sym.node, []) # type: Type + else: + tp = AnyType(TypeOfAny.special_form) + meta.type = tp + ctx.cls.info.names['metadata'] = SymbolTableNode(MDEF, meta) + def model_hook(ctx: FunctionContext) -> Type: """More precise model instantiation check. From f134e1375feedf0576eed3ff8afe738429fa4d56 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 17 Jan 2019 20:31:01 +0000 Subject: [PATCH 3/7] Fix bug --- sqlmypy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmypy.py b/sqlmypy.py index b07cbc7..b76abaf 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -143,7 +143,7 @@ def decl_info_hook(ctx): else: tp = AnyType(TypeOfAny.special_form) meta.type = tp - ctx.cls.info.names['metadata'] = SymbolTableNode(MDEF, meta) + info.names['metadata'] = SymbolTableNode(MDEF, meta) def model_hook(ctx: FunctionContext) -> Type: From 4b2832244561526456f20f0dac9aa44aec2664ef Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 18 Jan 2019 12:52:06 +0000 Subject: [PATCH 4/7] Fix kwargs in models --- sqlmypy.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sqlmypy.py b/sqlmypy.py index b76abaf..52000b5 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -172,6 +172,10 @@ def model_hook(ctx: FunctionContext) -> Type: assert len(ctx.arg_names) == 1 # only **kwargs in generated __init__ assert len(ctx.arg_types) == 1 for actual_name, actual_type in zip(ctx.arg_names[0], ctx.arg_types[0]): + if actual_name is None: + # We can't check kwargs reliably. + # TODO: support TypedDict? + continue if actual_name not in expected_types: ctx.api.fail('Unexpected column "{}" for model "{}"'.format(actual_name, model.name()), ctx.context) From a9d5c8e3b9369bf518a1cd22edf6b70826477427 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 18 Jan 2019 14:21:23 +0000 Subject: [PATCH 5/7] Refactor attribute addition --- sqlmypy.py | 52 +++++++++++++++++++++++++++++++--------------------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/sqlmypy.py b/sqlmypy.py index 52000b5..8dbae68 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -67,11 +67,23 @@ def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefConte sym = self.lookup_fully_qualified(fullname) if sym and isinstance(sym.node, TypeInfo): if is_declarative(sym.node): - return add_init_hook + return add_model_init_hook return None -def add_init_hook(ctx: ClassDefContext) -> None: +def add_var_to_class(name: str, typ: Type, info: TypeInfo) -> None: + """Add a variable with given name and type to the symbol table of a class. + + This also takes care about setting necessary attributes on the variable node. + """ + var = Var(name) + var.info = info + var._fullname = info.fullname() + '.' + name + var.type = typ + info.names[name] = SymbolTableNode(MDEF, var) + + +def add_model_init_hook(ctx: ClassDefContext) -> None: """Add a dummy __init__() to a model and record it is generated. Instantiation will be checked more precisely when we inferred types @@ -87,17 +99,24 @@ def add_init_hook(ctx: ClassDefContext) -> None: ctx.cls.info.metadata.setdefault('sqlalchemy', {})['generated_init'] = True # Also add a selection of auto-generated attributes. - table = Var('__table__') - table.info = ctx.cls.info - table._fullname = ctx.cls.fullname + '.__table__' sym = ctx.api.lookup_fully_qualified_or_none('sqlalchemy.sql.schema.Table') if sym: assert isinstance(sym.node, TypeInfo) - tp = Instance(sym.node, []) # type: Type + typ = Instance(sym.node, []) # type: Type + else: + typ = AnyType(TypeOfAny.special_form) + add_var_to_class('__table__', typ, ctx.cls.info) + + +def add_metadata_var(ctx: ClassDefContext, info: TypeInfo) -> None: + """Add .metadata attribute to a declarative base.""" + sym = ctx.api.lookup_fully_qualified_or_none('sqlalchemy.sql.schema.MetaData') + if sym: + assert isinstance(sym.node, TypeInfo) + typ = Instance(sym.node, []) # type: Type else: - tp = AnyType(TypeOfAny.special_form) - table.type = tp - ctx.cls.info.names['__table__'] = SymbolTableNode(MDEF, table) + typ = AnyType(TypeOfAny.special_form) + add_var_to_class('metadata', typ, info) def decl_deco_hook(ctx: ClassDefContext) -> None: @@ -111,6 +130,7 @@ class Base: ... """ set_declarative(ctx.cls.info) + add_metadata_var(ctx, ctx.cls.info) def decl_info_hook(ctx): @@ -132,18 +152,8 @@ def decl_info_hook(ctx): ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info)) set_declarative(info) - # Also add a selection of generated attributes. - meta = Var('metadata') - meta.info = info - meta._fullname = class_def.fullname + '.metadata' - sym = ctx.api.lookup_fully_qualified_or_none('sqlalchemy.sql.schema.MetaData') - if sym: - assert isinstance(sym.node, TypeInfo) - tp = Instance(sym.node, []) # type: Type - else: - tp = AnyType(TypeOfAny.special_form) - meta.type = tp - info.names['metadata'] = SymbolTableNode(MDEF, meta) + # TODO: check what else is added. + add_metadata_var(ctx, info) def model_hook(ctx: FunctionContext) -> Type: From 6edb179c2bf049d56cc7bf89bffc427bc656937e Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 18 Jan 2019 14:30:51 +0000 Subject: [PATCH 6/7] Add tests --- sqlalchemy-stubs/sql/sqltypes.pyi | 2 +- test/test-data/sqlalchemy-basics.test | 4 +-- .../test-data/sqlalchemy-plugin-features.test | 34 +++++++++++++++++++ 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/sqlalchemy-stubs/sql/sqltypes.pyi b/sqlalchemy-stubs/sql/sqltypes.pyi index 6cfd8b7..cb0e61a 100644 --- a/sqlalchemy-stubs/sql/sqltypes.pyi +++ b/sqlalchemy-stubs/sql/sqltypes.pyi @@ -39,7 +39,7 @@ class String(Concatenable, TypeEngine[str]): # XXX: should be typing_Text def bind_processor(self, dialect: Dialect) -> Optional[Callable[[str], str]]: ... def result_processor(self, dialect: Dialect, coltype: Any) -> Optional[Callable[[Optional[Any]], Optional[str]]]: ... @property - def python_type(self) -> Type[typing_Text]: ... + def python_type(self) -> Type[str]: ... def get_dbapi_type(self, dbapi: Any) -> Any: ... class Text(String): diff --git a/test/test-data/sqlalchemy-basics.test b/test/test-data/sqlalchemy-basics.test index 9fecf29..141851d 100644 --- a/test/test-data/sqlalchemy-basics.test +++ b/test/test-data/sqlalchemy-basics.test @@ -102,7 +102,7 @@ class User(Base): user = User() reveal_type(user.id) # E: Revealed type is 'builtins.int*' -reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[Union[builtins.unicode*, None]]' +reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[Union[builtins.str*, None]]' [out] [case testColumnFieldsInferredInstance_python2] @@ -118,5 +118,5 @@ class User(Base): user = User() reveal_type(user.id) # E: Revealed type is 'builtins.int*' -reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[builtins.unicode*]' +reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[builtins.str*]' [out] diff --git a/test/test-data/sqlalchemy-plugin-features.test b/test/test-data/sqlalchemy-plugin-features.test index 92cf4ac..246b742 100644 --- a/test/test-data/sqlalchemy-plugin-features.test +++ b/test/test-data/sqlalchemy-plugin-features.test @@ -170,3 +170,37 @@ class User(Base): user: User reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[Union[builtins.str*, None]]' [out] + +[case testAddedAttributesDeclared] +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import Column, Integer, String + +Base = declarative_base() + +class User(Base): + __tablename__ = 'users' + id = Column(Integer(), primary_key=True) + name = Column(String(), default='John Doe', nullable=True) + +user: User +reveal_type(User.metadata) # E: Revealed type is 'sqlalchemy.sql.schema.MetaData' +reveal_type(User.__table__) # E: Revealed type is 'sqlalchemy.sql.schema.Table' +[out] + +[case testAddedAttributedDecorated] +from sqlalchemy.ext.declarative import as_declarative +from sqlalchemy import Column, Integer, String + +@as_declarative() +class Base: + ... + +class User(Base): + __tablename__ = 'users' + id = Column(Integer(), primary_key=True) + name = Column(String(), default='John Doe', nullable=True) + +user: User +reveal_type(User.metadata) # E: Revealed type is 'sqlalchemy.sql.schema.MetaData' +reveal_type(User.__table__) # E: Revealed type is 'sqlalchemy.sql.schema.Table' +[out] From 5eda02f4874eaf6305c851ac9e801d1f62b9494a Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 18 Jan 2019 15:17:57 +0000 Subject: [PATCH 7/7] One more test --- test/test-data/sqlalchemy-plugin-features.test | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/test/test-data/sqlalchemy-plugin-features.test b/test/test-data/sqlalchemy-plugin-features.test index 246b742..8b8f530 100644 --- a/test/test-data/sqlalchemy-plugin-features.test +++ b/test/test-data/sqlalchemy-plugin-features.test @@ -204,3 +204,18 @@ user: User reveal_type(User.metadata) # E: Revealed type is 'sqlalchemy.sql.schema.MetaData' reveal_type(User.__table__) # E: Revealed type is 'sqlalchemy.sql.schema.Table' [out] + +[case testKwArgsModelOK] +from sqlalchemy import Column, Integer, String +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + +class User(Base): + __tablename__ = 'users' + id = Column(Integer, primary_key=True) + name = Column(String) + +record = {'name': 'John Doe'} +User(**record) # OK +[out]