Skip to content
41 changes: 28 additions & 13 deletions src/firebolt/async_db/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from collections import namedtuple
from datetime import date, datetime, timezone
from enum import Enum
from typing import Sequence, Union
from typing import List, Sequence, Union

from sqlparse import parse as parse_sql # type: ignore
from sqlparse.sql import Token, TokenList # type: ignore
from sqlparse.sql import Statement, Token, TokenList # type: ignore
from sqlparse.tokens import Token as TokenType # type: ignore

try:
Expand Down Expand Up @@ -224,10 +224,9 @@ def format_value(value: ParameterType) -> str:
raise DataError(f"unsupported parameter type {type(value)}")


def format_sql(query: str, parameters: Sequence[ParameterType]) -> str:
def format_statement(statement: Statement, parameters: Sequence[ParameterType]) -> str:
"""
Substitute placeholders in queries with provided values.
'?' symbol is used as a placeholder. Using '\\?' would result in a plain '?'
Substitute placeholders in a sqlparse statement with provided values.
"""
idx = 0

Expand All @@ -245,16 +244,11 @@ def process_token(token: Token) -> Token:
return Token(TokenType.Text, formatted)
if isinstance(token, TokenList):
# Process all children tokens
token.tokens = [process_token(t) for t in token.tokens]
return token

parsed = parse_sql(query)
if not parsed:
return query
if len(parsed) > 1:
raise NotSupportedError("Multi-statement queries are not supported")
return TokenList([process_token(t) for t in token.tokens])
return token

formatted_sql = str(process_token(parsed[0]))
formatted_sql = str(process_token(statement)).rstrip(";")

if idx < len(parameters):
raise DataError(
Expand All @@ -263,3 +257,24 @@ def process_token(token: Token) -> Token:
)

return formatted_sql


def split_format_sql(
query: str, parameters: Sequence[Sequence[ParameterType]]
) -> List[str]:
"""
Split a query into separate statement, and format it with parameters
if it's a single statement
Trying to format a multi-statement query would result in NotSupportedError
"""
statements = parse_sql(query)
if not statements:
return [query]

if parameters:
if len(statements) > 1:
raise NotSupportedError(
"formatting multistatement queries is not supported"
)
return [format_statement(statements[0], paramset) for paramset in parameters]
return [str(st).strip().rstrip(";") for st in statements]
165 changes: 103 additions & 62 deletions src/firebolt/async_db/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from enum import Enum
from functools import wraps
from inspect import cleandoc
from json import JSONDecodeError
from types import TracebackType
from typing import (
TYPE_CHECKING,
Expand All @@ -27,9 +26,9 @@
Column,
ParameterType,
RawColType,
format_sql,
parse_type,
parse_value,
split_format_sql,
)
from firebolt.async_db.util import is_db_available, is_engine_running
from firebolt.client import AsyncClient
Expand All @@ -38,7 +37,6 @@
DataError,
EngineNotRunningError,
FireboltDatabaseError,
NotSupportedError,
OperationalError,
ProgrammingError,
QueryNotRunError,
Expand All @@ -55,6 +53,7 @@

class CursorState(Enum):
NONE = 1
ERROR = 2
DONE = 3
CLOSED = 4

Expand Down Expand Up @@ -99,6 +98,8 @@ class BaseCursor:
"_rows",
"_idx",
"_idx_lock",
"_row_sets",
"_next_set_idx",
)

default_arraysize = 1
Expand All @@ -107,8 +108,15 @@ def __init__(self, client: AsyncClient, connection: Connection):
self.connection = connection
self._client = client
self._arraysize = self.default_arraysize
# These fields initialized here for type annotations purpose
self._rows: Optional[List[List[RawColType]]] = None
self._descriptions: Optional[List[Column]] = None
self._row_sets: List[
Tuple[int, Optional[List[Column]], Optional[List[List[RawColType]]]]
] = []
self._rowcount = -1
self._idx = 0
self._next_set_idx = 0
self._reset()

def __del__(self) -> None:
Expand Down Expand Up @@ -164,24 +172,58 @@ def close(self) -> None:
# remove typecheck skip after connection is implemented
self.connection._remove_cursor(self) # type: ignore

def _store_query_data(self, response: Response) -> None:
def _append_query_data(self, response: Response) -> None:
"""Store information about executed query from httpx response."""

row_set: Tuple[
int, Optional[List[Column]], Optional[List[List[RawColType]]]
] = (-1, None, None)

# Empty response is returned for insert query
if response.headers.get("content-length", "") == "0":
return
try:
query_data = response.json()
self._rowcount = int(query_data["rows"])
self._descriptions = [
Column(d["name"], parse_type(d["type"]), None, None, None, None, None)
for d in query_data["meta"]
]

# Parse data during fetch
self._rows = query_data["data"]
except (KeyError, JSONDecodeError) as err:
raise DataError(f"Invalid query data format: {str(err)}")
if response.headers.get("content-length", "") != "0":
try:
query_data = response.json()
rowcount = int(query_data["rows"])
descriptions = [
Column(
d["name"], parse_type(d["type"]), None, None, None, None, None
)
for d in query_data["meta"]
]

# Parse data during fetch
rows = query_data["data"]
row_set = (rowcount, descriptions, rows)
except (KeyError, ValueError) as err:
raise DataError(f"Invalid query data format: {str(err)}")

