From fc9387a0e5431828e12457b44658f9737b036c01 Mon Sep 17 00:00:00 2001 From: Jesse Whitehouse Date: Sat, 14 Oct 2023 16:55:47 -0400 Subject: [PATCH] Move contents of __init__.py into base.py Finish refactor of get_primary_keys and get_pk_constraint to improve readability. Signed-off-by: Jesse Whitehouse --- src/databricks/sqlalchemy/__init__.py | 456 +----------------- src/databricks/sqlalchemy/_parse.py | 107 +++- src/databricks/sqlalchemy/base.py | 376 +++++++++++++++ .../sqlalchemy/test_local/test_parsing.py | 88 +++- .../sqlalchemy/test_local/test_types.py | 2 +- 5 files changed, 566 insertions(+), 463 deletions(-) create mode 100644 src/databricks/sqlalchemy/base.py diff --git a/src/databricks/sqlalchemy/__init__.py b/src/databricks/sqlalchemy/__init__.py index 325f6c513..0eed85f33 100644 --- a/src/databricks/sqlalchemy/__init__.py +++ b/src/databricks/sqlalchemy/__init__.py @@ -1,455 +1 @@ -import re -from typing import Any, Optional, List, Tuple - -import sqlalchemy -from sqlalchemy import event, DDL -from sqlalchemy.engine import Engine, default, reflection, Connection, Row, CursorResult -from sqlalchemy.engine.interfaces import ( - ReflectedForeignKeyConstraint, - ReflectedPrimaryKeyConstraint, -) -from sqlalchemy.exc import DatabaseError, SQLAlchemyError - -import databricks.sqlalchemy._ddl as dialect_ddl_impl -from databricks.sql.exc import ServerOperationError - -# This import is required to process our @compiles decorators -import databricks.sqlalchemy._types as dialect_type_impl -from databricks import sql -from databricks.sqlalchemy._parse import ( - build_fk_dict, - extract_identifiers_from_string, - extract_three_level_identifier_from_constraint_string, -) - -try: - import alembic -except ImportError: - pass -else: - from alembic.ddl import DefaultImpl - - class DatabricksImpl(DefaultImpl): - __dialect__ = "databricks" - - -import logging - -logger = logging.getLogger(__name__) - -DBR_LTE_12_NOT_FOUND_STRING = "Table or view not found" -DBR_GT_12_NOT_FOUND_STRING = "TABLE_OR_VIEW_NOT_FOUND" - - -def _match_table_not_found_string(message: str) -> bool: - """Return True if the message contains a substring indicating that a table was not found""" - return any( - [ - DBR_LTE_12_NOT_FOUND_STRING in message, - DBR_GT_12_NOT_FOUND_STRING in message, - ] - ) - - -def _describe_table_extended_result_to_dict(result: CursorResult) -> dict: - """Transform the output of DESCRIBE TABLE EXTENDED into a dictionary - - The output from DESCRIBE TABLE EXTENDED puts all values in the `data_type` column - Even CONSTRAINT descriptions are contained in the `data_type` column - Some rows have an empty string for their col_name. These are present only for spacing - so we ignore them. - """ - - result_dict = {row.col_name: row.data_type for row in result if row.col_name != ""} - - return result_dict - - -def _extract_pk_from_dte_result(result: dict) -> ReflectedPrimaryKeyConstraint: - """Return a dictionary with the keys: - - constrained_columns - a list of column names that make up the primary key. Results is an empty list - if no PRIMARY KEY is defined. - - name - the name of the primary key constraint - - Today, DESCRIBE TABLE EXTENDED doesn't give a deterministic name to the field where - a primary key constraint will be found in its output. So we cycle through its - output looking for a match that includes "PRIMARY KEY". This is brittle. We - could optionally make two roundtrips: the first would query information_schema - for the name of the primary key constraint on this table, and a second to - DESCRIBE TABLE EXTENDED, at which point we would know the name of the constraint. - But for now we instead assume that Python list comprehension is faster than a - network roundtrip. - """ - - # find any rows that contain "PRIMARY KEY" as the `data_type` - filtered_rows = [(k, v) for k, v in result.items() if "PRIMARY KEY" in v] - - # bail if no primary key was found - if not filtered_rows: - return {"constrained_columns": [], "name": None} - - # there should only ever be one PRIMARY KEY that matches - if len(filtered_rows) > 1: - logger.warning( - "Found more than one primary key constraint in DESCRIBE TABLE EXTENDED output. " - "This is unexpected. Please report this as a bug. " - "Only the first primary key constraint will be returned." - ) - - # target is a tuple of (constraint_name, constraint_string) - target = filtered_rows[0] - name = target[0] - _constraint_string = target[1] - column_list = extract_identifiers_from_string(_constraint_string) - - return {"constrained_columns": column_list, "name": name} - - -def _extract_fk_from_dte_result( - result: dict, schema_name: Optional[str] -) -> ReflectedForeignKeyConstraint: - """Extract a list of foreign key information dictionaries from the result - of a DESCRIBE TABLE EXTENDED call. - - Returns an empty list if no foreign key is defined. - - Today, DESCRIBE TABLE EXTENDED doesn't give a deterministic name to the field where - a foreign key constraint will be found in its output. So we cycle through its - output looking for a match that includes "FOREIGN KEY". This is brittle. We - could optionally make two roundtrips: the first would query information_schema - for the name of the foreign key constraint on this table, and a second to - DESCRIBE TABLE EXTENDED, at which point we would know the name of the constraint. - But for now we instead assume that Python list comprehension is faster than a - network roundtrip. - """ - - # find any rows that contain "FOREIGN_KEY" as the `data_type` - filtered_rows: List[Tuple] = [(k, v) for k, v in result.items() if "FOREIGN KEY" in v] - - # bail if no foreign key was found - if not filtered_rows: - return [] - - constraint_list = [] - - # target is a tuple of (constraint_name, constraint_string) - for target in filtered_rows: - _constraint_name, _constraint_string = target - this_constraint_dict = build_fk_dict( - _constraint_name, _constraint_string, schema_name - ) - constraint_list.append(this_constraint_dict) - - return constraint_list - - -COLUMN_TYPE_MAP = { - "boolean": sqlalchemy.types.Boolean, - "smallint": sqlalchemy.types.SmallInteger, - "int": sqlalchemy.types.Integer, - "bigint": sqlalchemy.types.BigInteger, - "float": sqlalchemy.types.Float, - "double": sqlalchemy.types.Float, - "string": sqlalchemy.types.String, - "varchar": sqlalchemy.types.String, - "char": sqlalchemy.types.String, - "binary": sqlalchemy.types.String, - "array": sqlalchemy.types.String, - "map": sqlalchemy.types.String, - "struct": sqlalchemy.types.String, - "uniontype": sqlalchemy.types.String, - "decimal": sqlalchemy.types.Numeric, - "timestamp": sqlalchemy.types.DateTime, - "date": sqlalchemy.types.Date, -} - - -class DatabricksDialect(default.DefaultDialect): - """This dialect implements only those methods required to pass our e2e tests""" - - # Possible attributes are defined here: https://docs.sqlalchemy.org/en/14/core/internals.html#sqlalchemy.engine.Dialect - name: str = "databricks" - driver: str = "databricks" - default_schema_name: str = "default" - preparer = dialect_ddl_impl.DatabricksIdentifierPreparer # type: ignore - ddl_compiler = dialect_ddl_impl.DatabricksDDLCompiler - statement_compiler = dialect_ddl_impl.DatabricksStatementCompiler - supports_statement_cache: bool = True - supports_multivalues_insert: bool = True - supports_native_decimal: bool = True - supports_sane_rowcount: bool = False - non_native_boolean_check_constraint: bool = False - supports_identity_columns: bool = True - supports_schemas: bool = True - paramstyle: str = "named" - - colspecs = { - sqlalchemy.types.DateTime: dialect_type_impl.DatabricksDateTimeNoTimezoneType, - sqlalchemy.types.Time: dialect_type_impl.DatabricksTimeType, - sqlalchemy.types.String: dialect_type_impl.DatabricksStringType, - } - - @classmethod - def dbapi(cls): - return sql - - def create_connect_args(self, url): - # TODO: can schema be provided after HOST? - # Expected URI format is: databricks+thrift://token:dapi***@***.cloud.databricks.com?http_path=/sql/*** - - kwargs = { - "server_hostname": url.host, - "access_token": url.password, - "http_path": url.query.get("http_path"), - "catalog": url.query.get("catalog"), - "schema": url.query.get("schema"), - } - - self.schema = kwargs["schema"] - self.catalog = kwargs["catalog"] - - return [], kwargs - - def get_columns(self, connection, table_name, schema=None, **kwargs): - """Return information about columns in `table_name`. - - Given a :class:`_engine.Connection`, a string - `table_name`, and an optional string `schema`, return column - information as a list of dictionaries with these keys: - - name - the column's name - - type - [sqlalchemy.types#TypeEngine] - - nullable - boolean - - default - the column's default value - - autoincrement - boolean - - sequence - a dictionary of the form - {'name' : str, 'start' :int, 'increment': int, 'minvalue': int, - 'maxvalue': int, 'nominvalue': bool, 'nomaxvalue': bool, - 'cycle': bool, 'cache': int, 'order': bool} - - Additional column attributes may be present. - """ - - with self.get_connection_cursor(connection) as cur: - resp = cur.columns( - catalog_name=self.catalog, - schema_name=schema or self.schema, - table_name=table_name, - ).fetchall() - - if not resp: - raise sqlalchemy.exc.NoSuchTableError(table_name) - columns = [] - - for col in resp: - # Taken from PyHive. This removes added type info from decimals and maps - _col_type = re.search(r"^\w+", col.TYPE_NAME).group(0) - this_column = { - "name": col.COLUMN_NAME, - "type": COLUMN_TYPE_MAP[_col_type.lower()], - "nullable": bool(col.NULLABLE), - "default": col.COLUMN_DEF, - "autoincrement": False if col.IS_AUTO_INCREMENT == "NO" else True, - } - columns.append(this_column) - - return columns - - def _describe_table_extended( - self, - connection: Connection, - table_name: str, - catalog_name: Optional[str] = None, - schema_name: Optional[str] = None, - expect_result=True, - ): - """Run DESCRIBE TABLE EXTENDED on a table and return a dictionary of the result. - - This method is the fastest way to check for the presence of a table in a schema. - - If expect_result is False, this method returns None as the output dict isn't required. - - Raises NoSuchTableError if the table is not present in the schema. - """ - - _target_catalog = catalog_name or self.catalog - _target_schema = schema_name or self.schema - _target = f"`{_target_catalog}`.`{_target_schema}`.`{table_name}`" - - # sql injection risk? - # DESCRIBE TABLE EXTENDED in DBR doesn't support parameterised inputs :( - stmt = DDL(f"DESCRIBE TABLE EXTENDED {_target}") - - try: - result = connection.execute(stmt).all() - except DatabaseError as e: - if _match_table_not_found_string(str(e)): - raise sqlalchemy.exc.NoSuchTableError( - f"No such table {table_name}" - ) from e - raise e - - if not expect_result: - return None - - fmt_result = _describe_table_extended_result_to_dict(result) - return fmt_result - - @reflection.cache - def get_pk_constraint( - self, - connection, - table_name: str, - schema: Optional[str] = None, - **kw: Any, - ) -> ReflectedPrimaryKeyConstraint: - """Return information about the primary key constraint on - table_name`. - """ - - result = self._describe_table_extended( - connection=connection, - table_name=table_name, - schema_name=schema, - ) - - return _extract_pk_from_dte_result(result) - - def get_foreign_keys( - self, connection, table_name, schema=None, **kw - ) -> ReflectedForeignKeyConstraint: - """Return information about foreign_keys in `table_name`.""" - - result = self._describe_table_extended( - connection=connection, - table_name=table_name, - schema_name=schema, - ) - - return _extract_fk_from_dte_result(result, schema) - - def get_indexes(self, connection, table_name, schema=None, **kw): - """Return information about indexes in `table_name`. - - Given a :class:`_engine.Connection`, a string - `table_name` and an optional string `schema`, return index - information as a list of dictionaries with these keys: - - name - the index's name - - column_names - list of column names in order - - unique - boolean - """ - # TODO: Implement this behaviour - return [] - - def get_table_names(self, connection, schema=None, **kwargs): - TABLE_NAME = 1 - with self.get_connection_cursor(connection) as cur: - sql_str = "SHOW TABLES FROM {}".format( - ".".join([self.catalog, schema or self.schema]) - ) - data = cur.execute(sql_str).fetchall() - _tables = [i[TABLE_NAME] for i in data] - - return _tables - - def get_view_names(self, connection, schema=None, **kwargs): - VIEW_NAME = 1 - with self.get_connection_cursor(connection) as cur: - sql_str = "SHOW VIEWS FROM {}".format( - ".".join([self.catalog, schema or self.schema]) - ) - data = cur.execute(sql_str).fetchall() - _tables = [i[VIEW_NAME] for i in data] - - return _tables - - def do_rollback(self, dbapi_connection): - # Databricks SQL Does not support transactions - pass - - @reflection.cache - def has_table( - self, connection, table_name, schema=None, catalog=None, **kwargs - ) -> bool: - """For internal dialect use, check the existence of a particular table - or view in the database. - """ - - try: - self._describe_table_extended( - connection=connection, - table_name=table_name, - catalog_name=catalog, - schema_name=schema, - ) - return True - except sqlalchemy.exc.NoSuchTableError as e: - return False - - def get_connection_cursor(self, connection): - """Added for backwards compatibility with 1.3.x""" - if hasattr(connection, "_dbapi_connection"): - return connection._dbapi_connection.dbapi_connection.cursor() - elif hasattr(connection, "raw_connection"): - return connection.raw_connection().cursor() - elif hasattr(connection, "connection"): - return connection.connection.cursor() - - raise SQLAlchemyError( - "Databricks dialect can't obtain a cursor context manager from the dbapi" - ) - - @reflection.cache - def get_schema_names(self, connection, **kw): - """Return a list of all schema names available in the database.""" - stmt = DDL("SHOW SCHEMAS") - result = connection.execute(stmt) - schema_list = [row[0] for row in result] - return schema_list - - -@event.listens_for(Engine, "do_connect") -def receive_do_connect(dialect, conn_rec, cargs, cparams): - """Helpful for DS on traffic from clients using SQLAlchemy in particular""" - - # Ignore connect invocations that don't use our dialect - if not dialect.name == "databricks": - return - - if "_user_agent_entry" in cparams: - new_user_agent = f"sqlalchemy + {cparams['_user_agent_entry']}" - else: - new_user_agent = "sqlalchemy" - - cparams["_user_agent_entry"] = new_user_agent - - if sqlalchemy.__version__.startswith("1.3"): - # SQLAlchemy 1.3.x fails to parse the http_path, catalog, and schema from our connection string - # These should be passed in as connect_args when building the Engine - - if "schema" in cparams: - dialect.schema = cparams["schema"] - - if "catalog" in cparams: - dialect.catalog = cparams["catalog"] +from databricks.sqlalchemy.base import DatabricksDialect \ No newline at end of file diff --git a/src/databricks/sqlalchemy/_parse.py b/src/databricks/sqlalchemy/_parse.py index 587b1381d..941737ba6 100644 --- a/src/databricks/sqlalchemy/_parse.py +++ b/src/databricks/sqlalchemy/_parse.py @@ -1,11 +1,39 @@ -from typing import List, Optional +from typing import List, Optional, Dict import re +from sqlalchemy.engine import CursorResult + """ This module contains helper functions that can parse the contents -of DESCRIBE TABLE EXTENDED calls. Mostly wrappers around regexes. +of metadata and exceptions received from DBR. These are mostly just +wrappers around regexes. """ +def _match_table_not_found_string(message: str) -> bool: + """Return True if the message contains a substring indicating that a table was not found""" + + DBR_LTE_12_NOT_FOUND_STRING = "Table or view not found" + DBR_GT_12_NOT_FOUND_STRING = "TABLE_OR_VIEW_NOT_FOUND" + return any( + [ + DBR_LTE_12_NOT_FOUND_STRING in message, + DBR_GT_12_NOT_FOUND_STRING in message, + ] + ) + + +def _describe_table_extended_result_to_dict_list(result: CursorResult) -> List[Dict[str, str]]: + """Transform the CursorResult of DESCRIBE TABLE EXTENDED into a list of Dictionaries + """ + + rows_to_return = [] + for row in result: + this_row = {"col_name": row.col_name, "data_type": row.data_type} + rows_to_return.append(this_row) + + return rows_to_return + + def extract_identifiers_from_string(input_str: str) -> List[str]: """For a string input resembling (`a`, `b`, `c`) return a list of identifiers ['a', 'b', 'c']""" @@ -142,4 +170,77 @@ def build_fk_dict( **schema_override_dict, } - return complete_foreign_key_dict \ No newline at end of file + return complete_foreign_key_dict + +def _parse_pk_columns_from_constraint_string(constraint_str: str) -> List[str]: + """Build a list of constrained columns from a constraint string returned by DESCRIBE TABLE EXTENDED + + For example: + + PRIMARY KEY (`id`, `name`, `email_address`) + + Returns a list like + + ["id", "name", "email_address"] + """ + + _extracted = extract_identifiers_from_string(constraint_str) + + return _extracted + +def build_pk_dict(pk_name: str, pk_constraint_string: str) -> dict: + """Given a primary key name and a primary key constraint string, return a dictionary + with the following keys: + + constrained_columns + A list of string column names that make up the primary key + + name + The name of the primary key constraint + """ + + constrained_columns = _parse_pk_columns_from_constraint_string(pk_constraint_string) + + return {"constrained_columns": constrained_columns, "name": pk_name} + +def match_dte_rows_by_value(dte_output: List[Dict[str, str]], match: str) -> List[dict]: + """Return a list of dictionaries containing only the col_name:data_type pairs where the `data_type` + value contains the match argument. + + Today, DESCRIBE TABLE EXTENDED doesn't give a deterministic name to the fields + a constraint will be found in its output. So we cycle through its output looking + for a match. This is brittle. We could optionally make two roundtrips: the first + would query information_schema for the name of the constraint on this table, and + a second to DESCRIBE TABLE EXTENDED, at which point we would know the name of the + constraint. But for now we instead assume that Python list comprehension is faster + than a network roundtrip + """ + + output_rows = [] + + for row_dict in dte_output: + if match in row_dict["data_type"]: + output_rows.append(row_dict) + + return output_rows + +def get_fk_strings_from_dte_output(dte_output: List[List]) -> List[dict]: + """If the DESCRIBE TABLE EXTENDED output contains foreign key constraints, return a list of dictionaries, + one dictionary per defined constraint + """ + + output = match_dte_rows_by_value(dte_output, "FOREIGN KEY") + + return output + + +def get_pk_strings_from_dte_output(dte_output: List[Dict[str, str]]) -> Optional[List[dict]]: + """If the DESCRIBE TABLE EXTENDED output contains primary key constraints, return a list of dictionaries, + one dictionary per defined constraint. + + Returns None if no primary key constraints are found. + """ + + output = match_dte_rows_by_value(dte_output, "PRIMARY KEY") + + return output diff --git a/src/databricks/sqlalchemy/base.py b/src/databricks/sqlalchemy/base.py new file mode 100644 index 000000000..df3823439 --- /dev/null +++ b/src/databricks/sqlalchemy/base.py @@ -0,0 +1,376 @@ +import re +from typing import Any, List, Optional, Dict + +import databricks.sqlalchemy._ddl as dialect_ddl_impl +import databricks.sqlalchemy._types as dialect_type_impl +from databricks import sql +from databricks.sqlalchemy._parse import ( + _describe_table_extended_result_to_dict_list, + _match_table_not_found_string, + build_fk_dict, + build_pk_dict, + get_fk_strings_from_dte_output, + get_pk_strings_from_dte_output, +) + +import sqlalchemy +from sqlalchemy import DDL, event +from sqlalchemy.engine import Connection, Engine, default, reflection +from sqlalchemy.engine.interfaces import ( + ReflectedForeignKeyConstraint, + ReflectedPrimaryKeyConstraint, +) +from sqlalchemy.exc import DatabaseError, SQLAlchemyError + +try: + import alembic +except ImportError: + pass +else: + from alembic.ddl import DefaultImpl + + class DatabricksImpl(DefaultImpl): + __dialect__ = "databricks" + + +import logging + +logger = logging.getLogger(__name__) + + +COLUMN_TYPE_MAP = { + "boolean": sqlalchemy.types.Boolean, + "smallint": sqlalchemy.types.SmallInteger, + "int": sqlalchemy.types.Integer, + "bigint": sqlalchemy.types.BigInteger, + "float": sqlalchemy.types.Float, + "double": sqlalchemy.types.Float, + "string": sqlalchemy.types.String, + "varchar": sqlalchemy.types.String, + "char": sqlalchemy.types.String, + "binary": sqlalchemy.types.String, + "array": sqlalchemy.types.String, + "map": sqlalchemy.types.String, + "struct": sqlalchemy.types.String, + "uniontype": sqlalchemy.types.String, + "decimal": sqlalchemy.types.Numeric, + "timestamp": sqlalchemy.types.DateTime, + "date": sqlalchemy.types.Date, +} + + +class DatabricksDialect(default.DefaultDialect): + """This dialect implements only those methods required to pass our e2e tests""" + + # Possible attributes are defined here: https://docs.sqlalchemy.org/en/14/core/internals.html#sqlalchemy.engine.Dialect + name: str = "databricks" + driver: str = "databricks" + default_schema_name: str = "default" + preparer = dialect_ddl_impl.DatabricksIdentifierPreparer # type: ignore + ddl_compiler = dialect_ddl_impl.DatabricksDDLCompiler + statement_compiler = dialect_ddl_impl.DatabricksStatementCompiler + supports_statement_cache: bool = True + supports_multivalues_insert: bool = True + supports_native_decimal: bool = True + supports_sane_rowcount: bool = False + non_native_boolean_check_constraint: bool = False + supports_identity_columns: bool = True + supports_schemas: bool = True + paramstyle: str = "named" + + colspecs = { + sqlalchemy.types.DateTime: dialect_type_impl.DatabricksDateTimeNoTimezoneType, + sqlalchemy.types.Time: dialect_type_impl.DatabricksTimeType, + sqlalchemy.types.String: dialect_type_impl.DatabricksStringType, + } + + # SQLAlchemy requires that a table with no primary key + # constraint return a dictionary that looks like this. + EMPTY_PK = {"constrained_columns": [], "name": None} + + # SQLAlchemy requires that a table with no foreign keys + # defined return an empty list. Same for indexes. + EMPTY_FK = EMPTY_INDEX = [] + + @classmethod + def dbapi(cls): + return sql + + def create_connect_args(self, url): + # TODO: can schema be provided after HOST? + # Expected URI format is: databricks+thrift://token:dapi***@***.cloud.databricks.com?http_path=/sql/*** + + kwargs = { + "server_hostname": url.host, + "access_token": url.password, + "http_path": url.query.get("http_path"), + "catalog": url.query.get("catalog"), + "schema": url.query.get("schema"), + } + + self.schema = kwargs["schema"] + self.catalog = kwargs["catalog"] + + return [], kwargs + + def get_columns(self, connection, table_name, schema=None, **kwargs): + """Return information about columns in `table_name`. + + Given a :class:`_engine.Connection`, a string + `table_name`, and an optional string `schema`, return column + information as a list of dictionaries with these keys: + + name + the column's name + + type + [sqlalchemy.types#TypeEngine] + + nullable + boolean + + default + the column's default value + + autoincrement + boolean + + sequence + a dictionary of the form + {'name' : str, 'start' :int, 'increment': int, 'minvalue': int, + 'maxvalue': int, 'nominvalue': bool, 'nomaxvalue': bool, + 'cycle': bool, 'cache': int, 'order': bool} + + Additional column attributes may be present. + """ + + with self.get_connection_cursor(connection) as cur: + resp = cur.columns( + catalog_name=self.catalog, + schema_name=schema or self.schema, + table_name=table_name, + ).fetchall() + + if not resp: + raise sqlalchemy.exc.NoSuchTableError(table_name) + columns = [] + + for col in resp: + # Taken from PyHive. This removes added type info from decimals and maps + _col_type = re.search(r"^\w+", col.TYPE_NAME).group(0) + this_column = { + "name": col.COLUMN_NAME, + "type": COLUMN_TYPE_MAP[_col_type.lower()], + "nullable": bool(col.NULLABLE), + "default": col.COLUMN_DEF, + "autoincrement": False if col.IS_AUTO_INCREMENT == "NO" else True, + } + columns.append(this_column) + + return columns + + def _describe_table_extended( + self, + connection: Connection, + table_name: str, + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + expect_result=True, + ) -> List[Dict[str, str]]: + """Run DESCRIBE TABLE EXTENDED on a table and return a list of dictionaries of the result. + + This method is the fastest way to check for the presence of a table in a schema. + + If expect_result is False, this method returns None as the output dict isn't required. + + Raises NoSuchTableError if the table is not present in the schema. + """ + + _target_catalog = catalog_name or self.catalog + _target_schema = schema_name or self.schema + _target = f"`{_target_catalog}`.`{_target_schema}`.`{table_name}`" + + # sql injection risk? + # DESCRIBE TABLE EXTENDED in DBR doesn't support parameterised inputs :( + stmt = DDL(f"DESCRIBE TABLE EXTENDED {_target}") + + try: + result = connection.execute(stmt).all() + except DatabaseError as e: + if _match_table_not_found_string(str(e)): + raise sqlalchemy.exc.NoSuchTableError( + f"No such table {table_name}" + ) from e + raise e + + if not expect_result: + return None + + fmt_result = _describe_table_extended_result_to_dict_list(result) + return fmt_result + + @reflection.cache + def get_pk_constraint( + self, + connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> ReflectedPrimaryKeyConstraint: + """Fetch information about the primary key constraint on table_name. + + Returns a dictionary with these keys: + constrained_columns + a list of column names that make up the primary key. Results is an empty list + if no PRIMARY KEY is defined. + + name + the name of the primary key constraint + """ + + result = self._describe_table_extended( + connection=connection, + table_name=table_name, + schema_name=schema, + ) + + raw_pk_constraints: List = get_pk_strings_from_dte_output(result) + if not any(raw_pk_constraints): + return self.EMPTY_PK + + if len(raw_pk_constraints) > 1: + logger.warning( + "Found more than one primary key constraint in DESCRIBE TABLE EXTENDED output. " + "This is unexpected. Please report this as a bug. " + "Only the first primary key constraint will be returned." + ) + + first_pk_constraint = raw_pk_constraints[0] + pk_name = first_pk_constraint.get("col_name") + pk_constraint_string = first_pk_constraint.get("data_type") + + return build_pk_dict(pk_name, pk_constraint_string) + + def get_foreign_keys( + self, connection, table_name, schema=None, **kw + ) -> ReflectedForeignKeyConstraint: + """Return information about foreign_keys in `table_name`.""" + + result = self._describe_table_extended( + connection=connection, + table_name=table_name, + schema_name=schema, + ) + + raw_fk_constraints: List = get_fk_strings_from_dte_output(result) + + if not any(raw_fk_constraints): + return self.EMPTY_FK + + fk_constraints = [] + for constraint_dict in raw_fk_constraints: + fk_name = constraint_dict.get("col_name") + fk_constraint_string = constraint_dict.get("data_type") + this_constraint_dict = build_fk_dict( + fk_name, fk_constraint_string, schema_name=schema + ) + fk_constraints.append(this_constraint_dict) + + return fk_constraints + + def get_indexes(self, connection, table_name, schema=None, **kw): + """SQLAlchemy requires this method. Databricks doesn't support indexes. + """ + return self.EMPTY_INDEX + + def get_table_names(self, connection, schema=None, **kwargs): + TABLE_NAME = 1 + with self.get_connection_cursor(connection) as cur: + sql_str = "SHOW TABLES FROM {}".format( + ".".join([self.catalog, schema or self.schema]) + ) + data = cur.execute(sql_str).fetchall() + _tables = [i[TABLE_NAME] for i in data] + + return _tables + + def get_view_names(self, connection, schema=None, **kwargs): + VIEW_NAME = 1 + with self.get_connection_cursor(connection) as cur: + sql_str = "SHOW VIEWS FROM {}".format( + ".".join([self.catalog, schema or self.schema]) + ) + data = cur.execute(sql_str).fetchall() + _tables = [i[VIEW_NAME] for i in data] + + return _tables + + def do_rollback(self, dbapi_connection): + # Databricks SQL Does not support transactions + pass + + @reflection.cache + def has_table( + self, connection, table_name, schema=None, catalog=None, **kwargs + ) -> bool: + """For internal dialect use, check the existence of a particular table + or view in the database. + """ + + try: + self._describe_table_extended( + connection=connection, + table_name=table_name, + catalog_name=catalog, + schema_name=schema, + ) + return True + except sqlalchemy.exc.NoSuchTableError as e: + return False + + def get_connection_cursor(self, connection): + """Added for backwards compatibility with 1.3.x""" + if hasattr(connection, "_dbapi_connection"): + return connection._dbapi_connection.dbapi_connection.cursor() + elif hasattr(connection, "raw_connection"): + return connection.raw_connection().cursor() + elif hasattr(connection, "connection"): + return connection.connection.cursor() + + raise SQLAlchemyError( + "Databricks dialect can't obtain a cursor context manager from the dbapi" + ) + + @reflection.cache + def get_schema_names(self, connection, **kw): + """Return a list of all schema names available in the database.""" + stmt = DDL("SHOW SCHEMAS") + result = connection.execute(stmt) + schema_list = [row[0] for row in result] + return schema_list + + +@event.listens_for(Engine, "do_connect") +def receive_do_connect(dialect, conn_rec, cargs, cparams): + """Helpful for DS on traffic from clients using SQLAlchemy in particular""" + + # Ignore connect invocations that don't use our dialect + if not dialect.name == "databricks": + return + + if "_user_agent_entry" in cparams: + new_user_agent = f"sqlalchemy + {cparams['_user_agent_entry']}" + else: + new_user_agent = "sqlalchemy" + + cparams["_user_agent_entry"] = new_user_agent + + if sqlalchemy.__version__.startswith("1.3"): + # SQLAlchemy 1.3.x fails to parse the http_path, catalog, and schema from our connection string + # These should be passed in as connect_args when building the Engine + + if "schema" in cparams: + dialect.schema = cparams["schema"] + + if "catalog" in cparams: + dialect.catalog = cparams["catalog"] diff --git a/src/databricks/sqlalchemy/test_local/test_parsing.py b/src/databricks/sqlalchemy/test_local/test_parsing.py index 3ebb8616e..ab82613ed 100644 --- a/src/databricks/sqlalchemy/test_local/test_parsing.py +++ b/src/databricks/sqlalchemy/test_local/test_parsing.py @@ -3,7 +3,9 @@ extract_identifiers_from_string, extract_identifier_groups_from_string, extract_three_level_identifier_from_constraint_string, - build_fk_dict + build_fk_dict, + build_pk_dict, + match_dte_rows_by_value, ) @@ -39,16 +41,19 @@ def test_extract_identifer_batches(input, expected): extract_identifier_groups_from_string(input) == expected ), "Failed to extract identifier groups from string" -def test_extract_3l_namespace_from_constraint_string(): +def test_extract_3l_namespace_from_constraint_string(): input = "FOREIGN KEY (`parent_user_id`) REFERENCES `main`.`pysql_dialect_compliance`.`users` (`user_id`)" expected = { "catalog": "main", "schema": "pysql_dialect_compliance", - "table": "users" + "table": "users", } - assert extract_three_level_identifier_from_constraint_string(input) == expected, "Failed to extract 3L namespace from constraint string" + assert ( + extract_three_level_identifier_from_constraint_string(input) == expected + ), "Failed to extract 3L namespace from constraint string" + @pytest.mark.parametrize("schema", [None, "some_schema"]) def test_build_fk_dict(schema): @@ -64,3 +69,78 @@ def test_build_fk_dict(schema): "referred_columns": ["user_id"], } + +def test_build_pk_dict(): + pk_constraint_string = "PRIMARY KEY (`id`, `name`, `email_address`)" + pk_name = "pk1" + + result = build_pk_dict(pk_name, pk_constraint_string) + + assert result == { + "constrained_columns": ["id", "name", "email_address"], + "name": "pk1", + } + + +# This is a real example of the output from DESCRIBE TABLE EXTENDED as of 15 October 2023 +RAW_SAMPLE_DTE_OUTPUT = [ + ["id", "int"], + ["name", "string"], + ["", ""], + ["# Detailed Table Information", ""], + ["Catalog", "main"], + ["Database", "pysql_sqlalchemy"], + ["Table", "exampleexampleexample"], + ["Created Time", "Sun Oct 15 21:12:54 UTC 2023"], + ["Last Access", "UNKNOWN"], + ["Created By", "Spark "], + ["Type", "MANAGED"], + ["Location", "s3://us-west-2-****-/19a85dee-****/tables/ccb7***"], + ["Provider", "delta"], + ["Owner", "some.user@example.com"], + ["Is_managed_location", "true"], + ["Predictive Optimization", "ENABLE (inherited from CATALOG main)"], + [ + "Table Properties", + "[delta.checkpoint.writeStatsAsJson=false,delta.checkpoint.writeStatsAsStruct=true,delta.minReaderVersion=1,delta.minWriterVersion=2]", + ], + ["", ""], + ["# Constraints", ""], + ["exampleexampleexample_pk", "PRIMARY KEY (`id`)"], + [ + "exampleexampleexample_fk", + "FOREIGN KEY (`parent_user_id`) REFERENCES `main`.`pysql_dialect_compliance`.`users` (`user_id`)", + ], +] + +FMT_SAMPLE_DT_OUTPUT = [ + {"col_name": i[0], "data_type": i[1]} for i in RAW_SAMPLE_DTE_OUTPUT +] + + +@pytest.mark.parametrize( + "match, output", + [ + ( + "PRIMARY KEY", + [ + { + "col_name": "exampleexampleexample_pk", + "data_type": "PRIMARY KEY (`id`)", + } + ], + ), + ( + "FOREIGN KEY", + [ + { + "col_name": "exampleexampleexample_fk", + "data_type": "FOREIGN KEY (`parent_user_id`) REFERENCES `main`.`pysql_dialect_compliance`.`users` (`user_id`)", + } + ], + ), + ], +) +def test_filter_dict_by_value(match, output): + result = match_dte_rows_by_value(FMT_SAMPLE_DT_OUTPUT, match) + assert result == output diff --git a/src/databricks/sqlalchemy/test_local/test_types.py b/src/databricks/sqlalchemy/test_local/test_types.py index f7423f697..fb66562a7 100644 --- a/src/databricks/sqlalchemy/test_local/test_types.py +++ b/src/databricks/sqlalchemy/test_local/test_types.py @@ -3,7 +3,7 @@ import pytest import sqlalchemy -from databricks.sqlalchemy import DatabricksDialect +from databricks.sqlalchemy.base import DatabricksDialect class DatabricksDataType(enum.Enum):