From df3dda71b28e0deac3cd11e48239003eb10ea5bb Mon Sep 17 00:00:00 2001 From: ptiurin Date: Wed, 3 May 2023 13:54:12 +0100 Subject: [PATCH 01/13] WIP: tests working no mypy --- src/firebolt/async_db/__init__.py | 6 +- src/firebolt/async_db/connection.py | 188 +++----- src/firebolt/async_db/cursor.py | 442 ++---------------- src/firebolt/{async_db => common}/_types.py | 0 src/firebolt/common/base_connection.py | 112 +++++ src/firebolt/common/base_cursor.py | 401 ++++++++++++++++ src/firebolt/db/__init__.py | 2 +- src/firebolt/db/connection.py | 294 +++++++++++- src/firebolt/db/cursor.py | 332 +++++++++++-- src/firebolt/db/util.py | 55 +++ .../dbapi/async/test_queries_async.py | 2 +- tests/integration/dbapi/conftest.py | 2 +- tests/integration/dbapi/sync/test_queries.py | 2 +- tests/unit/async_db/test_connection.py | 2 +- tests/unit/async_db/test_cursor.py | 4 +- tests/unit/async_db/test_typing_format.py | 2 +- tests/unit/async_db/test_typing_parse.py | 2 +- tests/unit/db/test_connection.py | 2 +- tests/unit/db/test_cursor.py | 6 +- 19 files changed, 1242 insertions(+), 614 deletions(-) rename src/firebolt/{async_db => common}/_types.py (100%) create mode 100644 src/firebolt/common/base_connection.py create mode 100644 src/firebolt/common/base_cursor.py create mode 100644 src/firebolt/db/util.py diff --git a/src/firebolt/async_db/__init__.py b/src/firebolt/async_db/__init__.py index 4d74ebe04dd..35ff4a15a1c 100644 --- a/src/firebolt/async_db/__init__.py +++ b/src/firebolt/async_db/__init__.py @@ -1,4 +1,6 @@ -from firebolt.async_db._types import ( +from firebolt.async_db.connection import Connection, connect +from firebolt.async_db.cursor import Cursor +from firebolt.common._types import ( ARRAY, BINARY, DATETIME, @@ -14,8 +16,6 @@ Timestamp, TimestampFromTicks, ) -from firebolt.async_db.connection import Connection, connect -from firebolt.async_db.cursor import Cursor from firebolt.utils.exception import ( DatabaseError, DataError, diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 09b94bdda03..7cf550ba437 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -4,15 +4,23 @@ import socket from json import JSONDecodeError from types import TracebackType -from typing import Any, Callable, Dict, List, Optional, Type +from typing import Any, Callable, Dict, Optional, Type from httpcore.backends.auto import AutoBackend from httpcore.backends.base import AsyncNetworkStream from httpx import AsyncHTTPTransport, HTTPStatusError, RequestError, Timeout -from firebolt.async_db.cursor import BaseCursor, Cursor +from firebolt.async_db.cursor import Cursor from firebolt.client import DEFAULT_API_URL, AsyncClient -from firebolt.client.auth import Auth, Token, UsernamePassword +from firebolt.client.auth import Auth +from firebolt.common.base_connection import ( + DEFAULT_TIMEOUT_SECONDS, + KEEPALIVE_FLAG, + KEEPIDLE_RATE, + BaseConnection, + _get_auth, + _validate_engine_name_and_url, +) from firebolt.utils.exception import ( ConfigurationError, ConnectionClosedError, @@ -27,9 +35,6 @@ from firebolt.utils.usage_tracker import get_user_agent_header from firebolt.utils.util import fix_url_schema -DEFAULT_TIMEOUT_SECONDS: int = 60 -KEEPALIVE_FLAG: int = 1 -KEEPIDLE_RATE: int = 60 # seconds AUTH_CREDENTIALS_DEPRECATION_MESSAGE = """ Passing connection credentials directly to the `connect` function is deprecated. Pass the `Auth` object instead. @@ -120,49 +125,7 @@ async def _get_database_default_engine_url( raise InterfaceError(f"Unable to retrieve default engine endpoint: {e}.") -def _validate_engine_name_and_url( - engine_name: Optional[str], engine_url: Optional[str] -) -> None: - if engine_name and engine_url: - raise ConfigurationError( - "Both engine_name and engine_url are provided. Provide only one to connect." - ) - - -def _get_auth( - username: Optional[str], - password: Optional[str], - access_token: Optional[str], - use_token_cache: bool, -) -> Auth: - """Create `Auth` class based on provided credentials. - - If `access_token` is provided, it's used for `Auth` creation. - Otherwise, username/password are used. - - Returns: - Auth: `auth object` - - Raises: - `ConfigurationError`: Invalid combination of credentials provided - - """ - if not access_token: - if not username or not password: - raise ConfigurationError( - "Neither username/password nor access_token are provided. Provide one" - " to authenticate." - ) - return UsernamePassword(username, password, use_token_cache) - if username or password: - raise ConfigurationError( - "Username/password and access_token are both provided. Provide only one" - " to authenticate." - ) - return Token(access_token) - - -def async_connect_factory(connection_class: Type) -> Callable: +def async_connect_factory(connection_class: Type, async_class=False) -> Callable: async def connect_inner( database: str = None, username: Optional[str] = None, @@ -293,17 +256,32 @@ async def connect_tcp( return stream -class BaseConnection: - client_class: type - cursor_class: type - __slots__ = ( - "_client", - "_cursors", - "database", - "engine_url", - "api_endpoint", - "_is_closed", - ) +class Connection(BaseConnection): + """ + Firebolt asynchronous database connection class. Implements `PEP 249`_. + + Args: + `engine_url`: Firebolt database engine REST API url + `database`: Firebolt database name + `username`: Firebolt account username + `password`: Firebolt account password + `api_endpoint`: Optional. Firebolt API endpoint used for authentication + `connector_versions`: Optional. Tuple of connector name and version, or + a list of tuples of your connector stack. Useful for tracking custom + connector usage. + + Note: + Firebolt does not support transactions, + so commit and rollback methods are not implemented. + + .. _PEP 249: + https://www.python.org/dev/peps/pep-0249/ + + """ + + __slots__ = BaseConnection.__slots__ + + cursor_class = Cursor def __init__( self, @@ -313,11 +291,14 @@ def __init__( api_endpoint: str = DEFAULT_API_URL, additional_parameters: Dict[str, Any] = {}, ): + super().__init__( + engine_url, database, auth, api_endpoint, additional_parameters + ) + user_drivers = additional_parameters.get("user_drivers", []) + user_clients = additional_parameters.get("user_clients", []) # Override tcp keepalive settings for connection transport = AsyncHTTPTransport() transport._pool._network_backend = OverriddenHttpBackend() - user_drivers = additional_parameters.get("user_drivers", []) - user_clients = additional_parameters.get("user_clients", []) self._client = AsyncClient( auth=auth, base_url=engine_url, @@ -326,25 +307,19 @@ def __init__( transport=transport, headers={"User-Agent": get_user_agent_header(user_drivers, user_clients)}, ) - self.api_endpoint = api_endpoint - self.engine_url = engine_url - self.database = database - self._cursors: List[BaseCursor] = [] - self._is_closed = False - def _cursor(self, **kwargs: Any) -> BaseCursor: - """ - Create new cursor object. - """ + def cursor(self) -> Cursor: + c = super()._cursor() + assert isinstance(c, Cursor) # typecheck + return c + # Context manager support + async def __aenter__(self) -> Connection: if self.closed: - raise ConnectionClosedError("Unable to create cursor: connection closed.") - - c = self.cursor_class(self._client, self, **kwargs) - self._cursors.append(c) - return c + raise ConnectionClosedError("Connection is already closed.") + return self - async def _aclose(self) -> None: + async def aclose(self) -> None: """Close connection and all underlying cursors.""" if self.closed: return @@ -360,67 +335,10 @@ async def _aclose(self) -> None: await self._client.aclose() self._is_closed = True - @property - def closed(self) -> bool: - """`True` if connection is closed; `False` otherwise.""" - return self._is_closed - - def _remove_cursor(self, cursor: Cursor) -> None: - # This way it's atomic - try: - self._cursors.remove(cursor) - except ValueError: - pass - - def commit(self) -> None: - """Does nothing since Firebolt doesn't have transactions.""" - - if self.closed: - raise ConnectionClosedError("Unable to commit: Connection closed.") - - -class Connection(BaseConnection): - """ - Firebolt asynchronous database connection class. Implements `PEP 249`_. - - Args: - `engine_url`: Firebolt database engine REST API url - `database`: Firebolt database name - `username`: Firebolt account username - `password`: Firebolt account password - `api_endpoint`: Optional. Firebolt API endpoint used for authentication - `connector_versions`: Optional. Tuple of connector name and version, or - a list of tuples of your connector stack. Useful for tracking custom - connector usage. - - Note: - Firebolt does not support transactions, - so commit and rollback methods are not implemented. - - .. _PEP 249: - https://www.python.org/dev/peps/pep-0249/ - - """ - - cursor_class = Cursor - - aclose = BaseConnection._aclose - - def cursor(self) -> Cursor: - c = super()._cursor() - assert isinstance(c, Cursor) # typecheck - return c - - # Context manager support - async def __aenter__(self) -> Connection: - if self.closed: - raise ConnectionClosedError("Connection is already closed.") - return self - async def __aexit__( self, exc_type: type, exc_val: Exception, exc_tb: TracebackType ) -> None: - await self._aclose() + await self.aclose() connect = async_connect_factory(Connection) diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 6bc01849acb..8dfbabedbe5 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -3,14 +3,11 @@ import logging import re import time -from enum import Enum from functools import wraps -from types import TracebackType from typing import ( TYPE_CHECKING, Any, Callable, - Dict, List, Optional, Sequence, @@ -20,24 +17,25 @@ from aiorwlock import RWLock from httpx import Response, codes -from pydantic import BaseModel -from firebolt.async_db._types import ( +from firebolt.async_db.util import is_db_available, is_engine_running +from firebolt.common._types import ( ColType, Column, ParameterType, RawColType, SetParameter, - parse_type, - parse_value, split_format_sql, ) -from firebolt.async_db.util import is_db_available, is_engine_running -from firebolt.client import AsyncClient +from firebolt.common.base_cursor import ( + BaseCursor, + CursorState, + QueryStatus, + Statistics, +) from firebolt.utils.exception import ( AsyncExecutionUnavailableError, CursorClosedError, - DataError, EngineNotRunningError, FireboltDatabaseError, OperationalError, @@ -46,7 +44,7 @@ ) if TYPE_CHECKING: - from firebolt.async_db.connection import Connection + pass logger = logging.getLogger(__name__) @@ -54,40 +52,6 @@ JSON_OUTPUT_FORMAT = "JSON_Compact" -class CursorState(Enum): - NONE = 1 - ERROR = 2 - DONE = 3 - CLOSED = 4 - - -class QueryStatus(Enum): - """Enumeration of query responses on server-side async queries.""" - - RUNNING = 1 - ENDED_SUCCESSFULLY = 2 - ENDED_UNSUCCESSFULLY = 3 - NOT_READY = 4 - STARTED_EXECUTION = 5 - PARSE_ERROR = 6 - CANCELED_EXECUTION = 7 - EXECUTION_ERROR = 8 - - -class Statistics(BaseModel): - """ - Class for query execution statistics. - """ - - elapsed: float - rows_read: int - bytes_read: int - time_before_execution: float - time_to_execute: float - scanned_bytes_cache: Optional[float] - scanned_bytes_storage: Optional[float] - - def check_not_closed(func: Callable) -> Callable: """(Decorator) ensure cursor is not closed before calling method.""" @@ -115,142 +79,26 @@ def inner(self: Cursor, *args: Any, **kwargs: Any) -> Any: return inner -class BaseCursor: - __slots__ = ( - "connection", - "_arraysize", - "_client", - "_state", - "_descriptions", - "_statistics", - "_rowcount", - "_rows", - "_idx", - "_idx_lock", - "_row_sets", - "_next_set_idx", - "_set_parameters", - "_query_id", - ) - - default_arraysize = 1 - - def __init__(self, client: AsyncClient, connection: Connection): - self.connection = connection - self._client = client - self._arraysize = self.default_arraysize - # These fields initialized here for type annotations purpose - self._rows: Optional[List[List[RawColType]]] = None - self._descriptions: Optional[List[Column]] = None - self._statistics: Optional[Statistics] = None - self._row_sets: List[ - Tuple[ - int, - Optional[List[Column]], - Optional[Statistics], - Optional[List[List[RawColType]]], - ] - ] = [] - self._set_parameters: Dict[str, Any] = dict() - self._rowcount = -1 - self._idx = 0 - self._next_set_idx = 0 - self._query_id = "" - self._reset() - - def __del__(self) -> None: - self.close() - - @property # type: ignore - @check_not_closed - def description(self) -> Optional[List[Column]]: - """ - Provides information about a single result row of a query. - - Attributes: - * ``name`` - * ``type_code`` - * ``display_size`` - * ``internal_size`` - * ``precision`` - * ``scale`` - * ``null_ok`` - """ - return self._descriptions - - @property # type: ignore - @check_not_closed - def statistics(self) -> Optional[Statistics]: - """Query execution statistics returned by the backend.""" - return self._statistics - - @property # type: ignore - @check_not_closed - def rowcount(self) -> int: - """The number of rows produced by last query.""" - return self._rowcount - - @property # type: ignore - @check_not_closed - def query_id(self) -> str: - """The query id of a query executed asynchronously.""" - return self._query_id - - @property - def arraysize(self) -> int: - """Default number of rows returned by fetchmany.""" - return self._arraysize - - @arraysize.setter - def arraysize(self, value: int) -> None: - if not isinstance(value, int): - raise TypeError( - "Invalid arraysize value type, expected int," - f" got {type(value).__name__}" - ) - self._arraysize = value +class Cursor(BaseCursor): + """ + Executes async queries to Firebolt Database. + Should not be created directly; + use :py:func:`connection.cursor ` - @property - def closed(self) -> bool: - """True if connection is closed, False otherwise.""" - return self._state == CursorState.CLOSED + Args: + description: Information about a single result row. + rowcount: The number of rows produced by last query. + closed: True if connection is closed; False otherwise. + arraysize: Read/Write, specifies the number of rows to fetch at a time + with the :py:func:`fetchmany` method. - def close(self) -> None: - """Terminate an ongoing query (if any) and mark connection as closed.""" - self._state = CursorState.CLOSED - # remove typecheck skip after connection is implemented - self.connection._remove_cursor(self) # type: ignore + """ - @check_not_closed - @check_query_executed - def nextset(self) -> Optional[bool]: - """ - Skip to the next available set, discarding any remaining rows - from the current set. - Returns True if operation was successful; - None if there are no more sets to retrive. - """ - return self._pop_next_set() + __slots__ = BaseCursor.__slots__ + ("_async_query_lock",) - def _pop_next_set(self) -> Optional[bool]: - """ - Same functionality as .nextset, but doesn't check that query has been executed. - """ - if self._next_set_idx >= len(self._row_sets): - return None - ( - self._rowcount, - self._descriptions, - self._statistics, - self._rows, - ) = self._row_sets[self._next_set_idx] - self._idx = 0 - self._next_set_idx += 1 - return True - - def flush_parameters(self) -> None: - """Cleanup all previously set parameters""" - self._set_parameters = dict() + def __init__(self, *args: Any, **kwargs: Any) -> None: + self._async_query_lock = RWLock() + super().__init__(*args, **kwargs) async def _raise_if_error(self, resp: Response) -> None: """Raise a proper error if any""" @@ -275,63 +123,6 @@ async def _raise_if_error(self, resp: Response) -> None: ) resp.raise_for_status() - def _reset(self) -> None: - """Clear all data stored from previous query.""" - self._state = CursorState.NONE - self._rows = None - self._descriptions = None - self._statistics = None - self._rowcount = -1 - self._idx = 0 - self._row_sets = [] - self._next_set_idx = 0 - self._query_id = "" - - def _row_set_from_response( - self, response: Response - ) -> Tuple[ - int, - Optional[List[Column]], - Optional[Statistics], - Optional[List[List[RawColType]]], - ]: - """Fetch information about executed query from http response.""" - - # Empty response is returned for insert query - if response.headers.get("content-length", "") == "0": - return (-1, None, None, None) - try: - # Skip parsing floats to properly parse them later - query_data = response.json(parse_float=str) - rowcount = int(query_data["rows"]) - descriptions: Optional[List[Column]] = [ - Column(d["name"], parse_type(d["type"]), None, None, None, None, None) - for d in query_data["meta"] - ] - if not descriptions: - descriptions = None - statistics = Statistics(**query_data["statistics"]) - # Parse data during fetch - rows = query_data["data"] - return (rowcount, descriptions, statistics, rows) - except (KeyError, ValueError) as err: - raise DataError(f"Invalid query data format: {str(err)}") - - def _append_row_set( - self, - row_set: Tuple[ - int, - Optional[List[Column]], - Optional[Statistics], - Optional[List[List[RawColType]]], - ], - ) -> None: - """Store information about executed query.""" - self._row_sets.append(row_set) - if self._next_set_idx == 0: - # Populate values for first set - self._pop_next_set() - async def _api_request( self, query: Optional[str] = "", @@ -383,33 +174,6 @@ async def _validate_set_parameter(self, parameter: SetParameter) -> None: # set parameter passed validation self._set_parameters[parameter.name] = parameter.value - def _validate_server_side_async_settings( - self, - parameters: Sequence[Sequence[ParameterType]], - queries: List[Union[SetParameter, str]], - skip_parsing: bool = False, - async_execution: Optional[bool] = False, - ) -> None: - if async_execution and self._set_parameters.get("use_standard_sql", "1") == "0": - raise AsyncExecutionUnavailableError( - "It is not possible to execute queries asynchronously if " - "use_standard_sql=0." - ) - if parameters and skip_parsing: - logger.warning( - "Query formatting parameters are provided but skip_parsing " - "is specified. They will be ignored." - ) - non_set_queries = 0 - for query in queries: - if type(query) is not SetParameter: - non_set_queries += 1 - if non_set_queries > 1 and async_execution: - raise AsyncExecutionUnavailableError( - "It is not possible to execute multi-statement " - "queries asynchronously." - ) - async def _do_execute( self, raw_query: str, @@ -568,77 +332,19 @@ async def executemany( else: return self.rowcount - def _parse_row(self, row: List[RawColType]) -> List[ColType]: - """Parse a single data row based on query column types.""" - assert len(row) == len(self.description) - return [ - parse_value(col, self.description[i].type_code) for i, col in enumerate(row) - ] - - def _get_next_range(self, size: int) -> Tuple[int, int]: - """ - Return range of next rows of size (if possible), - and update _idx to point to the end of this range - """ - - if self._rows is None: - # No elements to take - raise DataError("no rows to fetch") - - left = self._idx - right = min(self._idx + size, len(self._rows)) - self._idx = right - return left, right - - @check_not_closed - @check_query_executed - def fetchone(self) -> Optional[List[ColType]]: - """Fetch the next row of a query result set.""" - left, right = self._get_next_range(1) - if left == right: - # We are out of elements - return None - assert self._rows is not None - return self._parse_row(self._rows[left]) - + # Iteration support @check_not_closed @check_query_executed - def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]: - """ - Fetch the next set of rows of a query result; - cursor.arraysize is default size. - """ - size = size if size is not None else self.arraysize - left, right = self._get_next_range(size) - assert self._rows is not None - rows = self._rows[left:right] - return [self._parse_row(row) for row in rows] + def __aiter__(self) -> Cursor: + return self @check_not_closed @check_query_executed - def fetchall(self) -> List[List[ColType]]: - """Fetch all remaining rows of a query result.""" - left, right = self._get_next_range(self.rowcount) - assert self._rows is not None - rows = self._rows[left:right] - return [self._parse_row(row) for row in rows] - - @check_not_closed - def setinputsizes(self, sizes: List[int]) -> None: - """Predefine memory areas for query parameters (does nothing).""" - - @check_not_closed - def setoutputsize(self, size: int, column: Optional[int] = None) -> None: - """Set a column buffer size for fetches of large columns (does nothing).""" - - @check_not_closed - async def cancel(self, query_id: str) -> None: - """Cancel a server-side async query.""" - await self._api_request( - parameters={"query_id": query_id}, - path="cancel", - use_set_parameters=False, - ) + async def __anext__(self) -> List[ColType]: + row = await self.fetchone() + if row is None: + raise StopAsyncIteration + return row @check_not_closed async def get_status(self, query_id: str) -> QueryStatus: @@ -669,72 +375,20 @@ async def get_status(self, query_id: str) -> QueryStatus: return QueryStatus.NOT_READY return QueryStatus[resp_json["status"]] - # Context manager support @check_not_closed - def __enter__(self) -> BaseCursor: - return self - - def __exit__( - self, exc_type: type, exc_val: Exception, exc_tb: TracebackType - ) -> None: - self.close() - - -class Cursor(BaseCursor): - """ - Executes async queries to Firebolt Database. - Should not be created directly; - use :py:func:`connection.cursor ` - - Args: - description: Information about a single result row. - rowcount: The number of rows produced by last query. - closed: True if connection is closed; False otherwise. - arraysize: Read/Write, specifies the number of rows to fetch at a time - with the :py:func:`fetchmany` method. - - """ - - __slots__ = BaseCursor.__slots__ + ("_async_query_lock",) - - def __init__(self, *args: Any, **kwargs: Any) -> None: - self._async_query_lock = RWLock() - super().__init__(*args, **kwargs) - - @wraps(BaseCursor.execute) - async def execute( - self, - query: str, - parameters: Optional[Sequence[ParameterType]] = None, - skip_parsing: bool = False, - async_execution: Optional[bool] = False, - ) -> Union[int, str]: - async with self._async_query_lock.writer: - return await super().execute( - query, parameters, skip_parsing, async_execution - ) - - @wraps(BaseCursor.executemany) - async def executemany( - self, - query: str, - parameters_seq: Sequence[Sequence[ParameterType]], - async_execution: Optional[bool] = False, - ) -> int: - """ - Prepare and execute a database query against all parameter - sequences provided. - """ - async with self._async_query_lock.writer: - return await super().executemany(query, parameters_seq, async_execution) + async def cancel(self, query_id: str) -> None: + """Cancel a server-side async query.""" + await self._api_request( + parameters={"query_id": query_id}, + path="cancel", + use_set_parameters=False, + ) - @wraps(BaseCursor.fetchone) async def fetchone(self) -> Optional[List[ColType]]: async with self._async_query_lock.reader: """Fetch the next row of a query result set.""" return super().fetchone() - @wraps(BaseCursor.fetchmany) async def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]: async with self._async_query_lock.reader: """ @@ -743,27 +397,11 @@ async def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]: """ return super().fetchmany(size) - @wraps(BaseCursor.fetchall) async def fetchall(self) -> List[List[ColType]]: async with self._async_query_lock.reader: """Fetch all remaining rows of a query result.""" return super().fetchall() - @wraps(BaseCursor.nextset) async def nextset(self) -> None: async with self._async_query_lock.reader: return super().nextset() - - # Iteration support - @check_not_closed - @check_query_executed - def __aiter__(self) -> Cursor: - return self - - @check_not_closed - @check_query_executed - async def __anext__(self) -> List[ColType]: - row = await self.fetchone() - if row is None: - raise StopAsyncIteration - return row diff --git a/src/firebolt/async_db/_types.py b/src/firebolt/common/_types.py similarity index 100% rename from src/firebolt/async_db/_types.py rename to src/firebolt/common/_types.py diff --git a/src/firebolt/common/base_connection.py b/src/firebolt/common/base_connection.py new file mode 100644 index 00000000000..209c8268dbe --- /dev/null +++ b/src/firebolt/common/base_connection.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from firebolt.async_db.cursor import BaseCursor, Cursor +from firebolt.client import DEFAULT_API_URL +from firebolt.client.auth import Auth, Token, UsernamePassword +from firebolt.utils.exception import ConfigurationError, ConnectionClosedError + +KEEPALIVE_FLAG: int = 1 + +KEEPIDLE_RATE: int = 60 # seconds +DEFAULT_TIMEOUT_SECONDS: int = 60 + + +def _get_auth( + username: Optional[str], + password: Optional[str], + access_token: Optional[str], + use_token_cache: bool, +) -> Auth: + """Create `Auth` class based on provided credentials. + + If `access_token` is provided, it's used for `Auth` creation. + Otherwise, username/password are used. + + Returns: + Auth: `auth object` + + Raises: + `ConfigurationError`: Invalid combination of credentials provided + + """ + if not access_token: + if not username or not password: + raise ConfigurationError( + "Neither username/password nor access_token are provided. Provide one" + " to authenticate." + ) + return UsernamePassword(username, password, use_token_cache) + if username or password: + raise ConfigurationError( + "Username/password and access_token are both provided. Provide only one" + " to authenticate." + ) + return Token(access_token) + + +def _validate_engine_name_and_url( + engine_name: Optional[str], engine_url: Optional[str] +) -> None: + if engine_name and engine_url: + raise ConfigurationError( + "Both engine_name and engine_url are provided. Provide only one to connect." + ) + + +class BaseConnection: + client_class: type + cursor_class: type + __slots__ = ( + "_client", + "_cursors", + "database", + "engine_url", + "api_endpoint", + "_is_closed", + ) + + def __init__( + self, + engine_url: str, + database: str, + auth: Auth, + api_endpoint: str = DEFAULT_API_URL, + additional_parameters: Dict[str, Any] = {}, + ): + self.api_endpoint = api_endpoint + self.engine_url = engine_url + self.database = database + self._cursors: List[BaseCursor] = [] + self._is_closed = False + + def _cursor(self, **kwargs: Any) -> BaseCursor: + """ + Create new cursor object. + """ + + if self.closed: + raise ConnectionClosedError("Unable to create cursor: connection closed.") + + c = self.cursor_class(self._client, self, **kwargs) + self._cursors.append(c) + return c + + @property + def closed(self) -> bool: + """`True` if connection is closed; `False` otherwise.""" + return self._is_closed + + def _remove_cursor(self, cursor: Cursor) -> None: + # This way it's atomic + try: + self._cursors.remove(cursor) + except ValueError: + pass + + def commit(self) -> None: + """Does nothing since Firebolt doesn't have transactions.""" + + if self.closed: + raise ConnectionClosedError("Unable to commit: Connection closed.") diff --git a/src/firebolt/common/base_cursor.py b/src/firebolt/common/base_cursor.py new file mode 100644 index 00000000000..0e7fa512642 --- /dev/null +++ b/src/firebolt/common/base_cursor.py @@ -0,0 +1,401 @@ +from __future__ import annotations + +import logging +from enum import Enum +from functools import wraps +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, +) + +from httpx import Response +from pydantic import BaseModel + +from firebolt.client import AsyncClient, Client +from firebolt.common._types import ( + ColType, + Column, + ParameterType, + RawColType, + SetParameter, + parse_type, + parse_value, +) +from firebolt.utils.exception import ( + AsyncExecutionUnavailableError, + CursorClosedError, + DataError, + QueryNotRunError, +) + +if TYPE_CHECKING: + from firebolt.common.base_connection import BaseConnection + +logger = logging.getLogger(__name__) + + +JSON_OUTPUT_FORMAT = "JSON_Compact" + + +class CursorState(Enum): + NONE = 1 + ERROR = 2 + DONE = 3 + CLOSED = 4 + + +class QueryStatus(Enum): + """Enumeration of query responses on server-side async queries.""" + + RUNNING = 1 + ENDED_SUCCESSFULLY = 2 + ENDED_UNSUCCESSFULLY = 3 + NOT_READY = 4 + STARTED_EXECUTION = 5 + PARSE_ERROR = 6 + CANCELED_EXECUTION = 7 + EXECUTION_ERROR = 8 + + +class Statistics(BaseModel): + """ + Class for query execution statistics. + """ + + elapsed: float + rows_read: int + bytes_read: int + time_before_execution: float + time_to_execute: float + scanned_bytes_cache: Optional[float] + scanned_bytes_storage: Optional[float] + + +def check_not_closed(func: Callable) -> Callable: + """(Decorator) ensure cursor is not closed before calling method.""" + + @wraps(func) + def inner(self: BaseCursor, *args: Any, **kwargs: Any) -> Any: + if self.closed: + raise CursorClosedError(method_name=func.__name__) + return func(self, *args, **kwargs) + + return inner + + +def check_query_executed(func: Callable) -> Callable: + """ + (Decorator) ensure that some query has been executed before + calling cursor method. + """ + + @wraps(func) + def inner(self: BaseCursor, *args: Any, **kwargs: Any) -> Any: + if self._state == CursorState.NONE: + raise QueryNotRunError(method_name=func.__name__) + return func(self, *args, **kwargs) + + return inner + + +class BaseCursor: + __slots__ = ( + "connection", + "_arraysize", + "_client", + "_state", + "_descriptions", + "_statistics", + "_rowcount", + "_rows", + "_idx", + "_idx_lock", + "_row_sets", + "_next_set_idx", + "_set_parameters", + "_query_id", + ) + + default_arraysize = 1 + + def __init__(self, client: Union[AsyncClient, Client], connection: BaseConnection): + self.connection = connection + self._client = client + self._arraysize = self.default_arraysize + # These fields initialized here for type annotations purpose + self._rows: Optional[List[List[RawColType]]] = None + self._descriptions: Optional[List[Column]] = None + self._statistics: Optional[Statistics] = None + self._row_sets: List[ + Tuple[ + int, + Optional[List[Column]], + Optional[Statistics], + Optional[List[List[RawColType]]], + ] + ] = [] + self._set_parameters: Dict[str, Any] = dict() + self._rowcount = -1 + self._idx = 0 + self._next_set_idx = 0 + self._query_id = "" + self._reset() + + def __del__(self) -> None: + self.close() + + @property # type: ignore + @check_not_closed + def description(self) -> Optional[List[Column]]: + """ + Provides information about a single result row of a query. + + Attributes: + * ``name`` + * ``type_code`` + * ``display_size`` + * ``internal_size`` + * ``precision`` + * ``scale`` + * ``null_ok`` + """ + return self._descriptions + + @property # type: ignore + @check_not_closed + def statistics(self) -> Optional[Statistics]: + """Query execution statistics returned by the backend.""" + return self._statistics + + @property # type: ignore + @check_not_closed + def rowcount(self) -> int: + """The number of rows produced by last query.""" + return self._rowcount + + @property # type: ignore + @check_not_closed + def query_id(self) -> str: + """The query id of a query executed asynchronously.""" + return self._query_id + + @property + def arraysize(self) -> int: + """Default number of rows returned by fetchmany.""" + return self._arraysize + + @arraysize.setter + def arraysize(self, value: int) -> None: + if not isinstance(value, int): + raise TypeError( + "Invalid arraysize value type, expected int," + f" got {type(value).__name__}" + ) + self._arraysize = value + + @property + def closed(self) -> bool: + """True if connection is closed, False otherwise.""" + return self._state == CursorState.CLOSED + + def close(self) -> None: + """Terminate an ongoing query (if any) and mark connection as closed.""" + self._state = CursorState.CLOSED + # remove typecheck skip after connection is implemented + self.connection._remove_cursor(self) # type: ignore + + @check_not_closed + @check_query_executed + def nextset(self) -> Optional[bool]: + """ + Skip to the next available set, discarding any remaining rows + from the current set. + Returns True if operation was successful; + None if there are no more sets to retrive. + """ + return self._pop_next_set() + + def _pop_next_set(self) -> Optional[bool]: + """ + Same functionality as .nextset, but doesn't check that query has been executed. + """ + if self._next_set_idx >= len(self._row_sets): + return None + ( + self._rowcount, + self._descriptions, + self._statistics, + self._rows, + ) = self._row_sets[self._next_set_idx] + self._idx = 0 + self._next_set_idx += 1 + return True + + def flush_parameters(self) -> None: + """Cleanup all previously set parameters""" + self._set_parameters = dict() + + def _reset(self) -> None: + """Clear all data stored from previous query.""" + self._state = CursorState.NONE + self._rows = None + self._descriptions = None + self._statistics = None + self._rowcount = -1 + self._idx = 0 + self._row_sets = [] + self._next_set_idx = 0 + self._query_id = "" + + def _row_set_from_response( + self, response: Response + ) -> Tuple[ + int, + Optional[List[Column]], + Optional[Statistics], + Optional[List[List[RawColType]]], + ]: + """Fetch information about executed query from http response.""" + + # Empty response is returned for insert query + if response.headers.get("content-length", "") == "0": + return (-1, None, None, None) + try: + # Skip parsing floats to properly parse them later + query_data = response.json(parse_float=str) + rowcount = int(query_data["rows"]) + descriptions: Optional[List[Column]] = [ + Column(d["name"], parse_type(d["type"]), None, None, None, None, None) + for d in query_data["meta"] + ] + if not descriptions: + descriptions = None + statistics = Statistics(**query_data["statistics"]) + # Parse data during fetch + rows = query_data["data"] + return (rowcount, descriptions, statistics, rows) + except (KeyError, ValueError) as err: + raise DataError(f"Invalid query data format: {str(err)}") + + def _append_row_set( + self, + row_set: Tuple[ + int, + Optional[List[Column]], + Optional[Statistics], + Optional[List[List[RawColType]]], + ], + ) -> None: + """Store information about executed query.""" + self._row_sets.append(row_set) + if self._next_set_idx == 0: + # Populate values for first set + self._pop_next_set() + + def _validate_server_side_async_settings( + self, + parameters: Sequence[Sequence[ParameterType]], + queries: List[Union[SetParameter, str]], + skip_parsing: bool = False, + async_execution: Optional[bool] = False, + ) -> None: + if async_execution and self._set_parameters.get("use_standard_sql", "1") == "0": + raise AsyncExecutionUnavailableError( + "It is not possible to execute queries asynchronously if " + "use_standard_sql=0." + ) + if parameters and skip_parsing: + logger.warning( + "Query formatting parameters are provided but skip_parsing " + "is specified. They will be ignored." + ) + non_set_queries = 0 + for query in queries: + if type(query) is not SetParameter: + non_set_queries += 1 + if non_set_queries > 1 and async_execution: + raise AsyncExecutionUnavailableError( + "It is not possible to execute multi-statement " + "queries asynchronously." + ) + + def _parse_row(self, row: List[RawColType]) -> List[ColType]: + """Parse a single data row based on query column types.""" + assert len(row) == len(self.description) + return [ + parse_value(col, self.description[i].type_code) for i, col in enumerate(row) + ] + + def _get_next_range(self, size: int) -> Tuple[int, int]: + """ + Return range of next rows of size (if possible), + and update _idx to point to the end of this range + """ + + if self._rows is None: + # No elements to take + raise DataError("no rows to fetch") + + left = self._idx + right = min(self._idx + size, len(self._rows)) + self._idx = right + return left, right + + @check_not_closed + @check_query_executed + def fetchone(self) -> Optional[List[ColType]]: + """Fetch the next row of a query result set.""" + left, right = self._get_next_range(1) + if left == right: + # We are out of elements + return None + assert self._rows is not None + return self._parse_row(self._rows[left]) + + @check_not_closed + @check_query_executed + def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]: + """ + Fetch the next set of rows of a query result; + cursor.arraysize is default size. + """ + size = size if size is not None else self.arraysize + left, right = self._get_next_range(size) + assert self._rows is not None + rows = self._rows[left:right] + return [self._parse_row(row) for row in rows] + + @check_not_closed + @check_query_executed + def fetchall(self) -> List[List[ColType]]: + """Fetch all remaining rows of a query result.""" + left, right = self._get_next_range(self.rowcount) + assert self._rows is not None + rows = self._rows[left:right] + return [self._parse_row(row) for row in rows] + + @check_not_closed + def setinputsizes(self, sizes: List[int]) -> None: + """Predefine memory areas for query parameters (does nothing).""" + + @check_not_closed + def setoutputsize(self, size: int, column: Optional[int] = None) -> None: + """Set a column buffer size for fetches of large columns (does nothing).""" + + # Context manager support + @check_not_closed + def __enter__(self) -> BaseCursor: + return self + + def __exit__( + self, exc_type: type, exc_val: Exception, exc_tb: TracebackType + ) -> None: + self.close() diff --git a/src/firebolt/db/__init__.py b/src/firebolt/db/__init__.py index 8ee02c1448f..fcf6ae53c80 100644 --- a/src/firebolt/db/__init__.py +++ b/src/firebolt/db/__init__.py @@ -1,4 +1,4 @@ -from firebolt.async_db._types import ( +from firebolt.common._types import ( ARRAY, BINARY, DATETIME, diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index 10fa91309e1..ac3c1c9a1a0 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -1,20 +1,254 @@ from __future__ import annotations -from functools import wraps +import logging +import socket +from json import JSONDecodeError from types import TracebackType -from typing import Any +from typing import Any, Callable, Dict, Optional, Type from warnings import warn +from httpcore.backends.base import NetworkStream +from httpcore.backends.sync import SyncBackend +from httpx import HTTPStatusError, HTTPTransport, RequestError, Timeout from readerwriterlock.rwlock import RWLockWrite -from firebolt.async_db.connection import BaseConnection as AsyncBaseConnection -from firebolt.async_db.connection import async_connect_factory +from firebolt.client import DEFAULT_API_URL, Client +from firebolt.client.auth import Auth +from firebolt.common.base_connection import ( + DEFAULT_TIMEOUT_SECONDS, + KEEPALIVE_FLAG, + KEEPIDLE_RATE, + BaseConnection, + _get_auth, + _validate_engine_name_and_url, +) +from firebolt.common.settings import AUTH_CREDENTIALS_DEPRECATION_MESSAGE from firebolt.db.cursor import Cursor -from firebolt.utils.exception import ConnectionClosedError -from firebolt.utils.util import async_to_sync +from firebolt.utils.exception import ( + ConfigurationError, + ConnectionClosedError, + FireboltEngineError, + InterfaceError, +) +from firebolt.utils.urls import ( + ACCOUNT_ENGINE_ID_BY_NAME_URL, + ACCOUNT_ENGINE_URL, + ACCOUNT_ENGINE_URL_BY_DATABASE_NAME, +) +from firebolt.utils.usage_tracker import get_user_agent_header +from firebolt.utils.util import fix_url_schema +logger = logging.getLogger(__name__) -class Connection(AsyncBaseConnection): + +def _resolve_engine_url( + engine_name: str, + auth: Auth, + api_endpoint: str, + account_name: Optional[str] = None, +) -> str: + with Client( + auth=auth, + base_url=api_endpoint, + account_name=account_name, + api_endpoint=api_endpoint, + timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), + ) as client: + account_id = client.account_id + url = ACCOUNT_ENGINE_ID_BY_NAME_URL.format(account_id=account_id) + try: + response = client.get( + url=url, + params={"engine_name": engine_name}, + ) + response.raise_for_status() + engine_id = response.json()["engine_id"]["engine_id"] + url = ACCOUNT_ENGINE_URL.format(account_id=account_id, engine_id=engine_id) + response = client.get(url=url) + response.raise_for_status() + return response.json()["engine"]["endpoint"] + except HTTPStatusError as e: + # Engine error would be 404. + if e.response.status_code != 404: + raise InterfaceError( + f"Error {e.__class__.__name__}: Unable to retrieve engine " + f"endpoint {url}." + ) + # Once this is point is reached we've already authenticated with + # the backend so it's safe to assume the cause of the error is + # missing engine. + raise FireboltEngineError(f"Firebolt engine {engine_name} does not exist.") + except (JSONDecodeError, RequestError, RuntimeError, HTTPStatusError) as e: + raise InterfaceError( + f"Error {e.__class__.__name__}: " + f"Unable to retrieve engine endpoint {url}." + ) + + +def _get_database_default_engine_url( + database: str, + auth: Auth, + api_endpoint: str, + account_name: Optional[str] = None, +) -> str: + with Client( + auth=auth, + base_url=api_endpoint, + account_name=account_name, + api_endpoint=api_endpoint, + timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), + ) as client: + try: + account_id = client.account_id + response = client.get( + url=ACCOUNT_ENGINE_URL_BY_DATABASE_NAME.format(account_id=account_id), + params={"database_name": database}, + ) + response.raise_for_status() + return response.json()["engine_url"] + except ( + JSONDecodeError, + RequestError, + RuntimeError, + HTTPStatusError, + KeyError, + ) as e: + raise InterfaceError(f"Unable to retrieve default engine endpoint: {e}.") + + +def sync_connect_factory(connection_class: Type) -> Callable: + def connect_inner( + database: str = None, + username: Optional[str] = None, + password: Optional[str] = None, + access_token: Optional[str] = None, + auth: Auth = None, + engine_name: Optional[str] = None, + engine_url: Optional[str] = None, + account_name: Optional[str] = None, + api_endpoint: str = DEFAULT_API_URL, + use_token_cache: bool = True, + additional_parameters: Dict[str, Any] = {}, + ) -> Connection: + """Connect to Firebolt database. + + Args: + `database` (str): Name of the database to connect + `username` (Optional[str]): User name to use for authentication (Deprecated) + `password` (Optional[str]): Password to use for authentication (Deprecated) + `access_token` (Optional[str]): Authentication token to use instead of + credentials (Deprecated) + `auth` (Auth)L Authentication object. + `engine_name` (Optional[str]): Name of the engine to connect to + `engine_url` (Optional[str]): The engine endpoint to use + `account_name` (Optional[str]): For customers with multiple accounts; + if none, default is used + `api_endpoint` (str): Firebolt API endpoint. Used for authentication + `use_token_cache` (bool): Cached authentication token in filesystem + Default: True + `additional_parameters` (Optional[Dict]): Dictionary of less widely-used + arguments for connection + + Note: + Providing both `engine_name` and `engine_url` will result in an error + + """ + # These parameters are optional in function signature + # but are required to connect. + # PEP 249 recommends making them kwargs. + if not database: + raise ConfigurationError("database name is required to connect.") + + _validate_engine_name_and_url(engine_name, engine_url) + + if not auth: + if any([username, password, access_token, api_endpoint, use_token_cache]): + logger.warning(AUTH_CREDENTIALS_DEPRECATION_MESSAGE) + auth = _get_auth(username, password, access_token, use_token_cache) + else: + raise ConfigurationError("No authentication provided.") + api_endpoint = fix_url_schema(api_endpoint) + + # Mypy checks, this should never happen + assert database is not None + + if not engine_name and not engine_url: + engine_url = _get_database_default_engine_url( + database=database, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + ) + + elif engine_name: + engine_url = _resolve_engine_url( + engine_name=engine_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + ) + elif account_name: + # In above if branches account name is validated since it's used to + # resolve or get an engine url. + # We need to manually validate account_name if none of the above + # cases are triggered. + with Client( + auth=auth, + base_url=api_endpoint, + account_name=account_name, + api_endpoint=api_endpoint, + ) as client: + client.account_id # TODO: is this correct? + + assert engine_url is not None + + engine_url = fix_url_schema(engine_url) + return connection_class( + engine_url, database, auth, api_endpoint, additional_parameters + ) + + return connect_inner + + +# TODO: verify new httpx has not improved this +class OverriddenHttpBackend(SyncBackend): + """ + `OverriddenHttpBackend` is a short-term solution for the TCP + connection idle timeout issue described in the following article: + https://docs.aws.amazon.com/elasticloadbalancing/latest/network/network-load-balancers.html#connection-idle-timeout + Since httpx creates a connection right before executing a request, the + backend must be overridden to set the socket to `KEEPALIVE` + and `KEEPIDLE` settings. + """ + + def connect_tcp( + self, + host: str, + port: int, + timeout: Optional[float] = None, + local_address: Optional[str] = None, + ) -> NetworkStream: + stream = super().connect_tcp( + host, port, timeout=timeout, local_address=local_address + ) + # Enable keepalive + stream.get_extra_info("socket").setsockopt( + socket.SOL_SOCKET, socket.SO_KEEPALIVE, KEEPALIVE_FLAG + ) + # MacOS does not have TCP_KEEPIDLE + if hasattr(socket, "TCP_KEEPIDLE"): + keepidle = socket.TCP_KEEPIDLE + else: + keepidle = 0x10 # TCP_KEEPALIVE on mac + + # Set keepalive to 60 seconds + stream.get_extra_info("socket").setsockopt( + socket.IPPROTO_TCP, keepidle, KEEPIDLE_RATE + ) + return stream + + +class Connection(BaseConnection): """ Firebolt database connection class. Implements PEP-249. @@ -31,12 +265,34 @@ class Connection(AsyncBaseConnection): are not implemented. """ - __slots__ = AsyncBaseConnection.__slots__ + ("_closing_lock",) + __slots__ = BaseConnection.__slots__ + ("_closing_lock",) cursor_class = Cursor - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__( + self, + engine_url: str, + database: str, + auth: Auth, + api_endpoint: str = DEFAULT_API_URL, + additional_parameters: Dict[str, Any] = {}, + ): + super().__init__( + engine_url, database, auth, api_endpoint, additional_parameters + ) + user_drivers = additional_parameters.get("user_drivers", []) + user_clients = additional_parameters.get("user_clients", []) + # Override tcp keepalive settings for connection + transport = HTTPTransport() + transport._pool._network_backend = OverriddenHttpBackend() + self._client = Client( + auth=auth, + base_url=engine_url, + api_endpoint=api_endpoint, + timeout=Timeout(DEFAULT_TIMEOUT_SECONDS, read=None), + transport=transport, + headers={"User-Agent": get_user_agent_header(user_drivers, user_clients)}, + ) # Holding this lock for write means that connection is closing itself. # cursor() should hold this lock for read to read/write state self._closing_lock = RWLockWrite() @@ -47,10 +303,20 @@ def cursor(self) -> Cursor: assert isinstance(c, Cursor) # typecheck return c - @wraps(AsyncBaseConnection._aclose) def close(self) -> None: - with self._closing_lock.gen_wlock(): - async_to_sync(self._aclose)() + if self.closed: + return + + # self._cursors is going to be changed during closing cursors + # after this point no cursors would be added to _cursors, only removed since + # closing lock is held, and later connection will be marked as closed + cursors = self._cursors[:] + for c in cursors: + # Here c can already be closed by another thread, + # but it shouldn't raise an error in this case + c.close() + self._client.close() + self._is_closed = True # Context manager support def __enter__(self) -> Connection: @@ -68,4 +334,4 @@ def __del__(self) -> None: warn(f"Unclosed {self!r}", UserWarning) -connect = async_to_sync(async_connect_factory(Connection)) +connect = sync_connect_factory(Connection) diff --git a/src/firebolt/db/cursor.py b/src/firebolt/db/cursor.py index 96e1f976def..874512638cc 100644 --- a/src/firebolt/db/cursor.py +++ b/src/firebolt/db/cursor.py @@ -1,23 +1,51 @@ from __future__ import annotations -from functools import wraps +import logging +import re +import time from threading import Lock from typing import Any, Generator, List, Optional, Sequence, Tuple, Union +from httpx import Response, codes from readerwriterlock.rwlock import RWLockWrite -from firebolt.async_db._types import ColType -from firebolt.async_db.cursor import BaseCursor as AsyncBaseCursor -from firebolt.async_db.cursor import ( +from firebolt.common._types import ( + ColType, + Column, ParameterType, + RawColType, + SetParameter, + split_format_sql, +) +from firebolt.common.base_cursor import ( + JSON_OUTPUT_FORMAT, + BaseCursor, + CursorState, QueryStatus, + Statistics, check_not_closed, check_query_executed, ) -from firebolt.utils.util import async_to_sync +# from firebolt.db.cursor import ( +# ParameterType, +# QueryStatus, +# check_not_closed, +# check_query_executed, +# ) +from firebolt.db.util import is_db_available, is_engine_running +from firebolt.utils.exception import ( + AsyncExecutionUnavailableError, + EngineNotRunningError, + FireboltDatabaseError, + OperationalError, + ProgrammingError, +) + +logger = logging.getLogger(__name__) -class Cursor(AsyncBaseCursor): + +class Cursor(BaseCursor): """ Class, responsible for executing queries to Firebolt Database. Should not be created directly, @@ -31,7 +59,7 @@ class Cursor(AsyncBaseCursor): with the :py:func:`fetchmany` method """ - __slots__ = AsyncBaseCursor.__slots__ + ( + __slots__ = BaseCursor.__slots__ + ( "_query_lock", "_idx_lock", ) @@ -41,7 +69,158 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._idx_lock = Lock() super().__init__(*args, **kwargs) - @wraps(AsyncBaseCursor.execute) + def _raise_if_error(self, resp: Response) -> None: + """Raise a proper error if any""" + if resp.status_code == codes.INTERNAL_SERVER_ERROR: + raise OperationalError( + f"Error executing query:\n{resp.read().decode('utf-8')}" + ) + if resp.status_code == codes.FORBIDDEN: + if not is_db_available(self.connection, self.connection.database): + raise FireboltDatabaseError( + f"Database {self.connection.database} does not exist" + ) + raise ProgrammingError(resp.read().decode("utf-8")) + if ( + resp.status_code == codes.SERVICE_UNAVAILABLE + or resp.status_code == codes.NOT_FOUND + ): + if not is_engine_running(self.connection, self.connection.engine_url): + raise EngineNotRunningError( + f"Firebolt engine {self.connection.engine_url} " + "needs to be running to run queries against it." + ) + resp.raise_for_status() + + def _api_request( + self, + query: Optional[str] = "", + parameters: Optional[dict[str, Any]] = {}, + path: Optional[str] = "", + use_set_parameters: Optional[bool] = True, + ) -> Response: + """ + Query API, return Response object. + + Args: + query (str): SQL query + parameters (Optional[Sequence[ParameterType]]): A sequence of substitution + parameters. Used to replace '?' placeholders inside a query with + actual values. Note: In order to "output_format" dict value, it + must be an empty string. If no value not specified, + JSON_OUTPUT_FORMAT will be used. + path (str): endpoint suffix, for example "cancel" or "status" + use_set_parameters: Optional[bool]: Some queries will fail if additional + set parameters are sent. Setting this to False will allow + self._set_parameters to be ignored. + """ + if use_set_parameters: + parameters = {**(self._set_parameters or {}), **(parameters or {})} + return self._client.request( + url=f"/{path}", + method="POST", + params={ + "database": self.connection.database, + **(parameters or dict()), + }, + content=query, + ) + + def _validate_set_parameter(self, parameter: SetParameter) -> None: + """Validate parameter by executing simple query with it.""" + if parameter.name == "async_execution": + raise AsyncExecutionUnavailableError( + "It is not possible to set async_execution using a SET command. " + "Instead, pass it as an argument to the execute() or " + "executemany() function." + ) + resp = self._api_request("select 1", {parameter.name: parameter.value}) + # Handle invalid set parameter + if resp.status_code == codes.BAD_REQUEST: + raise OperationalError(resp.text) + self._raise_if_error(resp) + + # set parameter passed validation + self._set_parameters[parameter.name] = parameter.value + + def _do_execute( + self, + raw_query: str, + parameters: Sequence[Sequence[ParameterType]], + skip_parsing: bool = False, + async_execution: Optional[bool] = False, + ) -> None: + self._reset() + # Allow users to manually skip parsing for performance improvement. + queries: List[Union[SetParameter, str]] = ( + [raw_query] if skip_parsing else split_format_sql(raw_query, parameters) + ) + try: + for query in queries: + + start_time = time.time() + # Our CREATE EXTERNAL TABLE queries currently require credentials, + # so we will skip logging those queries. + # https://docs.firebolt.io/sql-reference/commands/create-external-table.html + if isinstance(query, SetParameter) or not re.search( + "aws_key_id|credentials", query, flags=re.IGNORECASE + ): + logger.debug(f"Running query: {query}") + + # Define type for mypy + row_set: Tuple[ + int, + Optional[List[Column]], + Optional[Statistics], + Optional[List[List[RawColType]]], + ] = (-1, None, None, None) + if isinstance(query, SetParameter): + self._validate_set_parameter(query) + elif async_execution: + self._validate_server_side_async_settings( + parameters, + queries, + skip_parsing, + async_execution, + ) + response = self._api_request( + query, + { + "async_execution": 1, + "advanced_mode": 1, + "output_format": JSON_OUTPUT_FORMAT, + }, + ) + self._raise_if_error(response) + if response.headers.get("content-length", "") == "0": + raise OperationalError("No response to asynchronous query.") + resp = response.json() + if "query_id" not in resp or resp["query_id"] == "": + raise OperationalError( + "Invalid response to asynchronous query: missing query_id." + ) + self._query_id = resp["query_id"] + else: + resp = self._api_request( + query, {"output_format": JSON_OUTPUT_FORMAT} + ) + self._raise_if_error(resp) + row_set = self._row_set_from_response(resp) + + self._append_row_set(row_set) + + logger.info( + f"Query fetched {self.rowcount} rows in" + f" {time.time() - start_time} seconds." + ) + + self._state = CursorState.DONE + + except Exception: + self._state = CursorState.ERROR + raise + + @check_not_closed def execute( self, query: str, @@ -49,47 +228,116 @@ def execute( skip_parsing: bool = False, async_execution: Optional[bool] = False, ) -> Union[int, str]: - with self._query_lock.gen_wlock(): - return async_to_sync(super().execute)( - query, parameters, skip_parsing, async_execution - ) + """Prepare and execute a database query. + + Supported features: + Parameterized queries: placeholder characters ('?') are substituted + with values provided in `parameters`. Values are formatted to + be properly recognized by database and to exclude SQL injection. + Multi-statement queries: multiple statements, provided in a single query + and separated by semicolon, are executed separatelly and sequentially. + To switch to next statement result, `nextset` method should be used. + SET statements: to provide additional query execution parameters, execute + `SET param=value` statement before it. All parameters are stored in + cursor object until it's closed. They can also be removed with + `flush_parameters` method call. - @wraps(AsyncBaseCursor.executemany) + Args: + query (str): SQL query to execute + parameters (Optional[Sequence[ParameterType]]): A sequence of substitution + parameters. Used to replace '?' placeholders inside a query with + actual values + skip_parsing (bool): Flag to disable query parsing. This will + disable parameterized, multi-statement and SET queries, + while improving performance + async_execution (bool): flag to determine if query should be asynchronous + + Returns: + int: Query row count. + """ + params_list = [parameters] if parameters else [] + self._do_execute(query, params_list, skip_parsing, async_execution) + return self.query_id if async_execution else self.rowcount + + @check_not_closed def executemany( self, query: str, parameters_seq: Sequence[Sequence[ParameterType]], async_execution: Optional[bool] = False, ) -> Union[int, str]: - with self._query_lock.gen_wlock(): - return async_to_sync(super().executemany)( - query, parameters_seq, async_execution - ) + """Prepare and execute a database query. - @wraps(AsyncBaseCursor._get_next_range) - def _get_next_range(self, size: int) -> Tuple[int, int]: - with self._idx_lock: - return super()._get_next_range(size) + Supports providing multiple substitution parameter sets, executing them + as multiple statements sequentially. - @wraps(AsyncBaseCursor.fetchone) - def fetchone(self) -> Optional[List[ColType]]: - with self._query_lock.gen_rlock(): - return super().fetchone() + Supported features: + Parameterized queries: Placeholder characters ('?') are substituted + with values provided in `parameters`. Values are formatted to + be properly recognized by database and to exclude SQL injection. + Multi-statement queries: Multiple statements, provided in a single query + and separated by semicolon, are executed separately and sequentially. + To switch to next statement result, use `nextset` method. + SET statements: To provide additional query execution parameters, execute + `SET param=value` statement before it. All parameters are stored in + cursor object until it's closed. They can also be removed with + `flush_parameters` method call. - @wraps(AsyncBaseCursor.fetchmany) - def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]: - with self._query_lock.gen_rlock(): - return super().fetchmany(size) + Args: + query (str): SQL query to execute. + parameters_seq (Sequence[Sequence[ParameterType]]): A sequence of + substitution parameter sets. Used to replace '?' placeholders inside a + query with actual values from each set in a sequence. Resulting queries + for each subset are executed sequentially. + async_execution (bool): flag to determine if query should be asynchronous - @wraps(AsyncBaseCursor.fetchall) - def fetchall(self) -> List[List[ColType]]: - with self._query_lock.gen_rlock(): - return super().fetchall() + Returns: + int|str: Query row count for synchronous execution of queries, + query ID string for asynchronous execution. + """ + self._do_execute(query, parameters_seq, async_execution=async_execution) + if async_execution: + return self.query_id + else: + return self.rowcount - @wraps(AsyncBaseCursor.nextset) - def nextset(self) -> None: - with self._query_lock.gen_rlock(), self._idx_lock: - return super().nextset() + @check_not_closed + def get_status(self, query_id: str) -> QueryStatus: + """Get status of a server-side async query. Return the state of the query.""" + try: + resp = self._api_request( + # output_format must be empty for status to work correctly. + # And set parameters will cause 400 errors. + parameters={"query_id": query_id}, + path="status", + use_set_parameters=False, + ) + if resp.status_code == codes.BAD_REQUEST: + raise OperationalError( + f"Asynchronous query {query_id} status check failed: " + f"{resp.status_code}." + ) + resp_json = resp.json() + if "status" not in resp_json: + raise OperationalError( + "Invalid response to asynchronous query: missing status." + ) + except Exception: + self._state = CursorState.ERROR + raise + # Remember that query_id might be empty. + if resp_json["status"] == "": + return QueryStatus.NOT_READY + return QueryStatus[resp_json["status"]] + + @check_not_closed + def cancel(self, query_id: str) -> None: + """Cancel a server-side async query.""" + self._api_request( + parameters={"query_id": query_id}, + path="cancel", + use_set_parameters=False, + ) # Iteration support @check_not_closed @@ -100,13 +348,3 @@ def __iter__(self) -> Generator[List[ColType], None, None]: if row is None: return yield row - - @wraps(AsyncBaseCursor.get_status) - def get_status(self, query_id: str) -> QueryStatus: - with self._query_lock.gen_rlock(): - return async_to_sync(super().get_status)(query_id) - - @wraps(AsyncBaseCursor.cancel) - def cancel(self, query_id: str) -> None: - with self._query_lock.gen_rlock(): - return async_to_sync(super().cancel)(query_id) diff --git a/src/firebolt/db/util.py b/src/firebolt/db/util.py new file mode 100644 index 00000000000..94f24c3be7c --- /dev/null +++ b/src/firebolt/db/util.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from httpx import URL, Response + +from firebolt.utils.urls import DATABASES_URL, ENGINES_URL + +if TYPE_CHECKING: + from firebolt.db.connection import Connection + + +def is_db_available(connection: Connection, database_name: str) -> bool: + """ + Verify that the database exists. + + Args: + connection (firebolt.async_db.connection.Connection) + """ + resp = _filter_request( + connection, DATABASES_URL, {"filter.name_contains": database_name} + ) + return len(resp.json()["edges"]) > 0 + + +def is_engine_running(connection: Connection, engine_url: str) -> bool: + """ + Verify that the engine is running. + + Args: + connection (firebolt.async_db.connection.Connection): connection. + """ + # Url is not guaranteed to be of this structure, + # but for the sake of error checking this is sufficient. + engine_name = URL(engine_url).host.split(".")[0].replace("-", "_") + resp = _filter_request( + connection, + ENGINES_URL, + { + "filter.name_contains": engine_name, + "filter.current_status_eq": "ENGINE_STATUS_RUNNING_REVISION_SERVING", + }, + ) + return len(resp.json()["edges"]) > 0 + + +def _filter_request(connection: Connection, endpoint: str, filters: dict) -> Response: + resp = connection._client.request( + # Full url overrides the client url, which contains engine as a prefix. + url=connection.api_endpoint + endpoint, + method="GET", + params=filters, + ) + resp.raise_for_status() + return resp diff --git a/tests/integration/dbapi/async/test_queries_async.py b/tests/integration/dbapi/async/test_queries_async.py index 90c1b665e9d..ed793f808d2 100644 --- a/tests/integration/dbapi/async/test_queries_async.py +++ b/tests/integration/dbapi/async/test_queries_async.py @@ -11,8 +11,8 @@ DataError, OperationalError, ) -from firebolt.async_db._types import ColType, Column from firebolt.async_db.cursor import QueryStatus +from firebolt.common._types import ColType, Column VALS_TO_INSERT_2 = ",".join( [f"({i}, {i-3}, '{val}')" for (i, val) in enumerate(range(4, 1000))] diff --git a/tests/integration/dbapi/conftest.py b/tests/integration/dbapi/conftest.py index 1843a22b894..b61bedf4077 100644 --- a/tests/integration/dbapi/conftest.py +++ b/tests/integration/dbapi/conftest.py @@ -5,8 +5,8 @@ from pytest import fixture -from firebolt.async_db._types import ColType from firebolt.async_db.cursor import Column +from firebolt.common._types import ColType from firebolt.db import ARRAY, DECIMAL, Connection LOGGER = getLogger(__name__) diff --git a/tests/integration/dbapi/sync/test_queries.py b/tests/integration/dbapi/sync/test_queries.py index a409dea4729..1977b3b2bf9 100644 --- a/tests/integration/dbapi/sync/test_queries.py +++ b/tests/integration/dbapi/sync/test_queries.py @@ -5,9 +5,9 @@ from pytest import mark, raises -from firebolt.async_db._types import ColType, Column from firebolt.async_db.cursor import QueryStatus from firebolt.client.auth import Auth +from firebolt.common._types import ColType, Column from firebolt.db import ( Binary, Connection, diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index 2d665297972..69b2ed68dc8 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -7,9 +7,9 @@ from pytest import mark, raises from pytest_httpx import HTTPXMock -from firebolt.async_db._types import ColType from firebolt.async_db.connection import Connection, connect from firebolt.client.auth import Auth, Token, UsernamePassword +from firebolt.common._types import ColType from firebolt.common.settings import Settings from firebolt.utils.exception import ( AccountNotFoundError, diff --git a/tests/unit/async_db/test_cursor.py b/tests/unit/async_db/test_cursor.py index 69be65ae398..b5e52316627 100644 --- a/tests/unit/async_db/test_cursor.py +++ b/tests/unit/async_db/test_cursor.py @@ -6,8 +6,8 @@ from pytest_httpx import HTTPXMock from firebolt.async_db import Cursor -from firebolt.async_db._types import Column -from firebolt.async_db.cursor import ColType, CursorState, QueryStatus +from firebolt.common._types import Column +from firebolt.common.base_cursor import ColType, CursorState, QueryStatus from firebolt.utils.exception import ( AsyncExecutionUnavailableError, CursorClosedError, diff --git a/tests/unit/async_db/test_typing_format.py b/tests/unit/async_db/test_typing_format.py index a5e68e07aed..c04027ee6c6 100644 --- a/tests/unit/async_db/test_typing_format.py +++ b/tests/unit/async_db/test_typing_format.py @@ -12,7 +12,7 @@ InterfaceError, NotSupportedError, ) -from firebolt.async_db._types import ( +from firebolt.common._types import ( SetParameter, format_statement, format_value, diff --git a/tests/unit/async_db/test_typing_parse.py b/tests/unit/async_db/test_typing_parse.py index fad2c93a693..76528a96c78 100644 --- a/tests/unit/async_db/test_typing_parse.py +++ b/tests/unit/async_db/test_typing_parse.py @@ -11,7 +11,7 @@ TimeFromTicks, TimestampFromTicks, ) -from firebolt.async_db._types import parse_type, parse_value +from firebolt.common._types import parse_type, parse_value from firebolt.utils.exception import DataError, NotSupportedError diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index be1f7ee325d..0acf4392a91 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -8,8 +8,8 @@ from pytest import mark, raises, warns from pytest_httpx import HTTPXMock -from firebolt.async_db._types import ColType from firebolt.client.auth import Auth, Token, UsernamePassword +from firebolt.common._types import ColType from firebolt.common.settings import Settings from firebolt.db import Connection, connect from firebolt.utils.exception import ( diff --git a/tests/unit/db/test_cursor.py b/tests/unit/db/test_cursor.py index 8562d50a901..8b8e18be77c 100644 --- a/tests/unit/db/test_cursor.py +++ b/tests/unit/db/test_cursor.py @@ -5,8 +5,8 @@ from pytest import raises from pytest_httpx import HTTPXMock -from firebolt.async_db.cursor import ColType, Column, CursorState, QueryStatus from firebolt.db import Cursor +from firebolt.db.cursor import ColType, Column, CursorState, QueryStatus from firebolt.utils.exception import ( CursorClosedError, DataError, @@ -542,11 +542,11 @@ def test_cursor_skip_parse( httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(query_callback, url=query_url) - with patch("firebolt.async_db.cursor.split_format_sql") as split_format_sql_mock: + with patch("firebolt.db.cursor.split_format_sql") as split_format_sql_mock: cursor.execute("non-an-actual-sql") split_format_sql_mock.assert_called_once() - with patch("firebolt.async_db.cursor.split_format_sql") as split_format_sql_mock: + with patch("firebolt.db.cursor.split_format_sql") as split_format_sql_mock: cursor.execute("non-an-actual-sql", skip_parsing=True) split_format_sql_mock.assert_not_called() From d2da3dde2dde4c56d09085e8467abf4623930da7 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Fri, 5 May 2023 14:27:34 +0100 Subject: [PATCH 02/13] refactor for mypy --- src/firebolt/async_db/connection.py | 65 ++++++++++---- src/firebolt/async_db/cursor.py | 14 +++- src/firebolt/client/auth/__init__.py | 1 + src/firebolt/client/auth/utils.py | 37 ++++++++ src/firebolt/common/base_connection.py | 112 ------------------------- src/firebolt/common/base_cursor.py | 9 +- src/firebolt/common/settings.py | 5 ++ src/firebolt/db/connection.py | 68 +++++++++++---- src/firebolt/db/cursor.py | 21 ++++- 9 files changed, 182 insertions(+), 150 deletions(-) create mode 100644 src/firebolt/client/auth/utils.py delete mode 100644 src/firebolt/common/base_connection.py diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 7cf550ba437..e73d223151c 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -4,7 +4,7 @@ import socket from json import JSONDecodeError from types import TracebackType -from typing import Any, Callable, Dict, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Type from httpcore.backends.auto import AutoBackend from httpcore.backends.base import AsyncNetworkStream @@ -12,14 +12,11 @@ from firebolt.async_db.cursor import Cursor from firebolt.client import DEFAULT_API_URL, AsyncClient -from firebolt.client.auth import Auth -from firebolt.common.base_connection import ( +from firebolt.client.auth import Auth, _get_auth +from firebolt.common.settings import ( DEFAULT_TIMEOUT_SECONDS, KEEPALIVE_FLAG, KEEPIDLE_RATE, - BaseConnection, - _get_auth, - _validate_engine_name_and_url, ) from firebolt.utils.exception import ( ConfigurationError, @@ -50,6 +47,15 @@ logger = logging.getLogger(__name__) +def _validate_engine_name_and_url( + engine_name: Optional[str], engine_url: Optional[str] +) -> None: + if engine_name and engine_url: + raise ConfigurationError( + "Both engine_name and engine_url are provided. Provide only one to connect." + ) + + async def _resolve_engine_url( engine_name: str, auth: Auth, @@ -125,7 +131,7 @@ async def _get_database_default_engine_url( raise InterfaceError(f"Unable to retrieve default engine endpoint: {e}.") -def async_connect_factory(connection_class: Type, async_class=False) -> Callable: +def async_connect_factory(connection_class: Type) -> Callable: async def connect_inner( database: str = None, username: Optional[str] = None, @@ -256,7 +262,7 @@ async def connect_tcp( return stream -class Connection(BaseConnection): +class Connection: """ Firebolt asynchronous database connection class. Implements `PEP 249`_. @@ -279,7 +285,15 @@ class Connection(BaseConnection): """ - __slots__ = BaseConnection.__slots__ + client_class: type + __slots__ = ( + "_client", + "_cursors", + "database", + "engine_url", + "api_endpoint", + "_is_closed", + ) cursor_class = Cursor @@ -291,9 +305,11 @@ def __init__( api_endpoint: str = DEFAULT_API_URL, additional_parameters: Dict[str, Any] = {}, ): - super().__init__( - engine_url, database, auth, api_endpoint, additional_parameters - ) + self.api_endpoint = api_endpoint + self.engine_url = engine_url + self.database = database + self._cursors: List[Cursor] = [] + self._is_closed = False user_drivers = additional_parameters.get("user_drivers", []) user_clients = additional_parameters.get("user_clients", []) # Override tcp keepalive settings for connection @@ -309,8 +325,11 @@ def __init__( ) def cursor(self) -> Cursor: - c = super()._cursor() - assert isinstance(c, Cursor) # typecheck + if self.closed: + raise ConnectionClosedError("Unable to create cursor: connection closed.") + + c = self.cursor_class(client=self._client, connection=self) + self._cursors.append(c) return c # Context manager support @@ -319,6 +338,18 @@ async def __aenter__(self) -> Connection: raise ConnectionClosedError("Connection is already closed.") return self + def _remove_cursor(self, cursor: Cursor) -> None: + # This way it's atomic + try: + self._cursors.remove(cursor) + except ValueError: + pass + + @property + def closed(self) -> bool: + """`True` if connection is closed; `False` otherwise.""" + return self._is_closed + async def aclose(self) -> None: """Close connection and all underlying cursors.""" if self.closed: @@ -335,6 +366,12 @@ async def aclose(self) -> None: await self._client.aclose() self._is_closed = True + def commit(self) -> None: + """Does nothing since Firebolt doesn't have transactions.""" + + if self.closed: + raise ConnectionClosedError("Unable to commit: Connection closed.") + async def __aexit__( self, exc_type: type, exc_val: Exception, exc_tb: TracebackType ) -> None: diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 8dfbabedbe5..5c64b09be0f 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -44,7 +44,9 @@ ) if TYPE_CHECKING: - pass + from firebolt.async_db.connection import Connection + +from httpx import AsyncClient as AsyncHttpxClient logger = logging.getLogger(__name__) @@ -96,9 +98,17 @@ class Cursor(BaseCursor): __slots__ = BaseCursor.__slots__ + ("_async_query_lock",) - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__( + self, + *args: Any, + client: AsyncHttpxClient, + connection: Connection, + **kwargs: Any, + ) -> None: self._async_query_lock = RWLock() super().__init__(*args, **kwargs) + self._client = client + self.connection = connection async def _raise_if_error(self, resp: Response) -> None: """Raise a proper error if any""" diff --git a/src/firebolt/client/auth/__init__.py b/src/firebolt/client/auth/__init__.py index 2fef2f52339..90e4d4149a5 100644 --- a/src/firebolt/client/auth/__init__.py +++ b/src/firebolt/client/auth/__init__.py @@ -2,3 +2,4 @@ from firebolt.client.auth.service_account import ServiceAccount from firebolt.client.auth.token import Token from firebolt.client.auth.username_password import UsernamePassword +from firebolt.client.auth.utils import _get_auth diff --git a/src/firebolt/client/auth/utils.py b/src/firebolt/client/auth/utils.py new file mode 100644 index 00000000000..f6bb6b09374 --- /dev/null +++ b/src/firebolt/client/auth/utils.py @@ -0,0 +1,37 @@ +from typing import Optional + +from firebolt.client.auth import Auth, Token, UsernamePassword +from firebolt.utils.exception import ConfigurationError + + +def _get_auth( + username: Optional[str], + password: Optional[str], + access_token: Optional[str], + use_token_cache: bool, +) -> Auth: + """Create `Auth` class based on provided credentials. + + If `access_token` is provided, it's used for `Auth` creation. + Otherwise, username/password are used. + + Returns: + Auth: `auth object` + + Raises: + `ConfigurationError`: Invalid combination of credentials provided + + """ + if not access_token: + if not username or not password: + raise ConfigurationError( + "Neither username/password nor access_token are provided. Provide one" + " to authenticate." + ) + return UsernamePassword(username, password, use_token_cache) + if username or password: + raise ConfigurationError( + "Username/password and access_token are both provided. Provide only one" + " to authenticate." + ) + return Token(access_token) diff --git a/src/firebolt/common/base_connection.py b/src/firebolt/common/base_connection.py deleted file mode 100644 index 209c8268dbe..00000000000 --- a/src/firebolt/common/base_connection.py +++ /dev/null @@ -1,112 +0,0 @@ -from __future__ import annotations - -from typing import Any, Dict, List, Optional - -from firebolt.async_db.cursor import BaseCursor, Cursor -from firebolt.client import DEFAULT_API_URL -from firebolt.client.auth import Auth, Token, UsernamePassword -from firebolt.utils.exception import ConfigurationError, ConnectionClosedError - -KEEPALIVE_FLAG: int = 1 - -KEEPIDLE_RATE: int = 60 # seconds -DEFAULT_TIMEOUT_SECONDS: int = 60 - - -def _get_auth( - username: Optional[str], - password: Optional[str], - access_token: Optional[str], - use_token_cache: bool, -) -> Auth: - """Create `Auth` class based on provided credentials. - - If `access_token` is provided, it's used for `Auth` creation. - Otherwise, username/password are used. - - Returns: - Auth: `auth object` - - Raises: - `ConfigurationError`: Invalid combination of credentials provided - - """ - if not access_token: - if not username or not password: - raise ConfigurationError( - "Neither username/password nor access_token are provided. Provide one" - " to authenticate." - ) - return UsernamePassword(username, password, use_token_cache) - if username or password: - raise ConfigurationError( - "Username/password and access_token are both provided. Provide only one" - " to authenticate." - ) - return Token(access_token) - - -def _validate_engine_name_and_url( - engine_name: Optional[str], engine_url: Optional[str] -) -> None: - if engine_name and engine_url: - raise ConfigurationError( - "Both engine_name and engine_url are provided. Provide only one to connect." - ) - - -class BaseConnection: - client_class: type - cursor_class: type - __slots__ = ( - "_client", - "_cursors", - "database", - "engine_url", - "api_endpoint", - "_is_closed", - ) - - def __init__( - self, - engine_url: str, - database: str, - auth: Auth, - api_endpoint: str = DEFAULT_API_URL, - additional_parameters: Dict[str, Any] = {}, - ): - self.api_endpoint = api_endpoint - self.engine_url = engine_url - self.database = database - self._cursors: List[BaseCursor] = [] - self._is_closed = False - - def _cursor(self, **kwargs: Any) -> BaseCursor: - """ - Create new cursor object. - """ - - if self.closed: - raise ConnectionClosedError("Unable to create cursor: connection closed.") - - c = self.cursor_class(self._client, self, **kwargs) - self._cursors.append(c) - return c - - @property - def closed(self) -> bool: - """`True` if connection is closed; `False` otherwise.""" - return self._is_closed - - def _remove_cursor(self, cursor: Cursor) -> None: - # This way it's atomic - try: - self._cursors.remove(cursor) - except ValueError: - pass - - def commit(self) -> None: - """Does nothing since Firebolt doesn't have transactions.""" - - if self.closed: - raise ConnectionClosedError("Unable to commit: Connection closed.") diff --git a/src/firebolt/common/base_cursor.py b/src/firebolt/common/base_cursor.py index 0e7fa512642..5e9c40f9fc2 100644 --- a/src/firebolt/common/base_cursor.py +++ b/src/firebolt/common/base_cursor.py @@ -19,7 +19,6 @@ from httpx import Response from pydantic import BaseModel -from firebolt.client import AsyncClient, Client from firebolt.common._types import ( ColType, Column, @@ -37,7 +36,7 @@ ) if TYPE_CHECKING: - from firebolt.common.base_connection import BaseConnection + pass logger = logging.getLogger(__name__) @@ -126,9 +125,9 @@ class BaseCursor: default_arraysize = 1 - def __init__(self, client: Union[AsyncClient, Client], connection: BaseConnection): - self.connection = connection - self._client = client + def __init__(self, *args: Any, **kwargs: Any) -> None: + # self.connection = None + # self._client = None self._arraysize = self.default_arraysize # These fields initialized here for type annotations purpose self._rows: Optional[List[List[RawColType]]] = None diff --git a/src/firebolt/common/settings.py b/src/firebolt/common/settings.py index f14fda233e3..3cc2321937a 100644 --- a/src/firebolt/common/settings.py +++ b/src/firebolt/common/settings.py @@ -7,6 +7,11 @@ logger = logging.getLogger(__name__) +KEEPALIVE_FLAG: int = 1 + +KEEPIDLE_RATE: int = 60 # seconds +DEFAULT_TIMEOUT_SECONDS: int = 60 + AUTH_CREDENTIALS_DEPRECATION_MESSAGE = """ Passing connection credentials directly in Settings is deprecated. Use Auth object instead. Examples: diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index ac3c1c9a1a0..5c6cb061f9a 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -4,7 +4,7 @@ import socket from json import JSONDecodeError from types import TracebackType -from typing import Any, Callable, Dict, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Type from warnings import warn from httpcore.backends.base import NetworkStream @@ -13,16 +13,13 @@ from readerwriterlock.rwlock import RWLockWrite from firebolt.client import DEFAULT_API_URL, Client -from firebolt.client.auth import Auth -from firebolt.common.base_connection import ( +from firebolt.client.auth import Auth, _get_auth +from firebolt.common.settings import ( + AUTH_CREDENTIALS_DEPRECATION_MESSAGE, DEFAULT_TIMEOUT_SECONDS, KEEPALIVE_FLAG, KEEPIDLE_RATE, - BaseConnection, - _get_auth, - _validate_engine_name_and_url, ) -from firebolt.common.settings import AUTH_CREDENTIALS_DEPRECATION_MESSAGE from firebolt.db.cursor import Cursor from firebolt.utils.exception import ( ConfigurationError, @@ -41,6 +38,15 @@ logger = logging.getLogger(__name__) +def _validate_engine_name_and_url( + engine_name: Optional[str], engine_url: Optional[str] +) -> None: + if engine_name and engine_url: + raise ConfigurationError( + "Both engine_name and engine_url are provided. Provide only one to connect." + ) + + def _resolve_engine_url( engine_name: str, auth: Auth, @@ -248,7 +254,7 @@ def connect_tcp( return stream -class Connection(BaseConnection): +class Connection: """ Firebolt database connection class. Implements PEP-249. @@ -265,7 +271,16 @@ class Connection(BaseConnection): are not implemented. """ - __slots__ = BaseConnection.__slots__ + ("_closing_lock",) + client_class: type + __slots__ = ( + "_client", + "_cursors", + "database", + "engine_url", + "api_endpoint", + "_is_closed", + "_closing_lock", + ) cursor_class = Cursor @@ -277,9 +292,11 @@ def __init__( api_endpoint: str = DEFAULT_API_URL, additional_parameters: Dict[str, Any] = {}, ): - super().__init__( - engine_url, database, auth, api_endpoint, additional_parameters - ) + self.api_endpoint = api_endpoint + self.engine_url = engine_url + self.database = database + self._cursors: List[Cursor] = [] + self._is_closed = False user_drivers = additional_parameters.get("user_drivers", []) user_clients = additional_parameters.get("user_clients", []) # Override tcp keepalive settings for connection @@ -298,10 +315,25 @@ def __init__( self._closing_lock = RWLockWrite() def cursor(self) -> Cursor: + if self.closed: + raise ConnectionClosedError("Unable to create cursor: connection closed.") + with self._closing_lock.gen_rlock(): - c = super()._cursor() - assert isinstance(c, Cursor) # typecheck - return c + c = self.cursor_class(client=self._client, connection=self) + self._cursors.append(c) + return c + + def _remove_cursor(self, cursor: Cursor) -> None: + # This way it's atomic + try: + self._cursors.remove(cursor) + except ValueError: + pass + + @property + def closed(self) -> bool: + """`True` if connection is closed; `False` otherwise.""" + return self._is_closed def close(self) -> None: if self.closed: @@ -318,6 +350,12 @@ def close(self) -> None: self._client.close() self._is_closed = True + def commit(self) -> None: + """Does nothing since Firebolt doesn't have transactions.""" + + if self.closed: + raise ConnectionClosedError("Unable to commit: Connection closed.") + # Context manager support def __enter__(self) -> Connection: if self.closed: diff --git a/src/firebolt/db/cursor.py b/src/firebolt/db/cursor.py index 874512638cc..50da126dea5 100644 --- a/src/firebolt/db/cursor.py +++ b/src/firebolt/db/cursor.py @@ -4,8 +4,18 @@ import re import time from threading import Lock -from typing import Any, Generator, List, Optional, Sequence, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Generator, + List, + Optional, + Sequence, + Tuple, + Union, +) +from httpx import Client as HttpxClient from httpx import Response, codes from readerwriterlock.rwlock import RWLockWrite @@ -42,6 +52,9 @@ ProgrammingError, ) +if TYPE_CHECKING: + from firebolt.db.connection import Connection + logger = logging.getLogger(__name__) @@ -64,10 +77,14 @@ class Cursor(BaseCursor): "_idx_lock", ) - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__( + self, *args: Any, client: HttpxClient, connection: Connection, **kwargs: Any + ) -> None: self._query_lock = RWLockWrite() self._idx_lock = Lock() super().__init__(*args, **kwargs) + self._client = client + self.connection = connection def _raise_if_error(self, resp: Response) -> None: """Raise a proper error if any""" From 3fb87ab5f9d19fae6b5a9d070064c47ddb03fe55 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Fri, 5 May 2023 17:31:24 +0100 Subject: [PATCH 03/13] cleanup --- src/firebolt/async_db/connection.py | 11 ++++----- src/firebolt/async_db/cursor.py | 25 ++++++++++++++++++- src/firebolt/common/base_cursor.py | 37 +---------------------------- src/firebolt/db/connection.py | 28 +++++++++++----------- src/firebolt/db/cursor.py | 28 +++++++++++++++------- 5 files changed, 64 insertions(+), 65 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index e73d223151c..9f0eb06978f 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -286,6 +286,7 @@ class Connection: """ client_class: type + cursor_class = Cursor __slots__ = ( "_client", "_cursors", @@ -295,8 +296,6 @@ class Connection: "_is_closed", ) - cursor_class = Cursor - def __init__( self, engine_url: str, @@ -310,11 +309,11 @@ def __init__( self.database = database self._cursors: List[Cursor] = [] self._is_closed = False - user_drivers = additional_parameters.get("user_drivers", []) - user_clients = additional_parameters.get("user_clients", []) # Override tcp keepalive settings for connection transport = AsyncHTTPTransport() transport._pool._network_backend = OverriddenHttpBackend() + user_drivers = additional_parameters.get("user_drivers", []) + user_clients = additional_parameters.get("user_clients", []) self._client = AsyncClient( auth=auth, base_url=engine_url, @@ -324,11 +323,11 @@ def __init__( headers={"User-Agent": get_user_agent_header(user_drivers, user_clients)}, ) - def cursor(self) -> Cursor: + def cursor(self, **kwargs: Any) -> Cursor: if self.closed: raise ConnectionClosedError("Unable to create cursor: connection closed.") - c = self.cursor_class(client=self._client, connection=self) + c = self.cursor_class(client=self._client, connection=self, **kwargs) self._cursors.append(c) return c diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 5c64b09be0f..4f41426ac04 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -4,6 +4,7 @@ import re import time from functools import wraps +from types import TracebackType from typing import ( TYPE_CHECKING, Any, @@ -105,8 +106,8 @@ def __init__( connection: Connection, **kwargs: Any, ) -> None: - self._async_query_lock = RWLock() super().__init__(*args, **kwargs) + self._async_query_lock = RWLock() self._client = client self.connection = connection @@ -348,6 +349,24 @@ async def executemany( def __aiter__(self) -> Cursor: return self + def close(self) -> None: + """Terminate an ongoing query (if any) and mark connection as closed.""" + self._state = CursorState.CLOSED + self.connection._remove_cursor(self) + + def __del__(self) -> None: + self.close() + + # Context manager support + @check_not_closed + def __aenter__(self) -> Cursor: + return self + + async def __aexit__( + self, exc_type: type, exc_val: Exception, exc_tb: TracebackType + ) -> None: + self.close() + @check_not_closed @check_query_executed async def __anext__(self) -> List[ColType]: @@ -394,11 +413,13 @@ async def cancel(self, query_id: str) -> None: use_set_parameters=False, ) + @wraps(BaseCursor.fetchone) async def fetchone(self) -> Optional[List[ColType]]: async with self._async_query_lock.reader: """Fetch the next row of a query result set.""" return super().fetchone() + @wraps(BaseCursor.fetchmany) async def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]: async with self._async_query_lock.reader: """ @@ -407,11 +428,13 @@ async def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]: """ return super().fetchmany(size) + @wraps(BaseCursor.fetchall) async def fetchall(self) -> List[List[ColType]]: async with self._async_query_lock.reader: """Fetch all remaining rows of a query result.""" return super().fetchall() + @wraps(BaseCursor.nextset) async def nextset(self) -> None: async with self._async_query_lock.reader: return super().nextset() diff --git a/src/firebolt/common/base_cursor.py b/src/firebolt/common/base_cursor.py index 5e9c40f9fc2..eb288998116 100644 --- a/src/firebolt/common/base_cursor.py +++ b/src/firebolt/common/base_cursor.py @@ -3,18 +3,7 @@ import logging from enum import Enum from functools import wraps -from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - List, - Optional, - Sequence, - Tuple, - Union, -) +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from httpx import Response from pydantic import BaseModel @@ -35,9 +24,6 @@ QueryNotRunError, ) -if TYPE_CHECKING: - pass - logger = logging.getLogger(__name__) @@ -126,8 +112,6 @@ class BaseCursor: default_arraysize = 1 def __init__(self, *args: Any, **kwargs: Any) -> None: - # self.connection = None - # self._client = None self._arraysize = self.default_arraysize # These fields initialized here for type annotations purpose self._rows: Optional[List[List[RawColType]]] = None @@ -148,9 +132,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._query_id = "" self._reset() - def __del__(self) -> None: - self.close() - @property # type: ignore @check_not_closed def description(self) -> Optional[List[Column]]: @@ -205,12 +186,6 @@ def closed(self) -> bool: """True if connection is closed, False otherwise.""" return self._state == CursorState.CLOSED - def close(self) -> None: - """Terminate an ongoing query (if any) and mark connection as closed.""" - self._state = CursorState.CLOSED - # remove typecheck skip after connection is implemented - self.connection._remove_cursor(self) # type: ignore - @check_not_closed @check_query_executed def nextset(self) -> Optional[bool]: @@ -388,13 +363,3 @@ def setinputsizes(self, sizes: List[int]) -> None: @check_not_closed def setoutputsize(self, size: int, column: Optional[int] = None) -> None: """Set a column buffer size for fetches of large columns (does nothing).""" - - # Context manager support - @check_not_closed - def __enter__(self) -> BaseCursor: - return self - - def __exit__( - self, exc_type: type, exc_val: Exception, exc_tb: TracebackType - ) -> None: - self.close() diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index 5c6cb061f9a..c1b093bd228 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -204,7 +204,7 @@ def connect_inner( account_name=account_name, api_endpoint=api_endpoint, ) as client: - client.account_id # TODO: is this correct? + client.account_id assert engine_url is not None @@ -272,6 +272,7 @@ class Connection: """ client_class: type + cursor_class = Cursor __slots__ = ( "_client", "_cursors", @@ -282,8 +283,6 @@ class Connection: "_closing_lock", ) - cursor_class = Cursor - def __init__( self, engine_url: str, @@ -297,11 +296,11 @@ def __init__( self.database = database self._cursors: List[Cursor] = [] self._is_closed = False - user_drivers = additional_parameters.get("user_drivers", []) - user_clients = additional_parameters.get("user_clients", []) # Override tcp keepalive settings for connection transport = HTTPTransport() transport._pool._network_backend = OverriddenHttpBackend() + user_drivers = additional_parameters.get("user_drivers", []) + user_clients = additional_parameters.get("user_clients", []) self._client = Client( auth=auth, base_url=engine_url, @@ -314,12 +313,12 @@ def __init__( # cursor() should hold this lock for read to read/write state self._closing_lock = RWLockWrite() - def cursor(self) -> Cursor: + def cursor(self, **kwargs: Any) -> Cursor: if self.closed: raise ConnectionClosedError("Unable to create cursor: connection closed.") with self._closing_lock.gen_rlock(): - c = self.cursor_class(client=self._client, connection=self) + c = self.cursor_class(client=self._client, connection=self, **kwargs) self._cursors.append(c) return c @@ -342,13 +341,14 @@ def close(self) -> None: # self._cursors is going to be changed during closing cursors # after this point no cursors would be added to _cursors, only removed since # closing lock is held, and later connection will be marked as closed - cursors = self._cursors[:] - for c in cursors: - # Here c can already be closed by another thread, - # but it shouldn't raise an error in this case - c.close() - self._client.close() - self._is_closed = True + with self._closing_lock.gen_wlock(): + cursors = self._cursors[:] + for c in cursors: + # Here c can already be closed by another thread, + # but it shouldn't raise an error in this case + c.close() + self._client.close() + self._is_closed = True def commit(self) -> None: """Does nothing since Firebolt doesn't have transactions.""" diff --git a/src/firebolt/db/cursor.py b/src/firebolt/db/cursor.py index 50da126dea5..69c84a81115 100644 --- a/src/firebolt/db/cursor.py +++ b/src/firebolt/db/cursor.py @@ -4,6 +4,7 @@ import re import time from threading import Lock +from types import TracebackType from typing import ( TYPE_CHECKING, Any, @@ -36,13 +37,6 @@ check_not_closed, check_query_executed, ) - -# from firebolt.db.cursor import ( -# ParameterType, -# QueryStatus, -# check_not_closed, -# check_query_executed, -# ) from firebolt.db.util import is_db_available, is_engine_running from firebolt.utils.exception import ( AsyncExecutionUnavailableError, @@ -80,9 +74,9 @@ class Cursor(BaseCursor): def __init__( self, *args: Any, client: HttpxClient, connection: Connection, **kwargs: Any ) -> None: + super().__init__(*args, **kwargs) self._query_lock = RWLockWrite() self._idx_lock = Lock() - super().__init__(*args, **kwargs) self._client = client self.connection = connection @@ -347,6 +341,24 @@ def get_status(self, query_id: str) -> QueryStatus: return QueryStatus.NOT_READY return QueryStatus[resp_json["status"]] + def close(self) -> None: + """Terminate an ongoing query (if any) and mark connection as closed.""" + self._state = CursorState.CLOSED + self.connection._remove_cursor(self) + + def __del__(self) -> None: + self.close() + + # Context manager support + @check_not_closed + def __enter__(self) -> Cursor: + return self + + def __exit__( + self, exc_type: type, exc_val: Exception, exc_tb: TracebackType + ) -> None: + self.close() + @check_not_closed def cancel(self, query_id: str) -> None: """Cancel a server-side async query.""" From 02612376cd1f45938d96687197cf799d88a2292b Mon Sep 17 00:00:00 2001 From: ptiurin Date: Fri, 5 May 2023 17:48:37 +0100 Subject: [PATCH 04/13] fix tests --- src/firebolt/async_db/cursor.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 4f41426ac04..890e71674a7 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -9,6 +9,7 @@ TYPE_CHECKING, Any, Callable, + Iterator, List, Optional, Sequence, @@ -359,9 +360,22 @@ def __del__(self) -> None: # Context manager support @check_not_closed + def __enter__(self) -> Cursor: + return self + + def __exit__( + self, exc_type: type, exc_val: Exception, exc_tb: TracebackType + ) -> None: + self.close() + + # TODO: figure out how to implement __aenter__ and __await__ + @check_not_closed def __aenter__(self) -> Cursor: return self + def __await__(self) -> Iterator: + pass + async def __aexit__( self, exc_type: type, exc_val: Exception, exc_tb: TracebackType ) -> None: From 4b1e6c0a218d0dc302c3e7bc0b1de7b625a5c907 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Fri, 5 May 2023 17:49:44 +0100 Subject: [PATCH 05/13] add commented out test --- tests/unit/async_db/test_cursor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unit/async_db/test_cursor.py b/tests/unit/async_db/test_cursor.py index b5e52316627..64f13d4318f 100644 --- a/tests/unit/async_db/test_cursor.py +++ b/tests/unit/async_db/test_cursor.py @@ -139,6 +139,9 @@ async def test_cursor_no_query( # Context manager is also available with cursor: pass + # should this be available? + # async with cursor: + # pass async def test_cursor_execute( From 9c5be480137e56e6fff1b411ddee5d262e394d7a Mon Sep 17 00:00:00 2001 From: ptiurin Date: Tue, 9 May 2023 11:38:19 +0100 Subject: [PATCH 06/13] code smell - redundant exception --- src/firebolt/async_db/connection.py | 2 +- src/firebolt/db/connection.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 9f0eb06978f..7c155fd71f0 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -93,7 +93,7 @@ async def _resolve_engine_url( # the backend so it's safe to assume the cause of the error is # missing engine. raise FireboltEngineError(f"Firebolt engine {engine_name} does not exist.") - except (JSONDecodeError, RequestError, RuntimeError, HTTPStatusError) as e: + except (JSONDecodeError, RequestError, RuntimeError) as e: raise InterfaceError( f"Error {e.__class__.__name__}: " f"Unable to retrieve engine endpoint {url}." diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index c1b093bd228..373f2880d67 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -84,7 +84,7 @@ def _resolve_engine_url( # the backend so it's safe to assume the cause of the error is # missing engine. raise FireboltEngineError(f"Firebolt engine {engine_name} does not exist.") - except (JSONDecodeError, RequestError, RuntimeError, HTTPStatusError) as e: + except (JSONDecodeError, RequestError, RuntimeError) as e: raise InterfaceError( f"Error {e.__class__.__name__}: " f"Unable to retrieve engine endpoint {url}." From 04a085db2e7508c7744b34b13a2ceff8c79dcb23 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Tue, 9 May 2023 11:42:58 +0100 Subject: [PATCH 07/13] sonar skips for some required functions --- src/firebolt/common/_types.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/firebolt/common/_types.py b/src/firebolt/common/_types.py index 1881a07fe10..25105957b32 100644 --- a/src/firebolt/common/_types.py +++ b/src/firebolt/common/_types.py @@ -63,17 +63,17 @@ def parse_datetime(datetime_string: str) -> datetime: Date = date -def DateFromTicks(t: int) -> date: +def DateFromTicks(t: int) -> date: # NOSONAR """Convert `ticks` to `date` for Firebolt DB.""" return datetime.fromtimestamp(t).date() -def Time(hour: int, minute: int, second: int) -> None: +def Time(hour: int, minute: int, second: int) -> None: # NOSONAR """Unsupported: Construct `time`, for Firebolt DB.""" raise NotSupportedError("The time construct is not supported by Firebolt") -def TimeFromTicks(t: int) -> None: +def TimeFromTicks(t: int) -> None: # NOSONAR """Unsupported: Convert `ticks` to `time` for Firebolt DB.""" raise NotSupportedError("The time construct is not supported by Firebolt") @@ -82,7 +82,7 @@ def TimeFromTicks(t: int) -> None: TimestampFromTicks = datetime.fromtimestamp -def Binary(value: str) -> bytes: +def Binary(value: str) -> bytes: # NOSONAR """Encode a string into UTF-8.""" return value.encode("utf-8") From 70b375baacaf158409cab559cdb517b8478d8935 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Mon, 22 May 2023 13:57:38 +0100 Subject: [PATCH 08/13] refactor validate_engine_name_and_url --- src/firebolt/async_db/connection.py | 12 ++---------- src/firebolt/db/connection.py | 12 ++---------- src/firebolt/utils/util.py | 13 ++++++++++++- 3 files changed, 16 insertions(+), 21 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 7c155fd71f0..c4de061edf8 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -18,6 +18,7 @@ KEEPALIVE_FLAG, KEEPIDLE_RATE, ) +from firebolt.common.util import validate_engine_name_and_url from firebolt.utils.exception import ( ConfigurationError, ConnectionClosedError, @@ -47,15 +48,6 @@ logger = logging.getLogger(__name__) -def _validate_engine_name_and_url( - engine_name: Optional[str], engine_url: Optional[str] -) -> None: - if engine_name and engine_url: - raise ConfigurationError( - "Both engine_name and engine_url are provided. Provide only one to connect." - ) - - async def _resolve_engine_url( engine_name: str, auth: Auth, @@ -174,7 +166,7 @@ async def connect_inner( if not database: raise ConfigurationError("database name is required to connect.") - _validate_engine_name_and_url(engine_name, engine_url) + validate_engine_name_and_url(engine_name, engine_url) if not auth: if any([username, password, access_token, api_endpoint, use_token_cache]): diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index 373f2880d67..8be8dacb53e 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -20,6 +20,7 @@ KEEPALIVE_FLAG, KEEPIDLE_RATE, ) +from firebolt.common.util import validate_engine_name_and_url from firebolt.db.cursor import Cursor from firebolt.utils.exception import ( ConfigurationError, @@ -38,15 +39,6 @@ logger = logging.getLogger(__name__) -def _validate_engine_name_and_url( - engine_name: Optional[str], engine_url: Optional[str] -) -> None: - if engine_name and engine_url: - raise ConfigurationError( - "Both engine_name and engine_url are provided. Provide only one to connect." - ) - - def _resolve_engine_url( engine_name: str, auth: Auth, @@ -165,7 +157,7 @@ def connect_inner( if not database: raise ConfigurationError("database name is required to connect.") - _validate_engine_name_and_url(engine_name, engine_url) + validate_engine_name_and_url(engine_name, engine_url) if not auth: if any([username, password, access_token, api_endpoint, use_token_cache]): diff --git a/src/firebolt/utils/util.py b/src/firebolt/utils/util.py index e05ae677fb5..96ac7a1ae64 100644 --- a/src/firebolt/utils/util.py +++ b/src/firebolt/utils/util.py @@ -1,9 +1,11 @@ from functools import lru_cache, partial, wraps -from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar import trio from httpx import URL +from firebolt.utils.exception import ConfigurationError + T = TypeVar("T") @@ -103,3 +105,12 @@ def merge_urls(base: URL, merge: URL) -> URL: merge_raw_path = base.raw_path + merge.raw_path.lstrip(b"/") return base.copy_with(raw_path=merge_raw_path) return merge + + +def validate_engine_name_and_url( + engine_name: Optional[str], engine_url: Optional[str] +) -> None: + if engine_name and engine_url: + raise ConfigurationError( + "Both engine_name and engine_url are provided. Provide only one to connect." + ) From 70d8954d39bca3c1989679d9634d712a9a22ed17 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Mon, 22 May 2023 14:18:27 +0100 Subject: [PATCH 09/13] get rid of factory --- src/firebolt/async_db/connection.py | 184 +++++++++++++--------------- src/firebolt/db/connection.py | 184 +++++++++++++--------------- src/firebolt/model/engine.py | 2 +- 3 files changed, 177 insertions(+), 193 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index c4de061edf8..b2ccb55048a 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -4,7 +4,7 @@ import socket from json import JSONDecodeError from types import TracebackType -from typing import Any, Callable, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional from httpcore.backends.auto import AutoBackend from httpcore.backends.base import AsyncNetworkStream @@ -123,100 +123,6 @@ async def _get_database_default_engine_url( raise InterfaceError(f"Unable to retrieve default engine endpoint: {e}.") -def async_connect_factory(connection_class: Type) -> Callable: - async def connect_inner( - database: str = None, - username: Optional[str] = None, - password: Optional[str] = None, - access_token: Optional[str] = None, - auth: Auth = None, - engine_name: Optional[str] = None, - engine_url: Optional[str] = None, - account_name: Optional[str] = None, - api_endpoint: str = DEFAULT_API_URL, - use_token_cache: bool = True, - additional_parameters: Dict[str, Any] = {}, - ) -> Connection: - """Connect to Firebolt database. - - Args: - `database` (str): Name of the database to connect - `username` (Optional[str]): User name to use for authentication (Deprecated) - `password` (Optional[str]): Password to use for authentication (Deprecated) - `access_token` (Optional[str]): Authentication token to use instead of - credentials (Deprecated) - `auth` (Auth)L Authentication object. - `engine_name` (Optional[str]): Name of the engine to connect to - `engine_url` (Optional[str]): The engine endpoint to use - `account_name` (Optional[str]): For customers with multiple accounts; - if none, default is used - `api_endpoint` (str): Firebolt API endpoint. Used for authentication - `use_token_cache` (bool): Cached authentication token in filesystem - Default: True - `additional_parameters` (Optional[Dict]): Dictionary of less widely-used - arguments for connection - - Note: - Providing both `engine_name` and `engine_url` will result in an error - - """ - # These parameters are optional in function signature - # but are required to connect. - # PEP 249 recommends making them kwargs. - if not database: - raise ConfigurationError("database name is required to connect.") - - validate_engine_name_and_url(engine_name, engine_url) - - if not auth: - if any([username, password, access_token, api_endpoint, use_token_cache]): - logger.warning(AUTH_CREDENTIALS_DEPRECATION_MESSAGE) - auth = _get_auth(username, password, access_token, use_token_cache) - else: - raise ConfigurationError("No authentication provided.") - api_endpoint = fix_url_schema(api_endpoint) - - # Mypy checks, this should never happen - assert database is not None - - if not engine_name and not engine_url: - engine_url = await _get_database_default_engine_url( - database=database, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - ) - - elif engine_name: - engine_url = await _resolve_engine_url( - engine_name=engine_name, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - ) - elif account_name: - # In above if branches account name is validated since it's used to - # resolve or get an engine url. - # We need to manually validate account_name if none of the above - # cases are triggered. - async with AsyncClient( - auth=auth, - base_url=api_endpoint, - account_name=account_name, - api_endpoint=api_endpoint, - ) as client: - await client.account_id - - assert engine_url is not None - - engine_url = fix_url_schema(engine_url) - return connection_class( - engine_url, database, auth, api_endpoint, additional_parameters - ) - - return connect_inner - - class OverriddenHttpBackend(AutoBackend): """ `OverriddenHttpBackend` is a short-term solution for the TCP @@ -369,4 +275,90 @@ async def __aexit__( await self.aclose() -connect = async_connect_factory(Connection) +async def connect( + database: str = None, + username: Optional[str] = None, + password: Optional[str] = None, + access_token: Optional[str] = None, + auth: Auth = None, + engine_name: Optional[str] = None, + engine_url: Optional[str] = None, + account_name: Optional[str] = None, + api_endpoint: str = DEFAULT_API_URL, + use_token_cache: bool = True, + additional_parameters: Dict[str, Any] = {}, +) -> Connection: + """Connect to Firebolt database. + + Args: + `database` (str): Name of the database to connect + `username` (Optional[str]): User name to use for authentication (Deprecated) + `password` (Optional[str]): Password to use for authentication (Deprecated) + `access_token` (Optional[str]): Authentication token to use instead of + credentials (Deprecated) + `auth` (Auth)L Authentication object. + `engine_name` (Optional[str]): Name of the engine to connect to + `engine_url` (Optional[str]): The engine endpoint to use + `account_name` (Optional[str]): For customers with multiple accounts; + if none, default is used + `api_endpoint` (str): Firebolt API endpoint. Used for authentication + `use_token_cache` (bool): Cached authentication token in filesystem + Default: True + `additional_parameters` (Optional[Dict]): Dictionary of less widely-used + arguments for connection + + Note: + Providing both `engine_name` and `engine_url` will result in an error + + """ + # These parameters are optional in function signature + # but are required to connect. + # PEP 249 recommends making them kwargs. + if not database: + raise ConfigurationError("database name is required to connect.") + + validate_engine_name_and_url(engine_name, engine_url) + + if not auth: + if any([username, password, access_token, api_endpoint, use_token_cache]): + logger.warning(AUTH_CREDENTIALS_DEPRECATION_MESSAGE) + auth = _get_auth(username, password, access_token, use_token_cache) + else: + raise ConfigurationError("No authentication provided.") + api_endpoint = fix_url_schema(api_endpoint) + + # Mypy checks, this should never happen + assert database is not None + + if not engine_name and not engine_url: + engine_url = await _get_database_default_engine_url( + database=database, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + ) + + elif engine_name: + engine_url = await _resolve_engine_url( + engine_name=engine_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + ) + elif account_name: + # In above if branches account name is validated since it's used to + # resolve or get an engine url. + # We need to manually validate account_name if none of the above + # cases are triggered. + async with AsyncClient( + auth=auth, + base_url=api_endpoint, + account_name=account_name, + api_endpoint=api_endpoint, + ) as client: + await client.account_id + + assert engine_url is not None + + engine_url = fix_url_schema(engine_url) + return Connection(engine_url, database, auth, api_endpoint, additional_parameters) diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index 8be8dacb53e..9835b2bda32 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -4,7 +4,7 @@ import socket from json import JSONDecodeError from types import TracebackType -from typing import Any, Callable, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional from warnings import warn from httpcore.backends.base import NetworkStream @@ -114,100 +114,6 @@ def _get_database_default_engine_url( raise InterfaceError(f"Unable to retrieve default engine endpoint: {e}.") -def sync_connect_factory(connection_class: Type) -> Callable: - def connect_inner( - database: str = None, - username: Optional[str] = None, - password: Optional[str] = None, - access_token: Optional[str] = None, - auth: Auth = None, - engine_name: Optional[str] = None, - engine_url: Optional[str] = None, - account_name: Optional[str] = None, - api_endpoint: str = DEFAULT_API_URL, - use_token_cache: bool = True, - additional_parameters: Dict[str, Any] = {}, - ) -> Connection: - """Connect to Firebolt database. - - Args: - `database` (str): Name of the database to connect - `username` (Optional[str]): User name to use for authentication (Deprecated) - `password` (Optional[str]): Password to use for authentication (Deprecated) - `access_token` (Optional[str]): Authentication token to use instead of - credentials (Deprecated) - `auth` (Auth)L Authentication object. - `engine_name` (Optional[str]): Name of the engine to connect to - `engine_url` (Optional[str]): The engine endpoint to use - `account_name` (Optional[str]): For customers with multiple accounts; - if none, default is used - `api_endpoint` (str): Firebolt API endpoint. Used for authentication - `use_token_cache` (bool): Cached authentication token in filesystem - Default: True - `additional_parameters` (Optional[Dict]): Dictionary of less widely-used - arguments for connection - - Note: - Providing both `engine_name` and `engine_url` will result in an error - - """ - # These parameters are optional in function signature - # but are required to connect. - # PEP 249 recommends making them kwargs. - if not database: - raise ConfigurationError("database name is required to connect.") - - validate_engine_name_and_url(engine_name, engine_url) - - if not auth: - if any([username, password, access_token, api_endpoint, use_token_cache]): - logger.warning(AUTH_CREDENTIALS_DEPRECATION_MESSAGE) - auth = _get_auth(username, password, access_token, use_token_cache) - else: - raise ConfigurationError("No authentication provided.") - api_endpoint = fix_url_schema(api_endpoint) - - # Mypy checks, this should never happen - assert database is not None - - if not engine_name and not engine_url: - engine_url = _get_database_default_engine_url( - database=database, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - ) - - elif engine_name: - engine_url = _resolve_engine_url( - engine_name=engine_name, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - ) - elif account_name: - # In above if branches account name is validated since it's used to - # resolve or get an engine url. - # We need to manually validate account_name if none of the above - # cases are triggered. - with Client( - auth=auth, - base_url=api_endpoint, - account_name=account_name, - api_endpoint=api_endpoint, - ) as client: - client.account_id - - assert engine_url is not None - - engine_url = fix_url_schema(engine_url) - return connection_class( - engine_url, database, auth, api_endpoint, additional_parameters - ) - - return connect_inner - - # TODO: verify new httpx has not improved this class OverriddenHttpBackend(SyncBackend): """ @@ -364,4 +270,90 @@ def __del__(self) -> None: warn(f"Unclosed {self!r}", UserWarning) -connect = sync_connect_factory(Connection) +def connect( + database: str = None, + username: Optional[str] = None, + password: Optional[str] = None, + access_token: Optional[str] = None, + auth: Auth = None, + engine_name: Optional[str] = None, + engine_url: Optional[str] = None, + account_name: Optional[str] = None, + api_endpoint: str = DEFAULT_API_URL, + use_token_cache: bool = True, + additional_parameters: Dict[str, Any] = {}, +) -> Connection: + """Connect to Firebolt database. + + Args: + `database` (str): Name of the database to connect + `username` (Optional[str]): User name to use for authentication (Deprecated) + `password` (Optional[str]): Password to use for authentication (Deprecated) + `access_token` (Optional[str]): Authentication token to use instead of + credentials (Deprecated) + `auth` (Auth)L Authentication object. + `engine_name` (Optional[str]): Name of the engine to connect to + `engine_url` (Optional[str]): The engine endpoint to use + `account_name` (Optional[str]): For customers with multiple accounts; + if none, default is used + `api_endpoint` (str): Firebolt API endpoint. Used for authentication + `use_token_cache` (bool): Cached authentication token in filesystem + Default: True + `additional_parameters` (Optional[Dict]): Dictionary of less widely-used + arguments for connection + + Note: + Providing both `engine_name` and `engine_url` will result in an error + + """ + # These parameters are optional in function signature + # but are required to connect. + # PEP 249 recommends making them kwargs. + if not database: + raise ConfigurationError("database name is required to connect.") + + validate_engine_name_and_url(engine_name, engine_url) + + if not auth: + if any([username, password, access_token, api_endpoint, use_token_cache]): + logger.warning(AUTH_CREDENTIALS_DEPRECATION_MESSAGE) + auth = _get_auth(username, password, access_token, use_token_cache) + else: + raise ConfigurationError("No authentication provided.") + api_endpoint = fix_url_schema(api_endpoint) + + # Mypy checks, this should never happen + assert database is not None + + if not engine_name and not engine_url: + engine_url = _get_database_default_engine_url( + database=database, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + ) + + elif engine_name: + engine_url = _resolve_engine_url( + engine_name=engine_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + ) + elif account_name: + # In above if branches account name is validated since it's used to + # resolve or get an engine url. + # We need to manually validate account_name if none of the above + # cases are triggered. + with Client( + auth=auth, + base_url=api_endpoint, + account_name=account_name, + api_endpoint=api_endpoint, + ) as client: + client.account_id + + assert engine_url is not None + + engine_url = fix_url_schema(engine_url) + return Connection(engine_url, database, auth, api_endpoint, additional_parameters) diff --git a/src/firebolt/model/engine.py b/src/firebolt/model/engine.py index 6109a1f4df4..8eab5e8e35b 100644 --- a/src/firebolt/model/engine.py +++ b/src/firebolt/model/engine.py @@ -194,7 +194,7 @@ def get_connection(self) -> Connection: """ return connect( database=self.database.name, # type: ignore # already checked by decorator - auth=self._service.client.auth, + auth=self._service.client.auth, # type: ignore engine_url=self.endpoint, account_name=self._service.settings.account_name, api_endpoint=self._service.settings.server, From a351510005d966459fdf5b60eb92285f18ed9c43 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Mon, 22 May 2023 14:20:00 +0100 Subject: [PATCH 10/13] get rid of cursor class --- src/firebolt/async_db/connection.py | 3 +-- src/firebolt/db/connection.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index b2ccb55048a..7d7405ad3f3 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -184,7 +184,6 @@ class Connection: """ client_class: type - cursor_class = Cursor __slots__ = ( "_client", "_cursors", @@ -225,7 +224,7 @@ def cursor(self, **kwargs: Any) -> Cursor: if self.closed: raise ConnectionClosedError("Unable to create cursor: connection closed.") - c = self.cursor_class(client=self._client, connection=self, **kwargs) + c = Cursor(client=self._client, connection=self, **kwargs) self._cursors.append(c) return c diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index 9835b2bda32..937cff4b582 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -170,7 +170,6 @@ class Connection: """ client_class: type - cursor_class = Cursor __slots__ = ( "_client", "_cursors", @@ -216,7 +215,7 @@ def cursor(self, **kwargs: Any) -> Cursor: raise ConnectionClosedError("Unable to create cursor: connection closed.") with self._closing_lock.gen_rlock(): - c = self.cursor_class(client=self._client, connection=self, **kwargs) + c = Cursor(client=self._client, connection=self, **kwargs) self._cursors.append(c) return c From 28378e471a2fccd7c5569aec1eb8a633db508611 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Mon, 22 May 2023 15:24:08 +0100 Subject: [PATCH 11/13] Using BaseConnection --- src/firebolt/async_db/connection.py | 23 +++------------------- src/firebolt/common/base_connection.py | 27 ++++++++++++++++++++++++++ src/firebolt/db/connection.py | 17 +++------------- 3 files changed, 33 insertions(+), 34 deletions(-) create mode 100644 src/firebolt/common/base_connection.py diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 7d7405ad3f3..db87148a3d3 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -13,6 +13,7 @@ from firebolt.async_db.cursor import Cursor from firebolt.client import DEFAULT_API_URL, AsyncClient from firebolt.client.auth import Auth, _get_auth +from firebolt.common.base_connection import BaseConnection from firebolt.common.settings import ( DEFAULT_TIMEOUT_SECONDS, KEEPALIVE_FLAG, @@ -160,7 +161,7 @@ async def connect_tcp( return stream -class Connection: +class Connection(BaseConnection): """ Firebolt asynchronous database connection class. Implements `PEP 249`_. @@ -205,7 +206,6 @@ def __init__( self.engine_url = engine_url self.database = database self._cursors: List[Cursor] = [] - self._is_closed = False # Override tcp keepalive settings for connection transport = AsyncHTTPTransport() transport._pool._network_backend = OverriddenHttpBackend() @@ -219,6 +219,7 @@ def __init__( transport=transport, headers={"User-Agent": get_user_agent_header(user_drivers, user_clients)}, ) + super().__init__() def cursor(self, **kwargs: Any) -> Cursor: if self.closed: @@ -234,18 +235,6 @@ async def __aenter__(self) -> Connection: raise ConnectionClosedError("Connection is already closed.") return self - def _remove_cursor(self, cursor: Cursor) -> None: - # This way it's atomic - try: - self._cursors.remove(cursor) - except ValueError: - pass - - @property - def closed(self) -> bool: - """`True` if connection is closed; `False` otherwise.""" - return self._is_closed - async def aclose(self) -> None: """Close connection and all underlying cursors.""" if self.closed: @@ -262,12 +251,6 @@ async def aclose(self) -> None: await self._client.aclose() self._is_closed = True - def commit(self) -> None: - """Does nothing since Firebolt doesn't have transactions.""" - - if self.closed: - raise ConnectionClosedError("Unable to commit: Connection closed.") - async def __aexit__( self, exc_type: type, exc_val: Exception, exc_tb: TracebackType ) -> None: diff --git a/src/firebolt/common/base_connection.py b/src/firebolt/common/base_connection.py new file mode 100644 index 00000000000..1d47374015c --- /dev/null +++ b/src/firebolt/common/base_connection.py @@ -0,0 +1,27 @@ +from typing import Any, List + +from firebolt.utils.exception import ConnectionClosedError + + +class BaseConnection: + def __init__(self) -> None: + self._cursors: List[Any] = [] + self._is_closed = False + + def _remove_cursor(self, cursor: Any) -> None: + # This way it's atomic + try: + self._cursors.remove(cursor) + except ValueError: + pass + + @property + def closed(self) -> bool: + """`True` if connection is closed; `False` otherwise.""" + return self._is_closed + + def commit(self) -> None: + """Does nothing since Firebolt doesn't have transactions.""" + + if self.closed: + raise ConnectionClosedError("Unable to commit: Connection closed.") diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index 937cff4b582..415603a8729 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -14,6 +14,7 @@ from firebolt.client import DEFAULT_API_URL, Client from firebolt.client.auth import Auth, _get_auth +from firebolt.common.base_connection import BaseConnection from firebolt.common.settings import ( AUTH_CREDENTIALS_DEPRECATION_MESSAGE, DEFAULT_TIMEOUT_SECONDS, @@ -114,7 +115,6 @@ def _get_database_default_engine_url( raise InterfaceError(f"Unable to retrieve default engine endpoint: {e}.") -# TODO: verify new httpx has not improved this class OverriddenHttpBackend(SyncBackend): """ `OverriddenHttpBackend` is a short-term solution for the TCP @@ -152,7 +152,7 @@ def connect_tcp( return stream -class Connection: +class Connection(BaseConnection): """ Firebolt database connection class. Implements PEP-249. @@ -192,7 +192,6 @@ def __init__( self.engine_url = engine_url self.database = database self._cursors: List[Cursor] = [] - self._is_closed = False # Override tcp keepalive settings for connection transport = HTTPTransport() transport._pool._network_backend = OverriddenHttpBackend() @@ -209,6 +208,7 @@ def __init__( # Holding this lock for write means that connection is closing itself. # cursor() should hold this lock for read to read/write state self._closing_lock = RWLockWrite() + super().__init__() def cursor(self, **kwargs: Any) -> Cursor: if self.closed: @@ -226,11 +226,6 @@ def _remove_cursor(self, cursor: Cursor) -> None: except ValueError: pass - @property - def closed(self) -> bool: - """`True` if connection is closed; `False` otherwise.""" - return self._is_closed - def close(self) -> None: if self.closed: return @@ -247,12 +242,6 @@ def close(self) -> None: self._client.close() self._is_closed = True - def commit(self) -> None: - """Does nothing since Firebolt doesn't have transactions.""" - - if self.closed: - raise ConnectionClosedError("Unable to commit: Connection closed.") - # Context manager support def __enter__(self) -> Connection: if self.closed: From 0f3cd099b340dce0daf90268a776d698490d7056 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Mon, 22 May 2023 17:15:06 +0100 Subject: [PATCH 12/13] ci: integration tests on stg --- .github/workflows/integration-tests.yml | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index f1d667be11a..4608d31ae7b 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -3,6 +3,10 @@ on: workflow_dispatch: workflow_call: inputs: + environment: + default: 'staging' + required: false + type: string branch: required: false type: string @@ -35,13 +39,27 @@ jobs: python -m pip install --upgrade pip pip install ".[dev]" + - name: Determine env variables + run: | + if [ "${{ inputs.environment }}" == 'staging' ]; then + echo "USERNAME=${{ secrets.FIREBOLT_USERNAME_STAGING }}" >> "$GITHUB_ENV" + echo "PASSWORD=${{ secrets.FIREBOLT_PASSWORD_STAGING }}" >> "$GITHUB_ENV" + echo "CLIENT_ID=${{ secrets.FIREBOLT_CLIENT_ID_STAGING }}" >> "$GITHUB_ENV" + echo "CLIENT_SECRET=${{ secrets.FIREBOLT_CLIENT_SECRET_STAGING }}" >> "$GITHUB_ENV" + else + echo "USERNAME=${{ secrets.FIREBOLT_USERNAME_DEV }}" >> "$GITHUB_ENV" + echo "PASSWORD=${{ secrets.FIREBOLT_PASSWORD_DEV }}" >> "$GITHUB_ENV" + echo "CLIENT_ID=${{ secrets.FIREBOLT_CLIENT_ID_DEV }}" >> "$GITHUB_ENV" + echo "CLIENT_SECRET=${{ secrets.FIREBOLT_CLIENT_SECRET_DEV }}" >> "$GITHUB_ENV" + fi + - name: Setup database and engine id: setup uses: firebolt-db/integration-testing-setup@master with: firebolt-username: ${{ secrets.FIREBOLT_USERNAME }} firebolt-password: ${{ secrets.FIREBOLT_PASSWORD }} - api-endpoint: "api.dev.firebolt.io" + api-endpoint: "api.${{ inputs.environment }}.firebolt.io" region: "us-east-1" - name: Run integration tests @@ -55,7 +73,7 @@ jobs: ENGINE_URL: ${{ steps.setup.outputs.engine_url }} STOPPED_ENGINE_NAME: ${{ steps.setup.outputs.stopped_engine_name }} STOPPED_ENGINE_URL: ${{ steps.setup.outputs.stopped_engine_url }} - API_ENDPOINT: "api.dev.firebolt.io" + API_ENDPOINT: "api.${{ inputs.environment }}.firebolt.io" ACCOUNT_NAME: "firebolt" run: | pytest -n 6 --dist loadgroup --timeout_method "signal" -o log_cli=true -o log_cli_level=INFO tests/integration --alluredir=allure-results From bf41c0b91c6965d66a924c514bf28442691ec967 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Mon, 22 May 2023 19:56:20 +0100 Subject: [PATCH 13/13] dispatch --- .github/workflows/integration-tests.yml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 4608d31ae7b..bebe1be0d5a 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -1,6 +1,15 @@ name: Integration tests on: workflow_dispatch: + inputs: + environment: + description: 'Environment to run the tests against' + type: choice + required: true + default: 'dev' + options: + - dev + - staging workflow_call: inputs: environment: