Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions src/firebolt_db/firebolt_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sqlalchemy.engine.url import URL
from sqlalchemy.sql import compiler, text
from sqlalchemy.types import (
ARRAY,
BIGINT,
BOOLEAN,
CHAR,
Expand All @@ -24,8 +25,8 @@
)


class ARRAY(sqltypes.TypeEngine):
__visit_name__ = "ARRAY"
class BYTEA(sqltypes.LargeBinary):
__visit_name__ = "BYTEA"


# Firebolt data types compatibility with sqlalchemy.sql.types
Expand All @@ -37,18 +38,47 @@ class ARRAY(sqltypes.TypeEngine):
"float": FLOAT,
"double": FLOAT,
"double precision": FLOAT,
"real": FLOAT,
"boolean": BOOLEAN,
"int": INTEGER,
"integer": INTEGER,
"bigint": BIGINT,
"long": BIGINT,
"timestamp": TIMESTAMP,
"timestamptz": TIMESTAMP,
"timestampntz": TIMESTAMP,
"datetime": DATETIME,
"date": DATE,
"array": ARRAY,
"bytea": BYTEA,
}


def resolve_type(fb_type: str) -> sqltypes.TypeEngine:
def removesuffix(s: str, suffix: str) -> str:
"""Python < 3.9 compatibility"""
if s.endswith(suffix):
s = s[: -len(suffix)]
return s

result: sqltypes.TypeEngine
if fb_type.startswith("array"):
# Nested arrays not supported
dimensions = 0
while fb_type.startswith("array"):
dimensions += 1
fb_type = fb_type[6:-1] # Strip ARRAY()
fb_type = removesuffix(removesuffix(fb_type, " not null"), " null")
result = ARRAY(resolve_type(fb_type), dimensions=dimensions)
else:
# Strip complex type info e.g. DECIMAL(8,23) -> DECIMAL
fb_type = fb_type[: fb_type.find("(")] if "(" in fb_type else fb_type
result = type_map.get(fb_type, DEFAULT_TYPE) # type: ignore
return result


DEFAULT_TYPE = VARCHAR


class UniversalSet(set):
def __contains__(self, item: Any) -> bool:
return True
Expand Down Expand Up @@ -193,6 +223,7 @@ def get_columns(
schema: Optional[str] = None,
**kwargs: Any
) -> List[Dict]:

query = """
select column_name,
data_type,
Expand All @@ -212,7 +243,7 @@ def get_columns(
return [
{
"name": row[0],
"type": type_map[row[1].lower()],
"type": resolve_type(row[1].lower()),
"nullable": get_is_nullable(row[2]),
"default": None,
}
Expand Down
45 changes: 45 additions & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from logging import getLogger
from os import environ
from typing import List

from pytest import fixture
from sqlalchemy import create_engine, text
Expand Down Expand Up @@ -139,6 +140,44 @@ def ex_table_query(ex_table_name: str) -> str:
"""


@fixture(scope="class")
def type_table_name() -> str:
return "types_alchemy"


@fixture(scope="class")
def firebolt_columns() -> List[str]:
return [
"INTEGER",
"NUMERIC",
"BIGINT",
"REAL",
"DOUBLE PRECISION",
"TEXT",
"TIMESTAMPNTZ",
"TIMESTAMPTZ",
"DATE",
"TIMESTAMP",
"BOOLEAN",
"BYTEA",
]


@fixture(scope="class")
def type_table_query(firebolt_columns: List[str], type_table_name: str) -> str:
col_names = [c.replace(" ", "_").lower() for c in firebolt_columns]
cols = ",\n".join(
[f"c_{name} {c_type}" for name, c_type in zip(col_names, firebolt_columns)]
)
return f"""
CREATE DIMENSION TABLE {type_table_name}
(
{cols},
c_array ARRAY(ARRAY(INTEGER))
);
"""


@fixture(scope="class")
def fact_table_name() -> str:
return "test_alchemy"
Expand All @@ -155,6 +194,8 @@ def setup_test_tables(
engine: Engine,
fact_table_name: str,
dimension_table_name: str,
type_table_query: str,
type_table_name: str,
):
connection.execute(
text(
Expand All @@ -178,11 +219,15 @@ def setup_test_tables(
"""
)
)
connection.execute(text(type_table_query))
assert engine.dialect.has_table(connection, fact_table_name)
assert engine.dialect.has_table(connection, dimension_table_name)
assert engine.dialect.has_table(connection, type_table_name)
yield
# Teardown
connection.execute(text(f"DROP TABLE IF EXISTS {fact_table_name} CASCADE;"))
connection.execute(text(f"DROP TABLE IF EXISTS {dimension_table_name} CASCADE;"))
connection.execute(text(f"DROP TABLE IF EXISTS {type_table_name} CASCADE;"))
assert not engine.dialect.has_table(connection, fact_table_name)
assert not engine.dialect.has_table(connection, dimension_table_name)
assert not engine.dialect.has_table(connection, type_table_name)
21 changes: 12 additions & 9 deletions tests/integration/test_sqlalchemy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sqlalchemy import create_engine, text
from sqlalchemy.engine.base import Connection, Engine
from sqlalchemy.exc import OperationalError
from sqlalchemy.types import ARRAY, INTEGER, TypeEngine


class TestFireboltDialect:
Expand Down Expand Up @@ -114,17 +115,19 @@ def test_get_table_names(self, engine: Engine, connection: Connection):
assert len(results) == 0

def test_get_columns(
self, engine: Engine, connection: Connection, fact_table_name: str
self, engine: Engine, connection: Connection, type_table_name: str
):
results = engine.dialect.get_columns(connection, fact_table_name)
results = engine.dialect.get_columns(connection, type_table_name)
assert len(results) > 0
row = results[0]
assert isinstance(row, dict)
row_keys = list(row.keys())
assert row_keys[0] == "name"
assert row_keys[1] == "type"
assert row_keys[2] == "nullable"
assert row_keys[3] == "default"
for column in results:
assert isinstance(column, dict)
# Check only works for basic types
if type(column["type"]) == ARRAY:
# ARRAY[[INT]]
assert column["type"].dimensions == 2
assert type(column["type"].item_type) == INTEGER
else:
assert issubclass(column["type"], TypeEngine)

def test_service_account_connect(self, connection_service_account: Connection):
result = connection_service_account.execute(text("SELECT 1"))
Expand Down