From 4d3e7946c3c04f68e8a698441a9c93c6e51ba0d4 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 19 Jan 2022 11:55:46 +0200 Subject: [PATCH 01/12] implement split_format_sql --- src/firebolt/async_db/_types.py | 33 ++++++++++++++++------- tests/unit/async_db/test_typing_format.py | 33 +++++++++++++---------- 2 files changed, 42 insertions(+), 24 deletions(-) diff --git a/src/firebolt/async_db/_types.py b/src/firebolt/async_db/_types.py index 9ad06e2cd48..515e50e5f4c 100644 --- a/src/firebolt/async_db/_types.py +++ b/src/firebolt/async_db/_types.py @@ -3,10 +3,11 @@ from collections import namedtuple from datetime import date, datetime, timezone from enum import Enum -from typing import Sequence, Union +from typing import List, Sequence, Union from sqlparse import parse as parse_sql # type: ignore -from sqlparse.sql import Token, TokenList # type: ignore +from sqlparse import split +from sqlparse.sql import Statement, Token, TokenList # type: ignore from sqlparse.tokens import Token as TokenType # type: ignore try: @@ -224,7 +225,7 @@ def format_value(value: ParameterType) -> str: raise DataError(f"unsupported parameter type {type(value)}") -def format_sql(query: str, parameters: Sequence[ParameterType]) -> str: +def format_statement(statement: Statement, parameters: Sequence[ParameterType]) -> str: """ Substitute placeholders in queries with provided values. '?' symbol is used as a placeholder. Using '\\?' would result in a plain '?' @@ -248,13 +249,7 @@ def process_token(token: Token) -> Token: token.tokens = [process_token(t) for t in token.tokens] return token - parsed = parse_sql(query) - if not parsed: - return query - if len(parsed) > 1: - raise NotSupportedError("Multi-statement queries are not supported") - - formatted_sql = str(process_token(parsed[0])) + formatted_sql = str(process_token(statement)) if idx < len(parameters): raise DataError( @@ -263,3 +258,21 @@ def process_token(token: Token) -> Token: ) return formatted_sql + + +def format_sql(query: str, parameters: Sequence[ParameterType]) -> str: + return format_statement(parse_sql(query)[0], parameters) + + +def split_format_sql(query: str, parameters: Sequence[ParameterType]) -> List[str]: + statements = split(query) + if not statements: + return query + + if parameters: + if len(statements) > 1: + raise NotSupportedError( + "formatting multistatement queries is not supported" + ) + return [format_statement(statements[0])] + return [str(st) for st in statements] diff --git a/tests/unit/async_db/test_typing_format.py b/tests/unit/async_db/test_typing_format.py index bb03b79b29d..763b3c2993e 100644 --- a/tests/unit/async_db/test_typing_format.py +++ b/tests/unit/async_db/test_typing_format.py @@ -1,9 +1,11 @@ from datetime import date, datetime, timedelta, timezone from pytest import mark, raises +from sqlparse import parse +from sqlparse.sql import Statement from firebolt.async_db import DataError -from firebolt.async_db._types import format_sql, format_value +from firebolt.async_db._types import format_statement, format_value @mark.parametrize( @@ -45,52 +47,55 @@ def test_format_value_errors() -> None: assert str(exc_info.value) == "unsupported parameter type " +def to_statement(sql: str) -> Statement: + return parse(sql)[0] + + @mark.parametrize( - "sql,params,result", + "statement,params,result", [ - ("", (), ""), - ("select * from table", (), "select * from table"), + (to_statement("select * from table"), (), "select * from table"), ( - "select * from table where id == ?", + to_statement("select * from table where id == ?"), (1,), "select * from table where id == 1", ), ( - "select * from table where id == '?'", + to_statement("select * from table where id == '?'"), (), "select * from table where id == '?'", ), ( - "insert into table values (?, ?, '?')", + to_statement("insert into table values (?, ?, '?')"), (1, "1"), "insert into table values (1, '1', '?')", ), ( - "select * from t where /*comment ?*/ id == ?", + to_statement("select * from t where /*comment ?*/ id == ?"), ("*/ 1 == 1 or /*",), "select * from t where /*comment ?*/ id == '*/ 1 == 1 or /*'", ), ( - "select * from t where id == ?", + to_statement("select * from t where id == ?"), ("' or '' == '",), r"select * from t where id == '\' or \'\' == \''", ), ], ) -def test_format_sql(sql: str, params: tuple, result: str) -> None: - assert format_sql(sql, params) == result, "Invalid format sql result" +def test_format_statement(statement: Statement, params: tuple, result: str) -> None: + assert format_statement(statement, params) == result, "Invalid format sql result" -def test_format_sql_errors() -> None: +def test_format_statement_errors() -> None: with raises(DataError) as exc_info: - format_sql("?", []) + format_statement(to_statement("?"), []) assert ( str(exc_info.value) == "not enough parameters provided for substitution: given 0, found one more" ), "Invalid not enought parameters error" with raises(DataError) as exc_info: - format_sql("?", (1, 2)) + format_statement(to_statement("?"), (1, 2)) assert ( str(exc_info.value) == "too many parameters provided for substitution: given 2, used only 1" From d8e471529a2d12ca805514e40b9c404332dd6834 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 19 Jan 2022 13:00:24 +0200 Subject: [PATCH 02/12] add split_format_sql unit tests --- src/firebolt/async_db/_types.py | 11 +++--- tests/unit/async_db/test_typing_format.py | 46 ++++++++++++++++++++++- 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/src/firebolt/async_db/_types.py b/src/firebolt/async_db/_types.py index 515e50e5f4c..1a645d2ea15 100644 --- a/src/firebolt/async_db/_types.py +++ b/src/firebolt/async_db/_types.py @@ -6,7 +6,6 @@ from typing import List, Sequence, Union from sqlparse import parse as parse_sql # type: ignore -from sqlparse import split from sqlparse.sql import Statement, Token, TokenList # type: ignore from sqlparse.tokens import Token as TokenType # type: ignore @@ -249,7 +248,7 @@ def process_token(token: Token) -> Token: token.tokens = [process_token(t) for t in token.tokens] return token - formatted_sql = str(process_token(statement)) + formatted_sql = str(process_token(statement)).rstrip(";") if idx < len(parameters): raise DataError( @@ -265,14 +264,14 @@ def format_sql(query: str, parameters: Sequence[ParameterType]) -> str: def split_format_sql(query: str, parameters: Sequence[ParameterType]) -> List[str]: - statements = split(query) + statements = parse_sql(query) if not statements: - return query + return [query] if parameters: if len(statements) > 1: raise NotSupportedError( "formatting multistatement queries is not supported" ) - return [format_statement(statements[0])] - return [str(st) for st in statements] + return [format_statement(statements[0], parameters)] + return [str(st).strip().rstrip(";") for st in statements] diff --git a/tests/unit/async_db/test_typing_format.py b/tests/unit/async_db/test_typing_format.py index 763b3c2993e..0af568ab28c 100644 --- a/tests/unit/async_db/test_typing_format.py +++ b/tests/unit/async_db/test_typing_format.py @@ -1,11 +1,16 @@ from datetime import date, datetime, timedelta, timezone +from typing import List from pytest import mark, raises from sqlparse import parse from sqlparse.sql import Statement -from firebolt.async_db import DataError -from firebolt.async_db._types import format_statement, format_value +from firebolt.async_db import DataError, NotSupportedError +from firebolt.async_db._types import ( + format_statement, + format_value, + split_format_sql, +) @mark.parametrize( @@ -100,3 +105,40 @@ def test_format_statement_errors() -> None: str(exc_info.value) == "too many parameters provided for substitution: given 2, used only 1" ), "Invalid not enought parameters error" + + +@mark.parametrize( + "query,params,result", + [ + ("", (), [""]), + ("select * from t", (), ["select * from t"]), + ("select * from t;", (), ["select * from t"]), + ("select * from t where id == ?", (1,), ["select * from t where id == 1"]), + ("select * from t where id == ?;", (1,), ["select * from t where id == 1"]), + ( + "select * from t;insert into t values (1, 2)", + (), + ["select * from t", "insert into t values (1, 2)"], + ), + ( + "insert into t values (1, 2);select * from t;", + (), + ["insert into t values (1, 2)", "select * from t"], + ), + ], +) +def test_split_format_sql(query: str, params: tuple, result: List[str]) -> None: + assert ( + split_format_sql(query, params) == result + ), "Invalid split and format sql result" + + +def test_split_format_error() -> None: + with raises(NotSupportedError) as exc_info: + split_format_sql( + "select * from t where id == ?; insert into t values (?, ?)", (1, 2, 3) + ) + + assert ( + str(exc_info.value) == "formatting multistatement queries is not supported" + ), "Invalid not supported error message" From 3e20bae315598b00bf47b56028c45f50b07e5b8b Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 19 Jan 2022 16:52:16 +0200 Subject: [PATCH 03/12] add cursor error state --- src/firebolt/async_db/cursor.py | 34 ++++++++++++++++++------------ tests/unit/async_db/test_cursor.py | 15 +++++++++++++ tests/unit/db/test_cursor.py | 13 ++++++++++++ 3 files changed, 48 insertions(+), 14 deletions(-) diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 2418bd4d4a2..236703b4c42 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -55,6 +55,7 @@ class CursorState(Enum): NONE = 1 + ERROR = 2 DONE = 3 CLOSED = 4 @@ -220,21 +221,26 @@ async def _do_execute_request( parameters: Optional[Sequence[ParameterType]] = None, set_parameters: Optional[Dict] = None, ) -> Response: - if parameters: - query = format_sql(query, parameters) - - resp = await self._client.request( - url="/", - method="POST", - params={ - "database": self.connection.database, - "output_format": JSON_OUTPUT_FORMAT, - **(set_parameters or dict()), - }, - content=query, - ) + try: + if parameters: + query = format_sql(query, parameters) + + resp = await self._client.request( + url="/", + method="POST", + params={ + "database": self.connection.database, + "output_format": JSON_OUTPUT_FORMAT, + **(set_parameters or dict()), + }, + content=query, + ) + + await self._raise_if_error(resp) + except Exception: + self._state = CursorState.ERROR + raise - await self._raise_if_error(resp) return resp @check_not_closed diff --git a/tests/unit/async_db/test_cursor.py b/tests/unit/async_db/test_cursor.py index 5dda697db74..31799685681 100644 --- a/tests/unit/async_db/test_cursor.py +++ b/tests/unit/async_db/test_cursor.py @@ -37,6 +37,16 @@ async def test_cursor_state( await cursor.execute("select") assert cursor._state == CursorState.DONE + def error_query_callback(*args, **kwargs): + raise Exception() + + httpx_mock.add_callback(error_query_callback, url=query_url) + + cursor._reset() + with raises(Exception): + await cursor.execute("select") + assert cursor._state == CursorState.ERROR + cursor._reset() assert cursor._state == CursorState.NONE @@ -204,6 +214,7 @@ def http_error(**kwargs): with raises(StreamError) as excinfo: await query() + assert cursor._state == CursorState.ERROR assert str(excinfo.value) == "httpx error", "Invalid query error message" # HTTP error @@ -212,6 +223,7 @@ def http_error(**kwargs): await query() errmsg = str(excinfo.value) + assert cursor._state == CursorState.ERROR assert "Bad Request" in errmsg, "Invalid query error message" # Database query error @@ -223,6 +235,7 @@ def http_error(**kwargs): with raises(OperationalError) as excinfo: await query() + assert cursor._state == CursorState.ERROR assert ( str(excinfo.value) == "Error executing query:\nQuery error message" ), "Invalid authentication error message" @@ -239,6 +252,7 @@ def http_error(**kwargs): ) with raises(FireboltDatabaseError) as excinfo: await query() + assert cursor._state == CursorState.ERROR # Engine is not running error httpx_mock.add_response( @@ -255,6 +269,7 @@ def http_error(**kwargs): ) with raises(EngineNotRunningError) as excinfo: await query() + assert cursor._state == CursorState.ERROR httpx_mock.reset(True) diff --git a/tests/unit/db/test_cursor.py b/tests/unit/db/test_cursor.py index 13da4c989e4..67db8391b90 100644 --- a/tests/unit/db/test_cursor.py +++ b/tests/unit/db/test_cursor.py @@ -33,6 +33,16 @@ def test_cursor_state( cursor.execute("select") assert cursor._state == CursorState.DONE + def error_query_callback(*args, **kwargs): + raise Exception() + + httpx_mock.add_callback(error_query_callback, url=query_url) + + cursor._reset() + with raises(Exception): + cursor.execute("select") + assert cursor._state == CursorState.ERROR + cursor._reset() assert cursor._state == CursorState.NONE @@ -186,6 +196,7 @@ def http_error(**kwargs): with raises(StreamError) as excinfo: query() + assert cursor._state == CursorState.ERROR assert str(excinfo.value) == "httpx error", "Invalid query error message" # HTTP error @@ -194,6 +205,7 @@ def http_error(**kwargs): query() errmsg = str(excinfo.value) + assert cursor._state == CursorState.ERROR assert "Bad Request" in errmsg, "Invalid query error message" # Database query error @@ -205,6 +217,7 @@ def http_error(**kwargs): with raises(OperationalError) as excinfo: query() + assert cursor._state == CursorState.ERROR assert ( str(excinfo.value) == "Error executing query:\nQuery error message" ), "Invalid authentication error message" From a00580c525a80fe05468721e39983651e61948b7 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 19 Jan 2022 17:21:33 +0200 Subject: [PATCH 04/12] enable multi-statement queries --- src/firebolt/async_db/_types.py | 9 +- src/firebolt/async_db/cursor.py | 115 ++++++++++++---------- tests/unit/async_db/test_typing_format.py | 11 ++- 3 files changed, 78 insertions(+), 57 deletions(-) diff --git a/src/firebolt/async_db/_types.py b/src/firebolt/async_db/_types.py index 1a645d2ea15..22fb17d297f 100644 --- a/src/firebolt/async_db/_types.py +++ b/src/firebolt/async_db/_types.py @@ -245,7 +245,8 @@ def process_token(token: Token) -> Token: return Token(TokenType.Text, formatted) if isinstance(token, TokenList): # Process all children tokens - token.tokens = [process_token(t) for t in token.tokens] + + return TokenList([process_token(t) for t in token.tokens]) return token formatted_sql = str(process_token(statement)).rstrip(";") @@ -263,7 +264,9 @@ def format_sql(query: str, parameters: Sequence[ParameterType]) -> str: return format_statement(parse_sql(query)[0], parameters) -def split_format_sql(query: str, parameters: Sequence[ParameterType]) -> List[str]: +def split_format_sql( + query: str, parameters: Sequence(Sequence[ParameterType]) +) -> List[str]: statements = parse_sql(query) if not statements: return [query] @@ -273,5 +276,5 @@ def split_format_sql(query: str, parameters: Sequence[ParameterType]) -> List[st raise NotSupportedError( "formatting multistatement queries is not supported" ) - return [format_statement(statements[0], parameters)] + return [format_statement(statements[0], paramset) for paramset in parameters] return [str(st).strip().rstrip(";") for st in statements] diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 236703b4c42..f82fed1f27b 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -27,9 +27,9 @@ Column, ParameterType, RawColType, - format_sql, parse_type, parse_value, + split_format_sql, ) from firebolt.async_db.util import is_db_available, is_engine_running from firebolt.client import AsyncClient @@ -38,7 +38,6 @@ DataError, EngineNotRunningError, FireboltDatabaseError, - NotSupportedError, OperationalError, ProgrammingError, QueryNotRunError, @@ -100,6 +99,8 @@ class BaseCursor: "_rows", "_idx", "_idx_lock", + "_row_sets", + "_next_set_idx", ) default_arraysize = 1 @@ -110,6 +111,9 @@ def __init__(self, client: AsyncClient, connection: Connection): self._arraysize = self.default_arraysize self._rows: Optional[List[List[RawColType]]] = None self._descriptions: Optional[List[Column]] = None + self._row_sets: List[ + Tuple[int, Optional[List[Column]], Optional[List[List[RawColType]]]] + ] = [] self._reset() def __del__(self) -> None: @@ -165,25 +169,40 @@ def close(self) -> None: # remove typecheck skip after connection is implemented self.connection._remove_cursor(self) # type: ignore - def _store_query_data(self, response: Response) -> None: + def _append_query_data(self, response: Response) -> None: """Store information about executed query from httpx response.""" # Empty response is returned for insert query if response.headers.get("content-length", "") == "0": - return + row_set = (-1, None, None) try: query_data = response.json() - self._rowcount = int(query_data["rows"]) - self._descriptions = [ + rowcount = int(query_data["rows"]) + descriptions = [ Column(d["name"], parse_type(d["type"]), None, None, None, None, None) for d in query_data["meta"] ] # Parse data during fetch - self._rows = query_data["data"] + rows = query_data["data"] + row_set = (rowcount, descriptions, rows) except (KeyError, JSONDecodeError) as err: raise DataError(f"Invalid query data format: {str(err)}") + self._row_sets.append(row_set) + if self._next_set_idx == 0: + # Populate values for first set + self.nextset() + + def nextset(self) -> None: + if self._next_set_idx >= len(self._row_sets): + raise ProgrammingError("no more sets in cursor") + cur_set = self._row_sets[self._next_set_idx] + self._rowcount = cur_set[0] + self._descriptions = cur_set[1] + self._rows = cur_set[2] + self._next_set_idx += 1 + async def _raise_if_error(self, resp: Response) -> None: """Raise a proper error if any""" if resp.status_code == codes.INTERNAL_SERVER_ERROR: @@ -214,35 +233,53 @@ def _reset(self) -> None: self._descriptions = None self._rowcount = -1 self._idx = 0 + self._row_sets = [] + self._next_set_idx = 0 async def _do_execute_request( self, query: str, - parameters: Optional[Sequence[ParameterType]] = None, + parameters: Sequence[Sequence[ParameterType]] = None, set_parameters: Optional[Dict] = None, ) -> Response: + self._reset() try: - if parameters: - query = format_sql(query, parameters) - - resp = await self._client.request( - url="/", - method="POST", - params={ - "database": self.connection.database, - "output_format": JSON_OUTPUT_FORMAT, - **(set_parameters or dict()), - }, - content=query, - ) - await self._raise_if_error(resp) + queries = split_format_sql(query, parameters) + + for query in queries: + + start_time = time.time() + # our CREATE EXTERNAL TABLE queries currently require credentials, + # so we will skip logging those queries. + # https://docs.firebolt.io/sql-reference/commands/ddl-commands#create-external-table + if not re.search("aws_key_id|credentials", query, flags=re.IGNORECASE): + logger.debug(f"Running query: {query}") + + resp = await self._client.request( + url="/", + method="POST", + params={ + "database": self.connection.database, + "output_format": JSON_OUTPUT_FORMAT, + **(set_parameters or dict()), + }, + content=query, + ) + + await self._raise_if_error(resp) + self._append_query_data(resp) + logger.info( + f"Query fetched {self.rowcount} rows in" + f" {time.time() - start_time} seconds" + ) + + self._state = CursorState.DONE + except Exception: self._state = CursorState.ERROR raise - return resp - @check_not_closed async def execute( self, @@ -251,21 +288,9 @@ async def execute( set_parameters: Optional[Dict] = None, ) -> int: """Prepare and execute a database query. Return row count.""" - start_time = time.time() - - # our CREATE EXTERNAL TABLE queries currently require credentials, - # so we will skip logging those queries. - # https://docs.firebolt.io/sql-reference/commands/ddl-commands#create-external-table - if not re.search("aws_key_id|credentials", query, flags=re.IGNORECASE): - logger.debug(f"Running query: {query}") - self._reset() - resp = await self._do_execute_request(query, parameters, set_parameters) - self._store_query_data(resp) - self._state = CursorState.DONE - logger.info( - f"Query fetched {self.rowcount} rows in {time.time() - start_time} seconds" - ) + params_list = [parameters] if parameters else [] + await self._do_execute_request(query, params_list, set_parameters) return self.rowcount @check_not_closed @@ -276,19 +301,7 @@ async def executemany( Prepare and execute a database query against all parameter sequences provided. Return last query row count. """ - - if len(parameters_seq) > 1: - raise NotSupportedError( - "Parameterized multi-statement queries are not supported" - ) - - self._reset() - resp = None - for parameters in parameters_seq: - resp = await self._do_execute_request(query, parameters) - if resp is not None: - self._store_query_data(resp) - self._state = CursorState.DONE + await self._do_execute_request(query, parameters_seq) return self.rowcount def _parse_row(self, row: List[RawColType]) -> List[ColType]: diff --git a/tests/unit/async_db/test_typing_format.py b/tests/unit/async_db/test_typing_format.py index 0af568ab28c..b394cc7e3b4 100644 --- a/tests/unit/async_db/test_typing_format.py +++ b/tests/unit/async_db/test_typing_format.py @@ -113,8 +113,8 @@ def test_format_statement_errors() -> None: ("", (), [""]), ("select * from t", (), ["select * from t"]), ("select * from t;", (), ["select * from t"]), - ("select * from t where id == ?", (1,), ["select * from t where id == 1"]), - ("select * from t where id == ?;", (1,), ["select * from t where id == 1"]), + ("select * from t where id == ?", ((1,),), ["select * from t where id == 1"]), + ("select * from t where id == ?;", ((1,),), ["select * from t where id == 1"]), ( "select * from t;insert into t values (1, 2)", (), @@ -125,6 +125,11 @@ def test_format_statement_errors() -> None: (), ["insert into t values (1, 2)", "select * from t"], ), + ( + "select * from t where id == ?", + ((1,), (2,)), + ["select * from t where id == 1", "select * from t where id == 2"], + ), ], ) def test_split_format_sql(query: str, params: tuple, result: List[str]) -> None: @@ -136,7 +141,7 @@ def test_split_format_sql(query: str, params: tuple, result: List[str]) -> None: def test_split_format_error() -> None: with raises(NotSupportedError) as exc_info: split_format_sql( - "select * from t where id == ?; insert into t values (?, ?)", (1, 2, 3) + "select * from t where id == ?; insert into t values (?, ?)", ((1, 2, 3),) ) assert ( From cd81fde2966bf65ffc136d753a4ff4b141b6c2b9 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Thu, 20 Jan 2022 10:41:31 +0200 Subject: [PATCH 05/12] fix unit tests --- src/firebolt/async_db/cursor.py | 29 ++++++++++++++++------------- tests/unit/async_db/test_cursor.py | 9 +++------ tests/unit/db/test_cursor.py | 7 ++----- 3 files changed, 21 insertions(+), 24 deletions(-) diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index f82fed1f27b..666ef8bcb0c 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -175,19 +175,22 @@ def _append_query_data(self, response: Response) -> None: # Empty response is returned for insert query if response.headers.get("content-length", "") == "0": row_set = (-1, None, None) - try: - query_data = response.json() - rowcount = int(query_data["rows"]) - descriptions = [ - Column(d["name"], parse_type(d["type"]), None, None, None, None, None) - for d in query_data["meta"] - ] - - # Parse data during fetch - rows = query_data["data"] - row_set = (rowcount, descriptions, rows) - except (KeyError, JSONDecodeError) as err: - raise DataError(f"Invalid query data format: {str(err)}") + else: + try: + query_data = response.json() + rowcount = int(query_data["rows"]) + descriptions = [ + Column( + d["name"], parse_type(d["type"]), None, None, None, None, None + ) + for d in query_data["meta"] + ] + + # Parse data during fetch + rows = query_data["data"] + row_set = (rowcount, descriptions, rows) + except (KeyError, JSONDecodeError) as err: + raise DataError(f"Invalid query data format: {str(err)}") self._row_sets.append(row_set) if self._next_set_idx == 0: diff --git a/tests/unit/async_db/test_cursor.py b/tests/unit/async_db/test_cursor.py index 31799685681..fb242e43723 100644 --- a/tests/unit/async_db/test_cursor.py +++ b/tests/unit/async_db/test_cursor.py @@ -13,7 +13,6 @@ DataError, EngineNotRunningError, FireboltDatabaseError, - NotSupportedError, OperationalError, QueryNotRunError, ) @@ -157,8 +156,8 @@ async def test_cursor_execute( """Cursor is able to execute query, all fields are populated properly.""" for query in ( - lambda: cursor.execute("select *"), - lambda: cursor.executemany("select *", [None]), + lambda: cursor.execute("select * from t"), + lambda: cursor.executemany("select * from t", []), ): # Query with json output httpx_mock.add_callback(auth_callback, url=auth_url) @@ -202,7 +201,7 @@ async def test_cursor_execute_error( """Cursor handles all types of errors properly.""" for query in ( lambda: cursor.execute("select *"), - lambda: cursor.executemany("select *", [None]), + lambda: cursor.executemany("select *", []), ): httpx_mock.add_callback(auth_callback, url=auth_url) @@ -427,5 +426,3 @@ async def test_set_parameters( @mark.asyncio async def test_cursor_multi_statement(cursor: Cursor): """executemany with multiple parameter sets is not supported""" - with raises(NotSupportedError): - await cursor.executemany("select ?", [(1,), (2,)]) diff --git a/tests/unit/db/test_cursor.py b/tests/unit/db/test_cursor.py index 67db8391b90..928b94c6bb9 100644 --- a/tests/unit/db/test_cursor.py +++ b/tests/unit/db/test_cursor.py @@ -9,7 +9,6 @@ from firebolt.common.exception import ( CursorClosedError, DataError, - NotSupportedError, OperationalError, QueryNotRunError, ) @@ -145,7 +144,7 @@ def test_cursor_execute( for query in ( lambda: cursor.execute("select *"), - lambda: cursor.executemany("select *", [None]), + lambda: cursor.executemany("select *", []), ): # Query with json output httpx_mock.add_callback(auth_callback, url=auth_url) @@ -184,7 +183,7 @@ def test_cursor_execute_error( """Cursor handles all types of errors properly.""" for query in ( lambda: cursor.execute("select *"), - lambda: cursor.executemany("select *", [None]), + lambda: cursor.executemany("select *", []), ): httpx_mock.add_callback(auth_callback, url=auth_url) @@ -372,5 +371,3 @@ def test_set_parameters( def test_cursor_multi_statement(cursor: Cursor): """executemany with multiple parameter sets is not supported""" - with raises(NotSupportedError): - cursor.executemany("select ?", [(1,), (2,)]) From 88025621fbbbd367256eb5a1523fea6aa29b1632 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Thu, 20 Jan 2022 11:04:21 +0200 Subject: [PATCH 06/12] add nextset unit tests --- tests/unit/async_db/test_cursor.py | 34 +++++++++++++++++++++++++++++- tests/unit/db/test_cursor.py | 34 +++++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/tests/unit/async_db/test_cursor.py b/tests/unit/async_db/test_cursor.py index fb242e43723..4cce20c99c1 100644 --- a/tests/unit/async_db/test_cursor.py +++ b/tests/unit/async_db/test_cursor.py @@ -424,5 +424,37 @@ async def test_set_parameters( @mark.asyncio -async def test_cursor_multi_statement(cursor: Cursor): +async def test_cursor_multi_statement( + httpx_mock: HTTPXMock, + auth_callback: Callable, + auth_url: str, + query_callback: Callable, + insert_query_callback: Callable, + query_url: str, + cursor: Cursor, + python_query_description: List[Column], + python_query_data: List[List[ColType]], +): """executemany with multiple parameter sets is not supported""" + httpx_mock.add_callback(auth_callback, url=auth_url) + httpx_mock.add_callback(query_callback, url=query_url) + httpx_mock.add_callback(insert_query_callback, url=query_url) + + rc = await cursor.execute("select * from t; insert into t values (1, 2)") + assert rc == len(python_query_data), "Invalid row count returned" + assert cursor.rowcount == len(python_query_data), "Invalid cursor row count" + for i, (desc, exp) in enumerate(zip(cursor.description, python_query_description)): + assert desc == exp, f"Invalid column description at position {i}" + + for i in range(cursor.rowcount): + assert ( + await cursor.fetchone() == python_query_data[i] + ), f"Invalid data row at position {i}" + + cursor.nextset() + assert cursor.rowcount == -1, "Invalid cursor row count" + assert cursor.description is None, "Invalid cursor description" + with raises(DataError) as exc_info: + await cursor.fetchall() + + assert str(exc_info.value) == "no rows to fetch", "Invalid error message" diff --git a/tests/unit/db/test_cursor.py b/tests/unit/db/test_cursor.py index 928b94c6bb9..4edc5986afc 100644 --- a/tests/unit/db/test_cursor.py +++ b/tests/unit/db/test_cursor.py @@ -369,5 +369,37 @@ def test_set_parameters( cursor.execute("select 1", set_parameters=set_params) -def test_cursor_multi_statement(cursor: Cursor): +def test_cursor_multi_statement( + httpx_mock: HTTPXMock, + auth_callback: Callable, + auth_url: str, + query_callback: Callable, + insert_query_callback: Callable, + query_url: str, + cursor: Cursor, + python_query_description: List[Column], + python_query_data: List[List[ColType]], +): """executemany with multiple parameter sets is not supported""" + httpx_mock.add_callback(auth_callback, url=auth_url) + httpx_mock.add_callback(query_callback, url=query_url) + httpx_mock.add_callback(insert_query_callback, url=query_url) + + rc = cursor.execute("select * from t; insert into t values (1, 2)") + assert rc == len(python_query_data), "Invalid row count returned" + assert cursor.rowcount == len(python_query_data), "Invalid cursor row count" + for i, (desc, exp) in enumerate(zip(cursor.description, python_query_description)): + assert desc == exp, f"Invalid column description at position {i}" + + for i in range(cursor.rowcount): + assert ( + cursor.fetchone() == python_query_data[i] + ), f"Invalid data row at position {i}" + + cursor.nextset() + assert cursor.rowcount == -1, "Invalid cursor row count" + assert cursor.description is None, "Invalid cursor description" + with raises(DataError) as exc_info: + cursor.fetchall() + + assert str(exc_info.value) == "no rows to fetch", "Invalid error message" From 66d431993182fefdc19a6d7aa232fec5b7645cf7 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Thu, 20 Jan 2022 14:47:16 +0200 Subject: [PATCH 07/12] add integration tests --- .../dbapi/async/test_queries_async.py | 39 +++++++++++++++++++ tests/integration/dbapi/sync/test_queries.py | 38 ++++++++++++++++++ 2 files changed, 77 insertions(+) diff --git a/tests/integration/dbapi/async/test_queries_async.py b/tests/integration/dbapi/async/test_queries_async.py index e7c24c33880..e56fe1c3a37 100644 --- a/tests/integration/dbapi/async/test_queries_async.py +++ b/tests/integration/dbapi/async/test_queries_async.py @@ -237,3 +237,42 @@ async def test_empty_query(c: Cursor, query: str, params: tuple) -> None: [params + ["?"]], "Invalid data in table after parameterized insert", ) + + +@mark.asyncio +async def test_multi_statement_query(connection: Connection) -> None: + """Query parameters are handled properly""" + + with connection.cursor() as c: + await c.execute("DROP TABLE IF EXISTS test_tb_multi_statement") + await c.execute( + "CREATE FACT TABLE test_tb_multi_statement(i int, s string) primary index i" + ) + + assert ( + await c.execute( + "INSERT INTO test_tb_multi_statement values (1, 'a'), (2, 'b');" + "SELECT * FROM test_tb_multi_statement" + ) + == -1 + ), "Invalid row count returned for insert" + assert c.rowcount == -1, "Invalid row count" + assert c.description is None, "Invalid description" + + c.nextset() + + assert c.rowcount == 2, "Invalid select row count" + assert_deep_eq( + c.description, + [ + Column("i", int, None, None, None, None, None), + Column("s", str, None, None, None, None, None), + ], + "Invalid select query description", + ) + + assert_deep_eq( + await c.fetchall(), + [[1, "a"], [2, "b"]], + "Invalid data in table after parameterized insert", + ) diff --git a/tests/integration/dbapi/sync/test_queries.py b/tests/integration/dbapi/sync/test_queries.py index fa00f9b129f..ac47a227ab7 100644 --- a/tests/integration/dbapi/sync/test_queries.py +++ b/tests/integration/dbapi/sync/test_queries.py @@ -229,3 +229,41 @@ def test_empty_query(c: Cursor, query: str, params: tuple) -> None: [params + ["?"]], "Invalid data in table after parameterized insert", ) + + +def test_multi_statement_query(connection: Connection) -> None: + """Query parameters are handled properly""" + + with connection.cursor() as c: + c.execute("DROP TABLE IF EXISTS test_tb_multi_statement") + c.execute( + "CREATE FACT TABLE test_tb_multi_statement(i int, s string) primary index i" + ) + + assert ( + c.execute( + "INSERT INTO test_tb_multi_statement values (1, 'a'), (2, 'b');" + "SELECT * FROM test_tb_multi_statement" + ) + == -1 + ), "Invalid row count returned for insert" + assert c.rowcount == -1, "Invalid row count" + assert c.description is None, "Invalid description" + + c.nextset() + + assert c.rowcount == 2, "Invalid select row count" + assert_deep_eq( + c.description, + [ + Column("i", int, None, None, None, None, None), + Column("s", str, None, None, None, None, None), + ], + "Invalid select query description", + ) + + assert_deep_eq( + c.fetchall(), + [[1, "a"], [2, "b"]], + "Invalid data in table after parameterized insert", + ) From 1606e5749fb40b4f90679a87ba40b475f2cd31a3 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Thu, 20 Jan 2022 15:25:06 +0200 Subject: [PATCH 08/12] fix mypy issues, improve nextset --- src/firebolt/async_db/_types.py | 2 +- src/firebolt/async_db/cursor.py | 29 ++++++++++++++----- .../dbapi/async/test_queries_async.py | 4 ++- tests/integration/dbapi/sync/test_queries.py | 4 ++- tests/unit/async_db/test_cursor.py | 4 ++- tests/unit/db/test_cursor.py | 4 ++- 6 files changed, 35 insertions(+), 12 deletions(-) diff --git a/src/firebolt/async_db/_types.py b/src/firebolt/async_db/_types.py index 22fb17d297f..46098f492de 100644 --- a/src/firebolt/async_db/_types.py +++ b/src/firebolt/async_db/_types.py @@ -265,7 +265,7 @@ def format_sql(query: str, parameters: Sequence[ParameterType]) -> str: def split_format_sql( - query: str, parameters: Sequence(Sequence[ParameterType]) + query: str, parameters: Sequence[Sequence[ParameterType]] ) -> List[str]: statements = parse_sql(query) if not statements: diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 666ef8bcb0c..5a3e7f9bab8 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -109,11 +109,15 @@ def __init__(self, client: AsyncClient, connection: Connection): self.connection = connection self._client = client self._arraysize = self.default_arraysize + # These fields initialized here for type annotations purpose self._rows: Optional[List[List[RawColType]]] = None self._descriptions: Optional[List[Column]] = None self._row_sets: List[ Tuple[int, Optional[List[Column]], Optional[List[List[RawColType]]]] ] = [] + self._rowcount = -1 + self._idx = 0 + self._next_set_idx = 0 self._reset() def __del__(self) -> None: @@ -172,10 +176,12 @@ def close(self) -> None: def _append_query_data(self, response: Response) -> None: """Store information about executed query from httpx response.""" + row_set: Tuple[ + int, Optional[List[Column]], Optional[List[List[RawColType]]] + ] = (-1, None, None) + # Empty response is returned for insert query - if response.headers.get("content-length", "") == "0": - row_set = (-1, None, None) - else: + if response.headers.get("content-length", "") != "0": try: query_data = response.json() rowcount = int(query_data["rows"]) @@ -197,14 +203,23 @@ def _append_query_data(self, response: Response) -> None: # Populate values for first set self.nextset() - def nextset(self) -> None: + @check_not_closed + @check_query_executed + def nextset(self) -> Optional[bool]: + """ + Skip to the next available set, discarding any remaining rows + from the current set. + Returns True if operation was successful, + None if there are no more sets to retrive + """ if self._next_set_idx >= len(self._row_sets): - raise ProgrammingError("no more sets in cursor") + return None cur_set = self._row_sets[self._next_set_idx] self._rowcount = cur_set[0] self._descriptions = cur_set[1] self._rows = cur_set[2] self._next_set_idx += 1 + return True async def _raise_if_error(self, resp: Response) -> None: """Raise a proper error if any""" @@ -242,9 +257,9 @@ def _reset(self) -> None: async def _do_execute_request( self, query: str, - parameters: Sequence[Sequence[ParameterType]] = None, + parameters: Sequence[Sequence[ParameterType]], set_parameters: Optional[Dict] = None, - ) -> Response: + ) -> None: self._reset() try: diff --git a/tests/integration/dbapi/async/test_queries_async.py b/tests/integration/dbapi/async/test_queries_async.py index e56fe1c3a37..205298f45b8 100644 --- a/tests/integration/dbapi/async/test_queries_async.py +++ b/tests/integration/dbapi/async/test_queries_async.py @@ -259,7 +259,7 @@ async def test_multi_statement_query(connection: Connection) -> None: assert c.rowcount == -1, "Invalid row count" assert c.description is None, "Invalid description" - c.nextset() + assert c.nextset() assert c.rowcount == 2, "Invalid select row count" assert_deep_eq( @@ -276,3 +276,5 @@ async def test_multi_statement_query(connection: Connection) -> None: [[1, "a"], [2, "b"]], "Invalid data in table after parameterized insert", ) + + assert c.nextset() is None diff --git a/tests/integration/dbapi/sync/test_queries.py b/tests/integration/dbapi/sync/test_queries.py index ac47a227ab7..f730405d25b 100644 --- a/tests/integration/dbapi/sync/test_queries.py +++ b/tests/integration/dbapi/sync/test_queries.py @@ -250,7 +250,7 @@ def test_multi_statement_query(connection: Connection) -> None: assert c.rowcount == -1, "Invalid row count" assert c.description is None, "Invalid description" - c.nextset() + assert c.nextset() assert c.rowcount == 2, "Invalid select row count" assert_deep_eq( @@ -267,3 +267,5 @@ def test_multi_statement_query(connection: Connection) -> None: [[1, "a"], [2, "b"]], "Invalid data in table after parameterized insert", ) + + assert c.nextset() is None diff --git a/tests/unit/async_db/test_cursor.py b/tests/unit/async_db/test_cursor.py index 4cce20c99c1..bd5f503473c 100644 --- a/tests/unit/async_db/test_cursor.py +++ b/tests/unit/async_db/test_cursor.py @@ -451,10 +451,12 @@ async def test_cursor_multi_statement( await cursor.fetchone() == python_query_data[i] ), f"Invalid data row at position {i}" - cursor.nextset() + assert cursor.nextset() assert cursor.rowcount == -1, "Invalid cursor row count" assert cursor.description is None, "Invalid cursor description" with raises(DataError) as exc_info: await cursor.fetchall() assert str(exc_info.value) == "no rows to fetch", "Invalid error message" + + assert cursor.nextset() is None diff --git a/tests/unit/db/test_cursor.py b/tests/unit/db/test_cursor.py index 4edc5986afc..f739a20f247 100644 --- a/tests/unit/db/test_cursor.py +++ b/tests/unit/db/test_cursor.py @@ -396,10 +396,12 @@ def test_cursor_multi_statement( cursor.fetchone() == python_query_data[i] ), f"Invalid data row at position {i}" - cursor.nextset() + assert cursor.nextset() assert cursor.rowcount == -1, "Invalid cursor row count" assert cursor.description is None, "Invalid cursor description" with raises(DataError) as exc_info: cursor.fetchall() assert str(exc_info.value) == "no rows to fetch", "Invalid error message" + + assert cursor.nextset() is None From 16a544c2dfb402f56e6836a79a16f66f6f094670 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Thu, 20 Jan 2022 15:29:31 +0200 Subject: [PATCH 09/12] extended tests for nextset --- src/firebolt/async_db/cursor.py | 8 +++++++- tests/unit/async_db/test_cursor.py | 8 ++++---- tests/unit/db/test_cursor.py | 2 ++ 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 5a3e7f9bab8..ed72eccb162 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -201,7 +201,7 @@ def _append_query_data(self, response: Response) -> None: self._row_sets.append(row_set) if self._next_set_idx == 0: # Populate values for first set - self.nextset() + self._pop_next_set() @check_not_closed @check_query_executed @@ -212,6 +212,12 @@ def nextset(self) -> Optional[bool]: Returns True if operation was successful, None if there are no more sets to retrive """ + return self._pop_next_set() + + def _pop_next_set(self) -> Optional[bool]: + """ + Same functionality as .nextset, but doesn't check that query has been executed. + """ if self._next_set_idx >= len(self._row_sets): return None cur_set = self._row_sets[self._next_set_idx] diff --git a/tests/unit/async_db/test_cursor.py b/tests/unit/async_db/test_cursor.py index bd5f503473c..ead05bb0174 100644 --- a/tests/unit/async_db/test_cursor.py +++ b/tests/unit/async_db/test_cursor.py @@ -64,10 +64,7 @@ async def test_closed_cursor(cursor: Cursor): ("fetchmany", ()), ("fetchall", ()), ) - methods = ( - "setinputsizes", - "setoutputsize", - ) + methods = ("setinputsizes", "setoutputsize", "nextset") cursor.close() @@ -118,6 +115,9 @@ async def test_cursor_no_query( with raises(QueryNotRunError): await getattr(cursor, amethod)() + with raises(QueryNotRunError): + await cursor.nextset() + with raises(QueryNotRunError): [r async for r in cursor] diff --git a/tests/unit/db/test_cursor.py b/tests/unit/db/test_cursor.py index f739a20f247..652d7b56d03 100644 --- a/tests/unit/db/test_cursor.py +++ b/tests/unit/db/test_cursor.py @@ -60,6 +60,7 @@ def test_closed_cursor(cursor: Cursor): ("fetchall", ()), ("setinputsizes", (cursor, [0])), ("setoutputsize", (cursor, 0)), + ("nextset", (cursor, [])), ) cursor.close() @@ -97,6 +98,7 @@ def test_cursor_no_query( "fetchone", "fetchmany", "fetchall", + "nextset", ) httpx_mock.add_callback(auth_callback, url=auth_url) From 63efdbe86af8bdd27ba9962458743b504fcbe1b7 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Thu, 20 Jan 2022 17:09:00 +0200 Subject: [PATCH 10/12] address comments --- src/firebolt/async_db/_types.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/firebolt/async_db/_types.py b/src/firebolt/async_db/_types.py index 46098f492de..0d8e17f3d25 100644 --- a/src/firebolt/async_db/_types.py +++ b/src/firebolt/async_db/_types.py @@ -226,8 +226,7 @@ def format_value(value: ParameterType) -> str: def format_statement(statement: Statement, parameters: Sequence[ParameterType]) -> str: """ - Substitute placeholders in queries with provided values. - '?' symbol is used as a placeholder. Using '\\?' would result in a plain '?' + Substitute placeholders in a sqlparse statement with provided values. """ idx = 0 @@ -260,13 +259,14 @@ def process_token(token: Token) -> Token: return formatted_sql -def format_sql(query: str, parameters: Sequence[ParameterType]) -> str: - return format_statement(parse_sql(query)[0], parameters) - - def split_format_sql( query: str, parameters: Sequence[Sequence[ParameterType]] ) -> List[str]: + """ + Split a query into separate statement, and format it with parameters + if it's a single statement + Trying to format a multi-statement query would result in NotSupportedError + """ statements = parse_sql(query) if not statements: return [query] From 65aaedf3ef48e5d3ca2c959261e0982c68d428fc Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Thu, 20 Jan 2022 17:20:58 +0200 Subject: [PATCH 11/12] resolve new comments --- src/firebolt/async_db/cursor.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index ed72eccb162..fde9ca993ef 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -195,7 +195,7 @@ def _append_query_data(self, response: Response) -> None: # Parse data during fetch rows = query_data["data"] row_set = (rowcount, descriptions, rows) - except (KeyError, JSONDecodeError) as err: + except (KeyError, JSONDecodeError, ValueError) as err: raise DataError(f"Invalid query data format: {str(err)}") self._row_sets.append(row_set) @@ -220,10 +220,9 @@ def _pop_next_set(self) -> Optional[bool]: """ if self._next_set_idx >= len(self._row_sets): return None - cur_set = self._row_sets[self._next_set_idx] - self._rowcount = cur_set[0] - self._descriptions = cur_set[1] - self._rows = cur_set[2] + self._rowcount, self._descriptions, self._rows = self._row_sets[ + self._next_set_idx + ] self._next_set_idx += 1 return True From 5d5d25ac39f0b7473afe7fd3a585ba6671a53ca3 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Fri, 21 Jan 2022 12:14:07 +0200 Subject: [PATCH 12/12] remove code smell --- src/firebolt/async_db/cursor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index fde9ca993ef..8f39cf52e48 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -6,7 +6,6 @@ from enum import Enum from functools import wraps from inspect import cleandoc -from json import JSONDecodeError from types import TracebackType from typing import ( TYPE_CHECKING, @@ -195,7 +194,7 @@ def _append_query_data(self, response: Response) -> None: # Parse data during fetch rows = query_data["data"] row_set = (rowcount, descriptions, rows) - except (KeyError, JSONDecodeError, ValueError) as err: + except (KeyError, ValueError) as err: raise DataError(f"Invalid query data format: {str(err)}") self._row_sets.append(row_set)