Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhances Cursor type inference capabilities #504

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions pyathena/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from __future__ import annotations

import datetime
from typing import TYPE_CHECKING, FrozenSet, Type
from typing import TYPE_CHECKING, Any, FrozenSet, Type, overload

from pyathena.error import * # noqa

if TYPE_CHECKING:
from pyathena.connection import Connection
from pyathena.connection import Connection, ConnectionCursor
from pyathena.cursor import Cursor

__version__ = "3.1.0"
user_agent_extra: str = f"PyAthena/{__version__}"
Expand Down Expand Up @@ -57,7 +58,19 @@ def __hash__(self):
Timestamp: Type[datetime.datetime] = datetime.datetime


def connect(*args, **kwargs) -> "Connection":
@overload
def connect(*args, cursor_class: None = ..., **kwargs) -> "Connection[Cursor]":
...


@overload
def connect(
*args, cursor_class: Type[ConnectionCursor], **kwargs
) -> "Connection[ConnectionCursor]":
...


def connect(*args, **kwargs) -> "Connection[Any]":
from pyathena.connection import Connection

return Connection(*args, **kwargs)
2 changes: 1 addition & 1 deletion pyathena/arrow/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class AthenaArrowResultSet(AthenaResultSet):

def __init__(
self,
connection: "Connection",
connection: "Connection[Any]",
converter: Converter,
query_execution: AthenaQueryExecution,
arraysize: int,
Expand Down
4 changes: 2 additions & 2 deletions pyathena/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class BaseCursor(metaclass=ABCMeta):

def __init__(
self,
connection: "Connection",
connection: "Connection[Any]",
converter: Converter,
formatter: Formatter,
retry_config: RetryConfig,
Expand Down Expand Up @@ -134,7 +134,7 @@ def get_default_converter(unload: bool = False) -> Union[DefaultTypeConverter, A
return DefaultTypeConverter()

@property
def connection(self) -> "Connection":
def connection(self) -> "Connection[Any]":
return self._connection

def _build_start_query_execution_request(
Expand Down
105 changes: 96 additions & 9 deletions pyathena/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,19 @@
import logging
import os
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Optional,
Type,
TypeVar,
Union,
cast,
overload,
)

from boto3.session import Session
from botocore.config import Config
Expand All @@ -23,7 +35,11 @@
_logger = logging.getLogger(__name__) # type: ignore


class Connection:
ConnectionCursor = TypeVar("ConnectionCursor", bound=BaseCursor)
FunctionalCursor = TypeVar("FunctionalCursor", bound=BaseCursor)
laughingman7743 marked this conversation as resolved.
Show resolved Hide resolved


class Connection(Generic[ConnectionCursor]):
_ENV_S3_STAGING_DIR: str = "AWS_ATHENA_S3_STAGING_DIR"
_ENV_WORK_GROUP: str = "AWS_ATHENA_WORK_GROUP"
_SESSION_PASSING_ARGS: List[str] = [
Expand All @@ -46,6 +62,68 @@ class Connection:
"config",
]

@overload
def __init__(
self: Connection[Cursor],
s3_staging_dir: Optional[str] = ...,
region_name: Optional[str] = ...,
schema_name: Optional[str] = ...,
catalog_name: Optional[str] = ...,
work_group: Optional[str] = ...,
poll_interval: float = ...,
encryption_option: Optional[str] = ...,
kms_key: Optional[str] = ...,
profile_name: Optional[str] = ...,
role_arn: Optional[str] = ...,
role_session_name: str = ...,
external_id: Optional[str] = ...,
serial_number: Optional[str] = ...,
duration_seconds: int = ...,
converter: Optional[Converter] = ...,
formatter: Optional[Formatter] = ...,
retry_config: Optional[RetryConfig] = ...,
cursor_class: None = ...,
cursor_kwargs: Optional[Dict[str, Any]] = ...,
kill_on_interrupt: bool = ...,
session: Optional[Session] = ...,
config: Optional[Config] = ...,
result_reuse_enable: bool = ...,
result_reuse_minutes: int = ...,
**kwargs,
) -> None:
...

@overload
def __init__(
self: Connection[ConnectionCursor],
s3_staging_dir: Optional[str] = ...,
region_name: Optional[str] = ...,
schema_name: Optional[str] = ...,
catalog_name: Optional[str] = ...,
work_group: Optional[str] = ...,
poll_interval: float = ...,
encryption_option: Optional[str] = ...,
kms_key: Optional[str] = ...,
profile_name: Optional[str] = ...,
role_arn: Optional[str] = ...,
role_session_name: str = ...,
external_id: Optional[str] = ...,
serial_number: Optional[str] = ...,
duration_seconds: int = ...,
converter: Optional[Converter] = ...,
formatter: Optional[Formatter] = ...,
retry_config: Optional[RetryConfig] = ...,
cursor_class: Type[ConnectionCursor] = ...,
cursor_kwargs: Optional[Dict[str, Any]] = ...,
kill_on_interrupt: bool = ...,
session: Optional[Session] = ...,
config: Optional[Config] = ...,
result_reuse_enable: bool = ...,
result_reuse_minutes: int = ...,
**kwargs,
) -> None:
...

def __init__(
self,
s3_staging_dir: Optional[str] = None,
Expand All @@ -65,7 +143,7 @@ def __init__(
converter: Optional[Converter] = None,
formatter: Optional[Formatter] = None,
retry_config: Optional[RetryConfig] = None,
cursor_class: Type[BaseCursor] = Cursor,
cursor_class: Optional[Type[ConnectionCursor]] = cast(Type[ConnectionCursor], Cursor),
cursor_kwargs: Optional[Dict[str, Any]] = None,
kill_on_interrupt: bool = True,
session: Optional[Session] = None,
Expand Down Expand Up @@ -158,7 +236,7 @@ def __init__(
self._converter = converter
self._formatter = formatter if formatter else DefaultParameterFormatter()
self._retry_config = retry_config if retry_config else RetryConfig()
self.cursor_class = cursor_class
self.cursor_class = cast(Type[ConnectionCursor], cursor_class)
self.cursor_kwargs = cursor_kwargs if cursor_kwargs else dict()
self.kill_on_interrupt = kill_on_interrupt
self.result_reuse_enable = result_reuse_enable
Expand Down Expand Up @@ -250,14 +328,23 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

def cursor(self, cursor: Optional[Type[BaseCursor]] = None, **kwargs) -> BaseCursor:
@overload
def cursor(self, cursor: None = ..., **kwargs) -> ConnectionCursor:
...

@overload
def cursor(self, cursor: Type[FunctionalCursor], **kwargs) -> FunctionalCursor:
...

def cursor(
self, cursor: Optional[Type[FunctionalCursor]] = None, **kwargs
) -> Union[FunctionalCursor, ConnectionCursor]:
kwargs.update(self.cursor_kwargs)
if not cursor:
cursor = self.cursor_class
_cursor = cursor or self.cursor_class
converter = kwargs.pop("converter", self._converter)
if not converter:
converter = cursor.get_default_converter(kwargs.get("unload", False))
return cursor(
converter = _cursor.get_default_converter(kwargs.get("unload", False))
return _cursor(
connection=self,
converter=converter,
formatter=kwargs.pop("formatter", self._formatter),
Expand Down
2 changes: 1 addition & 1 deletion pyathena/filesystem/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class S3FileSystem(AbstractFileSystem):

def __init__(
self,
connection: Optional["Connection"] = None,
connection: Optional["Connection[Any]"] = None,
default_block_size: Optional[int] = None,
default_cache_type: Optional[str] = None,
max_workers: int = (cpu_count() or 1) * 5,
Expand Down
2 changes: 1 addition & 1 deletion pyathena/pandas/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class AthenaPandasResultSet(AthenaResultSet):

def __init__(
self,
connection: "Connection",
connection: "Connection[Any]",
converter: Converter,
query_execution: AthenaQueryExecution,
arraysize: int,
Expand Down
2 changes: 1 addition & 1 deletion pyathena/pandas/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def to_parquet(
def to_sql(
df: "DataFrame",
name: str,
conn: "Connection",
conn: "Connection[Any]",
location: str,
schema: str = "default",
index: bool = False,
Expand Down
8 changes: 4 additions & 4 deletions pyathena/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@
class AthenaResultSet(CursorIterator):
def __init__(
self,
connection: "Connection",
connection: "Connection[Any]",
converter: Converter,
query_execution: AthenaQueryExecution,
arraysize: int,
retry_config: RetryConfig,
) -> None:
super().__init__(arraysize=arraysize)
self._connection: Optional["Connection"] = connection
self._connection: Optional["Connection[Any]"] = connection
self._converter = converter
self._query_execution: Optional[AthenaQueryExecution] = query_execution
assert self._query_execution, "Required argument `query_execution` not found."
Expand Down Expand Up @@ -280,10 +280,10 @@ def description(
]

@property
def connection(self) -> "Connection":
def connection(self) -> "Connection[Any]":
if self.is_closed:
raise ProgrammingError("AthenaResultSet is closed.")
return cast("Connection", self._connection)
return cast("Connection[Any]", self._connection)

def __fetch(self, next_token: Optional[str] = None) -> Dict[str, Any]:
if not self.query_id:
Expand Down