From 972c2037c584c0af70cb71282c365492d38a8a8a Mon Sep 17 00:00:00 2001 From: ptiurin Date: Wed, 1 Jun 2022 16:56:39 +0100 Subject: [PATCH 1/2] feat: Parametrised query --- src/firebolt_db/firebolt_dialect.py | 4 +++- tests/integration/test_sqlalchemy_integration.py | 4 +++- tests/unit/test_firebolt_dialect.py | 6 ++++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/firebolt_db/firebolt_dialect.py b/src/firebolt_db/firebolt_dialect.py index 3d1815a..965b931 100644 --- a/src/firebolt_db/firebolt_dialect.py +++ b/src/firebolt_db/firebolt_dialect.py @@ -270,7 +270,9 @@ def do_execute( parameters: Tuple[str, Any], context: Optional[ExecutionContext] = None, ) -> None: - cursor.execute(statement, set_parameters=self._set_parameters) + cursor.execute( + statement, parameters=parameters, set_parameters=self._set_parameters + ) def do_rollback(self, dbapi_connection: AlchemyConnection) -> None: pass diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 2641a3e..dbffc6a 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -34,7 +34,9 @@ def test_data_write(self, connection: Connection, fact_table_name: str): connection.execute( f"INSERT INTO {fact_table_name}(idx, dummy) VALUES (1, 'some_text')" ) - result = connection.execute(f"SELECT * FROM {fact_table_name}") + result = connection.execute( + f"SELECT * FROM {fact_table_name} WHERE dummy=?", (1,) + ) assert len(result.fetchall()) == 1 # Update not supported with pytest.raises(OperationalError): diff --git a/tests/unit/test_firebolt_dialect.py b/tests/unit/test_firebolt_dialect.py index d70d69c..1753bef 100644 --- a/tests/unit/test_firebolt_dialect.py +++ b/tests/unit/test_firebolt_dialect.py @@ -64,8 +64,10 @@ def test_do_execute( self, dialect: FireboltDialect, cursor: mock.Mock(spec=MockCursor) ): dialect._set_parameters = {"a": "b"} - dialect.do_execute(cursor, "SELECT *", None, None) - cursor.execute.assert_called_once_with("SELECT *", set_parameters={"a": "b"}) + dialect.do_execute(cursor, "SELECT *", (1, 22), None) + cursor.execute.assert_called_once_with( + "SELECT *", parameters=(1, 22), set_parameters={"a": "b"} + ) def test_schema_names( self, dialect: FireboltDialect, connection: mock.Mock(spec=MockDBApi) From 6575d941ef04bbb5de62d9159517fd992db3dffb Mon Sep 17 00:00:00 2001 From: ptiurin Date: Wed, 1 Jun 2022 17:55:55 +0100 Subject: [PATCH 2/2] test fixes --- src/firebolt_db/firebolt_async_dialect.py | 20 ++++++++++++++----- src/firebolt_db/firebolt_dialect.py | 4 ++-- .../test_sqlalchemy_integration.py | 6 +++--- tests/unit/test_firebolt_async_dialect.py | 4 ++-- tests/unit/test_firebolt_dialect.py | 16 ++++++++------- 5 files changed, 31 insertions(+), 19 deletions(-) diff --git a/src/firebolt_db/firebolt_async_dialect.py b/src/firebolt_db/firebolt_async_dialect.py index 31f7a57..b8c417b 100644 --- a/src/firebolt_db/firebolt_async_dialect.py +++ b/src/firebolt_db/firebolt_async_dialect.py @@ -2,7 +2,7 @@ from asyncio import Lock from types import ModuleType -from typing import Any, Iterator, List, Optional, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple import firebolt.async_db as async_dbapi from firebolt.async_db import Connection @@ -53,14 +53,24 @@ def arraysize(self, value: int) -> None: def rowcount(self) -> int: return self._cursor.rowcount - def execute(self, operation: str, parameters: Optional[Tuple] = None) -> None: - self.await_(self._execute(operation, parameters)) + def execute( + self, + operation: str, + parameters: Optional[Tuple] = None, + set_parameters: Optional[Dict] = None, + ) -> None: + self.await_(self._execute(operation, parameters, set_parameters=set_parameters)) async def _execute( - self, operation: str, parameters: Optional[Tuple] = None + self, + operation: str, + parameters: Optional[Tuple] = None, + set_parameters: Optional[Dict] = None, ) -> None: async with self._adapt_connection._execute_mutex: - await self._cursor.execute(operation, parameters) + await self._cursor.execute( + operation, parameters, set_parameters=set_parameters + ) if self._cursor.description: self._rows = await self._cursor.fetchall() else: diff --git a/src/firebolt_db/firebolt_dialect.py b/src/firebolt_db/firebolt_dialect.py index 965b931..6e35d19 100644 --- a/src/firebolt_db/firebolt_dialect.py +++ b/src/firebolt_db/firebolt_dialect.py @@ -302,5 +302,5 @@ def do_commit(self, dbapi_connection: AlchemyConnection) -> None: dialect = FireboltDialect -def get_is_nullable(column_is_nullable: str) -> bool: - return column_is_nullable.lower() == "yes" +def get_is_nullable(column_is_nullable: int) -> bool: + return column_is_nullable == 1 diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index dbffc6a..fb99c34 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -34,9 +34,9 @@ def test_data_write(self, connection: Connection, fact_table_name: str): connection.execute( f"INSERT INTO {fact_table_name}(idx, dummy) VALUES (1, 'some_text')" ) - result = connection.execute( - f"SELECT * FROM {fact_table_name} WHERE dummy=?", (1,) - ) + result = connection.execute(f"SELECT * FROM {fact_table_name} WHERE idx=?", 1) + assert result.fetchall() == [(1, "some_text")] + result = connection.execute(f"SELECT * FROM {fact_table_name}") assert len(result.fetchall()) == 1 # Update not supported with pytest.raises(OperationalError): diff --git a/tests/unit/test_firebolt_async_dialect.py b/tests/unit/test_firebolt_async_dialect.py index d6a49cb..6c5e06b 100644 --- a/tests/unit/test_firebolt_async_dialect.py +++ b/tests/unit/test_firebolt_async_dialect.py @@ -80,7 +80,7 @@ def test_cursor() -> AsyncCursorWrapper: assert wrapper.description == "dummy" assert wrapper.rowcount == -1 async_cursor.execute.assert_awaited_once_with( - "INSERT INTO test(a, b) VALUES (?, ?)", [(1, "a")] + "INSERT INTO test(a, b) VALUES (?, ?)", [(1, "a")], set_parameters=None ) async_cursor.fetchall.assert_awaited_once() @@ -105,7 +105,7 @@ def test_cursor() -> AsyncCursorWrapper: assert wrapper.description is None assert wrapper.rowcount == 100 async_cursor.execute.assert_awaited_once_with( - "INSERT INTO test(a, b) VALUES (?, ?)", [(1, "a")] + "INSERT INTO test(a, b) VALUES (?, ?)", [(1, "a")], set_parameters=None ) async_cursor.fetchall.assert_not_awaited() diff --git a/tests/unit/test_firebolt_dialect.py b/tests/unit/test_firebolt_dialect.py index 1753bef..53677f6 100644 --- a/tests/unit/test_firebolt_dialect.py +++ b/tests/unit/test_firebolt_dialect.py @@ -64,6 +64,11 @@ def test_do_execute( self, dialect: FireboltDialect, cursor: mock.Mock(spec=MockCursor) ): dialect._set_parameters = {"a": "b"} + dialect.do_execute(cursor, "SELECT *", None) + cursor.execute.assert_called_once_with( + "SELECT *", parameters=None, set_parameters={"a": "b"} + ) + cursor.execute.reset_mock() dialect.do_execute(cursor, "SELECT *", (1, 22), None) cursor.execute.assert_called_once_with( "SELECT *", parameters=(1, 22), set_parameters={"a": "b"} @@ -136,8 +141,8 @@ def getitem(self, idx): return mock.Mock(__getitem__=getitem) connection.execute.return_value = [ - multi_column_row(["name1", "INT", "YES"]), - multi_column_row(["name2", "date", "no"]), + multi_column_row(["name1", "INT", 1]), + multi_column_row(["name2", "date", 0]), ] expected_query = """ @@ -223,11 +228,8 @@ def test_unicode_description( def test_get_is_nullable(): - assert firebolt_db.firebolt_dialect.get_is_nullable("YES") - assert firebolt_db.firebolt_dialect.get_is_nullable("yes") - assert not firebolt_db.firebolt_dialect.get_is_nullable("NO") - assert not firebolt_db.firebolt_dialect.get_is_nullable("no") - assert not firebolt_db.firebolt_dialect.get_is_nullable("ABC") + assert firebolt_db.firebolt_dialect.get_is_nullable(1) + assert not firebolt_db.firebolt_dialect.get_is_nullable(0) def test_types():