In [0]:
long_running_duration = 24 # units is hours 

In [0]:
import json, os, datetime, requests
import requests.packages.urllib3

global pprint_j

requests.packages.urllib3.disable_warnings()


# Helper to pretty print json
def pprint_j(i):
    print(json.dumps(i, indent=4, sort_keys=True))


class dbclient:
    """
    Rest API Wrapper for Databricks APIs
    """
    # set of http error codes to throw an exception if hit. Handles client and auth errors
    http_error_codes = (401, 403)

    def __init__(self, token, url):
        self._token = {'Authorization': 'Bearer {0}'.format(token)}
        self._url = url.rstrip("/")
        self._is_verbose = False
        self._verify_ssl = False
        if self._verify_ssl:
            # set these env variables if skip SSL verification is enabled
            os.environ['REQUESTS_CA_BUNDLE'] = ""
            os.environ['CURL_CA_BUNDLE'] = ""

    def is_aws(self):
        return self._is_aws

    def is_verbose(self):
        return self._is_verbose

    def is_skip_failed(self):
        return self._skip_failed

    def test_connection(self):
        # verify the proper url settings to configure this client
        if self._url[-4:] != '.com' and self._url[-4:] != '.net':
            print("Hostname should end in '.com'")
            return -1
        results = requests.get(self._url + '/api/2.0/clusters/spark-versions', headers=self._token,
                               verify=self._verify_ssl)
        http_status_code = results.status_code
        if http_status_code != 200:
            print("Error. Either the credentials have expired or the credentials don't have proper permissions.")
            print("If you have a ~/.netrc file, check those credentials. Those take precedence over passed input.")
            print(results.text)
            return -1
        return 0

    def get(self, endpoint, json_params=None, version='2.0', print_json=False):
        if version:
            ver = version
        full_endpoint = self._url + '/api/{0}'.format(ver) + endpoint
        if self.is_verbose():
            print("Get: {0}".format(full_endpoint))
        if json_params:
            raw_results = requests.get(full_endpoint, headers=self._token, params=json_params, verify=self._verify_ssl)
            http_status_code = raw_results.status_code
            if http_status_code in dbclient.http_error_codes:
                raise Exception("Error: GET request failed with code {}\n{}".format(http_status_code, raw_results.text))
            results = raw_results.json()
        else:
            raw_results = requests.get(full_endpoint, headers=self._token, verify=self._verify_ssl)
            http_status_code = raw_results.status_code
            if http_status_code in dbclient.http_error_codes:
                raise Exception("Error: GET request failed with code {}\n{}".format(http_status_code, raw_results.text))
            results = raw_results.json()
        if print_json:
            print(json.dumps(results, indent=4, sort_keys=True))
        if type(results) == list:
            results = {'elements': results}
        results['http_status_code'] = raw_results.status_code
        return results

    def http_req(self, http_type, endpoint, json_params, version='2.0', print_json=False, files_json=None):
        if version:
            ver = version
        full_endpoint = self._url + '/api/{0}'.format(ver) + endpoint
        if self.is_verbose():
            print("{0}: {1}".format(http_type, full_endpoint))
        if json_params:
            if http_type == 'post':
                if files_json:
                    raw_results = requests.post(full_endpoint, headers=self._token,
                                                data=json_params, files=files_json, verify=self._verify_ssl)
                else:
                    raw_results = requests.post(full_endpoint, headers=self._token,
                                                json=json_params, verify=self._verify_ssl)
            if http_type == 'put':
                raw_results = requests.put(full_endpoint, headers=self._token,
                                           json=json_params, verify=self._verify_ssl)
            if http_type == 'patch':
                raw_results = requests.patch(full_endpoint, headers=self._token,
                                             json=json_params, verify=self._verify_ssl)
            
            http_status_code = raw_results.status_code
            if http_status_code in dbclient.http_error_codes:
                raise Exception("Error: {0} request failed with code {1}\n{2}".format(http_type,
                                                                                      http_status_code,
                                                                                      raw_results.text))
            results = raw_results.json()
        else:
            print("Must have a payload in json_args param.")
            return {}
        if print_json:
            print(json.dumps(results, indent=4, sort_keys=True))
        # if results are empty, let's return the return status
        if results:
            results['http_status_code'] = raw_results.status_code
            return results
        else:
            return {'http_status_code': raw_results.status_code}

    def post(self, endpoint, json_params, version='2.0', print_json=False, files_json=None):
        return self.http_req('post', endpoint, json_params, version, print_json, files_json)

    def put(self, endpoint, json_params, version='2.0', print_json=False):
        return self.http_req('put', endpoint, json_params, version, print_json)

    def patch(self, endpoint, json_params, version='2.0', print_json=False):
        return self.http_req('patch', endpoint, json_params, version, print_json)

    @staticmethod
    def my_map(F, items):
        to_return = []
        for elem in items:
            to_return.append(F(elem))
        return to_return

    def set_export_dir(self, dir_location):
        self._export_dir = dir_location

    def get_export_dir(self):
        return self._export_dir

    def get_latest_spark_version(self):
        versions = self.get('/clusters/spark-versions')['versions']
        v_sorted = sorted(versions, key=lambda i: i['key'], reverse=True)
        for x in v_sorted:
            img_type = x['key'].split('-')[1][0:5]
            if img_type == 'scala':
                return x


