diff --git a/src/firebolt_db/__init__.py b/src/firebolt_db/__init__.py index 20b9379..b031390 100644 --- a/src/firebolt_db/__init__.py +++ b/src/firebolt_db/__init__.py @@ -1,3 +1,4 @@ +from firebolt.db import connect from firebolt.common.exception import ( DatabaseError, DataError, diff --git a/src/firebolt_db/firebolt_dialect.py b/src/firebolt_db/firebolt_dialect.py index 1f0ce18..cd641b0 100644 --- a/src/firebolt_db/firebolt_dialect.py +++ b/src/firebolt_db/firebolt_dialect.py @@ -1,15 +1,26 @@ +import os + +import sqlalchemy.pool.base import sqlalchemy.types as sqltypes from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.types import ( - CHAR, DATE, DATETIME, INTEGER, BIGINT, - TIMESTAMP, VARCHAR, BOOLEAN, FLOAT) + BIGINT, + BOOLEAN, + CHAR, + DATE, + DATETIME, + FLOAT, + INTEGER, + TIMESTAMP, + VARCHAR, +) + +import firebolt_db -import os -import firebolt.db class ARRAY(sqltypes.TypeEngine): - __visit_name__ = 'ARRAY' + __visit_name__ = "ARRAY" # Firebolt data types compatibility with sqlalchemy.sql.types @@ -47,7 +58,6 @@ class FireboltCompiler(compiler.SQLCompiler): class FireboltTypeCompiler(compiler.GenericTypeCompiler): - def visit_ARRAY(self, type, **kw): return "Array(%s)" % type @@ -81,7 +91,7 @@ def __init__(self, context=None, *args, **kwargs): @classmethod def dbapi(cls): - return firebolt.db + return firebolt_db # Build firebolt-sdk compatible connection arguments. # URL format : firebolt://username:password@host:port/db_name @@ -90,7 +100,7 @@ def create_connect_args(self, url): "database": url.host or None, "username": url.username or None, "password": url.password or None, - "engine_name": url.database + "engine_name": url.database, } # If URL override is not provided leave it to the sdk to determine the endpoint if "FIREBOLT_BASE_URL" in os.environ: @@ -100,9 +110,7 @@ def create_connect_args(self, url): def get_schema_names(self, connection, **kwargs): query = "select schema_name from information_schema.databases" result = connection.execute(query) - return [ - row.schema_name for row in result - ] + return [row.schema_name for row in result] def has_table(self, connection, table_name, schema=None): query = """ @@ -189,6 +197,9 @@ def _check_unicode_returns(self, connection, additional_tests=None): def _check_unicode_description(self, connection): return True + def do_commit(self, dbapi_connection: sqlalchemy.pool.base._ConnectionFairy): + pass + dialect = FireboltDialect diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 4b76505..c80ff61 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -53,9 +53,4 @@ def engine( @fixture(scope="session") def connection(engine: Engine) -> Connection: - if hasattr(firebolt_sdk.db.connection.Connection, "commit"): - return engine.connect() - else: - # Disabling autocommit allows for table creation/destruction without - # trying to call non-existing parameters - return engine.connect().execution_options(autocommit=False) + return engine.connect() diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 1825366..5af0788 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -29,7 +29,6 @@ def setup_test_tables(self, connection: Connection, engine: Engine): yield self.drop_test_table(connection, engine, self.test_table) - @pytest.mark.skip(reason="Commit not implemented in sdk") def test_create_ex_table(self, connection: Connection, engine: Engine): connection.execute( """ @@ -61,7 +60,6 @@ def test_create_ex_table(self, connection: Connection, engine: Engine): connection.execute("DROP TABLE ex_lineitem_alchemy;") assert not engine.dialect.has_table(engine, "ex_lineitem_alchemy") - @pytest.mark.skip(reason="Commit not implemented in sdk") def test_data_write(self, connection: Connection): connection.execute( "INSERT INTO test_alchemy(idx, dummy) VALUES (1, 'some_text')" diff --git a/tests/unit/test_firebolt_dialect.py b/tests/unit/test_firebolt_dialect.py index d4faa22..912edb6 100644 --- a/tests/unit/test_firebolt_dialect.py +++ b/tests/unit/test_firebolt_dialect.py @@ -19,7 +19,7 @@ class TestFireboltDialect: def test_create_dialect(self, dialect: FireboltDialect): assert issubclass(dialect_definition, FireboltDialect) - assert isinstance(FireboltDialect.dbapi(), type(firebolt.db)) + assert isinstance(FireboltDialect.dbapi(), type(firebolt_db)) assert dialect.name == "firebolt" assert dialect.driver == "firebolt" assert issubclass(dialect.preparer, FireboltIdentifierPreparer) @@ -97,7 +97,9 @@ def test_table_options( ): assert dialect.get_table_options(connection, "table") == {} - def test_columns(self, dialect, connection): + def test_columns( + self, dialect: FireboltDialect, connection: mock.Mock(spec=MockDBApi) + ): def multi_column_row(columns): def getitem(self, idx): for i, result in enumerate(columns):