self._row_sets.append(row_set)
if self._next_set_idx == 0:
# Populate values for first set
self._pop_next_set()

@check_not_closed
@check_query_executed
def nextset(self) -> Optional[bool]:
"""
Skip to the next available set, discarding any remaining rows
from the current set.
Returns True if operation was successful,
None if there are no more sets to retrive
"""
return self._pop_next_set()

def _pop_next_set(self) -> Optional[bool]:
"""
Same functionality as .nextset, but doesn't check that query has been executed.
"""
if self._next_set_idx >= len(self._row_sets):
return None
self._rowcount, self._descriptions, self._rows = self._row_sets[
self._next_set_idx
]
self._next_set_idx += 1
return True

async def _raise_if_error(self, resp: Response) -> None:
"""Raise a proper error if any"""
Expand Down Expand Up @@ -213,29 +255,52 @@ def _reset(self) -> None:
self._descriptions = None
self._rowcount = -1
self._idx = 0
self._row_sets = []
self._next_set_idx = 0

async def _do_execute_request(
self,
query: str,
parameters: Optional[Sequence[ParameterType]] = None,
parameters: Sequence[Sequence[ParameterType]],
set_parameters: Optional[Dict] = None,
) -> Response:
if parameters:
query = format_sql(query, parameters)

resp = await self._client.request(
url="/",
method="POST",
params={
"database": self.connection.database,
"output_format": JSON_OUTPUT_FORMAT,
**(set_parameters or dict()),
},
content=query,
)
) -> None:
self._reset()
try:

queries = split_format_sql(query, parameters)

for query in queries:

start_time = time.time()
# our CREATE EXTERNAL TABLE queries currently require credentials,
# so we will skip logging those queries.
# https://docs.firebolt.io/sql-reference/commands/ddl-commands#create-external-table
if not re.search("aws_key_id|credentials", query, flags=re.IGNORECASE):
logger.debug(f"Running query: {query}")

resp = await self._client.request(
url="/",
method="POST",
params={
"database": self.connection.database,
"output_format": JSON_OUTPUT_FORMAT,
**(set_parameters or dict()),
},
content=query,
)

await self._raise_if_error(resp)
self._append_query_data(resp)
logger.info(
f"Query fetched {self.rowcount} rows in"
f" {time.time() - start_time} seconds"
)

self._state = CursorState.DONE

await self._raise_if_error(resp)
return resp
except Exception:
self._state = CursorState.ERROR
raise

@check_not_closed
async def execute(
Expand All @@ -245,21 +310,9 @@ async def execute(
set_parameters: Optional[Dict] = None,
) -> int:
"""Prepare and execute a database query. Return row count."""
start_time = time.time()

# our CREATE EXTERNAL TABLE queries currently require credentials,
# so we will skip logging those queries.
# https://docs.firebolt.io/sql-reference/commands/ddl-commands#create-external-table
if not re.search("aws_key_id|credentials", query, flags=re.IGNORECASE):
logger.debug(f"Running query: {query}")

self._reset()
resp = await self._do_execute_request(query, parameters, set_parameters)
self._store_query_data(resp)
self._state = CursorState.DONE
logger.info(
f"Query fetched {self.rowcount} rows in {time.time() - start_time} seconds"
)
params_list = [parameters] if parameters else []
await self._do_execute_request(query, params_list, set_parameters)
return self.rowcount

@check_not_closed
Expand All @@ -270,19 +323,7 @@ async def executemany(
Prepare and execute a database query against all parameter
sequences provided. Return last query row count.
"""

if len(parameters_seq) > 1:
raise NotSupportedError(
"Parameterized multi-statement queries are not supported"
)

self._reset()
resp = None
for parameters in parameters_seq:
resp = await self._do_execute_request(query, parameters)
if resp is not None:
self._store_query_data(resp)
self._state = CursorState.DONE
await self._do_execute_request(query, parameters_seq)
return self.rowcount

def _parse_row(self, row: List[RawColType]) -> List[ColType]:
Expand Down
41 changes: 41 additions & 0 deletions tests/integration/dbapi/async/test_queries_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,44 @@ async def test_empty_query(c: Cursor, query: str, params: tuple) -> None:
[params + ["?"]],
"Invalid data in table after parameterized insert",
)


@mark.asyncio
async def test_multi_statement_query(connection: Connection) -> None:
"""Query parameters are handled properly"""

with connection.cursor() as c:
await c.execute("DROP TABLE IF EXISTS test_tb_multi_statement")
await c.execute(
"CREATE FACT TABLE test_tb_multi_statement(i int, s string) primary index i"
)

assert (
await c.execute(
"INSERT INTO test_tb_multi_statement values (1, 'a'), (2, 'b');"
"SELECT * FROM test_tb_multi_statement"
)
== -1
), "Invalid row count returned for insert"
assert c.rowcount == -1, "Invalid row count"
assert c.description is None, "Invalid description"

assert c.nextset()

assert c.rowcount == 2, "Invalid select row count"
assert_deep_eq(
c.description,
[
Column("i", int, None, None, None, None, None),
Column("s", str, None, None, None, None, None),
],
"Invalid select query description",
)

assert_deep_eq(
await c.fetchall(),
[[1, "a"], [2, "b"]],
"Invalid data in table after parameterized insert",
)

assert c.nextset() is None
Loading