Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhances Cursor type inference capabilities #504

Merged

Conversation

woosuk-choi-g
Copy link
Contributor

Related Issue

Overview

Enhances type inference capabilities. While there may be no significant impact in some use cases, it is better compared to before.

Use Case

connection = Connection() # -> pyathena.connection.Connection[pyathena.cursor.Cursor]
connection_cursor = connection.cursor() # -> pyathena.cursor.Cursor

async_connection = Connection(cursor_class=AsyncCursor) # -> pyathena.connection.Connection[pyathena.async_cursor.AsyncCursor]
async_connection_cursor = async_connection.cursor() # -> pyathena.async_cursor.AsyncCursor
dict_cursor = async_connection.cursor(DictCursor) # -> (pyathena.async_cursor.AsyncCursor | pyathena.cursor.DictCursor) (IDE seems to be pretty confusing)

factory_connection = connect(cursor_class=AsyncCursor) # -> pyathena.connection.Connection (factory can't be inferred)
factory_connection_cursor = factory_connection.cursor() # -> Any (factory can't be inferred)
factory_connection_functional_cursor = factory_connection.cursor(DictCursor) # -> pyathena.cursor.DictCursor

Limitation

  • using a mix of multiple cursor is inferred a union of cursor types
  • def connection factory method still not support type hint

@@ -250,14 +254,16 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

def cursor(self, cursor: Optional[Type[BaseCursor]] = None, **kwargs) -> BaseCursor:
def cursor(self, cursor: Optional[Type[FunctionalCursor]] = None, **kwargs):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def cursor(self, cursor: Optional[Type[FunctionalCursor]] = None, **kwargs):
def cursor(self, cursor: Optional[Type[FunctionalCursor]] = None, **kwargs): -> FunctionalCursor

The return type is missing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the cursor parameter is omitted, the type analyzer consistently infers Any.

def cursor(self, cursor: Optional[Type[FunctionalCursor]] = None, **kwargs): -> FunctionalCursor:
    return cursor()

_cursor = cursor() # -> Any

To enhance the clarity of the return type, I recommend using a union of types

def cursor(self, cursor: Optional[Type[FunctionalCursor]] = None, **kwargs): -> Union[FunctionalCursor, ConnectionCursor]:
    return cursor()

_cursor = cursor() # -> ConnectionCursor
_other_cursor = cursor(MyCursor) # -> (MyCursor | ConnectionCursor)

Alternatively, following your suggestion, we can ignore the generic ConnectionCursor type. I agree with this approach as it simplifies the code.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be a good choice to use the Union type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update Union return type 😊

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pyathena/connection.py:257:101: E501 Line too long (125 > 100)

You can format it with make fmt.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following error occurs with make chk.

$ make chk                                                                                                                                                                                                                                                                                                                                                              1043ms  Sun Jan 21 18:37:01 2024
pdm run ruff check .
pdm run ruff format --check .
42 files already formatted
pdm run mypy .
pyathena/__init__.py:60: error: Missing type parameters for generic type "Connection"  [type-arg]
pyathena/common.py:100: error: Missing type parameters for generic type "Connection"  [type-arg]
pyathena/common.py:137: error: Missing type parameters for generic type "Connection"  [type-arg]
pyathena/result_set.py:36: error: Missing type parameters for generic type "Connection"  [type-arg]
pyathena/result_set.py:43: error: Missing type parameters for generic type "Connection"  [type-arg]
pyathena/result_set.py:283: error: Missing type parameters for generic type "Connection"  [type-arg]
pyathena/result_set.py:286: error: Missing type parameters for generic type "Connection"  [type-arg]
pyathena/connection.py:72: error: Incompatible default for argument "cursor_class" (default has type "Type[Cursor]", argument has type "Type[ConnectionCursor]")  [assignment]
pyathena/connection.py:264: error: Incompatible types in assignment (expression has type "Type[ConnectionCursor]", variable has type "Type[FunctionalCursor]")  [assignment]
pyathena/pandas/util.py:138: error: Missing type parameters for generic type "Connection"  [type-arg]
pyathena/filesystem/s3.py:41: error: Missing type parameters for generic type "Connection"  [type-arg]
pyathena/arrow/result_set.py:54: error: Missing type parameters for generic type "Connection"  [type-arg]
pyathena/pandas/result_set.py:103: error: Missing type parameters for generic type "Connection"  [type-arg]
Found 13 errors in 8 files (checked 40 source files)
make: *** [chk] Error 1

