diff --git a/src/databricks/labs/ucx/source_code/base.py b/src/databricks/labs/ucx/source_code/base.py index 4815aa86aa..e9a15a5097 100644 --- a/src/databricks/labs/ucx/source_code/base.py +++ b/src/databricks/labs/ucx/source_code/base.py @@ -1,6 +1,7 @@ from __future__ import annotations import codecs +import dataclasses import locale import logging from abc import abstractmethod @@ -343,3 +344,20 @@ def __hash__(self) -> int: def __eq__(self, other) -> bool: return isinstance(other, DFSA) and self.key == other.key + + +def fix_field_types(klass): + """there is a (Python?) bug where calling dataclasses.fields(DFSA) + returns fields where type is a type name instead of a type, for example 'str' instead of + this prevents our ORM from working as expected, so we need a workaround + Hacking this locally for now because I can't submit PRs to lsql where the workaround belongs + """ + for field in dataclasses.fields(klass): + if isinstance(field.type, str): + try: + field.type = __builtins__[field.type] + except KeyError: + logger.warning(f"Can't infer type of {field.type}") + + +fix_field_types(DFSA) diff --git a/src/databricks/labs/ucx/source_code/dfsa.py b/src/databricks/labs/ucx/source_code/dfsa.py index 0838cc1565..cb8415be21 100644 --- a/src/databricks/labs/ucx/source_code/dfsa.py +++ b/src/databricks/labs/ucx/source_code/dfsa.py @@ -1,14 +1,17 @@ import logging from collections.abc import Iterable +from functools import partial from pathlib import Path from sqlglot import Expression as SqlExpression, parse as parse_sql, ParseError as SqlParseError from sqlglot.expressions import AlterTable, Create, Delete, Drop, Identifier, Insert, Literal, Select from databricks.sdk.service.workspace import Language +from databricks.sdk.service.sql import Query from databricks.labs.lsql.backends import SqlBackend from databricks.labs.ucx.framework.crawlers import CrawlerBase +from databricks.labs.ucx.framework.utils import escape_sql_identifier from databricks.labs.ucx.source_code.base import ( is_a_notebook, CurrentSessionState, @@ -43,6 +46,14 @@ def __init__(self, backend: SqlBackend, schema: str): def append(self, dfsa: DFSA): self._append_records([dfsa]) + def snapshot(self) -> Iterable[DFSA]: + return self._snapshot(partial(self._try_load), lambda: []) + + def _try_load(self) -> Iterable[DFSA]: + """Tries to load table information from the database or throws TABLE_OR_VIEW_NOT_FOUND error""" + for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"): + yield DFSA(*row) + class DfsaCollector: """DfsaCollector is responsible for collecting and storing DFSAs i.e. Direct File System Access records""" @@ -52,7 +63,27 @@ def __init__(self, crawler: DfsaCrawler, path_lookup: PathLookup, session_state: self._path_lookup = path_lookup self._session_state = session_state - def collect(self, graph: DependencyGraph) -> Iterable[DFSA]: + def collect_from_workspace_queries(self, ws) -> Iterable[DFSA]: + for query in ws.queries.list(): + yield from self.collect_from_query(query) + + def collect_from_query(self, query: Query) -> Iterable[DFSA]: + if query.query is None: + return + name: str = query.name or "" + source_path = Path(name) if query.parent is None else Path(query.parent) / name + for dfsa in self._collect_from_sql(query.query): + dfsa = DFSA( + source_type="QUERY", + source_id=str(source_path), + path=dfsa.path, + is_read=dfsa.is_read, + is_write=dfsa.is_write, + ) + self._crawler.append(dfsa) + yield dfsa + + def collect_from_graph(self, graph: DependencyGraph) -> Iterable[DFSA]: collected_paths: set[Path] = set() for dependency in graph.root_dependencies: root = dependency.path # since it's a root @@ -110,9 +141,9 @@ def _collect_from_source( ) -> Iterable[DFSA]: iterable: Iterable[DFSA] | None = None if language is Language.SQL: - iterable = cls._collect_from_sql(path, source) + iterable = cls._collect_from_sql(source) if language is Language.PYTHON: - iterable = cls._collect_from_python(path, source, graph, inherited_tree) + iterable = cls._collect_from_python(source, graph, inherited_tree) if iterable is None: logger.warning(f"Language {language.name} not supported yet!") return @@ -122,14 +153,12 @@ def _collect_from_source( ) @classmethod - def _collect_from_python( - cls, _path: Path, source: str, graph: DependencyGraph, inherited_tree: Tree | None - ) -> Iterable[DFSA]: + def _collect_from_python(cls, source: str, graph: DependencyGraph, inherited_tree: Tree | None) -> Iterable[DFSA]: analyzer = PythonCodeAnalyzer(graph.new_dependency_graph_context(), source) yield from analyzer.collect_dfsas(inherited_tree) @classmethod - def _collect_from_sql(cls, _path: Path, source: str) -> Iterable[DFSA]: + def _collect_from_sql(cls, source: str) -> Iterable[DFSA]: try: sqls = parse_sql(source, read='databricks') for sql in sqls: diff --git a/tests/integration/source_code/test_dfsa.py b/tests/integration/source_code/test_dfsa.py new file mode 100644 index 0000000000..c213d8f6af --- /dev/null +++ b/tests/integration/source_code/test_dfsa.py @@ -0,0 +1,60 @@ +import logging + +import pytest + + +from databricks.labs.ucx.mixins.fixtures import get_test_purge_time, factory +from databricks.labs.ucx.source_code.base import CurrentSessionState +from databricks.labs.ucx.source_code.dfsa import DfsaCrawler, DfsaCollector + +logger = logging.getLogger("__name__") + + +@pytest.fixture +def make_query(ws, make_random): + def create(name: str, sql: str, **kwargs): + # add RemoveAfter tag for test job cleanup + date_to_remove = get_test_purge_time() + tags: list[str] = kwargs["tags"] if 'tags' in kwargs else [] + tags.append(str({"key": "RemoveAfter", "value": date_to_remove})) + query = ws.queries.create(name=name, query=sql, tags=tags) + logger.info(f"Query: {ws.config.host}#query/{query.id}") + return query + + yield from factory("query", create, lambda query: ws.queries.delete(query.id)) + + +@pytest.fixture +def crawler(make_schema, sql_backend): + schema = make_schema(catalog_name="hive_metastore") + return DfsaCrawler(sql_backend, schema.name) + + +@pytest.fixture +def collector(crawler, simple_ctx): + return DfsaCollector(crawler, simple_ctx.path_lookup, CurrentSessionState()) + + +@pytest.mark.parametrize( + "name, sql, dfsa_paths, is_read, is_write", + [ + ( + "create_location", + "CREATE TABLE hive_metastore.indices_historical_data.sp_500 LOCATION 's3a://db-gtm-industry-solutions/data/fsi/capm/sp_500/'", + ['s3a://db-gtm-industry-solutions/data/fsi/capm/sp_500/'], + False, + True, + ) + ], +) +def test_dfsa_collector_collects_dfsas_from_query( + name, sql, dfsa_paths, is_read, is_write, ws, crawler, collector, make_query +): + query = make_query(name=name, sql=sql) + _ = list(collector.collect_from_workspace_queries(ws)) + for dfsa in crawler.snapshot(): + assert dfsa.path in set(dfsa_paths) + assert dfsa.source_type == "QUERY" + assert dfsa.source_id.endswith(query.name) + assert dfsa.is_read == is_read + assert dfsa.is_write == is_write diff --git a/tests/unit/source_code/test_dfsa.py b/tests/unit/source_code/test_dfsa.py index fb67a85562..d539e4075e 100644 --- a/tests/unit/source_code/test_dfsa.py +++ b/tests/unit/source_code/test_dfsa.py @@ -1,7 +1,10 @@ from pathlib import Path +from unittest.mock import create_autospec import pytest +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.sql import Query from databricks.labs.lsql.backends import MockBackend @@ -23,7 +26,7 @@ def test_dfsa_does_not_collect_erroneously(simple_dependency_resolver, migration maybe = simple_dependency_resolver.build_local_file_dependency_graph(Path("leaf4.py"), CurrentSessionState()) crawler = DfsaCrawler(MockBackend(), "schema") collector = DfsaCollector(crawler, mock_path_lookup, CurrentSessionState()) - dfsas = list(collector.collect(maybe.graph)) + dfsas = list(collector.collect_from_graph(maybe.graph)) assert not dfsas @@ -48,7 +51,7 @@ def test_dfsa_does_not_collect_erroneously(simple_dependency_resolver, migration ("dfsa/spark_dbfs_mount.py", ["/mnt/some_file.csv"], True, False), ], ) -def test_dfsa_collects_sql_dfsas( +def test_dfsa_collects_file_dfsas( source_path, dfsa_paths, is_read, is_write, simple_dependency_resolver, migration_index, mock_path_lookup ): """SQL expression not supported by sqlglot for Databricks, restore the test data below for once we drop sqlglot @@ -57,8 +60,34 @@ def test_dfsa_collects_sql_dfsas( assert not maybe.problems crawler = DfsaCrawler(MockBackend(), "schema") collector = DfsaCollector(crawler, mock_path_lookup, CurrentSessionState()) - dfsas = list(collector.collect(maybe.graph)) + dfsas = list(collector.collect_from_graph(maybe.graph)) assert set(dfsa.path for dfsa in dfsas) == set(dfsa_paths) assert not any(dfsa for dfsa in dfsas if dfsa.source_type == DFSA.UNKNOWN) assert not any(dfsa for dfsa in dfsas if dfsa.is_read != is_read) assert not any(dfsa for dfsa in dfsas if dfsa.is_write != is_write) + + +@pytest.mark.parametrize( + "name, query, dfsa_paths, is_read, is_write", + [ + ("none", "SELECT * from dual", [], False, False), + ( + "location", + "CREATE TABLE hive_metastore.indices_historical_data.sp_500 LOCATION 's3a://db-gtm-industry-solutions/data/fsi/capm/sp_500/'", + ["s3a://db-gtm-industry-solutions/data/fsi/capm/sp_500/"], + False, + True, + ), + ], +) +def test_dfsa_collects_query_dfsas(name, query, dfsa_paths, is_read, is_write, mock_path_lookup): + ws = create_autospec(WorkspaceClient) + query = Query.from_dict({"parent": "workspace", "name": name, "query": query}) + ws.queries.list.return_value = iter([query]) + crawler = DfsaCrawler(MockBackend(), "schema") + collector = DfsaCollector(crawler, mock_path_lookup, CurrentSessionState()) + dfsas = list(collector.collect_from_workspace_queries(ws)) + assert set(dfsa.path for dfsa in dfsas) == set(dfsa_paths) + assert not any(dfsa for dfsa in dfsas if dfsa.source_type != "QUERY") + assert not any(dfsa for dfsa in dfsas if dfsa.is_read != is_read) + assert not any(dfsa for dfsa in dfsas if dfsa.is_write != is_write)