diff --git a/external/mypy b/external/mypy index 01c2686..f82319f 160000 --- a/external/mypy +++ b/external/mypy @@ -1 +1 @@ -Subproject commit 01c268644d1d22506442df4e21b39c04710b7e8b +Subproject commit f82319f85043ac2afa9708cc90a7fec32a3aa767 diff --git a/setup.py b/setup.py index 81fc445..ae95c7f 100644 --- a/setup.py +++ b/setup.py @@ -22,10 +22,10 @@ def find_stub_files(): author='Ivan Levkivskyi', author_email='levkivskyi@gmail.com', license='MIT License', - py_modules=[], + py_modules=['sqlmypy', 'sqltyping'], install_requires=[ 'typing-extensions>=3.6.5' ], packages=['sqlalchemy-stubs'], - package_data={'sqlalchemy-stubs': find_stub_files()} + package_data={'sqlalchemy-stubs': find_stub_files()}, ) diff --git a/sqlalchemy-plugin/sqlmypy.py b/sqlalchemy-plugin/sqlmypy.py deleted file mode 100644 index e69de29..0000000 diff --git a/sqlalchemy-stubs/orm/__init__.pyi b/sqlalchemy-stubs/orm/__init__.pyi index 57665a4..ff5c7ad 100644 --- a/sqlalchemy-stubs/orm/__init__.pyi +++ b/sqlalchemy-stubs/orm/__init__.pyi @@ -43,7 +43,7 @@ from .strategy_options import Load as Load def create_session(bind: Optional[Any] = ..., **kwargs): ... -relationship = RelationshipProperty[Any] +relationship = RelationshipProperty def relation(*arg, **kw): ... def dynamic_loader(argument, **kw): ... diff --git a/sqlalchemy-stubs/orm/relationships.pyi b/sqlalchemy-stubs/orm/relationships.pyi index dcc7798..ecb6b0b 100644 --- a/sqlalchemy-stubs/orm/relationships.pyi +++ b/sqlalchemy-stubs/orm/relationships.pyi @@ -1,4 +1,4 @@ -from typing import Any, Optional, Generic, TypeVar, Union, overload +from typing import Any, Optional, Generic, TypeVar, Union, overload, Type from .interfaces import ( MANYTOMANY as MANYTOMANY, MANYTOONE as MANYTOONE, @@ -11,12 +11,11 @@ def remote(expr): ... def foreign(expr): ... -_T = TypeVar('_T') _T_co = TypeVar('_T_co', covariant=True) # Note: typical use case is where argument is a string, so this will require -# a plugin to infer '_T', otherwise a user will need to write an explicit annotation. +# a plugin to infer '_T_co', otherwise a user will need to write an explicit annotation. # It is not clear whether RelationshipProperty is covariant at this stage since # many types are still missing. class RelationshipProperty(StrategizedProperty, Generic[_T_co]): @@ -55,7 +54,7 @@ class RelationshipProperty(StrategizedProperty, Generic[_T_co]): order_by: Any = ... back_populates: Any = ... backref: Any = ... - def __init__(self, argument, secondary: Optional[Any] = ..., + def __init__(self, argument: Any, secondary: Optional[Any] = ..., primaryjoin: Optional[Any] = ..., secondaryjoin: Optional[Any] = ..., foreign_keys: Optional[Any] = ..., uselist: Optional[Any] = ..., order_by: Any = ..., backref: Optional[Any] = ..., diff --git a/sqlmypy.py b/sqlmypy.py new file mode 100644 index 0000000..e9d2819 --- /dev/null +++ b/sqlmypy.py @@ -0,0 +1,298 @@ +from mypy.plugin import Plugin, FunctionContext, ClassDefContext +from mypy.plugins.common import add_method +from mypy.nodes import( + NameExpr, Expression, StrExpr, TypeInfo, ClassDef, Block, SymbolTable, SymbolTableNode, GDEF, + Argument, Var, ARG_STAR2 +) +from mypy.types import ( + UnionType, NoneTyp, Instance, Type, AnyType, TypeOfAny, UninhabitedType, CallableType +) +from mypy.typevars import fill_typevars_with_any + +from typing import Optional, Callable, Dict, TYPE_CHECKING +if TYPE_CHECKING: + from typing_extensions import Final + +COLUMN_NAME = 'sqlalchemy.sql.schema.Column' # type: Final +RELATIONSHIP_NAME = 'sqlalchemy.orm.relationships.RelationshipProperty' # type: Final + + +def is_declarative(info: TypeInfo) -> bool: + """Check if this is a subclass of a declarative base.""" + if info.mro: + for base in info.mro: + metadata = base.metadata.get('sqlalchemy') + if metadata and metadata.get('declarative_base'): + return True + return False + + +def set_declarative(info: TypeInfo) -> None: + """Record given class as a declarative base.""" + info.metadata.setdefault('sqlalchemy', {})['declarative_base'] = True + + +class BasicSQLAlchemyPlugin(Plugin): + """Basic plugin to support simple operations with models. + + Currently supported functionality: + * Recognize dynamically defined declarative bases. + * Add an __init__() method to models. + * Provide better types for 'Column's and 'RelationshipProperty's + using flags 'primary_key', 'nullable', 'uselist', etc. + """ + def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], Type]]: + if fullname == COLUMN_NAME: + return column_hook + if fullname == RELATIONSHIP_NAME: + return relationship_hook + sym = self.lookup_fully_qualified(fullname) + if sym and isinstance(sym.node, TypeInfo): + # May be a model instantiation + if is_declarative(sym.node): + return model_hook + return None + + def get_dynamic_class_hook(self, fullname): + if fullname == 'sqlalchemy.ext.declarative.api.declarative_base': + return decl_info_hook + return None + + def get_class_decorator_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]: + if fullname == 'sqlalchemy.ext.declarative.api.as_declarative': + return decl_deco_hook + return None + + def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]: + sym = self.lookup_fully_qualified(fullname) + if sym and isinstance(sym.node, TypeInfo): + if is_declarative(sym.node): + return add_init_hook + return None + + +def add_init_hook(ctx: ClassDefContext) -> None: + """Add a dummy __init__() to a model and record it is generated. + + Instantiation will be checked more precisely when we inferred types + (using get_function_hook and model_hook). + """ + if '__init__' in ctx.cls.info.names: + # Don't override existing definition. + return + typ = AnyType(TypeOfAny.special_form) + var = Var('kwargs', typ) + kw_arg = Argument(variable=var, type_annotation=typ, initializer=None, kind=ARG_STAR2) + add_method(ctx, '__init__', [kw_arg], NoneTyp()) + ctx.cls.info.metadata.setdefault('sqlalchemy', {})['generated_init'] = True + + +def decl_deco_hook(ctx: ClassDefContext) -> None: + """Support declaring base class as declarative with a decorator. + + For example: + from from sqlalchemy.ext.declarative import as_declarative + + @as_declarative + class Base: + ... + """ + set_declarative(ctx.cls.info) + + +def decl_info_hook(ctx): + """Support dynamically defining declarative bases. + + For example: + from sqlalchemy.ext.declarative import declarative_base + + Base = declarative_base() + """ + class_def = ClassDef(ctx.name, Block([])) + class_def.fullname = ctx.api.qualified_name(ctx.name) + + info = TypeInfo(SymbolTable(), class_def, ctx.api.cur_mod_id) + class_def.info = info + obj = ctx.api.builtin_type('builtins.object') + info.mro = [info, obj.type] + info.bases = [obj] + ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info)) + set_declarative(info) + + +def model_hook(ctx: FunctionContext) -> Type: + """More precise model instantiation check. + + Note: sub-models are not supported. + Note: this is still not perfect, since the context for inference of + argument types is 'Any'. + """ + assert isinstance(ctx.default_return_type, Instance) + model = ctx.default_return_type.type + metadata = model.metadata.get('sqlalchemy') + if not metadata or not metadata.get('generated_init'): + return ctx.default_return_type + + # Collect column names and types defined in the model + # TODO: cache this? + expected_types = {} # type: Dict[str, Type] + for name, sym in model.names.items(): + if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance): + tp = sym.node.type + if tp.type.fullname() in (COLUMN_NAME, RELATIONSHIP_NAME): + assert len(tp.args) == 1 + expected_types[name] = tp.args[0] + + assert len(ctx.arg_names) == 1 # only **kwargs in generated __init__ + assert len(ctx.arg_types) == 1 + for actual_name, actual_type in zip(ctx.arg_names[0], ctx.arg_types[0]): + if actual_name not in expected_types: + ctx.api.fail('Unexpected column "{}" for model "{}"'.format(actual_name, model.name()), + ctx.context) + continue + # Using private API to simplify life. + ctx.api.check_subtype(actual_type, expected_types[actual_name], + ctx.context, + 'Incompatible type for "{}" of "{}"'.format(actual_name, model.name()), + 'got', 'expected') + return ctx.default_return_type + + +def get_argument_by_name(ctx: FunctionContext, name: str) -> Optional[Expression]: + """Return the expression for the specific argument. + + This helper should only be used with non-star arguments. + """ + if name not in ctx.callee_arg_names: + return None + idx = ctx.callee_arg_names.index(name) + args = ctx.args[idx] + if len(args) != 1: + # Either an error or no value passed. + return None + return args[0] + + +def get_argtype_by_name(ctx: FunctionContext, name: str) -> Optional[Type]: + """Same as above but for argument type.""" + if name not in ctx.callee_arg_names: + return None + idx = ctx.callee_arg_names.index(name) + arg_types = ctx.arg_types[idx] + if len(arg_types) != 1: + # Either an error or no value passed. + return None + return arg_types[0] + + +def column_hook(ctx: FunctionContext) -> Type: + """Infer better types for Column calls. + + Examples: + Column(String) -> Column[Optional[str]] + Column(String, primary_key=True) -> Column[str] + Column(String, nullable=False) -> Column[str] + Column(String, default=...) -> Column[str] + Column(String, default=..., nullable=True) -> Column[Optional[str]] + + TODO: check the type of 'default'. + """ + assert isinstance(ctx.default_return_type, Instance) + + nullable_arg = get_argument_by_name(ctx, 'nullable') + primary_arg = get_argument_by_name(ctx, 'primary_key') + default_arg = get_argument_by_name(ctx, 'default') + + if nullable_arg: + nullable = parse_bool(nullable_arg) + else: + if primary_arg: + nullable = not parse_bool(primary_arg) + else: + nullable = default_arg is None + # TODO: Add support for literal types. + + if not nullable: + return ctx.default_return_type + assert len(ctx.default_return_type.args) == 1 + arg_type = ctx.default_return_type.args[0] + return Instance(ctx.default_return_type.type, [UnionType([arg_type, NoneTyp()])], + line=ctx.default_return_type.line, + column=ctx.default_return_type.column) + + +def relationship_hook(ctx: FunctionContext) -> Type: + """Support basic use cases for relationships. + + Examples: + from sqlalchemy.orm import relationship + + from one import OneModel + if TYPE_CHECKING: + from other import OtherModel + + class User(Base): + __tablename__ = 'users' + id = Column(Integer(), primary_key=True) + one = relationship(OneModel) + other = relationship("OtherModel") + + This also tries to infer the type argument for 'RelationshipProperty' + using the 'uselist' flag. + """ + assert isinstance(ctx.default_return_type, Instance) + original_type_arg = ctx.default_return_type.args[0] + has_annotation = not isinstance(original_type_arg, UninhabitedType) + + arg = get_argument_by_name(ctx, 'argument') + arg_type = get_argtype_by_name(ctx, 'argument') + + uselist_arg = get_argument_by_name(ctx, 'uselist') + + if isinstance(arg, StrExpr): + name = arg.value + # Private API for local lookup, but probably needs to be public. + try: + sym = ctx.api.lookup_qualified(name) # type: Optional[SymbolTableNode] + except (KeyError, AssertionError): + sym = None + if sym and isinstance(sym.node, TypeInfo): + new_arg = fill_typevars_with_any(sym.node) + else: + ctx.api.fail('Cannot find model "{}"'.format(name), ctx.context) + ctx.api.note('Only imported models can be found; use "if TYPE_CHECKING: ..." to avoid import cycles', ctx.context) + new_arg = AnyType(TypeOfAny.from_error) + else: + if isinstance(arg_type, CallableType) and arg_type.is_type_obj(): + new_arg = fill_typevars_with_any(arg_type.type_object()) + else: + # Something complex, stay silent for now. + new_arg = AnyType(TypeOfAny.special_form) + + # We figured out, the model type. Now check if we need to wrap it in Iterable + if uselist_arg: + if parse_bool(uselist_arg): + new_arg = ctx.api.named_generic_type('typing.Iterable', [new_arg]) + else: + if has_annotation: + # If there is an annotation we use it as a source of truth. + # This will cause false negatives, but it is better than lots of false positives. + new_arg = original_type_arg + + return Instance(ctx.default_return_type.type, [new_arg], + line=ctx.default_return_type.line, + column=ctx.default_return_type.column) + + +# We really need to add this to TypeChecker API +def parse_bool(expr: Expression) -> Optional[bool]: + if isinstance(expr, NameExpr): + if expr.fullname == 'builtins.True': + return True + if expr.fullname == 'builtins.False': + return False + return None + + +def plugin(version): + return BasicSQLAlchemyPlugin diff --git a/sqlalchemy-typing/sqltyping.py b/sqltyping.py similarity index 100% rename from sqlalchemy-typing/sqltyping.py rename to sqltyping.py diff --git a/test/sqlalchemy.ini b/test/sqlalchemy.ini new file mode 100644 index 0000000..9f1b8e1 --- /dev/null +++ b/test/sqlalchemy.ini @@ -0,0 +1,3 @@ +[mypy] +plugins = sqlmypy + diff --git a/test/test-data/sqlalchemy-basics.test b/test/test-data/sqlalchemy-basics.test index 3c9774f..9fecf29 100644 --- a/test/test-data/sqlalchemy-basics.test +++ b/test/test-data/sqlalchemy-basics.test @@ -1,10 +1,8 @@ [case testColumnFieldsInferred] -from typing import Any - from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import Column, Integer, String -Base: Any = declarative_base() +Base = declarative_base() class User(Base): __tablename__ = 'users' @@ -13,7 +11,7 @@ class User(Base): user: User reveal_type(user.id) # E: Revealed type is 'builtins.int*' -reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[builtins.str*]' +reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[Union[builtins.str*, None]]' [out] [case testTypeEngineCovariance] @@ -29,17 +27,15 @@ func(String()) # E: Value of type variable "T" of "func" cannot be "str" [out] [case testColumnFieldsInferredInstance] -from typing import Any - from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import Column, Integer, String -Base: Any = declarative_base() +Base = declarative_base() class User(Base): __tablename__ = 'users' id = Column(Integer(), primary_key=True) - name = Column(String()) + name = Column(String(), nullable=False) user: User reveal_type(user.id) # E: Revealed type is 'builtins.int*' @@ -47,30 +43,38 @@ reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[buil [out] [case testColumnFieldsRelationship] -from typing import Any +from typing import TYPE_CHECKING -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import Column, Integer, String -from sqlalchemy.orm import relationship -from sqlalchemy.orm.properties import RelationshipProperty - -Base: Any = declarative_base() +from sqlalchemy.orm import relationship, RelationshipProperty +from base import Base +if TYPE_CHECKING: + from other import Other class User(Base): __tablename__ = 'users' id = Column(Integer(), primary_key=True) name = Column(String()) - other: RelationshipProperty[Other] = relationship('Other') + other = relationship('Other', uselist=False) + another: RelationshipProperty[Other] = relationship('Other') + +user: User +reveal_type(user.other) # E: Revealed type is 'other.Other*' +reveal_type(User.other) # E: Revealed type is 'sqlalchemy.orm.relationships.RelationshipProperty[other.Other*]' +reveal_type(user.other.name) # E: Revealed type is 'builtins.str*' + +[file other.py] +from sqlalchemy import Column, Integer, String +from base import Base class Other(Base): __tablename__ = 'other' id = Column(Integer(), primary_key=True) - name = Column(String()) + name = Column(String(), nullable=False) -user: User -reveal_type(user.other) # E: Revealed type is 'main.Other*' -reveal_type(User.other) # E: Revealed type is 'sqlalchemy.orm.relationships.RelationshipProperty[main.Other*]' -reveal_type(user.other.name) # E: Revealed type is 'builtins.str*' +[file base.py] +from sqlalchemy.ext.declarative import declarative_base +Base = declarative_base() [out] [case testTableColumns] @@ -86,12 +90,10 @@ reveal_type(users.c.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column*[ [out] [case testColumnFieldsInferred_python2] -from typing import Any - from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import Column, Integer, String -Base = declarative_base() # type: Any +Base = declarative_base() class User(Base): __tablename__ = 'users' @@ -100,21 +102,19 @@ class User(Base): user = User() reveal_type(user.id) # E: Revealed type is 'builtins.int*' -reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[builtins.unicode*]' +reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[Union[builtins.unicode*, None]]' [out] [case testColumnFieldsInferredInstance_python2] -from typing import Any - from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import Column, Integer, String -Base = declarative_base() # type: Any +Base = declarative_base() class User(Base): __tablename__ = 'users' id = Column(Integer(), primary_key=True) - name = Column(String()) + name = Column(String(), default='John Doe') user = User() reveal_type(user.id) # E: Revealed type is 'builtins.int*' diff --git a/test/test-data/sqlalchemy-plugin-features.test b/test/test-data/sqlalchemy-plugin-features.test new file mode 100644 index 0000000..92cf4ac --- /dev/null +++ b/test/test-data/sqlalchemy-plugin-features.test @@ -0,0 +1,172 @@ +[case testModelInitColumnDeclared] +# flags: --strict-optional +from sqlalchemy import Column, Integer, String +from base import Base +from typing import Optional + +class User(Base): + __tablename__ = 'users' + id = Column(Integer(), primary_key=True) + name = Column(String()) + +oi: Optional[int] +os: Optional[str] + +User() +User(1, 2) # E: Too many arguments for "User" +User(id=int(), name=str()) +User(id=oi) # E: Incompatible type for "id" of "User" (got "Optional[int]", expected "int") +User(name=os) + +[file base.py] +from sqlalchemy.ext.declarative import declarative_base +Base = declarative_base() +[out] + +[case testModelInitColumnDecorated] +# flags: --strict-optional +from sqlalchemy import Column, Integer, String +from base import Base +from typing import Optional + +class User(Base): + __tablename__ = 'users' + id = Column(Integer(), primary_key=True) + name = Column(String()) + +oi: Optional[int] +os: Optional[str] + +User() +User(1, 2) # E: Too many arguments for "User" +User(id=int(), name=str()) +User(id=oi) # E: Incompatible type for "id" of "User" (got "Optional[int]", expected "int") +User(name=os) + +[file base.py] +from sqlalchemy.ext.declarative import as_declarative +@as_declarative() +class Base: + ... +[out] + +[case testModelInitRelationship] +from typing import TYPE_CHECKING, List + +from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import relationship, RelationshipProperty +from base import Base +if TYPE_CHECKING: + from other import Other + +class User(Base): + __tablename__ = 'users' + id = Column(Integer(), primary_key=True) + other = relationship('Other', uselist=False) + many_others = relationship(Other, uselist=True) + +o: Other +mo: List[Other] +User() +User(other=o, many_others=mo) +User(other=mo) # E: Incompatible type for "other" of "User" (got "List[Other]", expected "Other") +User(unknown=42) # E: Unexpected column "unknown" for model "User" + +[file other.py] +from sqlalchemy import Column, Integer, String +from base import Base + +class Other(Base): + __tablename__ = 'other' + id = Column(Integer(), primary_key=True) + name = Column(String(), nullable=False) + +[file base.py] +from sqlalchemy.ext.declarative import declarative_base +Base = declarative_base() +[out] + +[case testRelationshipType] +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import relationship, RelationshipProperty +from typing import Iterable + +Base = declarative_base() + +class User(Base): + __tablename__ = 'users' + id = Column(Integer(), primary_key=True) + first_other = relationship(Other) + second_other = relationship(Other, uselist=True) + bad_other: RelationshipProperty[int] = relationship(Other, uselist=False) # E: Incompatible types in assignment (expression has type "RelationshipProperty[Other]", variable has type "RelationshipProperty[int]") + +user = User() +reveal_type(user.first_other) # E: Revealed type is 'main.Other*' +reveal_type(user.second_other) # E: Revealed type is 'typing.Iterable*[main.Other]' + +class Other(Base): + __tablename__ = 'other' + id = Column(Integer(), primary_key=True) +[out] + +[case testRelationshipString] +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import relationship, RelationshipProperty +from typing import Iterable + +Base = declarative_base() + +class User(Base): + __tablename__ = 'users' + id = Column(Integer(), primary_key=True) + first_other = relationship('Other', uselist=False) + second_other = relationship('Other', uselist=True) + second_bad_other = relationship('What') # E: Cannot find model "What" \ + # N: Only imported models can be found; use "if TYPE_CHECKING: ..." to avoid import cycles + +user = User() +reveal_type(user.first_other) # E: Revealed type is 'main.Other*' +reveal_type(user.second_other) # E: Revealed type is 'typing.Iterable*[main.Other]' + +class Other(Base): + __tablename__ = 'other' + id = Column(Integer(), primary_key=True) +[out] + +[case testRelationshipAnnotated] +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import relationship, RelationshipProperty +from typing import Iterable + +Base = declarative_base() + +class User(Base): + __tablename__ = 'users' + id = Column(Integer(), primary_key=True) + first_other: RelationshipProperty[Other] = relationship('Other') + second_other: RelationshipProperty[Iterable[Other]] = relationship(Other, uselist=True) + third_other: RelationshipProperty[Other] = relationship(Other, uselist=False) + bad_other: RelationshipProperty[Other] = relationship('Other', uselist=True) # E: Incompatible types in assignment (expression has type "RelationshipProperty[Iterable[Other]]", variable has type "RelationshipProperty[Other]") + +class Other(Base): + __tablename__ = 'other' + id = Column(Integer(), primary_key=True) +[out] + +[case testColumnCombo] +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import Column, Integer, String + +Base = declarative_base() + +class User(Base): + __tablename__ = 'users' + id = Column(Integer(), primary_key=True) + name = Column(String(), default='John Doe', nullable=True) + +user: User +reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[Union[builtins.str*, None]]' +[out] diff --git a/test/test-data/sqlalchemy-sql-elements.test b/test/test-data/sqlalchemy-sql-elements.test index ebc0d60..f6a0d77 100644 --- a/test/test-data/sqlalchemy-sql-elements.test +++ b/test/test-data/sqlalchemy-sql-elements.test @@ -85,7 +85,7 @@ Base: Any = declarative_base() class Model(Base): __tablename__ = 'users' - tags = Column(ARRAY(String(16))) + tags = Column(ARRAY(String(16)), nullable=False) tags: Set[str] = set() reveal_type(cast(tags, Model.tags.type)) # E: Revealed type is 'sqlalchemy.sql.elements.Cast[builtins.list*[builtins.str*]]' diff --git a/test/testsql.py b/test/testsql.py index 265e08a..d04dd6c 100644 --- a/test/testsql.py +++ b/test/testsql.py @@ -14,6 +14,7 @@ this_file_dir = os.path.dirname(os.path.realpath(__file__)) prefix = os.path.dirname(this_file_dir) +inipath = os.path.abspath(os.path.join(prefix, 'test')) # Locations of test data files such as test case descriptions (.test). test_data_prefix = os.path.join(prefix, 'test', 'test-data') @@ -24,7 +25,8 @@ class SQLDataSuite(DataSuite): 'sqlalchemy-sql-elements.test', 'sqlalchemy-sql-sqltypes.test', 'sqlalchemy-sql-selectable.test', - 'sqlalchemy-sql-schema.test'] + 'sqlalchemy-sql-schema.test', + 'sqlalchemy-plugin-features.test'] data_prefix = test_data_prefix def run_case(self, testcase: DataDrivenTestCase) -> None: @@ -33,6 +35,7 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: mypy_cmdline = [ '--show-traceback', '--no-silence-site-packages', + '--config-file={}/sqlalchemy.ini'.format(inipath), ] py2 = testcase.name.lower().endswith('python2') if py2: