Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhances Cursor type inference capabilities #504

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions pyathena/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import os
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, TypeVar

from boto3.session import Session
from botocore.config import Config
Expand All @@ -23,7 +23,11 @@
_logger = logging.getLogger(__name__) # type: ignore


class Connection:
ConnectionCursor = TypeVar("ConnectionCursor", bound=BaseCursor)
FunctionalCursor = TypeVar("FunctionalCursor", bound=BaseCursor)
laughingman7743 marked this conversation as resolved.
Show resolved Hide resolved


class Connection(Generic[ConnectionCursor]):
_ENV_S3_STAGING_DIR: str = "AWS_ATHENA_S3_STAGING_DIR"
_ENV_WORK_GROUP: str = "AWS_ATHENA_WORK_GROUP"
_SESSION_PASSING_ARGS: List[str] = [
Expand Down Expand Up @@ -65,7 +69,7 @@ def __init__(
converter: Optional[Converter] = None,
formatter: Optional[Formatter] = None,
retry_config: Optional[RetryConfig] = None,
cursor_class: Type[BaseCursor] = Cursor,
cursor_class: Type[ConnectionCursor] = Cursor,
cursor_kwargs: Optional[Dict[str, Any]] = None,
kill_on_interrupt: bool = True,
session: Optional[Session] = None,
Expand Down Expand Up @@ -250,14 +254,16 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

def cursor(self, cursor: Optional[Type[BaseCursor]] = None, **kwargs) -> BaseCursor:
def cursor(self, cursor: Optional[Type[FunctionalCursor]] = None, **kwargs):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def cursor(self, cursor: Optional[Type[FunctionalCursor]] = None, **kwargs):
def cursor(self, cursor: Optional[Type[FunctionalCursor]] = None, **kwargs): -> FunctionalCursor

The return type is missing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the cursor parameter is omitted, the type analyzer consistently infers Any.

def cursor(self, cursor: Optional[Type[FunctionalCursor]] = None, **kwargs): -> FunctionalCursor:
    return cursor()

_cursor = cursor() # -> Any

To enhance the clarity of the return type, I recommend using a union of types

def cursor(self, cursor: Optional[Type[FunctionalCursor]] = None, **kwargs): -> Union[FunctionalCursor, ConnectionCursor]:
    return cursor()

_cursor = cursor() # -> ConnectionCursor
_other_cursor = cursor(MyCursor) # -> (MyCursor | ConnectionCursor)

Alternatively, following your suggestion, we can ignore the generic ConnectionCursor type. I agree with this approach as it simplifies the code.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be a good choice to use the Union type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update Union return type 😊

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pyathena/connection.py:257:101: E501 Line too long (125 > 100)

You can format it with make fmt.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following error occurs with make chk.

$ make chk                                                                                                                                                                                                                                                                                                                                                              1043ms  Sun Jan 21 18:37:01 2024
pdm run ruff check .
pdm run ruff format --check .
42 files already formatted
pdm run mypy .
pyathena/__init__.py:60: error: Missing type parameters for generic type "Connection"  [type-arg]
pyathena/common.py:100: error: Missing type parameters for generic type "Connection"  [type-arg]
pyathena/common.py:137: error: Missing type parameters for generic type "Connection"  [type-arg]
pyathena/result_set.py:36: error: Missing type parameters for generic type "Connection"  [type-arg]
pyathena/result_set.py:43: error: Missing type parameters for generic type "Connection"  [type-arg]
pyathena/result_set.py:283: error: Missing type parameters for generic type "Connection"  [type-arg]
pyathena/result_set.py:286: error: Missing type parameters for generic type "Connection"  [type-arg]
pyathena/connection.py:72: error: Incompatible default for argument "cursor_class" (default has type "Type[Cursor]", argument has type "Type[ConnectionCursor]")  [assignment]
pyathena/connection.py:264: error: Incompatible types in assignment (expression has type "Type[ConnectionCursor]", variable has type "Type[FunctionalCursor]")  [assignment]
pyathena/pandas/util.py:138: error: Missing type parameters for generic type "Connection"  [type-arg]
pyathena/filesystem/s3.py:41: error: Missing type parameters for generic type "Connection"  [type-arg]
pyathena/arrow/result_set.py:54: error: Missing type parameters for generic type "Connection"  [type-arg]
pyathena/pandas/result_set.py:103: error: Missing type parameters for generic type "Connection"  [type-arg]
Found 13 errors in 8 files (checked 40 source files)
make: *** [chk] Error 1

kwargs.update(self.cursor_kwargs)
if not cursor:
cursor = self.cursor_class
if cursor:
_cursor = cursor
else:
_cursor = self.cursor_class
converter = kwargs.pop("converter", self._converter)
if not converter:
converter = cursor.get_default_converter(kwargs.get("unload", False))
return cursor(
converter = _cursor.get_default_converter(kwargs.get("unload", False))
return _cursor(
connection=self,
converter=converter,
formatter=kwargs.pop("formatter", self._formatter),
Expand Down