diff --git a/README.md b/README.md index 4e0f58d..d4b5f71 100644 --- a/README.md +++ b/README.md @@ -49,10 +49,36 @@ for item in result.fetchall(): print(item) ``` +### [AsyncIO](https://docs.sqlalchemy.org/en/14/orm/extensions/asyncio.html) extension + +```python +from sqlalchemy import text +from sqlalchemy.ext.asyncio import create_async_engine +from firebolt_db.firebolt_async_dialect import AsyncFireboltDialect +from sqlalchemy.dialects import registry + +registry.register("firebolt", "src.firebolt_db.firebolt_async_dialect", "AsyncFireboltDialect") +engine = create_async_engine("firebolt://email@domain:password@sample_database/sample_engine") + +async with engine.connect() as conn: + + await conn.execute( + text(f"INSERT INTO example(dummy) VALUES (11)") + ) + + result = await conn.execute( + text(f"SELECT * FROM example") + ) + print(result.fetchall()) + +await engine.dispose() +``` + + ## Limitations 1. Transactions are not supported since Firebolt database does not support them at this time. -1. [AsyncIO](https://docs.sqlalchemy.org/en/14/orm/extensions/asyncio.html) is not yet implemented. +1. Parametrised calls to execute and executemany are not implemented. ## Contributing diff --git a/setup.cfg b/setup.cfg index f78531d..b45f06c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -41,10 +41,11 @@ where = src [options.extras_require] dev = devtools==0.7.0 + mock==4.0.3 mypy==0.910 pre-commit==2.15.0 pytest==6.2.5 - sqlalchemy-stubs + sqlalchemy-stubs==0.4 [mypy] disallow_untyped_defs = True diff --git a/src/firebolt_db/__init__.py b/src/firebolt_db/__init__.py index d3f12ab..1830389 100644 --- a/src/firebolt_db/__init__.py +++ b/src/firebolt_db/__init__.py @@ -1,17 +1,3 @@ -from firebolt.common.exception import ( - DatabaseError, - DataError, - Error, - IntegrityError, - InterfaceError, - InternalError, - NotSupportedError, - OperationalError, - ProgrammingError, - Warning, -) -from firebolt.db import connect - __all__ = [ "connect", "apilevel", diff --git a/src/firebolt_db/firebolt_async_dialect.py b/src/firebolt_db/firebolt_async_dialect.py new file mode 100644 index 0000000..ef5179d --- /dev/null +++ b/src/firebolt_db/firebolt_async_dialect.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +from asyncio import Lock +from types import ModuleType +from typing import Any, Iterator, List, Optional, Tuple + +import firebolt.async_db as async_dbapi +from firebolt.async_db import Connection + +# Ignoring type since sqlalchemy-stubs doesn't cover AdaptedConnection +from sqlalchemy.engine import AdaptedConnection # type: ignore[attr-defined] +from sqlalchemy.util.concurrency import await_only + +from firebolt_db.firebolt_dialect import FireboltDialect + + +class AsyncCursorWrapper: + __slots__ = ( + "_adapt_connection", + "_connection", + "await_", + "_cursor", + "_rows", + ) + + server_side = False + + def __init__(self, adapt_connection: AsyncConnectionWrapper): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + self._rows: List[List] = [] + self._cursor = self._connection.cursor() + + def close(self) -> None: + self._rows[:] = [] + self._cursor.close() + + @property + def description(self) -> str: + return self._cursor.description + + @property + def arraysize(self) -> int: + return self._cursor.arraysize + + @arraysize.setter + def arraysize(self, value: int) -> None: + self._cursor.arraysize = value + + @property + def rowcount(self) -> int: + return self._cursor.rowcount + + def execute(self, operation: str, parameters: Optional[Tuple] = None) -> None: + self.await_(self._execute(operation, parameters)) + + async def _execute( + self, operation: str, parameters: Optional[Tuple] = None + ) -> None: + async with self._adapt_connection._execute_mutex: + await self._cursor.execute(operation, parameters) + if self._cursor.description: + self._rows = await self._cursor.fetchall() + else: + self._rows = [] + + def executemany(self, operation: str, seq_of_parameters: List[Tuple]) -> None: + raise NotImplementedError("executemany is not supported yet") + + def __iter__(self) -> Iterator[List]: + while self._rows: + yield self._rows.pop(0) + + def fetchone(self) -> Optional[List]: + if self._rows: + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size: int = None) -> List[List]: + if size is None: + size = self._cursor.arraysize + + retval = self._rows[0:size] + self._rows[:] = self._rows[size:] + return retval + + def fetchall(self) -> List[List]: + retval = self._rows[:] + self._rows[:] = [] + return retval + + +class AsyncConnectionWrapper(AdaptedConnection): + await_ = staticmethod(await_only) + __slots__ = ("dbapi", "_connection", "_execute_mutex") + + def __init__(self, dbapi: AsyncAPIWrapper, connection: Connection): + self.dbapi = dbapi + self._connection = connection + self._execute_mutex = Lock() + + def cursor(self) -> AsyncCursorWrapper: + return AsyncCursorWrapper(self) + + def rollback(self) -> None: + pass + + def commit(self) -> None: + self._connection.commit() + + def close(self) -> None: + self.await_(self._connection._aclose()) + + +class AsyncAPIWrapper(ModuleType): + """Wrapper around Firebolt async dbapi that returns a similar wrapper for + Cursor on connect()""" + + def __init__(self, dbapi: ModuleType): + self.dbapi = dbapi + self.paramstyle = dbapi.paramstyle # type: ignore[attr-defined] # noqa: F821 + self._init_dbapi_attributes() + + def _init_dbapi_attributes(self) -> None: + for name in ( + "DatabaseError", + "Error", + "IntegrityError", + "NotSupportedError", + "OperationalError", + "ProgrammingError", + ): + setattr(self, name, getattr(self.dbapi, name)) + + def connect(self, *arg: Any, **kw: Any) -> AsyncConnectionWrapper: + + connection = await_only(self.dbapi.connect(*arg, **kw)) # type: ignore[attr-defined] # noqa: F821,E501 + return AsyncConnectionWrapper( + self, + connection, + ) + + +class AsyncFireboltDialect(FireboltDialect): + driver = "firebolt_aio" + supports_statement_cache: bool = False + supports_server_side_cursors: bool = False + is_async: bool = True + + @classmethod + def dbapi(cls) -> AsyncAPIWrapper: + return AsyncAPIWrapper(async_dbapi) + + +dialect = AsyncFireboltDialect diff --git a/src/firebolt_db/firebolt_dialect.py b/src/firebolt_db/firebolt_dialect.py index 6f07f96..06d682f 100644 --- a/src/firebolt_db/firebolt_dialect.py +++ b/src/firebolt_db/firebolt_dialect.py @@ -2,11 +2,12 @@ from types import ModuleType from typing import Any, Dict, List, Optional, Tuple, Union +import firebolt.db as dbapi import sqlalchemy.types as sqltypes from sqlalchemy.engine import Connection as AlchemyConnection from sqlalchemy.engine import ExecutionContext, default from sqlalchemy.engine.url import URL -from sqlalchemy.sql import compiler +from sqlalchemy.sql import compiler, text from sqlalchemy.types import ( BIGINT, BOOLEAN, @@ -19,8 +20,6 @@ VARCHAR, ) -import firebolt_db - class ARRAY(sqltypes.TypeEngine): __visit_name__ = "ARRAY" @@ -97,7 +96,7 @@ def __init__( @classmethod def dbapi(cls) -> ModuleType: - return firebolt_db + return dbapi # Build firebolt-sdk compatible connection arguments. # URL format : firebolt://username:password@host:port/db_name @@ -117,7 +116,7 @@ def get_schema_names( self, connection: AlchemyConnection, **kwargs: Any ) -> List[str]: query = "select schema_name from information_schema.databases" - result = connection.execute(query) + result = connection.execute(text(query)) return [row.schema_name for row in result] def has_table( @@ -133,8 +132,7 @@ def has_table( """.format( table_name=table_name ) - - result = connection.execute(query) + result = connection.execute(text(query)) return result.fetchone().exists_ def get_table_names( @@ -146,7 +144,7 @@ def get_table_names( query=query, schema=schema ) - result = connection.execute(query) + result = connection.execute(text(query)) return [row.table_name for row in result] def get_view_names( @@ -184,7 +182,7 @@ def get_columns( query=query, schema=schema ) - result = connection.execute(query) + result = connection.execute(text(query)) return [ { diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 1dafaf9..f18be35 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,10 +1,15 @@ +import asyncio from logging import getLogger from os import environ +import nest_asyncio from pytest import fixture from sqlalchemy import create_engine from sqlalchemy.dialects import registry from sqlalchemy.engine.base import Connection, Engine +from sqlalchemy.ext.asyncio import create_async_engine + +nest_asyncio.apply() LOGGER = getLogger(__name__) @@ -52,4 +57,87 @@ def engine( @fixture(scope="session") def connection(engine: Engine) -> Connection: - return engine.connect() + with engine.connect() as c: + yield c + + +@fixture(scope="session") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + + +@fixture(scope="session") +def async_engine( + username: str, password: str, database_name: str, engine_name: str +) -> Engine: + registry.register( + "firebolt_aio", "src.firebolt_db.firebolt_async_dialect", "AsyncFireboltDialect" + ) + return create_async_engine( + f"firebolt_aio://{username}:{password}@{database_name}/{engine_name}" + ) + + +@fixture(scope="session") +async def async_connection( + async_engine: Engine, +) -> Connection: + async with async_engine.connect() as c: + yield c + + +@fixture +def ex_table_name() -> str: + return "ex_lineitem_alchemy" + + +@fixture +def ex_table_query(ex_table_name: str) -> str: + return f""" + CREATE EXTERNAL TABLE {ex_table_name} + ( l_orderkey LONG, + l_partkey LONG, + l_suppkey LONG, + l_linenumber INT, + l_quantity LONG, + l_extendedprice LONG, + l_discount LONG, + l_tax LONG, + l_returnflag TEXT, + l_linestatus TEXT, + l_shipdate TEXT, + l_commitdate TEXT, + l_receiptdate TEXT, + l_shipinstruct TEXT, + l_shipmode TEXT, + l_comment TEXT + ) + URL = 's3://firebolt-publishing-public/samples/tpc-h/parquet/lineitem/' + OBJECT_PATTERN = '*.parquet' + TYPE = (PARQUET); + """ + + +@fixture(scope="class") +def fact_table_name() -> str: + return "test_alchemy" + + +@fixture(scope="class", autouse=True) +def setup_test_tables(connection: Connection, engine: Engine, fact_table_name: str): + connection.execute( + f""" + CREATE FACT TABLE IF NOT EXISTS {fact_table_name} + ( + idx INT, + dummy TEXT + ) PRIMARY INDEX idx; + """ + ) + assert engine.dialect.has_table(engine, fact_table_name) + yield + # Teardown + connection.execute(f"DROP TABLE IF EXISTS {fact_table_name}") + assert not engine.dialect.has_table(engine, fact_table_name) diff --git a/tests/integration/test_sqlalchemy_async_integration.py b/tests/integration/test_sqlalchemy_async_integration.py new file mode 100644 index 0000000..ea02f8e --- /dev/null +++ b/tests/integration/test_sqlalchemy_async_integration.py @@ -0,0 +1,80 @@ +from typing import Dict, List + +import pytest +from sqlalchemy import inspect, text +from sqlalchemy.engine.base import Connection, Engine +from sqlalchemy.exc import OperationalError + + +class TestAsyncFireboltDialect: + @pytest.mark.asyncio + async def test_create_ex_table( + self, + async_connection: Connection, + async_engine: Engine, + ex_table_query: str, + ex_table_name: str, + ): + await async_connection.execute(text(ex_table_query)) + + def has_test_table(conn: Connection) -> bool: + inspector = inspect(conn) + return inspector.has_table(ex_table_name) + + assert await async_connection.run_sync(has_test_table) + # Cleanup + await async_connection.execute(text(f"DROP TABLE {ex_table_name}")) + assert not await async_connection.run_sync(has_test_table) + + @pytest.mark.asyncio + async def test_data_write(self, async_connection: Connection, fact_table_name: str): + result = await async_connection.execute( + text(f"INSERT INTO {fact_table_name}(idx, dummy) VALUES (1, 'some_text')") + ) + assert result.rowcount == -1 + result = await async_connection.execute( + text(f"SELECT * FROM {fact_table_name}") + ) + assert result.rowcount == 1 + assert len(result.fetchall()) == 1 + # Update not supported + with pytest.raises(OperationalError): + await async_connection.execute( + text( + f"UPDATE {fact_table_name} SET dummy='some_other_text' WHERE idx=1" + ) + ) + # Delete not supported + with pytest.raises(OperationalError): + await async_connection.execute( + text(f"DELETE FROM {fact_table_name} WHERE idx=1") + ) + + @pytest.mark.asyncio + async def test_get_table_names( + self, async_connection: Connection, database_name: str + ): + def get_table_names(conn: Connection) -> bool: + inspector = inspect(conn) + return inspector.get_table_names(database_name) + + results = await async_connection.run_sync(get_table_names) + assert len(results) > 0 + + @pytest.mark.asyncio + async def test_get_columns( + self, async_connection: Connection, database_name: str, fact_table_name: str + ): + def get_columns(conn: Connection) -> List[Dict]: + inspector = inspect(conn) + return inspector.get_columns(fact_table_name, database_name) + + results = await async_connection.run_sync(get_columns) + assert len(results) > 0 + row = results[0] + assert isinstance(row, dict) + row_keys = list(row.keys()) + assert row_keys[0] == "name" + assert row_keys[1] == "type" + assert row_keys[2] == "nullable" + assert row_keys[3] == "default" diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 5af0788..cc256bf 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -4,91 +4,50 @@ class TestFireboltDialect: - - test_table = "test_alchemy" - - def create_test_table(self, connection: Connection, engine: Engine, table: str): - connection.execute( - f""" - CREATE FACT TABLE IF NOT EXISTS {table} - ( - idx INT, - dummy TEXT - ) PRIMARY INDEX idx; - """ - ) - assert engine.dialect.has_table(engine, table) - - def drop_test_table(self, connection: Connection, engine: Engine, table: str): - connection.execute(f"DROP TABLE IF EXISTS {table}") - assert not engine.dialect.has_table(engine, table) - - @pytest.fixture(scope="class", autouse=True) - def setup_test_tables(self, connection: Connection, engine: Engine): - self.create_test_table(connection, engine, self.test_table) - yield - self.drop_test_table(connection, engine, self.test_table) - - def test_create_ex_table(self, connection: Connection, engine: Engine): - connection.execute( - """ - CREATE EXTERNAL TABLE ex_lineitem_alchemy - ( l_orderkey LONG, - l_partkey LONG, - l_suppkey LONG, - l_linenumber INT, - l_quantity LONG, - l_extendedprice LONG, - l_discount LONG, - l_tax LONG, - l_returnflag TEXT, - l_linestatus TEXT, - l_shipdate TEXT, - l_commitdate TEXT, - l_receiptdate TEXT, - l_shipinstruct TEXT, - l_shipmode TEXT, - l_comment TEXT - ) - URL = 's3://firebolt-publishing-public/samples/tpc-h/parquet/lineitem/' - OBJECT_PATTERN = '*.parquet' - TYPE = (PARQUET); - """ - ) - assert engine.dialect.has_table(engine, "ex_lineitem_alchemy") + def test_create_ex_table( + self, + connection: Connection, + engine: Engine, + ex_table_query: str, + ex_table_name: str, + ): + connection.execute(ex_table_query) + assert engine.dialect.has_table(engine, ex_table_name) # Cleanup - connection.execute("DROP TABLE ex_lineitem_alchemy;") - assert not engine.dialect.has_table(engine, "ex_lineitem_alchemy") + connection.execute(f"DROP TABLE {ex_table_name}") + assert not engine.dialect.has_table(engine, ex_table_name) - def test_data_write(self, connection: Connection): + def test_data_write(self, connection: Connection, fact_table_name: str): connection.execute( - "INSERT INTO test_alchemy(idx, dummy) VALUES (1, 'some_text')" + f"INSERT INTO {fact_table_name}(idx, dummy) VALUES (1, 'some_text')" ) - result = connection.execute("SELECT * FROM test_alchemy") + result = connection.execute(f"SELECT * FROM {fact_table_name}") assert len(result.fetchall()) == 1 # Update not supported with pytest.raises(OperationalError): connection.execute( - "UPDATE test_alchemy SET dummy='some_other_text' WHERE idx=1" + f"UPDATE {fact_table_name} SET dummy='some_other_text' WHERE idx=1" ) # Delete not supported with pytest.raises(OperationalError): - connection.execute("DELETE FROM test_alchemy WHERE idx=1") + connection.execute(f"DELETE FROM {fact_table_name} WHERE idx=1") def test_get_schema_names(self, engine: Engine, database_name: str): results = engine.dialect.get_schema_names(engine) assert database_name in results - def test_has_table(self, engine: Engine, database_name: str): - results = engine.dialect.has_table(engine, self.test_table, database_name) + def test_has_table(self, engine: Engine, database_name: str, fact_table_name: str): + results = engine.dialect.has_table(engine, fact_table_name, database_name) assert results == 1 def test_get_table_names(self, engine: Engine, database_name: str): results = engine.dialect.get_table_names(engine, database_name) assert len(results) > 0 - def test_get_columns(self, engine: Engine, database_name: str): - results = engine.dialect.get_columns(engine, self.test_table, database_name) + def test_get_columns( + self, engine: Engine, database_name: str, fact_table_name: str + ): + results = engine.dialect.get_columns(engine, fact_table_name, database_name) assert len(results) > 0 row = results[0] assert isinstance(row, dict) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 2e5b04d..8e5428b 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,23 +1,115 @@ from unittest import mock +from mock import AsyncMock from pytest import fixture -from firebolt_db import firebolt_dialect +from firebolt_db import firebolt_async_dialect, firebolt_dialect class MockDBApi: + class DatabaseError: + pass + + class Error: + pass + + class IntegrityError: + pass + + class NotSupportedError: + pass + + class OperationalError: + pass + + class ProgrammingError: + pass + + paramstyle = "" + def execute(): pass def executemany(): pass + def connect(): + pass + + +class MockAsyncDBApi: + class DatabaseError: + pass + + class Error: + pass + + class IntegrityError: + pass + + class NotSupportedError: + pass + + class OperationalError: + pass + + class ProgrammingError: + pass + + paramstyle = "" + + async def connect(): + pass + + +class MockAsyncConnection: + def cursor(): + pass + + +class MockAsyncCursor: + description = "" + rowcount = -1 + arraysize = 1 + + async def execute(): + pass + + async def executemany(): + pass + + async def fetchall(): + pass + + def close(): + pass + @fixture def dialect() -> firebolt_dialect.FireboltDialect: return firebolt_dialect.FireboltDialect() +@fixture +def async_dialect() -> firebolt_async_dialect.AsyncFireboltDialect: + return firebolt_async_dialect.AsyncFireboltDialect() + + @fixture def connection() -> mock.Mock(spec=MockDBApi): return mock.Mock(spec=MockDBApi) + + +@fixture +def async_api() -> AsyncMock(spec=MockAsyncDBApi): + return AsyncMock(spec=MockAsyncDBApi) + + +@fixture +def async_connection() -> AsyncMock(spec=MockAsyncConnection): + return AsyncMock(spec=MockAsyncConnection) + + +@fixture +def async_cursor() -> AsyncMock(spec=MockAsyncCursor): + return AsyncMock(spec=MockAsyncCursor) diff --git a/tests/unit/test_firebolt_async_dialect.py b/tests/unit/test_firebolt_async_dialect.py new file mode 100644 index 0000000..d6a49cb --- /dev/null +++ b/tests/unit/test_firebolt_async_dialect.py @@ -0,0 +1,168 @@ +import pytest +from conftest import MockAsyncConnection, MockAsyncCursor, MockAsyncDBApi +from mock import AsyncMock +from sqlalchemy.util import await_only, greenlet_spawn + +from firebolt_db.firebolt_async_dialect import ( + AsyncAPIWrapper, + AsyncConnectionWrapper, + AsyncCursorWrapper, + AsyncFireboltDialect, +) +from firebolt_db.firebolt_async_dialect import ( + dialect as async_dialect_definition, +) +from firebolt_db.firebolt_dialect import ( + FireboltCompiler, + FireboltIdentifierPreparer, + FireboltTypeCompiler, +) + + +class TestAsyncFireboltDialect: + def test_create_dialect(self, async_dialect: AsyncFireboltDialect): + assert issubclass(async_dialect_definition, AsyncFireboltDialect) + assert type(AsyncFireboltDialect.dbapi()) == AsyncAPIWrapper + assert async_dialect.name == "firebolt" + assert async_dialect.driver == "firebolt_aio" + assert issubclass(async_dialect.preparer, FireboltIdentifierPreparer) + assert issubclass(async_dialect.statement_compiler, FireboltCompiler) + # SQLAlchemy's DefaultDialect creates an instance of + # type_compiler behind the scenes + assert isinstance(async_dialect.type_compiler, FireboltTypeCompiler) + assert async_dialect.context == {} + + @pytest.mark.asyncio + async def test_create_api_wrapper(self, async_api: AsyncMock(spec=MockAsyncDBApi)): + def test_connect() -> AsyncAPIWrapper: + async_api.paramstyle = "quoted" + wrapper = AsyncAPIWrapper(async_api) + wrapper.connect("test arg") + return wrapper + + wrapper = await greenlet_spawn(test_connect) + assert wrapper.dbapi == async_api + assert wrapper.paramstyle == "quoted" + async_api.connect.assert_called_once_with("test arg") + + @pytest.mark.asyncio + async def test_connection_wrapper(self, async_api: AsyncMock(spec=MockAsyncDBApi)): + def test_connection() -> AsyncConnectionWrapper: + wrapper = AsyncConnectionWrapper(async_api, await_only(async_api.connect())) + # Check call propagation + wrapper.commit() + wrapper.rollback() + wrapper.close() + return wrapper + + wrapper = await greenlet_spawn(test_connection) + assert isinstance(wrapper.cursor(), AsyncCursorWrapper) + async_api.connect.return_value.commit.assert_called_once() + async_api.connect.return_value._aclose.assert_awaited_once() + + @pytest.mark.asyncio + async def test_cursor_execute( + self, + async_api: AsyncMock(spec=MockAsyncDBApi), + async_connection: AsyncMock(spec=MockAsyncConnection), + async_cursor: AsyncMock(spec=MockAsyncCursor), + ): + def test_cursor() -> AsyncCursorWrapper: + async_connection.cursor.return_value = async_cursor + conn_wrapper = AsyncConnectionWrapper(async_api, async_connection) + wrapper = AsyncCursorWrapper(conn_wrapper) + wrapper.execute("INSERT INTO test(a, b) VALUES (?, ?)", [(1, "a")]) + return wrapper + + async_cursor.description = "dummy" + async_cursor.rowcount = -1 + wrapper = await greenlet_spawn(test_cursor) + assert wrapper.description == "dummy" + assert wrapper.rowcount == -1 + async_cursor.execute.assert_awaited_once_with( + "INSERT INTO test(a, b) VALUES (?, ?)", [(1, "a")] + ) + async_cursor.fetchall.assert_awaited_once() + + @pytest.mark.asyncio + async def test_cursor_execute_no_fetch( + self, + async_api: AsyncMock(spec=MockAsyncDBApi), + async_connection: AsyncMock(spec=MockAsyncConnection), + async_cursor: AsyncMock(spec=MockAsyncCursor), + ): + def test_cursor() -> AsyncCursorWrapper: + async_connection.cursor.return_value = async_cursor + conn_wrapper = AsyncConnectionWrapper(async_api, async_connection) + wrapper = AsyncCursorWrapper(conn_wrapper) + wrapper.execute("INSERT INTO test(a, b) VALUES (?, ?)", [(1, "a")]) + return wrapper + + async_cursor.description = None + async_cursor.rowcount = 100 + + wrapper = await greenlet_spawn(test_cursor) + assert wrapper.description is None + assert wrapper.rowcount == 100 + async_cursor.execute.assert_awaited_once_with( + "INSERT INTO test(a, b) VALUES (?, ?)", [(1, "a")] + ) + async_cursor.fetchall.assert_not_awaited() + + @pytest.mark.asyncio + async def test_cursor_close( + self, + async_api: AsyncMock(spec=MockAsyncDBApi), + async_connection: AsyncMock(spec=MockAsyncConnection), + async_cursor: AsyncMock(spec=MockAsyncCursor), + ): + def test_cursor(): + async_connection.cursor.return_value = async_cursor + conn_wrapper = AsyncConnectionWrapper(async_api, async_connection) + wrapper = AsyncCursorWrapper(conn_wrapper) + wrapper._rows = [1, 2, 3] + wrapper.close() + assert wrapper._rows == [] + async_cursor.close.assert_called_once() + + await greenlet_spawn(test_cursor) + + @pytest.mark.asyncio + async def test_cursor_executemany( + self, + async_api: AsyncMock(spec=MockAsyncDBApi), + async_connection: AsyncMock(spec=MockAsyncConnection), + async_cursor: AsyncMock(spec=MockAsyncCursor), + ): + def test_cursor(): + async_connection.cursor.return_value = async_cursor + conn_wrapper = AsyncConnectionWrapper(async_api, async_connection) + wrapper = AsyncCursorWrapper(conn_wrapper) + with pytest.raises(NotImplementedError): + wrapper.executemany( + "INSERT INTO test(a, b) VALUES (?, ?)", [(1, "a"), (2, "b")] + ) + + await greenlet_spawn(test_cursor) + + @pytest.mark.asyncio + async def test_cursor_fetch( + self, + async_api: AsyncMock(spec=MockAsyncDBApi), + async_connection: AsyncMock(spec=MockAsyncConnection), + async_cursor: AsyncMock(spec=MockAsyncCursor), + ): + def test_cursor(): + async_connection.cursor.return_value = async_cursor + conn_wrapper = AsyncConnectionWrapper(async_api, async_connection) + wrapper = AsyncCursorWrapper(conn_wrapper) + wrapper._rows = [1, 2, 3, 4, 5, 6, 7, 8] + assert wrapper.fetchone() == 1 + assert wrapper.fetchmany() == [2] + async_cursor.arraysize = 2 + assert wrapper.fetchmany() == [3, 4] + async_cursor.arraysize = 1 + assert wrapper.fetchmany(2) == [5, 6] + assert wrapper.fetchall() == [7, 8] + + await greenlet_spawn(test_cursor) diff --git a/tests/unit/test_firebolt_dialect.py b/tests/unit/test_firebolt_dialect.py index d726aab..5417253 100644 --- a/tests/unit/test_firebolt_dialect.py +++ b/tests/unit/test_firebolt_dialect.py @@ -4,6 +4,7 @@ import sqlalchemy from conftest import MockDBApi from sqlalchemy.engine import url +from sqlalchemy.sql import text import firebolt_db # SQLAlchemy package from firebolt_db.firebolt_dialect import ( @@ -59,8 +60,9 @@ def row_with_schema(name): ] result = dialect.get_schema_names(connection) assert result == ["schema1", "schema2"] - connection.execute.assert_called_once_with( - "select schema_name from information_schema.databases" + connection.execute.assert_called_once() + assert str(connection.execute.call_args[0][0].compile()) == str( + text("select schema_name from information_schema.databases").compile() ) def test_table_names( @@ -76,15 +78,19 @@ def row_with_table_name(name): result = dialect.get_table_names(connection) assert result == ["table1", "table2"] - connection.execute.assert_called_once_with( - "select table_name from information_schema.tables" + connection.execute.assert_called_once() + assert str(connection.execute.call_args[0][0].compile()) == str( + text("select table_name from information_schema.tables").compile() ) connection.execute.reset_mock() result = dialect.get_table_names(connection, schema="schema") assert result == ["table1", "table2"] - connection.execute.assert_called_once_with( - "select table_name from information_schema.tables" - " where table_schema = 'schema'" + connection.execute.assert_called_once() + assert str(connection.execute.call_args[0][0].compile()) == str( + text( + "select table_name from information_schema.tables" + " where table_schema = 'schema'" + ).compile() ) def test_view_names( @@ -145,7 +151,10 @@ def getitem(self, idx): "default": None, }, ] - connection.execute.assert_called_once_with(expected_query) + connection.execute.assert_called_once() + assert str(connection.execute.call_args[0][0].compile()) == str( + text(expected_query).compile() + ) connection.execute.reset_mock() def test_pk_constraint(