From fdcb2ba79e21da2087a37cdcf76dbe9c9fd4650a Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 2 Apr 2025 12:42:55 +0300 Subject: [PATCH 01/10] initialize row set absraction --- src/firebolt/async_db/cursor.py | 12 ++-- src/firebolt/common/_types.py | 13 ---- src/firebolt/common/base_cursor.py | 66 ++----------------- src/firebolt/common/constants.py | 25 ++++++- src/firebolt/common/row_set/__init__.py | 0 src/firebolt/common/row_set/async/__init__.py | 0 src/firebolt/common/row_set/async/base.py | 51 ++++++++++++++ src/firebolt/common/row_set/sync/__init__.py | 0 src/firebolt/common/row_set/sync/base.py | 53 +++++++++++++++ src/firebolt/common/row_set/types.py | 53 +++++++++++++++ src/firebolt/db/cursor.py | 12 ++-- .../dbapi/async/V1/test_queries_async.py | 3 +- .../dbapi/async/V2/test_queries_async.py | 3 +- .../async/V2/test_system_engine_async.py | 3 +- tests/integration/dbapi/conftest.py | 3 +- .../integration/dbapi/sync/V1/test_queries.py | 3 +- .../integration/dbapi/sync/V2/test_queries.py | 3 +- .../dbapi/sync/V2/test_system_engine.py | 3 +- tests/unit/V1/async_db/test_cursor.py | 5 +- tests/unit/V1/db/test_cursor.py | 5 +- tests/unit/async_db/test_cursor.py | 5 +- tests/unit/db/test_cursor.py | 5 +- tests/unit/db_conftest.py | 6 +- 23 files changed, 232 insertions(+), 100 deletions(-) create mode 100644 src/firebolt/common/row_set/__init__.py create mode 100644 src/firebolt/common/row_set/async/__init__.py create mode 100644 src/firebolt/common/row_set/async/base.py create mode 100644 src/firebolt/common/row_set/sync/__init__.py create mode 100644 src/firebolt/common/row_set/sync/base.py create mode 100644 src/firebolt/common/row_set/types.py diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 95910a8fb9..ddeda5bfda 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -29,12 +29,7 @@ from firebolt.client.client import AsyncClient, AsyncClientV1, AsyncClientV2 from firebolt.common._types import ColType, ParameterType, SetParameter from firebolt.common.base_cursor import ( - JSON_OUTPUT_FORMAT, - RESET_SESSION_HEADER, - UPDATE_ENDPOINT_HEADER, - UPDATE_PARAMETERS_HEADER, BaseCursor, - CursorState, _parse_update_endpoint, _parse_update_parameters, _raise_if_internal_set_parameter, @@ -42,6 +37,13 @@ check_not_closed, check_query_executed, ) +from firebolt.common.constants import ( + JSON_OUTPUT_FORMAT, + RESET_SESSION_HEADER, + UPDATE_ENDPOINT_HEADER, + UPDATE_PARAMETERS_HEADER, + CursorState, +) from firebolt.common.statement_formatter import create_statement_formatter from firebolt.utils.exception import ( EngineNotRunningError, diff --git a/src/firebolt/common/_types.py b/src/firebolt/common/_types.py index 877c6ed387..172fe00b04 100644 --- a/src/firebolt/common/_types.py +++ b/src/firebolt/common/_types.py @@ -77,19 +77,6 @@ def Binary(value: str) -> bytes: # NOSONAR DATETIME = datetime ROWID = int -Column = namedtuple( - "Column", - ( - "name", - "type_code", - "display_size", - "internal_size", - "precision", - "scale", - "null_ok", - ), -) - class ExtendedType: """Base type for all extended types in Firebolt (array, decimal, struct, etc.).""" diff --git a/src/firebolt/common/base_cursor.py b/src/firebolt/common/base_cursor.py index 0db6fa70e5..c6e8e28c77 100644 --- a/src/firebolt/common/base_cursor.py +++ b/src/firebolt/common/base_cursor.py @@ -2,8 +2,6 @@ import logging import re -from dataclasses import dataclass, fields -from enum import Enum from functools import wraps from types import TracebackType from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -12,12 +10,18 @@ from firebolt.common._types import ( ColType, - Column, RawColType, SetParameter, parse_type, parse_value, ) +from firebolt.common.constants import ( + DISALLOWED_PARAMETER_LIST, + IMMUTABLE_PARAMETER_LIST, + USE_PARAMETER_LIST, + CursorState, +) +from firebolt.common.row_set.types import AsyncResponse, Column, Statistics from firebolt.common.statement_formatter import StatementFormatter from firebolt.utils.exception import ( ConfigurationError, @@ -32,28 +36,6 @@ logger = logging.getLogger(__name__) -JSON_OUTPUT_FORMAT = "JSON_Compact" - - -class CursorState(Enum): - NONE = 1 - ERROR = 2 - DONE = 3 - CLOSED = 4 - - -# Parameters that should be set using USE instead of SET -USE_PARAMETER_LIST = ["database", "engine"] -# parameters that can only be set by the backend -DISALLOWED_PARAMETER_LIST = ["output_format"] -# parameters that are set by the backend and should not be set by the user -IMMUTABLE_PARAMETER_LIST = USE_PARAMETER_LIST + DISALLOWED_PARAMETER_LIST - -UPDATE_ENDPOINT_HEADER = "Firebolt-Update-Endpoint" -UPDATE_PARAMETERS_HEADER = "Firebolt-Update-Parameters" -RESET_SESSION_HEADER = "Firebolt-Reset-Session" - - def _parse_update_parameters(parameter_header: str) -> Dict[str, str]: """Parse update parameters and set them as attributes.""" # parse key1=value1,key2=value2 comma separated string into dict @@ -88,40 +70,6 @@ def _raise_if_internal_set_parameter(parameter: SetParameter) -> None: ) -@dataclass -class AsyncResponse: - token: str - message: str - monitorSql: str - - -@dataclass -class Statistics: - """ - Class for query execution statistics. - """ - - elapsed: float - rows_read: int - bytes_read: int - time_before_execution: float - time_to_execute: float - scanned_bytes_cache: Optional[float] = None - scanned_bytes_storage: Optional[float] = None - - def __post_init__(self) -> None: - for field in fields(self): - value = getattr(self, field.name) - _type = eval(field.type) # type: ignore - - # Unpack Optional - if hasattr(_type, "__args__"): - _type = _type.__args__[0] - if value is not None and not isinstance(value, _type): - # convert values to proper types - setattr(self, field.name, _type(value)) - - RowSet = Tuple[ int, Optional[List[Column]], diff --git a/src/firebolt/common/constants.py b/src/firebolt/common/constants.py index 2d11b82d1a..50493d0867 100644 --- a/src/firebolt/common/constants.py +++ b/src/firebolt/common/constants.py @@ -1,6 +1,29 @@ +from __future__ import annotations + +from enum import Enum + KEEPALIVE_FLAG: int = 1 KEEPIDLE_RATE: int = 60 # seconds DEFAULT_TIMEOUT_SECONDS: int = 60 -# Running statuses in infromation schema +# Running statuses in information schema ENGINE_STATUS_RUNNING_LIST = ["RUNNING", "Running", "ENGINE_STATE_RUNNING"] +JSON_OUTPUT_FORMAT = "JSON_Compact" + + +class CursorState(Enum): + NONE = 1 + ERROR = 2 + DONE = 3 + CLOSED = 4 + + +# Parameters that should be set using USE instead of SET +USE_PARAMETER_LIST = ["database", "engine"] +# parameters that can only be set by the backend +DISALLOWED_PARAMETER_LIST = ["output_format"] +# parameters that are set by the backend and should not be set by the user +IMMUTABLE_PARAMETER_LIST = USE_PARAMETER_LIST + DISALLOWED_PARAMETER_LIST +UPDATE_ENDPOINT_HEADER = "Firebolt-Update-Endpoint" +UPDATE_PARAMETERS_HEADER = "Firebolt-Update-Parameters" +RESET_SESSION_HEADER = "Firebolt-Reset-Session" diff --git a/src/firebolt/common/row_set/__init__.py b/src/firebolt/common/row_set/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/firebolt/common/row_set/async/__init__.py b/src/firebolt/common/row_set/async/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/firebolt/common/row_set/async/base.py b/src/firebolt/common/row_set/async/base.py new file mode 100644 index 0000000000..a9afba8ab5 --- /dev/null +++ b/src/firebolt/common/row_set/async/base.py @@ -0,0 +1,51 @@ +from abc import ABC, abstractmethod +from typing import List, Optional + +from async_property import async_property # type: ignore +from httpx import AsyncByteStream + +from firebolt.common._types import ColType +from firebolt.common.row_set.types import Column, Statistics + + +class BaseAsyncRowSet(ABC): + """ + Base class for all async row sets. + """ + + @classmethod + @abstractmethod + async def from_response_stream(cls, stream: AsyncByteStream) -> "BaseAsyncRowSet": + ... + + @async_property + @abstractmethod + async def row_count(self) -> Optional[int]: + ... + + @async_property + @abstractmethod + def statistics(self) -> Optional[Statistics]: + ... + + @async_property + @abstractmethod + async def columns(self) -> List[Column]: + ... + + @async_property + @abstractmethod + def nextset(self) -> bool: + ... + + @abstractmethod + async def __aiter__(self) -> "BaseAsyncRowSet": + ... + + @abstractmethod + async def __anext__(self) -> List[ColType]: + ... + + @abstractmethod + async def aclose(self) -> None: + ... diff --git a/src/firebolt/common/row_set/sync/__init__.py b/src/firebolt/common/row_set/sync/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/firebolt/common/row_set/sync/base.py b/src/firebolt/common/row_set/sync/base.py new file mode 100644 index 0000000000..a9e63fd92b --- /dev/null +++ b/src/firebolt/common/row_set/sync/base.py @@ -0,0 +1,53 @@ +from abc import ABC, abstractmethod +from typing import List, Optional + +from httpx import SyncByteStream + +from firebolt.common._types import ColType +from firebolt.common.row_set.types import Column, Statistics + + +class BaseRowSet(ABC): + """ + Base class for all sync row sets. + """ + + @classmethod + @abstractmethod + def from_response_stream(cls, stream: SyncByteStream) -> "BaseRowSet": + ... + + @property + @abstractmethod + def row_count(self) -> Optional[int]: + # This is optional because for streaming it will not be available + # until all rows are read + ... + + @property + @abstractmethod + def statistics(self) -> Optional[Statistics]: + # This is optional because for streaming it will not be available + # until all rows are read + ... + + @property + @abstractmethod + def columns(self) -> List[Column]: + ... + + @abstractmethod + def nextset(self) -> bool: + ... + + @abstractmethod + def __iter__(self) -> "BaseRowSet": + ... + + @abstractmethod + def __next__(self) -> List[ColType]: + ... + + @abstractmethod + def close(self) -> None: + ... diff --git a/src/firebolt/common/row_set/types.py b/src/firebolt/common/row_set/types.py new file mode 100644 index 0000000000..26963c6a34 --- /dev/null +++ b/src/firebolt/common/row_set/types.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from collections import namedtuple +from dataclasses import dataclass, fields +from typing import Optional + + +@dataclass +class AsyncResponse: + token: str + message: str + monitorSql: str + + +@dataclass +class Statistics: + """ + Class for query execution statistics. + """ + + elapsed: float + rows_read: int + bytes_read: int + time_before_execution: float + time_to_execute: float + scanned_bytes_cache: Optional[float] = None + scanned_bytes_storage: Optional[float] = None + + def __post_init__(self) -> None: + for field in fields(self): + value = getattr(self, field.name) + _type = eval(field.type) # type: ignore + + # Unpack Optional + if hasattr(_type, "__args__"): + _type = _type.__args__[0] + if value is not None and not isinstance(value, _type): + # convert values to proper types + setattr(self, field.name, _type(value)) + + +Column = namedtuple( + "Column", + ( + "name", + "type_code", + "display_size", + "internal_size", + "precision", + "scale", + "null_ok", + ), +) diff --git a/src/firebolt/db/cursor.py b/src/firebolt/db/cursor.py index 96e7038b19..8b6ba71002 100644 --- a/src/firebolt/db/cursor.py +++ b/src/firebolt/db/cursor.py @@ -27,12 +27,7 @@ from firebolt.client import Client, ClientV1, ClientV2 from firebolt.common._types import ColType, ParameterType, SetParameter from firebolt.common.base_cursor import ( - JSON_OUTPUT_FORMAT, - RESET_SESSION_HEADER, - UPDATE_ENDPOINT_HEADER, - UPDATE_PARAMETERS_HEADER, BaseCursor, - CursorState, _parse_update_endpoint, _parse_update_parameters, _raise_if_internal_set_parameter, @@ -40,6 +35,13 @@ check_not_closed, check_query_executed, ) +from firebolt.common.constants import ( + JSON_OUTPUT_FORMAT, + RESET_SESSION_HEADER, + UPDATE_ENDPOINT_HEADER, + UPDATE_PARAMETERS_HEADER, + CursorState, +) from firebolt.common.statement_formatter import create_statement_formatter from firebolt.utils.exception import ( EngineNotRunningError, diff --git a/tests/integration/dbapi/async/V1/test_queries_async.py b/tests/integration/dbapi/async/V1/test_queries_async.py index 3f02aaead5..2f9ac59be2 100644 --- a/tests/integration/dbapi/async/V1/test_queries_async.py +++ b/tests/integration/dbapi/async/V1/test_queries_async.py @@ -7,7 +7,8 @@ from pytest import fixture, mark, raises from firebolt.async_db import Binary, Connection, Cursor, OperationalError -from firebolt.common._types import ColType, Column +from firebolt.common._types import ColType +from firebolt.common.row_set.types import Column VALS_TO_INSERT_2 = ",".join( [f"({i}, {i-3}, '{val}')" for (i, val) in enumerate(range(4, 1000))] diff --git a/tests/integration/dbapi/async/V2/test_queries_async.py b/tests/integration/dbapi/async/V2/test_queries_async.py index 9811aad4ee..335fa358e7 100644 --- a/tests/integration/dbapi/async/V2/test_queries_async.py +++ b/tests/integration/dbapi/async/V2/test_queries_async.py @@ -9,7 +9,8 @@ from firebolt.async_db import Binary, Connection, Cursor, OperationalError from firebolt.async_db.connection import connect from firebolt.client.auth.base import Auth -from firebolt.common._types import ColType, Column +from firebolt.common._types import ColType +from firebolt.common.row_set.types import Column from tests.integration.dbapi.utils import assert_deep_eq VALS_TO_INSERT_2 = ",".join( diff --git a/tests/integration/dbapi/async/V2/test_system_engine_async.py b/tests/integration/dbapi/async/V2/test_system_engine_async.py index c3861d583d..321ef50547 100644 --- a/tests/integration/dbapi/async/V2/test_system_engine_async.py +++ b/tests/integration/dbapi/async/V2/test_system_engine_async.py @@ -4,7 +4,8 @@ from pytest import raises from firebolt.async_db import Connection -from firebolt.common._types import ColType, Column +from firebolt.common._types import ColType +from firebolt.common.row_set.types import Column from firebolt.utils.exception import FireboltStructuredError from tests.integration.dbapi.utils import assert_deep_eq diff --git a/tests/integration/dbapi/conftest.py b/tests/integration/dbapi/conftest.py index 60f513463e..9d2b53c98e 100644 --- a/tests/integration/dbapi/conftest.py +++ b/tests/integration/dbapi/conftest.py @@ -5,7 +5,8 @@ from pytest import fixture -from firebolt.common._types import STRUCT, ColType, Column +from firebolt.common._types import STRUCT, ColType +from firebolt.common.row_set.types import Column from firebolt.db import ARRAY, DECIMAL, Connection LOGGER = getLogger(__name__) diff --git a/tests/integration/dbapi/sync/V1/test_queries.py b/tests/integration/dbapi/sync/V1/test_queries.py index f810800949..9950e231d6 100644 --- a/tests/integration/dbapi/sync/V1/test_queries.py +++ b/tests/integration/dbapi/sync/V1/test_queries.py @@ -7,7 +7,8 @@ from pytest import fixture, mark, raises from firebolt.client.auth import Auth -from firebolt.common._types import ColType, Column +from firebolt.common._types import ColType +from firebolt.common.row_set.types import Column from firebolt.db import Binary, Connection, Cursor, OperationalError, connect VALS_TO_INSERT = ",".join([f"({i},'{val}')" for (i, val) in enumerate(range(1, 360))]) diff --git a/tests/integration/dbapi/sync/V2/test_queries.py b/tests/integration/dbapi/sync/V2/test_queries.py index fa4adfcd93..5df7ef99ad 100644 --- a/tests/integration/dbapi/sync/V2/test_queries.py +++ b/tests/integration/dbapi/sync/V2/test_queries.py @@ -8,7 +8,8 @@ from pytest import mark, raises from firebolt.client.auth import Auth -from firebolt.common._types import ColType, Column +from firebolt.common._types import ColType +from firebolt.common.row_set.types import Column from firebolt.db import Binary, Connection, Cursor, OperationalError, connect from tests.integration.dbapi.utils import assert_deep_eq diff --git a/tests/integration/dbapi/sync/V2/test_system_engine.py b/tests/integration/dbapi/sync/V2/test_system_engine.py index 7ba36b6bb6..e17f0ea495 100644 --- a/tests/integration/dbapi/sync/V2/test_system_engine.py +++ b/tests/integration/dbapi/sync/V2/test_system_engine.py @@ -3,7 +3,8 @@ from pytest import raises -from firebolt.common._types import ColType, Column +from firebolt.common._types import ColType +from firebolt.common.row_set.types import Column from firebolt.db import Connection from firebolt.utils.exception import FireboltStructuredError from tests.integration.dbapi.utils import assert_deep_eq diff --git a/tests/unit/V1/async_db/test_cursor.py b/tests/unit/V1/async_db/test_cursor.py index 2e5f6f39f0..a3eae31e5f 100644 --- a/tests/unit/V1/async_db/test_cursor.py +++ b/tests/unit/V1/async_db/test_cursor.py @@ -6,8 +6,9 @@ from pytest_httpx import HTTPXMock from firebolt.async_db import Cursor -from firebolt.common._types import Column -from firebolt.common.base_cursor import ColType, CursorState +from firebolt.common.base_cursor import ColType +from firebolt.common.constants import CursorState +from firebolt.common.row_set.types import Column from firebolt.utils.exception import ( ConfigurationError, CursorClosedError, diff --git a/tests/unit/V1/db/test_cursor.py b/tests/unit/V1/db/test_cursor.py index 06ff9f5ea8..59b33ea099 100644 --- a/tests/unit/V1/db/test_cursor.py +++ b/tests/unit/V1/db/test_cursor.py @@ -5,9 +5,10 @@ from pytest import LogCaptureFixture, mark, raises from pytest_httpx import HTTPXMock -from firebolt.common._types import Column +from firebolt.common.constants import CursorState +from firebolt.common.row_set.types import Column from firebolt.db import Cursor -from firebolt.db.cursor import ColType, CursorState +from firebolt.db.cursor import ColType from firebolt.utils.exception import ( ConfigurationError, CursorClosedError, diff --git a/tests/unit/async_db/test_cursor.py b/tests/unit/async_db/test_cursor.py index 40f3bc271e..bc3d0b0815 100644 --- a/tests/unit/async_db/test_cursor.py +++ b/tests/unit/async_db/test_cursor.py @@ -7,8 +7,9 @@ from pytest_httpx import HTTPXMock from firebolt.async_db import Cursor -from firebolt.common._types import Column -from firebolt.common.base_cursor import ColType, CursorState +from firebolt.common.base_cursor import ColType +from firebolt.common.constants import CursorState +from firebolt.common.row_set.types import Column from firebolt.utils.exception import ( ConfigurationError, CursorClosedError, diff --git a/tests/unit/db/test_cursor.py b/tests/unit/db/test_cursor.py index 9a12ee84ae..bba871ee66 100644 --- a/tests/unit/db/test_cursor.py +++ b/tests/unit/db/test_cursor.py @@ -6,9 +6,10 @@ from pytest import LogCaptureFixture, mark, raises from pytest_httpx import HTTPXMock -from firebolt.common._types import Column +from firebolt.common.constants import CursorState +from firebolt.common.row_set.types import Column from firebolt.db import Cursor -from firebolt.db.cursor import ColType, CursorState +from firebolt.db.cursor import ColType from firebolt.utils.exception import ( ConfigurationError, CursorClosedError, diff --git a/tests/unit/db_conftest.py b/tests/unit/db_conftest.py index b23a6f6aa1..5ca70fb3f1 100644 --- a/tests/unit/db_conftest.py +++ b/tests/unit/db_conftest.py @@ -7,8 +7,10 @@ from pytest import fixture from pytest_httpx import HTTPXMock -from firebolt.async_db.cursor import JSON_OUTPUT_FORMAT, ColType -from firebolt.common._types import STRUCT, Column +from firebolt.async_db.cursor import ColType +from firebolt.common._types import STRUCT +from firebolt.common.constants import JSON_OUTPUT_FORMAT +from firebolt.common.row_set.types import Column from firebolt.db import ARRAY, DECIMAL from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME from tests.unit.response import Response From 201d1b04652885655247302e6910e3b77ddcd9d4 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 2 Apr 2025 15:26:09 +0300 Subject: [PATCH 02/10] implement in memory rowset --- .../{async => asynchronous}/__init__.py | 0 .../row_set/{async => asynchronous}/base.py | 10 +-- .../common/row_set/asynchronous/in_memory.py | 44 ++++++++++ .../row_set/{sync => synchronous}/__init__.py | 0 .../row_set/{sync => synchronous}/base.py | 11 +-- .../common/row_set/synchronous/in_memory.py | 83 +++++++++++++++++++ src/firebolt/common/row_set/types.py | 54 ++++++++---- 7 files changed, 175 insertions(+), 27 deletions(-) rename src/firebolt/common/row_set/{async => asynchronous}/__init__.py (100%) rename src/firebolt/common/row_set/{async => asynchronous}/base.py (74%) create mode 100644 src/firebolt/common/row_set/asynchronous/in_memory.py rename src/firebolt/common/row_set/{sync => synchronous}/__init__.py (100%) rename src/firebolt/common/row_set/{sync => synchronous}/base.py (77%) create mode 100644 src/firebolt/common/row_set/synchronous/in_memory.py diff --git a/src/firebolt/common/row_set/async/__init__.py b/src/firebolt/common/row_set/asynchronous/__init__.py similarity index 100% rename from src/firebolt/common/row_set/async/__init__.py rename to src/firebolt/common/row_set/asynchronous/__init__.py diff --git a/src/firebolt/common/row_set/async/base.py b/src/firebolt/common/row_set/asynchronous/base.py similarity index 74% rename from src/firebolt/common/row_set/async/base.py rename to src/firebolt/common/row_set/asynchronous/base.py index a9afba8ab5..f0390e61a7 100644 --- a/src/firebolt/common/row_set/async/base.py +++ b/src/firebolt/common/row_set/asynchronous/base.py @@ -1,11 +1,10 @@ from abc import ABC, abstractmethod -from typing import List, Optional +from typing import AsyncIterator, List, Optional from async_property import async_property # type: ignore -from httpx import AsyncByteStream from firebolt.common._types import ColType -from firebolt.common.row_set.types import Column, Statistics +from firebolt.common.row_set.types import AsyncByteStream, Column, Statistics class BaseAsyncRowSet(ABC): @@ -13,9 +12,8 @@ class BaseAsyncRowSet(ABC): Base class for all async row sets. """ - @classmethod @abstractmethod - async def from_response_stream(cls, stream: AsyncByteStream) -> "BaseAsyncRowSet": + async def append_response_stream(self, stream: AsyncByteStream) -> None: ... @async_property @@ -39,7 +37,7 @@ def nextset(self) -> bool: ... @abstractmethod - async def __aiter__(self) -> "BaseAsyncRowSet": + def __aiter__(self) -> AsyncIterator[List[ColType]]: ... @abstractmethod diff --git a/src/firebolt/common/row_set/asynchronous/in_memory.py b/src/firebolt/common/row_set/asynchronous/in_memory.py new file mode 100644 index 0000000000..428ec8914d --- /dev/null +++ b/src/firebolt/common/row_set/asynchronous/in_memory.py @@ -0,0 +1,44 @@ +import io +from typing import AsyncIterator, List, Optional + +from firebolt.common._types import ColType +from firebolt.common.row_set.asynchronous.base import BaseAsyncRowSet +from firebolt.common.row_set.synchronous.in_memory import InMemoryRowSet +from firebolt.common.row_set.types import AsyncByteStream, Column, Statistics + + +class InMemoryAsyncRowSet(BaseAsyncRowSet): + """ + A row set that holds all rows in memory. + """ + + def __init__(self) -> None: + self._sync_row_set = InMemoryRowSet() + + async def append_response_stream(self, stream: AsyncByteStream) -> None: + sync_stream = io.BytesIO(b"".join([b async for b in stream])) + self._sync_row_set.append_response_stream(sync_stream) + + @property + async def row_count(self) -> int: + return self._sync_row_set.row_count + + @property + async def columns(self) -> List[Column]: + return self._sync_row_set.columns + + @property + def statistics(self) -> Optional[Statistics]: + return self._sync_row_set.statistics + + def nextset(self) -> bool: + return self._sync_row_set.nextset() + + def __aiter__(self) -> AsyncIterator[List[ColType]]: + return self + + async def __anext__(self) -> List[ColType]: + return next(self._sync_row_set) + + async def aclose(self) -> None: + return self._sync_row_set.close() diff --git a/src/firebolt/common/row_set/sync/__init__.py b/src/firebolt/common/row_set/synchronous/__init__.py similarity index 100% rename from src/firebolt/common/row_set/sync/__init__.py rename to src/firebolt/common/row_set/synchronous/__init__.py diff --git a/src/firebolt/common/row_set/sync/base.py b/src/firebolt/common/row_set/synchronous/base.py similarity index 77% rename from src/firebolt/common/row_set/sync/base.py rename to src/firebolt/common/row_set/synchronous/base.py index a9e63fd92b..827f9d11f8 100644 --- a/src/firebolt/common/row_set/sync/base.py +++ b/src/firebolt/common/row_set/synchronous/base.py @@ -1,10 +1,8 @@ from abc import ABC, abstractmethod -from typing import List, Optional - -from httpx import SyncByteStream +from typing import Iterator, List, Optional from firebolt.common._types import ColType -from firebolt.common.row_set.types import Column, Statistics +from firebolt.common.row_set.types import ByteStream, Column, Statistics class BaseRowSet(ABC): @@ -12,9 +10,8 @@ class BaseRowSet(ABC): Base class for all sync row sets. """ - @classmethod @abstractmethod - def from_response_stream(cls, stream: SyncByteStream) -> "BaseRowSet": + def append_response_stream(self, stream: ByteStream) -> None: ... @property @@ -41,7 +38,7 @@ def nextset(self) -> bool: ... @abstractmethod - def __iter__(self) -> "BaseRowSet": + def __iter__(self) -> Iterator[List[ColType]]: ... @abstractmethod diff --git a/src/firebolt/common/row_set/synchronous/in_memory.py b/src/firebolt/common/row_set/synchronous/in_memory.py new file mode 100644 index 0000000000..98559f7f26 --- /dev/null +++ b/src/firebolt/common/row_set/synchronous/in_memory.py @@ -0,0 +1,83 @@ +import json +from typing import Iterator, List, Optional + +from firebolt.common._types import ColType, RawColType, parse_type, parse_value +from firebolt.common.row_set.synchronous.base import BaseRowSet +from firebolt.common.row_set.types import ( + ByteStream, + Column, + RowsResponse, + Statistics, +) +from firebolt.utils.exception import DataError + + +class InMemoryRowSet(BaseRowSet): + """ + A row set that holds all rows in memory. + """ + + def __init__(self) -> None: + self._row_sets: List[RowsResponse] = [] + self._current_row_set_idx = 0 + self._current_row = -1 + + def append_response_stream(self, stream: ByteStream) -> None: + """ + Create an InMemoryRowSet from a response stream. + """ + try: + content = b"".join(stream) + query_data = json.loads(content) + columns = [ + Column(d["name"], parse_type(d["type"]), None, None, None, None, None) + for d in query_data["meta"] + ] + # Extract rows + rows = query_data["data"] + row_count = len(rows) + statistics = query_data.get("statistics") + self._row_sets.append(RowsResponse(row_count, columns, statistics, rows)) + except (KeyError, ValueError) as err: + raise DataError(f"Invalid query data format: {str(err)}") + + @property + def _row_set(self) -> RowsResponse: + return self._row_sets[self._current_row_set_idx] + + @property + def row_count(self) -> int: + return self._row_set.row_count + + @property + def columns(self) -> List[Column]: + return self._row_set.columns + + @property + def statistics(self) -> Optional[Statistics]: + return self._row_set.statistics + + def nextset(self) -> bool: + if self._current_row_set_idx + 1 < len(self._row_sets): + self._current_row_set_idx += 1 + self._current_row = -1 + return True + return False + + def _parse_row(self, row: List[RawColType]) -> List[ColType]: + assert len(row) == len(self.columns) + return [ + parse_value(col, self.columns[i].type_code) for i, col in enumerate(row) + ] + + def __iter__(self) -> Iterator[List[ColType]]: + return self + + def __next__(self) -> List[ColType]: + self._current_row += 1 + if self._current_row >= self._row_set.row_count: + raise StopIteration + return self._parse_row(self._row_set.rows[self._current_row]) + + def close(self) -> None: + pass diff --git a/src/firebolt/common/row_set/types.py b/src/firebolt/common/row_set/types.py index 26963c6a34..d008d8128b 100644 --- a/src/firebolt/common/row_set/types.py +++ b/src/firebolt/common/row_set/types.py @@ -1,8 +1,9 @@ from __future__ import annotations -from collections import namedtuple from dataclasses import dataclass, fields -from typing import Optional +from typing import AsyncIterator, Iterator, List, Optional, Protocol, Union + +from firebolt.common._types import ExtendedType, RawColType @dataclass @@ -39,15 +40,40 @@ def __post_init__(self) -> None: setattr(self, field.name, _type(value)) -Column = namedtuple( - "Column", - ( - "name", - "type_code", - "display_size", - "internal_size", - "precision", - "scale", - "null_ok", - ), -) +@dataclass +class RowsResponse: + """ + Class for query execution response. + """ + + row_count: int + columns: List[Column] + statistics: Optional[Statistics] + rows: List[List[RawColType]] + + +@dataclass +class Column: + name: str + type_code: Union[type, ExtendedType] + display_size: Optional[int] = None + internal_size: Optional[int] = None + precision: Optional[int] = None + scale: Optional[int] = None + null_ok: Optional[bool] = None + + +class ByteStream(Protocol): + def __iter__(self) -> Iterator[bytes]: + ... + + def close(self) -> None: + ... + + +class AsyncByteStream(Protocol): + def __aiter__(self) -> AsyncIterator[bytes]: + ... + + def aclose(self) -> None: + ... From e6e958d9db50c02d4b040cec4ec8a27ce83cd478 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Tue, 8 Apr 2025 16:09:13 +0300 Subject: [PATCH 03/10] fix tests --- src/firebolt/async_db/connection.py | 2 +- src/firebolt/async_db/cursor.py | 139 ++++++---- src/firebolt/common/_types.py | 2 +- src/firebolt/common/cursor/__init__.py | 0 .../common/{ => cursor}/base_cursor.py | 241 +++--------------- src/firebolt/common/cursor/decorators.py | 59 +++++ .../common/row_set/asynchronous/base.py | 30 +-- .../common/row_set/asynchronous/in_memory.py | 21 +- src/firebolt/common/row_set/base.py | 33 +++ .../common/row_set/synchronous/base.py | 33 +-- .../common/row_set/synchronous/in_memory.py | 65 +++-- src/firebolt/common/row_set/types.py | 2 +- src/firebolt/db/cursor.py | 87 ++++++- src/firebolt/utils/async_util.py | 36 +++ tests/integration/dbapi/async/V1/conftest.py | 14 + .../dbapi/async/V1/test_auth_async.py | 4 +- .../dbapi/async/V1/test_errors_async.py | 2 +- .../dbapi/async/V1/test_queries_async.py | 56 ++-- tests/integration/dbapi/async/V2/conftest.py | 2 +- .../dbapi/async/V2/test_auth_async.py | 4 +- .../dbapi/async/V2/test_errors_async.py | 2 +- .../dbapi/async/V2/test_queries_async.py | 36 +-- .../async/V2/test_system_engine_async.py | 4 +- .../dbapi/async/V2/test_timeout.py | 2 +- tests/integration/dbapi/conftest.py | 1 + tests/integration/dbapi/sync/V1/conftest.py | 14 + .../integration/dbapi/sync/V1/test_queries.py | 28 +- .../integration/dbapi/sync/V2/test_queries.py | 10 +- tests/unit/V1/async_db/test_cursor.py | 8 +- tests/unit/V1/db/test_cursor.py | 2 +- tests/unit/async_db/test_cursor.py | 8 +- tests/unit/common/test_base_cursor.py | 2 +- tests/unit/db/test_cursor.py | 2 +- 33 files changed, 503 insertions(+), 448 deletions(-) create mode 100644 src/firebolt/common/cursor/__init__.py rename src/firebolt/common/{ => cursor}/base_cursor.py (53%) create mode 100644 src/firebolt/common/cursor/decorators.py create mode 100644 src/firebolt/common/row_set/base.py create mode 100644 src/firebolt/utils/async_util.py diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 139a817895..05603bbe10 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -179,7 +179,7 @@ async def aclose(self) -> None: for c in cursors: # Here c can already be closed by another thread, # but it shouldn't raise an error in this case - c.close() + await c.aclose() await self._client.aclose() self._is_closed = True diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index ddeda5bfda..c5bb278cf3 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -5,16 +5,7 @@ from abc import ABCMeta, abstractmethod from functools import wraps from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Iterator, - List, - Optional, - Sequence, - Union, -) +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union from urllib.parse import urljoin from httpx import ( @@ -28,22 +19,26 @@ from firebolt.client.client import AsyncClient, AsyncClientV1, AsyncClientV2 from firebolt.common._types import ColType, ParameterType, SetParameter -from firebolt.common.base_cursor import ( +from firebolt.common.constants import ( + JSON_OUTPUT_FORMAT, + RESET_SESSION_HEADER, + UPDATE_ENDPOINT_HEADER, + UPDATE_PARAMETERS_HEADER, + CursorState, +) +from firebolt.common.cursor.base_cursor import ( BaseCursor, _parse_update_endpoint, _parse_update_parameters, _raise_if_internal_set_parameter, +) +from firebolt.common.cursor.decorators import ( async_not_allowed, check_not_closed, check_query_executed, ) -from firebolt.common.constants import ( - JSON_OUTPUT_FORMAT, - RESET_SESSION_HEADER, - UPDATE_ENDPOINT_HEADER, - UPDATE_PARAMETERS_HEADER, - CursorState, -) +from firebolt.common.row_set.asynchronous.base import BaseAsyncRowSet +from firebolt.common.row_set.asynchronous.in_memory import InMemoryAsyncRowSet from firebolt.common.statement_formatter import create_statement_formatter from firebolt.utils.exception import ( EngineNotRunningError, @@ -60,7 +55,12 @@ if TYPE_CHECKING: from firebolt.async_db.connection import Connection -from firebolt.utils.util import _print_error_body, raise_errors_from_body +from firebolt.utils.async_util import async_islice +from firebolt.utils.util import ( + Timer, + _print_error_body, + raise_errors_from_body, +) logger = logging.getLogger(__name__) @@ -90,6 +90,7 @@ def __init__( self._client = client self.connection = connection self.engine_url = connection.engine_url + self._row_set: Optional[BaseAsyncRowSet] = None if connection.init_parameters: self._update_set_parameters(connection.init_parameters) @@ -123,13 +124,14 @@ async def _api_request( if self.parameters: parameters = {**self.parameters, **parameters} try: - return await self._client.request( + req = self._client.build_request( url=urljoin(self.engine_url.rstrip("/") + "/", path or ""), method="POST", params=parameters, content=query, timeout=timeout if timeout is not None else USE_CLIENT_DEFAULT, ) + return await self._client.send(req, stream=True) except TimeoutException: raise QueryTimeoutError() @@ -172,6 +174,9 @@ async def _validate_set_parameter( # set parameter passed validation self._set_parameters[parameter.name] = parameter.value + # append empty result set + await self._append_row_set_from_response(None) + async def _parse_response_headers(self, headers: Headers) -> None: if headers.get(UPDATE_ENDPOINT_HEADER): endpoint, params = _parse_update_endpoint( @@ -273,8 +278,7 @@ async def _handle_query_execution( self._parse_async_response(resp) else: await self._parse_response_headers(resp.headers) - row_set = self._row_set_from_response(resp) - self._append_row_set(row_set) + await self._append_row_set_from_response(resp) @check_not_closed async def execute( @@ -355,38 +359,76 @@ async def executemany( await self._do_execute(query, parameters_seq, timeout=timeout_seconds) return self.rowcount - @abstractmethod - async def is_db_available(self, database: str) -> bool: - """Verify that the database exists.""" - ... + async def _append_row_set_from_response( + self, + response: Optional[Response], + ) -> None: + """Store information about executed query.""" + if self._row_set is None: + self._row_set = InMemoryAsyncRowSet() + if response is None: + self._row_set.append_empty_response() + else: + await self._row_set.append_response(response) - @abstractmethod - async def is_engine_running(self, engine_url: str) -> bool: - """Verify that the engine is running.""" - ... + _performance_log_message = ( + "[PERFORMANCE] Parsing query output into native Python types" + ) - @wraps(BaseCursor.fetchone) + @check_not_closed + @async_not_allowed + @check_query_executed async def fetchone(self) -> Optional[List[ColType]]: """Fetch the next row of a query result set.""" - return super().fetchone() + assert self._row_set is not None + with Timer(self._performance_log_message): + # anext() is only supported in Python 3.10+ + try: + return await self._row_set.__anext__() + except StopAsyncIteration: + return None - @wraps(BaseCursor.fetchmany) + @check_not_closed + @async_not_allowed + @check_query_executed async def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]: """ Fetch the next set of rows of a query result; - size is cursor.arraysize by default. + cursor.arraysize is default size. """ - return super().fetchmany(size) + assert self._row_set is not None + size = size if size is not None else self.arraysize + with Timer(self._performance_log_message): + return await async_islice(self._row_set, size) - @wraps(BaseCursor.fetchall) + @check_not_closed + @async_not_allowed + @check_query_executed async def fetchall(self) -> List[List[ColType]]: """Fetch all remaining rows of a query result.""" - return super().fetchall() + assert self._row_set is not None + with Timer(self._performance_log_message): + return [it async for it in self._row_set] @wraps(BaseCursor.nextset) async def nextset(self) -> None: return super().nextset() + async def aclose(self) -> None: + super().close() + if self._row_set is not None: + await self._row_set.aclose() + + @abstractmethod + async def is_db_available(self, database: str) -> bool: + """Verify that the database exists.""" + ... + + @abstractmethod + async def is_engine_running(self, engine_url: str) -> bool: + """Verify that the engine is running.""" + ... + # Iteration support @check_not_closed @async_not_allowed @@ -394,36 +436,21 @@ async def nextset(self) -> None: def __aiter__(self) -> Cursor: return self - # TODO: figure out how to implement __aenter__ and __await__ @check_not_closed - def __aenter__(self) -> Cursor: + async def __aenter__(self) -> Cursor: return self - @check_not_closed - def __enter__(self) -> Cursor: - return self - - def __exit__( - self, exc_type: type, exc_val: Exception, exc_tb: TracebackType - ) -> None: - self.close() - - def __await__(self) -> Iterator: - yield None - async def __aexit__( self, exc_type: type, exc_val: Exception, exc_tb: TracebackType ) -> None: - self.close() + await self.aclose() @check_not_closed @async_not_allowed @check_query_executed async def __anext__(self) -> List[ColType]: - row = await self.fetchone() - if row is None: - raise StopAsyncIteration - return row + assert self._row_set is not None + return await self._row_set.__anext__() class CursorV2(Cursor): diff --git a/src/firebolt/common/_types.py b/src/firebolt/common/_types.py index 172fe00b04..b72e0f237d 100644 --- a/src/firebolt/common/_types.py +++ b/src/firebolt/common/_types.py @@ -325,7 +325,7 @@ def parse_value( raise DataError(f"Invalid bytea value {value}: str expected") return _parse_bytea(value) if isinstance(ctype, DECIMAL): - if not isinstance(value, (str, int)): + if not isinstance(value, (str, int, float)): raise DataError(f"Invalid decimal value {value}: str or int expected") return Decimal(value) if isinstance(ctype, ARRAY): diff --git a/src/firebolt/common/cursor/__init__.py b/src/firebolt/common/cursor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/firebolt/common/base_cursor.py b/src/firebolt/common/cursor/base_cursor.py similarity index 53% rename from src/firebolt/common/base_cursor.py rename to src/firebolt/common/cursor/base_cursor.py index c6e8e28c77..e3f31910a2 100644 --- a/src/firebolt/common/base_cursor.py +++ b/src/firebolt/common/cursor/base_cursor.py @@ -2,36 +2,28 @@ import logging import re -from functools import wraps from types import TracebackType -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from httpx import URL, Response -from firebolt.common._types import ( - ColType, - RawColType, - SetParameter, - parse_type, - parse_value, -) +from firebolt.common._types import RawColType, SetParameter from firebolt.common.constants import ( DISALLOWED_PARAMETER_LIST, IMMUTABLE_PARAMETER_LIST, USE_PARAMETER_LIST, CursorState, ) +from firebolt.common.cursor.decorators import ( + async_not_allowed, + check_not_closed, + check_query_executed, +) +from firebolt.common.row_set.base import BaseRowSet from firebolt.common.row_set.types import AsyncResponse, Column, Statistics from firebolt.common.statement_formatter import StatementFormatter -from firebolt.utils.exception import ( - ConfigurationError, - CursorClosedError, - DataError, - FireboltError, - MethodNotAllowedInAsyncError, - QueryNotRunError, -) -from firebolt.utils.util import Timer, fix_url_schema +from firebolt.utils.exception import ConfigurationError, FireboltError +from firebolt.utils.util import fix_url_schema logger = logging.getLogger(__name__) @@ -78,51 +70,6 @@ def _raise_if_internal_set_parameter(parameter: SetParameter) -> None: ] -def check_not_closed(func: Callable) -> Callable: - """(Decorator) ensure cursor is not closed before calling method.""" - - @wraps(func) - def inner(self: BaseCursor, *args: Any, **kwargs: Any) -> Any: - if self.closed: - raise CursorClosedError(method_name=func.__name__) - return func(self, *args, **kwargs) - - return inner - - -def check_query_executed(func: Callable) -> Callable: - """ - (Decorator) ensure that some query has been executed before - calling cursor method. - """ - - @wraps(func) - def inner(self: BaseCursor, *args: Any, **kwargs: Any) -> Any: - if self._state == CursorState.NONE: - raise QueryNotRunError(method_name=func.__name__) - if self._query_token: - # query_token is set only for async queries - raise MethodNotAllowedInAsyncError(method_name=func.__name__) - return func(self, *args, **kwargs) - - return inner - - -def async_not_allowed(func: Callable) -> Callable: - """ - (Decorator) ensure that fetch methods are not called on async queries. - """ - - @wraps(func) - def inner(self: BaseCursor, *args: Any, **kwargs: Any) -> Any: - if self._query_token: - # query_token is set only for async queries - raise MethodNotAllowedInAsyncError(method_name=func.__name__) - return func(self, *args, **kwargs) - - return inner - - class BaseCursor: __slots__ = ( "connection", @@ -130,17 +77,11 @@ class BaseCursor: "_arraysize", "_client", "_state", - "_descriptions", - "_statistics", - "_rowcount", - "_rows", - "_idx", - "_row_sets", "_formatter", - "_next_set_idx", "_set_parameters", "_query_id", "_query_token", + "_row_set", "engine_url", ) @@ -151,21 +92,15 @@ def __init__( ) -> None: self._arraysize = self.default_arraysize # These fields initialized here for type annotations purpose - self._rows: Optional[List[List[RawColType]]] = None - self._descriptions: Optional[List[Column]] = None - self._statistics: Optional[Statistics] = None - self._row_sets: List[RowSet] = [] self._formatter = formatter # User-defined set parameters self._set_parameters: Dict[str, Any] = dict() # Server-side parameters (user can't change them) self.parameters: Dict[str, str] = dict() self.engine_url = "" - self._rowcount = -1 - self._idx = 0 - self._next_set_idx = 0 self._query_id = "" self._query_token = "" + self._row_set: Optional[BaseRowSet] = None self._reset() @property @@ -191,19 +126,25 @@ def description(self) -> Optional[List[Column]]: * ``scale`` * ``null_ok`` """ - return self._descriptions + if not self._row_set: + return None + return self._row_set.columns @property # type: ignore @check_not_closed def statistics(self) -> Optional[Statistics]: """Query execution statistics returned by the backend.""" - return self._statistics + if not self._row_set: + return None + return self._row_set.statistics @property # type: ignore @check_not_closed def rowcount(self) -> int: """The number of rows produced by last query.""" - return self._rowcount + if not self._row_set: + return -1 + return self._row_set.row_count @property # type: ignore @check_not_closed @@ -234,38 +175,23 @@ def arraysize(self, value: int) -> None: ) self._arraysize = value - @property - def closed(self) -> bool: - """True if connection is closed, False otherwise.""" - return self._state == CursorState.CLOSED - @check_not_closed @async_not_allowed @check_query_executed - def nextset(self) -> Optional[bool]: + def nextset(self) -> bool: """ Skip to the next available set, discarding any remaining rows from the current set. Returns True if operation was successful; - None if there are no more sets to retrive. + False if there are no more sets to retrieve. """ - return self._pop_next_set() + assert self._row_set is not None + return self._row_set.nextset() - def _pop_next_set(self) -> Optional[bool]: - """ - Same functionality as .nextset, but doesn't check that query has been executed. - """ - if self._next_set_idx >= len(self._row_sets): - return None - ( - self._rowcount, - self._descriptions, - self._statistics, - self._rows, - ) = self._row_sets[self._next_set_idx] - self._idx = 0 - self._next_set_idx += 1 - return True + @property + def closed(self) -> bool: + """True if connection is closed, False otherwise.""" + return self._state == CursorState.CLOSED def flush_parameters(self) -> None: """Cleanup all previously set parameters""" @@ -274,13 +200,7 @@ def flush_parameters(self) -> None: def _reset(self) -> None: """Clear all data stored from previous query.""" self._state = CursorState.NONE - self._rows = None - self._descriptions = None - self._statistics = None - self._rowcount = -1 - self._idx = 0 - self._row_sets = [] - self._next_set_idx = 0 + self._row_set = None self._query_id = "" self._query_token = "" @@ -332,107 +252,6 @@ def _parse_async_response(self, response: Response) -> None: async_response = AsyncResponse(**response.json()) self._query_token = async_response.token - def _row_set_from_response(self, response: Response) -> RowSet: - """Fetch information about executed query from http response.""" - - # Empty response is returned for insert query - if response.headers.get("content-length", "") == "0": - return -1, None, None, None - try: - # Skip parsing floats to properly parse them later - query_data = response.json(parse_float=str) - rowcount = int(query_data["rows"]) - descriptions: Optional[List[Column]] = [ - Column(d["name"], parse_type(d["type"]), None, None, None, None, None) - for d in query_data["meta"] - ] - if not descriptions: - descriptions = None - statistics = Statistics(**query_data["statistics"]) - # Parse data during fetch - rows = query_data["data"] - return rowcount, descriptions, statistics, rows - except (KeyError, ValueError) as err: - raise DataError(f"Invalid query data format: {str(err)}") - - def _append_row_set( - self, - row_set: RowSet, - ) -> None: - """Store information about executed query.""" - self._row_sets.append(row_set) - if self._next_set_idx == 0: - # Populate values for first set - self._pop_next_set() - - def _parse_row(self, row: List[RawColType]) -> List[ColType]: - """Parse a single data row based on query column types.""" - assert len(row) == len(self.description) - return [ - parse_value(col, self.description[i].type_code) for i, col in enumerate(row) - ] - - def _get_next_range(self, size: int) -> Tuple[int, int]: - """ - Return range of next rows of size (if possible), - and update _idx to point to the end of this range - """ - - if self._rows is None: - # No elements to take - raise DataError("no rows to fetch") - - left = self._idx - right = min(self._idx + size, len(self._rows)) - self._idx = right - return left, right - - _performance_log_message = ( - "[PERFORMANCE] Parsing query output into native Python types" - ) - - @check_not_closed - @async_not_allowed - @check_query_executed - def fetchone(self) -> Optional[List[ColType]]: - """Fetch the next row of a query result set.""" - left, right = self._get_next_range(1) - if left == right: - # We are out of elements - return None - assert self._rows is not None - with Timer(self._performance_log_message): - result = self._parse_row(self._rows[left]) - return result - - @check_not_closed - @async_not_allowed - @check_query_executed - def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]: - """ - Fetch the next set of rows of a query result; - cursor.arraysize is default size. - """ - size = size if size is not None else self.arraysize - left, right = self._get_next_range(size) - assert self._rows is not None - rows = self._rows[left:right] - with Timer(self._performance_log_message): - result = [self._parse_row(row) for row in rows] - return result - - @check_not_closed - @async_not_allowed - @check_query_executed - def fetchall(self) -> List[List[ColType]]: - """Fetch all remaining rows of a query result.""" - left, right = self._get_next_range(self.rowcount) - assert self._rows is not None - rows = self._rows[left:right] - with Timer(self._performance_log_message): - result = [self._parse_row(row) for row in rows] - return result - @check_not_closed def setinputsizes(self, sizes: List[int]) -> None: """Predefine memory areas for query parameters (does nothing).""" diff --git a/src/firebolt/common/cursor/decorators.py b/src/firebolt/common/cursor/decorators.py new file mode 100644 index 0000000000..1507151ee8 --- /dev/null +++ b/src/firebolt/common/cursor/decorators.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable + +from firebolt.common.constants import CursorState +from firebolt.utils.exception import ( + CursorClosedError, + MethodNotAllowedInAsyncError, + QueryNotRunError, +) + +if TYPE_CHECKING: + from firebolt.common.cursor.base_cursor import BaseCursor + + +def check_not_closed(func: Callable) -> Callable: + """(Decorator) ensure cursor is not closed before calling method.""" + + @wraps(func) + def inner(self: BaseCursor, *args: Any, **kwargs: Any) -> Any: + if self.closed: + raise CursorClosedError(method_name=func.__name__) + return func(self, *args, **kwargs) + + return inner + + +def check_query_executed(func: Callable) -> Callable: + """ + (Decorator) ensure that some query has been executed before + calling cursor method. + """ + + @wraps(func) + def inner(self: BaseCursor, *args: Any, **kwargs: Any) -> Any: + if self._state == CursorState.NONE or self._row_set is None: + raise QueryNotRunError(method_name=func.__name__) + if self._query_token: + # query_token is set only for async queries + raise MethodNotAllowedInAsyncError(method_name=func.__name__) + return func(self, *args, **kwargs) + + return inner + + +def async_not_allowed(func: Callable) -> Callable: + """ + (Decorator) ensure that fetch methods are not called on async queries. + """ + + @wraps(func) + def inner(self: BaseCursor, *args: Any, **kwargs: Any) -> Any: + if self._query_token: + # query_token is set only for async queries + raise MethodNotAllowedInAsyncError(method_name=func.__name__) + return func(self, *args, **kwargs) + + return inner diff --git a/src/firebolt/common/row_set/asynchronous/base.py b/src/firebolt/common/row_set/asynchronous/base.py index f0390e61a7..5b02e6acce 100644 --- a/src/firebolt/common/row_set/asynchronous/base.py +++ b/src/firebolt/common/row_set/asynchronous/base.py @@ -1,39 +1,19 @@ from abc import ABC, abstractmethod -from typing import AsyncIterator, List, Optional +from typing import AsyncIterator, List -from async_property import async_property # type: ignore +from httpx import Response from firebolt.common._types import ColType -from firebolt.common.row_set.types import AsyncByteStream, Column, Statistics +from firebolt.common.row_set.base import BaseRowSet -class BaseAsyncRowSet(ABC): +class BaseAsyncRowSet(BaseRowSet, ABC): """ Base class for all async row sets. """ @abstractmethod - async def append_response_stream(self, stream: AsyncByteStream) -> None: - ... - - @async_property - @abstractmethod - async def row_count(self) -> Optional[int]: - ... - - @async_property - @abstractmethod - def statistics(self) -> Optional[Statistics]: - ... - - @async_property - @abstractmethod - async def columns(self) -> List[Column]: - ... - - @async_property - @abstractmethod - def nextset(self) -> bool: + async def append_response(self, response: Response) -> None: ... @abstractmethod diff --git a/src/firebolt/common/row_set/asynchronous/in_memory.py b/src/firebolt/common/row_set/asynchronous/in_memory.py index 428ec8914d..2a08e8d8c9 100644 --- a/src/firebolt/common/row_set/asynchronous/in_memory.py +++ b/src/firebolt/common/row_set/asynchronous/in_memory.py @@ -1,10 +1,12 @@ import io from typing import AsyncIterator, List, Optional +from httpx import Response + from firebolt.common._types import ColType from firebolt.common.row_set.asynchronous.base import BaseAsyncRowSet from firebolt.common.row_set.synchronous.in_memory import InMemoryRowSet -from firebolt.common.row_set.types import AsyncByteStream, Column, Statistics +from firebolt.common.row_set.types import Column, Statistics class InMemoryAsyncRowSet(BaseAsyncRowSet): @@ -15,16 +17,20 @@ class InMemoryAsyncRowSet(BaseAsyncRowSet): def __init__(self) -> None: self._sync_row_set = InMemoryRowSet() - async def append_response_stream(self, stream: AsyncByteStream) -> None: - sync_stream = io.BytesIO(b"".join([b async for b in stream])) + def append_empty_response(self) -> None: + self._sync_row_set.append_empty_response() + + async def append_response(self, response: Response) -> None: + sync_stream = io.BytesIO(b"".join([b async for b in response.aiter_bytes()])) self._sync_row_set.append_response_stream(sync_stream) + await response.aclose() @property - async def row_count(self) -> int: + def row_count(self) -> int: return self._sync_row_set.row_count @property - async def columns(self) -> List[Column]: + def columns(self) -> List[Column]: return self._sync_row_set.columns @property @@ -38,7 +44,10 @@ def __aiter__(self) -> AsyncIterator[List[ColType]]: return self async def __anext__(self) -> List[ColType]: - return next(self._sync_row_set) + try: + return next(self._sync_row_set) + except StopIteration: + raise StopAsyncIteration async def aclose(self) -> None: return self._sync_row_set.close() diff --git a/src/firebolt/common/row_set/base.py b/src/firebolt/common/row_set/base.py new file mode 100644 index 0000000000..b1bfceb3eb --- /dev/null +++ b/src/firebolt/common/row_set/base.py @@ -0,0 +1,33 @@ +from abc import ABC, abstractmethod +from typing import List, Optional + +from firebolt.common.row_set.types import Column, Statistics + + +class BaseRowSet(ABC): + """ + Base class for all async row sets. + """ + + @property + @abstractmethod + def row_count(self) -> int: + ... + + @property + @abstractmethod + def statistics(self) -> Optional[Statistics]: + ... + + @property + @abstractmethod + def columns(self) -> List[Column]: + ... + + @abstractmethod + def nextset(self) -> bool: + ... + + @abstractmethod + def append_empty_response(self) -> None: + ... diff --git a/src/firebolt/common/row_set/synchronous/base.py b/src/firebolt/common/row_set/synchronous/base.py index 827f9d11f8..7d273335fc 100644 --- a/src/firebolt/common/row_set/synchronous/base.py +++ b/src/firebolt/common/row_set/synchronous/base.py @@ -1,40 +1,19 @@ from abc import ABC, abstractmethod -from typing import Iterator, List, Optional +from typing import Iterator, List + +from httpx import Response from firebolt.common._types import ColType -from firebolt.common.row_set.types import ByteStream, Column, Statistics +from firebolt.common.row_set.base import BaseRowSet -class BaseRowSet(ABC): +class BaseSyncRowSet(BaseRowSet, ABC): """ Base class for all sync row sets. """ @abstractmethod - def append_response_stream(self, stream: ByteStream) -> None: - ... - - @property - @abstractmethod - def row_count(self) -> Optional[int]: - # This is optional because for streaming it will not be available - # until all rows are read - ... - - @property - @abstractmethod - def statistics(self) -> Optional[Statistics]: - # This is optional because for streaming it will not be available - # until all rows are read - ... - - @property - @abstractmethod - def columns(self) -> List[Column]: - ... - - @abstractmethod - def nextset(self) -> bool: + def append_response(self, response: Response) -> None: ... @abstractmethod diff --git a/src/firebolt/common/row_set/synchronous/in_memory.py b/src/firebolt/common/row_set/synchronous/in_memory.py index 98559f7f26..da1203bd79 100644 --- a/src/firebolt/common/row_set/synchronous/in_memory.py +++ b/src/firebolt/common/row_set/synchronous/in_memory.py @@ -1,18 +1,15 @@ import json from typing import Iterator, List, Optional +from httpx import Response + from firebolt.common._types import ColType, RawColType, parse_type, parse_value -from firebolt.common.row_set.synchronous.base import BaseRowSet -from firebolt.common.row_set.types import ( - ByteStream, - Column, - RowsResponse, - Statistics, -) +from firebolt.common.row_set.synchronous.base import BaseSyncRowSet +from firebolt.common.row_set.types import Column, RowsResponse, Statistics from firebolt.utils.exception import DataError -class InMemoryRowSet(BaseRowSet): +class InMemoryRowSet(BaseSyncRowSet): """ A row set that holds all rows in memory. """ @@ -22,24 +19,44 @@ def __init__(self) -> None: self._current_row_set_idx = 0 self._current_row = -1 - def append_response_stream(self, stream: ByteStream) -> None: + def append_empty_response(self) -> None: + """ + Create an InMemoryRowSet from an empty response. + """ + self._row_sets.append(RowsResponse(-1, [], None, [])) + + def append_response(self, response: Response) -> None: + """ + Create an InMemoryRowSet from a response. + """ + self.append_response_stream(response.iter_bytes()) + response.close() + + def append_response_stream(self, stream: Iterator[bytes]) -> None: """ Create an InMemoryRowSet from a response stream. """ - try: - content = b"".join(stream) - query_data = json.loads(content) - columns = [ - Column(d["name"], parse_type(d["type"]), None, None, None, None, None) - for d in query_data["meta"] - ] - # Extract rows - rows = query_data["data"] - row_count = len(rows) - statistics = query_data.get("statistics") - self._row_sets.append(RowsResponse(row_count, columns, statistics, rows)) - except (KeyError, ValueError) as err: - raise DataError(f"Invalid query data format: {str(err)}") + content = b"".join(stream).decode("utf-8") + if len(content) == 0: + self.append_empty_response() + else: + try: + query_data = json.loads(content) + columns = [ + Column( + d["name"], parse_type(d["type"]), None, None, None, None, None + ) + for d in query_data["meta"] + ] + # Extract rows + rows = query_data["data"] + row_count = len(rows) + statistics = Statistics(**query_data.get("statistics", {})) + self._row_sets.append( + RowsResponse(row_count, columns, statistics, rows) + ) + except (KeyError, ValueError) as err: + raise DataError(f"Invalid query data format: {str(err)}") @property def _row_set(self) -> RowsResponse: @@ -74,6 +91,8 @@ def __iter__(self) -> Iterator[List[ColType]]: return self def __next__(self) -> List[ColType]: + if self._row_set.row_count == -1: + raise DataError("no rows to fetch") self._current_row += 1 if self._current_row >= self._row_set.row_count: raise StopIteration diff --git a/src/firebolt/common/row_set/types.py b/src/firebolt/common/row_set/types.py index d008d8128b..924f720dc5 100644 --- a/src/firebolt/common/row_set/types.py +++ b/src/firebolt/common/row_set/types.py @@ -75,5 +75,5 @@ class AsyncByteStream(Protocol): def __aiter__(self) -> AsyncIterator[bytes]: ... - def aclose(self) -> None: + async def aclose(self) -> None: ... diff --git a/src/firebolt/db/cursor.py b/src/firebolt/db/cursor.py index 8b6ba71002..5cf88fc40c 100644 --- a/src/firebolt/db/cursor.py +++ b/src/firebolt/db/cursor.py @@ -3,6 +3,7 @@ import logging import time from abc import ABCMeta, abstractmethod +from itertools import islice from typing import ( TYPE_CHECKING, Any, @@ -26,22 +27,26 @@ from firebolt.client import Client, ClientV1, ClientV2 from firebolt.common._types import ColType, ParameterType, SetParameter -from firebolt.common.base_cursor import ( +from firebolt.common.constants import ( + JSON_OUTPUT_FORMAT, + RESET_SESSION_HEADER, + UPDATE_ENDPOINT_HEADER, + UPDATE_PARAMETERS_HEADER, + CursorState, +) +from firebolt.common.cursor.base_cursor import ( BaseCursor, _parse_update_endpoint, _parse_update_parameters, _raise_if_internal_set_parameter, +) +from firebolt.common.cursor.decorators import ( async_not_allowed, check_not_closed, check_query_executed, ) -from firebolt.common.constants import ( - JSON_OUTPUT_FORMAT, - RESET_SESSION_HEADER, - UPDATE_ENDPOINT_HEADER, - UPDATE_PARAMETERS_HEADER, - CursorState, -) +from firebolt.common.row_set.synchronous.base import BaseSyncRowSet +from firebolt.common.row_set.synchronous.in_memory import InMemoryRowSet from firebolt.common.statement_formatter import create_statement_formatter from firebolt.utils.exception import ( EngineNotRunningError, @@ -54,7 +59,11 @@ ) from firebolt.utils.timeout_controller import TimeoutController from firebolt.utils.urls import DATABASES_URL, ENGINES_URL -from firebolt.utils.util import _print_error_body, raise_errors_from_body +from firebolt.utils.util import ( + Timer, + _print_error_body, + raise_errors_from_body, +) if TYPE_CHECKING: from firebolt.db.connection import Connection @@ -87,6 +96,7 @@ def __init__( self._client = client self.connection = connection self.engine_url = connection.engine_url + self._row_set: Optional[BaseSyncRowSet] = None if connection.init_parameters: self._update_set_parameters(connection.init_parameters) @@ -170,6 +180,9 @@ def _validate_set_parameter( # set parameter passed validation self._set_parameters[parameter.name] = parameter.value + # append empty result set + self._append_row_set_from_response(None) + def _parse_response_headers(self, headers: Headers) -> None: if headers.get(UPDATE_ENDPOINT_HEADER): endpoint, params = _parse_update_endpoint( @@ -267,8 +280,7 @@ def _handle_query_execution( self._parse_async_response(resp) else: self._parse_response_headers(resp.headers) - row_set = self._row_set_from_response(resp) - self._append_row_set(row_set) + self._append_row_set_from_response(resp) @check_not_closed def execute( @@ -347,6 +359,59 @@ def executemany( self._do_execute(query, parameters_seq, timeout=timeout_seconds) return self.rowcount + def _append_row_set_from_response( + self, + response: Optional[Response], + ) -> None: + """Store information about executed query.""" + if self._row_set is None: + self._row_set = InMemoryRowSet() + + if response is None: + self._row_set.append_empty_response() + else: + self._row_set.append_response(response) + + _performance_log_message = ( + "[PERFORMANCE] Parsing query output into native Python types" + ) + + @check_not_closed + @async_not_allowed + @check_query_executed + def fetchone(self) -> Optional[List[ColType]]: + """Fetch the next row of a query result set.""" + assert self._row_set is not None + with Timer(self._performance_log_message): + return next(self._row_set, None) + + @check_not_closed + @async_not_allowed + @check_query_executed + def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]: + """ + Fetch the next set of rows of a query result; + cursor.arraysize is default size. + """ + assert self._row_set is not None + size = size if size is not None else self.arraysize + with Timer(self._performance_log_message): + return list(islice(self._row_set, size)) + + @check_not_closed + @async_not_allowed + @check_query_executed + def fetchall(self) -> List[List[ColType]]: + """Fetch all remaining rows of a query result.""" + assert self._row_set is not None + with Timer(self._performance_log_message): + return list(self._row_set) + + def close(self) -> None: + super().close() + if self._row_set is not None: + self._row_set.close() + @abstractmethod def is_db_available(self, database: str) -> bool: """Verify that the database exists.""" diff --git a/src/firebolt/utils/async_util.py b/src/firebolt/utils/async_util.py new file mode 100644 index 0000000000..078b4fea6b --- /dev/null +++ b/src/firebolt/utils/async_util.py @@ -0,0 +1,36 @@ +from typing import AsyncIterator, List, TypeVar + +from firebolt.common.row_set.types import AsyncByteStream + +TIter = TypeVar("TIter") + + +async def async_islice(async_iterator: AsyncIterator[TIter], n: int) -> List[TIter]: + result = [] + try: + for _ in range(n): + result.append(await async_iterator.__anext__()) + except StopAsyncIteration: + pass + return result + + +def async_byte_stream(b: bytes) -> AsyncByteStream: + class ABS: + def __init__(self, b: bytes): + self.b = b + self.read = False + + def __aiter__(self) -> AsyncIterator[bytes]: + return self + + async def __anext__(self) -> bytes: + if self.read: + raise StopAsyncIteration + self.read = True + return self.b + + async def aclose(self) -> None: + pass + + return ABS(b) diff --git a/tests/integration/dbapi/async/V1/conftest.py b/tests/integration/dbapi/async/V1/conftest.py index 895d4b2d56..cb2fc40a56 100644 --- a/tests/integration/dbapi/async/V1/conftest.py +++ b/tests/integration/dbapi/async/V1/conftest.py @@ -1,7 +1,10 @@ +from decimal import Decimal + from pytest import fixture from firebolt.async_db import Connection, connect from firebolt.client.auth.base import Auth +from firebolt.common._types import ColType @fixture @@ -74,3 +77,14 @@ async def connection_no_engine( api_endpoint=api_endpoint, ) as connection: yield connection + + +@fixture +def all_types_query_response_v1( + all_types_query_response: list[list[ColType]], +) -> list[list[ColType]]: + """ + V1 still returns decimals as floats, despite overflow. That's why it's not fully accurate. + """ + all_types_query_response[0][18] = Decimal("1231232.1234599999152123928070068359375") + return all_types_query_response diff --git a/tests/integration/dbapi/async/V1/test_auth_async.py b/tests/integration/dbapi/async/V1/test_auth_async.py index 081f17a304..98f71a2dcb 100644 --- a/tests/integration/dbapi/async/V1/test_auth_async.py +++ b/tests/integration/dbapi/async/V1/test_auth_async.py @@ -9,7 +9,7 @@ @mark.skip(reason="flaky, token not updated each time") async def test_refresh_token(connection: Connection) -> None: """Auth refreshes token on expiration/invalidation""" - with connection.cursor() as c: + async with connection.cursor() as c: # Works fine await c.execute("show tables") @@ -35,7 +35,7 @@ async def test_credentials_invalidation( """Auth raises authentication error on credentials invalidation""" # Can't pytest.parametrize it due to nested event loop error for conn in [connection, username_password_connection]: - with conn.cursor() as c: + async with conn.cursor() as c: # Works fine await c.execute("show tables") diff --git a/tests/integration/dbapi/async/V1/test_errors_async.py b/tests/integration/dbapi/async/V1/test_errors_async.py index be19f63a9b..1fb46143f4 100644 --- a/tests/integration/dbapi/async/V1/test_errors_async.py +++ b/tests/integration/dbapi/async/V1/test_errors_async.py @@ -117,7 +117,7 @@ async def test_database_not_exists( async def test_sql_error(connection: Connection) -> None: """Connection properly reacts to SQL execution error.""" - with connection.cursor() as c: + async with connection.cursor() as c: with raises(OperationalError) as exc_info: await c.execute("select ]") diff --git a/tests/integration/dbapi/async/V1/test_queries_async.py b/tests/integration/dbapi/async/V1/test_queries_async.py index 2f9ac59be2..fd32ab3789 100644 --- a/tests/integration/dbapi/async/V1/test_queries_async.py +++ b/tests/integration/dbapi/async/V1/test_queries_async.py @@ -78,7 +78,7 @@ async def test_connect_engine_name( connection_engine_name: Connection, all_types_query: str, all_types_query_description: List[Column], - all_types_query_response: List[ColType], + all_types_query_response_v1: List[ColType], timezone_name: str, ) -> None: """Connecting with engine name is handled properly.""" @@ -86,7 +86,7 @@ async def test_connect_engine_name( connection_engine_name, all_types_query, all_types_query_description, - all_types_query_response, + all_types_query_response_v1, timezone_name, ) @@ -95,7 +95,7 @@ async def test_connect_no_engine( connection_no_engine: Connection, all_types_query: str, all_types_query_description: List[Column], - all_types_query_response: List[ColType], + all_types_query_response_v1: List[ColType], timezone_name: str, ) -> None: """Connecting with engine name is handled properly.""" @@ -103,7 +103,7 @@ async def test_connect_no_engine( connection_no_engine, all_types_query, all_types_query_description, - all_types_query_response, + all_types_query_response_v1, timezone_name, ) @@ -112,11 +112,11 @@ async def test_select( connection: Connection, all_types_query: str, all_types_query_description: List[Column], - all_types_query_response: List[ColType], + all_types_query_response_v1: List[ColType], timezone_name: str, ) -> None: """Select handles all data types properly.""" - with connection.cursor() as c: + async with connection.cursor() as c: # For timestamptz test assert ( await c.execute(f"SET time_zone={timezone_name}") == -1 @@ -130,7 +130,7 @@ async def test_select( assert c.rowcount == 1, "Invalid rowcount value" data = await c.fetchall() assert len(data) == c.rowcount, "Invalid data length" - assert_deep_eq(data, all_types_query_response, "Invalid data") + assert_deep_eq(data, all_types_query_response_v1, "Invalid data") assert c.description == all_types_query_description, "Invalid description value" assert len(data[0]) == len(c.description), "Invalid description length" assert len(await c.fetchall()) == 0, "Redundant data returned by fetchall" @@ -138,7 +138,7 @@ async def test_select( # Different fetch types await c.execute(all_types_query) assert ( - await c.fetchone() == all_types_query_response[0] + await c.fetchone() == all_types_query_response_v1[0] ), "Invalid fetchone data" assert await c.fetchone() is None, "Redundant data returned by fetchone" @@ -147,12 +147,12 @@ async def test_select( data = await c.fetchmany() assert len(data) == 1, "Invalid data size returned by fetchmany" assert_deep_eq( - data, all_types_query_response, "Invalid data returned by fetchmany" + data, all_types_query_response_v1, "Invalid data returned by fetchmany" ) async def test_select_inf(connection: Connection) -> None: - with connection.cursor() as c: + async with connection.cursor() as c: await c.execute("SELECT 'inf'::float, '-inf'::float") data = await c.fetchall() assert len(data) == 1, "Invalid data size returned by fetchall" @@ -161,7 +161,7 @@ async def test_select_inf(connection: Connection) -> None: async def test_select_nan(connection: Connection) -> None: - with connection.cursor() as c: + async with connection.cursor() as c: await c.execute("SELECT 'nan'::float, '-nan'::float") data = await c.fetchall() assert len(data) == 1, "Invalid data size returned by fetchall" @@ -180,7 +180,7 @@ async def test_long_query( # Fail test if it takes less than 350 seconds minimal_time(350) - with connection.cursor() as c: + async with connection.cursor() as c: await c.execute(LONG_SELECT) data = await c.fetchall() assert len(data) == 1, "Invalid data size returned by fetchall" @@ -191,12 +191,12 @@ async def test_drop_create(connection: Connection) -> None: async def test_query(c: Cursor, query: str) -> None: await c.execute(query) - assert c.description == None + assert c.description == [] # This is inconsistent, commenting for now # assert c.rowcount == -1 """Create table query is handled properly""" - with connection.cursor() as c: + async with connection.cursor() as c: # Cleanup await c.execute( 'DROP AGGREGATING INDEX IF EXISTS "test_drop_create_async_db_agg_idx"' @@ -242,12 +242,12 @@ async def test_insert(connection: Connection) -> None: async def test_empty_query(c: Cursor, query: str) -> None: assert await c.execute(query) == 0, "Invalid row count returned" assert c.rowcount == 0, "Invalid rowcount value" - assert c.description is None, "Invalid description" + assert c.description == [], "Invalid description" assert await c.fetchone() is None assert len(await c.fetchmany()) == 0 assert len(await c.fetchall()) == 0 - with connection.cursor() as c: + async with connection.cursor() as c: await c.execute('DROP TABLE IF EXISTS "test_insert_async_tb"') await c.execute( 'CREATE FACT TABLE "test_insert_async_tb"(id int, sn string null, f float,' @@ -286,7 +286,7 @@ async def test_empty_query(c: Cursor, query: str) -> None: async def test_parameterized_query_with_special_chars(connection: Connection) -> None: """Query parameters are handled properly.""" - with connection.cursor() as c: + async with connection.cursor() as c: params = ["text with 'quote'", "text with \\slashes"] await c.execute( @@ -306,12 +306,12 @@ async def test_parameterized_query(connection: Connection) -> None: async def test_empty_query(c: Cursor, query: str, params: tuple) -> None: assert await c.execute(query, params) == 0, "Invalid row count returned" assert c.rowcount == 0, "Invalid rowcount value" - assert c.description is None, "Invalid description" + assert c.description == [], "Invalid description" assert await c.fetchone() is None assert len(await c.fetchmany()) == 0 assert len(await c.fetchall()) == 0 - with connection.cursor() as c: + async with connection.cursor() as c: await c.execute('DROP TABLE IF EXISTS "test_tb_async_parameterized"') await c.execute( 'CREATE FACT TABLE "test_tb_async_parameterized"(i int, f float, s string, sn' @@ -328,7 +328,7 @@ async def test_empty_query(c: Cursor, query: str, params: tuple) -> None: datetime(2022, 1, 1, 1, 1, 1), True, [1, 2, 3], - Decimal("123.456"), + Decimal(123.456), ] await test_empty_query( @@ -353,9 +353,9 @@ async def test_empty_query(c: Cursor, query: str, params: tuple) -> None: async def test_multi_statement_query(connection: Connection) -> None: - """Query parameters are handled properly""" + """Query parameters are handled properly.""" - with connection.cursor() as c: + async with connection.cursor() as c: await c.execute('DROP TABLE IF EXISTS "test_tb_async_multi_statement"') await c.execute( 'CREATE FACT TABLE "test_tb_async_multi_statement"(i int, s string)' @@ -367,7 +367,7 @@ async def test_multi_statement_query(connection: Connection) -> None: 'SELECT * FROM "test_tb_async_multi_statement";' 'SELECT * FROM "test_tb_async_multi_statement" WHERE i <= 1' ) - assert c.description is None, "Invalid description" + assert c.description == [], "Invalid description" assert await c.nextset() @@ -405,11 +405,11 @@ async def test_multi_statement_query(connection: Connection) -> None: "Invalid data in table after parameterized insert", ) - assert await c.nextset() is None + assert await c.nextset() is False async def test_set_invalid_parameter(connection: Connection): - with connection.cursor() as c: + async with connection.cursor() as c: assert len(c._set_parameters) == 0 with raises(OperationalError): await c.execute("SET some_invalid_parameter = 1") @@ -421,7 +421,7 @@ async def test_bytea_roundtrip( connection: Connection, ) -> None: """Inserted and than selected bytea value doesn't get corrupted.""" - with connection.cursor() as c: + async with connection.cursor() as c: await c.execute('DROP TABLE IF EXISTS "test_bytea_roundtrip"') await c.execute( 'CREATE FACT TABLE "test_bytea_roundtrip"(id int, b bytea) primary index id' @@ -444,7 +444,7 @@ async def test_bytea_roundtrip( @fixture async def setup_db(connection_no_engine: Connection, use_db_name: str): use_db_name = f"{use_db_name}_async" - with connection_no_engine.cursor() as cursor: + async with connection_no_engine.cursor() as cursor: suffix = "".join(choice("0123456789") for _ in range(2)) await cursor.execute(f'CREATE DATABASE "{use_db_name}{suffix}"') yield @@ -461,7 +461,7 @@ async def test_use_database( test_db_name = f"{use_db_name}_async" test_table_name = "verify_use_db_async" """Use database works as expected.""" - with connection_no_engine.cursor() as c: + async with connection_no_engine.cursor() as c: await c.execute(f'USE DATABASE "{test_db_name}"') assert c.database == test_db_name await c.execute(f'CREATE TABLE "{test_table_name}" (id int)') diff --git a/tests/integration/dbapi/async/V2/conftest.py b/tests/integration/dbapi/async/V2/conftest.py index a379f0c667..eb14724c75 100644 --- a/tests/integration/dbapi/async/V2/conftest.py +++ b/tests/integration/dbapi/async/V2/conftest.py @@ -103,7 +103,7 @@ async def service_account_no_user( # function-level fixture so we need to make sa name is unique randomness = "".join(random.choices(string.ascii_letters + string.digits, k=2)) sa_account_name = f"{database_name}_no_user_{randomness}" - with connection_system_engine_no_db.cursor() as cursor: + async with connection_system_engine_no_db.cursor() as cursor: await cursor.execute( f'CREATE SERVICE ACCOUNT "{sa_account_name}" ' "WITH DESCRIPTION = 'Ecosytem test with no user'" diff --git a/tests/integration/dbapi/async/V2/test_auth_async.py b/tests/integration/dbapi/async/V2/test_auth_async.py index 081f17a304..98f71a2dcb 100644 --- a/tests/integration/dbapi/async/V2/test_auth_async.py +++ b/tests/integration/dbapi/async/V2/test_auth_async.py @@ -9,7 +9,7 @@ @mark.skip(reason="flaky, token not updated each time") async def test_refresh_token(connection: Connection) -> None: """Auth refreshes token on expiration/invalidation""" - with connection.cursor() as c: + async with connection.cursor() as c: # Works fine await c.execute("show tables") @@ -35,7 +35,7 @@ async def test_credentials_invalidation( """Auth raises authentication error on credentials invalidation""" # Can't pytest.parametrize it due to nested event loop error for conn in [connection, username_password_connection]: - with conn.cursor() as c: + async with conn.cursor() as c: # Works fine await c.execute("show tables") diff --git a/tests/integration/dbapi/async/V2/test_errors_async.py b/tests/integration/dbapi/async/V2/test_errors_async.py index 8be4e5c1cc..39a7815667 100644 --- a/tests/integration/dbapi/async/V2/test_errors_async.py +++ b/tests/integration/dbapi/async/V2/test_errors_async.py @@ -104,7 +104,7 @@ async def test_database_not_exists( async def test_sql_error(connection: Connection) -> None: """Connection properly reacts to SQL execution error.""" - with connection.cursor() as c: + async with connection.cursor() as c: with raises(FireboltStructuredError) as exc_info: await c.execute("select ]") diff --git a/tests/integration/dbapi/async/V2/test_queries_async.py b/tests/integration/dbapi/async/V2/test_queries_async.py index 335fa358e7..d782d54e7e 100644 --- a/tests/integration/dbapi/async/V2/test_queries_async.py +++ b/tests/integration/dbapi/async/V2/test_queries_async.py @@ -47,7 +47,7 @@ async def test_select( timezone_name: str, ) -> None: """Select handles all data types properly.""" - with connection.cursor() as c: + async with connection.cursor() as c: # For timestamptz test assert ( await c.execute(f"SET time_zone={timezone_name}") == -1 @@ -79,7 +79,7 @@ async def test_select( async def test_select_inf(connection: Connection) -> None: - with connection.cursor() as c: + async with connection.cursor() as c: await c.execute("SELECT 'inf'::float, '-inf'::float") data = await c.fetchall() assert len(data) == 1, "Invalid data size returned by fetchall" @@ -88,7 +88,7 @@ async def test_select_inf(connection: Connection) -> None: async def test_select_nan(connection: Connection) -> None: - with connection.cursor() as c: + async with connection.cursor() as c: await c.execute("SELECT 'nan'::float, '-nan'::float") data = await c.fetchall() assert len(data) == 1, "Invalid data size returned by fetchall" @@ -106,7 +106,7 @@ async def test_long_query( minimal_time(350) - with connection.cursor() as c: + async with connection.cursor() as c: await c.execute(LONG_SELECT) data = await c.fetchall() assert len(data) == 1, "Invalid data size returned by fetchall" @@ -117,11 +117,11 @@ async def test_drop_create(connection: Connection) -> None: async def test_query(c: Cursor, query: str) -> None: await c.execute(query) - assert c.description == None + assert c.description == [] assert c.rowcount == 0 """Create table query is handled properly""" - with connection.cursor() as c: + async with connection.cursor() as c: # Cleanup await c.execute('DROP AGGREGATING INDEX IF EXISTS "test_db_agg_idx"') await c.execute('DROP TABLE IF EXISTS "test_drop_create_async"') @@ -165,12 +165,12 @@ async def test_insert(connection: Connection) -> None: async def test_empty_query(c: Cursor, query: str) -> None: assert await c.execute(query) == 0, "Invalid row count returned" assert c.rowcount == 0, "Invalid rowcount value" - assert c.description is None, "Invalid description" + assert c.description == [], "Invalid description" assert await c.fetchone() is None assert len(await c.fetchmany()) == 0 assert len(await c.fetchall()) == 0 - with connection.cursor() as c: + async with connection.cursor() as c: await c.execute('DROP TABLE IF EXISTS "test_insert_async_tb"') await c.execute( 'CREATE FACT TABLE "test_insert_async_tb"(id int, sn string null, f float,' @@ -213,12 +213,12 @@ async def test_parameterized_query(connection: Connection) -> None: async def test_empty_query(c: Cursor, query: str, params: tuple) -> None: assert await c.execute(query, params) == 0, "Invalid row count returned" assert c.rowcount == 0, "Invalid rowcount value" - assert c.description is None, "Invalid description" + assert c.description == [], "Invalid description" assert await c.fetchone() is None assert len(await c.fetchmany()) == 0 assert len(await c.fetchall()) == 0 - with connection.cursor() as c: + async with connection.cursor() as c: await c.execute('DROP TABLE IF EXISTS "test_tb_async_parameterized"') await c.execute( 'CREATE FACT TABLE "test_tb_async_parameterized"(i int, f float, s string, sn' @@ -261,7 +261,7 @@ async def test_empty_query(c: Cursor, query: str, params: tuple) -> None: async def test_parameterized_query_with_special_chars(connection: Connection) -> None: """Query parameters are handled properly.""" - with connection.cursor() as c: + async with connection.cursor() as c: params = ["text with 'quote'", "text with \\slashes"] await c.execute( @@ -278,7 +278,7 @@ async def test_parameterized_query_with_special_chars(connection: Connection) -> async def test_multi_statement_query(connection: Connection) -> None: """Query parameters are handled properly""" - with connection.cursor() as c: + async with connection.cursor() as c: await c.execute('DROP TABLE IF EXISTS "test_tb_async_multi_statement"') await c.execute( 'CREATE FACT TABLE "test_tb_async_multi_statement"(i int, s string)' @@ -290,7 +290,7 @@ async def test_multi_statement_query(connection: Connection) -> None: 'SELECT * FROM "test_tb_async_multi_statement";' 'SELECT * FROM "test_tb_async_multi_statement" WHERE i <= 1' ) - assert c.description is None, "Invalid description" + assert c.description == [], "Invalid description" assert await c.nextset() @@ -328,11 +328,11 @@ async def test_multi_statement_query(connection: Connection) -> None: "Invalid data in table after parameterized insert", ) - assert await c.nextset() is None + assert await c.nextset() is False async def test_set_invalid_parameter(connection: Connection): - with connection.cursor() as c: + async with connection.cursor() as c: assert len(c._set_parameters) == 0 with raises(OperationalError): await c.execute("SET some_invalid_parameter = 1") @@ -344,7 +344,7 @@ async def test_bytea_roundtrip( connection: Connection, ) -> None: """Inserted and than selected bytea value doesn't get corrupted.""" - with connection.cursor() as c: + async with connection.cursor() as c: await c.execute('DROP TABLE IF EXISTS "test_bytea_roundtrip_2"') await c.execute( 'CREATE FACT TABLE "test_bytea_roundtrip_2"(id int, b bytea) primary index id' @@ -435,7 +435,7 @@ async def test_select_geography( select_geography_description: List[Column], select_geography_response: List[ColType], ) -> None: - with connection.cursor() as c: + async with connection.cursor() as c: await c.execute(select_geography_query) assert ( c.description == select_geography_description @@ -457,7 +457,7 @@ async def test_select_struct( select_struct_description: List[Column], select_struct_response: List[ColType], ): - with connection.cursor() as c: + async with connection.cursor() as c: try: await c.execute(setup_struct_query) await c.execute(select_struct_query) diff --git a/tests/integration/dbapi/async/V2/test_system_engine_async.py b/tests/integration/dbapi/async/V2/test_system_engine_async.py index 321ef50547..d6eb8079ab 100644 --- a/tests/integration/dbapi/async/V2/test_system_engine_async.py +++ b/tests/integration/dbapi/async/V2/test_system_engine_async.py @@ -24,7 +24,7 @@ async def test_system_engine( all_types_query_system_engine_response: List[ColType], timezone_name: str, ) -> None: - with connection_system_engine.cursor() as c: + async with connection_system_engine.cursor() as c: assert await c.execute(all_types_query) == 1, "Invalid row count returned" assert c.rowcount == 1, "Invalid rowcount value" data = await c.fetchall() @@ -86,7 +86,7 @@ async def test_system_engine_use_engine( connection_system_engine: Connection, database_name: str, engine_name: str ): table_name = "test_table_async" - with connection_system_engine.cursor() as cursor: + async with connection_system_engine.cursor() as cursor: await cursor.execute(f'USE DATABASE "{database_name}"') await cursor.execute(f'USE ENGINE "{engine_name}"') await cursor.execute(f'CREATE TABLE IF NOT EXISTS "{table_name}" (id int)') diff --git a/tests/integration/dbapi/async/V2/test_timeout.py b/tests/integration/dbapi/async/V2/test_timeout.py index 9cd3dc53ee..0aff66e3d6 100644 --- a/tests/integration/dbapi/async/V2/test_timeout.py +++ b/tests/integration/dbapi/async/V2/test_timeout.py @@ -8,6 +8,6 @@ async def test_query_timeout(connection: Connection): - with connection.cursor() as cursor: + async with connection.cursor() as cursor: with raises(QueryTimeoutError): await cursor.execute(LONG_SELECT, timeout_seconds=1) diff --git a/tests/integration/dbapi/conftest.py b/tests/integration/dbapi/conftest.py index 9d2b53c98e..112749684d 100644 --- a/tests/integration/dbapi/conftest.py +++ b/tests/integration/dbapi/conftest.py @@ -215,6 +215,7 @@ def select_geography_response() -> List[ColType]: def setup_struct_query() -> str: return """ SET advanced_mode=1; + SET enable_struct=true; SET enable_create_table_v2=true; SET enable_struct_syntax=true; SET prevent_create_on_information_schema=true; diff --git a/tests/integration/dbapi/sync/V1/conftest.py b/tests/integration/dbapi/sync/V1/conftest.py index 9beea7962f..4fb34f0c8a 100644 --- a/tests/integration/dbapi/sync/V1/conftest.py +++ b/tests/integration/dbapi/sync/V1/conftest.py @@ -1,6 +1,9 @@ +from decimal import Decimal + from pytest import fixture from firebolt.client.auth.base import Auth +from firebolt.common._types import ColType from firebolt.db import Connection, connect @@ -93,3 +96,14 @@ def connection_system_engine( ) yield connection connection.close() + + +@fixture +def all_types_query_response_v1( + all_types_query_response: list[list[ColType]], +) -> list[list[ColType]]: + """ + V1 still returns decimals as floats, despite overflow. That's why it's not fully accurate. + """ + all_types_query_response[0][18] = Decimal("1231232.1234599999152123928070068359375") + return all_types_query_response diff --git a/tests/integration/dbapi/sync/V1/test_queries.py b/tests/integration/dbapi/sync/V1/test_queries.py index 9950e231d6..aaae6dbdd5 100644 --- a/tests/integration/dbapi/sync/V1/test_queries.py +++ b/tests/integration/dbapi/sync/V1/test_queries.py @@ -30,7 +30,7 @@ def test_connect_engine_name( connection_engine_name: Connection, all_types_query: str, all_types_query_description: List[Column], - all_types_query_response: List[ColType], + all_types_query_response_v1: List[ColType], timezone_name: str, ) -> None: """Connecting with engine name is handled properly.""" @@ -38,7 +38,7 @@ def test_connect_engine_name( connection_engine_name, all_types_query, all_types_query_description, - all_types_query_response, + all_types_query_response_v1, timezone_name, ) @@ -47,7 +47,7 @@ def test_connect_no_engine( connection_no_engine: Connection, all_types_query: str, all_types_query_description: List[Column], - all_types_query_response: List[ColType], + all_types_query_response_v1: List[ColType], timezone_name: str, ) -> None: """Connecting with engine name is handled properly.""" @@ -55,7 +55,7 @@ def test_connect_no_engine( connection_no_engine, all_types_query, all_types_query_description, - all_types_query_response, + all_types_query_response_v1, timezone_name, ) @@ -64,7 +64,7 @@ def test_select( connection: Connection, all_types_query: str, all_types_query_description: List[Column], - all_types_query_response: List[ColType], + all_types_query_response_v1: List[ColType], timezone_name: str, ) -> None: """Select handles all data types properly.""" @@ -82,14 +82,14 @@ def test_select( assert c.rowcount == 1, "Invalid rowcount value" data = c.fetchall() assert len(data) == c.rowcount, "Invalid data length" - assert_deep_eq(data, all_types_query_response, "Invalid data") + assert_deep_eq(data, all_types_query_response_v1, "Invalid data") assert c.description == all_types_query_description, "Invalid description value" assert len(data[0]) == len(c.description), "Invalid description length" assert len(c.fetchall()) == 0, "Redundant data returned by fetchall" # Different fetch types c.execute(all_types_query) - assert c.fetchone() == all_types_query_response[0], "Invalid fetchone data" + assert c.fetchone() == all_types_query_response_v1[0], "Invalid fetchone data" assert c.fetchone() is None, "Redundant data returned by fetchone" c.execute(all_types_query) @@ -97,7 +97,7 @@ def test_select( data = c.fetchmany() assert len(data) == 1, "Invalid data size returned by fetchmany" assert_deep_eq( - data, all_types_query_response, "Invalid data returned by fetchmany" + data, all_types_query_response_v1, "Invalid data returned by fetchmany" ) @@ -140,7 +140,7 @@ def test_drop_create(connection: Connection) -> None: def test_query(c: Cursor, query: str) -> None: c.execute(query) - assert c.description == None + assert c.description == [] # Inconsistent behaviour in Firebolt # assert c.rowcount == -1 @@ -189,7 +189,7 @@ def test_insert(connection: Connection) -> None: def test_empty_query(c: Cursor, query: str) -> None: assert c.execute(query) == 0, "Invalid row count returned" assert c.rowcount == 0, "Invalid rowcount value" - assert c.description is None, "Invalid description" + assert c.description == [], "Invalid description" assert c.fetchone() is None assert len(c.fetchmany()) == 0 assert len(c.fetchall()) == 0 @@ -251,7 +251,7 @@ def test_parameterized_query(connection: Connection) -> None: def test_empty_query(c: Cursor, query: str, params: tuple) -> None: assert c.execute(query, params) == 0, "Invalid row count returned" assert c.rowcount == 0, "Invalid rowcount value" - assert c.description is None, "Invalid description" + assert c.description == [], "Invalid description" assert c.fetchone() is None assert len(c.fetchmany()) == 0 assert len(c.fetchall()) == 0 @@ -273,7 +273,7 @@ def test_empty_query(c: Cursor, query: str, params: tuple) -> None: datetime(2022, 1, 1, 1, 1, 1), True, [1, 2, 3], - Decimal("123.456"), + Decimal(123.456), ] test_empty_query( @@ -311,7 +311,7 @@ def test_multi_statement_query(connection: Connection) -> None: 'SELECT * FROM "test_tb_multi_statement";' 'SELECT * FROM "test_tb_multi_statement" WHERE i <= 1' ) - assert c.description is None, "Invalid description" + assert c.description == [], "Invalid description" assert c.nextset() @@ -349,7 +349,7 @@ def test_multi_statement_query(connection: Connection) -> None: "Invalid data in table after parameterized insert", ) - assert c.nextset() is None + assert c.nextset() is False def test_set_invalid_parameter(connection: Connection): diff --git a/tests/integration/dbapi/sync/V2/test_queries.py b/tests/integration/dbapi/sync/V2/test_queries.py index 5df7ef99ad..20653ba6c8 100644 --- a/tests/integration/dbapi/sync/V2/test_queries.py +++ b/tests/integration/dbapi/sync/V2/test_queries.py @@ -121,7 +121,7 @@ def test_drop_create(connection: Connection) -> None: def test_query(c: Cursor, query: str) -> None: c.execute(query) - assert c.description == None + assert c.description == [] assert c.rowcount == 0 """Create table query is handled properly""" @@ -169,7 +169,7 @@ def test_insert(connection: Connection) -> None: def test_empty_query(c: Cursor, query: str) -> None: assert c.execute(query) == 0, "Invalid row count returned" assert c.rowcount == 0, "Invalid rowcount value" - assert c.description is None, "Invalid description" + assert c.description == [], "Invalid description" assert c.fetchone() is None assert len(c.fetchmany()) == 0 assert len(c.fetchall()) == 0 @@ -230,7 +230,7 @@ def test_parameterized_query(connection: Connection) -> None: def test_empty_query(c: Cursor, query: str, params: tuple) -> None: assert c.execute(query, params) == 0, "Invalid row count returned" assert c.rowcount == 0, "Invalid rowcount value" - assert c.description is None, "Invalid description" + assert c.description == [], "Invalid description" assert c.fetchone() is None assert len(c.fetchmany()) == 0 assert len(c.fetchall()) == 0 @@ -290,7 +290,7 @@ def test_multi_statement_query(connection: Connection) -> None: 'SELECT * FROM "test_tb_multi_statement";' 'SELECT * FROM "test_tb_multi_statement" WHERE i <= 1' ) - assert c.description is None, "Invalid description" + assert c.description == [], "Invalid description" assert c.nextset() @@ -328,7 +328,7 @@ def test_multi_statement_query(connection: Connection) -> None: "Invalid data in table after parameterized insert", ) - assert c.nextset() is None + assert c.nextset() is False def test_set_invalid_parameter(connection: Connection): diff --git a/tests/unit/V1/async_db/test_cursor.py b/tests/unit/V1/async_db/test_cursor.py index a3eae31e5f..a02dd89058 100644 --- a/tests/unit/V1/async_db/test_cursor.py +++ b/tests/unit/V1/async_db/test_cursor.py @@ -6,7 +6,7 @@ from pytest_httpx import HTTPXMock from firebolt.async_db import Cursor -from firebolt.common.base_cursor import ColType +from firebolt.common._types import ColType from firebolt.common.constants import CursorState from firebolt.common.row_set.types import Column from firebolt.utils.exception import ( @@ -85,7 +85,7 @@ async def test_closed_cursor(cursor: Cursor): await getattr(cursor, amethod)(*args) with raises(CursorClosedError): - with cursor: + async with cursor: pass with raises(CursorClosedError): @@ -139,7 +139,7 @@ async def test_cursor_no_query( cursor._reset() cursor.setoutputsize(0) # Context manager is also available - with cursor: + async with cursor: pass # should this be available? # async with cursor: @@ -496,7 +496,7 @@ async def test_cursor_multi_statement( await cursor.fetchone() == python_query_data[i] ), f"Invalid data row at position {i}" - assert await cursor.nextset() is None + assert await cursor.nextset() is False async def test_cursor_set_statements( diff --git a/tests/unit/V1/db/test_cursor.py b/tests/unit/V1/db/test_cursor.py index 59b33ea099..7d9bcaa3cf 100644 --- a/tests/unit/V1/db/test_cursor.py +++ b/tests/unit/V1/db/test_cursor.py @@ -439,7 +439,7 @@ def test_cursor_multi_statement( cursor.fetchone() == python_query_data[i] ), f"Invalid data row at position {i}" - assert cursor.nextset() is None + assert cursor.nextset() is False def test_cursor_set_statements( diff --git a/tests/unit/async_db/test_cursor.py b/tests/unit/async_db/test_cursor.py index bc3d0b0815..55e441906e 100644 --- a/tests/unit/async_db/test_cursor.py +++ b/tests/unit/async_db/test_cursor.py @@ -7,7 +7,7 @@ from pytest_httpx import HTTPXMock from firebolt.async_db import Cursor -from firebolt.common.base_cursor import ColType +from firebolt.common._types import ColType from firebolt.common.constants import CursorState from firebolt.common.row_set.types import Column from firebolt.utils.exception import ( @@ -84,7 +84,7 @@ async def test_closed_cursor(cursor: Cursor): await getattr(cursor, amethod)(*args) with raises(CursorClosedError): - with cursor: + async with cursor: pass with raises(CursorClosedError): @@ -134,7 +134,7 @@ async def test_cursor_no_query( cursor._reset() cursor.setoutputsize(0) # Context manager is also available - with cursor: + async with cursor: pass # should this be available? # async with cursor: @@ -439,7 +439,7 @@ async def test_cursor_multi_statement( await cursor.fetchone() == python_query_data[i] ), f"Invalid data row at position {i}" - assert await cursor.nextset() is None + assert await cursor.nextset() is False async def test_cursor_set_statements( diff --git a/tests/unit/common/test_base_cursor.py b/tests/unit/common/test_base_cursor.py index 6e32647e22..830732b593 100644 --- a/tests/unit/common/test_base_cursor.py +++ b/tests/unit/common/test_base_cursor.py @@ -3,7 +3,7 @@ from pytest import fixture, mark -from firebolt.common.base_cursor import BaseCursor +from firebolt.common.cursor.base_cursor import BaseCursor from firebolt.common.statement_formatter import create_statement_formatter diff --git a/tests/unit/db/test_cursor.py b/tests/unit/db/test_cursor.py index bba871ee66..416a92b22a 100644 --- a/tests/unit/db/test_cursor.py +++ b/tests/unit/db/test_cursor.py @@ -424,7 +424,7 @@ def test_cursor_multi_statement( cursor.fetchone() == python_query_data[i] ), f"Invalid data row at position {i}" - assert cursor.nextset() is None + assert cursor.nextset() is False def test_cursor_set_statements( From 7e2437db1250800f43b59be9f9d21e56cc354ad9 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Tue, 8 Apr 2025 17:19:47 +0300 Subject: [PATCH 04/10] fix type annotations --- tests/integration/dbapi/async/V1/conftest.py | 5 +++-- tests/integration/dbapi/sync/V1/conftest.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/integration/dbapi/async/V1/conftest.py b/tests/integration/dbapi/async/V1/conftest.py index cb2fc40a56..b3cd40edac 100644 --- a/tests/integration/dbapi/async/V1/conftest.py +++ b/tests/integration/dbapi/async/V1/conftest.py @@ -1,4 +1,5 @@ from decimal import Decimal +from typing import List from pytest import fixture @@ -81,8 +82,8 @@ async def connection_no_engine( @fixture def all_types_query_response_v1( - all_types_query_response: list[list[ColType]], -) -> list[list[ColType]]: + all_types_query_response: List[List[ColType]], +) -> List[List[ColType]]: """ V1 still returns decimals as floats, despite overflow. That's why it's not fully accurate. """ diff --git a/tests/integration/dbapi/sync/V1/conftest.py b/tests/integration/dbapi/sync/V1/conftest.py index 4fb34f0c8a..c05d4d42ba 100644 --- a/tests/integration/dbapi/sync/V1/conftest.py +++ b/tests/integration/dbapi/sync/V1/conftest.py @@ -1,4 +1,5 @@ from decimal import Decimal +from typing import List from pytest import fixture @@ -100,8 +101,8 @@ def connection_system_engine( @fixture def all_types_query_response_v1( - all_types_query_response: list[list[ColType]], -) -> list[list[ColType]]: + all_types_query_response: List[List[ColType]], +) -> List[List[ColType]]: """ V1 still returns decimals as floats, despite overflow. That's why it's not fully accurate. """ From 91aa3f323fa43b6e76bc24c9934ad71c87d66f4d Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 9 Apr 2025 10:58:51 +0300 Subject: [PATCH 05/10] fix unit tests --- tests/unit/V1/async_db/test_cursor.py | 10 +++++----- tests/unit/V1/db/test_cursor.py | 10 +++++----- tests/unit/async_db/test_cursor.py | 10 +++++----- tests/unit/db/test_cursor.py | 10 +++++----- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/tests/unit/V1/async_db/test_cursor.py b/tests/unit/V1/async_db/test_cursor.py index a02dd89058..aae527b1e8 100644 --- a/tests/unit/V1/async_db/test_cursor.py +++ b/tests/unit/V1/async_db/test_cursor.py @@ -201,7 +201,7 @@ async def test_cursor_execute( cursor.rowcount == -1 ), f"Invalid rowcount value for insert using {message}." assert ( - cursor.description is None + cursor.description == [] ), f"Invalid description for insert using {message}." @@ -469,7 +469,7 @@ async def test_cursor_multi_statement( assert await cursor.nextset() assert cursor.rowcount == -1, "Invalid cursor row count" - assert cursor.description is None, "Invalid cursor description" + assert cursor.description == [], "Invalid cursor description" assert cursor.statistics is None, "Invalid cursor statistics" with raises(DataError) as exc_info: @@ -515,7 +515,7 @@ async def test_cursor_set_statements( rc = await cursor.execute("set a = b") assert rc == -1, "Invalid row count returned" - assert cursor.description is None, "Non-empty description for set" + assert cursor.description == [], "Non-empty description for set" with raises(DataError): await cursor.fetchall() @@ -533,7 +533,7 @@ async def test_cursor_set_statements( rc = await cursor.execute("set param1=1") assert rc == -1, "Invalid row count returned" - assert cursor.description is None, "Non-empty description for set" + assert cursor.description == [], "Non-empty description for set" with raises(DataError): await cursor.fetchall() @@ -549,7 +549,7 @@ async def test_cursor_set_statements( rc = await cursor.execute("set param2=0") assert rc == -1, "Invalid row count returned" - assert cursor.description is None, "Non-empty description for set" + assert cursor.description == [], "Non-empty description for set" with raises(DataError): await cursor.fetchall() diff --git a/tests/unit/V1/db/test_cursor.py b/tests/unit/V1/db/test_cursor.py index 7d9bcaa3cf..9c3f5b5ee3 100644 --- a/tests/unit/V1/db/test_cursor.py +++ b/tests/unit/V1/db/test_cursor.py @@ -196,7 +196,7 @@ def test_cursor_execute( cursor.rowcount == -1 ), f"Invalid rowcount value for insert using {message}." assert ( - cursor.description is None + cursor.description == [] ), f"Invalid description for insert using {message}." @@ -422,7 +422,7 @@ def test_cursor_multi_statement( assert cursor.nextset() assert cursor.rowcount == -1, "Invalid cursor row count" - assert cursor.description is None, "Invalid cursor description" + assert cursor.description == [], "Invalid cursor description" with raises(DataError) as exc_info: cursor.fetchall() @@ -461,7 +461,7 @@ def test_cursor_set_statements( rc = cursor.execute("set a = b") assert rc == -1, "Invalid row count returned" - assert cursor.description is None, "Non-empty description for set" + assert cursor.description == [], "Non-empty description for set" with raises(DataError): cursor.fetchall() @@ -479,7 +479,7 @@ def test_cursor_set_statements( rc = cursor.execute("set param1=1") assert rc == -1, "Invalid row count returned" - assert cursor.description is None, "Non-empty description for set" + assert cursor.description == [], "Non-empty description for set" with raises(DataError): cursor.fetchall() @@ -495,7 +495,7 @@ def test_cursor_set_statements( rc = cursor.execute("set param2=0") assert rc == -1, "Invalid row count returned" - assert cursor.description is None, "Non-empty description for set" + assert cursor.description == [], "Non-empty description for set" with raises(DataError): cursor.fetchall() diff --git a/tests/unit/async_db/test_cursor.py b/tests/unit/async_db/test_cursor.py index 55e441906e..0f77502c82 100644 --- a/tests/unit/async_db/test_cursor.py +++ b/tests/unit/async_db/test_cursor.py @@ -188,7 +188,7 @@ async def test_cursor_execute( cursor.rowcount == -1 ), f"Invalid rowcount value for insert using {message}." assert ( - cursor.description is None + cursor.description == [] ), f"Invalid description for insert using {message}." @@ -412,7 +412,7 @@ async def test_cursor_multi_statement( assert await cursor.nextset() assert cursor.rowcount == -1, "Invalid cursor row count" - assert cursor.description is None, "Invalid cursor description" + assert cursor.description == [], "Invalid cursor description" assert cursor.statistics is None, "Invalid cursor statistics" with raises(DataError) as exc_info: @@ -455,7 +455,7 @@ async def test_cursor_set_statements( rc = await cursor.execute("set a = b") assert rc == -1, "Invalid row count returned" - assert cursor.description is None, "Non-empty description for set" + assert cursor.description == [], "Non-empty description for set" with raises(DataError): await cursor.fetchall() @@ -473,7 +473,7 @@ async def test_cursor_set_statements( rc = await cursor.execute("set param1=1") assert rc == -1, "Invalid row count returned" - assert cursor.description is None, "Non-empty description for set" + assert cursor.description == [], "Non-empty description for set" with raises(DataError): await cursor.fetchall() @@ -489,7 +489,7 @@ async def test_cursor_set_statements( rc = await cursor.execute("set param2=0") assert rc == -1, "Invalid row count returned" - assert cursor.description is None, "Non-empty description for set" + assert cursor.description == [], "Non-empty description for set" with raises(DataError): await cursor.fetchall() diff --git a/tests/unit/db/test_cursor.py b/tests/unit/db/test_cursor.py index 416a92b22a..4295f25edf 100644 --- a/tests/unit/db/test_cursor.py +++ b/tests/unit/db/test_cursor.py @@ -185,7 +185,7 @@ def test_cursor_execute( cursor.rowcount == -1 ), f"Invalid rowcount value for insert using {message}." assert ( - cursor.description is None + cursor.description == [] ), f"Invalid description for insert using {message}." @@ -407,7 +407,7 @@ def test_cursor_multi_statement( assert cursor.nextset() assert cursor.rowcount == -1, "Invalid cursor row count" - assert cursor.description is None, "Invalid cursor description" + assert cursor.description == [], "Invalid cursor description" with raises(DataError) as exc_info: cursor.fetchall() @@ -440,7 +440,7 @@ def test_cursor_set_statements( rc = cursor.execute("set a = b") assert rc == -1, "Invalid row count returned" - assert cursor.description is None, "Non-empty description for set" + assert cursor.description == [], "Non-empty description for set" with raises(DataError): cursor.fetchall() @@ -458,7 +458,7 @@ def test_cursor_set_statements( rc = cursor.execute("set param1=1") assert rc == -1, "Invalid row count returned" - assert cursor.description is None, "Non-empty description for set" + assert cursor.description == [], "Non-empty description for set" with raises(DataError): cursor.fetchall() @@ -474,7 +474,7 @@ def test_cursor_set_statements( rc = cursor.execute("set param2=0") assert rc == -1, "Invalid row count returned" - assert cursor.description is None, "Non-empty description for set" + assert cursor.description == [], "Non-empty description for set" with raises(DataError): cursor.fetchall() From f2014db86df7c24a9477186ae182bf71bf028961 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 9 Apr 2025 14:59:57 +0300 Subject: [PATCH 06/10] fix sonar --- src/firebolt/common/row_set/synchronous/in_memory.py | 1 + src/firebolt/utils/async_util.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/firebolt/common/row_set/synchronous/in_memory.py b/src/firebolt/common/row_set/synchronous/in_memory.py index da1203bd79..fb546da18b 100644 --- a/src/firebolt/common/row_set/synchronous/in_memory.py +++ b/src/firebolt/common/row_set/synchronous/in_memory.py @@ -99,4 +99,5 @@ def __next__(self) -> List[ColType]: return self._parse_row(self._row_set.rows[self._current_row]) def close(self) -> None: + # No-op for in-memory row set pass diff --git a/src/firebolt/utils/async_util.py b/src/firebolt/utils/async_util.py index 078b4fea6b..fe925881cf 100644 --- a/src/firebolt/utils/async_util.py +++ b/src/firebolt/utils/async_util.py @@ -31,6 +31,7 @@ async def __anext__(self) -> bytes: return self.b async def aclose(self) -> None: + # No-op since there is nothing to close pass return ABS(b) From 2efe40b0f7dfe37cba0e047edf008bb42f1e49ec Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Mon, 14 Apr 2025 17:29:16 +0300 Subject: [PATCH 07/10] address comments --- src/firebolt/common/row_set/types.py | 17 ++++++++++++++++- src/firebolt/utils/async_util.py | 24 ------------------------ tests/unit/common/test_types.py | 17 +++++++++++++++++ 3 files changed, 33 insertions(+), 25 deletions(-) create mode 100644 tests/unit/common/test_types.py diff --git a/src/firebolt/common/row_set/types.py b/src/firebolt/common/row_set/types.py index 924f720dc5..e6a11c9ecf 100644 --- a/src/firebolt/common/row_set/types.py +++ b/src/firebolt/common/row_set/types.py @@ -1,7 +1,15 @@ from __future__ import annotations from dataclasses import dataclass, fields -from typing import AsyncIterator, Iterator, List, Optional, Protocol, Union +from typing import ( + Any, + AsyncIterator, + Iterator, + List, + Optional, + Protocol, + Union, +) from firebolt.common._types import ExtendedType, RawColType @@ -62,6 +70,13 @@ class Column: scale: Optional[int] = None null_ok: Optional[bool] = None + def __getitem__(self, item: int) -> Any: + """Support indexing for column attributes.""" + field_names = [f.name for f in fields(self)] + if item >= len(field_names): + raise IndexError("Index out of range") + return getattr(self, field_names[item]) + class ByteStream(Protocol): def __iter__(self) -> Iterator[bytes]: diff --git a/src/firebolt/utils/async_util.py b/src/firebolt/utils/async_util.py index fe925881cf..69f4a8f8bb 100644 --- a/src/firebolt/utils/async_util.py +++ b/src/firebolt/utils/async_util.py @@ -1,7 +1,5 @@ from typing import AsyncIterator, List, TypeVar -from firebolt.common.row_set.types import AsyncByteStream - TIter = TypeVar("TIter") @@ -13,25 +11,3 @@ async def async_islice(async_iterator: AsyncIterator[TIter], n: int) -> List[TIt except StopAsyncIteration: pass return result - - -def async_byte_stream(b: bytes) -> AsyncByteStream: - class ABS: - def __init__(self, b: bytes): - self.b = b - self.read = False - - def __aiter__(self) -> AsyncIterator[bytes]: - return self - - async def __anext__(self) -> bytes: - if self.read: - raise StopAsyncIteration - self.read = True - return self.b - - async def aclose(self) -> None: - # No-op since there is nothing to close - pass - - return ABS(b) diff --git a/tests/unit/common/test_types.py b/tests/unit/common/test_types.py new file mode 100644 index 0000000000..62f007105d --- /dev/null +++ b/tests/unit/common/test_types.py @@ -0,0 +1,17 @@ +from dataclasses import fields + +from firebolt.common.row_set.types import Column + + +def test_columns_supports_indexing(): + column = Column( + name="test_column", + type_code=int, + display_size=10, + internal_size=20, + precision=5, + scale=2, + null_ok=True, + ) + for i, field in enumerate(fields(column)): + assert getattr(column, field.name) == column[i] From c07db142fb83c49c25b071ec14cf93350e923558 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Mon, 14 Apr 2025 19:00:37 +0300 Subject: [PATCH 08/10] resolve more comments --- src/firebolt/async_db/cursor.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index c5bb278cf3..55d82c391f 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -2,6 +2,7 @@ import logging import time +import warnings from abc import ABCMeta, abstractmethod from functools import wraps from types import TracebackType @@ -383,6 +384,8 @@ async def fetchone(self) -> Optional[List[ColType]]: assert self._row_set is not None with Timer(self._performance_log_message): # anext() is only supported in Python 3.10+ + # this means we cannot just do return anext(self._row_set), + # we need to handle iteration manually try: return await self._row_set.__anext__() except StopAsyncIteration: @@ -452,6 +455,18 @@ async def __anext__(self) -> List[ColType]: assert self._row_set is not None return await self._row_set.__anext__() + @check_not_closed + def __enter__(self) -> Cursor: + warnings.warn( + "Using __enter__ is deprecated, use async with instead", DeprecationWarning + ) + return self + + def __exit__( + self, exc_type: type, exc_val: Exception, exc_tb: TracebackType + ) -> None: + return None + class CursorV2(Cursor): def __init__( From 4724e740f855306657198cda63908a4aba32bb91 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Mon, 14 Apr 2025 19:31:02 +0300 Subject: [PATCH 09/10] Update src/firebolt/async_db/cursor.py Co-authored-by: Petro Tiurin <93913847+ptiurin@users.noreply.github.com> --- src/firebolt/async_db/cursor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 55d82c391f..654edbd9b0 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -458,7 +458,7 @@ async def __anext__(self) -> List[ColType]: @check_not_closed def __enter__(self) -> Cursor: warnings.warn( - "Using __enter__ is deprecated, use async with instead", DeprecationWarning + "Using __enter__ is deprecated, use 'async with' instead", DeprecationWarning ) return self From f1eb789473038042aa693162b191fa9b046f5aea Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Tue, 15 Apr 2025 10:04:27 +0300 Subject: [PATCH 10/10] fix linters --- src/firebolt/async_db/cursor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 654edbd9b0..a2a2ede7e2 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -458,7 +458,8 @@ async def __anext__(self) -> List[ColType]: @check_not_closed def __enter__(self) -> Cursor: warnings.warn( - "Using __enter__ is deprecated, use 'async with' instead", DeprecationWarning + "Using __enter__ is deprecated, use 'async with' instead", + DeprecationWarning, ) return self