Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 53 additions & 6 deletions src/firebolt/async_db/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
from datetime import date, datetime, timezone
from decimal import Decimal
from enum import Enum
from typing import List, Sequence, Union
from typing import List, Optional, Sequence, Union

from sqlparse import parse as parse_sql # type: ignore
from sqlparse.sql import Statement, Token, TokenList # type: ignore
from sqlparse.sql import ( # type: ignore
Comparison,
Statement,
Token,
TokenList,
)
from sqlparse.tokens import Token as TokenType # type: ignore

try:
Expand All @@ -22,7 +27,11 @@ def parse_datetime(date_string: str) -> datetime: # type: ignore
return datetime.fromisoformat(date_string)


from firebolt.common.exception import DataError, NotSupportedError
from firebolt.common.exception import (
DataError,
InterfaceError,
NotSupportedError,
)
from firebolt.common.util import cached_property

_NoneType = type(None)
Expand Down Expand Up @@ -312,7 +321,7 @@ def process_token(token: Token) -> Token:
return TokenList([process_token(t) for t in token.tokens])
return token

formatted_sql = str(process_token(statement)).rstrip(";")
formatted_sql = statement_to_sql(process_token(statement))

if idx < len(parameters):
raise DataError(
Expand All @@ -323,9 +332,43 @@ def process_token(token: Token) -> Token:
return formatted_sql


SetParameter = namedtuple("SetParameter", ["name", "value"])


def statement_to_set(statement: Statement) -> Optional[SetParameter]:
"""Try to parse statement as a SET command. Return None if it's not a SET command"""
# Filter out meaningless tokens like Punctuation and Whitespaces
tokens = [
token
for token in statement.tokens
if token.ttype == TokenType.Keyword or isinstance(token, Comparison)
]

# Check if it's a SET statement by checking if it starts with set
if (
len(tokens) > 0
and tokens[0].ttype == TokenType.Keyword
and tokens[0].value.lower() == "set"
):
# Check if set statement has a valid format
if len(tokens) != 2 or not isinstance(tokens[1], Comparison):
raise InterfaceError(
f"Invalid set statement format: {statement_to_sql(statement)},"
" expected SET <param> = <value>"
)
return SetParameter(
statement_to_sql(tokens[1].left), statement_to_sql(tokens[1].right)
)
return None


def statement_to_sql(statement: Statement) -> str:
return str(statement).strip().rstrip(";")


def split_format_sql(
query: str, parameters: Sequence[Sequence[ParameterType]]
) -> List[str]:
) -> List[Union[str, SetParameter]]:
"""
Split a query into separate statement, and format it with parameters
if it's a single statement
Expand All @@ -340,5 +383,9 @@ def split_format_sql(
raise NotSupportedError(
"formatting multistatement queries is not supported"
)
if statement_to_set(statements[0]):
raise NotSupportedError("formatting set statements is not supported")
return [format_statement(statements[0], paramset) for paramset in parameters]
return [str(st).strip().rstrip(";") for st in statements]

# Try parsing each statement as a SET, otherwise return as a plain sql string
return [statement_to_set(st) or statement_to_sql(st) for st in statements]
139 changes: 90 additions & 49 deletions src/firebolt/async_db/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Column,
ParameterType,
RawColType,
SetParameter,
parse_type,
parse_value,
split_format_sql,
Expand Down Expand Up @@ -100,6 +101,7 @@ class BaseCursor:
"_idx_lock",
"_row_sets",
"_next_set_idx",
"_set_parameters",
)

default_arraysize = 1
Expand All @@ -114,6 +116,7 @@ def __init__(self, client: AsyncClient, connection: Connection):
self._row_sets: List[
Tuple[int, Optional[List[Column]], Optional[List[List[RawColType]]]]
] = []
self._set_parameters: Dict[str, Any] = dict()
self._rowcount = -1
self._idx = 0
self._next_set_idx = 0
Expand Down Expand Up @@ -172,37 +175,6 @@ def close(self) -> None:
# remove typecheck skip after connection is implemented
self.connection._remove_cursor(self) # type: ignore

def _append_query_data(self, response: Response) -> None:
"""Store information about executed query from httpx response."""

row_set: Tuple[
int, Optional[List[Column]], Optional[List[List[RawColType]]]
] = (-1, None, None)

# Empty response is returned for insert query
if response.headers.get("content-length", "") != "0":
try:
# Skip parsing floats to properly parse them later
query_data = response.json(parse_float=str)
rowcount = int(query_data["rows"])
descriptions = [
Column(
d["name"], parse_type(d["type"]), None, None, None, None, None
)
for d in query_data["meta"]
]

# Parse data during fetch
rows = query_data["data"]
row_set = (rowcount, descriptions, rows)
except (KeyError, ValueError) as err:
raise DataError(f"Invalid query data format: {str(err)}")

self._row_sets.append(row_set)
if self._next_set_idx == 0:
# Populate values for first set
self._pop_next_set()

@check_not_closed
@check_query_executed
def nextset(self) -> Optional[bool]:
Expand All @@ -227,6 +199,9 @@ def _pop_next_set(self) -> Optional[bool]:
self._next_set_idx += 1
return True

def flush_parameters(self) -> None:
self._set_parameters = dict()

async def _raise_if_error(self, resp: Response) -> None:
"""Raise a proper error if any"""
if resp.status_code == codes.INTERNAL_SERVER_ERROR:
Expand Down Expand Up @@ -260,39 +235,105 @@ def _reset(self) -> None:
self._row_sets = []
self._next_set_idx = 0

async def _do_execute_request(
def _row_set_from_response(
self, response: Response
) -> Tuple[int, Optional[List[Column]], Optional[List[List[RawColType]]]]:
"""Fetch information about executed query from http response"""

# Empty response is returned for insert query
if response.headers.get("content-length", "") == "0":
return (-1, None, None)

try:
# Skip parsing floats to properly parse them later
query_data = response.json(parse_float=str)
rowcount = int(query_data["rows"])
descriptions = [
Column(d["name"], parse_type(d["type"]), None, None, None, None, None)
for d in query_data["meta"]
]

# Parse data during fetch
rows = query_data["data"]
return (rowcount, descriptions, rows)
except (KeyError, ValueError) as err:
raise DataError(f"Invalid query data format: {str(err)}")

def _append_row_set(
self,
query: str,
row_set: Tuple[int, Optional[List[Column]], Optional[List[List[RawColType]]]],
) -> None:
"""Store information about executed query."""
self._row_sets.append(row_set)
if self._next_set_idx == 0:
# Populate values for first set
self._pop_next_set()

async def _api_request(
self, query: str, set_parameters: Optional[dict]
) -> Response:
return await self._client.request(
url="/",
method="POST",
params={
"database": self.connection.database,
"output_format": JSON_OUTPUT_FORMAT,
**self._set_parameters,
**(set_parameters or dict()),
},
content=query,
)

async def _do_execute(
self,
raw_query: str,
parameters: Sequence[Sequence[ParameterType]],
set_parameters: Optional[Dict] = None,
) -> None:
self._reset()
if set_parameters is not None:
logger.warning(
"Passing set parameters as an argument is deprecated. Please run "
"a query 'SET <param> = <value>'"
)
try:

queries = split_format_sql(query, parameters)
queries = split_format_sql(raw_query, parameters)

for query in queries:

start_time = time.time()
# our CREATE EXTERNAL TABLE queries currently require credentials,
# so we will skip logging those queries.
# https://docs.firebolt.io/sql-reference/commands/ddl-commands#create-external-table
if not re.search("aws_key_id|credentials", query, flags=re.IGNORECASE):
if isinstance(query, SetParameter) or not re.search(
"aws_key_id|credentials", query, flags=re.IGNORECASE
):
logger.debug(f"Running query: {query}")

resp = await self._client.request(
url="/",
method="POST",
params={
"database": self.connection.database,
"output_format": JSON_OUTPUT_FORMAT,
**(set_parameters or dict()),
},
content=query,
)
# Define type for mypy
row_set: Tuple[
int, Optional[List[Column]], Optional[List[List[RawColType]]]
] = (-1, None, None)
if isinstance(query, SetParameter):
# Validate parameter by executing simple query with it
resp = await self._api_request(
"select 1", {query.name: query.value}
)
# Handle invalid set parameter
if resp.status_code == codes.BAD_REQUEST:
raise OperationalError(resp.text)
await self._raise_if_error(resp)

# set parameter passed validation
self._set_parameters[query.name] = query.value
else:
resp = await self._api_request(query, set_parameters)
await self._raise_if_error(resp)
row_set = self._row_set_from_response(resp)

self._append_row_set(row_set)

await self._raise_if_error(resp)
self._append_query_data(resp)
logger.info(
f"Query fetched {self.rowcount} rows in"
f" {time.time() - start_time} seconds"
Expand All @@ -314,7 +355,7 @@ async def execute(
"""Prepare and execute a database query. Return row count."""

params_list = [parameters] if parameters else []
await self._do_execute_request(query, params_list, set_parameters)
await self._do_execute(query, params_list, set_parameters)
return self.rowcount

@check_not_closed
Expand All @@ -325,7 +366,7 @@ async def executemany(
Prepare and execute a database query against all parameter
sequences provided. Return last query row count.
"""
await self._do_execute_request(query, parameters_seq)
await self._do_execute(query, parameters_seq)
return self.rowcount

def _parse_row(self, row: List[RawColType]) -> List[ColType]:
Expand Down
8 changes: 0 additions & 8 deletions tests/conftest.py

This file was deleted.

Loading