diff --git a/src/firebolt_db/firebolt_dialect.py b/src/firebolt_db/firebolt_dialect.py index 47400ae..4623ea6 100644 --- a/src/firebolt_db/firebolt_dialect.py +++ b/src/firebolt_db/firebolt_dialect.py @@ -12,6 +12,7 @@ from sqlalchemy.engine.url import URL from sqlalchemy.sql import compiler, text from sqlalchemy.types import ( + ARRAY, BIGINT, BOOLEAN, CHAR, @@ -24,8 +25,8 @@ ) -class ARRAY(sqltypes.TypeEngine): - __visit_name__ = "ARRAY" +class BYTEA(sqltypes.LargeBinary): + __visit_name__ = "BYTEA" # Firebolt data types compatibility with sqlalchemy.sql.types @@ -37,18 +38,47 @@ class ARRAY(sqltypes.TypeEngine): "float": FLOAT, "double": FLOAT, "double precision": FLOAT, + "real": FLOAT, "boolean": BOOLEAN, "int": INTEGER, "integer": INTEGER, "bigint": BIGINT, "long": BIGINT, "timestamp": TIMESTAMP, + "timestamptz": TIMESTAMP, + "timestampntz": TIMESTAMP, "datetime": DATETIME, "date": DATE, - "array": ARRAY, + "bytea": BYTEA, } +def resolve_type(fb_type: str) -> sqltypes.TypeEngine: + def removesuffix(s: str, suffix: str) -> str: + """Python < 3.9 compatibility""" + if s.endswith(suffix): + s = s[: -len(suffix)] + return s + + result: sqltypes.TypeEngine + if fb_type.startswith("array"): + # Nested arrays not supported + dimensions = 0 + while fb_type.startswith("array"): + dimensions += 1 + fb_type = fb_type[6:-1] # Strip ARRAY() + fb_type = removesuffix(removesuffix(fb_type, " not null"), " null") + result = ARRAY(resolve_type(fb_type), dimensions=dimensions) + else: + # Strip complex type info e.g. DECIMAL(8,23) -> DECIMAL + fb_type = fb_type[: fb_type.find("(")] if "(" in fb_type else fb_type + result = type_map.get(fb_type, DEFAULT_TYPE) # type: ignore + return result + + +DEFAULT_TYPE = VARCHAR + + class UniversalSet(set): def __contains__(self, item: Any) -> bool: return True @@ -193,6 +223,7 @@ def get_columns( schema: Optional[str] = None, **kwargs: Any ) -> List[Dict]: + query = """ select column_name, data_type, @@ -212,7 +243,7 @@ def get_columns( return [ { "name": row[0], - "type": type_map[row[1].lower()], + "type": resolve_type(row[1].lower()), "nullable": get_is_nullable(row[2]), "default": None, } diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index f04dd57..b9f0002 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,6 +1,7 @@ import asyncio from logging import getLogger from os import environ +from typing import List from pytest import fixture from sqlalchemy import create_engine, text @@ -139,6 +140,44 @@ def ex_table_query(ex_table_name: str) -> str: """ +@fixture(scope="class") +def type_table_name() -> str: + return "types_alchemy" + + +@fixture(scope="class") +def firebolt_columns() -> List[str]: + return [ + "INTEGER", + "NUMERIC", + "BIGINT", + "REAL", + "DOUBLE PRECISION", + "TEXT", + "TIMESTAMPNTZ", + "TIMESTAMPTZ", + "DATE", + "TIMESTAMP", + "BOOLEAN", + "BYTEA", + ] + + +@fixture(scope="class") +def type_table_query(firebolt_columns: List[str], type_table_name: str) -> str: + col_names = [c.replace(" ", "_").lower() for c in firebolt_columns] + cols = ",\n".join( + [f"c_{name} {c_type}" for name, c_type in zip(col_names, firebolt_columns)] + ) + return f""" + CREATE DIMENSION TABLE {type_table_name} + ( + {cols}, + c_array ARRAY(ARRAY(INTEGER)) + ); + """ + + @fixture(scope="class") def fact_table_name() -> str: return "test_alchemy" @@ -155,6 +194,8 @@ def setup_test_tables( engine: Engine, fact_table_name: str, dimension_table_name: str, + type_table_query: str, + type_table_name: str, ): connection.execute( text( @@ -178,11 +219,15 @@ def setup_test_tables( """ ) ) + connection.execute(text(type_table_query)) assert engine.dialect.has_table(connection, fact_table_name) assert engine.dialect.has_table(connection, dimension_table_name) + assert engine.dialect.has_table(connection, type_table_name) yield # Teardown connection.execute(text(f"DROP TABLE IF EXISTS {fact_table_name} CASCADE;")) connection.execute(text(f"DROP TABLE IF EXISTS {dimension_table_name} CASCADE;")) + connection.execute(text(f"DROP TABLE IF EXISTS {type_table_name} CASCADE;")) assert not engine.dialect.has_table(connection, fact_table_name) assert not engine.dialect.has_table(connection, dimension_table_name) + assert not engine.dialect.has_table(connection, type_table_name) diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index ec9aa32..2093c45 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -5,6 +5,7 @@ from sqlalchemy import create_engine, text from sqlalchemy.engine.base import Connection, Engine from sqlalchemy.exc import OperationalError +from sqlalchemy.types import ARRAY, INTEGER, TypeEngine class TestFireboltDialect: @@ -114,17 +115,19 @@ def test_get_table_names(self, engine: Engine, connection: Connection): assert len(results) == 0 def test_get_columns( - self, engine: Engine, connection: Connection, fact_table_name: str + self, engine: Engine, connection: Connection, type_table_name: str ): - results = engine.dialect.get_columns(connection, fact_table_name) + results = engine.dialect.get_columns(connection, type_table_name) assert len(results) > 0 - row = results[0] - assert isinstance(row, dict) - row_keys = list(row.keys()) - assert row_keys[0] == "name" - assert row_keys[1] == "type" - assert row_keys[2] == "nullable" - assert row_keys[3] == "default" + for column in results: + assert isinstance(column, dict) + # Check only works for basic types + if type(column["type"]) == ARRAY: + # ARRAY[[INT]] + assert column["type"].dimensions == 2 + assert type(column["type"].item_type) == INTEGER + else: + assert issubclass(column["type"], TypeEngine) def test_service_account_connect(self, connection_service_account: Connection): result = connection_service_account.execute(text("SELECT 1"))