Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
script: |
set -e
pytest
- name: "run direct typecheck"
- name: "run typecheck on stubs"
python: 3.6
script: |
set -e
Expand All @@ -29,6 +29,11 @@ jobs:
python: 3.6
script: |
flake8 sqlalchemy-stubs
- name: "run typecheck on plugin"
python: 3.6
script: |
set -e
MYPYPATH=external/mypy python3 -m mypy --disallow-untyped-defs sqlmypy.py


before_install: |
Expand Down
49 changes: 30 additions & 19 deletions sqlmypy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from mypy.mro import calculate_mro, MroError
from mypy.plugin import Plugin, FunctionContext, ClassDefContext
from mypy.plugin import (
Plugin, FunctionContext, ClassDefContext, DynamicClassDefContext,
SemanticAnalyzerPluginInterface
)
from mypy.plugins.common import add_method
from mypy.nodes import (
NameExpr, Expression, StrExpr, TypeInfo, ClassDef, Block, SymbolTable, SymbolTableNode, GDEF,
Expand All @@ -10,10 +13,13 @@
)
from mypy.typevars import fill_typevars_with_any

from typing import Optional, Callable, Dict, TYPE_CHECKING, List
from typing import Optional, Callable, Dict, TYPE_CHECKING, List, Type as TypingType, TypeVar
if TYPE_CHECKING:
from typing_extensions import Final

T = TypeVar('T')
CB = Optional[Callable[[T], None]]

COLUMN_NAME = 'sqlalchemy.sql.schema.Column' # type: Final
RELATIONSHIP_NAME = 'sqlalchemy.orm.relationships.RelationshipProperty' # type: Final

Expand Down Expand Up @@ -54,17 +60,17 @@ def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext
return model_hook
return None

def get_dynamic_class_hook(self, fullname):
def get_dynamic_class_hook(self, fullname: str) -> CB[DynamicClassDefContext]:
if fullname == 'sqlalchemy.ext.declarative.api.declarative_base':
return decl_info_hook
return None

def get_class_decorator_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
def get_class_decorator_hook(self, fullname: str) -> CB[ClassDefContext]:
if fullname == 'sqlalchemy.ext.declarative.api.as_declarative':
return decl_deco_hook
return None

def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
def get_base_class_hook(self, fullname: str) -> CB[ClassDefContext]:
sym = self.lookup_fully_qualified(fullname)
if sym and isinstance(sym.node, TypeInfo):
if is_declarative(sym.node):
Expand Down Expand Up @@ -109,9 +115,9 @@ def add_model_init_hook(ctx: ClassDefContext) -> None:
add_var_to_class('__table__', typ, ctx.cls.info)


def add_metadata_var(ctx: ClassDefContext, info: TypeInfo) -> None:
def add_metadata_var(api: SemanticAnalyzerPluginInterface, info: TypeInfo) -> None:
"""Add .metadata attribute to a declarative base."""
sym = ctx.api.lookup_fully_qualified_or_none('sqlalchemy.sql.schema.MetaData')
sym = api.lookup_fully_qualified_or_none('sqlalchemy.sql.schema.MetaData')
if sym:
assert isinstance(sym.node, TypeInfo)
typ = Instance(sym.node, []) # type: Type
Expand All @@ -131,10 +137,10 @@ class Base:
...
"""
set_declarative(ctx.cls.info)
add_metadata_var(ctx, ctx.cls.info)
add_metadata_var(ctx.api, ctx.cls.info)


def decl_info_hook(ctx):
def decl_info_hook(ctx: DynamicClassDefContext) -> None:
"""Support dynamically defining declarative bases.

For example:
Expand Down Expand Up @@ -177,7 +183,7 @@ def decl_info_hook(ctx):
set_declarative(info)

# TODO: check what else is added.
add_metadata_var(ctx, info)
add_metadata_var(ctx.api, info)


def model_hook(ctx: FunctionContext) -> Type:
Expand Down Expand Up @@ -211,13 +217,15 @@ def model_hook(ctx: FunctionContext) -> Type:
# TODO: support TypedDict?
continue
if actual_name not in expected_types:
ctx.api.fail('Unexpected column "{}" for model "{}"'.format(actual_name, model.name()),
ctx.api.fail('Unexpected column "{}" for model "{}"'.format(actual_name,
model.name()),
ctx.context)
continue
# Using private API to simplify life.
ctx.api.check_subtype(actual_type, expected_types[actual_name],
ctx.api.check_subtype(actual_type, expected_types[actual_name], # type: ignore
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_name, model.name()),
'Incompatible type for "{}" of "{}"'.format(actual_name,
model.name()),
'got', 'expected')
return ctx.default_return_type

Expand Down Expand Up @@ -315,16 +323,19 @@ class User(Base):

if isinstance(arg, StrExpr):
name = arg.value
# Private API for local lookup, but probably needs to be public.
sym = None # type: Optional[SymbolTableNode]
try:
sym = ctx.api.lookup_qualified(name) # type: Optional[SymbolTableNode]
# Private API for local lookup, but probably needs to be public.
sym = ctx.api.lookup_qualified(name) # type: ignore
except (KeyError, AssertionError):
sym = None
pass
if sym and isinstance(sym.node, TypeInfo):
new_arg = fill_typevars_with_any(sym.node)
new_arg = fill_typevars_with_any(sym.node) # type: Type
else:
ctx.api.fail('Cannot find model "{}"'.format(name), ctx.context)
ctx.api.note('Only imported models can be found; use "if TYPE_CHECKING: ..." to avoid import cycles',
# TODO: Add note() to public API.
ctx.api.note('Only imported models can be found;' # type: ignore
' use "if TYPE_CHECKING: ..." to avoid import cycles',
ctx.context)
new_arg = AnyType(TypeOfAny.from_error)
else:
Expand Down Expand Up @@ -359,5 +370,5 @@ def parse_bool(expr: Expression) -> Optional[bool]:
return None


def plugin(version):
def plugin(version: str) -> TypingType[Plugin]:
return BasicSQLAlchemyPlugin