diff --git a/src/databricks/labs/ucx/cli.py b/src/databricks/labs/ucx/cli.py index ec8dc524b7..640dee5d4b 100644 --- a/src/databricks/labs/ucx/cli.py +++ b/src/databricks/labs/ucx/cli.py @@ -509,8 +509,8 @@ def migrate_locations(w: WorkspaceClient, aws_profile: str | None = None): def create_catalogs_schemas(w: WorkspaceClient, prompts: Prompts): """Create UC catalogs and schemas based on the destinations created from create_table_mapping command.""" installation = Installation.current(w, 'ucx') - catalog_schema = CatalogSchema.for_cli(w, installation, prompts) - catalog_schema.create_catalog_schema() + catalog_schema = CatalogSchema.for_cli(w, installation) + catalog_schema.create_all_catalogs_schemas(prompts) @ucx.command diff --git a/src/databricks/labs/ucx/hive_metastore/catalog_schema.py b/src/databricks/labs/ucx/hive_metastore/catalog_schema.py index 737553ada6..a83a8125f7 100644 --- a/src/databricks/labs/ucx/hive_metastore/catalog_schema.py +++ b/src/databricks/labs/ucx/hive_metastore/catalog_schema.py @@ -1,9 +1,11 @@ import logging +from pathlib import PurePath from databricks.labs.blueprint.installation import Installation from databricks.labs.blueprint.tui import Prompts from databricks.labs.lsql.backends import StatementExecutionBackend from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import NotFound from databricks.labs.ucx.config import WorkspaceConfig from databricks.labs.ucx.hive_metastore.mapping import TableMapping @@ -12,17 +14,40 @@ class CatalogSchema: - def __init__(self, ws: WorkspaceClient, table_mapping: TableMapping, prompts: Prompts): + def __init__(self, ws: WorkspaceClient, table_mapping: TableMapping): self._ws = ws self._table_mapping = table_mapping - self._prompts = prompts + self._external_locations = self._ws.external_locations.list() @classmethod - def for_cli(cls, ws: WorkspaceClient, installation: Installation, prompts: Prompts): + def for_cli(cls, ws: WorkspaceClient, installation: Installation): config = installation.load(WorkspaceConfig) sql_backend = StatementExecutionBackend(ws, config.warehouse_id) table_mapping = TableMapping(installation, ws, sql_backend) - return cls(ws, table_mapping, prompts) + return cls(ws, table_mapping) + + def create_all_catalogs_schemas(self, prompts: Prompts): + candidate_catalogs, candidate_schemas = self._get_missing_catalogs_schemas() + for candidate_catalog in candidate_catalogs: + self._create_catalog_validate(candidate_catalog, prompts) + for candidate_catalog, schemas in candidate_schemas.items(): + for candidate_schema in schemas: + self._create_schema(candidate_catalog, candidate_schema) + + def _create_catalog_validate(self, catalog, prompts: Prompts): + logger.info(f"Creating UC catalog: {catalog}") + # create catalogs + attempts = 3 + while True: + catalog_storage = prompts.question( + f"Please provide storage location url for catalog:{catalog}.", default="metastore" + ) + if self._validate_location(catalog_storage): + break + attempts -= 1 + if attempts == 0: + raise NotFound(f"Failed to validate location for {catalog} catalog") + self._create_catalog(catalog, catalog_storage) def _list_existing(self) -> tuple[set[str], dict[str, set[str]]]: """generate a list of existing UC catalogs and schema.""" @@ -56,7 +81,7 @@ def _list_target(self) -> tuple[set[str], dict[str, set[str]]]: target_schemas[target_catalog].add(target_schema) return target_catalogs, target_schemas - def _prepare(self) -> tuple[set[str], dict[str, set[str]]]: + def _get_missing_catalogs_schemas(self) -> tuple[set[str], dict[str, set[str]]]: """prepare a list of catalogs and schema to be created""" existing_catalogs, existing_schemas = self._list_existing() target_catalogs, target_schemas = self._list_target() @@ -72,23 +97,28 @@ def _prepare(self) -> tuple[set[str], dict[str, set[str]]]: target_schemas[catalog] = target_schemas[catalog] - schemas return target_catalogs, target_schemas - def _create(self, catalogs, schemas): - logger.info("Creating UC catalogs and schemas.") - # create catalogs - for catalog_name in catalogs: - catalog_storage = self._prompts.question( - f"Please provide storage location url for catalog:{catalog_name}.", default="metastore" - ) - if catalog_storage == "metastore": - self._ws.catalogs.create(catalog_name, comment="Created by UCX") - continue - self._ws.catalogs.create(catalog_name, storage_root=catalog_storage, comment="Created by UCX") - - # create schemas - for catalog_name, schema_names in schemas.items(): - for schema_name in schema_names: - self._ws.schemas.create(schema_name, catalog_name, comment="Created by UCX") - - def create_catalog_schema(self): - candidate_catalogs, candidate_schemas = self._prepare() - self._create(candidate_catalogs, candidate_schemas) + def _validate_location(self, location: str): + if location == "metastore": + return True + try: + location_path = PurePath(location) + except ValueError: + logger.error(f"Invalid location path {location}") + return False + for external_location in self._external_locations: + if location == external_location.url: + return True + if location_path.match(f"{external_location.url}/*"): + return True + return False + + def _create_catalog(self, catalog, catalog_storage): + logger.info(f"Creating UC catalog: {catalog}") + if catalog_storage == "metastore": + self._ws.catalogs.create(catalog, comment="Created by UCX") + else: + self._ws.catalogs.create(catalog, storage_root=catalog_storage, comment="Created by UCX") + + def _create_schema(self, catalog, schema): + logger.info(f"Creating UC schema: {schema} in catalog: {catalog}") + self._ws.schemas.create(schema, catalog, comment="Created by UCX") diff --git a/tests/unit/hive_metastore/test_catalog_schema.py b/tests/unit/hive_metastore/test_catalog_schema.py index 3a91219a25..59cd38cf95 100644 --- a/tests/unit/hive_metastore/test_catalog_schema.py +++ b/tests/unit/hive_metastore/test_catalog_schema.py @@ -1,18 +1,21 @@ from unittest.mock import create_autospec +import pytest from databricks.labs.blueprint.installation import MockInstallation from databricks.labs.blueprint.tui import MockPrompts from databricks.labs.lsql.backends import MockBackend from databricks.sdk import WorkspaceClient -from databricks.sdk.service.catalog import CatalogInfo, SchemaInfo +from databricks.sdk.errors import NotFound +from databricks.sdk.service.catalog import CatalogInfo, ExternalLocationInfo, SchemaInfo from databricks.labs.ucx.hive_metastore.catalog_schema import CatalogSchema from databricks.labs.ucx.hive_metastore.mapping import TableMapping -def prepare_test(ws, mock_prompts) -> CatalogSchema: +def prepare_test(ws) -> CatalogSchema: ws.catalogs.list.return_value = [CatalogInfo(name="catalog1")] ws.schemas.list.return_value = [SchemaInfo(name="schema1")] + ws.external_locations.list.return_value = [ExternalLocationInfo(url="s3://foo/bar")] backend = MockBackend() installation = MockInstallation( { @@ -46,26 +49,45 @@ def prepare_test(ws, mock_prompts) -> CatalogSchema: ) table_mapping = TableMapping(installation, ws, backend) - return CatalogSchema(ws, table_mapping, mock_prompts) + return CatalogSchema(ws, table_mapping) def test_create(): ws = create_autospec(WorkspaceClient) mock_prompts = MockPrompts({"Please provide storage location url for catalog: *": "s3://foo/bar"}) - catalog_schema = prepare_test(ws, mock_prompts) - catalog_schema.create_catalog_schema() + catalog_schema = prepare_test(ws) + catalog_schema.create_all_catalogs_schemas(mock_prompts) ws.catalogs.create.assert_called_once_with("catalog2", storage_root="s3://foo/bar", comment="Created by UCX") ws.schemas.create.assert_any_call("schema2", "catalog2", comment="Created by UCX") ws.schemas.create.assert_any_call("schema3", "catalog1", comment="Created by UCX") +def test_create_sub_location(): + ws = create_autospec(WorkspaceClient) + mock_prompts = MockPrompts({"Please provide storage location url for catalog: *": "s3://foo/bar/test"}) + + catalog_schema = prepare_test(ws) + catalog_schema.create_all_catalogs_schemas(mock_prompts) + ws.catalogs.create.assert_called_once_with("catalog2", storage_root="s3://foo/bar/test", comment="Created by UCX") + ws.schemas.create.assert_any_call("schema2", "catalog2", comment="Created by UCX") + ws.schemas.create.assert_any_call("schema3", "catalog1", comment="Created by UCX") + + +def test_create_bad_location(): + ws = create_autospec(WorkspaceClient) + mock_prompts = MockPrompts({"Please provide storage location url for catalog: *": "s3://foo/fail"}) + catalog_schema = prepare_test(ws) + with pytest.raises(NotFound): + catalog_schema.create_all_catalogs_schemas(mock_prompts) + + def test_no_catalog_storage(): ws = create_autospec(WorkspaceClient) mock_prompts = MockPrompts({"Please provide storage location url for catalog: *": ""}) - catalog_schema = prepare_test(ws, mock_prompts) - catalog_schema.create_catalog_schema() + catalog_schema = prepare_test(ws) + catalog_schema.create_all_catalogs_schemas(mock_prompts) ws.catalogs.create.assert_called_once_with("catalog2", comment="Created by UCX") @@ -83,6 +105,5 @@ def test_for_cli(): } } ) - prompts = MockPrompts({"hello": "world"}) - catalog_schema = CatalogSchema.for_cli(ws, installation, prompts) + catalog_schema = CatalogSchema.for_cli(ws, installation) assert isinstance(catalog_schema, CatalogSchema) diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index b297c5501a..1ee4ccb775 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -10,6 +10,7 @@ from databricks.sdk import AccountClient, WorkspaceClient from databricks.sdk.errors import NotFound from databricks.sdk.service import iam, sql +from databricks.sdk.service.catalog import ExternalLocationInfo from databricks.sdk.service.compute import ClusterDetails, ClusterSource from databricks.sdk.service.workspace import ObjectInfo @@ -452,6 +453,7 @@ def test_migrate_locations_gcp(ws, caplog): def test_create_catalogs_schemas(ws): prompts = MockPrompts({'.*': 's3://test'}) + ws.external_locations.list.return_value = [ExternalLocationInfo(url="s3://test")] create_catalogs_schemas(ws, prompts) ws.catalogs.list.assert_called_once()