From d7e04ae8affc1fd57353a7f6a7e8a9a342449021 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 17 Jul 2019 15:58:36 +0100 Subject: [PATCH 1/3] Add a fake __eq__ method to relationship() to avoid false positives with --strict-equality --- sqlalchemy-stubs/orm/relationships.pyi | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sqlalchemy-stubs/orm/relationships.pyi b/sqlalchemy-stubs/orm/relationships.pyi index ecb6b0b..d334b41 100644 --- a/sqlalchemy-stubs/orm/relationships.pyi +++ b/sqlalchemy-stubs/orm/relationships.pyi @@ -87,6 +87,8 @@ class RelationshipProperty(StrategizedProperty, Generic[_T_co]): def __ne__(self, other): ... @property def property(self): ... + # This doesn't exist at runtime, and Comparator is used instead, but it is hard to explain to mypy. + def __eq__(self, other: Any) -> Any: ... def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive, _resolve_conflict_map): ... def cascade_iterator(self, *args, **kwargs): ... From ec86542f2eb35fa8b99a9d93505316f8ed3b0bdf Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 17 Jul 2019 16:23:25 +0100 Subject: [PATCH 2/3] Add test --- test/test-data/sqlalchemy-basics.test | 23 +++++++++++++++++++++++ test/testsql.py | 8 ++++++++ 2 files changed, 31 insertions(+) diff --git a/test/test-data/sqlalchemy-basics.test b/test/test-data/sqlalchemy-basics.test index a215b46..fa9440c 100644 --- a/test/test-data/sqlalchemy-basics.test +++ b/test/test-data/sqlalchemy-basics.test @@ -120,3 +120,26 @@ user = User() reveal_type(user.id) # N: Revealed type is 'builtins.int*' reveal_type(User.name) # N: Revealed type is 'sqlalchemy.sql.schema.Column[builtins.unicode*]' [out] + +[case testRelationshipStrictEquality] +# flags: --strict-equality +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session + +Base = declarative_base() +session = Session() + +class User(Base): + __tablename__ = 'users' + id = Column(Integer(), primary_key=True) + other = relationship('Other') + +class Other(Base): + __tablename__ = 'other' + id = Column(Integer(), primary_key=True) + +other: Other +session.query(User).filter(User.other == other) +[out] diff --git a/test/testsql.py b/test/testsql.py index 4cefc4e..0d87850 100644 --- a/test/testsql.py +++ b/test/testsql.py @@ -3,6 +3,7 @@ import os import os.path import sys +import re import pytest # type: ignore # no pytest in typeshed @@ -54,6 +55,13 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: version = sys.version_info[:2] mypy_cmdline.append('--python-version={}'.format('.'.join(map(str, version)))) + program_text = '\n'.join(testcase.input) + flags = re.search('# flags: (.*)$', program_text, flags=re.MULTILINE) + if flags: + if flags: + flag_list = flags.group(1).split() + mypy_cmdline.extend(flag_list) + # Write the program to a file. program_path = os.path.join(test_temp_dir, 'main.py') mypy_cmdline.append(program_path) From 8300b2df5c51c084bc516b414a55abd08926b6d6 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 17 Jul 2019 23:27:17 +0100 Subject: [PATCH 3/3] Remove redundant if. --- test/testsql.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/testsql.py b/test/testsql.py index 0d87850..fb7a582 100644 --- a/test/testsql.py +++ b/test/testsql.py @@ -58,8 +58,7 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: program_text = '\n'.join(testcase.input) flags = re.search('# flags: (.*)$', program_text, flags=re.MULTILINE) if flags: - if flags: - flag_list = flags.group(1).split() + flag_list = flags.group(1).split() mypy_cmdline.extend(flag_list) # Write the program to a file.