Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/firebolt/async_db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ async def aclose(self) -> None:
for c in cursors:
# Here c can already be closed by another thread,
# but it shouldn't raise an error in this case
c.close()
await c.aclose()
await self._client.aclose()
self._is_closed = True

Expand Down
145 changes: 95 additions & 50 deletions src/firebolt/async_db/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,11 @@

import logging
import time
import warnings
from abc import ABCMeta, abstractmethod
from functools import wraps
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterator,
List,
Optional,
Sequence,
Union,
)
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from urllib.parse import urljoin

from httpx import (
Expand All @@ -28,20 +20,26 @@

from firebolt.client.client import AsyncClient, AsyncClientV1, AsyncClientV2
from firebolt.common._types import ColType, ParameterType, SetParameter
from firebolt.common.base_cursor import (
from firebolt.common.constants import (
JSON_OUTPUT_FORMAT,
RESET_SESSION_HEADER,
UPDATE_ENDPOINT_HEADER,
UPDATE_PARAMETERS_HEADER,
BaseCursor,
CursorState,
)
from firebolt.common.cursor.base_cursor import (
BaseCursor,
_parse_update_endpoint,
_parse_update_parameters,
_raise_if_internal_set_parameter,
)
from firebolt.common.cursor.decorators import (
async_not_allowed,
check_not_closed,
check_query_executed,
)
from firebolt.common.row_set.asynchronous.base import BaseAsyncRowSet
from firebolt.common.row_set.asynchronous.in_memory import InMemoryAsyncRowSet
from firebolt.common.statement_formatter import create_statement_formatter
from firebolt.utils.exception import (
EngineNotRunningError,
Expand All @@ -58,7 +56,12 @@
if TYPE_CHECKING:
from firebolt.async_db.connection import Connection

from firebolt.utils.util import _print_error_body, raise_errors_from_body
from firebolt.utils.async_util import async_islice
from firebolt.utils.util import (
Timer,
_print_error_body,
raise_errors_from_body,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -88,6 +91,7 @@ def __init__(
self._client = client
self.connection = connection
self.engine_url = connection.engine_url
self._row_set: Optional[BaseAsyncRowSet] = None
if connection.init_parameters:
self._update_set_parameters(connection.init_parameters)

Expand Down Expand Up @@ -121,13 +125,14 @@ async def _api_request(
if self.parameters:
parameters = {**self.parameters, **parameters}
try:
return await self._client.request(
req = self._client.build_request(
url=urljoin(self.engine_url.rstrip("/") + "/", path or ""),
method="POST",
params=parameters,
content=query,
timeout=timeout if timeout is not None else USE_CLIENT_DEFAULT,
)
return await self._client.send(req, stream=True)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this the default, previous behavior with stream? I thought this is the refactoring and then you will add the streaming functionality.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We used to read the response here fully and then work with the result in memory. Now we stream the response here, and let the RowSet to decide whether to read the stream in memory and process it or (in future streaming implementation) to read it line by line and not store it in memory

except TimeoutException:
raise QueryTimeoutError()

Expand Down Expand Up @@ -170,6 +175,9 @@ async def _validate_set_parameter(
# set parameter passed validation
self._set_parameters[parameter.name] = parameter.value

# append empty result set
await self._append_row_set_from_response(None)

async def _parse_response_headers(self, headers: Headers) -> None:
if headers.get(UPDATE_ENDPOINT_HEADER):
endpoint, params = _parse_update_endpoint(
Expand Down Expand Up @@ -271,8 +279,7 @@ async def _handle_query_execution(
self._parse_async_response(resp)
else:
await self._parse_response_headers(resp.headers)
row_set = self._row_set_from_response(resp)
self._append_row_set(row_set)
await self._append_row_set_from_response(resp)

@check_not_closed
async def execute(
Expand Down Expand Up @@ -353,75 +360,113 @@ async def executemany(
await self._do_execute(query, parameters_seq, timeout=timeout_seconds)
return self.rowcount

@abstractmethod
async def is_db_available(self, database: str) -> bool:
"""Verify that the database exists."""
...
async def _append_row_set_from_response(
self,
response: Optional[Response],
) -> None:
"""Store information about executed query."""
if self._row_set is None:
self._row_set = InMemoryAsyncRowSet()
if response is None:
self._row_set.append_empty_response()
else:
await self._row_set.append_response(response)

@abstractmethod
async def is_engine_running(self, engine_url: str) -> bool:
"""Verify that the engine is running."""
...
_performance_log_message = (
"[PERFORMANCE] Parsing query output into native Python types"
)

@wraps(BaseCursor.fetchone)
@check_not_closed
@async_not_allowed
@check_query_executed
async def fetchone(self) -> Optional[List[ColType]]:
"""Fetch the next row of a query result set."""
return super().fetchone()
assert self._row_set is not None
with Timer(self._performance_log_message):
# anext() is only supported in Python 3.10+
# this means we cannot just do return anext(self._row_set),
# we need to handle iteration manually
try:
return await self._row_set.__anext__()
except StopAsyncIteration:
return None

@wraps(BaseCursor.fetchmany)
@check_not_closed
@async_not_allowed
@check_query_executed
async def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]:
"""
Fetch the next set of rows of a query result;
size is cursor.arraysize by default.
cursor.arraysize is default size.
"""
return super().fetchmany(size)
assert self._row_set is not None
size = size if size is not None else self.arraysize
with Timer(self._performance_log_message):
return await async_islice(self._row_set, size)

@wraps(BaseCursor.fetchall)
@check_not_closed
@async_not_allowed
@check_query_executed
async def fetchall(self) -> List[List[ColType]]:
"""Fetch all remaining rows of a query result."""
return super().fetchall()
assert self._row_set is not None
with Timer(self._performance_log_message):
return [it async for it in self._row_set]

@wraps(BaseCursor.nextset)
async def nextset(self) -> None:
return super().nextset()

async def aclose(self) -> None:
super().close()
if self._row_set is not None:
await self._row_set.aclose()

@abstractmethod
async def is_db_available(self, database: str) -> bool:
"""Verify that the database exists."""
...

@abstractmethod
async def is_engine_running(self, engine_url: str) -> bool:
"""Verify that the engine is running."""
...

# Iteration support
@check_not_closed
@async_not_allowed
@check_query_executed
def __aiter__(self) -> Cursor:
return self

# TODO: figure out how to implement __aenter__ and __await__
@check_not_closed
def __aenter__(self) -> Cursor:
return self

@check_not_closed
def __enter__(self) -> Cursor:
async def __aenter__(self) -> Cursor:
return self

def __exit__(
self, exc_type: type, exc_val: Exception, exc_tb: TracebackType
) -> None:
self.close()

def __await__(self) -> Iterator:
yield None

async def __aexit__(
self, exc_type: type, exc_val: Exception, exc_tb: TracebackType
) -> None:
self.close()
await self.aclose()

@check_not_closed
@async_not_allowed
@check_query_executed
async def __anext__(self) -> List[ColType]:
row = await self.fetchone()
if row is None:
raise StopAsyncIteration
return row
assert self._row_set is not None
return await self._row_set.__anext__()

@check_not_closed
def __enter__(self) -> Cursor:
warnings.warn(
"Using __enter__ is deprecated, use 'async with' instead",
DeprecationWarning,
)
return self

def __exit__(
self, exc_type: type, exc_val: Exception, exc_tb: TracebackType
) -> None:
return None


class CursorV2(Cursor):
Expand Down
15 changes: 1 addition & 14 deletions src/firebolt/common/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,6 @@ def Binary(value: str) -> bytes: # NOSONAR
DATETIME = datetime
ROWID = int

Column = namedtuple(
"Column",
(
"name",
"type_code",
"display_size",
"internal_size",
"precision",
"scale",
"null_ok",
),
)


class ExtendedType:
"""Base type for all extended types in Firebolt (array, decimal, struct, etc.)."""
Expand Down Expand Up @@ -338,7 +325,7 @@ def parse_value(
raise DataError(f"Invalid bytea value {value}: str expected")
return _parse_bytea(value)
if isinstance(ctype, DECIMAL):
if not isinstance(value, (str, int)):
if not isinstance(value, (str, int, float)):
raise DataError(f"Invalid decimal value {value}: str or int expected")
return Decimal(value)
if isinstance(ctype, ARRAY):
Expand Down
Loading