From 0165362480628f781d12bb09df99fe3fdbd52112 Mon Sep 17 00:00:00 2001 From: Petro Date: Wed, 17 Nov 2021 15:11:57 +0000 Subject: [PATCH 01/17] Splitting and adding unit tests --- tests/unit/test_firebolt_dialect.py | 183 ++++++++++++++++++++++++++++ 1 file changed, 183 insertions(+) create mode 100644 tests/unit/test_firebolt_dialect.py diff --git a/tests/unit/test_firebolt_dialect.py b/tests/unit/test_firebolt_dialect.py new file mode 100644 index 0000000..e978cb6 --- /dev/null +++ b/tests/unit/test_firebolt_dialect.py @@ -0,0 +1,183 @@ +import enum +import os + +import sqlalchemy +from sqlalchemy.sql.expression import false, true +from firebolt_db import firebolt_dialect + +from sqlalchemy.engine import url + +from unittest import mock +from pytest import fixture + +import firebolt_db + +class DBApi(): + def execute(): + pass + def executemany(): + pass + +@fixture +def dialect(): + return firebolt_dialect.FireboltDialect() + +@fixture +def mock_connection(): + return mock.Mock(spec=DBApi) + +class TestFireboltDialect: + + def test_create_dialect(self, dialect): + assert issubclass(firebolt_dialect.dialect, firebolt_dialect.FireboltDialect) + assert isinstance(firebolt_dialect.FireboltDialect.dbapi(), type(firebolt_db)) + assert dialect.name == "firebolt" + assert dialect.driver == "firebolt" + assert issubclass(dialect.preparer, firebolt_dialect.FireboltIdentifierPreparer) + assert issubclass(dialect.statement_compiler, firebolt_dialect.FireboltCompiler) + #assert issubclass(dialect.type_compiler, firebolt_dialect.FireboltTypeCompiler) + assert isinstance(dialect.type_compiler, firebolt_dialect.FireboltTypeCompiler) # ?? + assert dialect.context == {} + + def test_create_connect_args(self, dialect): + connection_url = "test_engine://test_user@email:test_password@test_db_name/test_engine_name" + u = url.make_url(connection_url) + with mock.patch.dict(os.environ, {"FIREBOLT_BASE_URL": "test_url"}): + result_list, result_dict = dialect.create_connect_args(u) + assert result_dict["engine_name"] == "test_engine_name" + assert result_dict["username"] == "test_user@email" + assert result_dict["password"] == "test_password" + assert result_dict["database"] == "test_db_name" + assert result_dict["api_endpoint"] == "test_url" + # No endpoint override + with mock.patch.dict(os.environ, {}, clear=True): + result_list, result_dict = dialect.create_connect_args(u) + assert "api_endpoint" not in result_dict + + def test_schema_names(self, dialect, mock_connection): + def row_with_schema(name): + return mock.Mock(schema_name=name) + + mock_connection.execute.return_value = [ + row_with_schema("schema1"), + row_with_schema("schema2"), + ] + result = dialect.get_schema_names(mock_connection) + assert result == ["schema1", "schema2"] + mock_connection.execute.assert_called_once_with( + "select schema_name from information_schema.databases" + ) + + def test_table_names(self, dialect, mock_connection): + def row_with_table_name(name): + return mock.Mock(table_name=name) + + mock_connection.execute.return_value = [ + row_with_table_name("table1"), + row_with_table_name("table2"), + ] + + result = dialect.get_table_names(mock_connection) + assert result == ["table1", "table2"] + mock_connection.execute.assert_called_once_with( + "select table_name from information_schema.tables" + ) + mock_connection.execute.reset_mock() + result = dialect.get_table_names(mock_connection, schema="schema") + assert result == ["table1", "table2"] + mock_connection.execute.assert_called_once_with( + "select table_name from information_schema.tables" + " where table_schema = 'schema'" + ) + + def test_view_names(self, dialect, mock_connection): + assert dialect.get_view_names(mock_connection) == [] + + def test_table_options(self, dialect, mock_connection): + assert dialect.get_table_options(mock_connection, "table") == {} + + def test_columns(self, dialect, mock_connection): + def multi_column_row(columns): + def getitem(self, idx): + for i, result in enumerate(columns): + if idx == i: + return result + + return mock.Mock(__getitem__=getitem) + + mock_connection.execute.return_value = [ + multi_column_row(["name1", "INT", "YES"]), + multi_column_row(["name2", "date", "no"]), + ] + + result = dialect.get_columns(mock_connection, "table") + assert result == [ + { + "name": "name1", + "type": sqlalchemy.types.INTEGER, + "nullable": True, + "default": None, + }, + { + "name": "name2", + "type": sqlalchemy.types.DATE, + "nullable": False, + "default": None, + }, + ] + mock_connection.execute.assert_called_once_with( + """ + select column_name, + data_type, + is_nullable + from information_schema.columns + where table_name = 'table' + """ + ) + mock_connection.execute.reset_mock() + result = dialect.get_columns(mock_connection, "table", "schema") + mock_connection.execute.assert_called_once_with( + """ + select column_name, + data_type, + is_nullable + from information_schema.columns + where table_name = 'table' + """ + " and table_schema = 'schema'" + ) + + def test_pk_constraint(self, dialect, mock_connection): + assert dialect.get_pk_constraint(mock_connection, "table") == { + "constrained_columns": [], + "name": None, + } + def test_foreign_keys(self, dialect, mock_connection): + assert dialect.get_foreign_keys(mock_connection, "table") == [] + + def test_check_constraints(self, dialect, mock_connection): + assert dialect.get_check_constraints(mock_connection, "table") == [] + + def test_table_comment(self, dialect, mock_connection): + assert dialect.get_table_comment(mock_connection, "table") == { + "text": "" + } + + def test_indexes(self, dialect, mock_connection): + assert dialect.get_indexes(mock_connection, "table") == [] + + def test_unique_constraints(self, dialect, mock_connection): + assert dialect.get_unique_constraints(mock_connection, "table") == [] + + def test_unicode_returns(self, dialect, mock_connection): + assert dialect._check_unicode_returns(mock_connection) + + def test_unicode_description(self, dialect, mock_connection): + assert dialect._check_unicode_description(mock_connection) + +def test_get_is_nullable(): + assert firebolt_dialect.get_is_nullable("YES") + assert firebolt_dialect.get_is_nullable("yes") + assert not firebolt_dialect.get_is_nullable("NO") + assert not firebolt_dialect.get_is_nullable("no") + assert not firebolt_dialect.get_is_nullable("ABC") \ No newline at end of file From f0a48eb5eb38e2a265f95bbfebc20a3fd09dd5ae Mon Sep 17 00:00:00 2001 From: Petro Date: Wed, 17 Nov 2021 17:03:08 +0000 Subject: [PATCH 02/17] Testing dialect types --- tests/unit/test_firebolt_dialect.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_firebolt_dialect.py b/tests/unit/test_firebolt_dialect.py index e978cb6..82b3f2f 100644 --- a/tests/unit/test_firebolt_dialect.py +++ b/tests/unit/test_firebolt_dialect.py @@ -175,9 +175,23 @@ def test_unicode_returns(self, dialect, mock_connection): def test_unicode_description(self, dialect, mock_connection): assert dialect._check_unicode_description(mock_connection) + def test_get_is_nullable(): assert firebolt_dialect.get_is_nullable("YES") assert firebolt_dialect.get_is_nullable("yes") assert not firebolt_dialect.get_is_nullable("NO") assert not firebolt_dialect.get_is_nullable("no") - assert not firebolt_dialect.get_is_nullable("ABC") \ No newline at end of file + assert not firebolt_dialect.get_is_nullable("ABC") + + +def test_types(): + assert firebolt_dialect.CHAR is sqlalchemy.sql.sqltypes.CHAR + assert firebolt_dialect.DATE is sqlalchemy.sql.sqltypes.DATE + assert firebolt_dialect.DATETIME is sqlalchemy.sql.sqltypes.DATETIME + assert firebolt_dialect.INTEGER is sqlalchemy.sql.sqltypes.INTEGER + assert firebolt_dialect.BIGINT is sqlalchemy.sql.sqltypes.BIGINT + assert firebolt_dialect.TIMESTAMP is sqlalchemy.sql.sqltypes.TIMESTAMP + assert firebolt_dialect.VARCHAR is sqlalchemy.sql.sqltypes.VARCHAR + assert firebolt_dialect.BOOLEAN is sqlalchemy.sql.sqltypes.BOOLEAN + assert firebolt_dialect.FLOAT is sqlalchemy.sql.sqltypes.FLOAT + assert issubclass(firebolt_dialect.ARRAY, sqlalchemy.types.TypeEngine) From 8065d39c5e43a1c51afdd678d3146f81f794766d Mon Sep 17 00:00:00 2001 From: Petro Date: Thu, 18 Nov 2021 14:10:33 +0000 Subject: [PATCH 03/17] Adding integration tests --- .../test_sqlalchemy_integration.py | 149 ++++++++++++++++++ tests/test_fireboltdialect.py | 97 ------------ 2 files changed, 149 insertions(+), 97 deletions(-) create mode 100644 tests/integration/test_sqlalchemy_integration.py delete mode 100644 tests/test_fireboltdialect.py diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py new file mode 100644 index 0000000..25338c8 --- /dev/null +++ b/tests/integration/test_sqlalchemy_integration.py @@ -0,0 +1,149 @@ +from _pytest.fixtures import fixture +import pytest +import os + +import sqlalchemy + +from firebolt_db import firebolt_dialect + +from sqlalchemy import create_engine +from sqlalchemy.dialects import registry + + +test_username = os.environ["username"] +test_password = os.environ["password"] +test_engine_name = os.environ["engine_name"] +test_db_name = os.environ["db_name"] + + +@pytest.fixture(scope="class") +def get_engine(): + registry.register("firebolt", "src.firebolt_db.firebolt_dialect", "FireboltDialect") + return create_engine(f"firebolt://{test_username}:{test_password}@{test_db_name}/{test_engine_name}") + + +@pytest.fixture(scope="class") +def get_connection(get_engine): + engine = get_engine + # TODO: once commit is implemented remove execution options + return engine.connect().execution_options(autocommit=False) + + +dialect = firebolt_dialect.FireboltDialect() + + +class TestFireboltDialect: + + test_table = "test_alchemy" + + def create_test_table(self, get_connection, get_engine, table): + connection = get_connection + connection.commit = lambda x: x + connection.execute(f""" + CREATE FACT TABLE IF NOT EXISTS {table} + ( + dummy TEXT + ) PRIMARY INDEX dummy; + """) + assert get_engine.dialect.has_table(get_engine, table) + + def drop_test_table(self, get_connection, get_engine, table): + connection = get_connection + connection.commit = lambda x: x + connection.execute(f"DROP TABLE IF EXISTS {table}") + assert not get_engine.dialect.has_table(get_engine, table) + + @pytest.fixture(scope="class", autouse=True) + def setup_test_tables(self, get_connection, get_engine): + self.create_test_table(get_connection, get_engine, self.table_name) + yield + self.drop_test_table(get_connection, get_engine, self.table_name) + + @pytest.mark.skip(reason="Commit not implemented in sdk") + def test_create_ex_table(self, get_engine, get_connection): + engine = get_engine + connection = get_connection + connection.execute(""" + CREATE EXTERNAL TABLE ex_lineitem_alchemy + ( l_orderkey LONG, + l_partkey LONG, + l_suppkey LONG, + l_linenumber INT, + l_quantity LONG, + l_extendedprice LONG, + l_discount LONG, + l_tax LONG, + l_returnflag TEXT, + l_linestatus TEXT, + l_shipdate TEXT, + l_commitdate TEXT, + l_receiptdate TEXT, + l_shipinstruct TEXT, + l_shipmode TEXT, + l_comment TEXT + ) + URL = 's3://firebolt-publishing-public/samples/tpc-h/parquet/lineitem/' + OBJECT_PATTERN = '*.parquet' + TYPE = (PARQUET); + """ + ) + assert engine.dialect.has_table(engine, "ex_lineitem_alchemy") + # Cleanup + connection.execute("DROP TABLE ex_lineitem_alchemy;") + assert not engine.dialect.has_table(engine, "ex_lineitem_alchemy") + + @pytest.mark.skip(reason="Commit not implemented in sdk") + def test_data_write(self, get_connection): + connection = get_connection + connection.execute( + "INSERT INTO test_alchemy(dummy) VALUES ('some_text')" + ) + result = connection.execute("SELECT * FROM lineitem") + assert result.rowcount == 1 + connection.execute( + "DELETE FROM lineitem WHERE l_orderkey=10" + ) + result = connection.execute("SELECT * FROM lineitem") + assert result.rowcount == 0 + + def test_get_schema_names(self, get_engine): + engine = get_engine + try: + results = dialect.get_schema_names(engine) + assert test_db_name in results + except sqlalchemy.exc.InternalError as http_err: + assert http_err != "" + + def test_has_table(self, get_engine): + schema = test_db_name + engine = get_engine + try: + results = dialect.has_table(engine, self.test_table, schema) + assert results == 1 + except sqlalchemy.exc.InternalError as http_err: + assert http_err != "" + + def test_get_table_names(self, get_engine): + schema = test_db_name + engine = get_engine + try: + results = dialect.get_table_names(engine, schema) + assert len(results) > 0 + except sqlalchemy.exc.InternalError as http_err: + assert http_err != "" + + def test_get_columns(self, get_engine): + schema = test_db_name + engine = get_engine + try: + results = dialect.get_columns(engine, self.test_table, schema) + assert len(results) > 0 + row = results[0] + assert isinstance(row, dict) + row_keys = list(row.keys()) + assert row_keys[0] == "name" + assert row_keys[1] == "type" + assert row_keys[2] == "nullable" + assert row_keys[3] == "default" + except sqlalchemy.exc.InternalError as http_err: + assert http_err != "" \ No newline at end of file diff --git a/tests/test_fireboltdialect.py b/tests/test_fireboltdialect.py deleted file mode 100644 index 7cf7f5c..0000000 --- a/tests/test_fireboltdialect.py +++ /dev/null @@ -1,97 +0,0 @@ -import pytest -import os - -import sqlalchemy - -from firebolt_db import firebolt_dialect - -from sqlalchemy.engine import url -from sqlalchemy import create_engine -from sqlalchemy.dialects import registry - -from unittest import mock - - -test_username = os.environ["username"] -test_password = os.environ["password"] -test_engine_name = os.environ["engine_name"] -test_db_name = os.environ["db_name"] - - -@pytest.fixture -def get_engine(): - registry.register("firebolt", "src.firebolt_db.firebolt_dialect", "FireboltDialect") - return create_engine(f"firebolt://{test_username}:{test_password}@{test_db_name}/{test_engine_name}") - - -dialect = firebolt_dialect.FireboltDialect() - - -class TestFireboltDialect: - - def test_create_connect_args(self): - connection_url = "test_engine://test_user@email:test_password@test_db_name/test_engine_name" - u = url.make_url(connection_url) - with mock.patch.dict(os.environ, {"FIREBOLT_BASE_URL": "test_url"}): - result_list, result_dict = dialect.create_connect_args(u) - assert result_dict["engine_name"] == "test_engine_name" - assert result_dict["username"] == "test_user@email" - assert result_dict["password"] == "test_password" - assert result_dict["database"] == "test_db_name" - assert result_dict["api_endpoint"] == "test_url" - # No endpoint override - with mock.patch.dict(os.environ, {}, clear=True): - result_list, result_dict = dialect.create_connect_args(u) - assert "api_endpoint" not in result_dict - - def test_get_schema_names(self, get_engine): - engine = get_engine - try: - results = dialect.get_schema_names(engine) - assert test_db_name in results - except sqlalchemy.exc.InternalError as http_err: - assert http_err != "" - - def test_has_table(self, get_engine): - table = 'ci_fact_table' - schema = test_db_name - engine = get_engine - try: - results = dialect.has_table(engine, table, schema) - assert results == 1 - except sqlalchemy.exc.InternalError as http_err: - assert http_err != "" - - def test_get_table_names(self, get_engine): - schema = test_db_name - engine = get_engine - try: - results = dialect.get_table_names(engine, schema) - assert len(results) > 0 - except sqlalchemy.exc.InternalError as http_err: - assert http_err != "" - - def test_get_columns(self, get_engine): - table = 'ci_fact_table' - schema = test_db_name - engine = get_engine - try: - results = dialect.get_columns(engine, table, schema) - assert len(results) > 0 - row = results[0] - assert isinstance(row, dict) - row_keys = list(row.keys()) - assert row_keys[0] == "name" - assert row_keys[1] == "type" - assert row_keys[2] == "nullable" - assert row_keys[3] == "default" - except sqlalchemy.exc.InternalError as http_err: - assert http_err != "" - - -def test_get_is_nullable(): - assert firebolt_dialect.get_is_nullable("YES") - assert firebolt_dialect.get_is_nullable("yes") - assert not firebolt_dialect.get_is_nullable("NO") - assert not firebolt_dialect.get_is_nullable("no") - assert not firebolt_dialect.get_is_nullable("ABC") From da0914444c5dc4103f44f1da3d7bdb333928ac05 Mon Sep 17 00:00:00 2001 From: Petro Date: Thu, 18 Nov 2021 15:05:09 +0000 Subject: [PATCH 04/17] Not masking errors in the integration --- .../test_sqlalchemy_integration.py | 46 +++++++------------ 1 file changed, 17 insertions(+), 29 deletions(-) diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 25338c8..5b24b01 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -55,9 +55,9 @@ def drop_test_table(self, get_connection, get_engine, table): @pytest.fixture(scope="class", autouse=True) def setup_test_tables(self, get_connection, get_engine): - self.create_test_table(get_connection, get_engine, self.table_name) + self.create_test_table(get_connection, get_engine, self.test_table) yield - self.drop_test_table(get_connection, get_engine, self.table_name) + self.drop_test_table(get_connection, get_engine, self.test_table) @pytest.mark.skip(reason="Commit not implemented in sdk") def test_create_ex_table(self, get_engine, get_connection): @@ -108,42 +108,30 @@ def test_data_write(self, get_connection): def test_get_schema_names(self, get_engine): engine = get_engine - try: - results = dialect.get_schema_names(engine) - assert test_db_name in results - except sqlalchemy.exc.InternalError as http_err: - assert http_err != "" + results = dialect.get_schema_names(engine) + assert test_db_name in results def test_has_table(self, get_engine): schema = test_db_name engine = get_engine - try: - results = dialect.has_table(engine, self.test_table, schema) - assert results == 1 - except sqlalchemy.exc.InternalError as http_err: - assert http_err != "" + results = dialect.has_table(engine, self.test_table, schema) + assert results == 1 def test_get_table_names(self, get_engine): schema = test_db_name engine = get_engine - try: - results = dialect.get_table_names(engine, schema) - assert len(results) > 0 - except sqlalchemy.exc.InternalError as http_err: - assert http_err != "" + results = dialect.get_table_names(engine, schema) + assert len(results) > 0 def test_get_columns(self, get_engine): schema = test_db_name engine = get_engine - try: - results = dialect.get_columns(engine, self.test_table, schema) - assert len(results) > 0 - row = results[0] - assert isinstance(row, dict) - row_keys = list(row.keys()) - assert row_keys[0] == "name" - assert row_keys[1] == "type" - assert row_keys[2] == "nullable" - assert row_keys[3] == "default" - except sqlalchemy.exc.InternalError as http_err: - assert http_err != "" \ No newline at end of file + results = dialect.get_columns(engine, self.test_table, schema) + assert len(results) > 0 + row = results[0] + assert isinstance(row, dict) + row_keys = list(row.keys()) + assert row_keys[0] == "name" + assert row_keys[1] == "type" + assert row_keys[2] == "nullable" + assert row_keys[3] == "default" \ No newline at end of file From f553a42281097ffdfa5ee842fc4df2c1f6c32c67 Mon Sep 17 00:00:00 2001 From: Petro Date: Thu, 18 Nov 2021 15:34:37 +0000 Subject: [PATCH 05/17] Fix skipped tests --- .../test_sqlalchemy_integration.py | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 5b24b01..e266be5 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -3,6 +3,7 @@ import os import sqlalchemy +from sqlalchemy.exc import OperationalError from firebolt_db import firebolt_dialect @@ -25,7 +26,7 @@ def get_engine(): @pytest.fixture(scope="class") def get_connection(get_engine): engine = get_engine - # TODO: once commit is implemented remove execution options + # TODO: once commit is implemented in sdk remove execution options return engine.connect().execution_options(autocommit=False) @@ -42,8 +43,9 @@ def create_test_table(self, get_connection, get_engine, table): connection.execute(f""" CREATE FACT TABLE IF NOT EXISTS {table} ( + idx INT, dummy TEXT - ) PRIMARY INDEX dummy; + ) PRIMARY INDEX idx; """) assert get_engine.dialect.has_table(get_engine, table) @@ -96,15 +98,20 @@ def test_create_ex_table(self, get_engine, get_connection): def test_data_write(self, get_connection): connection = get_connection connection.execute( - "INSERT INTO test_alchemy(dummy) VALUES ('some_text')" + "INSERT INTO test_alchemy(idx, dummy) VALUES (1, 'some_text')" ) - result = connection.execute("SELECT * FROM lineitem") - assert result.rowcount == 1 - connection.execute( - "DELETE FROM lineitem WHERE l_orderkey=10" - ) - result = connection.execute("SELECT * FROM lineitem") - assert result.rowcount == 0 + result = connection.execute("SELECT * FROM test_alchemy") + assert len(result.fetchall()) == 1 + # Update not supported + with pytest.raises(OperationalError): + connection.execute( + "UPDATE test_alchemy SET dummy='some_other_text' WHERE idx=1" + ) + # Delete not supported + with pytest.raises(OperationalError): + connection.execute( + "DELETE FROM test_alchemy WHERE idx=1" + ) def test_get_schema_names(self, get_engine): engine = get_engine From 54ae2ade28ccd43392eb4d8ae0475a97ecd5828b Mon Sep 17 00:00:00 2001 From: Petro Date: Thu, 18 Nov 2021 15:53:39 +0000 Subject: [PATCH 06/17] Adding condition to failing tests --- tests/integration/test_sqlalchemy_integration.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index e266be5..31f5d64 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -2,7 +2,6 @@ import pytest import os -import sqlalchemy from sqlalchemy.exc import OperationalError from firebolt_db import firebolt_dialect @@ -10,6 +9,8 @@ from sqlalchemy import create_engine from sqlalchemy.dialects import registry +import firebolt as firebolt_sdk + test_username = os.environ["username"] test_password = os.environ["password"] @@ -26,8 +27,11 @@ def get_engine(): @pytest.fixture(scope="class") def get_connection(get_engine): engine = get_engine - # TODO: once commit is implemented in sdk remove execution options - return engine.connect().execution_options(autocommit=False) + if hasattr(firebolt_sdk.db.connection.Connection, "commit"): + return engine.connect() + else: + # Disabling autocommit allows for table creation/destruction without trying to call non-existing parameters + return engine.connect().execution_options(autocommit=False) dialect = firebolt_dialect.FireboltDialect() @@ -61,7 +65,7 @@ def setup_test_tables(self, get_connection, get_engine): yield self.drop_test_table(get_connection, get_engine, self.test_table) - @pytest.mark.skip(reason="Commit not implemented in sdk") + @pytest.mark.skipif(not hasattr(firebolt_sdk.db.connection.Connection, "commit"), reason="Commit not implemented in sdk") def test_create_ex_table(self, get_engine, get_connection): engine = get_engine connection = get_connection @@ -94,7 +98,7 @@ def test_create_ex_table(self, get_engine, get_connection): connection.execute("DROP TABLE ex_lineitem_alchemy;") assert not engine.dialect.has_table(engine, "ex_lineitem_alchemy") - @pytest.mark.skip(reason="Commit not implemented in sdk") + @pytest.mark.skipif(not hasattr(firebolt_sdk.db.connection.Connection, "commit"), reason="Commit not implemented in sdk") def test_data_write(self, get_connection): connection = get_connection connection.execute( From 3bc3407e9c685260db6914e8d52e69e803e37e09 Mon Sep 17 00:00:00 2001 From: Petro Date: Thu, 18 Nov 2021 16:04:21 +0000 Subject: [PATCH 07/17] Adding explanation for odd unit test --- tests/unit/test_firebolt_dialect.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_firebolt_dialect.py b/tests/unit/test_firebolt_dialect.py index 82b3f2f..fe7ed77 100644 --- a/tests/unit/test_firebolt_dialect.py +++ b/tests/unit/test_firebolt_dialect.py @@ -35,8 +35,8 @@ def test_create_dialect(self, dialect): assert dialect.driver == "firebolt" assert issubclass(dialect.preparer, firebolt_dialect.FireboltIdentifierPreparer) assert issubclass(dialect.statement_compiler, firebolt_dialect.FireboltCompiler) - #assert issubclass(dialect.type_compiler, firebolt_dialect.FireboltTypeCompiler) - assert isinstance(dialect.type_compiler, firebolt_dialect.FireboltTypeCompiler) # ?? + # SQLAlchemy's DefaultDialect creates an instance of type_compiler behind the scenes + assert isinstance(dialect.type_compiler, firebolt_dialect.FireboltTypeCompiler) assert dialect.context == {} def test_create_connect_args(self, dialect): From 9813281b159c7f51621bc90271196d941b3bb16f Mon Sep 17 00:00:00 2001 From: Petro Date: Thu, 18 Nov 2021 16:47:38 +0000 Subject: [PATCH 08/17] Formatting linting changes --- .../test_sqlalchemy_integration.py | 30 ++++++++----- tests/unit/test_firebolt_dialect.py | 43 ++++++++++--------- 2 files changed, 41 insertions(+), 32 deletions(-) diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 31f5d64..15ae06a 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -1,4 +1,3 @@ -from _pytest.fixtures import fixture import pytest import os @@ -21,7 +20,9 @@ @pytest.fixture(scope="class") def get_engine(): registry.register("firebolt", "src.firebolt_db.firebolt_dialect", "FireboltDialect") - return create_engine(f"firebolt://{test_username}:{test_password}@{test_db_name}/{test_engine_name}") + return create_engine( + f"firebolt://{test_username}:{test_password}@{test_db_name}/{test_engine_name}" + ) @pytest.fixture(scope="class") @@ -44,13 +45,15 @@ class TestFireboltDialect: def create_test_table(self, get_connection, get_engine, table): connection = get_connection connection.commit = lambda x: x - connection.execute(f""" + connection.execute( + f""" CREATE FACT TABLE IF NOT EXISTS {table} ( idx INT, dummy TEXT ) PRIMARY INDEX idx; - """) + """ + ) assert get_engine.dialect.has_table(get_engine, table) def drop_test_table(self, get_connection, get_engine, table): @@ -65,11 +68,15 @@ def setup_test_tables(self, get_connection, get_engine): yield self.drop_test_table(get_connection, get_engine, self.test_table) - @pytest.mark.skipif(not hasattr(firebolt_sdk.db.connection.Connection, "commit"), reason="Commit not implemented in sdk") + @pytest.mark.skipif( + not hasattr(firebolt_sdk.db.connection.Connection, "commit"), + reason="Commit not implemented in sdk", + ) def test_create_ex_table(self, get_engine, get_connection): engine = get_engine connection = get_connection - connection.execute(""" + connection.execute( + """ CREATE EXTERNAL TABLE ex_lineitem_alchemy ( l_orderkey LONG, l_partkey LONG, @@ -98,7 +105,10 @@ def test_create_ex_table(self, get_engine, get_connection): connection.execute("DROP TABLE ex_lineitem_alchemy;") assert not engine.dialect.has_table(engine, "ex_lineitem_alchemy") - @pytest.mark.skipif(not hasattr(firebolt_sdk.db.connection.Connection, "commit"), reason="Commit not implemented in sdk") + @pytest.mark.skipif( + not hasattr(firebolt_sdk.db.connection.Connection, "commit"), + reason="Commit not implemented in sdk", + ) def test_data_write(self, get_connection): connection = get_connection connection.execute( @@ -113,9 +123,7 @@ def test_data_write(self, get_connection): ) # Delete not supported with pytest.raises(OperationalError): - connection.execute( - "DELETE FROM test_alchemy WHERE idx=1" - ) + connection.execute("DELETE FROM test_alchemy WHERE idx=1") def test_get_schema_names(self, get_engine): engine = get_engine @@ -145,4 +153,4 @@ def test_get_columns(self, get_engine): assert row_keys[0] == "name" assert row_keys[1] == "type" assert row_keys[2] == "nullable" - assert row_keys[3] == "default" \ No newline at end of file + assert row_keys[3] == "default" diff --git a/tests/unit/test_firebolt_dialect.py b/tests/unit/test_firebolt_dialect.py index fe7ed77..b73828c 100644 --- a/tests/unit/test_firebolt_dialect.py +++ b/tests/unit/test_firebolt_dialect.py @@ -1,35 +1,35 @@ -import enum import os +from unittest import mock import sqlalchemy -from sqlalchemy.sql.expression import false, true -from firebolt_db import firebolt_dialect - -from sqlalchemy.engine import url - -from unittest import mock from pytest import fixture +from sqlalchemy.engine import url import firebolt_db +from firebolt_db import firebolt_dialect + -class DBApi(): +class DBApi: def execute(): pass + def executemany(): pass + @fixture def dialect(): return firebolt_dialect.FireboltDialect() + @fixture def mock_connection(): return mock.Mock(spec=DBApi) -class TestFireboltDialect: +class TestFireboltDialect: def test_create_dialect(self, dialect): - assert issubclass(firebolt_dialect.dialect, firebolt_dialect.FireboltDialect) + assert issubclass(firebolt_dialect.dialect, firebolt_dialect.FireboltDialect) assert isinstance(firebolt_dialect.FireboltDialect.dbapi(), type(firebolt_db)) assert dialect.name == "firebolt" assert dialect.driver == "firebolt" @@ -40,7 +40,9 @@ def test_create_dialect(self, dialect): assert dialect.context == {} def test_create_connect_args(self, dialect): - connection_url = "test_engine://test_user@email:test_password@test_db_name/test_engine_name" + connection_url = ( + "test_engine://test_user@email:test_password@test_db_name/test_engine_name" + ) u = url.make_url(connection_url) with mock.patch.dict(os.environ, {"FIREBOLT_BASE_URL": "test_url"}): result_list, result_dict = dialect.create_connect_args(u) @@ -65,7 +67,7 @@ def row_with_schema(name): result = dialect.get_schema_names(mock_connection) assert result == ["schema1", "schema2"] mock_connection.execute.assert_called_once_with( - "select schema_name from information_schema.databases" + "select schema_name from information_schema.databases" ) def test_table_names(self, dialect, mock_connection): @@ -76,7 +78,7 @@ def row_with_table_name(name): row_with_table_name("table1"), row_with_table_name("table2"), ] - + result = dialect.get_table_names(mock_connection) assert result == ["table1", "table2"] mock_connection.execute.assert_called_once_with( @@ -89,7 +91,7 @@ def row_with_table_name(name): "select table_name from information_schema.tables" " where table_schema = 'schema'" ) - + def test_view_names(self, dialect, mock_connection): assert dialect.get_view_names(mock_connection) == [] @@ -102,7 +104,7 @@ def getitem(self, idx): for i, result in enumerate(columns): if idx == i: return result - + return mock.Mock(__getitem__=getitem) mock_connection.execute.return_value = [ @@ -144,24 +146,23 @@ def getitem(self, idx): from information_schema.columns where table_name = 'table' """ - " and table_schema = 'schema'" + " and table_schema = 'schema'" ) - + def test_pk_constraint(self, dialect, mock_connection): assert dialect.get_pk_constraint(mock_connection, "table") == { "constrained_columns": [], "name": None, } + def test_foreign_keys(self, dialect, mock_connection): assert dialect.get_foreign_keys(mock_connection, "table") == [] def test_check_constraints(self, dialect, mock_connection): assert dialect.get_check_constraints(mock_connection, "table") == [] - + def test_table_comment(self, dialect, mock_connection): - assert dialect.get_table_comment(mock_connection, "table") == { - "text": "" - } + assert dialect.get_table_comment(mock_connection, "table") == {"text": ""} def test_indexes(self, dialect, mock_connection): assert dialect.get_indexes(mock_connection, "table") == [] From e37fb46e0b4d46c1065226357fd853bcfcf3f536 Mon Sep 17 00:00:00 2001 From: Petro Date: Fri, 19 Nov 2021 13:50:17 +0000 Subject: [PATCH 09/17] Refactor and cleanup --- tests/integration/conftest.py | 63 ++++++++ .../test_sqlalchemy_integration.py | 153 ++++++------------ 2 files changed, 115 insertions(+), 101 deletions(-) create mode 100644 tests/integration/conftest.py diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 0000000..fc2c735 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,63 @@ + +from os import environ + +from logging import getLogger +from pytest import fixture + +from sqlalchemy import create_engine +from sqlalchemy.dialects import registry + +import firebolt as firebolt_sdk + +LOGGER = getLogger(__name__) + +ENGINE_NAME_ENV = "engine_name" +DATABASE_NAME_ENV = "database_name" +USERNAME_ENV = "username" +PASSWORD_ENV = "password" + + +def must_env(var_name: str) -> str: + assert var_name in environ, f"Expected {var_name} to be provided in environment" + LOGGER.info(f"{var_name}: {environ[var_name]}") + return environ[var_name] + + +@fixture(scope="session") +def engine_name() -> str: + return must_env(ENGINE_NAME_ENV) + + +@fixture(scope="session") +def database_name() -> str: + return must_env(DATABASE_NAME_ENV) + + +@fixture(scope="session") +def username() -> str: + return must_env(USERNAME_ENV) + + +@fixture(scope="session") +def password() -> str: + return must_env(PASSWORD_ENV) + + +@fixture(scope="session") +def engine(username, password, database_name, engine_name): + registry.register( + "firebolt", "src.firebolt_db.firebolt_dialect", "FireboltDialect") + return create_engine( + f"firebolt://{username}:{password}@{database_name}/{engine_name}" + ) + + +@fixture(scope="session") +def connection(engine): + engine = engine + if hasattr(firebolt_sdk.db.connection.Connection, "commit"): + return engine.connect() + else: + # Disabling autocommit allows for table creation/destruction without + # trying to call non-existing parameters + return engine.connect().execution_options(autocommit=False) diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 15ae06a..9a40877 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -1,104 +1,62 @@ import pytest -import os from sqlalchemy.exc import OperationalError -from firebolt_db import firebolt_dialect - -from sqlalchemy import create_engine -from sqlalchemy.dialects import registry - -import firebolt as firebolt_sdk - - -test_username = os.environ["username"] -test_password = os.environ["password"] -test_engine_name = os.environ["engine_name"] -test_db_name = os.environ["db_name"] - - -@pytest.fixture(scope="class") -def get_engine(): - registry.register("firebolt", "src.firebolt_db.firebolt_dialect", "FireboltDialect") - return create_engine( - f"firebolt://{test_username}:{test_password}@{test_db_name}/{test_engine_name}" - ) - - -@pytest.fixture(scope="class") -def get_connection(get_engine): - engine = get_engine - if hasattr(firebolt_sdk.db.connection.Connection, "commit"): - return engine.connect() - else: - # Disabling autocommit allows for table creation/destruction without trying to call non-existing parameters - return engine.connect().execution_options(autocommit=False) - - -dialect = firebolt_dialect.FireboltDialect() - class TestFireboltDialect: test_table = "test_alchemy" - def create_test_table(self, get_connection, get_engine, table): - connection = get_connection - connection.commit = lambda x: x + def create_test_table(self, connection, engine, table): connection.execute( f""" - CREATE FACT TABLE IF NOT EXISTS {table} - ( - idx INT, - dummy TEXT - ) PRIMARY INDEX idx; - """ + CREATE FACT TABLE IF NOT EXISTS {table} + ( + idx INT, + dummy TEXT + ) PRIMARY INDEX idx; + """ ) - assert get_engine.dialect.has_table(get_engine, table) + assert engine.dialect.has_table(engine, table) - def drop_test_table(self, get_connection, get_engine, table): - connection = get_connection - connection.commit = lambda x: x + def drop_test_table(self, connection, engine, table): connection.execute(f"DROP TABLE IF EXISTS {table}") - assert not get_engine.dialect.has_table(get_engine, table) + assert not engine.dialect.has_table(engine, table) @pytest.fixture(scope="class", autouse=True) - def setup_test_tables(self, get_connection, get_engine): - self.create_test_table(get_connection, get_engine, self.test_table) + def setup_test_tables(self, connection, engine): + self.create_test_table(connection, engine, self.test_table) yield - self.drop_test_table(get_connection, get_engine, self.test_table) + self.drop_test_table(connection, engine, self.test_table) - @pytest.mark.skipif( - not hasattr(firebolt_sdk.db.connection.Connection, "commit"), - reason="Commit not implemented in sdk", + @pytest.mark.skip( + reason="Commit not implemented in sdk" ) - def test_create_ex_table(self, get_engine, get_connection): - engine = get_engine - connection = get_connection + def test_create_ex_table(self, engine, connection): connection.execute( - """ - CREATE EXTERNAL TABLE ex_lineitem_alchemy - ( l_orderkey LONG, - l_partkey LONG, - l_suppkey LONG, - l_linenumber INT, - l_quantity LONG, - l_extendedprice LONG, - l_discount LONG, - l_tax LONG, - l_returnflag TEXT, - l_linestatus TEXT, - l_shipdate TEXT, - l_commitdate TEXT, - l_receiptdate TEXT, - l_shipinstruct TEXT, - l_shipmode TEXT, - l_comment TEXT - ) - URL = 's3://firebolt-publishing-public/samples/tpc-h/parquet/lineitem/' - OBJECT_PATTERN = '*.parquet' - TYPE = (PARQUET); - """ + """ + CREATE EXTERNAL TABLE ex_lineitem_alchemy + ( l_orderkey LONG, + l_partkey LONG, + l_suppkey LONG, + l_linenumber INT, + l_quantity LONG, + l_extendedprice LONG, + l_discount LONG, + l_tax LONG, + l_returnflag TEXT, + l_linestatus TEXT, + l_shipdate TEXT, + l_commitdate TEXT, + l_receiptdate TEXT, + l_shipinstruct TEXT, + l_shipmode TEXT, + l_comment TEXT + ) + URL = 's3://firebolt-publishing-public/samples/tpc-h/parquet/lineitem/' + OBJECT_PATTERN = '*.parquet' + TYPE = (PARQUET); + """ ) assert engine.dialect.has_table(engine, "ex_lineitem_alchemy") # Cleanup @@ -106,11 +64,9 @@ def test_create_ex_table(self, get_engine, get_connection): assert not engine.dialect.has_table(engine, "ex_lineitem_alchemy") @pytest.mark.skipif( - not hasattr(firebolt_sdk.db.connection.Connection, "commit"), - reason="Commit not implemented in sdk", + reason="Commit not implemented in sdk" ) - def test_data_write(self, get_connection): - connection = get_connection + def test_data_write(self, connection): connection.execute( "INSERT INTO test_alchemy(idx, dummy) VALUES (1, 'some_text')" ) @@ -125,27 +81,22 @@ def test_data_write(self, get_connection): with pytest.raises(OperationalError): connection.execute("DELETE FROM test_alchemy WHERE idx=1") - def test_get_schema_names(self, get_engine): - engine = get_engine - results = dialect.get_schema_names(engine) - assert test_db_name in results + def test_get_schema_names(self, engine, database_name): + results = engine.dialect.get_schema_names(engine) + assert database_name in results - def test_has_table(self, get_engine): - schema = test_db_name - engine = get_engine - results = dialect.has_table(engine, self.test_table, schema) + def test_has_table(self, engine, database_name): + results = engine.dialect.has_table( + engine, self.test_table, database_name) assert results == 1 - def test_get_table_names(self, get_engine): - schema = test_db_name - engine = get_engine - results = dialect.get_table_names(engine, schema) + def test_get_table_names(self, engine, database_name): + results = engine.dialect.get_table_names(engine, database_name) assert len(results) > 0 - def test_get_columns(self, get_engine): - schema = test_db_name - engine = get_engine - results = dialect.get_columns(engine, self.test_table, schema) + def test_get_columns(self, engine, database_name): + results = engine.dialect.get_columns( + engine, self.test_table, database_name) assert len(results) > 0 row = results[0] assert isinstance(row, dict) From 25093d1dc7285e87bfaf1aa4f8cede06bcce14a1 Mon Sep 17 00:00:00 2001 From: Petro Date: Fri, 19 Nov 2021 13:51:42 +0000 Subject: [PATCH 10/17] Fix typo --- tests/integration/test_sqlalchemy_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 9a40877..cf3a4b7 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -63,7 +63,7 @@ def test_create_ex_table(self, engine, connection): connection.execute("DROP TABLE ex_lineitem_alchemy;") assert not engine.dialect.has_table(engine, "ex_lineitem_alchemy") - @pytest.mark.skipif( + @pytest.mark.skip( reason="Commit not implemented in sdk" ) def test_data_write(self, connection): From cb2fc4749127336129be45283e3503e3ed55f826 Mon Sep 17 00:00:00 2001 From: Petro Date: Fri, 19 Nov 2021 14:17:53 +0000 Subject: [PATCH 11/17] Refactor unit tests --- tests/unit/conftest.py | 23 ++++ tests/unit/test_firebolt_dialect.py | 158 +++++++++++++--------------- 2 files changed, 94 insertions(+), 87 deletions(-) create mode 100644 tests/unit/conftest.py diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 0000000..9c22e37 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,23 @@ + + +from pytest import fixture +from firebolt_db import firebolt_dialect +from unittest import mock + + +class DBApi: + def execute(): + pass + + def executemany(): + pass + + +@fixture +def dialect(): + return firebolt_dialect.FireboltDialect() + + +@fixture +def connection(): + return mock.Mock(spec=DBApi) diff --git a/tests/unit/test_firebolt_dialect.py b/tests/unit/test_firebolt_dialect.py index b73828c..a2d9ea8 100644 --- a/tests/unit/test_firebolt_dialect.py +++ b/tests/unit/test_firebolt_dialect.py @@ -2,41 +2,27 @@ from unittest import mock import sqlalchemy -from pytest import fixture from sqlalchemy.engine import url import firebolt_db from firebolt_db import firebolt_dialect -class DBApi: - def execute(): - pass - - def executemany(): - pass - - -@fixture -def dialect(): - return firebolt_dialect.FireboltDialect() - - -@fixture -def mock_connection(): - return mock.Mock(spec=DBApi) - - class TestFireboltDialect: def test_create_dialect(self, dialect): - assert issubclass(firebolt_dialect.dialect, firebolt_dialect.FireboltDialect) - assert isinstance(firebolt_dialect.FireboltDialect.dbapi(), type(firebolt_db)) + assert issubclass(firebolt_dialect.dialect, + firebolt_dialect.FireboltDialect) + assert isinstance( + firebolt_dialect.FireboltDialect.dbapi(), type(firebolt_db)) assert dialect.name == "firebolt" assert dialect.driver == "firebolt" - assert issubclass(dialect.preparer, firebolt_dialect.FireboltIdentifierPreparer) - assert issubclass(dialect.statement_compiler, firebolt_dialect.FireboltCompiler) + assert issubclass(dialect.preparer, + firebolt_dialect.FireboltIdentifierPreparer) + assert issubclass(dialect.statement_compiler, + firebolt_dialect.FireboltCompiler) # SQLAlchemy's DefaultDialect creates an instance of type_compiler behind the scenes - assert isinstance(dialect.type_compiler, firebolt_dialect.FireboltTypeCompiler) + assert isinstance(dialect.type_compiler, + firebolt_dialect.FireboltTypeCompiler) assert dialect.context == {} def test_create_connect_args(self, dialect): @@ -51,54 +37,56 @@ def test_create_connect_args(self, dialect): assert result_dict["password"] == "test_password" assert result_dict["database"] == "test_db_name" assert result_dict["api_endpoint"] == "test_url" + assert result_list == [] # No endpoint override with mock.patch.dict(os.environ, {}, clear=True): result_list, result_dict = dialect.create_connect_args(u) assert "api_endpoint" not in result_dict + assert result_list == [] - def test_schema_names(self, dialect, mock_connection): + def test_schema_names(self, dialect, connection): def row_with_schema(name): return mock.Mock(schema_name=name) - mock_connection.execute.return_value = [ + connection.execute.return_value = [ row_with_schema("schema1"), row_with_schema("schema2"), ] - result = dialect.get_schema_names(mock_connection) + result = dialect.get_schema_names(connection) assert result == ["schema1", "schema2"] - mock_connection.execute.assert_called_once_with( + connection.execute.assert_called_once_with( "select schema_name from information_schema.databases" ) - def test_table_names(self, dialect, mock_connection): + def test_table_names(self, dialect, connection): def row_with_table_name(name): return mock.Mock(table_name=name) - mock_connection.execute.return_value = [ + connection.execute.return_value = [ row_with_table_name("table1"), row_with_table_name("table2"), ] - result = dialect.get_table_names(mock_connection) + result = dialect.get_table_names(connection) assert result == ["table1", "table2"] - mock_connection.execute.assert_called_once_with( + connection.execute.assert_called_once_with( "select table_name from information_schema.tables" ) - mock_connection.execute.reset_mock() - result = dialect.get_table_names(mock_connection, schema="schema") + connection.execute.reset_mock() + result = dialect.get_table_names(connection, schema="schema") assert result == ["table1", "table2"] - mock_connection.execute.assert_called_once_with( + connection.execute.assert_called_once_with( "select table_name from information_schema.tables" " where table_schema = 'schema'" ) - def test_view_names(self, dialect, mock_connection): - assert dialect.get_view_names(mock_connection) == [] + def test_view_names(self, dialect, connection): + assert dialect.get_view_names(connection) == [] - def test_table_options(self, dialect, mock_connection): - assert dialect.get_table_options(mock_connection, "table") == {} + def test_table_options(self, dialect, connection): + assert dialect.get_table_options(connection, "table") == {} - def test_columns(self, dialect, mock_connection): + def test_columns(self, dialect, connection): def multi_column_row(columns): def getitem(self, idx): for i, result in enumerate(columns): @@ -107,74 +95,70 @@ def getitem(self, idx): return mock.Mock(__getitem__=getitem) - mock_connection.execute.return_value = [ + connection.execute.return_value = [ multi_column_row(["name1", "INT", "YES"]), multi_column_row(["name2", "date", "no"]), ] - result = dialect.get_columns(mock_connection, "table") - assert result == [ - { - "name": "name1", - "type": sqlalchemy.types.INTEGER, - "nullable": True, - "default": None, - }, - { - "name": "name2", - "type": sqlalchemy.types.DATE, - "nullable": False, - "default": None, - }, - ] - mock_connection.execute.assert_called_once_with( - """ + expected_query = """ select column_name, data_type, is_nullable from information_schema.columns where table_name = 'table' """ - ) - mock_connection.execute.reset_mock() - result = dialect.get_columns(mock_connection, "table", "schema") - mock_connection.execute.assert_called_once_with( - """ - select column_name, - data_type, - is_nullable - from information_schema.columns - where table_name = 'table' - """ - " and table_schema = 'schema'" - ) - def test_pk_constraint(self, dialect, mock_connection): - assert dialect.get_pk_constraint(mock_connection, "table") == { + expected_query_schema = expected_query + " and table_schema = 'schema'" + + for call, expected_query in ( + (lambda: dialect.get_columns(connection, "table"), expected_query), + (lambda: dialect.get_columns(connection, + "table", "schema"), expected_query_schema) + ): + + assert call() == [ + { + "name": "name1", + "type": sqlalchemy.types.INTEGER, + "nullable": True, + "default": None, + }, + { + "name": "name2", + "type": sqlalchemy.types.DATE, + "nullable": False, + "default": None, + }, + ] + connection.execute.assert_called_once_with(expected_query) + connection.execute.reset_mock() + + def test_pk_constraint(self, dialect, connection): + assert dialect.get_pk_constraint(connection, "table") == { "constrained_columns": [], "name": None, } - def test_foreign_keys(self, dialect, mock_connection): - assert dialect.get_foreign_keys(mock_connection, "table") == [] + def test_foreign_keys(self, dialect, connection): + assert dialect.get_foreign_keys(connection, "table") == [] - def test_check_constraints(self, dialect, mock_connection): - assert dialect.get_check_constraints(mock_connection, "table") == [] + def test_check_constraints(self, dialect, connection): + assert dialect.get_check_constraints(connection, "table") == [] - def test_table_comment(self, dialect, mock_connection): - assert dialect.get_table_comment(mock_connection, "table") == {"text": ""} + def test_table_comment(self, dialect, connection): + assert dialect.get_table_comment(connection, "table") == {"text": ""} - def test_indexes(self, dialect, mock_connection): - assert dialect.get_indexes(mock_connection, "table") == [] + def test_indexes(self, dialect, connection): + assert dialect.get_indexes(connection, "table") == [] - def test_unique_constraints(self, dialect, mock_connection): - assert dialect.get_unique_constraints(mock_connection, "table") == [] + def test_unique_constraints(self, dialect, connection): + assert dialect.get_unique_constraints(connection, "table") == [] - def test_unicode_returns(self, dialect, mock_connection): - assert dialect._check_unicode_returns(mock_connection) + def test_unicode_returns(self, dialect, connection): + assert dialect._check_unicode_returns(connection) - def test_unicode_description(self, dialect, mock_connection): - assert dialect._check_unicode_description(mock_connection) + def test_unicode_description(self, dialect, connection): + assert dialect._check_unicode_description(connection) def test_get_is_nullable(): From 82f35ba315d390838657cb26142c48679b05649b Mon Sep 17 00:00:00 2001 From: Petro Date: Fri, 19 Nov 2021 14:18:09 +0000 Subject: [PATCH 12/17] Remove ci folder --- ci/firebolt_ingest_data.py | 350 ------------------------------------- 1 file changed, 350 deletions(-) delete mode 100644 ci/firebolt_ingest_data.py diff --git a/ci/firebolt_ingest_data.py b/ci/firebolt_ingest_data.py deleted file mode 100644 index 35277ab..0000000 --- a/ci/firebolt_ingest_data.py +++ /dev/null @@ -1,350 +0,0 @@ -import json -import sys - -import requests -from requests.exceptions import HTTPError - -from firebolt_db import exceptions -from firebolt_db import constants - -# Arguments passed -user_email = sys.argv[1] -password = sys.argv[2] -db_name = sys.argv[3] -if len(sys.argv) == 5: - engine_name = sys.argv[4] -else: - engine_name = None - - -class IngestFireboltData: - - @staticmethod - def create_tables(): - """ - Create tables for testing purpose. - This method creates tables inside firebolt database to test for code changes - It internally calls methods to get access token and engine URL. - :input user-email, password, database name, external table name and fact table name - """ - # get access token - access_token = IngestFireboltData.get_access_token({'username': user_email, 'password': password}) - # get engine url - if engine_name is None or engine_name == '': - engine_url = IngestFireboltData.get_engine_url_by_db(access_token) - else: - engine_url = IngestFireboltData.get_engine_url_by_engine(access_token) - # create external table - IngestFireboltData.create_external_table(engine_url, access_token) - # create fact table - IngestFireboltData.create_fact_table(engine_url, access_token) - # ingest data into fact table - IngestFireboltData.ingest_data(engine_url, access_token) - - @staticmethod - def get_access_token(data): - """ - Retrieve authentication token - This method uses the user email and the password to fire the API to generate access-token. - :input dictionary containing user-email and password - :returns access-token - """ - json_data = {} # base case - payload = {} - try: - - """ - General format of request: - curl --request POST 'https://api.app.firebolt.io/auth/v1/login' --header 'Content-Type: application/json;charset=UTF-8' --data-binary '{"username":"username","password":"password"}' - """ - token_response = requests.post(url=constants.token_url, data=json.dumps(data), - headers=constants.token_header) - token_response.raise_for_status() - - """ - General format of response: - - { - "access_token": "YOUR_ACCESS_TOKEN_VALUE", - "expires_in": 86400, - "refresh_token": "YOUR_REFRESH_TOKEN_VALUE", - "scope": "offline_access", - "token_type": "Bearer" - } - """ - - json_data = json.loads(token_response.text) - access_token = json_data["access_token"] - - except HTTPError as http_err: - payload = { - "error": "Access Token API Exception", - "errorMessage": http_err.response.text, - } - except Exception as err: - payload = { - "error": "Access Token API Exception", - "errorMessage": str(err), - } - - if payload != {}: - msg = "{error} : {errorMessage}".format(**payload) - raise exceptions.InvalidCredentialsError(msg) - - return access_token - - @staticmethod - def get_engine_url_by_db(access_token): - """ - Get engine url by db name - This method generates engine url using db name and access-token - :input api url, request type, authentication header and access-token - :returns engine url - """ - engine_url = "" # base case - payload = {} - try: - """ - Request: - curl --request GET 'https://api.app.firebolt.io/core/v1/account/engines:getURLByDatabaseName?database_name=YOUR_DATABASE_NAME' \ - --header 'Authorization: Bearer YOUR_ACCESS_TOKEN_VALUE' - """ - header = {'Authorization': "Bearer " + access_token} - query_engine_response = requests.get(constants.query_engine_url, params={'database_name': db_name}, - headers=header) - query_engine_response.raise_for_status() - - """ - Response: - {"engine_url": "YOUR_DATABASES_DEFAULT_ENGINE_URL"} - """ - json_data = json.loads(query_engine_response.text) - engine_url = json_data["engine_url"] - - except HTTPError as http_err: - payload = { - "error": "Engine Url API Exception", - "errorMessage": http_err.response.text, - } - except Exception as err: - payload = { - "error": "Engine Url API Exception", - "errorMessage": str(err), - } - if payload != {}: - msg = "{error} : {errorMessage}".format(**payload) - raise exceptions.SchemaNotFoundError(msg) - - return engine_url - - @staticmethod - def get_engine_url_by_engine(access_token): - """ - Get engine url by engine name - This method generates engine url using engine name and access-token - :input engine name and access-token - :returns engine url - """ - engine_url = "" # base case - payload = {} - try: - """ - Request: - curl --request GET 'https://api.app.firebolt.io/core/v1/account/engines?filter.name_contains=YOUR_ENGINE_NAME' - --header 'Authorization: Bearer YOUR_ACCESS_TOKEN_VALUE' - """ - header = {'Authorization': "Bearer " + access_token} - query_engine_response = requests.get(constants.query_engine_url_by_engine_name, - params={'filter.name_contains': engine_name}, - headers=header) - query_engine_response.raise_for_status() - - """ - Response: - { - "page": { - ... - }, - "edges": [ - { - ... - "endpoint": "YOUR_ENGINE_URL", - ... - } - } - ] - } - """ - json_data = json.loads(query_engine_response.text) - engine_url = json_data["edges"][0]["node"]["endpoint"] - - except HTTPError as http_err: - payload = { - "error": "Engine Url API Exception", - "errorMessage": http_err.response.text, - } - except Exception as err: - payload = { - "error": "Engine Url API Exception", - "errorMessage": str(err), - } - if payload != {}: - msg = "{error} : {errorMessage}".format(**payload) - raise exceptions.EngineNotFoundError(msg) - - return engine_url - - @staticmethod - def create_external_table(engine_url, access_token): - """ - This method is used to create an external table. - :input engine url, access_token - """ - payload = {} - try: - external_table_script = 'CREATE EXTERNAL TABLE IF NOT EXISTS test_external_table' \ - '( l_orderkey LONG,' \ - 'l_partkey LONG,' \ - 'l_suppkey LONG,' \ - 'l_linenumber INT,' \ - 'l_quantity LONG,' \ - 'l_extendedprice LONG,' \ - 'l_discount LONG,' \ - 'l_tax LONG,' \ - 'l_returnflag TEXT,' \ - 'l_linestatus TEXT,' \ - 'l_shipdate TEXT,' \ - 'l_commitdate TEXT,' \ - 'l_receiptdate TEXT,' \ - 'l_shipinstruct TEXT,' \ - 'l_shipmode TEXT,' \ - 'l_comment TEXT' \ - ')' \ - 'URL = \'s3://firebolt-publishing-public/samples/tpc-h/parquet/lineitem/\'' \ - 'OBJECT_PATTERN = \'*.parquet\'' \ - 'TYPE = (PARQUET);' - # '-- CREDENTIALS = ( AWS_KEY_ID = \'******\' AWS_SECRET_KEY = \'******\' )' \ - - """ - General format of request: - echo "CREATE_EXTERNAL_TABLE_SCRIPT" | curl \ - --request POST 'https://YOUR_ENGINE_URL/?database=YOUR_DATABASE_NAME' \ - --header 'Authorization: Bearer YOUR_ACCESS_TOKEN_VALUE' \ - --data-binary @- - """ - - header = {'Authorization': "Bearer " + access_token} - api_response = requests.post(url="https://" + engine_url + "/", params={'database': db_name}, - headers=header, files={"query": (None, external_table_script)}) - api_response.raise_for_status() - - except HTTPError as http_err: - payload = { - "error": "DB-API Exception", - "errorMessage": http_err.response.text, - } - except Exception as err: - payload = { - "error": "DB-API Exception", - "errorMessage": str(err), - } - if payload != {}: - msg = "{error} : {errorMessage}".format(**payload) - raise exceptions.InternalError(msg) - - @staticmethod - def create_fact_table(engine_url, access_token): - """ - Create a fact table - This method is used to create a fact table. - :input engine url, access_token - """ - payload = {} - try: - fact_table_script = 'CREATE FACT TABLE IF NOT EXISTS ci_fact_table' \ - '( l_orderkey LONG,' \ - ' l_partkey LONG,' \ - ' l_suppkey LONG,' \ - ' l_linenumber INT,' \ - ' l_quantity LONG,' \ - ' l_extendedprice LONG,' \ - ' l_discount LONG,' \ - ' l_tax LONG,' \ - ' l_returnflag TEXT,' \ - ' l_linestatus TEXT,' \ - ' l_shipdate TEXT,' \ - ' l_commitdate TEXT,' \ - ' l_receiptdate TEXT,' \ - ' l_shipinstruct TEXT,' \ - ' l_shipmode TEXT,' \ - ' l_comment TEXT' \ - ') PRIMARY INDEX l_orderkey, l_linenumber;' - """ - General format of request: - echo "CREATE_FACT_TABLE_SCRIPT" | curl \ - --request POST 'https://YOUR_ENGINE_URL/?database=YOUR_DATABASE_NAME' \ - --header 'Authorization: Bearer YOUR_ACCESS_TOKEN_VALUE' \ - --data-binary @- - """ - - header = {'Authorization': "Bearer " + access_token} - api_response = requests.post(url="https://" + engine_url, params={'database': db_name}, - headers=header, files={"query": (None, fact_table_script)}) - api_response.raise_for_status() - - except HTTPError as http_err: - payload = { - "error": "DB-API Exception", - "errorMessage": http_err.response.text, - } - except Exception as err: - payload = { - "error": "DB-API Exception", - "errorMessage": str(err), - } - if payload != {}: - msg = "{error} : {errorMessage}".format(**payload) - raise exceptions.InternalError(msg) - - @staticmethod - def ingest_data(engine_url, access_token): - """ - This method is used to ingest data into the fact table. - :input engine url, access_token - """ - payload = {} - try: - import_script = 'INSERT INTO ci_fact_table\n' \ - 'SELECT *' \ - 'FROM test_external_table WHERE NOT EXISTS ' \ - '(SELECT l_orderkey FROM ci_fact_table WHERE ci_fact_table.l_orderkey = ' \ - 'test_external_table.l_orderkey) ;' - - """ - General format of request: - echo "IMPORT_SCRIPT" | curl \ - --request POST 'https://YOUR_ENGINE_URL/?database=YOUR_DATABASE_NAME' \ - --header 'Authorization: Bearer YOUR_ACCESS_TOKEN_VALUE' --data-binary @- - """ - - header = {'Authorization': "Bearer " + access_token} - api_response = requests.post(url="https://" + engine_url, params={'database': db_name}, - headers=header, files={"query": (None, import_script)}) - api_response.raise_for_status() - - except HTTPError as http_err: - payload = { - "error": "DB-API Exception", - "errorMessage": http_err.response.text, - } - except Exception as err: - payload = { - "error": "DB-API Exception", - "errorMessage": str(err), - } - if payload != {}: - msg = "{error} : {errorMessage}".format(**payload) - raise exceptions.InternalError(msg) - - -IngestFireboltData.create_tables() From f1d4b5335978e16b0255f9a189542524715ca344 Mon Sep 17 00:00:00 2001 From: Petro Date: Fri, 19 Nov 2021 14:20:46 +0000 Subject: [PATCH 13/17] Formatting fixes --- tests/integration/conftest.py | 10 +++------ .../test_sqlalchemy_integration.py | 15 ++++--------- tests/unit/conftest.py | 4 ++-- tests/unit/test_firebolt_dialect.py | 21 ++++++++----------- 4 files changed, 18 insertions(+), 32 deletions(-) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index fc2c735..c99177b 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,14 +1,11 @@ - +from logging import getLogger from os import environ -from logging import getLogger +import firebolt as firebolt_sdk from pytest import fixture - from sqlalchemy import create_engine from sqlalchemy.dialects import registry -import firebolt as firebolt_sdk - LOGGER = getLogger(__name__) ENGINE_NAME_ENV = "engine_name" @@ -45,8 +42,7 @@ def password() -> str: @fixture(scope="session") def engine(username, password, database_name, engine_name): - registry.register( - "firebolt", "src.firebolt_db.firebolt_dialect", "FireboltDialect") + registry.register("firebolt", "src.firebolt_db.firebolt_dialect", "FireboltDialect") return create_engine( f"firebolt://{username}:{password}@{database_name}/{engine_name}" ) diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index cf3a4b7..371e03d 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -1,5 +1,4 @@ import pytest - from sqlalchemy.exc import OperationalError @@ -29,9 +28,7 @@ def setup_test_tables(self, connection, engine): yield self.drop_test_table(connection, engine, self.test_table) - @pytest.mark.skip( - reason="Commit not implemented in sdk" - ) + @pytest.mark.skip(reason="Commit not implemented in sdk") def test_create_ex_table(self, engine, connection): connection.execute( """ @@ -63,9 +60,7 @@ def test_create_ex_table(self, engine, connection): connection.execute("DROP TABLE ex_lineitem_alchemy;") assert not engine.dialect.has_table(engine, "ex_lineitem_alchemy") - @pytest.mark.skip( - reason="Commit not implemented in sdk" - ) + @pytest.mark.skip(reason="Commit not implemented in sdk") def test_data_write(self, connection): connection.execute( "INSERT INTO test_alchemy(idx, dummy) VALUES (1, 'some_text')" @@ -86,8 +81,7 @@ def test_get_schema_names(self, engine, database_name): assert database_name in results def test_has_table(self, engine, database_name): - results = engine.dialect.has_table( - engine, self.test_table, database_name) + results = engine.dialect.has_table(engine, self.test_table, database_name) assert results == 1 def test_get_table_names(self, engine, database_name): @@ -95,8 +89,7 @@ def test_get_table_names(self, engine, database_name): assert len(results) > 0 def test_get_columns(self, engine, database_name): - results = engine.dialect.get_columns( - engine, self.test_table, database_name) + results = engine.dialect.get_columns(engine, self.test_table, database_name) assert len(results) > 0 row = results[0] assert isinstance(row, dict) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 9c22e37..2c08093 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,8 +1,8 @@ - +from unittest import mock from pytest import fixture + from firebolt_db import firebolt_dialect -from unittest import mock class DBApi: diff --git a/tests/unit/test_firebolt_dialect.py b/tests/unit/test_firebolt_dialect.py index a2d9ea8..09c2836 100644 --- a/tests/unit/test_firebolt_dialect.py +++ b/tests/unit/test_firebolt_dialect.py @@ -10,19 +10,14 @@ class TestFireboltDialect: def test_create_dialect(self, dialect): - assert issubclass(firebolt_dialect.dialect, - firebolt_dialect.FireboltDialect) - assert isinstance( - firebolt_dialect.FireboltDialect.dbapi(), type(firebolt_db)) + assert issubclass(firebolt_dialect.dialect, firebolt_dialect.FireboltDialect) + assert isinstance(firebolt_dialect.FireboltDialect.dbapi(), type(firebolt_db)) assert dialect.name == "firebolt" assert dialect.driver == "firebolt" - assert issubclass(dialect.preparer, - firebolt_dialect.FireboltIdentifierPreparer) - assert issubclass(dialect.statement_compiler, - firebolt_dialect.FireboltCompiler) + assert issubclass(dialect.preparer, firebolt_dialect.FireboltIdentifierPreparer) + assert issubclass(dialect.statement_compiler, firebolt_dialect.FireboltCompiler) # SQLAlchemy's DefaultDialect creates an instance of type_compiler behind the scenes - assert isinstance(dialect.type_compiler, - firebolt_dialect.FireboltTypeCompiler) + assert isinstance(dialect.type_compiler, firebolt_dialect.FireboltTypeCompiler) assert dialect.context == {} def test_create_connect_args(self, dialect): @@ -112,8 +107,10 @@ def getitem(self, idx): for call, expected_query in ( (lambda: dialect.get_columns(connection, "table"), expected_query), - (lambda: dialect.get_columns(connection, - "table", "schema"), expected_query_schema) + ( + lambda: dialect.get_columns(connection, "table", "schema"), + expected_query_schema, + ), ): assert call() == [ From f9f540338fe073ea54b2be6f2a4fdd66ff46bbc6 Mon Sep 17 00:00:00 2001 From: Petro Date: Mon, 22 Nov 2021 11:24:01 +0000 Subject: [PATCH 14/17] Adding typehints to tests --- tests/integration/conftest.py | 7 +- .../test_sqlalchemy_integration.py | 19 ++-- tests/unit/conftest.py | 8 +- tests/unit/test_firebolt_dialect.py | 99 ++++++++++++------- 4 files changed, 83 insertions(+), 50 deletions(-) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index c99177b..a1c29a7 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -5,6 +5,7 @@ from pytest import fixture from sqlalchemy import create_engine from sqlalchemy.dialects import registry +from sqlalchemy.engine.base import Connection, Engine LOGGER = getLogger(__name__) @@ -41,7 +42,9 @@ def password() -> str: @fixture(scope="session") -def engine(username, password, database_name, engine_name): +def engine( + username: str, password: str, database_name: str, engine_name: str +) -> Engine: registry.register("firebolt", "src.firebolt_db.firebolt_dialect", "FireboltDialect") return create_engine( f"firebolt://{username}:{password}@{database_name}/{engine_name}" @@ -49,7 +52,7 @@ def engine(username, password, database_name, engine_name): @fixture(scope="session") -def connection(engine): +def connection(engine: Engine) -> Connection: engine = engine if hasattr(firebolt_sdk.db.connection.Connection, "commit"): return engine.connect() diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 371e03d..1825366 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -1,4 +1,5 @@ import pytest +from sqlalchemy.engine.base import Connection, Engine from sqlalchemy.exc import OperationalError @@ -6,7 +7,7 @@ class TestFireboltDialect: test_table = "test_alchemy" - def create_test_table(self, connection, engine, table): + def create_test_table(self, connection: Connection, engine: Engine, table: str): connection.execute( f""" CREATE FACT TABLE IF NOT EXISTS {table} @@ -18,18 +19,18 @@ def create_test_table(self, connection, engine, table): ) assert engine.dialect.has_table(engine, table) - def drop_test_table(self, connection, engine, table): + def drop_test_table(self, connection: Connection, engine: Engine, table: str): connection.execute(f"DROP TABLE IF EXISTS {table}") assert not engine.dialect.has_table(engine, table) @pytest.fixture(scope="class", autouse=True) - def setup_test_tables(self, connection, engine): + def setup_test_tables(self, connection: Connection, engine: Engine): self.create_test_table(connection, engine, self.test_table) yield self.drop_test_table(connection, engine, self.test_table) @pytest.mark.skip(reason="Commit not implemented in sdk") - def test_create_ex_table(self, engine, connection): + def test_create_ex_table(self, connection: Connection, engine: Engine): connection.execute( """ CREATE EXTERNAL TABLE ex_lineitem_alchemy @@ -61,7 +62,7 @@ def test_create_ex_table(self, engine, connection): assert not engine.dialect.has_table(engine, "ex_lineitem_alchemy") @pytest.mark.skip(reason="Commit not implemented in sdk") - def test_data_write(self, connection): + def test_data_write(self, connection: Connection): connection.execute( "INSERT INTO test_alchemy(idx, dummy) VALUES (1, 'some_text')" ) @@ -76,19 +77,19 @@ def test_data_write(self, connection): with pytest.raises(OperationalError): connection.execute("DELETE FROM test_alchemy WHERE idx=1") - def test_get_schema_names(self, engine, database_name): + def test_get_schema_names(self, engine: Engine, database_name: str): results = engine.dialect.get_schema_names(engine) assert database_name in results - def test_has_table(self, engine, database_name): + def test_has_table(self, engine: Engine, database_name: str): results = engine.dialect.has_table(engine, self.test_table, database_name) assert results == 1 - def test_get_table_names(self, engine, database_name): + def test_get_table_names(self, engine: Engine, database_name: str): results = engine.dialect.get_table_names(engine, database_name) assert len(results) > 0 - def test_get_columns(self, engine, database_name): + def test_get_columns(self, engine: Engine, database_name: str): results = engine.dialect.get_columns(engine, self.test_table, database_name) assert len(results) > 0 row = results[0] diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 2c08093..2e5b04d 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -5,7 +5,7 @@ from firebolt_db import firebolt_dialect -class DBApi: +class MockDBApi: def execute(): pass @@ -14,10 +14,10 @@ def executemany(): @fixture -def dialect(): +def dialect() -> firebolt_dialect.FireboltDialect: return firebolt_dialect.FireboltDialect() @fixture -def connection(): - return mock.Mock(spec=DBApi) +def connection() -> mock.Mock(spec=MockDBApi): + return mock.Mock(spec=MockDBApi) diff --git a/tests/unit/test_firebolt_dialect.py b/tests/unit/test_firebolt_dialect.py index 09c2836..ff76631 100644 --- a/tests/unit/test_firebolt_dialect.py +++ b/tests/unit/test_firebolt_dialect.py @@ -2,25 +2,32 @@ from unittest import mock import sqlalchemy +from conftest import MockDBApi from sqlalchemy.engine import url import firebolt_db -from firebolt_db import firebolt_dialect +from firebolt_db.firebolt_dialect import ( + FireboltCompiler, + FireboltDialect, + FireboltIdentifierPreparer, + FireboltTypeCompiler, +) +from firebolt_db.firebolt_dialect import dialect as dialect_definition class TestFireboltDialect: - def test_create_dialect(self, dialect): - assert issubclass(firebolt_dialect.dialect, firebolt_dialect.FireboltDialect) - assert isinstance(firebolt_dialect.FireboltDialect.dbapi(), type(firebolt_db)) + def test_create_dialect(self, dialect: FireboltDialect): + assert issubclass(dialect_definition, FireboltDialect) + assert isinstance(FireboltDialect.MockDBApi(), type(firebolt_db)) assert dialect.name == "firebolt" assert dialect.driver == "firebolt" - assert issubclass(dialect.preparer, firebolt_dialect.FireboltIdentifierPreparer) - assert issubclass(dialect.statement_compiler, firebolt_dialect.FireboltCompiler) + assert issubclass(dialect.preparer, FireboltIdentifierPreparer) + assert issubclass(dialect.statement_compiler, FireboltCompiler) # SQLAlchemy's DefaultDialect creates an instance of type_compiler behind the scenes - assert isinstance(dialect.type_compiler, firebolt_dialect.FireboltTypeCompiler) + assert isinstance(dialect.type_compiler, FireboltTypeCompiler) assert dialect.context == {} - def test_create_connect_args(self, dialect): + def test_create_connect_args(self, dialect: FireboltDialect): connection_url = ( "test_engine://test_user@email:test_password@test_db_name/test_engine_name" ) @@ -39,7 +46,9 @@ def test_create_connect_args(self, dialect): assert "api_endpoint" not in result_dict assert result_list == [] - def test_schema_names(self, dialect, connection): + def test_schema_names( + self, dialect: FireboltDialect, connection: mock.Mock(spec=MockDBApi) + ): def row_with_schema(name): return mock.Mock(schema_name=name) @@ -53,7 +62,9 @@ def row_with_schema(name): "select schema_name from information_schema.databases" ) - def test_table_names(self, dialect, connection): + def test_table_names( + self, dialect: FireboltDialect, connection: mock.Mock(spec=MockDBApi) + ): def row_with_table_name(name): return mock.Mock(table_name=name) @@ -75,10 +86,14 @@ def row_with_table_name(name): " where table_schema = 'schema'" ) - def test_view_names(self, dialect, connection): + def test_view_names( + self, dialect: FireboltDialect, connection: mock.Mock(spec=MockDBApi) + ): assert dialect.get_view_names(connection) == [] - def test_table_options(self, dialect, connection): + def test_table_options( + self, dialect: FireboltDialect, connection: mock.Mock(spec=MockDBApi) + ): assert dialect.get_table_options(connection, "table") == {} def test_columns(self, dialect, connection): @@ -130,50 +145,64 @@ def getitem(self, idx): connection.execute.assert_called_once_with(expected_query) connection.execute.reset_mock() - def test_pk_constraint(self, dialect, connection): + def test_pk_constraint( + self, dialect: FireboltDialect, connection: mock.Mock(spec=MockDBApi) + ): assert dialect.get_pk_constraint(connection, "table") == { "constrained_columns": [], "name": None, } - def test_foreign_keys(self, dialect, connection): + def test_foreign_keys( + self, dialect: FireboltDialect, connection: mock.Mock(spec=MockDBApi) + ): assert dialect.get_foreign_keys(connection, "table") == [] - def test_check_constraints(self, dialect, connection): + def test_check_constraints( + self, dialect: FireboltDialect, connection: mock.Mock(spec=MockDBApi) + ): assert dialect.get_check_constraints(connection, "table") == [] - def test_table_comment(self, dialect, connection): + def test_table_comment( + self, dialect: FireboltDialect, connection: mock.Mock(spec=MockDBApi) + ): assert dialect.get_table_comment(connection, "table") == {"text": ""} - def test_indexes(self, dialect, connection): + def test_indexes(self, dialect: FireboltDialect, connection: mock.Mock(spec=MockDBApi)): assert dialect.get_indexes(connection, "table") == [] - def test_unique_constraints(self, dialect, connection): + def test_unique_constraints( + self, dialect: FireboltDialect, connection: mock.Mock(spec=MockDBApi) + ): assert dialect.get_unique_constraints(connection, "table") == [] - def test_unicode_returns(self, dialect, connection): + def test_unicode_returns( + self, dialect: FireboltDialect, connection: mock.Mock(spec=MockDBApi) + ): assert dialect._check_unicode_returns(connection) - def test_unicode_description(self, dialect, connection): + def test_unicode_description( + self, dialect: FireboltDialect, connection: mock.Mock(spec=MockDBApi) + ): assert dialect._check_unicode_description(connection) def test_get_is_nullable(): - assert firebolt_dialect.get_is_nullable("YES") - assert firebolt_dialect.get_is_nullable("yes") - assert not firebolt_dialect.get_is_nullable("NO") - assert not firebolt_dialect.get_is_nullable("no") - assert not firebolt_dialect.get_is_nullable("ABC") + assert firebolt_db.firebolt_dialect.get_is_nullable("YES") + assert firebolt_db.firebolt_dialect.get_is_nullable("yes") + assert not firebolt_db.firebolt_dialect.get_is_nullable("NO") + assert not firebolt_db.firebolt_dialect.get_is_nullable("no") + assert not firebolt_db.firebolt_dialect.get_is_nullable("ABC") def test_types(): - assert firebolt_dialect.CHAR is sqlalchemy.sql.sqltypes.CHAR - assert firebolt_dialect.DATE is sqlalchemy.sql.sqltypes.DATE - assert firebolt_dialect.DATETIME is sqlalchemy.sql.sqltypes.DATETIME - assert firebolt_dialect.INTEGER is sqlalchemy.sql.sqltypes.INTEGER - assert firebolt_dialect.BIGINT is sqlalchemy.sql.sqltypes.BIGINT - assert firebolt_dialect.TIMESTAMP is sqlalchemy.sql.sqltypes.TIMESTAMP - assert firebolt_dialect.VARCHAR is sqlalchemy.sql.sqltypes.VARCHAR - assert firebolt_dialect.BOOLEAN is sqlalchemy.sql.sqltypes.BOOLEAN - assert firebolt_dialect.FLOAT is sqlalchemy.sql.sqltypes.FLOAT - assert issubclass(firebolt_dialect.ARRAY, sqlalchemy.types.TypeEngine) + assert firebolt_db.firebolt_dialect.CHAR is sqlalchemy.sql.sqltypes.CHAR + assert firebolt_db.firebolt_dialect.DATE is sqlalchemy.sql.sqltypes.DATE + assert firebolt_db.firebolt_dialect.DATETIME is sqlalchemy.sql.sqltypes.DATETIME + assert firebolt_db.firebolt_dialect.INTEGER is sqlalchemy.sql.sqltypes.INTEGER + assert firebolt_db.firebolt_dialect.BIGINT is sqlalchemy.sql.sqltypes.BIGINT + assert firebolt_db.firebolt_dialect.TIMESTAMP is sqlalchemy.sql.sqltypes.TIMESTAMP + assert firebolt_db.firebolt_dialect.VARCHAR is sqlalchemy.sql.sqltypes.VARCHAR + assert firebolt_db.firebolt_dialect.BOOLEAN is sqlalchemy.sql.sqltypes.BOOLEAN + assert firebolt_db.firebolt_dialect.FLOAT is sqlalchemy.sql.sqltypes.FLOAT + assert issubclass(firebolt_db.firebolt_dialect.ARRAY, sqlalchemy.types.TypeEngine) From cd114d6aa21e170f6793f78c58c0185222f78e7f Mon Sep 17 00:00:00 2001 From: Petro Date: Mon, 22 Nov 2021 11:39:33 +0000 Subject: [PATCH 15/17] Removing redundant operation --- tests/integration/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index a1c29a7..4b76505 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -53,7 +53,6 @@ def engine( @fixture(scope="session") def connection(engine: Engine) -> Connection: - engine = engine if hasattr(firebolt_sdk.db.connection.Connection, "commit"): return engine.connect() else: From 277c978ca368d53f58e4361e7baf944624408678 Mon Sep 17 00:00:00 2001 From: Petro Date: Mon, 22 Nov 2021 15:04:26 +0000 Subject: [PATCH 16/17] Returning db api object on dbapi() call --- src/firebolt_db/firebolt_dialect.py | 5 ++--- tests/unit/test_firebolt_dialect.py | 9 ++++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/firebolt_db/firebolt_dialect.py b/src/firebolt_db/firebolt_dialect.py index 8913377..1f0ce18 100644 --- a/src/firebolt_db/firebolt_dialect.py +++ b/src/firebolt_db/firebolt_dialect.py @@ -5,9 +5,8 @@ CHAR, DATE, DATETIME, INTEGER, BIGINT, TIMESTAMP, VARCHAR, BOOLEAN, FLOAT) -import firebolt_db import os - +import firebolt.db class ARRAY(sqltypes.TypeEngine): __visit_name__ = 'ARRAY' @@ -82,7 +81,7 @@ def __init__(self, context=None, *args, **kwargs): @classmethod def dbapi(cls): - return firebolt_db + return firebolt.db # Build firebolt-sdk compatible connection arguments. # URL format : firebolt://username:password@host:port/db_name diff --git a/tests/unit/test_firebolt_dialect.py b/tests/unit/test_firebolt_dialect.py index ff76631..d4faa22 100644 --- a/tests/unit/test_firebolt_dialect.py +++ b/tests/unit/test_firebolt_dialect.py @@ -1,11 +1,12 @@ import os from unittest import mock +import firebolt.db # Firebolt sdk import sqlalchemy from conftest import MockDBApi from sqlalchemy.engine import url -import firebolt_db +import firebolt_db # SQLAlchemy package from firebolt_db.firebolt_dialect import ( FireboltCompiler, FireboltDialect, @@ -18,7 +19,7 @@ class TestFireboltDialect: def test_create_dialect(self, dialect: FireboltDialect): assert issubclass(dialect_definition, FireboltDialect) - assert isinstance(FireboltDialect.MockDBApi(), type(firebolt_db)) + assert isinstance(FireboltDialect.dbapi(), type(firebolt.db)) assert dialect.name == "firebolt" assert dialect.driver == "firebolt" assert issubclass(dialect.preparer, FireboltIdentifierPreparer) @@ -168,7 +169,9 @@ def test_table_comment( ): assert dialect.get_table_comment(connection, "table") == {"text": ""} - def test_indexes(self, dialect: FireboltDialect, connection: mock.Mock(spec=MockDBApi)): + def test_indexes( + self, dialect: FireboltDialect, connection: mock.Mock(spec=MockDBApi) + ): assert dialect.get_indexes(connection, "table") == [] def test_unique_constraints( From 019023b01d3366fc047df124e85edae333cef62f Mon Sep 17 00:00:00 2001 From: Petro Date: Mon, 22 Nov 2021 15:10:52 +0000 Subject: [PATCH 17/17] Removing redundant import in __init__ --- src/firebolt_db/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/firebolt_db/__init__.py b/src/firebolt_db/__init__.py index b031390..20b9379 100644 --- a/src/firebolt_db/__init__.py +++ b/src/firebolt_db/__init__.py @@ -1,4 +1,3 @@ -from firebolt.db import connect from firebolt.common.exception import ( DatabaseError, DataError,