Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/firebolt_db/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from firebolt.db import connect
from firebolt.common.exception import (
DatabaseError,
DataError,
Expand Down
33 changes: 22 additions & 11 deletions src/firebolt_db/firebolt_dialect.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
import os

import sqlalchemy.pool.base
import sqlalchemy.types as sqltypes
from sqlalchemy.engine import default
from sqlalchemy.sql import compiler
from sqlalchemy.types import (
CHAR, DATE, DATETIME, INTEGER, BIGINT,
TIMESTAMP, VARCHAR, BOOLEAN, FLOAT)
BIGINT,
BOOLEAN,
CHAR,
DATE,
DATETIME,
FLOAT,
INTEGER,
TIMESTAMP,
VARCHAR,
)

import firebolt_db

import os
import firebolt.db

class ARRAY(sqltypes.TypeEngine):
__visit_name__ = 'ARRAY'
__visit_name__ = "ARRAY"


# Firebolt data types compatibility with sqlalchemy.sql.types
Expand Down Expand Up @@ -47,7 +58,6 @@ class FireboltCompiler(compiler.SQLCompiler):


class FireboltTypeCompiler(compiler.GenericTypeCompiler):

def visit_ARRAY(self, type, **kw):
return "Array(%s)" % type

Expand Down Expand Up @@ -81,7 +91,7 @@ def __init__(self, context=None, *args, **kwargs):

@classmethod
def dbapi(cls):
return firebolt.db
return firebolt_db

# Build firebolt-sdk compatible connection arguments.
# URL format : firebolt://username:password@host:port/db_name
Expand All @@ -90,7 +100,7 @@ def create_connect_args(self, url):
"database": url.host or None,
"username": url.username or None,
"password": url.password or None,
"engine_name": url.database
"engine_name": url.database,
}
# If URL override is not provided leave it to the sdk to determine the endpoint
if "FIREBOLT_BASE_URL" in os.environ:
Expand All @@ -100,9 +110,7 @@ def create_connect_args(self, url):
def get_schema_names(self, connection, **kwargs):
query = "select schema_name from information_schema.databases"
result = connection.execute(query)
return [
row.schema_name for row in result
]
return [row.schema_name for row in result]

def has_table(self, connection, table_name, schema=None):
query = """
Expand Down Expand Up @@ -189,6 +197,9 @@ def _check_unicode_returns(self, connection, additional_tests=None):
def _check_unicode_description(self, connection):
return True

def do_commit(self, dbapi_connection: sqlalchemy.pool.base._ConnectionFairy):
pass


dialect = FireboltDialect

Expand Down
7 changes: 1 addition & 6 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,4 @@ def engine(

@fixture(scope="session")
def connection(engine: Engine) -> Connection:
if hasattr(firebolt_sdk.db.connection.Connection, "commit"):
return engine.connect()
else:
# Disabling autocommit allows for table creation/destruction without
# trying to call non-existing parameters
return engine.connect().execution_options(autocommit=False)
return engine.connect()
2 changes: 0 additions & 2 deletions tests/integration/test_sqlalchemy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def setup_test_tables(self, connection: Connection, engine: Engine):
yield
self.drop_test_table(connection, engine, self.test_table)

@pytest.mark.skip(reason="Commit not implemented in sdk")
def test_create_ex_table(self, connection: Connection, engine: Engine):
connection.execute(
"""
Expand Down Expand Up @@ -61,7 +60,6 @@ def test_create_ex_table(self, connection: Connection, engine: Engine):
connection.execute("DROP TABLE ex_lineitem_alchemy;")
assert not engine.dialect.has_table(engine, "ex_lineitem_alchemy")

@pytest.mark.skip(reason="Commit not implemented in sdk")
def test_data_write(self, connection: Connection):
connection.execute(
"INSERT INTO test_alchemy(idx, dummy) VALUES (1, 'some_text')"
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/test_firebolt_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
class TestFireboltDialect:
def test_create_dialect(self, dialect: FireboltDialect):
assert issubclass(dialect_definition, FireboltDialect)
assert isinstance(FireboltDialect.dbapi(), type(firebolt.db))
assert isinstance(FireboltDialect.dbapi(), type(firebolt_db))
assert dialect.name == "firebolt"
assert dialect.driver == "firebolt"
assert issubclass(dialect.preparer, FireboltIdentifierPreparer)
Expand Down Expand Up @@ -97,7 +97,9 @@ def test_table_options(
):
assert dialect.get_table_options(connection, "table") == {}

def test_columns(self, dialect, connection):
def test_columns(
self, dialect: FireboltDialect, connection: mock.Mock(spec=MockDBApi)
):
def multi_column_row(columns):
def getitem(self, idx):
for i, result in enumerate(columns):
Expand Down