In [0]:
import json
import os
import requests
import requests.packages.urllib3

global pprint_j

requests.packages.urllib3.disable_warnings()


# Helper to pretty print json
def pprint_j(i):
    print(json.dumps(i, indent=4, sort_keys=True))


class dbclient:
    """
    Rest API Wrapper for Databricks APIs
    """
    # set of http error codes to throw an exception if hit. Handles client and auth errors
    http_error_codes = (401, 403)

    def __init__(self, token, url):
        self._token = {'Authorization': 'Bearer {0}'.format(token)}
        self._url = url.rstrip("/")
        self._is_verbose = False
        self._verify_ssl = False
        if self._verify_ssl:
            # set these env variables if skip SSL verification is enabled
            os.environ['REQUESTS_CA_BUNDLE'] = ""
            os.environ['CURL_CA_BUNDLE'] = ""

    def is_aws(self):
        return self._is_aws

    def is_verbose(self):
        return self._is_verbose

    def is_skip_failed(self):
        return self._skip_failed

    def test_connection(self):
        # verify the proper url settings to configure this client
        if self._url[-4:] != '.com' and self._url[-4:] != '.net':
            print("Hostname should end in '.com'")
            return -1
        results = requests.get(self._url + '/api/2.0/clusters/spark-versions', headers=self._token,
                               verify=self._verify_ssl)
        http_status_code = results.status_code
        if http_status_code != 200:
            print("Error. Either the credentials have expired or the credentials don't have proper permissions.")
            print("If you have a ~/.netrc file, check those credentials. Those take precedence over passed input.")
            print(results.text)
            return -1
        return 0

    def get(self, endpoint, json_params=None, version='2.0', print_json=False):
        if version:
            ver = version
        full_endpoint = self._url + '/api/{0}'.format(ver) + endpoint
        if self.is_verbose():
            print("Get: {0}".format(full_endpoint))
        if json_params:
            raw_results = requests.get(full_endpoint, headers=self._token, params=json_params, verify=self._verify_ssl)
            http_status_code = raw_results.status_code
            if http_status_code in dbclient.http_error_codes:
                raise Exception("Error: GET request failed with code {}\n{}".format(http_status_code, raw_results.text))
            results = raw_results.json()
        else:
            raw_results = requests.get(full_endpoint, headers=self._token, verify=self._verify_ssl)
            http_status_code = raw_results.status_code
            if http_status_code in dbclient.http_error_codes:
                raise Exception("Error: GET request failed with code {}\n{}".format(http_status_code, raw_results.text))
            results = raw_results.json()
        if print_json:
            print(json.dumps(results, indent=4, sort_keys=True))
        if type(results) == list:
            results = {'elements': results}
        results['http_status_code'] = raw_results.status_code
        return results

    def http_req(self, http_type, endpoint, json_params, version='2.0', print_json=False, files_json=None):
        if version:
            ver = version
        full_endpoint = self._url + '/api/{0}'.format(ver) + endpoint
        if self.is_verbose():
            print("{0}: {1}".format(http_type, full_endpoint))
        if json_params:
            if http_type == 'post':
                if files_json:
                    raw_results = requests.post(full_endpoint, headers=self._token,
                                                data=json_params, files=files_json, verify=self._verify_ssl)
                else:
                    raw_results = requests.post(full_endpoint, headers=self._token,
                                                json=json_params, verify=self._verify_ssl)
            if http_type == 'put':
                raw_results = requests.put(full_endpoint, headers=self._token,
                                           json=json_params, verify=self._verify_ssl)
            if http_type == 'patch':
                raw_results = requests.patch(full_endpoint, headers=self._token,
                                             json=json_params, verify=self._verify_ssl)
            
            http_status_code = raw_results.status_code
            if http_status_code in dbclient.http_error_codes:
                raise Exception("Error: {0} request failed with code {1}\n{2}".format(http_type,
                                                                                      http_status_code,
                                                                                      raw_results.text))
            results = raw_results.json()
        else:
            print("Must have a payload in json_args param.")
            return {}
        if print_json:
            print(json.dumps(results, indent=4, sort_keys=True))
        # if results are empty, let's return the return status
        if results:
            results['http_status_code'] = raw_results.status_code
            return results
        else:
            return {'http_status_code': raw_results.status_code}

    def post(self, endpoint, json_params, version='2.0', print_json=False, files_json=None):
        return self.http_req('post', endpoint, json_params, version, print_json, files_json)

    def put(self, endpoint, json_params, version='2.0', print_json=False):
        return self.http_req('put', endpoint, json_params, version, print_json)

    def patch(self, endpoint, json_params, version='2.0', print_json=False):
        return self.http_req('patch', endpoint, json_params, version, print_json)

    @staticmethod
    def my_map(F, items):
        to_return = []
        for elem in items:
            to_return.append(F(elem))
        return to_return

    def set_export_dir(self, dir_location):
        self._export_dir = dir_location

    def get_export_dir(self):
        return self._export_dir

    def get_latest_spark_version(self):
        versions = self.get('/clusters/spark-versions')['versions']
        v_sorted = sorted(versions, key=lambda i: i['key'], reverse=True)
        for x in v_sorted:
            img_type = x['key'].split('-')[1][0:5]
            if img_type == 'scala':
                return x