In [0]:
class jobsclient(dbclient):
    
    def get_jobs_list(self):
      job_list = self.get('/jobs/list').get('jobs', [])
      return job_list
    
    def get_current_job_runs(self):
      job_run_list = self.get('/jobs/runs/list').get('runs', [])
      running_jobs_list = list(filter(lambda x: x.get('state').get('life_cycle_state') == 'RUNNING', job_run_list))
      return running_jobs_list
    
    def get_job_id_by_name(self, current_job_name):
      jobs = self.get_jobs_list()
      for job in jobs:
        job_id = job['job_id']
        job_name = job['settings']['name']
        if current_job_name == job_name:
          return job_id
      return None
  
    def run_job_by_name(self, current_job_name, job_args = None):
      current_job_id = self.get_job_id_by_name(current_job_name)
      if job_args:
        job_args_json = { 
          "job_id": current_job_id,
          "notebook_params": job_args
        }
      else:
        job_args_json = { "job_id": current_job_id }
      submit_resp = self.post('/jobs/run-now', job_args_json)
      return submit_resp 

    def poll_status_of_jobs(self, run_id_list):
      for run_id in run_id_list:
        run_endpoint = f'/jobs/runs/get?run_id={run_id}'
        status = self.get(run_endpoint)
        print(status['state'])
      
    def cancel_job(self, run_id):
      resp = self.post('/jobs/runs/cancel', {'run_id': run_id})
      print('Cancelling: ', resp)
      return resp
      
    def get_job_duration(self, job_json):
      start_time = job_json.get('start_time')
      start_dt_obj = datetime.datetime.strptime(str(datetime.datetime.utcfromtimestamp(start_time // 1000)), '%Y-%m-%d %H:%M:%S')
      now = datetime.datetime.utcnow()
      rt = now - start_dt_obj
      hours_run = rt.total_seconds() / 3600
      return hours_run

In [0]:
url = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().getOrElse(None) 
token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)

client = jobsclient(token, url)

In [0]:
running_jobs_list = client.get_current_job_runs()

long_running_jobs = []

for job in running_jobs_list:
  hours_run = client.get_job_duration(job)
  print('Job ID: {0} \t Name: {1} \t Duration: {2}'.format(job['job_id'], job['run_name'], str(hours_run)))
  if hours_run > long_running_duration:
    job['hours_run'] = hours_run 
    long_running_jobs.append(job)

In [0]:
for x in long_running_jobs:
  print(json.dumps(x, indent=4))

In [0]:
for x in long_running_jobs:
  print("Job Creator: {0} \t Job Name: {1} \t Job ID: {2} \t Runtime (hours): {3}".format(x['creator_user_name'], x['run_name'], x['job_id'], x['hours_run']))

if long_running_jobs:
  raise ValueError('Long running jobs on the platform. Please check job output above for offending jobs')