Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/databricks/labs/ucx/source_code/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import codecs
import dataclasses
import locale
import logging
from abc import abstractmethod
Expand Down Expand Up @@ -343,3 +344,20 @@ def __hash__(self) -> int:

def __eq__(self, other) -> bool:
return isinstance(other, DFSA) and self.key == other.key


def fix_field_types(klass):
"""there is a (Python?) bug where calling dataclasses.fields(DFSA)
returns fields where type is a type name instead of a type, for example 'str' instead of <class str>
this prevents our ORM from working as expected, so we need a workaround
Hacking this locally for now because I can't submit PRs to lsql where the workaround belongs
"""
for field in dataclasses.fields(klass):
if isinstance(field.type, str):
try:
field.type = __builtins__[field.type]
except KeyError:
logger.warning(f"Can't infer type of {field.type}")


fix_field_types(DFSA)
43 changes: 36 additions & 7 deletions src/databricks/labs/ucx/source_code/dfsa.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import logging
from collections.abc import Iterable
from functools import partial
from pathlib import Path

from sqlglot import Expression as SqlExpression, parse as parse_sql, ParseError as SqlParseError
from sqlglot.expressions import AlterTable, Create, Delete, Drop, Identifier, Insert, Literal, Select

from databricks.sdk.service.workspace import Language
from databricks.sdk.service.sql import Query

