diff --git a/src/firebolt_db/firebolt_async_dialect.py b/src/firebolt_db/firebolt_async_dialect.py index 31f7a57..b8c417b 100644 --- a/src/firebolt_db/firebolt_async_dialect.py +++ b/src/firebolt_db/firebolt_async_dialect.py @@ -2,7 +2,7 @@ from asyncio import Lock from types import ModuleType -from typing import Any, Iterator, List, Optional, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple import firebolt.async_db as async_dbapi from firebolt.async_db import Connection @@ -53,14 +53,24 @@ def arraysize(self, value: int) -> None: def rowcount(self) -> int: return self._cursor.rowcount - def execute(self, operation: str, parameters: Optional[Tuple] = None) -> None: - self.await_(self._execute(operation, parameters)) + def execute( + self, + operation: str, + parameters: Optional[Tuple] = None, + set_parameters: Optional[Dict] = None, + ) -> None: + self.await_(self._execute(operation, parameters, set_parameters=set_parameters)) async def _execute( - self, operation: str, parameters: Optional[Tuple] = None + self, + operation: str, + parameters: Optional[Tuple] = None, + set_parameters: Optional[Dict] = None, ) -> None: async with self._adapt_connection._execute_mutex: - await self._cursor.execute(operation, parameters) + await self._cursor.execute( + operation, parameters, set_parameters=set_parameters + ) if self._cursor.description: self._rows = await self._cursor.fetchall() else: diff --git a/src/firebolt_db/firebolt_dialect.py b/src/firebolt_db/firebolt_dialect.py index 3d1815a..6e35d19 100644 --- a/src/firebolt_db/firebolt_dialect.py +++ b/src/firebolt_db/firebolt_dialect.py @@ -270,7 +270,9 @@ def do_execute( parameters: Tuple[str, Any], context: Optional[ExecutionContext] = None, ) -> None: - cursor.execute(statement, set_parameters=self._set_parameters) + cursor.execute( + statement, parameters=parameters, set_parameters=self._set_parameters + ) def do_rollback(self, dbapi_connection: AlchemyConnection) -> None: pass @@ -300,5 +302,5 @@ def do_commit(self, dbapi_connection: AlchemyConnection) -> None: dialect = FireboltDialect -def get_is_nullable(column_is_nullable: str) -> bool: - return column_is_nullable.lower() == "yes" +def get_is_nullable(column_is_nullable: int) -> bool: + return column_is_nullable == 1 diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 2641a3e..fb99c34 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -34,6 +34,8 @@ def test_data_write(self, connection: Connection, fact_table_name: str): connection.execute( f"INSERT INTO {fact_table_name}(idx, dummy) VALUES (1, 'some_text')" ) + result = connection.execute(f"SELECT * FROM {fact_table_name} WHERE idx=?", 1) + assert result.fetchall() == [(1, "some_text")] result = connection.execute(f"SELECT * FROM {fact_table_name}") assert len(result.fetchall()) == 1 # Update not supported diff --git a/tests/unit/test_firebolt_async_dialect.py b/tests/unit/test_firebolt_async_dialect.py index d6a49cb..6c5e06b 100644 --- a/tests/unit/test_firebolt_async_dialect.py +++ b/tests/unit/test_firebolt_async_dialect.py @@ -80,7 +80,7 @@ def test_cursor() -> AsyncCursorWrapper: assert wrapper.description == "dummy" assert wrapper.rowcount == -1 async_cursor.execute.assert_awaited_once_with( - "INSERT INTO test(a, b) VALUES (?, ?)", [(1, "a")] + "INSERT INTO test(a, b) VALUES (?, ?)", [(1, "a")], set_parameters=None ) async_cursor.fetchall.assert_awaited_once() @@ -105,7 +105,7 @@ def test_cursor() -> AsyncCursorWrapper: assert wrapper.description is None assert wrapper.rowcount == 100 async_cursor.execute.assert_awaited_once_with( - "INSERT INTO test(a, b) VALUES (?, ?)", [(1, "a")] + "INSERT INTO test(a, b) VALUES (?, ?)", [(1, "a")], set_parameters=None ) async_cursor.fetchall.assert_not_awaited() diff --git a/tests/unit/test_firebolt_dialect.py b/tests/unit/test_firebolt_dialect.py index d70d69c..53677f6 100644 --- a/tests/unit/test_firebolt_dialect.py +++ b/tests/unit/test_firebolt_dialect.py @@ -64,8 +64,15 @@ def test_do_execute( self, dialect: FireboltDialect, cursor: mock.Mock(spec=MockCursor) ): dialect._set_parameters = {"a": "b"} - dialect.do_execute(cursor, "SELECT *", None, None) - cursor.execute.assert_called_once_with("SELECT *", set_parameters={"a": "b"}) + dialect.do_execute(cursor, "SELECT *", None) + cursor.execute.assert_called_once_with( + "SELECT *", parameters=None, set_parameters={"a": "b"} + ) + cursor.execute.reset_mock() + dialect.do_execute(cursor, "SELECT *", (1, 22), None) + cursor.execute.assert_called_once_with( + "SELECT *", parameters=(1, 22), set_parameters={"a": "b"} + ) def test_schema_names( self, dialect: FireboltDialect, connection: mock.Mock(spec=MockDBApi) @@ -134,8 +141,8 @@ def getitem(self, idx): return mock.Mock(__getitem__=getitem) connection.execute.return_value = [ - multi_column_row(["name1", "INT", "YES"]), - multi_column_row(["name2", "date", "no"]), + multi_column_row(["name1", "INT", 1]), + multi_column_row(["name2", "date", 0]), ] expected_query = """ @@ -221,11 +228,8 @@ def test_unicode_description( def test_get_is_nullable(): - assert firebolt_db.firebolt_dialect.get_is_nullable("YES") - assert firebolt_db.firebolt_dialect.get_is_nullable("yes") - assert not firebolt_db.firebolt_dialect.get_is_nullable("NO") - assert not firebolt_db.firebolt_dialect.get_is_nullable("no") - assert not firebolt_db.firebolt_dialect.get_is_nullable("ABC") + assert firebolt_db.firebolt_dialect.get_is_nullable(1) + assert not firebolt_db.firebolt_dialect.get_is_nullable(0) def test_types():