diff --git a/src/databricks/sqlalchemy/_ddl.py b/src/databricks/sqlalchemy/_ddl.py index 4d825c9fe..7aefb034e 100644 --- a/src/databricks/sqlalchemy/_ddl.py +++ b/src/databricks/sqlalchemy/_ddl.py @@ -19,11 +19,11 @@ def post_create_table(self, table): return " USING DELTA" def visit_unique_constraint(self, constraint, **kw): - logger.warn("Databricks does not support unique constraints") + logger.warning("Databricks does not support unique constraints") pass def visit_check_constraint(self, constraint, **kw): - logger.warn("Databricks does not support check constraints") + logger.warning("This dialect does not support check constraints") pass def visit_identity_column(self, identity, **kw): diff --git a/src/databricks/sqlalchemy/_parse.py b/src/databricks/sqlalchemy/_parse.py index 941737ba6..55f34f950 100644 --- a/src/databricks/sqlalchemy/_parse.py +++ b/src/databricks/sqlalchemy/_parse.py @@ -1,7 +1,9 @@ from typing import List, Optional, Dict import re +import sqlalchemy from sqlalchemy.engine import CursorResult +from sqlalchemy.engine.interfaces import ReflectedColumn """ This module contains helper functions that can parse the contents @@ -9,6 +11,7 @@ wrappers around regexes. """ + def _match_table_not_found_string(message: str) -> bool: """Return True if the message contains a substring indicating that a table was not found""" @@ -22,9 +25,10 @@ def _match_table_not_found_string(message: str) -> bool: ) -def _describe_table_extended_result_to_dict_list(result: CursorResult) -> List[Dict[str, str]]: - """Transform the CursorResult of DESCRIBE TABLE EXTENDED into a list of Dictionaries - """ +def _describe_table_extended_result_to_dict_list( + result: CursorResult, +) -> List[Dict[str, str]]: + """Transform the CursorResult of DESCRIBE TABLE EXTENDED into a list of Dictionaries""" rows_to_return = [] for row in result: @@ -68,22 +72,23 @@ def extract_three_level_identifier_from_constraint_string(input_str: str) -> dic """ pat = re.compile(r"REFERENCES\s+(.*?)\s*\(") matches = pat.findall(input_str) - + if not matches: return None - + first_match = matches[0] parts = first_match.split(".") - def strip_backticks(input:str): + def strip_backticks(input: str): return input.replace("`", "") - + return { - "catalog": strip_backticks(parts[0]), + "catalog": strip_backticks(parts[0]), "schema": strip_backticks(parts[1]), - "table": strip_backticks(parts[2]) + "table": strip_backticks(parts[2]), } + def _parse_fk_from_constraint_string(constraint_str: str) -> dict: """Build a dictionary of foreign key constraint information from a constraint string. @@ -133,6 +138,7 @@ def _parse_fk_from_constraint_string(constraint_str: str) -> dict: "referred_schema": referred_schema, } + def build_fk_dict( fk_name: str, fk_constraint_string: str, schema_name: Optional[str] ) -> dict: @@ -172,6 +178,7 @@ def build_fk_dict( return complete_foreign_key_dict + def _parse_pk_columns_from_constraint_string(constraint_str: str) -> List[str]: """Build a list of constrained columns from a constraint string returned by DESCRIBE TABLE EXTENDED @@ -188,21 +195,23 @@ def _parse_pk_columns_from_constraint_string(constraint_str: str) -> List[str]: return _extracted + def build_pk_dict(pk_name: str, pk_constraint_string: str) -> dict: """Given a primary key name and a primary key constraint string, return a dictionary with the following keys: - + constrained_columns A list of string column names that make up the primary key name The name of the primary key constraint """ - + constrained_columns = _parse_pk_columns_from_constraint_string(pk_constraint_string) return {"constrained_columns": constrained_columns, "name": pk_name} - + + def match_dte_rows_by_value(dte_output: List[Dict[str, str]], match: str) -> List[dict]: """Return a list of dictionaries containing only the col_name:data_type pairs where the `data_type` value contains the match argument. @@ -221,9 +230,10 @@ def match_dte_rows_by_value(dte_output: List[Dict[str, str]], match: str) -> Lis for row_dict in dte_output: if match in row_dict["data_type"]: output_rows.append(row_dict) - + return output_rows + def get_fk_strings_from_dte_output(dte_output: List[List]) -> List[dict]: """If the DESCRIBE TABLE EXTENDED output contains foreign key constraints, return a list of dictionaries, one dictionary per defined constraint @@ -233,8 +243,10 @@ def get_fk_strings_from_dte_output(dte_output: List[List]) -> List[dict]: return output - -def get_pk_strings_from_dte_output(dte_output: List[Dict[str, str]]) -> Optional[List[dict]]: + +def get_pk_strings_from_dte_output( + dte_output: List[Dict[str, str]] +) -> Optional[List[dict]]: """If the DESCRIBE TABLE EXTENDED output contains primary key constraints, return a list of dictionaries, one dictionary per defined constraint. @@ -244,3 +256,82 @@ def get_pk_strings_from_dte_output(dte_output: List[Dict[str, str]]) -> Optional output = match_dte_rows_by_value(dte_output, "PRIMARY KEY") return output + + +# The keys of this dictionary are the values we expect to see in a +# TGetColumnsRequest's .TYPE_NAME attribute. +# These are enumerated in ttypes.py as class TTypeId. +# TODO: confirm that all types in TTypeId are included here. +GET_COLUMNS_TYPE_MAP = { + "boolean": sqlalchemy.types.Boolean, + "smallint": sqlalchemy.types.SmallInteger, + "int": sqlalchemy.types.Integer, + "bigint": sqlalchemy.types.BigInteger, + "float": sqlalchemy.types.Float, + "double": sqlalchemy.types.Float, + "string": sqlalchemy.types.String, + "varchar": sqlalchemy.types.String, + "char": sqlalchemy.types.String, + "binary": sqlalchemy.types.String, + "array": sqlalchemy.types.String, + "map": sqlalchemy.types.String, + "struct": sqlalchemy.types.String, + "uniontype": sqlalchemy.types.String, + "decimal": sqlalchemy.types.Numeric, + "timestamp": sqlalchemy.types.DateTime, + "date": sqlalchemy.types.Date, +} + + +def parse_numeric_type_precision_and_scale(type_name_str): + """Return an intantiated sqlalchemy Numeric() type that preserves the precision and scale indicated + in the output from TGetColumnsRequest. + + type_name_str + The value of TGetColumnsReq.TYPE_NAME. + + If type_name_str is "DECIMAL(18,5) returns sqlalchemy.types.Numeric(18,5) + """ + + pattern = re.compile(r"DECIMAL\((\d+,\d+)\)") + match = re.search(pattern, type_name_str) + precision_and_scale = match.group(1) + precision, scale = tuple(precision_and_scale.split(",")) + + return sqlalchemy.types.Numeric(int(precision), int(scale)) + + +def parse_column_info_from_tgetcolumnsresponse(thrift_resp_row) -> ReflectedColumn: + """Returns a dictionary of the ReflectedColumn schema parsed from + a single of the result of a TGetColumnsRequest thrift RPC + """ + + pat = re.compile(r"^\w+") + _raw_col_type = re.search(pat, thrift_resp_row.TYPE_NAME).group(0).lower() + _col_type = GET_COLUMNS_TYPE_MAP[_raw_col_type] + + if _raw_col_type == "decimal": + final_col_type = parse_numeric_type_precision_and_scale( + thrift_resp_row.TYPE_NAME + ) + else: + final_col_type = _col_type + + # See comments about autoincrement in test_suite.py + # Since Databricks SQL doesn't currently support inline AUTOINCREMENT declarations + # the autoincrement must be manually declared with an Identity() construct in SQLAlchemy + # Other dialects can perform this extra Identity() step automatically. But that is not + # implemented in the Databricks dialect right now. So autoincrement is currently always False. + # It's not clear what IS_AUTO_INCREMENT in the thrift response actually reflects or whether + # it ever returns a `YES`. + + # Per the guidance in SQLAlchemy's docstrings, we prefer to not even include an autoincrement + # key in this dictionary. + this_column = { + "name": thrift_resp_row.COLUMN_NAME, + "type": final_col_type, + "nullable": bool(thrift_resp_row.NULLABLE), + "default": thrift_resp_row.COLUMN_DEF, + } + + return this_column diff --git a/src/databricks/sqlalchemy/base.py b/src/databricks/sqlalchemy/base.py index df3823439..fa100f4f7 100644 --- a/src/databricks/sqlalchemy/base.py +++ b/src/databricks/sqlalchemy/base.py @@ -1,5 +1,5 @@ import re -from typing import Any, List, Optional, Dict +from typing import Any, List, Optional, Dict, Collection, Iterable, Tuple import databricks.sqlalchemy._ddl as dialect_ddl_impl import databricks.sqlalchemy._types as dialect_type_impl @@ -11,14 +11,18 @@ build_pk_dict, get_fk_strings_from_dte_output, get_pk_strings_from_dte_output, + parse_column_info_from_tgetcolumnsresponse, ) import sqlalchemy from sqlalchemy import DDL, event from sqlalchemy.engine import Connection, Engine, default, reflection +from sqlalchemy.engine.reflection import ObjectKind from sqlalchemy.engine.interfaces import ( ReflectedForeignKeyConstraint, ReflectedPrimaryKeyConstraint, + ReflectedColumn, + TableKey, ) from sqlalchemy.exc import DatabaseError, SQLAlchemyError @@ -38,27 +42,6 @@ class DatabricksImpl(DefaultImpl): logger = logging.getLogger(__name__) -COLUMN_TYPE_MAP = { - "boolean": sqlalchemy.types.Boolean, - "smallint": sqlalchemy.types.SmallInteger, - "int": sqlalchemy.types.Integer, - "bigint": sqlalchemy.types.BigInteger, - "float": sqlalchemy.types.Float, - "double": sqlalchemy.types.Float, - "string": sqlalchemy.types.String, - "varchar": sqlalchemy.types.String, - "char": sqlalchemy.types.String, - "binary": sqlalchemy.types.String, - "array": sqlalchemy.types.String, - "map": sqlalchemy.types.String, - "struct": sqlalchemy.types.String, - "uniontype": sqlalchemy.types.String, - "decimal": sqlalchemy.types.Numeric, - "timestamp": sqlalchemy.types.DateTime, - "date": sqlalchemy.types.Date, -} - - class DatabricksDialect(default.DefaultDialect): """This dialect implements only those methods required to pass our e2e tests""" @@ -113,36 +96,10 @@ def create_connect_args(self, url): return [], kwargs - def get_columns(self, connection, table_name, schema=None, **kwargs): - """Return information about columns in `table_name`. - - Given a :class:`_engine.Connection`, a string - `table_name`, and an optional string `schema`, return column - information as a list of dictionaries with these keys: - - name - the column's name - - type - [sqlalchemy.types#TypeEngine] - - nullable - boolean - - default - the column's default value - - autoincrement - boolean - - sequence - a dictionary of the form - {'name' : str, 'start' :int, 'increment': int, 'minvalue': int, - 'maxvalue': int, 'nominvalue': bool, 'nomaxvalue': bool, - 'cycle': bool, 'cache': int, 'order': bool} - - Additional column attributes may be present. - """ + def get_columns( + self, connection, table_name, schema=None, **kwargs + ) -> List[ReflectedColumn]: + """Return information about columns in `table_name`.""" with self.get_connection_cursor(connection) as cur: resp = cur.columns( @@ -154,18 +111,9 @@ def get_columns(self, connection, table_name, schema=None, **kwargs): if not resp: raise sqlalchemy.exc.NoSuchTableError(table_name) columns = [] - for col in resp: - # Taken from PyHive. This removes added type info from decimals and maps - _col_type = re.search(r"^\w+", col.TYPE_NAME).group(0) - this_column = { - "name": col.COLUMN_NAME, - "type": COLUMN_TYPE_MAP[_col_type.lower()], - "nullable": bool(col.NULLABLE), - "default": col.COLUMN_DEF, - "autoincrement": False if col.IS_AUTO_INCREMENT == "NO" else True, - } - columns.append(this_column) + row_dict = parse_column_info_from_tgetcolumnsresponse(col) + columns.append(row_dict) return columns @@ -279,31 +227,68 @@ def get_foreign_keys( return fk_constraints def get_indexes(self, connection, table_name, schema=None, **kw): - """SQLAlchemy requires this method. Databricks doesn't support indexes. - """ + """SQLAlchemy requires this method. Databricks doesn't support indexes.""" return self.EMPTY_INDEX - def get_table_names(self, connection, schema=None, **kwargs): - TABLE_NAME = 1 - with self.get_connection_cursor(connection) as cur: - sql_str = "SHOW TABLES FROM {}".format( - ".".join([self.catalog, schema or self.schema]) - ) - data = cur.execute(sql_str).fetchall() - _tables = [i[TABLE_NAME] for i in data] + @reflection.cache + def get_table_names(self, connection: Connection, schema=None, **kwargs): + """Return a list of tables in the current schema.""" - return _tables + _target_catalog = self.catalog + _target_schema = schema or self.schema + _target = f"`{_target_catalog}`.`{_target_schema}`" - def get_view_names(self, connection, schema=None, **kwargs): - VIEW_NAME = 1 - with self.get_connection_cursor(connection) as cur: - sql_str = "SHOW VIEWS FROM {}".format( - ".".join([self.catalog, schema or self.schema]) - ) - data = cur.execute(sql_str).fetchall() - _tables = [i[VIEW_NAME] for i in data] + stmt = DDL(f"SHOW TABLES FROM {_target}") + + tables_result = connection.execute(stmt).all() + views_result = self.get_view_names(connection=connection, schema=schema) + + # In Databricks, SHOW TABLES FROM returns both tables and views. + # Potential optimisation: rewrite this to instead query informtation_schema + tables_minus_views = [ + row.tableName for row in tables_result if row.tableName not in views_result + ] + + return tables_minus_views - return _tables + @reflection.cache + def get_view_names( + self, + connection, + schema=None, + only_materialized=False, + only_temp=False, + **kwargs, + ) -> List[str]: + """Returns a list of string view names contained in the schema, if any.""" + + _target_catalog = self.catalog + _target_schema = schema or self.schema + _target = f"`{_target_catalog}`.`{_target_schema}`" + + stmt = DDL(f"SHOW VIEWS FROM {_target}") + result = connection.execute(stmt).all() + + return [ + row.viewName + for row in result + if (not only_materialized or row.isMaterialized) + and (not only_temp or row.isTemporary) + ] + + @reflection.cache + def get_materialized_view_names( + self, connection: Connection, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + """A wrapper around get_view_names that fetches only the names of materialized views""" + return self.get_view_names(connection, schema, only_materialized=True) + + @reflection.cache + def get_temp_view_names( + self, connection: Connection, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + """A wrapper around get_view_names taht fetches only the names of temporary views""" + return self.get_view_names(connection, schema, only_temp=True) def do_rollback(self, dbapi_connection): # Databricks SQL Does not support transactions diff --git a/src/databricks/sqlalchemy/requirements.py b/src/databricks/sqlalchemy/requirements.py index 614eea2e6..6fb252dbc 100644 --- a/src/databricks/sqlalchemy/requirements.py +++ b/src/databricks/sqlalchemy/requirements.py @@ -154,6 +154,11 @@ def temporary_tables(self): """ return sqlalchemy.testing.exclusions.closed() + @property + def table_reflection(self): + """target database has general support for table reflection""" + return sqlalchemy.testing.exclusions.open() + @property def temp_table_reflection(self): """ComponentReflection test is intricate and simply cannot function without this exclusion being defined here. @@ -173,7 +178,7 @@ def unique_constraint_reflection(self): """ComponentReflection test is intricate and simply cannot function without this exclusion being defined here. This happens because we cannot skip individual combinations used in ComponentReflection test. - Databricks supports unique constraints but they are not implemented in this dialect. + Databricks doesn't support UNIQUE constraints. """ return sqlalchemy.testing.exclusions.closed() diff --git a/src/databricks/sqlalchemy/test/test_suite.py b/src/databricks/sqlalchemy/test/test_suite.py index 93096b509..4b13dbeee 100644 --- a/src/databricks/sqlalchemy/test/test_suite.py +++ b/src/databricks/sqlalchemy/test/test_suite.py @@ -330,23 +330,33 @@ class CompositeKeyReflectionTest(CompositeKeyReflectionTest): pass +@pytest.mark.reviewed class ComponentReflectionTestExtra(ComponentReflectionTestExtra): - @pytest.mark.skip(reason="Test setup needs adjustment.") + @pytest.mark.skip(reason="This dialect does not support check constraints") + def test_get_check_constraints(self): + pass + + @pytest.mark.skip(reason="Databricks does not support indexes.") + def test_reflect_covering_index(self): + pass + + @pytest.mark.skip(reason="Databricks does not support indexes.") + def test_reflect_expression_based_indexes(self): + pass + + @pytest.mark.skip( + reason="Databricks doesn't enforce String or VARCHAR length limitations." + ) def test_varchar_reflection(self): - """ - Exception: - databricks.sql.exc.ServerOperationError: [TABLE_OR_VIEW_ALREADY_EXISTS] Cannot create table or view `pysql_sqlalchemy`.`t` because it already exists. - Choose a different name, drop or replace the existing object, add the IF NOT EXISTS clause to tolerate pre-existing objects, or add the OR REFRESH clause to refresh the existing streaming table. - """ + """Even if a user specifies String(52), Databricks won't enforce that limit.""" + pass - @pytest.mark.skip(reason="Test setup appears broken") - def test_numeric_reflection(self): - """ - Exception: - databricks.sql.exc.ServerOperationError: [SCHEMA_NOT_FOUND] The schema `main.test_schema` cannot be found. Verify the spelling and correctness of the schema and catalog. - If you did not qualify the name with a catalog, verify the current_schema() output, or qualify the name with the correct catalog. - To tolerate the error on drop use DROP SCHEMA IF EXISTS. - """ + @pytest.mark.skip( + reason="This dialect doesn't implement foreign key options checks." + ) + def test_get_foreign_key_options(self): + """It's not clear from the test code what the expected output is here. Further research required.""" + pass class DifficultParametersTest(DifficultParametersTest): @@ -411,11 +421,46 @@ class ComponentReflectionTest(ComponentReflectionTest): """This test requires two schemas be present in the target Databricks workspace: - The schema set in --dburi - A second schema named "test_schema" + + Note that test_get_multi_foreign keys is flaky because DBR does not guarantee the order of data returned in DESCRIBE TABLE EXTENDED """ - # We've reviewed these tests: - # test_get_schema_names - # test_not_existing_table + @pytest.mark.skip( + reason="Comment reflection is possible but not enabled in this dialect" + ) + def test_get_multi_table_comment(self): + """There are 84 permutations of this test that are skipped.""" + pass + + @pytest.mark.skip(reason="Databricks doesn't support UNIQUE constraints") + def test_get_multi_unique_constraints(self): + pass + + @pytest.mark.skip( + reason="This dialect doesn't support get_table_options. See comment in test_suite.py" + ) + def test_multi_get_table_options_tables(self): + """It's not clear what the expected ouput from this method would even _be_. Requires research.""" + pass + + @pytest.mark.skip("This dialect doesn't implement get_view_definition") + def test_get_view_definition(self): + pass + + @pytest.mark.skip(reason="This dialect doesn't implement get_view_definition") + def test_get_view_definition_does_not_exist(self): + pass + + @pytest.mark.skip(reason="Strange test design. See test_suite.py") + def test_get_temp_view_names(self): + """While Databricks supports temporary views, this test creates a temp view aimed at a temp table. + Databricks doesn't support temp tables. So the test can never pass. + """ + pass + + @pytest.mark.skip("This dialect doesn't implement get_multi_pk_constraint") + def test_get_multi_pk_constraint(self): + pass @pytest.mark.skip(reason="Databricks doesn't support temp tables.") def test_get_temp_table_columns(self): @@ -485,3 +530,21 @@ class QuotedNameArgumentTest(QuotedNameArgumentTest): also checks the behaviour of DDL identifier preparation process. We need to override some of IdentifierPreparer methods because these are the ultimate control for whether or not CHECK and UNIQUE constraints are emitted. """ + + +@pytest.mark.reviewed +@pytest.mark.skip(reason="Implementation deferred. See test_suite.py") +class BizarroCharacterFKResolutionTest: + """Some of the combinations in this test pass. Others fail. Given the esoteric nature of these failures, + we have opted to defer implementing fixes to a later time, guided by customer feedback. Passage of + these tests is not an acceptance criteria for our dialect. + """ + + +@pytest.mark.reviewed +@pytest.mark.skip(reason="Implementation deferred. See test_suite.py") +class DifficultParametersTest: + """Some of the combinations in this test pass. Others fail. Given the esoteric nature of these failures, + we have opted to defer implementing fixes to a later time, guided by customer feedback. Passage of + these tests is not an acceptance criteria for our dialect. + """