Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sqlalchemy-stubs/engine/interfaces.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ from .result import ResultProxy
from ..sql.compiler import Compiled as Compiled, TypeCompiler as TypeCompiler

class Dialect(object):
@property
def name(self) -> str: ...
def create_connect_args(self, url): ...
@classmethod
def type_descriptor(cls, typeobj): ...
Expand Down
22 changes: 18 additions & 4 deletions sqlalchemy-stubs/sql/schema.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ from .. import util
from ..engine import Engine, Connection, Connectable
from ..engine.url import URL
from .compiler import DDLCompiler
from .expression import FunctionElement
import threading

_T = TypeVar('_T')
Expand Down Expand Up @@ -90,25 +91,38 @@ class Column(SchemaItem, ColumnClause[_T]):
def __init__(self, name: str, type_: Type[TypeEngine[_T]], *args: Any, autoincrement: Union[bool, str] = ...,
default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ...,
nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ...,
server_onupdate: FetchedValue = ..., quote: Optional[bool] = ..., unique: bool = ...,
server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ...,
system: bool = ..., comment: str = ...) -> None: ...
@overload
def __init__(self, type_: Type[TypeEngine[_T]], *args: Any, autoincrement: Union[bool, str] = ...,
default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ...,
nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ...,
server_onupdate: FetchedValue = ..., quote: Optional[bool] = ..., unique: bool = ...,
server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ...,
system: bool = ..., comment: str = ...) -> None: ...
@overload
def __init__(self, name: str, type_: TypeEngine[_T], *args: Any, autoincrement: Union[bool, str] = ...,
default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ...,
nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ...,
server_onupdate: FetchedValue = ..., quote: Optional[bool] = ..., unique: bool = ...,
server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ...,
system: bool = ..., comment: str = ...) -> None: ...
@overload
def __init__(self, type_: TypeEngine[_T], *args: Any, autoincrement: Union[bool, str] = ...,
default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ...,
nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ...,
server_onupdate: FetchedValue = ..., quote: Optional[bool] = ..., unique: bool = ...,
server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ...,
system: bool = ..., comment: str = ...) -> None: ...
# The two overloads below exist to make annotation more like a cast. This is a temporary measure.
@overload
def __init__(self, name: str, type_: Any, *args: Any, autoincrement: Union[bool, str] = ...,
default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ...,
nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ...,
server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ...,
system: bool = ..., comment: str = ...) -> None: ...
@overload
def __init__(self, type_: Any, *args: Any, autoincrement: Union[bool, str] = ...,
default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ...,
nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ...,
server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ...,
system: bool = ..., comment: str = ...) -> None: ...
def references(self, column: Column[Any]) -> bool: ...
def append_foreign_key(self, fk: ForeignKey) -> None: ...
Expand Down
4 changes: 2 additions & 2 deletions sqlalchemy-stubs/sql/sqltypes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Indexable(object):

# Docs say that String is unicode when DBAPI supports it
# but it should be all major DBAPIs now.
class String(Concatenable, TypeEngine[typing_Text]):
class String(Concatenable, TypeEngine[str]): # XXX: should be typing_Text
__visit_name__: str = ...
length: Optional[int] = ...
collation: Optional[str] = ...
Expand All @@ -39,7 +39,7 @@ class String(Concatenable, TypeEngine[typing_Text]):
def bind_processor(self, dialect: Dialect) -> Optional[Callable[[str], str]]: ...
def result_processor(self, dialect: Dialect, coltype: Any) -> Optional[Callable[[Optional[Any]], Optional[str]]]: ...
@property
def python_type(self) -> Type[typing_Text]: ...
def python_type(self) -> Type[str]: ...
def get_dbapi_type(self, dbapi: Any) -> Any: ...