@laughingman7743
Copy link
Owner

Thank you for your contribution. I have made some comments, please check them out.

@woosuk-choi-g
Copy link
Contributor Author

woosuk-choi-g commented Jan 24, 2024

Now, the type analyzer can accurately infer the Cursor type of Connection.

# reveal_type.py
connection_aysnc_cursor = connect(cursor_class=AsyncCursor)
reveal_type(connection_aysnc_cursor)

aysnc_cursor = connection_aysnc_cursor.cursor()
reveal_type(aysnc_cursor)

dict_cursor = connection_aysnc_cursor.cursor(DictCursor)
reveal_type(dict_cursor)

connection_cursor = Connection()
reveal_type(connection_cursor)

cursor = connection_cursor.cursor()
reveal_type(cursor)

connection_async = Connection(cursor_class=AsyncCursor)
reveal_type(connection_async)

async_cursor = connection_async.cursor()
reveal_type(async_cursor)

dict_cursor = connection_async.cursor(DictCursor)
reveal_type(dict_cursor)
$ mypy reveal_type.py
.vscode\main.py:8: note: Revealed type is "pyathena.connection.Connection[pyathena.async_cursor.AsyncCursor]"
.vscode\main.py:11: note: Revealed type is "pyathena.async_cursor.AsyncCursor"
.vscode\main.py:14: note: Revealed type is "pyathena.cursor.DictCursor"
.vscode\main.py:17: note: Revealed type is "pyathena.connection.Connection[pyathena.cursor.Cursor]" 
.vscode\main.py:20: note: Revealed type is "pyathena.cursor.Cursor"
.vscode\main.py:23: note: Revealed type is "pyathena.connection.Connection[pyathena.async_cursor.AsyncCursor]"
.vscode\main.py:26: note: Revealed type is "pyathena.async_cursor.AsyncCursor"
.vscode\main.py:29: note: Revealed type is "pyathena.cursor.DictCursor"

and fixed formatting error 😊

$ mypy .
Success: no issues found in 40 source files

Copy link
Owner

@laughingman7743 laughingman7743 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

pdm run ruff check .
WARNING: The following problems are found in your project:
  package-type: package-type = "library" has been renamed to distribution = true under [tool.pdm] table
Run pdm fix to fix all or pdm fix <name> to fix individual problem.
pdm run ruff format --check .
WARNING: The following problems are found in your project:
  package-type: package-type = "library" has been renamed to distribution = true under [tool.pdm] table
Run pdm fix to fix all or pdm fix <name> to fix individual problem.
42 files already formatted
pdm run mypy .
WARNING: The following problems are found in your project:
  package-type: package-type = "library" has been renamed to distribution = true under [tool.pdm] table
Run pdm fix to fix all or pdm fix <name> to fix individual problem.
Success: no issues found in 40 source files
pdm run pytest -n 8 --cov pyathena --cov-report html --cov-report term tests/pyathena/
WARNING: The following problems are found in your project:
  package-type: package-type = "library" has been renamed to distribution = true under [tool.pdm] table
Run pdm fix to fix all or pdm fix <name> to fix individual problem.
============================================================================================================================================= test session starts =============================================================================================================================================
platform darwin -- Python 3.11.3, pytest-7.4.4, pluggy-1.3.0
rootdir: /Users/foobar/github/PyAthena
configfile: pyproject.toml
plugins: cov-4.1.0, dependency-0.6.0, xdist-3.5.0
8 workers [498 items]
....................................................................................................................................................................................................................................................................................................... [ 59%]
...........................................................................................................................................................................................................                                                                                             [100%]
============================================================================================================================================== warnings summary ===============================================================================================================================================
tests/pyathena/spark/test_spark_cursor.py:77
tests/pyathena/spark/test_spark_cursor.py:77
tests/pyathena/spark/test_spark_cursor.py:77
tests/pyathena/spark/test_spark_cursor.py:77
tests/pyathena/spark/test_spark_cursor.py:77
tests/pyathena/spark/test_spark_cursor.py:77
tests/pyathena/spark/test_spark_cursor.py:77
tests/pyathena/spark/test_spark_cursor.py:77
  /Users/foobar/github/PyAthena/tests/pyathena/spark/test_spark_cursor.py:77: PytestUnknownMarkWarning: Unknown pytest.mark.depends - is this a typo?  You can register custom marks to avoid this warning - for details, see https://docs.pytest.org/en/stable/how-to/mark.html
    @pytest.mark.depends(on="test_spark_dataframe")

