diff --git a/sqlalchemy-stubs/engine/interfaces.pyi b/sqlalchemy-stubs/engine/interfaces.pyi index a6f2adf..24c01ed 100644 --- a/sqlalchemy-stubs/engine/interfaces.pyi +++ b/sqlalchemy-stubs/engine/interfaces.pyi @@ -4,6 +4,8 @@ from .result import ResultProxy from ..sql.compiler import Compiled as Compiled, TypeCompiler as TypeCompiler class Dialect(object): + @property + def name(self) -> str: ... def create_connect_args(self, url): ... @classmethod def type_descriptor(cls, typeobj): ... diff --git a/sqlalchemy-stubs/sql/schema.pyi b/sqlalchemy-stubs/sql/schema.pyi index 4d837d3..b2cc49a 100644 --- a/sqlalchemy-stubs/sql/schema.pyi +++ b/sqlalchemy-stubs/sql/schema.pyi @@ -11,6 +11,7 @@ from .. import util from ..engine import Engine, Connection, Connectable from ..engine.url import URL from .compiler import DDLCompiler +from .expression import FunctionElement import threading _T = TypeVar('_T') @@ -90,25 +91,38 @@ class Column(SchemaItem, ColumnClause[_T]): def __init__(self, name: str, type_: Type[TypeEngine[_T]], *args: Any, autoincrement: Union[bool, str] = ..., default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ..., nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ..., - server_onupdate: FetchedValue = ..., quote: Optional[bool] = ..., unique: bool = ..., + server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ..., system: bool = ..., comment: str = ...) -> None: ... @overload def __init__(self, type_: Type[TypeEngine[_T]], *args: Any, autoincrement: Union[bool, str] = ..., default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ..., nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ..., - server_onupdate: FetchedValue = ..., quote: Optional[bool] = ..., unique: bool = ..., + server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ..., system: bool = ..., comment: str = ...) -> None: ... @overload def __init__(self, name: str, type_: TypeEngine[_T], *args: Any, autoincrement: Union[bool, str] = ..., default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ..., nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ..., - server_onupdate: FetchedValue = ..., quote: Optional[bool] = ..., unique: bool = ..., + server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ..., system: bool = ..., comment: str = ...) -> None: ... @overload def __init__(self, type_: TypeEngine[_T], *args: Any, autoincrement: Union[bool, str] = ..., default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ..., nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ..., - server_onupdate: FetchedValue = ..., quote: Optional[bool] = ..., unique: bool = ..., + server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ..., + system: bool = ..., comment: str = ...) -> None: ... + # The two overloads below exist to make annotation more like a cast. This is a temporary measure. + @overload + def __init__(self, name: str, type_: Any, *args: Any, autoincrement: Union[bool, str] = ..., + default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ..., + nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ..., + server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ..., + system: bool = ..., comment: str = ...) -> None: ... + @overload + def __init__(self, type_: Any, *args: Any, autoincrement: Union[bool, str] = ..., + default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ..., + nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ..., + server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ..., system: bool = ..., comment: str = ...) -> None: ... def references(self, column: Column[Any]) -> bool: ... def append_foreign_key(self, fk: ForeignKey) -> None: ... diff --git a/sqlalchemy-stubs/sql/sqltypes.pyi b/sqlalchemy-stubs/sql/sqltypes.pyi index 2827bc2..cb0e61a 100644 --- a/sqlalchemy-stubs/sql/sqltypes.pyi +++ b/sqlalchemy-stubs/sql/sqltypes.pyi @@ -22,7 +22,7 @@ class Indexable(object): # Docs say that String is unicode when DBAPI supports it # but it should be all major DBAPIs now. -class String(Concatenable, TypeEngine[typing_Text]): +class String(Concatenable, TypeEngine[str]): # XXX: should be typing_Text __visit_name__: str = ... length: Optional[int] = ... collation: Optional[str] = ... @@ -39,7 +39,7 @@ class String(Concatenable, TypeEngine[typing_Text]): def bind_processor(self, dialect: Dialect) -> Optional[Callable[[str], str]]: ... def result_processor(self, dialect: Dialect, coltype: Any) -> Optional[Callable[[Optional[Any]], Optional[str]]]: ... @property - def python_type(self) -> Type[typing_Text]: ... + def python_type(self) -> Type[str]: ... def get_dbapi_type(self, dbapi: Any) -> Any: ... class Text(String): diff --git a/sqlalchemy-stubs/sql/type_api.pyi b/sqlalchemy-stubs/sql/type_api.pyi index a2c0ca6..0e8ce50 100644 --- a/sqlalchemy-stubs/sql/type_api.pyi +++ b/sqlalchemy-stubs/sql/type_api.pyi @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union, TypeVar, Generic, Type, Callable, ClassVar, Tuple, Mapping, overload +from typing import Any, Optional, Union, TypeVar, Generic, Type, Callable, ClassVar, Tuple, Mapping, overload, Text as typing_Text from .. import util from .visitors import Visitable as Visitable, VisitableType as VisitableType from .base import SchemaEventTarget as SchemaEventTarget @@ -91,7 +91,7 @@ class TypeDecorator(SchemaEventTarget, TypeEngine[_T]): def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: ... def __getattr__(self, key: str) -> Any: ... def process_literal_param(self, value: Optional[_T], dialect: Dialect) -> Optional[str]: ... - def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Optional[str]: ... + def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Optional[typing_Text]: ... def process_result_value(self, value: Optional[Any], dialect: Dialect) -> Optional[_T]: ... def literal_processor(self, dialect: Dialect) -> Callable[[Optional[_T]], Optional[str]]: ... def bind_processor(self, dialect: Dialect) -> Callable[[Optional[_T]], Optional[str]]: ... diff --git a/sqlmypy.py b/sqlmypy.py index e9d2819..8dbae68 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -2,7 +2,7 @@ from mypy.plugins.common import add_method from mypy.nodes import( NameExpr, Expression, StrExpr, TypeInfo, ClassDef, Block, SymbolTable, SymbolTableNode, GDEF, - Argument, Var, ARG_STAR2 + Argument, Var, ARG_STAR2, MDEF ) from mypy.types import ( UnionType, NoneTyp, Instance, Type, AnyType, TypeOfAny, UninhabitedType, CallableType @@ -67,11 +67,23 @@ def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefConte sym = self.lookup_fully_qualified(fullname) if sym and isinstance(sym.node, TypeInfo): if is_declarative(sym.node): - return add_init_hook + return add_model_init_hook return None -def add_init_hook(ctx: ClassDefContext) -> None: +def add_var_to_class(name: str, typ: Type, info: TypeInfo) -> None: + """Add a variable with given name and type to the symbol table of a class. + + This also takes care about setting necessary attributes on the variable node. + """ + var = Var(name) + var.info = info + var._fullname = info.fullname() + '.' + name + var.type = typ + info.names[name] = SymbolTableNode(MDEF, var) + + +def add_model_init_hook(ctx: ClassDefContext) -> None: """Add a dummy __init__() to a model and record it is generated. Instantiation will be checked more precisely when we inferred types @@ -86,6 +98,26 @@ def add_init_hook(ctx: ClassDefContext) -> None: add_method(ctx, '__init__', [kw_arg], NoneTyp()) ctx.cls.info.metadata.setdefault('sqlalchemy', {})['generated_init'] = True + # Also add a selection of auto-generated attributes. + sym = ctx.api.lookup_fully_qualified_or_none('sqlalchemy.sql.schema.Table') + if sym: + assert isinstance(sym.node, TypeInfo) + typ = Instance(sym.node, []) # type: Type + else: + typ = AnyType(TypeOfAny.special_form) + add_var_to_class('__table__', typ, ctx.cls.info) + + +def add_metadata_var(ctx: ClassDefContext, info: TypeInfo) -> None: + """Add .metadata attribute to a declarative base.""" + sym = ctx.api.lookup_fully_qualified_or_none('sqlalchemy.sql.schema.MetaData') + if sym: + assert isinstance(sym.node, TypeInfo) + typ = Instance(sym.node, []) # type: Type + else: + typ = AnyType(TypeOfAny.special_form) + add_var_to_class('metadata', typ, info) + def decl_deco_hook(ctx: ClassDefContext) -> None: """Support declaring base class as declarative with a decorator. @@ -98,6 +130,7 @@ class Base: ... """ set_declarative(ctx.cls.info) + add_metadata_var(ctx, ctx.cls.info) def decl_info_hook(ctx): @@ -119,6 +152,9 @@ def decl_info_hook(ctx): ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info)) set_declarative(info) + # TODO: check what else is added. + add_metadata_var(ctx, info) + def model_hook(ctx: FunctionContext) -> Type: """More precise model instantiation check. @@ -146,6 +182,10 @@ def model_hook(ctx: FunctionContext) -> Type: assert len(ctx.arg_names) == 1 # only **kwargs in generated __init__ assert len(ctx.arg_types) == 1 for actual_name, actual_type in zip(ctx.arg_names[0], ctx.arg_types[0]): + if actual_name is None: + # We can't check kwargs reliably. + # TODO: support TypedDict? + continue if actual_name not in expected_types: ctx.api.fail('Unexpected column "{}" for model "{}"'.format(actual_name, model.name()), ctx.context) diff --git a/test/test-data/sqlalchemy-basics.test b/test/test-data/sqlalchemy-basics.test index 9fecf29..141851d 100644 --- a/test/test-data/sqlalchemy-basics.test +++ b/test/test-data/sqlalchemy-basics.test @@ -102,7 +102,7 @@ class User(Base): user = User() reveal_type(user.id) # E: Revealed type is 'builtins.int*' -reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[Union[builtins.unicode*, None]]' +reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[Union[builtins.str*, None]]' [out] [case testColumnFieldsInferredInstance_python2] @@ -118,5 +118,5 @@ class User(Base): user = User() reveal_type(user.id) # E: Revealed type is 'builtins.int*' -reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[builtins.unicode*]' +reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[builtins.str*]' [out] diff --git a/test/test-data/sqlalchemy-plugin-features.test b/test/test-data/sqlalchemy-plugin-features.test index 92cf4ac..8b8f530 100644 --- a/test/test-data/sqlalchemy-plugin-features.test +++ b/test/test-data/sqlalchemy-plugin-features.test @@ -170,3 +170,52 @@ class User(Base): user: User reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[Union[builtins.str*, None]]' [out] + +[case testAddedAttributesDeclared] +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import Column, Integer, String + +Base = declarative_base() + +class User(Base): + __tablename__ = 'users' + id = Column(Integer(), primary_key=True) + name = Column(String(), default='John Doe', nullable=True) + +user: User +reveal_type(User.metadata) # E: Revealed type is 'sqlalchemy.sql.schema.MetaData' +reveal_type(User.__table__) # E: Revealed type is 'sqlalchemy.sql.schema.Table' +[out] + +[case testAddedAttributedDecorated] +from sqlalchemy.ext.declarative import as_declarative +from sqlalchemy import Column, Integer, String + +@as_declarative() +class Base: + ... + +class User(Base): + __tablename__ = 'users' + id = Column(Integer(), primary_key=True) + name = Column(String(), default='John Doe', nullable=True) + +user: User +reveal_type(User.metadata) # E: Revealed type is 'sqlalchemy.sql.schema.MetaData' +reveal_type(User.__table__) # E: Revealed type is 'sqlalchemy.sql.schema.Table' +[out] + +[case testKwArgsModelOK] +from sqlalchemy import Column, Integer, String +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + +class User(Base): + __tablename__ = 'users' + id = Column(Integer, primary_key=True) + name = Column(String) + +record = {'name': 'John Doe'} +User(**record) # OK +[out]