diff --git a/src/databricks/labs/ucx/account/aggregate.py b/src/databricks/labs/ucx/account/aggregate.py index 237ebfebf1..be52d60067 100644 --- a/src/databricks/labs/ucx/account/aggregate.py +++ b/src/databricks/labs/ucx/account/aggregate.py @@ -49,7 +49,7 @@ def __init__( self._workspace_context_factory = workspace_context_factory @cached_property - def _workspace_contexts(self): + def _workspace_contexts(self) -> list[WorkspaceContext]: contexts = [] for workspace_client in self._account_workspaces.workspace_clients(): ctx = self._workspace_context_factory(workspace_client) diff --git a/src/databricks/labs/ucx/cli.py b/src/databricks/labs/ucx/cli.py index b9b75b1640..a6e1bb57bd 100644 --- a/src/databricks/labs/ucx/cli.py +++ b/src/databricks/labs/ucx/cli.py @@ -593,11 +593,19 @@ def assign_metastore( ctx: AccountContext | None = None, ): """Assign metastore to a workspace""" + if workspace_id is None: + logger.error("--workspace-id is a required parameter.") + return + try: + workspace_id_casted = int(workspace_id) + except ValueError: + logger.error("--workspace-id should be an integer.") + return logger.info(f"Account ID: {a.config.account_id}") ctx = ctx or AccountContext(a) ctx.account_metastores.assign_metastore( ctx.prompts, - workspace_id, + workspace_id_casted, metastore_id=metastore_id, default_catalog=default_catalog, ) @@ -635,7 +643,7 @@ def migrate_tables( deployed_workflows = workspace_context.deployed_workflows deployed_workflows.run_workflow("migrate-tables") - tables = workspace_context.tables_crawler.snapshot() + tables = list(workspace_context.tables_crawler.snapshot()) hiveserde_tables = [table for table in tables if table.what == What.EXTERNAL_HIVESERDE] if len(hiveserde_tables) > 0: percentage_hiveserde_tables = len(hiveserde_tables) / len(tables) * 100 diff --git a/src/databricks/labs/ucx/contexts/account_cli.py b/src/databricks/labs/ucx/contexts/account_cli.py index 9f88049f65..ef4b1282af 100644 --- a/src/databricks/labs/ucx/contexts/account_cli.py +++ b/src/databricks/labs/ucx/contexts/account_cli.py @@ -19,21 +19,21 @@ def account_client(self) -> AccountClient: return self._ac @cached_property - def workspace_ids(self): + def workspace_ids(self) -> list[int]: return [int(_.strip()) for _ in self.named_parameters.get("workspace_ids", "").split(",") if _] @cached_property - def account_workspaces(self): + def account_workspaces(self) -> AccountWorkspaces: return AccountWorkspaces(self.account_client, self.workspace_ids) @cached_property - def account_aggregate(self): + def account_aggregate(self) -> AccountAggregate: return AccountAggregate(self.account_workspaces) @cached_property - def is_account_install(self): + def is_account_install(self) -> bool: return environ.get("UCX_FORCE_INSTALL") == "account" @cached_property - def account_metastores(self): + def account_metastores(self) -> AccountMetastores: return AccountMetastores(self.account_client) diff --git a/src/databricks/labs/ucx/contexts/application.py b/src/databricks/labs/ucx/contexts/application.py index 6e2e66c5c1..a1ff1d63ef 100644 --- a/src/databricks/labs/ucx/contexts/application.py +++ b/src/databricks/labs/ucx/contexts/application.py @@ -1,5 +1,7 @@ import abc import logging +import sys +from collections.abc import Callable, Iterable from datetime import timedelta from functools import cached_property from pathlib import Path @@ -28,12 +30,14 @@ from databricks.labs.ucx.hive_metastore import ExternalLocations, Mounts, TablesCrawler from databricks.labs.ucx.hive_metastore.catalog_schema import CatalogSchema from databricks.labs.ucx.hive_metastore.grants import ( + ACLMigrator, + AwsACL, AzureACL, + ComputeLocations, + Grant, GrantsCrawler, - PrincipalACL, - AwsACL, MigrateGrants, - ACLMigrator, + PrincipalACL, ) from databricks.labs.ucx.hive_metastore.mapping import TableMapping from databricks.labs.ucx.hive_metastore.table_migration_status import TableMigrationIndex @@ -45,15 +49,15 @@ from databricks.labs.ucx.hive_metastore.udfs import UdfsCrawler from databricks.labs.ucx.hive_metastore.verification import VerifyHasMetastore from databricks.labs.ucx.installer.workflows import DeployedWorkflows +from databricks.labs.ucx.source_code.graph import DependencyResolver from databricks.labs.ucx.source_code.jobs import WorkflowLinter +from databricks.labs.ucx.source_code.known import KnownList +from databricks.labs.ucx.source_code.linters.files import FileLoader, FolderLoader, ImportFileResolver from databricks.labs.ucx.source_code.notebooks.loaders import ( NotebookResolver, NotebookLoader, ) -from databricks.labs.ucx.source_code.linters.files import FileLoader, FolderLoader, ImportFileResolver from databricks.labs.ucx.source_code.path_lookup import PathLookup -from databricks.labs.ucx.source_code.graph import DependencyResolver -from databricks.labs.ucx.source_code.known import KnownList from databricks.labs.ucx.source_code.queries import QueryLinter from databricks.labs.ucx.source_code.redash import Redash from databricks.labs.ucx.workspace_access import generic, redash @@ -63,6 +67,11 @@ from databricks.labs.ucx.workspace_access.secrets import SecretScopesSupport from databricks.labs.ucx.workspace_access.tacl import TableAclSupport +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + # "Service Factories" would always have a lot of public methods. # This is because they are responsible for creating objects that are # used throughout the application. That being said, we'll do best @@ -79,7 +88,7 @@ def __init__(self, named_parameters: dict[str, str] | None = None): named_parameters = {} self._named_parameters = named_parameters - def replace(self, **kwargs): + def replace(self, **kwargs) -> Self: """Replace cached properties for unit testing purposes.""" for key, value in kwargs.items(): self.__dict__[key] = value @@ -102,11 +111,11 @@ def named_parameters(self) -> dict[str, str]: return self._named_parameters @cached_property - def product_info(self): + def product_info(self) -> ProductInfo: return ProductInfo.from_class(WorkspaceConfig) @cached_property - def installation(self): + def installation(self) -> Installation: return Installation.current(self.workspace_client, self.product_info.product_name()) @cached_property @@ -134,7 +143,7 @@ def inventory_database(self) -> str: return self.config.inventory_database @cached_property - def workspace_listing(self): + def workspace_listing(self) -> generic.WorkspaceListing: return generic.WorkspaceListing( self.workspace_client, self.sql_backend, @@ -144,7 +153,7 @@ def workspace_listing(self): ) @cached_property - def generic_permissions_support(self): + def generic_permissions_support(self) -> generic.GenericPermissionsSupport: models_listing = generic.models_listing(self.workspace_client, self.config.num_threads) acl_listing = [ generic.Listing(self.workspace_client.clusters.list, "cluster_id", "clusters"), @@ -169,7 +178,7 @@ def generic_permissions_support(self): ) @cached_property - def redash_permissions_support(self): + def redash_permissions_support(self) -> redash.RedashPermissionsSupport: acl_listing = [ redash.Listing(self.workspace_client.alerts.list, sql.ObjectTypePlural.ALERTS), redash.Listing(self.workspace_client.dashboards.list, sql.ObjectTypePlural.DASHBOARDS), @@ -182,17 +191,17 @@ def redash_permissions_support(self): ) @cached_property - def scim_entitlements_support(self): + def scim_entitlements_support(self) -> ScimSupport: return ScimSupport(self.workspace_client, include_object_permissions=self.config.include_object_permissions) @cached_property - def secret_scope_acl_support(self): + def secret_scope_acl_support(self) -> SecretScopesSupport: return SecretScopesSupport( self.workspace_client, include_object_permissions=self.config.include_object_permissions ) @cached_property - def legacy_table_acl_support(self): + def legacy_table_acl_support(self) -> TableAclSupport: return TableAclSupport( self.grants_crawler, self.sql_backend, @@ -200,7 +209,7 @@ def legacy_table_acl_support(self): ) @cached_property - def permission_manager(self): + def permission_manager(self) -> PermissionManager: return PermissionManager( self.sql_backend, self.inventory_database, @@ -214,7 +223,7 @@ def permission_manager(self): ) @cached_property - def group_manager(self): + def group_manager(self) -> GroupManager: return GroupManager( self.sql_backend, self.workspace_client, @@ -228,19 +237,19 @@ def group_manager(self): ) @cached_property - def grants_crawler(self): + def grants_crawler(self) -> GrantsCrawler: return GrantsCrawler(self.tables_crawler, self.udfs_crawler, self.config.include_databases) @cached_property - def udfs_crawler(self): + def udfs_crawler(self) -> UdfsCrawler: return UdfsCrawler(self.sql_backend, self.inventory_database, self.config.include_databases) @cached_property - def tables_crawler(self): + def tables_crawler(self) -> TablesCrawler: return TablesCrawler(self.sql_backend, self.inventory_database, self.config.include_databases) @cached_property - def tables_migrator(self): + def tables_migrator(self) -> TablesMigrator: return TablesMigrator( self.tables_crawler, self.workspace_client, @@ -251,7 +260,7 @@ def tables_migrator(self): ) @cached_property - def acl_migrator(self): + def acl_migrator(self) -> ACLMigrator: return ACLMigrator( self.tables_crawler, self.workspace_info, @@ -260,8 +269,8 @@ def acl_migrator(self): ) @cached_property - def migrate_grants(self): - grant_loaders = [ + def migrate_grants(self) -> MigrateGrants: + grant_loaders: list[Callable[[], Iterable[Grant]]] = [ self.grants_crawler.snapshot, self.principal_acl.get_interactive_cluster_grants, ] @@ -272,23 +281,23 @@ def migrate_grants(self): ) @cached_property - def table_move(self): + def table_move(self) -> TableMove: return TableMove(self.workspace_client, self.sql_backend) @cached_property - def mounts_crawler(self): + def mounts_crawler(self) -> Mounts: return Mounts(self.sql_backend, self.workspace_client, self.inventory_database) @cached_property - def azure_service_principal_crawler(self): + def azure_service_principal_crawler(self) -> AzureServicePrincipalCrawler: return AzureServicePrincipalCrawler(self.workspace_client, self.sql_backend, self.inventory_database) @cached_property - def external_locations(self): + def external_locations(self) -> ExternalLocations: return ExternalLocations(self.workspace_client, self.sql_backend, self.inventory_database) @cached_property - def azure_acl(self): + def azure_acl(self) -> AzureACL: return AzureACL( self.workspace_client, self.sql_backend, @@ -297,7 +306,7 @@ def azure_acl(self): ) @cached_property - def aws_acl(self): + def aws_acl(self) -> AwsACL: return AwsACL( self.workspace_client, self.sql_backend, @@ -305,7 +314,7 @@ def aws_acl(self): ) @cached_property - def principal_locations_retriever(self): + def principal_locations_retriever(self) -> Callable[[], list[ComputeLocations]]: def inner(): if self.is_azure: return self.azure_acl.get_eligible_locations_principals() @@ -316,7 +325,7 @@ def inner(): return inner @cached_property - def principal_acl(self): + def principal_acl(self) -> PrincipalACL: return PrincipalACL( self.workspace_client, self.sql_backend, @@ -327,7 +336,7 @@ def principal_acl(self): ) @cached_property - def migration_status_refresher(self): + def migration_status_refresher(self) -> TableMigrationStatusRefresher: return TableMigrationStatusRefresher( self.workspace_client, self.sql_backend, @@ -336,15 +345,15 @@ def migration_status_refresher(self): ) @cached_property - def iam_credential_manager(self): + def iam_credential_manager(self) -> CredentialManager: return CredentialManager(self.workspace_client) @cached_property - def table_mapping(self): + def table_mapping(self) -> TableMapping: return TableMapping(self.installation, self.workspace_client, self.sql_backend) @cached_property - def catalog_schema(self): + def catalog_schema(self) -> CatalogSchema: return CatalogSchema( self.workspace_client, self.table_mapping, @@ -355,31 +364,31 @@ def catalog_schema(self): ) @cached_property - def verify_timeout(self): + def verify_timeout(self) -> timedelta: return timedelta(minutes=2) @cached_property - def wheels(self): + def wheels(self) -> WheelsV2: return WheelsV2(self.installation, self.product_info) @cached_property - def install_state(self): + def install_state(self) -> InstallState: return InstallState.from_installation(self.installation) @cached_property - def deployed_workflows(self): + def deployed_workflows(self) -> DeployedWorkflows: return DeployedWorkflows(self.workspace_client, self.install_state) @cached_property - def workspace_info(self): + def workspace_info(self) -> WorkspaceInfo: return WorkspaceInfo(self.installation, self.workspace_client) @cached_property - def verify_has_metastore(self): + def verify_has_metastore(self) -> VerifyHasMetastore: return VerifyHasMetastore(self.workspace_client) @cached_property - def pip_resolver(self): + def pip_resolver(self) -> PythonLibraryResolver: return PythonLibraryResolver(self.allow_list) @cached_property @@ -387,37 +396,37 @@ def notebook_loader(self) -> NotebookLoader: return NotebookLoader() @cached_property - def notebook_resolver(self): + def notebook_resolver(self) -> NotebookResolver: return NotebookResolver(self.notebook_loader) @cached_property - def site_packages_path(self): + def site_packages_path(self) -> Path: lookup = self.path_lookup return next(path for path in lookup.library_roots if "site-packages" in path.as_posix()) @cached_property - def path_lookup(self): + def path_lookup(self) -> PathLookup: # TODO find a solution to enable a different cwd per job/task (maybe it's not necessary or possible?) return PathLookup.from_sys_path(Path.cwd()) @cached_property - def file_loader(self): + def file_loader(self) -> FileLoader: return FileLoader() @cached_property - def folder_loader(self): + def folder_loader(self) -> FolderLoader: return FolderLoader(self.notebook_loader, self.file_loader) @cached_property - def allow_list(self): + def allow_list(self) -> KnownList: return KnownList() @cached_property - def file_resolver(self): + def file_resolver(self) -> ImportFileResolver: return ImportFileResolver(self.file_loader, self.allow_list) @cached_property - def dependency_resolver(self): + def dependency_resolver(self) -> DependencyResolver: return DependencyResolver( self.pip_resolver, self.notebook_resolver, self.file_resolver, self.file_resolver, self.path_lookup ) @@ -435,7 +444,7 @@ def workflow_linter(self) -> WorkflowLinter: ) @cached_property - def query_linter(self): + def query_linter(self) -> QueryLinter: return QueryLinter( self.workspace_client, TableMigrationIndex([]), # TODO: bring back self.tables_migrator.index() @@ -444,11 +453,11 @@ def query_linter(self): ) @cached_property - def directfs_access_crawler_for_paths(self): + def directfs_access_crawler_for_paths(self) -> DirectFsAccessCrawler: return DirectFsAccessCrawler.for_paths(self.sql_backend, self.inventory_database) @cached_property - def directfs_access_crawler_for_queries(self): + def directfs_access_crawler_for_queries(self) -> DirectFsAccessCrawler: return DirectFsAccessCrawler.for_queries(self.sql_backend, self.inventory_database) @cached_property @@ -460,7 +469,7 @@ def used_tables_crawler_for_queries(self): return UsedTablesCrawler.for_queries(self.sql_backend, self.inventory_database) @cached_property - def redash(self): + def redash(self) -> Redash: return Redash( self.migration_status_refresher.index(), self.workspace_client, @@ -468,23 +477,23 @@ def redash(self): ) @cached_property - def metadata_retriever(self): + def metadata_retriever(self) -> DatabricksTableMetadataRetriever: return DatabricksTableMetadataRetriever(self.sql_backend) @cached_property - def schema_comparator(self): + def schema_comparator(self) -> StandardSchemaComparator: return StandardSchemaComparator(self.metadata_retriever) @cached_property - def data_profiler(self): + def data_profiler(self) -> StandardDataProfiler: return StandardDataProfiler(self.sql_backend, self.metadata_retriever) @cached_property - def data_comparator(self): + def data_comparator(self) -> StandardDataComparator: return StandardDataComparator(self.sql_backend, self.data_profiler) @cached_property - def migration_recon(self): + def migration_recon(self) -> MigrationRecon: return MigrationRecon( self.sql_backend, self.inventory_database, diff --git a/src/databricks/labs/ucx/contexts/workflow_task.py b/src/databricks/labs/ucx/contexts/workflow_task.py index 715fb5c67d..14a6130fca 100644 --- a/src/databricks/labs/ucx/contexts/workflow_task.py +++ b/src/databricks/labs/ucx/contexts/workflow_task.py @@ -46,16 +46,16 @@ def sql_backend(self) -> SqlBackend: return RuntimeBackend(debug_truncate_bytes=self.connect_config.debug_truncate_bytes) @cached_property - def installation(self): + def installation(self) -> Installation: install_folder = self._config_path.parent.as_posix().removeprefix("/Workspace") return Installation(self.workspace_client, "ucx", install_folder=install_folder) @cached_property - def jobs_crawler(self): + def jobs_crawler(self) -> JobsCrawler: return JobsCrawler(self.workspace_client, self.sql_backend, self.inventory_database) @cached_property - def submit_runs_crawler(self): + def submit_runs_crawler(self) -> SubmitRunsCrawler: return SubmitRunsCrawler( self.workspace_client, self.sql_backend, @@ -64,31 +64,32 @@ def submit_runs_crawler(self): ) @cached_property - def clusters_crawler(self): + def clusters_crawler(self) -> ClustersCrawler: return ClustersCrawler(self.workspace_client, self.sql_backend, self.inventory_database) @cached_property - def pipelines_crawler(self): + def pipelines_crawler(self) -> PipelinesCrawler: return PipelinesCrawler(self.workspace_client, self.sql_backend, self.inventory_database) @cached_property - def table_size_crawler(self): + def table_size_crawler(self) -> TableSizeCrawler: return TableSizeCrawler(self.sql_backend, self.inventory_database, self.config.include_databases) @cached_property - def policies_crawler(self): + def policies_crawler(self) -> PoliciesCrawler: return PoliciesCrawler(self.workspace_client, self.sql_backend, self.inventory_database) @cached_property - def global_init_scripts_crawler(self): + def global_init_scripts_crawler(self) -> GlobalInitScriptCrawler: return GlobalInitScriptCrawler(self.workspace_client, self.sql_backend, self.inventory_database) @cached_property def tables_crawler(self): + # TODO: Update tables crawler inheritance to specify return type hint return FasterTableScanCrawler(self.sql_backend, self.inventory_database) @cached_property - def tables_in_mounts(self): + def tables_in_mounts(self) -> TablesInMounts: return TablesInMounts( self.sql_backend, self.workspace_client, @@ -100,7 +101,7 @@ def tables_in_mounts(self): ) @cached_property - def task_run_warning_recorder(self): + def task_run_warning_recorder(self) -> TaskRunWarningRecorder: return TaskRunWarningRecorder( self._config_path.parent, self.named_parameters["workflow"], diff --git a/src/databricks/labs/ucx/contexts/workspace_cli.py b/src/databricks/labs/ucx/contexts/workspace_cli.py index 2c4830a863..5191c1792e 100644 --- a/src/databricks/labs/ucx/contexts/workspace_cli.py +++ b/src/databricks/labs/ucx/contexts/workspace_cli.py @@ -1,6 +1,7 @@ import logging import os import shutil +from collections.abc import Callable from functools import cached_property from databricks.labs.lsql.backends import SqlBackend, StatementExecutionBackend @@ -42,11 +43,11 @@ def sql_backend(self) -> SqlBackend: return StatementExecutionBackend(self.workspace_client, self.config.warehouse_id) @cached_property - def cluster_access(self): + def cluster_access(self) -> ClusterAccess: return ClusterAccess(self.installation, self.workspace_client, self.prompts) @cached_property - def azure_cli_authenticated(self): + def azure_cli_authenticated(self) -> bool: if not self.is_azure: raise NotImplementedError("Azure only") if self.connect_config.auth_type != "azure-cli": @@ -54,7 +55,7 @@ def azure_cli_authenticated(self): return True @cached_property - def azure_management_client(self): + def azure_management_client(self) -> AzureAPIClient: if not self.azure_cli_authenticated: raise NotImplementedError return AzureAPIClient( @@ -63,7 +64,7 @@ def azure_management_client(self): ) @cached_property - def microsoft_graph_client(self): + def microsoft_graph_client(self) -> AzureAPIClient: if not self.azure_cli_authenticated: raise NotImplementedError return AzureAPIClient("https://graph.microsoft.com", "https://graph.microsoft.com") @@ -76,7 +77,7 @@ def azure_subscription_ids(self) -> list[str]: return subscription_ids.split(",") @cached_property - def azure_resources(self): + def azure_resources(self) -> AzureResources: return AzureResources( self.azure_management_client, self.microsoft_graph_client, @@ -84,7 +85,7 @@ def azure_resources(self): ) @cached_property - def azure_resource_permissions(self): + def azure_resource_permissions(self) -> AzureResourcePermissions: return AzureResourcePermissions( self.installation, self.workspace_client, @@ -93,11 +94,11 @@ def azure_resource_permissions(self): ) @cached_property - def azure_credential_manager(self): + def azure_credential_manager(self) -> StorageCredentialManager: return StorageCredentialManager(self.workspace_client) @cached_property - def service_principal_migration(self): + def service_principal_migration(self) -> ServicePrincipalMigration: return ServicePrincipalMigration( self.installation, self.workspace_client, @@ -107,7 +108,7 @@ def service_principal_migration(self): ) @cached_property - def external_locations_migration(self): + def external_locations_migration(self) -> AWSExternalLocationsMigration | ExternalLocationsMigration: if self.is_aws: return AWSExternalLocationsMigration( self.workspace_client, @@ -126,7 +127,7 @@ def external_locations_migration(self): raise NotImplementedError @cached_property - def aws_cli_run_command(self): + def aws_cli_run_command(self) -> Callable[[str | list[str]], tuple[int, str, str]]: # this is a convenience method for unit testing if not shutil.which("aws"): raise ValueError("Couldn't find AWS CLI in path. Please install the CLI from https://aws.amazon.com/cli/") @@ -145,13 +146,13 @@ def aws_profile(self) -> str: return aws_profile @cached_property - def aws_resources(self): + def aws_resources(self) -> AWSResources: if not self.is_aws: raise NotImplementedError("AWS only") return AWSResources(self.aws_profile, self.aws_cli_run_command) @cached_property - def aws_resource_permissions(self): + def aws_resource_permissions(self) -> AWSResourcePermissions: return AWSResourcePermissions( self.installation, self.workspace_client, @@ -161,7 +162,7 @@ def aws_resource_permissions(self): ) @cached_property - def iam_role_migration(self): + def iam_role_migration(self) -> IamRoleMigration: return IamRoleMigration( self.installation, self.aws_resource_permissions, @@ -169,7 +170,7 @@ def iam_role_migration(self): ) @cached_property - def iam_role_creation(self): + def iam_role_creation(self) -> IamRoleCreation: return IamRoleCreation( self.installation, self.workspace_client, @@ -200,11 +201,11 @@ def linter_context_factory(self, session_state: CurrentSessionState | None = Non return LinterContext(index, session_state) @cached_property - def local_file_migrator(self): + def local_file_migrator(self) -> LocalFileMigrator: return LocalFileMigrator(lambda: self.linter_context_factory(CurrentSessionState())) @cached_property - def local_code_linter(self): + def local_code_linter(self) -> LocalCodeLinter: session_state = CurrentSessionState() return LocalCodeLinter( self.notebook_loader, diff --git a/src/databricks/labs/ucx/hive_metastore/locations.py b/src/databricks/labs/ucx/hive_metastore/locations.py index 05802153b4..6ad6817f8e 100644 --- a/src/databricks/labs/ucx/hive_metastore/locations.py +++ b/src/databricks/labs/ucx/hive_metastore/locations.py @@ -54,10 +54,10 @@ class LocationTrie: tables: list[Table] = dataclasses.field(default_factory=list) @cached_property - def _path(self): + def _path(self) -> list[str]: """The path to traverse to get to the current node.""" parts = [] - current = self + current: LocationTrie | None = self while current: parts.append(current.key) current = current.parent diff --git a/src/databricks/labs/ucx/hive_metastore/tables.py b/src/databricks/labs/ucx/hive_metastore/tables.py index f935aada95..097faca778 100644 --- a/src/databricks/labs/ucx/hive_metastore/tables.py +++ b/src/databricks/labs/ucx/hive_metastore/tables.py @@ -4,7 +4,7 @@ from collections.abc import Iterable, Iterator, Collection from dataclasses import dataclass from enum import Enum, auto -from functools import partial, cached_property +from functools import cached_property, partial import sqlglot from sqlglot import expressions diff --git a/src/databricks/labs/ucx/install.py b/src/databricks/labs/ucx/install.py index 812be2c22d..d2cc3a9561 100644 --- a/src/databricks/labs/ucx/install.py +++ b/src/databricks/labs/ucx/install.py @@ -165,15 +165,15 @@ def __init__( self._tasks = tasks if tasks else Workflows.all().tasks() @cached_property - def upgrades(self): + def upgrades(self) -> Upgrades: return Upgrades(self.product_info, self.installation) @cached_property - def policy_installer(self): + def policy_installer(self) -> ClusterPolicyInstaller: return ClusterPolicyInstaller(self.installation, self.workspace_client, self.prompts) @cached_property - def installation(self): + def installation(self) -> Installation: try: return self.product_info.current_installation(self.workspace_client) except NotFound: @@ -820,12 +820,13 @@ def join_collection(self, workspace_ids: list[int], join_on_install: bool = Fals account_client = self._get_safe_account_client() ctx = AccountContext(account_client) - try: + try: # pylint: disable=too-many-try-statements # if user is account admin list all the available workspace the user has admin access on. # This code is run if joining collection after installation or through cli accessible_workspaces = ctx.account_workspaces.get_accessible_workspaces() for workspace in accessible_workspaces: - ids_to_workspace[workspace.workspace_id] = workspace + if workspace.workspace_id is not None: + ids_to_workspace[workspace.workspace_id] = workspace if join_on_install: # if run as part of ucx installation allow user to select from the list to join target_workspace = self._get_collection_workspace(accessible_workspaces, account_client) diff --git a/src/databricks/labs/ucx/source_code/known.py b/src/databricks/labs/ucx/source_code/known.py index 60a266b26d..4d08f746a8 100644 --- a/src/databricks/labs/ucx/source_code/known.py +++ b/src/databricks/labs/ucx/source_code/known.py @@ -8,6 +8,7 @@ import re import sys from dataclasses import dataclass +from email.message import Message from functools import cached_property from pathlib import Path @@ -208,7 +209,7 @@ def module_paths(self) -> list[Path]: return files @cached_property - def _metadata(self): + def _metadata(self) -> Message: with Path(self._path, "METADATA").open(encoding=_DEFAULT_ENCODING) as f: return email.message_from_file(f) diff --git a/src/databricks/labs/ucx/source_code/python_libraries.py b/src/databricks/labs/ucx/source_code/python_libraries.py index f675362ed8..21e3738fe6 100644 --- a/src/databricks/labs/ucx/source_code/python_libraries.py +++ b/src/databricks/labs/ucx/source_code/python_libraries.py @@ -43,7 +43,7 @@ def register_library(self, path_lookup: PathLookup, *libraries: str) -> list[Dep return self._install_library(path_lookup, *libraries) @cached_property - def _temporary_virtual_environment(self): + def _temporary_virtual_environment(self) -> Path: # TODO: for `databricks labs ucx lint-local-code`, detect if we already have a virtual environment # and use that one. See Databricks CLI code for the labs command to see how to detect the virtual # environment. If we don't have a virtual environment, create a temporary one. diff --git a/src/databricks/labs/ucx/workspace_access/generic.py b/src/databricks/labs/ucx/workspace_access/generic.py index 0fd06db6d9..6945f22d73 100644 --- a/src/databricks/labs/ucx/workspace_access/generic.py +++ b/src/databricks/labs/ucx/workspace_access/generic.py @@ -397,7 +397,7 @@ def __repr__(self): return f"WorkspaceListing(start_path={self._start_path})" -def models_listing(ws: WorkspaceClient, num_threads: int): +def models_listing(ws: WorkspaceClient, num_threads: int | None) -> Callable[[], Iterator[ml.ModelDatabricks]]: def inner() -> Iterator[ml.ModelDatabricks]: tasks = [] for model in ws.model_registry.list_models(): diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 2fc3f47b08..853257e134 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -26,12 +26,12 @@ from databricks.sdk import AccountClient, WorkspaceClient from databricks.sdk.errors import NotFound from databricks.sdk.retries import retried -from databricks.sdk.service import iam +from databricks.sdk.service import iam, jobs from databricks.sdk.service.catalog import FunctionInfo, SchemaInfo, TableInfo from databricks.sdk.service.compute import ClusterSpec from databricks.sdk.service.dashboards import Dashboard as SDKDashboard from databricks.sdk.service.iam import Group -from databricks.sdk.service.jobs import Task, SparkPythonTask +from databricks.sdk.service.jobs import SparkPythonTask from databricks.sdk.service.sql import Dashboard, WidgetPosition, WidgetOptions, LegacyQuery from databricks.labs.ucx.__about__ import __version__ @@ -46,6 +46,7 @@ from databricks.labs.ucx.config import WorkspaceConfig from databricks.labs.ucx.contexts.workspace_cli import WorkspaceContext from databricks.labs.ucx.contexts.workflow_task import RuntimeContext +from databricks.labs.ucx.framework.tasks import Task from databricks.labs.ucx.hive_metastore import TablesCrawler from databricks.labs.ucx.hive_metastore.grants import Grant from databricks.labs.ucx.hive_metastore.locations import Mount, Mounts, ExternalLocation, ExternalLocations @@ -436,7 +437,7 @@ def with_aws_storage_permissions( self.installation.save(uc_roles_mapping, filename=AWSResourcePermissions.UC_ROLES_FILE_NAME) @cached_property - def installation(self): + def installation(self) -> Installation: return MockInstallation() @cached_property @@ -564,7 +565,7 @@ def config(self) -> WorkspaceConfig: ) @cached_property - def tables_crawler(self) -> TablesCrawler: + def tables_crawler(self): """ Returns a TablesCrawler instance with the tables that were created in the context. Overrides the FasterTableScanCrawler with TablesCrawler used as DBR is not available while running integration tests @@ -658,14 +659,15 @@ def created_databases(self) -> list[str]: return list(created_databases) @cached_property - def created_groups(self): + def created_groups(self) -> list[str]: created_groups = [] for group in self._groups: - created_groups.append(group.display_name) + if group.display_name is not None: + created_groups.append(group.display_name) return created_groups @cached_property - def azure_service_principal_crawler(self): + def azure_service_principal_crawler(self) -> StaticServicePrincipalCrawler: return StaticServicePrincipalCrawler( self._spn_infos, self.workspace_client, @@ -674,7 +676,7 @@ def azure_service_principal_crawler(self): ) @cached_property - def mounts_crawler(self): + def mounts_crawler(self) -> StaticMountCrawler: mount = Mount( f'/mnt/{self._env_or_skip("TEST_MOUNT_NAME")}/a', f'{self._env_or_skip("TEST_MOUNT_CONTAINER")}/a' ) @@ -686,7 +688,7 @@ def mounts_crawler(self): ) @cached_property - def group_manager(self): + def group_manager(self) -> GroupManager: return GroupManager( self.sql_backend, self.workspace_client, @@ -768,7 +770,7 @@ def save_locations(self) -> None: class MockLocalAzureCli(MockWorkspaceContext): @cached_property - def azure_cli_authenticated(self): + def azure_cli_authenticated(self) -> bool: if not self.is_azure: pytest.skip("Azure only") if self.connect_config.auth_type != "azure-cli": @@ -788,7 +790,7 @@ def az_cli_ctx(ws, env_or_skip, make_catalog, make_schema, make_random, sql_back class MockLocalAwsCli(MockWorkspaceContext): @cached_property - def aws_cli_run_command(self): + def aws_cli_run_command(self) -> Callable[[str | list[str]], tuple[int, str, str]]: if not self.is_aws: pytest.skip("Aws only") if not shutil.which("aws"): @@ -796,7 +798,7 @@ def aws_cli_run_command(self): return run_command @cached_property - def aws_profile(self): + def aws_profile(self) -> str: return self._env_or_skip("AWS_PROFILE") @@ -884,7 +886,7 @@ def make_ucx_group(self, workspace_group_name=None, account_group_name=None, wai return ws_group, acc_group @cached_property - def running_clusters(self): + def running_clusters(self) -> tuple[str, str, str]: logger.debug("Waiting for clusters to start...") default_cluster_id = self._env_or_skip("TEST_DEFAULT_CLUSTER_ID") tacl_cluster_id = self._env_or_skip("TEST_LEGACY_TABLE_ACL_CLUSTER_ID") @@ -902,15 +904,15 @@ def running_clusters(self): return default_cluster_id, tacl_cluster_id, table_migration_cluster_id @cached_property - def installation(self): + def installation(self) -> Installation: return Installation(self.workspace_client, self.product_info.product_name()) @cached_property - def account_client(self): + def account_client(self) -> AccountClient: return AccountClient(product="ucx", product_version=__version__) @cached_property - def account_installer(self): + def account_installer(self) -> AccountInstaller: return AccountInstaller(self.account_client) @cached_property @@ -918,7 +920,7 @@ def environ(self) -> dict[str, str]: return {**os.environ} @cached_property - def workspace_installer(self): + def workspace_installer(self) -> WorkspaceInstaller: return WorkspaceInstaller( self.workspace_client, self.environ, @@ -929,7 +931,7 @@ def config_transform(self) -> Callable[[WorkspaceConfig], WorkspaceConfig]: return lambda wc: wc @cached_property - def include_object_permissions(self): + def include_object_permissions(self) -> None: return None @cached_property @@ -956,15 +958,15 @@ def config(self) -> WorkspaceConfig: return workspace_config @cached_property - def product_info(self): + def product_info(self) -> ProductInfo: return ProductInfo.for_testing(WorkspaceConfig) @cached_property - def tasks(self): + def tasks(self) -> list[Task]: return Workflows.all().tasks() @cached_property - def workflows_deployment(self): + def workflows_deployment(self) -> WorkflowsDeployment: return WorkflowsDeployment( self.config, self.installation, @@ -976,7 +978,7 @@ def workflows_deployment(self): ) @cached_property - def workspace_installation(self): + def workspace_installation(self) -> WorkspaceInstallation: return WorkspaceInstallation( self.config, self.installation, @@ -993,15 +995,15 @@ def progress_tracking_installation(self) -> ProgressTrackingInstallation: return ProgressTrackingInstallation(self.sql_backend, self.ucx_catalog) @cached_property - def extend_prompts(self): + def extend_prompts(self) -> dict[str, str]: return {} @cached_property - def renamed_group_prefix(self): + def renamed_group_prefix(self) -> str: return f"rename-{self.product_info.product_name()}-" @cached_property - def prompts(self): + def prompts(self) -> MockPrompts: return MockPrompts( { r'Open job overview in your browser.*': 'no', @@ -1262,7 +1264,7 @@ def create(installation, **_kwargs): file_name = f"dummy_{make_random(4)}_{watchdog_purge_suffix}" file_path = WorkspacePath(ws, installation.install_folder()) / file_name file_path.write_text("spark.read.parquet('dbfs://mnt/foo/bar')") - task = Task( + task = jobs.Task( task_key=make_random(4), description=make_random(4), new_cluster=ClusterSpec(