In [0]:
class scimclient(dbclient):
    __cached_users = []

    def get_all_users(self):
        """
        Get list of all users today
        :return: List of registered users
        """
        all_users = self.get('/preview/scim/v2/Users').get('Resources', None)
        return all_users
    
    def get_homedir_id(self, email):
      resp = self.get('/workspace/get-status', {'path': '/Users/' + email.rstrip().lstrip()})
      id = resp.get('object_id', None)
      return id
    
    def get_acl_perms(self, email):
      acl_args = {'access_control_list' : 
                    [{ "user_name": email,
                        "permission_level": "CAN_MANAGE"
                     }]
                 }
      return acl_args
    
    def grant_access_notebooks(self, new_email, old_email):
      directory_id = self.get_homedir_id(old_email)
      perms_endpoint = f'/permissions/directories/{directory_id}'
      print(perms_endpoint)
      # allow new email access to old notebooks 
      acl_args = self.get_acl_perms(new_email)
      resp = self.patch(perms_endpoint, acl_args)
      return resp 
    
    def get_users_jobs(self, old_email):
      jobs_list = self.get('/jobs/list').get('jobs', None)
      if jobs_list:
        users_jobs = list(filter(lambda x: x.get('creator_user_name', "") == old_email, jobs_list))
        return users_jobs
      return []
            
    def grant_access_jobs(self, new_email, old_email):
      jobs_list = self.get_users_jobs(old_email)
      jobs_updated = []
      for job in jobs_list:
        job_id = job['job_id']
        jobs_endpoint = f'/permissions/jobs/{job_id}'
        # allow new email access to old jobs
        acl_args = self.get_acl_perms(new_email)
        resp = self.patch(jobs_endpoint, acl_args)
        jobs_updated.append(resp.get('object_id'))
      return jobs_updated
    
    def get_clusters_by_user(self, old_email):
      cluster_list = self.get('/clusters/list').get('clusters', None)
      user_clusters = []
      for cluster in cluster_list:
        if cluster['creator_user_name'] == old_email:
          user_clusters.append(cluster)
      return user_clusters
    
    def grant_access_clusters(self, new_email, old_email):
      user_clusters = self.get_clusters_by_user(old_email)
      clusters_updated = []
      for cluster in user_clusters:
        cluster_id = cluster['cluster_id']
        clusters_endpoint = f'/permissions/clusters/{cluster_id}'
        acl_args = self.get_acl_perms(new_email)
        resp = self.patch(clusters_endpoint, acl_args)
        clusters_updated.append(resp.get('object_id'))
      return clusters_updated
    
    def get_user_cache(self):
      return self.__cached_users
    
    def cache_users(self):
      users = self.get('/preview/scim/v2/Users').get('Resources', [])
      self.__cached_users = users
      return True
    
    def get_user_from_cache(self, old_email):
      for user in self.get_user_cache():
        user_email = user.get('emails', '')[0].get('value', '')
        if old_email == user_email:
          #print(json.dumps(user, indent=2))
          return user
          break
      return None
    
    def get_user_entitlements(self, old_email):
      if not self.get_user_cache():
        self.cache_users()
      else:
        print("Cache exists...")
      user_info = self.get_user_from_cache(old_email)
      return user_info.get('entitlements', [])
    
    def get_user_groups(self, old_email):
      if not self.get_user_cache():
        self.cache_users()
      else:
        print("Cache exists...")
      user_info = self.get_user_from_cache(old_email)
      user_groups = user_info.get('groups', [])
      group_ids = []
      for g in user_groups:
        gid = g.get('value', '')
        if gid:
          group_ids.append(gid)
      return group_ids 
      
    def get_user_create_args(self, new_email, old_email):
      gid_list = self.get_user_groups(old_email)
      entitlements = self.get_user_entitlements(old_email)
      groups_arg = list(map(lambda x: {'value': x}, gid_list))
      user_create_args = {"schemas":[ "urn:ietf:params:scim:schemas:core:2.0:User" ],
          "userName": new_email,
          "groups": groups_arg,
        "entitlements": entitlements
      }
      return user_create_args 
    
    def create_user(self, new_email, old_email):
      create_args = self.get_user_create_args(new_email, old_email)
      resp = self.post('/preview/scim/v2/Users', create_args)
      print(resp)
      return resp

In [0]:
url = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().getOrElse(None) 
token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)

client = scimclient(token, url)

In [0]:
# lets create a dictionary to create new users / existing user workspace for us to write a script 
# to create new resources and grant permissions to old workspaces 
# { new_email : old_email } 

user_dict = {
  'foobar@databricks.com': 'new_foobar@databricks.com',
}

In [0]:
for new_email, old_email in user_dict.items():
  print(f"Creating user {new_email} using old user entitlements {old_email}")
  resp = client.create_user(new_email, old_email)
  print(json.dumps(resp, indent=2))

In [0]:
for new_email, old_email in user_dict.items():
  print(f"Granting {new_email} access to their old workspace {old_email}")
  workspace_path = f'/Users/{old_email}'
  resp = client.grant_access_notebooks(new_email, old_email)
  print(resp, '\n')

In [0]:
for new_email, old_email in user_dict.items():
  print(f"Granting {new_email} access to their old clusters {old_email}")
  updated_clusters = client.grant_access_clusters(new_email, old_email)
  print('Clusters granted access:', updated_clusters, '\n')

In [0]:
for new_email, old_email in user_dict.items():
  print(f"Granting {new_email} access to their old job templates {old_email}")
  resp = client.grant_access_jobs(new_email, old_email)
  print('Jobs granted access:', resp, '\n')