tests/pyathena/sqlalchemy/test_base.py: 61 warnings
  /Users/foobar/github/PyAthena/tests/pyathena/conftest.py:96: SADeprecationWarning: The dbapi() classmethod on dialect classes has been renamed to import_dbapi().  Implement an import_dbapi() classmethod directly on class <class 'pyathena.sqlalchemy.rest.AthenaRestDialect'> to remove this warning; the old .dbapi() classmethod may be maintained for backwards compatibility.
    return sqlalchemy.engine.create_engine(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html

---------- coverage: platform darwin, python 3.11.3-final-0 ----------
Name                                  Stmts   Miss  Cover
---------------------------------------------------------
pyathena/__init__.py                     43      9    79%
pyathena/arrow/__init__.py                0      0   100%
pyathena/arrow/async_cursor.py           42      0   100%
pyathena/arrow/converter.py              32      2    94%
pyathena/arrow/cursor.py                 90      1    99%
pyathena/arrow/result_set.py            148     14    91%
pyathena/arrow/util.py                   45      3    93%
pyathena/async_cursor.py                 54      1    98%
pyathena/common.py                      291     47    84%
pyathena/connection.py                  127     32    75%
pyathena/converter.py                    81      9    89%
pyathena/cursor.py                       72      1    99%
pyathena/error.py                        21      0   100%
pyathena/fastparquet/__init__.py          0      0   100%
pyathena/fastparquet/util.py             44      3    93%
pyathena/filesystem/__init__.py           0      0   100%
pyathena/filesystem/s3.py               279     74    73%
pyathena/filesystem/s3_object.py         34      0   100%
pyathena/formatter.py                    99      5    95%
pyathena/model.py                       490     13    97%
pyathena/pandas/__init__.py               3      0   100%
pyathena/pandas/async_cursor.py          44      0   100%
pyathena/pandas/converter.py             23      0   100%
pyathena/pandas/cursor.py                97      1    99%
pyathena/pandas/result_set.py           230     26    89%
pyathena/pandas/util.py                 155      6    96%
pyathena/result_set.py                  523     97    81%
pyathena/spark/__init__.py                0      0   100%
pyathena/spark/async_cursor.py           34      5    85%
pyathena/spark/common.py                188     48    74%
pyathena/spark/cursor.py                 32      2    94%
pyathena/sqlalchemy/__init__.py           0      0   100%
pyathena/sqlalchemy/arrow.py             15     15     0%
pyathena/sqlalchemy/base.py             507     73    86%
pyathena/sqlalchemy/pandas.py            19     19     0%
pyathena/sqlalchemy/requirements.py      95     95     0%
pyathena/sqlalchemy/rest.py               4      0   100%
pyathena/sqlalchemy/types.py             26      9    65%
pyathena/sqlalchemy/util.py               3      1    67%
pyathena/util.py                         31      1    97%
---------------------------------------------------------
TOTAL                                  4021    612    85%
Coverage HTML written to dir htmlcov

================================================================================================================================ 498 passed, 69 warnings in 324.53s (0:05:24) =================================================================================================================================

@laughingman7743 laughingman7743 merged commit d1d6af5 into laughingman7743:master Jan 25, 2024
@laughingman7743
Copy link
Owner

@laughingman7743
Copy link
Owner

@woosuk-choi-g Please check the following issues:
#506

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Mypy Error When using Connection.cursor method to instantiate cursor
2 participants