diff --git a/databricks_cli/cli.py b/databricks_cli/cli.py index 5825efc4..0b0a8f1b 100644 --- a/databricks_cli/cli.py +++ b/databricks_cli/cli.py @@ -67,4 +67,4 @@ def cli(): cli.add_command(pipelines_group, name='pipelines') if __name__ == "__main__": - cli() \ No newline at end of file + cli() diff --git a/databricks_cli/configure/cli.py b/databricks_cli/configure/cli.py index a933ddef..3e8303da 100644 --- a/databricks_cli/configure/cli.py +++ b/databricks_cli/configure/cli.py @@ -21,6 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import click from click import ParamType @@ -34,16 +35,37 @@ PROMPT_USERNAME = 'Username' PROMPT_PASSWORD = 'Password' # NOQA PROMPT_TOKEN = 'Token' # NOQA +ENV_AAD_TOKEN = 'DATABRICKS_AAD_TOKEN' def _configure_cli_token(profile, insecure): config = ProfileConfigProvider(profile).get_config() or DatabricksConfig.empty() host = click.prompt(PROMPT_HOST, default=config.host, type=_DbfsHost()) - token = click.prompt(PROMPT_TOKEN, default=config.token, hide_input=True) + token = click.prompt(PROMPT_TOKEN, default=config.token) new_config = DatabricksConfig.from_token(host, token, insecure) update_and_persist_config(profile, new_config) +def _configure_cli_aad_token(profile, insecure): + config = ProfileConfigProvider(profile).get_config() or DatabricksConfig.empty() + + if ENV_AAD_TOKEN not in os.environ: + print('[ERROR] Set Environment Variable \'%s\' with your ' + 'AAD Token and run again.\n' % ENV_AAD_TOKEN) + print('Commands to run to get your AAD token:\n' + '\t az login\n' + '\t token_response=$(az account get-access-token ' + '--resource 2ff814a6-3304-4ab8-85cb-cd0e6f879c1d)\n' + '\t export %s=$(jq .accessToken -r <<< "$token_response")\n' % ENV_AAD_TOKEN + ) + return + + host = click.prompt(PROMPT_HOST, default=config.host, type=_DbfsHost()) + aad_token = os.environ.get(ENV_AAD_TOKEN) + new_config = DatabricksConfig.from_token(host, aad_token, insecure) + update_and_persist_config(profile, new_config) + + def _configure_cli_password(profile, insecure): config = ProfileConfigProvider(profile).get_config() or DatabricksConfig.empty() if config.password: @@ -63,10 +85,11 @@ def _configure_cli_password(profile, insecure): @click.command(context_settings=CONTEXT_SETTINGS, short_help='Configures host and authentication info for the CLI.') @click.option('--token', show_default=True, is_flag=True, default=False) +@click.option('--aad-token', show_default=True, is_flag=True, default=False) @click.option('--insecure', show_default=True, is_flag=True, default=None) @debug_option @profile_option -def configure_cli(token, insecure): +def configure_cli(token, aad_token, insecure): """ Configures host and authentication info for the CLI. """ @@ -74,6 +97,8 @@ def configure_cli(token, insecure): insecure_str = str(insecure) if insecure is not None else None if token: _configure_cli_token(profile, insecure_str) + elif aad_token: + _configure_cli_aad_token(profile, insecure_str) else: _configure_cli_password(profile, insecure_str) diff --git a/databricks_cli/configure/provider.py b/databricks_cli/configure/provider.py index d8b96aae..e1a1eeac 100644 --- a/databricks_cli/configure/provider.py +++ b/databricks_cli/configure/provider.py @@ -263,7 +263,7 @@ def get_config(self): class DatabricksConfig(object): - def __init__(self, host, username, password, token, insecure): # noqa + def __init__(self, host, username, password, token, insecure): # noqa self.host = host self.username = username self.password = password diff --git a/databricks_cli/sdk/service.py b/databricks_cli/sdk/service.py old mode 100644 new mode 100755 index c0b4b55b..0d6a05f3 --- a/databricks_cli/sdk/service.py +++ b/databricks_cli/sdk/service.py @@ -78,7 +78,7 @@ def create_job(self, name=None, existing_cluster_id=None, new_cluster=None, libr if max_concurrent_runs is not None: _data['max_concurrent_runs'] = max_concurrent_runs return self.client.perform_query('POST', '/jobs/create', data=_data, headers=headers) - + def submit_run(self, run_name=None, existing_cluster_id=None, new_cluster=None, libraries=None, notebook_task=None, spark_jar_task=None, spark_python_task=None, spark_submit_task=None, timeout_seconds=None, headers=None): @@ -112,7 +112,7 @@ def submit_run(self, run_name=None, existing_cluster_id=None, new_cluster=None, if timeout_seconds is not None: _data['timeout_seconds'] = timeout_seconds return self.client.perform_query('POST', '/jobs/runs/submit', data=_data, headers=headers) - + def reset_job(self, job_id, new_settings, headers=None): _data = {} if job_id is not None: @@ -122,24 +122,24 @@ def reset_job(self, job_id, new_settings, headers=None): if not isinstance(new_settings, dict): raise TypeError('Expected databricks.JobSettings() or dict for field new_settings') return self.client.perform_query('POST', '/jobs/reset', data=_data, headers=headers) - + def delete_job(self, job_id, headers=None): _data = {} if job_id is not None: _data['job_id'] = job_id return self.client.perform_query('POST', '/jobs/delete', data=_data, headers=headers) - + def get_job(self, job_id, headers=None): _data = {} if job_id is not None: _data['job_id'] = job_id return self.client.perform_query('GET', '/jobs/get', data=_data, headers=headers) - + def list_jobs(self, headers=None): _data = {} - + return self.client.perform_query('GET', '/jobs/list', data=_data, headers=headers) - + def run_now(self, job_id=None, jar_params=None, notebook_params=None, python_params=None, spark_submit_params=None, headers=None): _data = {} @@ -154,7 +154,7 @@ def run_now(self, job_id=None, jar_params=None, notebook_params=None, python_par if spark_submit_params is not None: _data['spark_submit_params'] = spark_submit_params return self.client.perform_query('POST', '/jobs/run-now', data=_data, headers=headers) - + def list_runs(self, job_id=None, active_only=None, completed_only=None, offset=None, limit=None, headers=None): _data = {} @@ -169,31 +169,31 @@ def list_runs(self, job_id=None, active_only=None, completed_only=None, offset=N if limit is not None: _data['limit'] = limit return self.client.perform_query('GET', '/jobs/runs/list', data=_data, headers=headers) - + def get_run(self, run_id=None, headers=None): _data = {} if run_id is not None: _data['run_id'] = run_id return self.client.perform_query('GET', '/jobs/runs/get', data=_data, headers=headers) - + def delete_run(self, run_id=None, headers=None): _data = {} if run_id is not None: _data['run_id'] = run_id return self.client.perform_query('POST', '/jobs/runs/delete', data=_data, headers=headers) - + def cancel_run(self, run_id, headers=None): _data = {} if run_id is not None: _data['run_id'] = run_id return self.client.perform_query('POST', '/jobs/runs/cancel', data=_data, headers=headers) - + def get_run_output(self, run_id, headers=None): _data = {} if run_id is not None: _data['run_id'] = run_id return self.client.perform_query('GET', '/jobs/runs/get-output', data=_data, headers=headers) - + def export_run(self, run_id, views_to_export=None, headers=None): _data = {} if run_id is not None: @@ -201,7 +201,7 @@ def export_run(self, run_id, views_to_export=None, headers=None): if views_to_export is not None: _data['views_to_export'] = views_to_export return self.client.perform_query('GET', '/jobs/runs/export', data=_data, headers=headers) - + class ClusterService(object): def __init__(self, client): @@ -209,15 +209,15 @@ def __init__(self, client): def list_clusters(self, headers=None): _data = {} - + return self.client.perform_query('GET', '/clusters/list', data=_data, headers=headers) - + def create_cluster(self, num_workers=None, autoscale=None, cluster_name=None, spark_version=None, spark_conf=None, aws_attributes=None, node_type_id=None, driver_node_type_id=None, ssh_public_keys=None, custom_tags=None, - cluster_log_conf=None, init_scripts=None, spark_env_vars=None, - autotermination_minutes=None, enable_elastic_disk=None, cluster_source=None, - instance_pool_id=None, headers=None): + cluster_log_conf=None, spark_env_vars=None, autotermination_minutes=None, + enable_elastic_disk=None, cluster_source=None, instance_pool_id=None, + headers=None): _data = {} if num_workers is not None: _data['num_workers'] = num_workers @@ -247,8 +247,6 @@ def create_cluster(self, num_workers=None, autoscale=None, cluster_name=None, sp _data['cluster_log_conf'] = cluster_log_conf if not isinstance(cluster_log_conf, dict): raise TypeError('Expected databricks.ClusterLogConf() or dict for field cluster_log_conf') - if init_scripts is not None: - _data['init_scripts'] = init_scripts if spark_env_vars is not None: _data['spark_env_vars'] = spark_env_vars if autotermination_minutes is not None: @@ -260,36 +258,36 @@ def create_cluster(self, num_workers=None, autoscale=None, cluster_name=None, sp if instance_pool_id is not None: _data['instance_pool_id'] = instance_pool_id return self.client.perform_query('POST', '/clusters/create', data=_data, headers=headers) - + def start_cluster(self, cluster_id, headers=None): _data = {} if cluster_id is not None: _data['cluster_id'] = cluster_id return self.client.perform_query('POST', '/clusters/start', data=_data, headers=headers) - + def list_spark_versions(self, headers=None): _data = {} - + return self.client.perform_query('GET', '/clusters/spark-versions', data=_data, headers=headers) - + def delete_cluster(self, cluster_id, headers=None): _data = {} if cluster_id is not None: _data['cluster_id'] = cluster_id return self.client.perform_query('POST', '/clusters/delete', data=_data, headers=headers) - + def permanent_delete_cluster(self, cluster_id, headers=None): _data = {} if cluster_id is not None: _data['cluster_id'] = cluster_id return self.client.perform_query('POST', '/clusters/permanent-delete', data=_data, headers=headers) - + def restart_cluster(self, cluster_id, headers=None): _data = {} if cluster_id is not None: _data['cluster_id'] = cluster_id return self.client.perform_query('POST', '/clusters/restart', data=_data, headers=headers) - + def resize_cluster(self, cluster_id, num_workers=None, autoscale=None, headers=None): _data = {} if cluster_id is not None: @@ -301,13 +299,13 @@ def resize_cluster(self, cluster_id, num_workers=None, autoscale=None, headers=N if not isinstance(autoscale, dict): raise TypeError('Expected databricks.AutoScale() or dict for field autoscale') return self.client.perform_query('POST', '/clusters/resize', data=_data, headers=headers) - + def edit_cluster(self, cluster_id, num_workers=None, autoscale=None, cluster_name=None, spark_version=None, spark_conf=None, aws_attributes=None, node_type_id=None, driver_node_type_id=None, ssh_public_keys=None, custom_tags=None, - cluster_log_conf=None, init_scripts=None, spark_env_vars=None, - autotermination_minutes=None, enable_elastic_disk=None, cluster_source=None, - instance_pool_id=None, headers=None): + cluster_log_conf=None, spark_env_vars=None, autotermination_minutes=None, + enable_elastic_disk=None, cluster_source=None, instance_pool_id=None, + headers=None): _data = {} if cluster_id is not None: _data['cluster_id'] = cluster_id @@ -339,8 +337,6 @@ def edit_cluster(self, cluster_id, num_workers=None, autoscale=None, cluster_nam _data['cluster_log_conf'] = cluster_log_conf if not isinstance(cluster_log_conf, dict): raise TypeError('Expected databricks.ClusterLogConf() or dict for field cluster_log_conf') - if init_scripts is not None: - _data['init_scripts'] = init_scripts if spark_env_vars is not None: _data['spark_env_vars'] = spark_env_vars if autotermination_minutes is not None: @@ -352,35 +348,35 @@ def edit_cluster(self, cluster_id, num_workers=None, autoscale=None, cluster_nam if instance_pool_id is not None: _data['instance_pool_id'] = instance_pool_id return self.client.perform_query('POST', '/clusters/edit', data=_data, headers=headers) - + def get_cluster(self, cluster_id, headers=None): _data = {} if cluster_id is not None: _data['cluster_id'] = cluster_id return self.client.perform_query('GET', '/clusters/get', data=_data, headers=headers) - + def pin_cluster(self, cluster_id, headers=None): _data = {} if cluster_id is not None: _data['cluster_id'] = cluster_id return self.client.perform_query('POST', '/clusters/pin', data=_data, headers=headers) - + def unpin_cluster(self, cluster_id, headers=None): _data = {} if cluster_id is not None: _data['cluster_id'] = cluster_id return self.client.perform_query('POST', '/clusters/unpin', data=_data, headers=headers) - + def list_node_types(self, headers=None): _data = {} - + return self.client.perform_query('GET', '/clusters/list-node-types', data=_data, headers=headers) - + def list_available_zones(self, headers=None): _data = {} - + return self.client.perform_query('GET', '/clusters/list-zones', data=_data, headers=headers) - + def get_events(self, cluster_id, start_time=None, end_time=None, order=None, event_types=None, offset=None, limit=None, headers=None): _data = {} @@ -400,6 +396,7 @@ def get_events(self, cluster_id, start_time=None, end_time=None, order=None, eve _data['limit'] = limit return self.client.perform_query('POST', '/clusters/events', data=_data, headers=headers) + class PolicyService(object): def __init__(self, client): self.client = client @@ -452,12 +449,12 @@ def cluster_status(self, cluster_id, headers=None): if cluster_id is not None: _data['cluster_id'] = cluster_id return self.client.perform_query('GET', '/libraries/cluster-status', data=_data, headers=headers) - + def all_cluster_statuses(self, headers=None): _data = {} - + return self.client.perform_query('GET', '/libraries/all-cluster-statuses', data=_data, headers=headers) - + def install_libraries(self, cluster_id, libraries=None, headers=None): _data = {} if cluster_id is not None: @@ -465,7 +462,7 @@ def install_libraries(self, cluster_id, libraries=None, headers=None): if libraries is not None: _data['libraries'] = libraries return self.client.perform_query('POST', '/libraries/install', data=_data, headers=headers) - + def uninstall_libraries(self, cluster_id, libraries=None, headers=None): _data = {} if cluster_id is not None: @@ -473,7 +470,7 @@ def uninstall_libraries(self, cluster_id, libraries=None, headers=None): if libraries is not None: _data['libraries'] = libraries return self.client.perform_query('POST', '/libraries/uninstall', data=_data, headers=headers) - + class DbfsService(object): def __init__(self, client): @@ -488,19 +485,41 @@ def read(self, path, offset=None, length=None, headers=None): if length is not None: _data['length'] = length return self.client.perform_query('GET', '/dbfs/read', data=_data, headers=headers) - + + def read_test(self, path, offset=None, length=None, headers=None): + _data = {} + if path is not None: + _data['path'] = path + if offset is not None: + _data['offset'] = offset + if length is not None: + _data['length'] = length + return self.client.perform_query('GET', '/dbfs-testing/read', data=_data, headers=headers) + def get_status(self, path, headers=None): _data = {} if path is not None: _data['path'] = path return self.client.perform_query('GET', '/dbfs/get-status', data=_data, headers=headers) - + + def get_status_test(self, path, headers=None): + _data = {} + if path is not None: + _data['path'] = path + return self.client.perform_query('GET', '/dbfs-testing/get-status', data=_data, headers=headers) + def list(self, path, headers=None): _data = {} if path is not None: _data['path'] = path return self.client.perform_query('GET', '/dbfs/list', data=_data, headers=headers) - + + def list_test(self, path, headers=None): + _data = {} + if path is not None: + _data['path'] = path + return self.client.perform_query('GET', '/dbfs-testing/list', data=_data, headers=headers) + def put(self, path, contents=None, overwrite=None, headers=None): _data = {} if path is not None: @@ -510,13 +529,29 @@ def put(self, path, contents=None, overwrite=None, headers=None): if overwrite is not None: _data['overwrite'] = overwrite return self.client.perform_query('POST', '/dbfs/put', data=_data, headers=headers) - + + def put_test(self, path, contents=None, overwrite=None, headers=None): + _data = {} + if path is not None: + _data['path'] = path + if contents is not None: + _data['contents'] = contents + if overwrite is not None: + _data['overwrite'] = overwrite + return self.client.perform_query('POST', '/dbfs-testing/put', data=_data, headers=headers) + def mkdirs(self, path, headers=None): _data = {} if path is not None: _data['path'] = path return self.client.perform_query('POST', '/dbfs/mkdirs', data=_data, headers=headers) - + + def mkdirs_test(self, path, headers=None): + _data = {} + if path is not None: + _data['path'] = path + return self.client.perform_query('POST', '/dbfs-testing/mkdirs', data=_data, headers=headers) + def move(self, source_path, destination_path, headers=None): _data = {} if source_path is not None: @@ -524,7 +559,15 @@ def move(self, source_path, destination_path, headers=None): if destination_path is not None: _data['destination_path'] = destination_path return self.client.perform_query('POST', '/dbfs/move', data=_data, headers=headers) - + + def move_test(self, source_path, destination_path, headers=None): + _data = {} + if source_path is not None: + _data['source_path'] = source_path + if destination_path is not None: + _data['destination_path'] = destination_path + return self.client.perform_query('POST', '/dbfs-testing/move', data=_data, headers=headers) + def delete(self, path, recursive=None, headers=None): _data = {} if path is not None: @@ -532,7 +575,15 @@ def delete(self, path, recursive=None, headers=None): if recursive is not None: _data['recursive'] = recursive return self.client.perform_query('POST', '/dbfs/delete', data=_data, headers=headers) - + + def delete_test(self, path, recursive=None, headers=None): + _data = {} + if path is not None: + _data['path'] = path + if recursive is not None: + _data['recursive'] = recursive + return self.client.perform_query('POST', '/dbfs-testing/delete', data=_data, headers=headers) + def create(self, path, overwrite=None, headers=None): _data = {} if path is not None: @@ -540,7 +591,15 @@ def create(self, path, overwrite=None, headers=None): if overwrite is not None: _data['overwrite'] = overwrite return self.client.perform_query('POST', '/dbfs/create', data=_data, headers=headers) - + + def create_test(self, path, overwrite=None, headers=None): + _data = {} + if path is not None: + _data['path'] = path + if overwrite is not None: + _data['overwrite'] = overwrite + return self.client.perform_query('POST', '/dbfs-testing/create', data=_data, headers=headers) + def add_block(self, handle, data, headers=None): _data = {} if handle is not None: @@ -548,13 +607,27 @@ def add_block(self, handle, data, headers=None): if data is not None: _data['data'] = data return self.client.perform_query('POST', '/dbfs/add-block', data=_data, headers=headers) - + + def add_block_test(self, handle, data, headers=None): + _data = {} + if handle is not None: + _data['handle'] = handle + if data is not None: + _data['data'] = data + return self.client.perform_query('POST', '/dbfs-testing/add-block', data=_data, headers=headers) + def close(self, handle, headers=None): _data = {} if handle is not None: _data['handle'] = handle return self.client.perform_query('POST', '/dbfs/close', data=_data, headers=headers) - + + def close_test(self, handle, headers=None): + _data = {} + if handle is not None: + _data['handle'] = handle + return self.client.perform_query('POST', '/dbfs-testing/close', data=_data, headers=headers) + class WorkspaceService(object): def __init__(self, client): @@ -565,13 +638,13 @@ def mkdirs(self, path, headers=None): if path is not None: _data['path'] = path return self.client.perform_query('POST', '/workspace/mkdirs', data=_data, headers=headers) - + def list(self, path, headers=None): _data = {} if path is not None: _data['path'] = path return self.client.perform_query('GET', '/workspace/list', data=_data, headers=headers) - + def import_workspace(self, path, format=None, language=None, content=None, overwrite=None, headers=None): _data = {} @@ -586,7 +659,7 @@ def import_workspace(self, path, format=None, language=None, content=None, overw if overwrite is not None: _data['overwrite'] = overwrite return self.client.perform_query('POST', '/workspace/import', data=_data, headers=headers) - + def export_workspace(self, path, format=None, direct_download=None, headers=None): _data = {} if path is not None: @@ -596,7 +669,7 @@ def export_workspace(self, path, format=None, direct_download=None, headers=None if direct_download is not None: _data['direct_download'] = direct_download return self.client.perform_query('GET', '/workspace/export', data=_data, headers=headers) - + def delete(self, path, recursive=None, headers=None): _data = {} if path is not None: @@ -604,20 +677,20 @@ def delete(self, path, recursive=None, headers=None): if recursive is not None: _data['recursive'] = recursive return self.client.perform_query('POST', '/workspace/delete', data=_data, headers=headers) - + def get_status(self, path, headers=None): _data = {} if path is not None: _data['path'] = path return self.client.perform_query('GET', '/workspace/get-status', data=_data, headers=headers) - + class SecretService(object): def __init__(self, client): self.client = client def create_scope(self, scope, initial_manage_principal=None, scope_backend_type=None, - headers=None): + backend_azure_keyvault=None, headers=None): _data = {} if scope is not None: _data['scope'] = scope @@ -625,19 +698,23 @@ def create_scope(self, scope, initial_manage_principal=None, scope_backend_type= _data['initial_manage_principal'] = initial_manage_principal if scope_backend_type is not None: _data['scope_backend_type'] = scope_backend_type + if backend_azure_keyvault is not None: + _data['backend_azure_keyvault'] = backend_azure_keyvault + if not isinstance(backend_azure_keyvault, dict): + raise TypeError('Expected databricks.AzureKeyVaultSecretScopeMetadata() or dict for field backend_azure_keyvault') return self.client.perform_query('POST', '/secrets/scopes/create', data=_data, headers=headers) - + def delete_scope(self, scope, headers=None): _data = {} if scope is not None: _data['scope'] = scope return self.client.perform_query('POST', '/secrets/scopes/delete', data=_data, headers=headers) - + def list_scopes(self, headers=None): _data = {} - + return self.client.perform_query('GET', '/secrets/scopes/list', data=_data, headers=headers) - + def put_secret(self, scope, key, string_value=None, bytes_value=None, headers=None): _data = {} if scope is not None: @@ -649,7 +726,7 @@ def put_secret(self, scope, key, string_value=None, bytes_value=None, headers=No if bytes_value is not None: _data['bytes_value'] = bytes_value return self.client.perform_query('POST', '/secrets/put', data=_data, headers=headers) - + def delete_secret(self, scope, key, headers=None): _data = {} if scope is not None: @@ -657,13 +734,13 @@ def delete_secret(self, scope, key, headers=None): if key is not None: _data['key'] = key return self.client.perform_query('POST', '/secrets/delete', data=_data, headers=headers) - + def list_secrets(self, scope, headers=None): _data = {} if scope is not None: _data['scope'] = scope return self.client.perform_query('GET', '/secrets/list', data=_data, headers=headers) - + def put_acl(self, scope, principal, permission, headers=None): _data = {} if scope is not None: @@ -673,7 +750,7 @@ def put_acl(self, scope, principal, permission, headers=None): if permission is not None: _data['permission'] = permission return self.client.perform_query('POST', '/secrets/acls/put', data=_data, headers=headers) - + def delete_acl(self, scope, principal, headers=None): _data = {} if scope is not None: @@ -681,13 +758,13 @@ def delete_acl(self, scope, principal, headers=None): if principal is not None: _data['principal'] = principal return self.client.perform_query('POST', '/secrets/acls/delete', data=_data, headers=headers) - + def list_acls(self, scope, headers=None): _data = {} if scope is not None: _data['scope'] = scope return self.client.perform_query('GET', '/secrets/acls/list', data=_data, headers=headers) - + def get_acl(self, scope, principal, headers=None): _data = {} if scope is not None: @@ -695,7 +772,7 @@ def get_acl(self, scope, principal, headers=None): if principal is not None: _data['principal'] = principal return self.client.perform_query('GET', '/secrets/acls/get', data=_data, headers=headers) - + class GroupsService(object): def __init__(self, client): @@ -706,7 +783,7 @@ def create_group(self, group_name, headers=None): if group_name is not None: _data['group_name'] = group_name return self.client.perform_query('POST', '/groups/create', data=_data, headers=headers) - + def add_to_group(self, parent_name, user_name=None, group_name=None, headers=None): _data = {} if user_name is not None: @@ -716,7 +793,7 @@ def add_to_group(self, parent_name, user_name=None, group_name=None, headers=Non if parent_name is not None: _data['parent_name'] = parent_name return self.client.perform_query('POST', '/groups/add-member', data=_data, headers=headers) - + def remove_from_group(self, parent_name, user_name=None, group_name=None, headers=None): _data = {} if user_name is not None: @@ -726,24 +803,24 @@ def remove_from_group(self, parent_name, user_name=None, group_name=None, header if parent_name is not None: _data['parent_name'] = parent_name return self.client.perform_query('POST', '/groups/remove-member', data=_data, headers=headers) - + def get_groups(self, headers=None): _data = {} - + return self.client.perform_query('GET', '/groups/list', data=_data, headers=headers) - + def get_group_members(self, group_name, headers=None): _data = {} if group_name is not None: _data['group_name'] = group_name return self.client.perform_query('GET', '/groups/list-members', data=_data, headers=headers) - + def remove_group(self, group_name, headers=None): _data = {} if group_name is not None: _data['group_name'] = group_name return self.client.perform_query('POST', '/groups/delete', data=_data, headers=headers) - + def get_groups_for_principal(self, user_name=None, group_name=None, headers=None): _data = {} if user_name is not None: @@ -751,7 +828,7 @@ def get_groups_for_principal(self, user_name=None, group_name=None, headers=None if group_name is not None: _data['group_name'] = group_name return self.client.perform_query('GET', '/groups/list-parents', data=_data, headers=headers) - + class TokenService(object): def __init__(self, client): @@ -764,12 +841,12 @@ def create_token(self, lifetime_seconds=None, comment=None, headers=None): if comment is not None: _data['comment'] = comment return self.client.perform_query('POST', '/token/create', data=_data, headers=headers) - + def list_tokens(self, headers=None): _data = {} - + return self.client.perform_query('GET', '/token/list', data=_data, headers=headers) - + def revoke_token(self, token_id, headers=None): _data = {} if token_id is not None: @@ -893,7 +970,7 @@ def create(self, id=None, name=None, storage=None, configuration=None, clusters= if allow_duplicate_names is not None: _data['allow_duplicate_names'] = allow_duplicate_names return self.client.perform_query('POST', '/pipelines', data=_data, headers=headers) - + def deploy(self, pipeline_id=None, id=None, name=None, storage=None, configuration=None, clusters=None, libraries=None, trigger=None, filters=None, allow_duplicate_names=None, headers=None): @@ -920,21 +997,29 @@ def deploy(self, pipeline_id=None, id=None, name=None, storage=None, configurati raise TypeError('Expected databricks.Filters() or dict for field filters') if allow_duplicate_names is not None: _data['allow_duplicate_names'] = allow_duplicate_names - return self.client.perform_query('PUT', - '/pipelines/{pipeline_id}'.format(pipeline_id=pipeline_id), - data=_data, headers=headers) - + return self.client.perform_query('PUT', '/pipelines/{pipeline_id}'.format(pipeline_id=pipeline_id), data=_data, headers=headers) + def delete(self, pipeline_id=None, headers=None): _data = {} - + return self.client.perform_query('DELETE', '/pipelines/{pipeline_id}'.format(pipeline_id=pipeline_id), data=_data, headers=headers) - + def get(self, pipeline_id=None, headers=None): _data = {} - + return self.client.perform_query('GET', '/pipelines/{pipeline_id}'.format(pipeline_id=pipeline_id), data=_data, headers=headers) - + def reset(self, pipeline_id=None, headers=None): _data = {} - + return self.client.perform_query('POST', '/pipelines/{pipeline_id}/reset'.format(pipeline_id=pipeline_id), data=_data, headers=headers) + + def run(self, pipeline_id=None, headers=None): + _data = {} + + return self.client.perform_query('POST', '/pipelines/{pipeline_id}/run'.format(pipeline_id=pipeline_id), data=_data, headers=headers) + + def stop(self, pipeline_id=None, headers=None): + _data = {} + + return self.client.perform_query('POST', '/pipelines/{pipeline_id}/stop'.format(pipeline_id=pipeline_id), data=_data, headers=headers) diff --git a/databricks_cli/secrets/api.py b/databricks_cli/secrets/api.py index 6ebf510c..c3487705 100644 --- a/databricks_cli/secrets/api.py +++ b/databricks_cli/secrets/api.py @@ -28,8 +28,10 @@ class SecretApi(object): def __init__(self, api_client): self.client = SecretService(api_client) - def create_scope(self, scope, initial_manage_principal): - return self.client.create_scope(scope, initial_manage_principal) + def create_scope(self, scope, initial_manage_principal, scope_backend_type, + backend_azure_keyvault): + return self.client.create_scope(scope, initial_manage_principal, + scope_backend_type, backend_azure_keyvault) def delete_scope(self, scope): return self.client.delete_scope(scope) diff --git a/databricks_cli/secrets/cli.py b/databricks_cli/secrets/cli.py index 485dfdb9..3695e9d4 100644 --- a/databricks_cli/secrets/cli.py +++ b/databricks_cli/secrets/cli.py @@ -50,15 +50,35 @@ ' principal for this option is the group "users", which contains all users in the' ' workspace. If not specified, the initial ACL with MANAGE permission applied to the' ' scope is assigned to the request issuer\'s user identity.') +@click.option('--scope-backend-type', + type=click.Choice(['AZURE_KEYVAULT', 'DATABRICKS'], case_sensitive=True), + default='DATABRICKS', help='The backend that will be used for this secret scope. ' + 'Options are (case-sensitive): 1) \'AZURE_KEYVAULT\' and ' + '2) \'DATABRICKS\' (default option)' + '\nNote: To create an Azure Keyvault, be sure ' + 'to configure an AAD Token using ' + '\'databricks-cli configure --aad-token\'') +@click.option('--resource-id', default=None, type=click.STRING, + help='The resource ID associated with the azure keyvault to be used as the backend' + ' for the secret scope. NOTE: Only use with azure-keyvault as backend') +@click.option('--dns-name', default=None, type=click.STRING, + help='The dns name associated with the azure keyvault to be used as the' + ' backed for the secret scope. NOTE: Only use with azure-keyvault as backend') @debug_option @profile_option @eat_exceptions @provide_api_client -def create_scope(api_client, scope, initial_manage_principal): +def create_scope(api_client, scope, initial_manage_principal, + scope_backend_type, resource_id, dns_name): """ Creates a new secret scope with given name. """ - SecretApi(api_client).create_scope(scope, initial_manage_principal) + backend_azure_keyvault = { + 'resource_id': resource_id, + 'dns_name': dns_name + } + SecretApi(api_client).create_scope(scope, initial_manage_principal, + scope_backend_type, backend_azure_keyvault) def _scopes_to_table(scopes_json): @@ -316,7 +336,7 @@ def get_acl(api_client, scope, principal, output): @debug_option @profile_option @eat_exceptions -def secrets_group(): +def secrets_group(): # pragma: no cover """ Utility to interact with secret API. """