Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/firebolt/async_db/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def _pop_next_set(self) -> Optional[bool]:
self._rowcount, self._descriptions, self._rows = self._row_sets[
self._next_set_idx
]
self._idx = 0
self._next_set_idx += 1
return True

Expand Down Expand Up @@ -464,6 +465,11 @@ async def fetchall(self) -> List[List[ColType]]:
return super().fetchall()
"""Fetch all remaining rows of a query result"""

@wraps(BaseCursor.nextset)
async def nextset(self) -> None:
async with self._async_query_lock.reader:
return super().nextset()

# Iteration support
@check_not_closed
@check_query_executed
Expand Down
5 changes: 5 additions & 0 deletions src/firebolt/db/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def fetchall(self) -> List[List[ColType]]:
with self._query_lock.gen_rlock():
return super().fetchall()

@wraps(AsyncBaseCursor.nextset)
def nextset(self) -> None:
with self._query_lock.gen_rlock(), self._idx_lock:
return super().nextset()

# Iteration support
@check_not_closed
@check_query_executed
Expand Down
25 changes: 22 additions & 3 deletions tests/integration/dbapi/async/test_queries_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,14 +260,15 @@ async def test_multi_statement_query(connection: Connection) -> None:
assert (
await c.execute(
"INSERT INTO test_tb_multi_statement values (1, 'a'), (2, 'b');"
"SELECT * FROM test_tb_multi_statement"
"SELECT * FROM test_tb_multi_statement;"
"SELECT * FROM test_tb_multi_statement WHERE i <= 1"
)
== -1
), "Invalid row count returned for insert"
assert c.rowcount == -1, "Invalid row count"
assert c.description is None, "Invalid description"

assert c.nextset()
assert await c.nextset()

assert c.rowcount == 2, "Invalid select row count"
assert_deep_eq(
Expand All @@ -285,4 +286,22 @@ async def test_multi_statement_query(connection: Connection) -> None:
"Invalid data in table after parameterized insert",
)

assert c.nextset() is None
assert await c.nextset()

assert c.rowcount == 1, "Invalid select row count"
assert_deep_eq(
c.description,
[
Column("i", int, None, None, None, None, None),
Column("s", str, None, None, None, None, None),
],
"Invalid select query description",
)

assert_deep_eq(
await c.fetchall(),
[[1, "a"]],
"Invalid data in table after parameterized insert",
)

assert await c.nextset() is None
21 changes: 20 additions & 1 deletion tests/integration/dbapi/sync/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,8 @@ def test_multi_statement_query(connection: Connection) -> None:
assert (
c.execute(
"INSERT INTO test_tb_multi_statement values (1, 'a'), (2, 'b');"
"SELECT * FROM test_tb_multi_statement"
"SELECT * FROM test_tb_multi_statement;"
"SELECT * FROM test_tb_multi_statement WHERE i <= 1"
)
== -1
), "Invalid row count returned for insert"
Expand All @@ -276,4 +277,22 @@ def test_multi_statement_query(connection: Connection) -> None:
"Invalid data in table after parameterized insert",
)

assert c.nextset()

assert c.rowcount == 1, "Invalid select row count"
assert_deep_eq(
c.description,
[
Column("i", int, None, None, None, None, None),
Column("s", str, None, None, None, None, None),
],
"Invalid select query description",
)

assert_deep_eq(
c.fetchall(),
[[1, "a"]],
"Invalid data in table after parameterized insert",
)

assert c.nextset() is None
23 changes: 19 additions & 4 deletions tests/unit/async_db/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ async def test_closed_cursor(cursor: Cursor):
("fetchone", ()),
("fetchmany", ()),
("fetchall", ()),
("nextset", ()),
)
methods = ("setinputsizes", "setoutputsize", "nextset")
methods = ("setinputsizes", "setoutputsize")

cursor.close()

Expand Down Expand Up @@ -439,8 +440,11 @@ async def test_cursor_multi_statement(
httpx_mock.add_callback(auth_callback, url=auth_url)
httpx_mock.add_callback(query_callback, url=query_url)
httpx_mock.add_callback(insert_query_callback, url=query_url)
httpx_mock.add_callback(query_callback, url=query_url)

rc = await cursor.execute("select * from t; insert into t values (1, 2)")
rc = await cursor.execute(
"select * from t; insert into t values (1, 2); select * from t"
)
assert rc == len(python_query_data), "Invalid row count returned"
assert cursor.rowcount == len(python_query_data), "Invalid cursor row count"
for i, (desc, exp) in enumerate(zip(cursor.description, python_query_description)):
Expand All @@ -451,12 +455,23 @@ async def test_cursor_multi_statement(
await cursor.fetchone() == python_query_data[i]
), f"Invalid data row at position {i}"

assert cursor.nextset()
assert await cursor.nextset()
assert cursor.rowcount == -1, "Invalid cursor row count"
assert cursor.description is None, "Invalid cursor description"
with raises(DataError) as exc_info:
await cursor.fetchall()

assert str(exc_info.value) == "no rows to fetch", "Invalid error message"

assert cursor.nextset() is None
assert await cursor.nextset()

assert cursor.rowcount == len(python_query_data), "Invalid cursor row count"
for i, (desc, exp) in enumerate(zip(cursor.description, python_query_description)):
assert desc == exp, f"Invalid column description at position {i}"

for i in range(cursor.rowcount):
assert (
await cursor.fetchone() == python_query_data[i]
), f"Invalid data row at position {i}"

assert await cursor.nextset() is None
17 changes: 15 additions & 2 deletions tests/unit/db/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_closed_cursor(cursor: Cursor):
("fetchall", ()),
("setinputsizes", (cursor, [0])),
("setoutputsize", (cursor, 0)),
("nextset", (cursor, [])),
("nextset", ()),
)

cursor.close()
Expand All @@ -71,6 +71,7 @@ def test_closed_cursor(cursor: Cursor):

for method, args in methods:
with raises(CursorClosedError):
print(method, args)
getattr(cursor, method)(*args)

with raises(CursorClosedError):
Expand Down Expand Up @@ -386,8 +387,9 @@ def test_cursor_multi_statement(
httpx_mock.add_callback(auth_callback, url=auth_url)
httpx_mock.add_callback(query_callback, url=query_url)
httpx_mock.add_callback(insert_query_callback, url=query_url)
httpx_mock.add_callback(query_callback, url=query_url)

rc = cursor.execute("select * from t; insert into t values (1, 2)")
rc = cursor.execute("select * from t; insert into t values (1, 2); select * from t")
assert rc == len(python_query_data), "Invalid row count returned"
assert cursor.rowcount == len(python_query_data), "Invalid cursor row count"
for i, (desc, exp) in enumerate(zip(cursor.description, python_query_description)):
Expand All @@ -406,4 +408,15 @@ def test_cursor_multi_statement(

assert str(exc_info.value) == "no rows to fetch", "Invalid error message"

assert cursor.nextset()

assert cursor.rowcount == len(python_query_data), "Invalid cursor row count"
for i, (desc, exp) in enumerate(zip(cursor.description, python_query_description)):
assert desc == exp, f"Invalid column description at position {i}"

for i in range(cursor.rowcount):
assert (
cursor.fetchone() == python_query_data[i]
), f"Invalid data row at position {i}"

assert cursor.nextset() is None