diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index f1d667be11a..bebe1be0d5a 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -1,8 +1,21 @@ name: Integration tests on: workflow_dispatch: + inputs: + environment: + description: 'Environment to run the tests against' + type: choice + required: true + default: 'dev' + options: + - dev + - staging workflow_call: inputs: + environment: + default: 'staging' + required: false + type: string branch: required: false type: string @@ -35,13 +48,27 @@ jobs: python -m pip install --upgrade pip pip install ".[dev]" + - name: Determine env variables + run: | + if [ "${{ inputs.environment }}" == 'staging' ]; then + echo "USERNAME=${{ secrets.FIREBOLT_USERNAME_STAGING }}" >> "$GITHUB_ENV" + echo "PASSWORD=${{ secrets.FIREBOLT_PASSWORD_STAGING }}" >> "$GITHUB_ENV" + echo "CLIENT_ID=${{ secrets.FIREBOLT_CLIENT_ID_STAGING }}" >> "$GITHUB_ENV" + echo "CLIENT_SECRET=${{ secrets.FIREBOLT_CLIENT_SECRET_STAGING }}" >> "$GITHUB_ENV" + else + echo "USERNAME=${{ secrets.FIREBOLT_USERNAME_DEV }}" >> "$GITHUB_ENV" + echo "PASSWORD=${{ secrets.FIREBOLT_PASSWORD_DEV }}" >> "$GITHUB_ENV" + echo "CLIENT_ID=${{ secrets.FIREBOLT_CLIENT_ID_DEV }}" >> "$GITHUB_ENV" + echo "CLIENT_SECRET=${{ secrets.FIREBOLT_CLIENT_SECRET_DEV }}" >> "$GITHUB_ENV" + fi + - name: Setup database and engine id: setup uses: firebolt-db/integration-testing-setup@master with: firebolt-username: ${{ secrets.FIREBOLT_USERNAME }} firebolt-password: ${{ secrets.FIREBOLT_PASSWORD }} - api-endpoint: "api.dev.firebolt.io" + api-endpoint: "api.${{ inputs.environment }}.firebolt.io" region: "us-east-1" - name: Run integration tests @@ -55,7 +82,7 @@ jobs: ENGINE_URL: ${{ steps.setup.outputs.engine_url }} STOPPED_ENGINE_NAME: ${{ steps.setup.outputs.stopped_engine_name }} STOPPED_ENGINE_URL: ${{ steps.setup.outputs.stopped_engine_url }} - API_ENDPOINT: "api.dev.firebolt.io" + API_ENDPOINT: "api.${{ inputs.environment }}.firebolt.io" ACCOUNT_NAME: "firebolt" run: | pytest -n 6 --dist loadgroup --timeout_method "signal" -o log_cli=true -o log_cli_level=INFO tests/integration --alluredir=allure-results diff --git a/src/firebolt/async_db/__init__.py b/src/firebolt/async_db/__init__.py index 4d74ebe04dd..35ff4a15a1c 100644 --- a/src/firebolt/async_db/__init__.py +++ b/src/firebolt/async_db/__init__.py @@ -1,4 +1,6 @@ -from firebolt.async_db._types import ( +from firebolt.async_db.connection import Connection, connect +from firebolt.async_db.cursor import Cursor +from firebolt.common._types import ( ARRAY, BINARY, DATETIME, @@ -14,8 +16,6 @@ Timestamp, TimestampFromTicks, ) -from firebolt.async_db.connection import Connection, connect -from firebolt.async_db.cursor import Cursor from firebolt.utils.exception import ( DatabaseError, DataError, diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 09b94bdda03..db87148a3d3 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -4,15 +4,22 @@ import socket from json import JSONDecodeError from types import TracebackType -from typing import Any, Callable, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional from httpcore.backends.auto import AutoBackend from httpcore.backends.base import AsyncNetworkStream from httpx import AsyncHTTPTransport, HTTPStatusError, RequestError, Timeout -from firebolt.async_db.cursor import BaseCursor, Cursor +from firebolt.async_db.cursor import Cursor from firebolt.client import DEFAULT_API_URL, AsyncClient -from firebolt.client.auth import Auth, Token, UsernamePassword +from firebolt.client.auth import Auth, _get_auth +from firebolt.common.base_connection import BaseConnection +from firebolt.common.settings import ( + DEFAULT_TIMEOUT_SECONDS, + KEEPALIVE_FLAG, + KEEPIDLE_RATE, +) +from firebolt.common.util import validate_engine_name_and_url from firebolt.utils.exception import ( ConfigurationError, ConnectionClosedError, @@ -27,9 +34,6 @@ from firebolt.utils.usage_tracker import get_user_agent_header from firebolt.utils.util import fix_url_schema -DEFAULT_TIMEOUT_SECONDS: int = 60 -KEEPALIVE_FLAG: int = 1 -KEEPIDLE_RATE: int = 60 # seconds AUTH_CREDENTIALS_DEPRECATION_MESSAGE = """ Passing connection credentials directly to the `connect` function is deprecated. Pass the `Auth` object instead. @@ -82,7 +86,7 @@ async def _resolve_engine_url( # the backend so it's safe to assume the cause of the error is # missing engine. raise FireboltEngineError(f"Firebolt engine {engine_name} does not exist.") - except (JSONDecodeError, RequestError, RuntimeError, HTTPStatusError) as e: + except (JSONDecodeError, RequestError, RuntimeError) as e: raise InterfaceError( f"Error {e.__class__.__name__}: " f"Unable to retrieve engine endpoint {url}." @@ -120,142 +124,6 @@ async def _get_database_default_engine_url( raise InterfaceError(f"Unable to retrieve default engine endpoint: {e}.") -def _validate_engine_name_and_url( - engine_name: Optional[str], engine_url: Optional[str] -) -> None: - if engine_name and engine_url: - raise ConfigurationError( - "Both engine_name and engine_url are provided. Provide only one to connect." - ) - - -def _get_auth( - username: Optional[str], - password: Optional[str], - access_token: Optional[str], - use_token_cache: bool, -) -> Auth: - """Create `Auth` class based on provided credentials. - - If `access_token` is provided, it's used for `Auth` creation. - Otherwise, username/password are used. - - Returns: - Auth: `auth object` - - Raises: - `ConfigurationError`: Invalid combination of credentials provided - - """ - if not access_token: - if not username or not password: - raise ConfigurationError( - "Neither username/password nor access_token are provided. Provide one" - " to authenticate." - ) - return UsernamePassword(username, password, use_token_cache) - if username or password: - raise ConfigurationError( - "Username/password and access_token are both provided. Provide only one" - " to authenticate." - ) - return Token(access_token) - - -def async_connect_factory(connection_class: Type) -> Callable: - async def connect_inner( - database: str = None, - username: Optional[str] = None, - password: Optional[str] = None, - access_token: Optional[str] = None, - auth: Auth = None, - engine_name: Optional[str] = None, - engine_url: Optional[str] = None, - account_name: Optional[str] = None, - api_endpoint: str = DEFAULT_API_URL, - use_token_cache: bool = True, - additional_parameters: Dict[str, Any] = {}, - ) -> Connection: - """Connect to Firebolt database. - - Args: - `database` (str): Name of the database to connect - `username` (Optional[str]): User name to use for authentication (Deprecated) - `password` (Optional[str]): Password to use for authentication (Deprecated) - `access_token` (Optional[str]): Authentication token to use instead of - credentials (Deprecated) - `auth` (Auth)L Authentication object. - `engine_name` (Optional[str]): Name of the engine to connect to - `engine_url` (Optional[str]): The engine endpoint to use - `account_name` (Optional[str]): For customers with multiple accounts; - if none, default is used - `api_endpoint` (str): Firebolt API endpoint. Used for authentication - `use_token_cache` (bool): Cached authentication token in filesystem - Default: True - `additional_parameters` (Optional[Dict]): Dictionary of less widely-used - arguments for connection - - Note: - Providing both `engine_name` and `engine_url` will result in an error - - """ - # These parameters are optional in function signature - # but are required to connect. - # PEP 249 recommends making them kwargs. - if not database: - raise ConfigurationError("database name is required to connect.") - - _validate_engine_name_and_url(engine_name, engine_url) - - if not auth: - if any([username, password, access_token, api_endpoint, use_token_cache]): - logger.warning(AUTH_CREDENTIALS_DEPRECATION_MESSAGE) - auth = _get_auth(username, password, access_token, use_token_cache) - else: - raise ConfigurationError("No authentication provided.") - api_endpoint = fix_url_schema(api_endpoint) - - # Mypy checks, this should never happen - assert database is not None - - if not engine_name and not engine_url: - engine_url = await _get_database_default_engine_url( - database=database, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - ) - - elif engine_name: - engine_url = await _resolve_engine_url( - engine_name=engine_name, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - ) - elif account_name: - # In above if branches account name is validated since it's used to - # resolve or get an engine url. - # We need to manually validate account_name if none of the above - # cases are triggered. - async with AsyncClient( - auth=auth, - base_url=api_endpoint, - account_name=account_name, - api_endpoint=api_endpoint, - ) as client: - await client.account_id - - assert engine_url is not None - - engine_url = fix_url_schema(engine_url) - return connection_class( - engine_url, database, auth, api_endpoint, additional_parameters - ) - - return connect_inner - - class OverriddenHttpBackend(AutoBackend): """ `OverriddenHttpBackend` is a short-term solution for the TCP @@ -293,9 +161,30 @@ async def connect_tcp( return stream -class BaseConnection: +class Connection(BaseConnection): + """ + Firebolt asynchronous database connection class. Implements `PEP 249`_. + + Args: + `engine_url`: Firebolt database engine REST API url + `database`: Firebolt database name + `username`: Firebolt account username + `password`: Firebolt account password + `api_endpoint`: Optional. Firebolt API endpoint used for authentication + `connector_versions`: Optional. Tuple of connector name and version, or + a list of tuples of your connector stack. Useful for tracking custom + connector usage. + + Note: + Firebolt does not support transactions, + so commit and rollback methods are not implemented. + + .. _PEP 249: + https://www.python.org/dev/peps/pep-0249/ + + """ + client_class: type - cursor_class: type __slots__ = ( "_client", "_cursors", @@ -313,6 +202,10 @@ def __init__( api_endpoint: str = DEFAULT_API_URL, additional_parameters: Dict[str, Any] = {}, ): + self.api_endpoint = api_endpoint + self.engine_url = engine_url + self.database = database + self._cursors: List[Cursor] = [] # Override tcp keepalive settings for connection transport = AsyncHTTPTransport() transport._pool._network_backend = OverriddenHttpBackend() @@ -326,25 +219,23 @@ def __init__( transport=transport, headers={"User-Agent": get_user_agent_header(user_drivers, user_clients)}, ) - self.api_endpoint = api_endpoint - self.engine_url = engine_url - self.database = database - self._cursors: List[BaseCursor] = [] - self._is_closed = False - - def _cursor(self, **kwargs: Any) -> BaseCursor: - """ - Create new cursor object. - """ + super().__init__() + def cursor(self, **kwargs: Any) -> Cursor: if self.closed: raise ConnectionClosedError("Unable to create cursor: connection closed.") - c = self.cursor_class(self._client, self, **kwargs) + c = Cursor(client=self._client, connection=self, **kwargs) self._cursors.append(c) return c - async def _aclose(self) -> None: + # Context manager support + async def __aenter__(self) -> Connection: + if self.closed: + raise ConnectionClosedError("Connection is already closed.") + return self + + async def aclose(self) -> None: """Close connection and all underlying cursors.""" if self.closed: return @@ -360,67 +251,96 @@ async def _aclose(self) -> None: await self._client.aclose() self._is_closed = True - @property - def closed(self) -> bool: - """`True` if connection is closed; `False` otherwise.""" - return self._is_closed - - def _remove_cursor(self, cursor: Cursor) -> None: - # This way it's atomic - try: - self._cursors.remove(cursor) - except ValueError: - pass - - def commit(self) -> None: - """Does nothing since Firebolt doesn't have transactions.""" - - if self.closed: - raise ConnectionClosedError("Unable to commit: Connection closed.") + async def __aexit__( + self, exc_type: type, exc_val: Exception, exc_tb: TracebackType + ) -> None: + await self.aclose() -class Connection(BaseConnection): - """ - Firebolt asynchronous database connection class. Implements `PEP 249`_. +async def connect( + database: str = None, + username: Optional[str] = None, + password: Optional[str] = None, + access_token: Optional[str] = None, + auth: Auth = None, + engine_name: Optional[str] = None, + engine_url: Optional[str] = None, + account_name: Optional[str] = None, + api_endpoint: str = DEFAULT_API_URL, + use_token_cache: bool = True, + additional_parameters: Dict[str, Any] = {}, +) -> Connection: + """Connect to Firebolt database. Args: - `engine_url`: Firebolt database engine REST API url - `database`: Firebolt database name - `username`: Firebolt account username - `password`: Firebolt account password - `api_endpoint`: Optional. Firebolt API endpoint used for authentication - `connector_versions`: Optional. Tuple of connector name and version, or - a list of tuples of your connector stack. Useful for tracking custom - connector usage. + `database` (str): Name of the database to connect + `username` (Optional[str]): User name to use for authentication (Deprecated) + `password` (Optional[str]): Password to use for authentication (Deprecated) + `access_token` (Optional[str]): Authentication token to use instead of + credentials (Deprecated) + `auth` (Auth)L Authentication object. + `engine_name` (Optional[str]): Name of the engine to connect to + `engine_url` (Optional[str]): The engine endpoint to use + `account_name` (Optional[str]): For customers with multiple accounts; + if none, default is used + `api_endpoint` (str): Firebolt API endpoint. Used for authentication + `use_token_cache` (bool): Cached authentication token in filesystem + Default: True + `additional_parameters` (Optional[Dict]): Dictionary of less widely-used + arguments for connection Note: - Firebolt does not support transactions, - so commit and rollback methods are not implemented. - - .. _PEP 249: - https://www.python.org/dev/peps/pep-0249/ + Providing both `engine_name` and `engine_url` will result in an error """ + # These parameters are optional in function signature + # but are required to connect. + # PEP 249 recommends making them kwargs. + if not database: + raise ConfigurationError("database name is required to connect.") + + validate_engine_name_and_url(engine_name, engine_url) + + if not auth: + if any([username, password, access_token, api_endpoint, use_token_cache]): + logger.warning(AUTH_CREDENTIALS_DEPRECATION_MESSAGE) + auth = _get_auth(username, password, access_token, use_token_cache) + else: + raise ConfigurationError("No authentication provided.") + api_endpoint = fix_url_schema(api_endpoint) - cursor_class = Cursor - - aclose = BaseConnection._aclose - - def cursor(self) -> Cursor: - c = super()._cursor() - assert isinstance(c, Cursor) # typecheck - return c + # Mypy checks, this should never happen + assert database is not None - # Context manager support - async def __aenter__(self) -> Connection: - if self.closed: - raise ConnectionClosedError("Connection is already closed.") - return self + if not engine_name and not engine_url: + engine_url = await _get_database_default_engine_url( + database=database, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + ) - async def __aexit__( - self, exc_type: type, exc_val: Exception, exc_tb: TracebackType - ) -> None: - await self._aclose() + elif engine_name: + engine_url = await _resolve_engine_url( + engine_name=engine_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + ) + elif account_name: + # In above if branches account name is validated since it's used to + # resolve or get an engine url. + # We need to manually validate account_name if none of the above + # cases are triggered. + async with AsyncClient( + auth=auth, + base_url=api_endpoint, + account_name=account_name, + api_endpoint=api_endpoint, + ) as client: + await client.account_id + assert engine_url is not None -connect = async_connect_factory(Connection) + engine_url = fix_url_schema(engine_url) + return Connection(engine_url, database, auth, api_endpoint, additional_parameters) diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 6bc01849acb..890e71674a7 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -3,14 +3,13 @@ import logging import re import time -from enum import Enum from functools import wraps from types import TracebackType from typing import ( TYPE_CHECKING, Any, Callable, - Dict, + Iterator, List, Optional, Sequence, @@ -20,24 +19,25 @@ from aiorwlock import RWLock from httpx import Response, codes -from pydantic import BaseModel -from firebolt.async_db._types import ( +from firebolt.async_db.util import is_db_available, is_engine_running +from firebolt.common._types import ( ColType, Column, ParameterType, RawColType, SetParameter, - parse_type, - parse_value, split_format_sql, ) -from firebolt.async_db.util import is_db_available, is_engine_running -from firebolt.client import AsyncClient +from firebolt.common.base_cursor import ( + BaseCursor, + CursorState, + QueryStatus, + Statistics, +) from firebolt.utils.exception import ( AsyncExecutionUnavailableError, CursorClosedError, - DataError, EngineNotRunningError, FireboltDatabaseError, OperationalError, @@ -48,46 +48,14 @@ if TYPE_CHECKING: from firebolt.async_db.connection import Connection +from httpx import AsyncClient as AsyncHttpxClient + logger = logging.getLogger(__name__) JSON_OUTPUT_FORMAT = "JSON_Compact" -class CursorState(Enum): - NONE = 1 - ERROR = 2 - DONE = 3 - CLOSED = 4 - - -class QueryStatus(Enum): - """Enumeration of query responses on server-side async queries.""" - - RUNNING = 1 - ENDED_SUCCESSFULLY = 2 - ENDED_UNSUCCESSFULLY = 3 - NOT_READY = 4 - STARTED_EXECUTION = 5 - PARSE_ERROR = 6 - CANCELED_EXECUTION = 7 - EXECUTION_ERROR = 8 - - -class Statistics(BaseModel): - """ - Class for query execution statistics. - """ - - elapsed: float - rows_read: int - bytes_read: int - time_before_execution: float - time_to_execute: float - scanned_bytes_cache: Optional[float] - scanned_bytes_storage: Optional[float] - - def check_not_closed(func: Callable) -> Callable: """(Decorator) ensure cursor is not closed before calling method.""" @@ -115,142 +83,34 @@ def inner(self: Cursor, *args: Any, **kwargs: Any) -> Any: return inner -class BaseCursor: - __slots__ = ( - "connection", - "_arraysize", - "_client", - "_state", - "_descriptions", - "_statistics", - "_rowcount", - "_rows", - "_idx", - "_idx_lock", - "_row_sets", - "_next_set_idx", - "_set_parameters", - "_query_id", - ) - - default_arraysize = 1 - - def __init__(self, client: AsyncClient, connection: Connection): - self.connection = connection - self._client = client - self._arraysize = self.default_arraysize - # These fields initialized here for type annotations purpose - self._rows: Optional[List[List[RawColType]]] = None - self._descriptions: Optional[List[Column]] = None - self._statistics: Optional[Statistics] = None - self._row_sets: List[ - Tuple[ - int, - Optional[List[Column]], - Optional[Statistics], - Optional[List[List[RawColType]]], - ] - ] = [] - self._set_parameters: Dict[str, Any] = dict() - self._rowcount = -1 - self._idx = 0 - self._next_set_idx = 0 - self._query_id = "" - self._reset() - - def __del__(self) -> None: - self.close() - - @property # type: ignore - @check_not_closed - def description(self) -> Optional[List[Column]]: - """ - Provides information about a single result row of a query. - - Attributes: - * ``name`` - * ``type_code`` - * ``display_size`` - * ``internal_size`` - * ``precision`` - * ``scale`` - * ``null_ok`` - """ - return self._descriptions - - @property # type: ignore - @check_not_closed - def statistics(self) -> Optional[Statistics]: - """Query execution statistics returned by the backend.""" - return self._statistics - - @property # type: ignore - @check_not_closed - def rowcount(self) -> int: - """The number of rows produced by last query.""" - return self._rowcount - - @property # type: ignore - @check_not_closed - def query_id(self) -> str: - """The query id of a query executed asynchronously.""" - return self._query_id - - @property - def arraysize(self) -> int: - """Default number of rows returned by fetchmany.""" - return self._arraysize - - @arraysize.setter - def arraysize(self, value: int) -> None: - if not isinstance(value, int): - raise TypeError( - "Invalid arraysize value type, expected int," - f" got {type(value).__name__}" - ) - self._arraysize = value +class Cursor(BaseCursor): + """ + Executes async queries to Firebolt Database. + Should not be created directly; + use :py:func:`connection.cursor ` - @property - def closed(self) -> bool: - """True if connection is closed, False otherwise.""" - return self._state == CursorState.CLOSED + Args: + description: Information about a single result row. + rowcount: The number of rows produced by last query. + closed: True if connection is closed; False otherwise. + arraysize: Read/Write, specifies the number of rows to fetch at a time + with the :py:func:`fetchmany` method. - def close(self) -> None: - """Terminate an ongoing query (if any) and mark connection as closed.""" - self._state = CursorState.CLOSED - # remove typecheck skip after connection is implemented - self.connection._remove_cursor(self) # type: ignore + """ - @check_not_closed - @check_query_executed - def nextset(self) -> Optional[bool]: - """ - Skip to the next available set, discarding any remaining rows - from the current set. - Returns True if operation was successful; - None if there are no more sets to retrive. - """ - return self._pop_next_set() + __slots__ = BaseCursor.__slots__ + ("_async_query_lock",) - def _pop_next_set(self) -> Optional[bool]: - """ - Same functionality as .nextset, but doesn't check that query has been executed. - """ - if self._next_set_idx >= len(self._row_sets): - return None - ( - self._rowcount, - self._descriptions, - self._statistics, - self._rows, - ) = self._row_sets[self._next_set_idx] - self._idx = 0 - self._next_set_idx += 1 - return True - - def flush_parameters(self) -> None: - """Cleanup all previously set parameters""" - self._set_parameters = dict() + def __init__( + self, + *args: Any, + client: AsyncHttpxClient, + connection: Connection, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self._async_query_lock = RWLock() + self._client = client + self.connection = connection async def _raise_if_error(self, resp: Response) -> None: """Raise a proper error if any""" @@ -275,63 +135,6 @@ async def _raise_if_error(self, resp: Response) -> None: ) resp.raise_for_status() - def _reset(self) -> None: - """Clear all data stored from previous query.""" - self._state = CursorState.NONE - self._rows = None - self._descriptions = None - self._statistics = None - self._rowcount = -1 - self._idx = 0 - self._row_sets = [] - self._next_set_idx = 0 - self._query_id = "" - - def _row_set_from_response( - self, response: Response - ) -> Tuple[ - int, - Optional[List[Column]], - Optional[Statistics], - Optional[List[List[RawColType]]], - ]: - """Fetch information about executed query from http response.""" - - # Empty response is returned for insert query - if response.headers.get("content-length", "") == "0": - return (-1, None, None, None) - try: - # Skip parsing floats to properly parse them later - query_data = response.json(parse_float=str) - rowcount = int(query_data["rows"]) - descriptions: Optional[List[Column]] = [ - Column(d["name"], parse_type(d["type"]), None, None, None, None, None) - for d in query_data["meta"] - ] - if not descriptions: - descriptions = None - statistics = Statistics(**query_data["statistics"]) - # Parse data during fetch - rows = query_data["data"] - return (rowcount, descriptions, statistics, rows) - except (KeyError, ValueError) as err: - raise DataError(f"Invalid query data format: {str(err)}") - - def _append_row_set( - self, - row_set: Tuple[ - int, - Optional[List[Column]], - Optional[Statistics], - Optional[List[List[RawColType]]], - ], - ) -> None: - """Store information about executed query.""" - self._row_sets.append(row_set) - if self._next_set_idx == 0: - # Populate values for first set - self._pop_next_set() - async def _api_request( self, query: Optional[str] = "", @@ -383,33 +186,6 @@ async def _validate_set_parameter(self, parameter: SetParameter) -> None: # set parameter passed validation self._set_parameters[parameter.name] = parameter.value - def _validate_server_side_async_settings( - self, - parameters: Sequence[Sequence[ParameterType]], - queries: List[Union[SetParameter, str]], - skip_parsing: bool = False, - async_execution: Optional[bool] = False, - ) -> None: - if async_execution and self._set_parameters.get("use_standard_sql", "1") == "0": - raise AsyncExecutionUnavailableError( - "It is not possible to execute queries asynchronously if " - "use_standard_sql=0." - ) - if parameters and skip_parsing: - logger.warning( - "Query formatting parameters are provided but skip_parsing " - "is specified. They will be ignored." - ) - non_set_queries = 0 - for query in queries: - if type(query) is not SetParameter: - non_set_queries += 1 - if non_set_queries > 1 and async_execution: - raise AsyncExecutionUnavailableError( - "It is not possible to execute multi-statement " - "queries asynchronously." - ) - async def _do_execute( self, raw_query: str, @@ -568,77 +344,50 @@ async def executemany( else: return self.rowcount - def _parse_row(self, row: List[RawColType]) -> List[ColType]: - """Parse a single data row based on query column types.""" - assert len(row) == len(self.description) - return [ - parse_value(col, self.description[i].type_code) for i, col in enumerate(row) - ] - - def _get_next_range(self, size: int) -> Tuple[int, int]: - """ - Return range of next rows of size (if possible), - and update _idx to point to the end of this range - """ + # Iteration support + @check_not_closed + @check_query_executed + def __aiter__(self) -> Cursor: + return self - if self._rows is None: - # No elements to take - raise DataError("no rows to fetch") + def close(self) -> None: + """Terminate an ongoing query (if any) and mark connection as closed.""" + self._state = CursorState.CLOSED + self.connection._remove_cursor(self) - left = self._idx - right = min(self._idx + size, len(self._rows)) - self._idx = right - return left, right + def __del__(self) -> None: + self.close() + # Context manager support @check_not_closed - @check_query_executed - def fetchone(self) -> Optional[List[ColType]]: - """Fetch the next row of a query result set.""" - left, right = self._get_next_range(1) - if left == right: - # We are out of elements - return None - assert self._rows is not None - return self._parse_row(self._rows[left]) + def __enter__(self) -> Cursor: + return self - @check_not_closed - @check_query_executed - def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]: - """ - Fetch the next set of rows of a query result; - cursor.arraysize is default size. - """ - size = size if size is not None else self.arraysize - left, right = self._get_next_range(size) - assert self._rows is not None - rows = self._rows[left:right] - return [self._parse_row(row) for row in rows] + def __exit__( + self, exc_type: type, exc_val: Exception, exc_tb: TracebackType + ) -> None: + self.close() + # TODO: figure out how to implement __aenter__ and __await__ @check_not_closed - @check_query_executed - def fetchall(self) -> List[List[ColType]]: - """Fetch all remaining rows of a query result.""" - left, right = self._get_next_range(self.rowcount) - assert self._rows is not None - rows = self._rows[left:right] - return [self._parse_row(row) for row in rows] + def __aenter__(self) -> Cursor: + return self - @check_not_closed - def setinputsizes(self, sizes: List[int]) -> None: - """Predefine memory areas for query parameters (does nothing).""" + def __await__(self) -> Iterator: + pass - @check_not_closed - def setoutputsize(self, size: int, column: Optional[int] = None) -> None: - """Set a column buffer size for fetches of large columns (does nothing).""" + async def __aexit__( + self, exc_type: type, exc_val: Exception, exc_tb: TracebackType + ) -> None: + self.close() @check_not_closed - async def cancel(self, query_id: str) -> None: - """Cancel a server-side async query.""" - await self._api_request( - parameters={"query_id": query_id}, - path="cancel", - use_set_parameters=False, - ) + @check_query_executed + async def __anext__(self) -> List[ColType]: + row = await self.fetchone() + if row is None: + raise StopAsyncIteration + return row @check_not_closed async def get_status(self, query_id: str) -> QueryStatus: @@ -669,64 +418,14 @@ async def get_status(self, query_id: str) -> QueryStatus: return QueryStatus.NOT_READY return QueryStatus[resp_json["status"]] - # Context manager support @check_not_closed - def __enter__(self) -> BaseCursor: - return self - - def __exit__( - self, exc_type: type, exc_val: Exception, exc_tb: TracebackType - ) -> None: - self.close() - - -class Cursor(BaseCursor): - """ - Executes async queries to Firebolt Database. - Should not be created directly; - use :py:func:`connection.cursor ` - - Args: - description: Information about a single result row. - rowcount: The number of rows produced by last query. - closed: True if connection is closed; False otherwise. - arraysize: Read/Write, specifies the number of rows to fetch at a time - with the :py:func:`fetchmany` method. - - """ - - __slots__ = BaseCursor.__slots__ + ("_async_query_lock",) - - def __init__(self, *args: Any, **kwargs: Any) -> None: - self._async_query_lock = RWLock() - super().__init__(*args, **kwargs) - - @wraps(BaseCursor.execute) - async def execute( - self, - query: str, - parameters: Optional[Sequence[ParameterType]] = None, - skip_parsing: bool = False, - async_execution: Optional[bool] = False, - ) -> Union[int, str]: - async with self._async_query_lock.writer: - return await super().execute( - query, parameters, skip_parsing, async_execution - ) - - @wraps(BaseCursor.executemany) - async def executemany( - self, - query: str, - parameters_seq: Sequence[Sequence[ParameterType]], - async_execution: Optional[bool] = False, - ) -> int: - """ - Prepare and execute a database query against all parameter - sequences provided. - """ - async with self._async_query_lock.writer: - return await super().executemany(query, parameters_seq, async_execution) + async def cancel(self, query_id: str) -> None: + """Cancel a server-side async query.""" + await self._api_request( + parameters={"query_id": query_id}, + path="cancel", + use_set_parameters=False, + ) @wraps(BaseCursor.fetchone) async def fetchone(self) -> Optional[List[ColType]]: @@ -753,17 +452,3 @@ async def fetchall(self) -> List[List[ColType]]: async def nextset(self) -> None: async with self._async_query_lock.reader: return super().nextset() - - # Iteration support - @check_not_closed - @check_query_executed - def __aiter__(self) -> Cursor: - return self - - @check_not_closed - @check_query_executed - async def __anext__(self) -> List[ColType]: - row = await self.fetchone() - if row is None: - raise StopAsyncIteration - return row diff --git a/src/firebolt/client/auth/__init__.py b/src/firebolt/client/auth/__init__.py index 2fef2f52339..90e4d4149a5 100644 --- a/src/firebolt/client/auth/__init__.py +++ b/src/firebolt/client/auth/__init__.py @@ -2,3 +2,4 @@ from firebolt.client.auth.service_account import ServiceAccount from firebolt.client.auth.token import Token from firebolt.client.auth.username_password import UsernamePassword +from firebolt.client.auth.utils import _get_auth diff --git a/src/firebolt/client/auth/utils.py b/src/firebolt/client/auth/utils.py new file mode 100644 index 00000000000..f6bb6b09374 --- /dev/null +++ b/src/firebolt/client/auth/utils.py @@ -0,0 +1,37 @@ +from typing import Optional + +from firebolt.client.auth import Auth, Token, UsernamePassword +from firebolt.utils.exception import ConfigurationError + + +def _get_auth( + username: Optional[str], + password: Optional[str], + access_token: Optional[str], + use_token_cache: bool, +) -> Auth: + """Create `Auth` class based on provided credentials. + + If `access_token` is provided, it's used for `Auth` creation. + Otherwise, username/password are used. + + Returns: + Auth: `auth object` + + Raises: + `ConfigurationError`: Invalid combination of credentials provided + + """ + if not access_token: + if not username or not password: + raise ConfigurationError( + "Neither username/password nor access_token are provided. Provide one" + " to authenticate." + ) + return UsernamePassword(username, password, use_token_cache) + if username or password: + raise ConfigurationError( + "Username/password and access_token are both provided. Provide only one" + " to authenticate." + ) + return Token(access_token) diff --git a/src/firebolt/async_db/_types.py b/src/firebolt/common/_types.py similarity index 98% rename from src/firebolt/async_db/_types.py rename to src/firebolt/common/_types.py index 1881a07fe10..25105957b32 100644 --- a/src/firebolt/async_db/_types.py +++ b/src/firebolt/common/_types.py @@ -63,17 +63,17 @@ def parse_datetime(datetime_string: str) -> datetime: Date = date -def DateFromTicks(t: int) -> date: +def DateFromTicks(t: int) -> date: # NOSONAR """Convert `ticks` to `date` for Firebolt DB.""" return datetime.fromtimestamp(t).date() -def Time(hour: int, minute: int, second: int) -> None: +def Time(hour: int, minute: int, second: int) -> None: # NOSONAR """Unsupported: Construct `time`, for Firebolt DB.""" raise NotSupportedError("The time construct is not supported by Firebolt") -def TimeFromTicks(t: int) -> None: +def TimeFromTicks(t: int) -> None: # NOSONAR """Unsupported: Convert `ticks` to `time` for Firebolt DB.""" raise NotSupportedError("The time construct is not supported by Firebolt") @@ -82,7 +82,7 @@ def TimeFromTicks(t: int) -> None: TimestampFromTicks = datetime.fromtimestamp -def Binary(value: str) -> bytes: +def Binary(value: str) -> bytes: # NOSONAR """Encode a string into UTF-8.""" return value.encode("utf-8") diff --git a/src/firebolt/common/base_connection.py b/src/firebolt/common/base_connection.py new file mode 100644 index 00000000000..1d47374015c --- /dev/null +++ b/src/firebolt/common/base_connection.py @@ -0,0 +1,27 @@ +from typing import Any, List + +from firebolt.utils.exception import ConnectionClosedError + + +class BaseConnection: + def __init__(self) -> None: + self._cursors: List[Any] = [] + self._is_closed = False + + def _remove_cursor(self, cursor: Any) -> None: + # This way it's atomic + try: + self._cursors.remove(cursor) + except ValueError: + pass + + @property + def closed(self) -> bool: + """`True` if connection is closed; `False` otherwise.""" + return self._is_closed + + def commit(self) -> None: + """Does nothing since Firebolt doesn't have transactions.""" + + if self.closed: + raise ConnectionClosedError("Unable to commit: Connection closed.") diff --git a/src/firebolt/common/base_cursor.py b/src/firebolt/common/base_cursor.py new file mode 100644 index 00000000000..eb288998116 --- /dev/null +++ b/src/firebolt/common/base_cursor.py @@ -0,0 +1,365 @@ +from __future__ import annotations + +import logging +from enum import Enum +from functools import wraps +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +from httpx import Response +from pydantic import BaseModel + +from firebolt.common._types import ( + ColType, + Column, + ParameterType, + RawColType, + SetParameter, + parse_type, + parse_value, +) +from firebolt.utils.exception import ( + AsyncExecutionUnavailableError, + CursorClosedError, + DataError, + QueryNotRunError, +) + +logger = logging.getLogger(__name__) + + +JSON_OUTPUT_FORMAT = "JSON_Compact" + + +class CursorState(Enum): + NONE = 1 + ERROR = 2 + DONE = 3 + CLOSED = 4 + + +class QueryStatus(Enum): + """Enumeration of query responses on server-side async queries.""" + + RUNNING = 1 + ENDED_SUCCESSFULLY = 2 + ENDED_UNSUCCESSFULLY = 3 + NOT_READY = 4 + STARTED_EXECUTION = 5 + PARSE_ERROR = 6 + CANCELED_EXECUTION = 7 + EXECUTION_ERROR = 8 + + +class Statistics(BaseModel): + """ + Class for query execution statistics. + """ + + elapsed: float + rows_read: int + bytes_read: int + time_before_execution: float + time_to_execute: float + scanned_bytes_cache: Optional[float] + scanned_bytes_storage: Optional[float] + + +def check_not_closed(func: Callable) -> Callable: + """(Decorator) ensure cursor is not closed before calling method.""" + + @wraps(func) + def inner(self: BaseCursor, *args: Any, **kwargs: Any) -> Any: + if self.closed: + raise CursorClosedError(method_name=func.__name__) + return func(self, *args, **kwargs) + + return inner + + +def check_query_executed(func: Callable) -> Callable: + """ + (Decorator) ensure that some query has been executed before + calling cursor method. + """ + + @wraps(func) + def inner(self: BaseCursor, *args: Any, **kwargs: Any) -> Any: + if self._state == CursorState.NONE: + raise QueryNotRunError(method_name=func.__name__) + return func(self, *args, **kwargs) + + return inner + + +class BaseCursor: + __slots__ = ( + "connection", + "_arraysize", + "_client", + "_state", + "_descriptions", + "_statistics", + "_rowcount", + "_rows", + "_idx", + "_idx_lock", + "_row_sets", + "_next_set_idx", + "_set_parameters", + "_query_id", + ) + + default_arraysize = 1 + + def __init__(self, *args: Any, **kwargs: Any) -> None: + self._arraysize = self.default_arraysize + # These fields initialized here for type annotations purpose + self._rows: Optional[List[List[RawColType]]] = None + self._descriptions: Optional[List[Column]] = None + self._statistics: Optional[Statistics] = None + self._row_sets: List[ + Tuple[ + int, + Optional[List[Column]], + Optional[Statistics], + Optional[List[List[RawColType]]], + ] + ] = [] + self._set_parameters: Dict[str, Any] = dict() + self._rowcount = -1 + self._idx = 0 + self._next_set_idx = 0 + self._query_id = "" + self._reset() + + @property # type: ignore + @check_not_closed + def description(self) -> Optional[List[Column]]: + """ + Provides information about a single result row of a query. + + Attributes: + * ``name`` + * ``type_code`` + * ``display_size`` + * ``internal_size`` + * ``precision`` + * ``scale`` + * ``null_ok`` + """ + return self._descriptions + + @property # type: ignore + @check_not_closed + def statistics(self) -> Optional[Statistics]: + """Query execution statistics returned by the backend.""" + return self._statistics + + @property # type: ignore + @check_not_closed + def rowcount(self) -> int: + """The number of rows produced by last query.""" + return self._rowcount + + @property # type: ignore + @check_not_closed + def query_id(self) -> str: + """The query id of a query executed asynchronously.""" + return self._query_id + + @property + def arraysize(self) -> int: + """Default number of rows returned by fetchmany.""" + return self._arraysize + + @arraysize.setter + def arraysize(self, value: int) -> None: + if not isinstance(value, int): + raise TypeError( + "Invalid arraysize value type, expected int," + f" got {type(value).__name__}" + ) + self._arraysize = value + + @property + def closed(self) -> bool: + """True if connection is closed, False otherwise.""" + return self._state == CursorState.CLOSED + + @check_not_closed + @check_query_executed + def nextset(self) -> Optional[bool]: + """ + Skip to the next available set, discarding any remaining rows + from the current set. + Returns True if operation was successful; + None if there are no more sets to retrive. + """ + return self._pop_next_set() + + def _pop_next_set(self) -> Optional[bool]: + """ + Same functionality as .nextset, but doesn't check that query has been executed. + """ + if self._next_set_idx >= len(self._row_sets): + return None + ( + self._rowcount, + self._descriptions, + self._statistics, + self._rows, + ) = self._row_sets[self._next_set_idx] + self._idx = 0 + self._next_set_idx += 1 + return True + + def flush_parameters(self) -> None: + """Cleanup all previously set parameters""" + self._set_parameters = dict() + + def _reset(self) -> None: + """Clear all data stored from previous query.""" + self._state = CursorState.NONE + self._rows = None + self._descriptions = None + self._statistics = None + self._rowcount = -1 + self._idx = 0 + self._row_sets = [] + self._next_set_idx = 0 + self._query_id = "" + + def _row_set_from_response( + self, response: Response + ) -> Tuple[ + int, + Optional[List[Column]], + Optional[Statistics], + Optional[List[List[RawColType]]], + ]: + """Fetch information about executed query from http response.""" + + # Empty response is returned for insert query + if response.headers.get("content-length", "") == "0": + return (-1, None, None, None) + try: + # Skip parsing floats to properly parse them later + query_data = response.json(parse_float=str) + rowcount = int(query_data["rows"]) + descriptions: Optional[List[Column]] = [ + Column(d["name"], parse_type(d["type"]), None, None, None, None, None) + for d in query_data["meta"] + ] + if not descriptions: + descriptions = None + statistics = Statistics(**query_data["statistics"]) + # Parse data during fetch + rows = query_data["data"] + return (rowcount, descriptions, statistics, rows) + except (KeyError, ValueError) as err: + raise DataError(f"Invalid query data format: {str(err)}") + + def _append_row_set( + self, + row_set: Tuple[ + int, + Optional[List[Column]], + Optional[Statistics], + Optional[List[List[RawColType]]], + ], + ) -> None: + """Store information about executed query.""" + self._row_sets.append(row_set) + if self._next_set_idx == 0: + # Populate values for first set + self._pop_next_set() + + def _validate_server_side_async_settings( + self, + parameters: Sequence[Sequence[ParameterType]], + queries: List[Union[SetParameter, str]], + skip_parsing: bool = False, + async_execution: Optional[bool] = False, + ) -> None: + if async_execution and self._set_parameters.get("use_standard_sql", "1") == "0": + raise AsyncExecutionUnavailableError( + "It is not possible to execute queries asynchronously if " + "use_standard_sql=0." + ) + if parameters and skip_parsing: + logger.warning( + "Query formatting parameters are provided but skip_parsing " + "is specified. They will be ignored." + ) + non_set_queries = 0 + for query in queries: + if type(query) is not SetParameter: + non_set_queries += 1 + if non_set_queries > 1 and async_execution: + raise AsyncExecutionUnavailableError( + "It is not possible to execute multi-statement " + "queries asynchronously." + ) + + def _parse_row(self, row: List[RawColType]) -> List[ColType]: + """Parse a single data row based on query column types.""" + assert len(row) == len(self.description) + return [ + parse_value(col, self.description[i].type_code) for i, col in enumerate(row) + ] + + def _get_next_range(self, size: int) -> Tuple[int, int]: + """ + Return range of next rows of size (if possible), + and update _idx to point to the end of this range + """ + + if self._rows is None: + # No elements to take + raise DataError("no rows to fetch") + + left = self._idx + right = min(self._idx + size, len(self._rows)) + self._idx = right + return left, right + + @check_not_closed + @check_query_executed + def fetchone(self) -> Optional[List[ColType]]: + """Fetch the next row of a query result set.""" + left, right = self._get_next_range(1) + if left == right: + # We are out of elements + return None + assert self._rows is not None + return self._parse_row(self._rows[left]) + + @check_not_closed + @check_query_executed + def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]: + """ + Fetch the next set of rows of a query result; + cursor.arraysize is default size. + """ + size = size if size is not None else self.arraysize + left, right = self._get_next_range(size) + assert self._rows is not None + rows = self._rows[left:right] + return [self._parse_row(row) for row in rows] + + @check_not_closed + @check_query_executed + def fetchall(self) -> List[List[ColType]]: + """Fetch all remaining rows of a query result.""" + left, right = self._get_next_range(self.rowcount) + assert self._rows is not None + rows = self._rows[left:right] + return [self._parse_row(row) for row in rows] + + @check_not_closed + def setinputsizes(self, sizes: List[int]) -> None: + """Predefine memory areas for query parameters (does nothing).""" + + @check_not_closed + def setoutputsize(self, size: int, column: Optional[int] = None) -> None: + """Set a column buffer size for fetches of large columns (does nothing).""" diff --git a/src/firebolt/common/settings.py b/src/firebolt/common/settings.py index f77db994626..0864e2a5fbc 100644 --- a/src/firebolt/common/settings.py +++ b/src/firebolt/common/settings.py @@ -7,6 +7,11 @@ logger = logging.getLogger(__name__) +KEEPALIVE_FLAG: int = 1 + +KEEPIDLE_RATE: int = 60 # seconds +DEFAULT_TIMEOUT_SECONDS: int = 60 + AUTH_CREDENTIALS_DEPRECATION_MESSAGE = """ Passing connection credentials directly in Settings is deprecated. Use Auth object instead. Examples: diff --git a/src/firebolt/db/__init__.py b/src/firebolt/db/__init__.py index 8ee02c1448f..fcf6ae53c80 100644 --- a/src/firebolt/db/__init__.py +++ b/src/firebolt/db/__init__.py @@ -1,4 +1,4 @@ -from firebolt.async_db._types import ( +from firebolt.common._types import ( ARRAY, BINARY, DATETIME, diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index 10fa91309e1..415603a8729 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -1,20 +1,158 @@ from __future__ import annotations -from functools import wraps +import logging +import socket +from json import JSONDecodeError from types import TracebackType -from typing import Any +from typing import Any, Dict, List, Optional from warnings import warn +from httpcore.backends.base import NetworkStream +from httpcore.backends.sync import SyncBackend +from httpx import HTTPStatusError, HTTPTransport, RequestError, Timeout from readerwriterlock.rwlock import RWLockWrite -from firebolt.async_db.connection import BaseConnection as AsyncBaseConnection -from firebolt.async_db.connection import async_connect_factory +from firebolt.client import DEFAULT_API_URL, Client +from firebolt.client.auth import Auth, _get_auth +from firebolt.common.base_connection import BaseConnection +from firebolt.common.settings import ( + AUTH_CREDENTIALS_DEPRECATION_MESSAGE, + DEFAULT_TIMEOUT_SECONDS, + KEEPALIVE_FLAG, + KEEPIDLE_RATE, +) +from firebolt.common.util import validate_engine_name_and_url from firebolt.db.cursor import Cursor -from firebolt.utils.exception import ConnectionClosedError -from firebolt.utils.util import async_to_sync +from firebolt.utils.exception import ( + ConfigurationError, + ConnectionClosedError, + FireboltEngineError, + InterfaceError, +) +from firebolt.utils.urls import ( + ACCOUNT_ENGINE_ID_BY_NAME_URL, + ACCOUNT_ENGINE_URL, + ACCOUNT_ENGINE_URL_BY_DATABASE_NAME, +) +from firebolt.utils.usage_tracker import get_user_agent_header +from firebolt.utils.util import fix_url_schema +logger = logging.getLogger(__name__) -class Connection(AsyncBaseConnection): + +def _resolve_engine_url( + engine_name: str, + auth: Auth, + api_endpoint: str, + account_name: Optional[str] = None, +) -> str: + with Client( + auth=auth, + base_url=api_endpoint, + account_name=account_name, + api_endpoint=api_endpoint, + timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), + ) as client: + account_id = client.account_id + url = ACCOUNT_ENGINE_ID_BY_NAME_URL.format(account_id=account_id) + try: + response = client.get( + url=url, + params={"engine_name": engine_name}, + ) + response.raise_for_status() + engine_id = response.json()["engine_id"]["engine_id"] + url = ACCOUNT_ENGINE_URL.format(account_id=account_id, engine_id=engine_id) + response = client.get(url=url) + response.raise_for_status() + return response.json()["engine"]["endpoint"] + except HTTPStatusError as e: + # Engine error would be 404. + if e.response.status_code != 404: + raise InterfaceError( + f"Error {e.__class__.__name__}: Unable to retrieve engine " + f"endpoint {url}." + ) + # Once this is point is reached we've already authenticated with + # the backend so it's safe to assume the cause of the error is + # missing engine. + raise FireboltEngineError(f"Firebolt engine {engine_name} does not exist.") + except (JSONDecodeError, RequestError, RuntimeError) as e: + raise InterfaceError( + f"Error {e.__class__.__name__}: " + f"Unable to retrieve engine endpoint {url}." + ) + + +def _get_database_default_engine_url( + database: str, + auth: Auth, + api_endpoint: str, + account_name: Optional[str] = None, +) -> str: + with Client( + auth=auth, + base_url=api_endpoint, + account_name=account_name, + api_endpoint=api_endpoint, + timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), + ) as client: + try: + account_id = client.account_id + response = client.get( + url=ACCOUNT_ENGINE_URL_BY_DATABASE_NAME.format(account_id=account_id), + params={"database_name": database}, + ) + response.raise_for_status() + return response.json()["engine_url"] + except ( + JSONDecodeError, + RequestError, + RuntimeError, + HTTPStatusError, + KeyError, + ) as e: + raise InterfaceError(f"Unable to retrieve default engine endpoint: {e}.") + + +class OverriddenHttpBackend(SyncBackend): + """ + `OverriddenHttpBackend` is a short-term solution for the TCP + connection idle timeout issue described in the following article: + https://docs.aws.amazon.com/elasticloadbalancing/latest/network/network-load-balancers.html#connection-idle-timeout + Since httpx creates a connection right before executing a request, the + backend must be overridden to set the socket to `KEEPALIVE` + and `KEEPIDLE` settings. + """ + + def connect_tcp( + self, + host: str, + port: int, + timeout: Optional[float] = None, + local_address: Optional[str] = None, + ) -> NetworkStream: + stream = super().connect_tcp( + host, port, timeout=timeout, local_address=local_address + ) + # Enable keepalive + stream.get_extra_info("socket").setsockopt( + socket.SOL_SOCKET, socket.SO_KEEPALIVE, KEEPALIVE_FLAG + ) + # MacOS does not have TCP_KEEPIDLE + if hasattr(socket, "TCP_KEEPIDLE"): + keepidle = socket.TCP_KEEPIDLE + else: + keepidle = 0x10 # TCP_KEEPALIVE on mac + + # Set keepalive to 60 seconds + stream.get_extra_info("socket").setsockopt( + socket.IPPROTO_TCP, keepidle, KEEPIDLE_RATE + ) + return stream + + +class Connection(BaseConnection): """ Firebolt database connection class. Implements PEP-249. @@ -31,26 +169,78 @@ class Connection(AsyncBaseConnection): are not implemented. """ - __slots__ = AsyncBaseConnection.__slots__ + ("_closing_lock",) - - cursor_class = Cursor + client_class: type + __slots__ = ( + "_client", + "_cursors", + "database", + "engine_url", + "api_endpoint", + "_is_closed", + "_closing_lock", + ) - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__( + self, + engine_url: str, + database: str, + auth: Auth, + api_endpoint: str = DEFAULT_API_URL, + additional_parameters: Dict[str, Any] = {}, + ): + self.api_endpoint = api_endpoint + self.engine_url = engine_url + self.database = database + self._cursors: List[Cursor] = [] + # Override tcp keepalive settings for connection + transport = HTTPTransport() + transport._pool._network_backend = OverriddenHttpBackend() + user_drivers = additional_parameters.get("user_drivers", []) + user_clients = additional_parameters.get("user_clients", []) + self._client = Client( + auth=auth, + base_url=engine_url, + api_endpoint=api_endpoint, + timeout=Timeout(DEFAULT_TIMEOUT_SECONDS, read=None), + transport=transport, + headers={"User-Agent": get_user_agent_header(user_drivers, user_clients)}, + ) # Holding this lock for write means that connection is closing itself. # cursor() should hold this lock for read to read/write state self._closing_lock = RWLockWrite() + super().__init__() + + def cursor(self, **kwargs: Any) -> Cursor: + if self.closed: + raise ConnectionClosedError("Unable to create cursor: connection closed.") - def cursor(self) -> Cursor: with self._closing_lock.gen_rlock(): - c = super()._cursor() - assert isinstance(c, Cursor) # typecheck - return c + c = Cursor(client=self._client, connection=self, **kwargs) + self._cursors.append(c) + return c + + def _remove_cursor(self, cursor: Cursor) -> None: + # This way it's atomic + try: + self._cursors.remove(cursor) + except ValueError: + pass - @wraps(AsyncBaseConnection._aclose) def close(self) -> None: + if self.closed: + return + + # self._cursors is going to be changed during closing cursors + # after this point no cursors would be added to _cursors, only removed since + # closing lock is held, and later connection will be marked as closed with self._closing_lock.gen_wlock(): - async_to_sync(self._aclose)() + cursors = self._cursors[:] + for c in cursors: + # Here c can already be closed by another thread, + # but it shouldn't raise an error in this case + c.close() + self._client.close() + self._is_closed = True # Context manager support def __enter__(self) -> Connection: @@ -68,4 +258,90 @@ def __del__(self) -> None: warn(f"Unclosed {self!r}", UserWarning) -connect = async_to_sync(async_connect_factory(Connection)) +def connect( + database: str = None, + username: Optional[str] = None, + password: Optional[str] = None, + access_token: Optional[str] = None, + auth: Auth = None, + engine_name: Optional[str] = None, + engine_url: Optional[str] = None, + account_name: Optional[str] = None, + api_endpoint: str = DEFAULT_API_URL, + use_token_cache: bool = True, + additional_parameters: Dict[str, Any] = {}, +) -> Connection: + """Connect to Firebolt database. + + Args: + `database` (str): Name of the database to connect + `username` (Optional[str]): User name to use for authentication (Deprecated) + `password` (Optional[str]): Password to use for authentication (Deprecated) + `access_token` (Optional[str]): Authentication token to use instead of + credentials (Deprecated) + `auth` (Auth)L Authentication object. + `engine_name` (Optional[str]): Name of the engine to connect to + `engine_url` (Optional[str]): The engine endpoint to use + `account_name` (Optional[str]): For customers with multiple accounts; + if none, default is used + `api_endpoint` (str): Firebolt API endpoint. Used for authentication + `use_token_cache` (bool): Cached authentication token in filesystem + Default: True + `additional_parameters` (Optional[Dict]): Dictionary of less widely-used + arguments for connection + + Note: + Providing both `engine_name` and `engine_url` will result in an error + + """ + # These parameters are optional in function signature + # but are required to connect. + # PEP 249 recommends making them kwargs. + if not database: + raise ConfigurationError("database name is required to connect.") + + validate_engine_name_and_url(engine_name, engine_url) + + if not auth: + if any([username, password, access_token, api_endpoint, use_token_cache]): + logger.warning(AUTH_CREDENTIALS_DEPRECATION_MESSAGE) + auth = _get_auth(username, password, access_token, use_token_cache) + else: + raise ConfigurationError("No authentication provided.") + api_endpoint = fix_url_schema(api_endpoint) + + # Mypy checks, this should never happen + assert database is not None + + if not engine_name and not engine_url: + engine_url = _get_database_default_engine_url( + database=database, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + ) + + elif engine_name: + engine_url = _resolve_engine_url( + engine_name=engine_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + ) + elif account_name: + # In above if branches account name is validated since it's used to + # resolve or get an engine url. + # We need to manually validate account_name if none of the above + # cases are triggered. + with Client( + auth=auth, + base_url=api_endpoint, + account_name=account_name, + api_endpoint=api_endpoint, + ) as client: + client.account_id + + assert engine_url is not None + + engine_url = fix_url_schema(engine_url) + return Connection(engine_url, database, auth, api_endpoint, additional_parameters) diff --git a/src/firebolt/db/cursor.py b/src/firebolt/db/cursor.py index 96e1f976def..69c84a81115 100644 --- a/src/firebolt/db/cursor.py +++ b/src/firebolt/db/cursor.py @@ -1,23 +1,58 @@ from __future__ import annotations -from functools import wraps +import logging +import re +import time from threading import Lock -from typing import Any, Generator, List, Optional, Sequence, Tuple, Union +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Generator, + List, + Optional, + Sequence, + Tuple, + Union, +) +from httpx import Client as HttpxClient +from httpx import Response, codes from readerwriterlock.rwlock import RWLockWrite -from firebolt.async_db._types import ColType -from firebolt.async_db.cursor import BaseCursor as AsyncBaseCursor -from firebolt.async_db.cursor import ( +from firebolt.common._types import ( + ColType, + Column, ParameterType, + RawColType, + SetParameter, + split_format_sql, +) +from firebolt.common.base_cursor import ( + JSON_OUTPUT_FORMAT, + BaseCursor, + CursorState, QueryStatus, + Statistics, check_not_closed, check_query_executed, ) -from firebolt.utils.util import async_to_sync +from firebolt.db.util import is_db_available, is_engine_running +from firebolt.utils.exception import ( + AsyncExecutionUnavailableError, + EngineNotRunningError, + FireboltDatabaseError, + OperationalError, + ProgrammingError, +) +if TYPE_CHECKING: + from firebolt.db.connection import Connection -class Cursor(AsyncBaseCursor): +logger = logging.getLogger(__name__) + + +class Cursor(BaseCursor): """ Class, responsible for executing queries to Firebolt Database. Should not be created directly, @@ -31,17 +66,172 @@ class Cursor(AsyncBaseCursor): with the :py:func:`fetchmany` method """ - __slots__ = AsyncBaseCursor.__slots__ + ( + __slots__ = BaseCursor.__slots__ + ( "_query_lock", "_idx_lock", ) - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__( + self, *args: Any, client: HttpxClient, connection: Connection, **kwargs: Any + ) -> None: + super().__init__(*args, **kwargs) self._query_lock = RWLockWrite() self._idx_lock = Lock() - super().__init__(*args, **kwargs) + self._client = client + self.connection = connection + + def _raise_if_error(self, resp: Response) -> None: + """Raise a proper error if any""" + if resp.status_code == codes.INTERNAL_SERVER_ERROR: + raise OperationalError( + f"Error executing query:\n{resp.read().decode('utf-8')}" + ) + if resp.status_code == codes.FORBIDDEN: + if not is_db_available(self.connection, self.connection.database): + raise FireboltDatabaseError( + f"Database {self.connection.database} does not exist" + ) + raise ProgrammingError(resp.read().decode("utf-8")) + if ( + resp.status_code == codes.SERVICE_UNAVAILABLE + or resp.status_code == codes.NOT_FOUND + ): + if not is_engine_running(self.connection, self.connection.engine_url): + raise EngineNotRunningError( + f"Firebolt engine {self.connection.engine_url} " + "needs to be running to run queries against it." + ) + resp.raise_for_status() + + def _api_request( + self, + query: Optional[str] = "", + parameters: Optional[dict[str, Any]] = {}, + path: Optional[str] = "", + use_set_parameters: Optional[bool] = True, + ) -> Response: + """ + Query API, return Response object. + + Args: + query (str): SQL query + parameters (Optional[Sequence[ParameterType]]): A sequence of substitution + parameters. Used to replace '?' placeholders inside a query with + actual values. Note: In order to "output_format" dict value, it + must be an empty string. If no value not specified, + JSON_OUTPUT_FORMAT will be used. + path (str): endpoint suffix, for example "cancel" or "status" + use_set_parameters: Optional[bool]: Some queries will fail if additional + set parameters are sent. Setting this to False will allow + self._set_parameters to be ignored. + """ + if use_set_parameters: + parameters = {**(self._set_parameters or {}), **(parameters or {})} + return self._client.request( + url=f"/{path}", + method="POST", + params={ + "database": self.connection.database, + **(parameters or dict()), + }, + content=query, + ) + + def _validate_set_parameter(self, parameter: SetParameter) -> None: + """Validate parameter by executing simple query with it.""" + if parameter.name == "async_execution": + raise AsyncExecutionUnavailableError( + "It is not possible to set async_execution using a SET command. " + "Instead, pass it as an argument to the execute() or " + "executemany() function." + ) + resp = self._api_request("select 1", {parameter.name: parameter.value}) + # Handle invalid set parameter + if resp.status_code == codes.BAD_REQUEST: + raise OperationalError(resp.text) + self._raise_if_error(resp) + + # set parameter passed validation + self._set_parameters[parameter.name] = parameter.value + + def _do_execute( + self, + raw_query: str, + parameters: Sequence[Sequence[ParameterType]], + skip_parsing: bool = False, + async_execution: Optional[bool] = False, + ) -> None: + self._reset() + # Allow users to manually skip parsing for performance improvement. + queries: List[Union[SetParameter, str]] = ( + [raw_query] if skip_parsing else split_format_sql(raw_query, parameters) + ) + try: + for query in queries: + + start_time = time.time() + # Our CREATE EXTERNAL TABLE queries currently require credentials, + # so we will skip logging those queries. + # https://docs.firebolt.io/sql-reference/commands/create-external-table.html + if isinstance(query, SetParameter) or not re.search( + "aws_key_id|credentials", query, flags=re.IGNORECASE + ): + logger.debug(f"Running query: {query}") + + # Define type for mypy + row_set: Tuple[ + int, + Optional[List[Column]], + Optional[Statistics], + Optional[List[List[RawColType]]], + ] = (-1, None, None, None) + if isinstance(query, SetParameter): + self._validate_set_parameter(query) + elif async_execution: + self._validate_server_side_async_settings( + parameters, + queries, + skip_parsing, + async_execution, + ) + response = self._api_request( + query, + { + "async_execution": 1, + "advanced_mode": 1, + "output_format": JSON_OUTPUT_FORMAT, + }, + ) + self._raise_if_error(response) + if response.headers.get("content-length", "") == "0": + raise OperationalError("No response to asynchronous query.") + resp = response.json() + if "query_id" not in resp or resp["query_id"] == "": + raise OperationalError( + "Invalid response to asynchronous query: missing query_id." + ) + self._query_id = resp["query_id"] + else: + resp = self._api_request( + query, {"output_format": JSON_OUTPUT_FORMAT} + ) + self._raise_if_error(resp) + row_set = self._row_set_from_response(resp) + + self._append_row_set(row_set) + + logger.info( + f"Query fetched {self.rowcount} rows in" + f" {time.time() - start_time} seconds." + ) - @wraps(AsyncBaseCursor.execute) + self._state = CursorState.DONE + + except Exception: + self._state = CursorState.ERROR + raise + + @check_not_closed def execute( self, query: str, @@ -49,47 +239,134 @@ def execute( skip_parsing: bool = False, async_execution: Optional[bool] = False, ) -> Union[int, str]: - with self._query_lock.gen_wlock(): - return async_to_sync(super().execute)( - query, parameters, skip_parsing, async_execution - ) + """Prepare and execute a database query. + + Supported features: + Parameterized queries: placeholder characters ('?') are substituted + with values provided in `parameters`. Values are formatted to + be properly recognized by database and to exclude SQL injection. + Multi-statement queries: multiple statements, provided in a single query + and separated by semicolon, are executed separatelly and sequentially. + To switch to next statement result, `nextset` method should be used. + SET statements: to provide additional query execution parameters, execute + `SET param=value` statement before it. All parameters are stored in + cursor object until it's closed. They can also be removed with + `flush_parameters` method call. + + Args: + query (str): SQL query to execute + parameters (Optional[Sequence[ParameterType]]): A sequence of substitution + parameters. Used to replace '?' placeholders inside a query with + actual values + skip_parsing (bool): Flag to disable query parsing. This will + disable parameterized, multi-statement and SET queries, + while improving performance + async_execution (bool): flag to determine if query should be asynchronous - @wraps(AsyncBaseCursor.executemany) + Returns: + int: Query row count. + """ + params_list = [parameters] if parameters else [] + self._do_execute(query, params_list, skip_parsing, async_execution) + return self.query_id if async_execution else self.rowcount + + @check_not_closed def executemany( self, query: str, parameters_seq: Sequence[Sequence[ParameterType]], async_execution: Optional[bool] = False, ) -> Union[int, str]: - with self._query_lock.gen_wlock(): - return async_to_sync(super().executemany)( - query, parameters_seq, async_execution + """Prepare and execute a database query. + + Supports providing multiple substitution parameter sets, executing them + as multiple statements sequentially. + + Supported features: + Parameterized queries: Placeholder characters ('?') are substituted + with values provided in `parameters`. Values are formatted to + be properly recognized by database and to exclude SQL injection. + Multi-statement queries: Multiple statements, provided in a single query + and separated by semicolon, are executed separately and sequentially. + To switch to next statement result, use `nextset` method. + SET statements: To provide additional query execution parameters, execute + `SET param=value` statement before it. All parameters are stored in + cursor object until it's closed. They can also be removed with + `flush_parameters` method call. + + Args: + query (str): SQL query to execute. + parameters_seq (Sequence[Sequence[ParameterType]]): A sequence of + substitution parameter sets. Used to replace '?' placeholders inside a + query with actual values from each set in a sequence. Resulting queries + for each subset are executed sequentially. + async_execution (bool): flag to determine if query should be asynchronous + + Returns: + int|str: Query row count for synchronous execution of queries, + query ID string for asynchronous execution. + """ + self._do_execute(query, parameters_seq, async_execution=async_execution) + if async_execution: + return self.query_id + else: + return self.rowcount + + @check_not_closed + def get_status(self, query_id: str) -> QueryStatus: + """Get status of a server-side async query. Return the state of the query.""" + try: + resp = self._api_request( + # output_format must be empty for status to work correctly. + # And set parameters will cause 400 errors. + parameters={"query_id": query_id}, + path="status", + use_set_parameters=False, ) + if resp.status_code == codes.BAD_REQUEST: + raise OperationalError( + f"Asynchronous query {query_id} status check failed: " + f"{resp.status_code}." + ) + resp_json = resp.json() + if "status" not in resp_json: + raise OperationalError( + "Invalid response to asynchronous query: missing status." + ) + except Exception: + self._state = CursorState.ERROR + raise + # Remember that query_id might be empty. + if resp_json["status"] == "": + return QueryStatus.NOT_READY + return QueryStatus[resp_json["status"]] - @wraps(AsyncBaseCursor._get_next_range) - def _get_next_range(self, size: int) -> Tuple[int, int]: - with self._idx_lock: - return super()._get_next_range(size) + def close(self) -> None: + """Terminate an ongoing query (if any) and mark connection as closed.""" + self._state = CursorState.CLOSED + self.connection._remove_cursor(self) - @wraps(AsyncBaseCursor.fetchone) - def fetchone(self) -> Optional[List[ColType]]: - with self._query_lock.gen_rlock(): - return super().fetchone() + def __del__(self) -> None: + self.close() - @wraps(AsyncBaseCursor.fetchmany) - def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]: - with self._query_lock.gen_rlock(): - return super().fetchmany(size) + # Context manager support + @check_not_closed + def __enter__(self) -> Cursor: + return self - @wraps(AsyncBaseCursor.fetchall) - def fetchall(self) -> List[List[ColType]]: - with self._query_lock.gen_rlock(): - return super().fetchall() + def __exit__( + self, exc_type: type, exc_val: Exception, exc_tb: TracebackType + ) -> None: + self.close() - @wraps(AsyncBaseCursor.nextset) - def nextset(self) -> None: - with self._query_lock.gen_rlock(), self._idx_lock: - return super().nextset() + @check_not_closed + def cancel(self, query_id: str) -> None: + """Cancel a server-side async query.""" + self._api_request( + parameters={"query_id": query_id}, + path="cancel", + use_set_parameters=False, + ) # Iteration support @check_not_closed @@ -100,13 +377,3 @@ def __iter__(self) -> Generator[List[ColType], None, None]: if row is None: return yield row - - @wraps(AsyncBaseCursor.get_status) - def get_status(self, query_id: str) -> QueryStatus: - with self._query_lock.gen_rlock(): - return async_to_sync(super().get_status)(query_id) - - @wraps(AsyncBaseCursor.cancel) - def cancel(self, query_id: str) -> None: - with self._query_lock.gen_rlock(): - return async_to_sync(super().cancel)(query_id) diff --git a/src/firebolt/db/util.py b/src/firebolt/db/util.py new file mode 100644 index 00000000000..94f24c3be7c --- /dev/null +++ b/src/firebolt/db/util.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from httpx import URL, Response + +from firebolt.utils.urls import DATABASES_URL, ENGINES_URL + +if TYPE_CHECKING: + from firebolt.db.connection import Connection + + +def is_db_available(connection: Connection, database_name: str) -> bool: + """ + Verify that the database exists. + + Args: + connection (firebolt.async_db.connection.Connection) + """ + resp = _filter_request( + connection, DATABASES_URL, {"filter.name_contains": database_name} + ) + return len(resp.json()["edges"]) > 0 + + +def is_engine_running(connection: Connection, engine_url: str) -> bool: + """ + Verify that the engine is running. + + Args: + connection (firebolt.async_db.connection.Connection): connection. + """ + # Url is not guaranteed to be of this structure, + # but for the sake of error checking this is sufficient. + engine_name = URL(engine_url).host.split(".")[0].replace("-", "_") + resp = _filter_request( + connection, + ENGINES_URL, + { + "filter.name_contains": engine_name, + "filter.current_status_eq": "ENGINE_STATUS_RUNNING_REVISION_SERVING", + }, + ) + return len(resp.json()["edges"]) > 0 + + +def _filter_request(connection: Connection, endpoint: str, filters: dict) -> Response: + resp = connection._client.request( + # Full url overrides the client url, which contains engine as a prefix. + url=connection.api_endpoint + endpoint, + method="GET", + params=filters, + ) + resp.raise_for_status() + return resp diff --git a/src/firebolt/model/engine.py b/src/firebolt/model/engine.py index 6109a1f4df4..8eab5e8e35b 100644 --- a/src/firebolt/model/engine.py +++ b/src/firebolt/model/engine.py @@ -194,7 +194,7 @@ def get_connection(self) -> Connection: """ return connect( database=self.database.name, # type: ignore # already checked by decorator - auth=self._service.client.auth, + auth=self._service.client.auth, # type: ignore engine_url=self.endpoint, account_name=self._service.settings.account_name, api_endpoint=self._service.settings.server, diff --git a/src/firebolt/utils/util.py b/src/firebolt/utils/util.py index e05ae677fb5..96ac7a1ae64 100644 --- a/src/firebolt/utils/util.py +++ b/src/firebolt/utils/util.py @@ -1,9 +1,11 @@ from functools import lru_cache, partial, wraps -from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar import trio from httpx import URL +from firebolt.utils.exception import ConfigurationError + T = TypeVar("T") @@ -103,3 +105,12 @@ def merge_urls(base: URL, merge: URL) -> URL: merge_raw_path = base.raw_path + merge.raw_path.lstrip(b"/") return base.copy_with(raw_path=merge_raw_path) return merge + + +def validate_engine_name_and_url( + engine_name: Optional[str], engine_url: Optional[str] +) -> None: + if engine_name and engine_url: + raise ConfigurationError( + "Both engine_name and engine_url are provided. Provide only one to connect." + ) diff --git a/tests/integration/dbapi/async/test_queries_async.py b/tests/integration/dbapi/async/test_queries_async.py index 3fd749a7ad9..d671ae6f029 100644 --- a/tests/integration/dbapi/async/test_queries_async.py +++ b/tests/integration/dbapi/async/test_queries_async.py @@ -11,8 +11,8 @@ DataError, OperationalError, ) -from firebolt.async_db._types import ColType, Column from firebolt.async_db.cursor import QueryStatus +from firebolt.common._types import ColType, Column VALS_TO_INSERT_2 = ",".join( [f"({i}, {i-3}, '{val}')" for (i, val) in enumerate(range(4, 1000))] diff --git a/tests/integration/dbapi/conftest.py b/tests/integration/dbapi/conftest.py index 1843a22b894..b61bedf4077 100644 --- a/tests/integration/dbapi/conftest.py +++ b/tests/integration/dbapi/conftest.py @@ -5,8 +5,8 @@ from pytest import fixture -from firebolt.async_db._types import ColType from firebolt.async_db.cursor import Column +from firebolt.common._types import ColType from firebolt.db import ARRAY, DECIMAL, Connection LOGGER = getLogger(__name__) diff --git a/tests/integration/dbapi/sync/test_queries.py b/tests/integration/dbapi/sync/test_queries.py index 8af7f781984..44eb217c8d0 100644 --- a/tests/integration/dbapi/sync/test_queries.py +++ b/tests/integration/dbapi/sync/test_queries.py @@ -5,9 +5,9 @@ from pytest import mark, raises -from firebolt.async_db._types import ColType, Column from firebolt.async_db.cursor import QueryStatus from firebolt.client.auth import Auth +from firebolt.common._types import ColType, Column from firebolt.db import ( Binary, Connection, diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index 570a3ed1217..6da6cf46445 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -7,9 +7,9 @@ from pytest import mark, raises from pytest_httpx import HTTPXMock -from firebolt.async_db._types import ColType from firebolt.async_db.connection import Connection, connect from firebolt.client.auth import Auth, Token, UsernamePassword +from firebolt.common._types import ColType from firebolt.common.settings import Settings from firebolt.utils.exception import ( AccountNotFoundError, diff --git a/tests/unit/async_db/test_cursor.py b/tests/unit/async_db/test_cursor.py index 69be65ae398..64f13d4318f 100644 --- a/tests/unit/async_db/test_cursor.py +++ b/tests/unit/async_db/test_cursor.py @@ -6,8 +6,8 @@ from pytest_httpx import HTTPXMock from firebolt.async_db import Cursor -from firebolt.async_db._types import Column -from firebolt.async_db.cursor import ColType, CursorState, QueryStatus +from firebolt.common._types import Column +from firebolt.common.base_cursor import ColType, CursorState, QueryStatus from firebolt.utils.exception import ( AsyncExecutionUnavailableError, CursorClosedError, @@ -139,6 +139,9 @@ async def test_cursor_no_query( # Context manager is also available with cursor: pass + # should this be available? + # async with cursor: + # pass async def test_cursor_execute( diff --git a/tests/unit/async_db/test_typing_format.py b/tests/unit/async_db/test_typing_format.py index a5e68e07aed..c04027ee6c6 100644 --- a/tests/unit/async_db/test_typing_format.py +++ b/tests/unit/async_db/test_typing_format.py @@ -12,7 +12,7 @@ InterfaceError, NotSupportedError, ) -from firebolt.async_db._types import ( +from firebolt.common._types import ( SetParameter, format_statement, format_value, diff --git a/tests/unit/async_db/test_typing_parse.py b/tests/unit/async_db/test_typing_parse.py index fad2c93a693..76528a96c78 100644 --- a/tests/unit/async_db/test_typing_parse.py +++ b/tests/unit/async_db/test_typing_parse.py @@ -11,7 +11,7 @@ TimeFromTicks, TimestampFromTicks, ) -from firebolt.async_db._types import parse_type, parse_value +from firebolt.common._types import parse_type, parse_value from firebolt.utils.exception import DataError, NotSupportedError diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index 594a99060ce..571c8d7f06d 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -8,8 +8,8 @@ from pytest import mark, raises, warns from pytest_httpx import HTTPXMock -from firebolt.async_db._types import ColType from firebolt.client.auth import Auth, Token, UsernamePassword +from firebolt.common._types import ColType from firebolt.common.settings import Settings from firebolt.db import Connection, connect from firebolt.utils.exception import ( diff --git a/tests/unit/db/test_cursor.py b/tests/unit/db/test_cursor.py index 8562d50a901..8b8e18be77c 100644 --- a/tests/unit/db/test_cursor.py +++ b/tests/unit/db/test_cursor.py @@ -5,8 +5,8 @@ from pytest import raises from pytest_httpx import HTTPXMock -from firebolt.async_db.cursor import ColType, Column, CursorState, QueryStatus from firebolt.db import Cursor +from firebolt.db.cursor import ColType, Column, CursorState, QueryStatus from firebolt.utils.exception import ( CursorClosedError, DataError, @@ -542,11 +542,11 @@ def test_cursor_skip_parse( httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(query_callback, url=query_url) - with patch("firebolt.async_db.cursor.split_format_sql") as split_format_sql_mock: + with patch("firebolt.db.cursor.split_format_sql") as split_format_sql_mock: cursor.execute("non-an-actual-sql") split_format_sql_mock.assert_called_once() - with patch("firebolt.async_db.cursor.split_format_sql") as split_format_sql_mock: + with patch("firebolt.db.cursor.split_format_sql") as split_format_sql_mock: cursor.execute("non-an-actual-sql", skip_parsing=True) split_format_sql_mock.assert_not_called()