from databricks.labs.lsql.backends import SqlBackend
from databricks.labs.ucx.framework.crawlers import CrawlerBase
from databricks.labs.ucx.framework.utils import escape_sql_identifier
from databricks.labs.ucx.source_code.base import (
is_a_notebook,
CurrentSessionState,
Expand Down Expand Up @@ -43,6 +46,14 @@ def __init__(self, backend: SqlBackend, schema: str):
def append(self, dfsa: DFSA):
self._append_records([dfsa])

def snapshot(self) -> Iterable[DFSA]:
return self._snapshot(partial(self._try_load), lambda: [])

def _try_load(self) -> Iterable[DFSA]:
"""Tries to load table information from the database or throws TABLE_OR_VIEW_NOT_FOUND error"""
for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"):
yield DFSA(*row)


class DfsaCollector:
"""DfsaCollector is responsible for collecting and storing DFSAs i.e. Direct File System Access records"""
Expand All @@ -52,7 +63,27 @@ def __init__(self, crawler: DfsaCrawler, path_lookup: PathLookup, session_state:
self._path_lookup = path_lookup
self._session_state = session_state

def collect(self, graph: DependencyGraph) -> Iterable[DFSA]:
def collect_from_workspace_queries(self, ws) -> Iterable[DFSA]:
for query in ws.queries.list():
yield from self.collect_from_query(query)

def collect_from_query(self, query: Query) -> Iterable[DFSA]:
if query.query is None:
return
name: str = query.name or "<anonymous>"
source_path = Path(name) if query.parent is None else Path(query.parent) / name
for dfsa in self._collect_from_sql(query.query):
dfsa = DFSA(
source_type="QUERY",
source_id=str(source_path),
path=dfsa.path,
is_read=dfsa.is_read,
is_write=dfsa.is_write,
)
self._crawler.append(dfsa)
yield dfsa

def collect_from_graph(self, graph: DependencyGraph) -> Iterable[DFSA]:
collected_paths: set[Path] = set()
for dependency in graph.root_dependencies:
root = dependency.path # since it's a root
Expand Down Expand Up @@ -110,9 +141,9 @@ def _collect_from_source(
) -> Iterable[DFSA]:
iterable: Iterable[DFSA] | None = None
if language is Language.SQL:
iterable = cls._collect_from_sql(path, source)
iterable = cls._collect_from_sql(source)
if language is Language.PYTHON:
iterable = cls._collect_from_python(path, source, graph, inherited_tree)
iterable = cls._collect_from_python(source, graph, inherited_tree)
if iterable is None:
logger.warning(f"Language {language.name} not supported yet!")
return
Expand All @@ -122,14 +153,12 @@ def _collect_from_source(
)

@classmethod
def _collect_from_python(
cls, _path: Path, source: str, graph: DependencyGraph, inherited_tree: Tree | None
) -> Iterable[DFSA]:
def _collect_from_python(cls, source: str, graph: DependencyGraph, inherited_tree: Tree | None) -> Iterable[DFSA]:
analyzer = PythonCodeAnalyzer(graph.new_dependency_graph_context(), source)
yield from analyzer.collect_dfsas(inherited_tree)

@classmethod
def _collect_from_sql(cls, _path: Path, source: str) -> Iterable[DFSA]:
def _collect_from_sql(cls, source: str) -> Iterable[DFSA]:
try:
sqls = parse_sql(source, read='databricks')
for sql in sqls:
Expand Down
60 changes: 60 additions & 0 deletions tests/integration/source_code/test_dfsa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import logging

import pytest


from databricks.labs.ucx.mixins.fixtures import get_test_purge_time, factory
from databricks.labs.ucx.source_code.base import CurrentSessionState
from databricks.labs.ucx.source_code.dfsa import DfsaCrawler, DfsaCollector

logger = logging.getLogger("__name__")


@pytest.fixture
def make_query(ws, make_random):
def create(name: str, sql: str, **kwargs):
# add RemoveAfter tag for test job cleanup
date_to_remove = get_test_purge_time()
tags: list[str] = kwargs["tags"] if 'tags' in kwargs else []
tags.append(str({"key": "RemoveAfter", "value": date_to_remove}))
query = ws.queries.create(name=name, query=sql, tags=tags)
logger.info(f"Query: {ws.config.host}#query/{query.id}")
return query

yield from factory("query", create, lambda query: ws.queries.delete(query.id))


@pytest.fixture
def crawler(make_schema, sql_backend):
schema = make_schema(catalog_name="hive_metastore")
return DfsaCrawler(sql_backend, schema.name)


@pytest.fixture
def collector(crawler, simple_ctx):
return DfsaCollector(crawler, simple_ctx.path_lookup, CurrentSessionState())


@pytest.mark.parametrize(
"name, sql, dfsa_paths, is_read, is_write",
[
(
"create_location",
"CREATE TABLE hive_metastore.indices_historical_data.sp_500 LOCATION 's3a://db-gtm-industry-solutions/data/fsi/capm/sp_500/'",
['s3a://db-gtm-industry-solutions/data/fsi/capm/sp_500/'],
False,
True,
)
],
)
def test_dfsa_collector_collects_dfsas_from_query(
name, sql, dfsa_paths, is_read, is_write, ws, crawler, collector, make_query
):
query = make_query(name=name, sql=sql)
_ = list(collector.collect_from_workspace_queries(ws))
for dfsa in crawler.snapshot():
assert dfsa.path in set(dfsa_paths)
assert dfsa.source_type == "QUERY"
assert dfsa.source_id.endswith(query.name)
assert dfsa.is_read == is_read
assert dfsa.is_write == is_write
35 changes: 32 additions & 3 deletions tests/unit/source_code/test_dfsa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from pathlib import Path
from unittest.mock import create_autospec

import pytest

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.sql import Query

from databricks.labs.lsql.backends import MockBackend

Expand All @@ -23,7 +26,7 @@ def test_dfsa_does_not_collect_erroneously(simple_dependency_resolver, migration
maybe = simple_dependency_resolver.build_local_file_dependency_graph(Path("leaf4.py"), CurrentSessionState())
crawler = DfsaCrawler(MockBackend(), "schema")
collector = DfsaCollector(crawler, mock_path_lookup, CurrentSessionState())
dfsas = list(collector.collect(maybe.graph))
dfsas = list(collector.collect_from_graph(maybe.graph))
assert not dfsas


Expand All @@ -48,7 +51,7 @@ def test_dfsa_does_not_collect_erroneously(simple_dependency_resolver, migration
("dfsa/spark_dbfs_mount.py", ["/mnt/some_file.csv"], True, False),
],
)
def test_dfsa_collects_sql_dfsas(
def test_dfsa_collects_file_dfsas(
source_path, dfsa_paths, is_read, is_write, simple_dependency_resolver, migration_index, mock_path_lookup
):
"""SQL expression not supported by sqlglot for Databricks, restore the test data below for once we drop sqlglot
Expand All @@ -57,8 +60,34 @@ def test_dfsa_collects_sql_dfsas(
assert not maybe.problems
crawler = DfsaCrawler(MockBackend(), "schema")
collector = DfsaCollector(crawler, mock_path_lookup, CurrentSessionState())
dfsas = list(collector.collect(maybe.graph))
dfsas = list(collector.collect_from_graph(maybe.graph))
assert set(dfsa.path for dfsa in dfsas) == set(dfsa_paths)
assert not any(dfsa for dfsa in dfsas if dfsa.source_type == DFSA.UNKNOWN)
assert not any(dfsa for dfsa in dfsas if dfsa.is_read != is_read)
assert not any(dfsa for dfsa in dfsas if dfsa.is_write != is_write)


@pytest.mark.parametrize(
"name, query, dfsa_paths, is_read, is_write",
[
("none", "SELECT * from dual", [], False, False),
(
"location",
"CREATE TABLE hive_metastore.indices_historical_data.sp_500 LOCATION 's3a://db-gtm-industry-solutions/data/fsi/capm/sp_500/'",
["s3a://db-gtm-industry-solutions/data/fsi/capm/sp_500/"],
False,
True,
),
],
)
def test_dfsa_collects_query_dfsas(name, query, dfsa_paths, is_read, is_write, mock_path_lookup):
ws = create_autospec(WorkspaceClient)
query = Query.from_dict({"parent": "workspace", "name": name, "query": query})
ws.queries.list.return_value = iter([query])
crawler = DfsaCrawler(MockBackend(), "schema")
collector = DfsaCollector(crawler, mock_path_lookup, CurrentSessionState())
dfsas = list(collector.collect_from_workspace_queries(ws))
assert set(dfsa.path for dfsa in dfsas) == set(dfsa_paths)
assert not any(dfsa for dfsa in dfsas if dfsa.source_type != "QUERY")
assert not any(dfsa for dfsa in dfsas if dfsa.is_read != is_read)
assert not any(dfsa for dfsa in dfsas if dfsa.is_write != is_write)