class Text(String):
Expand Down
4 changes: 2 additions & 2 deletions sqlalchemy-stubs/sql/type_api.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Union, TypeVar, Generic, Type, Callable, ClassVar, Tuple, Mapping, overload
from typing import Any, Optional, Union, TypeVar, Generic, Type, Callable, ClassVar, Tuple, Mapping, overload, Text as typing_Text
from .. import util
from .visitors import Visitable as Visitable, VisitableType as VisitableType
from .base import SchemaEventTarget as SchemaEventTarget
Expand Down Expand Up @@ -91,7 +91,7 @@ class TypeDecorator(SchemaEventTarget, TypeEngine[_T]):
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: ...
def __getattr__(self, key: str) -> Any: ...
def process_literal_param(self, value: Optional[_T], dialect: Dialect) -> Optional[str]: ...
def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Optional[str]: ...
def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Optional[typing_Text]: ...
def process_result_value(self, value: Optional[Any], dialect: Dialect) -> Optional[_T]: ...
def literal_processor(self, dialect: Dialect) -> Callable[[Optional[_T]], Optional[str]]: ...
def bind_processor(self, dialect: Dialect) -> Callable[[Optional[_T]], Optional[str]]: ...
Expand Down
46 changes: 43 additions & 3 deletions sqlmypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from mypy.plugins.common import add_method
from mypy.nodes import(
NameExpr, Expression, StrExpr, TypeInfo, ClassDef, Block, SymbolTable, SymbolTableNode, GDEF,
Argument, Var, ARG_STAR2
Argument, Var, ARG_STAR2, MDEF
)
from mypy.types import (
UnionType, NoneTyp, Instance, Type, AnyType, TypeOfAny, UninhabitedType, CallableType
Expand Down Expand Up @@ -67,11 +67,23 @@ def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefConte
sym = self.lookup_fully_qualified(fullname)
if sym and isinstance(sym.node, TypeInfo):
if is_declarative(sym.node):
return add_init_hook
return add_model_init_hook
return None


def add_init_hook(ctx: ClassDefContext) -> None:
def add_var_to_class(name: str, typ: Type, info: TypeInfo) -> None:
"""Add a variable with given name and type to the symbol table of a class.

This also takes care about setting necessary attributes on the variable node.
"""
var = Var(name)
var.info = info
var._fullname = info.fullname() + '.' + name
var.type = typ
info.names[name] = SymbolTableNode(MDEF, var)


def add_model_init_hook(ctx: ClassDefContext) -> None:
"""Add a dummy __init__() to a model and record it is generated.

Instantiation will be checked more precisely when we inferred types
Expand All @@ -86,6 +98,26 @@ def add_init_hook(ctx: ClassDefContext) -> None:
add_method(ctx, '__init__', [kw_arg], NoneTyp())
ctx.cls.info.metadata.setdefault('sqlalchemy', {})['generated_init'] = True

# Also add a selection of auto-generated attributes.
sym = ctx.api.lookup_fully_qualified_or_none('sqlalchemy.sql.schema.Table')
if sym:
assert isinstance(sym.node, TypeInfo)
typ = Instance(sym.node, []) # type: Type
else:
typ = AnyType(TypeOfAny.special_form)
add_var_to_class('__table__', typ, ctx.cls.info)


def add_metadata_var(ctx: ClassDefContext, info: TypeInfo) -> None:
"""Add .metadata attribute to a declarative base."""
sym = ctx.api.lookup_fully_qualified_or_none('sqlalchemy.sql.schema.MetaData')
if sym:
assert isinstance(sym.node, TypeInfo)
typ = Instance(sym.node, []) # type: Type
else:
typ = AnyType(TypeOfAny.special_form)
add_var_to_class('metadata', typ, info)


def decl_deco_hook(ctx: ClassDefContext) -> None:
"""Support declaring base class as declarative with a decorator.
Expand All @@ -98,6 +130,7 @@ class Base:
...
"""
set_declarative(ctx.cls.info)
add_metadata_var(ctx, ctx.cls.info)


def decl_info_hook(ctx):
Expand All @@ -119,6 +152,9 @@ def decl_info_hook(ctx):
ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))
set_declarative(info)

# TODO: check what else is added.
add_metadata_var(ctx, info)


def model_hook(ctx: FunctionContext) -> Type:
"""More precise model instantiation check.
Expand Down Expand Up @@ -146,6 +182,10 @@ def model_hook(ctx: FunctionContext) -> Type:
assert len(ctx.arg_names) == 1 # only **kwargs in generated __init__
assert len(ctx.arg_types) == 1
for actual_name, actual_type in zip(ctx.arg_names[0], ctx.arg_types[0]):
if actual_name is None:
# We can't check kwargs reliably.
# TODO: support TypedDict?
continue
if actual_name not in expected_types:
ctx.api.fail('Unexpected column "{}" for model "{}"'.format(actual_name, model.name()),
ctx.context)
Expand Down
4 changes: 2 additions & 2 deletions test/test-data/sqlalchemy-basics.test
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class User(Base):

user = User()
reveal_type(user.id) # E: Revealed type is 'builtins.int*'
reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[Union[builtins.unicode*, None]]'
reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[Union[builtins.str*, None]]'
[out]

[case testColumnFieldsInferredInstance_python2]
Expand All @@ -118,5 +118,5 @@ class User(Base):

user = User()
reveal_type(user.id) # E: Revealed type is 'builtins.int*'
reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[builtins.unicode*]'
reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[builtins.str*]'
[out]
49 changes: 49 additions & 0 deletions test/test-data/sqlalchemy-plugin-features.test
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,52 @@ class User(Base):
user: User
reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[Union[builtins.str*, None]]'
[out]

[case testAddedAttributesDeclared]
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, String

Base = declarative_base()

class User(Base):
__tablename__ = 'users'
id = Column(Integer(), primary_key=True)
name = Column(String(), default='John Doe', nullable=True)

user: User
reveal_type(User.metadata) # E: Revealed type is 'sqlalchemy.sql.schema.MetaData'
reveal_type(User.__table__) # E: Revealed type is 'sqlalchemy.sql.schema.Table'
[out]

[case testAddedAttributedDecorated]
from sqlalchemy.ext.declarative import as_declarative
from sqlalchemy import Column, Integer, String

@as_declarative()
class Base:
...

class User(Base):
__tablename__ = 'users'
id = Column(Integer(), primary_key=True)
name = Column(String(), default='John Doe', nullable=True)

user: User
reveal_type(User.metadata) # E: Revealed type is 'sqlalchemy.sql.schema.MetaData'
reveal_type(User.__table__) # E: Revealed type is 'sqlalchemy.sql.schema.Table'
[out]

[case testKwArgsModelOK]
from sqlalchemy import Column, Integer, String
from sqlalchemy.ext.declarative import declarative_base

Base = declarative_base()

class User(Base):
__tablename__ = 'users'
id = Column(Integer, primary_key=True)
name = Column(String)

record = {'name': 'John Doe'}
User(**record) # OK
[out]