From 5ff6ace33108324d90f7617b949fec789e02d277 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 26 Jan 2022 16:39:22 +0200 Subject: [PATCH 1/3] fixes for multi-statement queries --- src/firebolt/async_db/cursor.py | 6 ++++++ src/firebolt/db/cursor.py | 5 +++++ 2 files changed, 11 insertions(+) diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 8f39cf52e48..f84aecfe1d0 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -222,6 +222,7 @@ def _pop_next_set(self) -> Optional[bool]: self._rowcount, self._descriptions, self._rows = self._row_sets[ self._next_set_idx ] + self._idx = 0 self._next_set_idx += 1 return True @@ -464,6 +465,11 @@ async def fetchall(self) -> List[List[ColType]]: return super().fetchall() """Fetch all remaining rows of a query result""" + @wraps(BaseCursor.nextset) + async def nextrow(self) -> None: + async with self._async_query_lock.reader: + return super().nextset() + # Iteration support @check_not_closed @check_query_executed diff --git a/src/firebolt/db/cursor.py b/src/firebolt/db/cursor.py index e7086597856..54704787457 100644 --- a/src/firebolt/db/cursor.py +++ b/src/firebolt/db/cursor.py @@ -75,6 +75,11 @@ def fetchall(self) -> List[List[ColType]]: with self._query_lock.gen_rlock(): return super().fetchall() + @wraps(AsyncBaseCursor.nextset) + def nextset(self) -> None: + with self._query_lock.gen_rlock(), self._idx_lock: + return super().nextset() + # Iteration support @check_not_closed @check_query_executed From e6c9f76e120ade40957774edd6a63ab44935b4e6 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 26 Jan 2022 17:00:34 +0200 Subject: [PATCH 2/3] extend unit tests --- src/firebolt/async_db/cursor.py | 2 +- tests/unit/async_db/test_cursor.py | 23 +++++++++++++++++++---- tests/unit/db/test_cursor.py | 17 +++++++++++++++-- 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index f84aecfe1d0..35cc69d36ad 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -466,7 +466,7 @@ async def fetchall(self) -> List[List[ColType]]: """Fetch all remaining rows of a query result""" @wraps(BaseCursor.nextset) - async def nextrow(self) -> None: + async def nextset(self) -> None: async with self._async_query_lock.reader: return super().nextset() diff --git a/tests/unit/async_db/test_cursor.py b/tests/unit/async_db/test_cursor.py index 430465c76f4..ad2ca0e4c40 100644 --- a/tests/unit/async_db/test_cursor.py +++ b/tests/unit/async_db/test_cursor.py @@ -63,8 +63,9 @@ async def test_closed_cursor(cursor: Cursor): ("fetchone", ()), ("fetchmany", ()), ("fetchall", ()), + ("nextset", ()), ) - methods = ("setinputsizes", "setoutputsize", "nextset") + methods = ("setinputsizes", "setoutputsize") cursor.close() @@ -439,8 +440,11 @@ async def test_cursor_multi_statement( httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(query_callback, url=query_url) httpx_mock.add_callback(insert_query_callback, url=query_url) + httpx_mock.add_callback(query_callback, url=query_url) - rc = await cursor.execute("select * from t; insert into t values (1, 2)") + rc = await cursor.execute( + "select * from t; insert into t values (1, 2); select * from t" + ) assert rc == len(python_query_data), "Invalid row count returned" assert cursor.rowcount == len(python_query_data), "Invalid cursor row count" for i, (desc, exp) in enumerate(zip(cursor.description, python_query_description)): @@ -451,7 +455,7 @@ async def test_cursor_multi_statement( await cursor.fetchone() == python_query_data[i] ), f"Invalid data row at position {i}" - assert cursor.nextset() + assert await cursor.nextset() assert cursor.rowcount == -1, "Invalid cursor row count" assert cursor.description is None, "Invalid cursor description" with raises(DataError) as exc_info: @@ -459,4 +463,15 @@ async def test_cursor_multi_statement( assert str(exc_info.value) == "no rows to fetch", "Invalid error message" - assert cursor.nextset() is None + assert await cursor.nextset() + + assert cursor.rowcount == len(python_query_data), "Invalid cursor row count" + for i, (desc, exp) in enumerate(zip(cursor.description, python_query_description)): + assert desc == exp, f"Invalid column description at position {i}" + + for i in range(cursor.rowcount): + assert ( + await cursor.fetchone() == python_query_data[i] + ), f"Invalid data row at position {i}" + + assert await cursor.nextset() is None diff --git a/tests/unit/db/test_cursor.py b/tests/unit/db/test_cursor.py index e5a9fdc6a2a..6713e0872a2 100644 --- a/tests/unit/db/test_cursor.py +++ b/tests/unit/db/test_cursor.py @@ -60,7 +60,7 @@ def test_closed_cursor(cursor: Cursor): ("fetchall", ()), ("setinputsizes", (cursor, [0])), ("setoutputsize", (cursor, 0)), - ("nextset", (cursor, [])), + ("nextset", ()), ) cursor.close() @@ -71,6 +71,7 @@ def test_closed_cursor(cursor: Cursor): for method, args in methods: with raises(CursorClosedError): + print(method, args) getattr(cursor, method)(*args) with raises(CursorClosedError): @@ -386,8 +387,9 @@ def test_cursor_multi_statement( httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(query_callback, url=query_url) httpx_mock.add_callback(insert_query_callback, url=query_url) + httpx_mock.add_callback(query_callback, url=query_url) - rc = cursor.execute("select * from t; insert into t values (1, 2)") + rc = cursor.execute("select * from t; insert into t values (1, 2); select * from t") assert rc == len(python_query_data), "Invalid row count returned" assert cursor.rowcount == len(python_query_data), "Invalid cursor row count" for i, (desc, exp) in enumerate(zip(cursor.description, python_query_description)): @@ -406,4 +408,15 @@ def test_cursor_multi_statement( assert str(exc_info.value) == "no rows to fetch", "Invalid error message" + assert cursor.nextset() + + assert cursor.rowcount == len(python_query_data), "Invalid cursor row count" + for i, (desc, exp) in enumerate(zip(cursor.description, python_query_description)): + assert desc == exp, f"Invalid column description at position {i}" + + for i in range(cursor.rowcount): + assert ( + cursor.fetchone() == python_query_data[i] + ), f"Invalid data row at position {i}" + assert cursor.nextset() is None From 7dc191e5bbd740af9f207edc8e640025ac4b18dd Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 26 Jan 2022 17:37:58 +0200 Subject: [PATCH 3/3] add integration tests --- .../dbapi/async/test_queries_async.py | 25 ++++++++++++++++--- tests/integration/dbapi/sync/test_queries.py | 21 +++++++++++++++- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/tests/integration/dbapi/async/test_queries_async.py b/tests/integration/dbapi/async/test_queries_async.py index a3556805bf6..1a1d39a0b69 100644 --- a/tests/integration/dbapi/async/test_queries_async.py +++ b/tests/integration/dbapi/async/test_queries_async.py @@ -260,14 +260,15 @@ async def test_multi_statement_query(connection: Connection) -> None: assert ( await c.execute( "INSERT INTO test_tb_multi_statement values (1, 'a'), (2, 'b');" - "SELECT * FROM test_tb_multi_statement" + "SELECT * FROM test_tb_multi_statement;" + "SELECT * FROM test_tb_multi_statement WHERE i <= 1" ) == -1 ), "Invalid row count returned for insert" assert c.rowcount == -1, "Invalid row count" assert c.description is None, "Invalid description" - assert c.nextset() + assert await c.nextset() assert c.rowcount == 2, "Invalid select row count" assert_deep_eq( @@ -285,4 +286,22 @@ async def test_multi_statement_query(connection: Connection) -> None: "Invalid data in table after parameterized insert", ) - assert c.nextset() is None + assert await c.nextset() + + assert c.rowcount == 1, "Invalid select row count" + assert_deep_eq( + c.description, + [ + Column("i", int, None, None, None, None, None), + Column("s", str, None, None, None, None, None), + ], + "Invalid select query description", + ) + + assert_deep_eq( + await c.fetchall(), + [[1, "a"]], + "Invalid data in table after parameterized insert", + ) + + assert await c.nextset() is None diff --git a/tests/integration/dbapi/sync/test_queries.py b/tests/integration/dbapi/sync/test_queries.py index 8016fccbd69..4e91e501194 100644 --- a/tests/integration/dbapi/sync/test_queries.py +++ b/tests/integration/dbapi/sync/test_queries.py @@ -251,7 +251,8 @@ def test_multi_statement_query(connection: Connection) -> None: assert ( c.execute( "INSERT INTO test_tb_multi_statement values (1, 'a'), (2, 'b');" - "SELECT * FROM test_tb_multi_statement" + "SELECT * FROM test_tb_multi_statement;" + "SELECT * FROM test_tb_multi_statement WHERE i <= 1" ) == -1 ), "Invalid row count returned for insert" @@ -276,4 +277,22 @@ def test_multi_statement_query(connection: Connection) -> None: "Invalid data in table after parameterized insert", ) + assert c.nextset() + + assert c.rowcount == 1, "Invalid select row count" + assert_deep_eq( + c.description, + [ + Column("i", int, None, None, None, None, None), + Column("s", str, None, None, None, None, None), + ], + "Invalid select query description", + ) + + assert_deep_eq( + c.fetchall(), + [[1, "a"]], + "Invalid data in table after parameterized insert", + ) + assert c.nextset() is None