# Update Clusters & Jobs ♻️

## Requirements
### Databricks
* A Databricks Workspace & Workspace Access Token
* At least one runnable cluster within the workspace


In [0]:
!pip install --upgrade databricks-sdk -q
!pip install loguru -q

In [0]:
dbutils.library.restartPython()

In [0]:
from pathlib import Path
import re

import pandas as pd
from loguru import logger

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.compute import (
    ClusterDetails,
    UpdateClusterResource,
    ListClustersFilterBy,
    ClusterSource,
    InitScriptInfo,
)
from databricks.sdk.service.jobs import Job, JobSettings, BaseJob, JobCluster

## Setup

| Parameter Name                  | Description                                                                                                        | Allowed Values                      |
| ------------------------------- | ------------------------------------------------------------------------------------------------------------------ | ----------------------------------- |
| `workspace_host`                | The **domain** of the Databricks workspace.                                                                        | `str`                               |
| `workspace_token`               | The **token** for accessing the Databricks Workspace API                                                           | `str`                               |
| `desired_runtime_version`       | The desired **Databricks Runtime Version** for the updated clusters/job clusters.                                  | `str` [Eg: `"15.4"`]                |
| `init_scripts_dir`              | Path to the common **directory with init scripts** on a Unity Catalog **Volume**                                   | `str`                               |
| `cluster_init_script_files`     | **Filenames** for the scripts to be used when initializing the **clusters**. Use `,` commas to separate files.     | `str` [Eg: `"S-154.sh, RE-154.sh"`] |
| `job_cluster_init_script_files` | **Filenames** for the scripts to be used when initializing the **job clusters**. Use `,` commas to separate files. | `str` [Eg: `"S-154.sh, RE-154.sh"`] |


In [0]:
dbutils.widgets.removeAll()

dbutils.widgets.text("workspace_host", "")
workspace_host: str = getArgument("workspace_host")

dbutils.widgets.text("workspace_token", "")
workspace_token: str = getArgument("workspace_token")

dbutils.widgets.text("desired_runtime_version", "")
desired_runtime_version: str = getArgument("desired_runtime_version")

dbutils.widgets.text("init_scripts_dir", "")
# Validate if directory exists and normalize the path
init_scripts_dir: str = str(Path(getArgument("init_scripts_dir")).resolve(strict=True))

dbutils.widgets.text("cluster_init_script_files", "")
cluster_init_script_files: list[str] = [
    filename.strip() for filename in getArgument("cluster_init_script_files").split(",")
]
# Validate if files exist and are not empty
assert all(
    (Path(init_scripts_dir) / file_name).exists()
    for file_name in cluster_init_script_files
), "One or more cluster init script files do not exist"

dbutils.widgets.text("job_cluster_init_script_files", "")
job_cluster_init_script_files: list[str] = [
    filename.strip()
    for filename in getArgument("job_cluster_init_script_files").split(",")
]
# Validate if files exist and are not empty
assert all(
    (Path(init_scripts_dir) / file_name).exists()
    for file_name in job_cluster_init_script_files
), "One or more job cluster init script files do not exist"

assert all(
    [
        workspace_host,
        workspace_token,
        desired_runtime_version,
        init_scripts_dir,
        cluster_init_script_files,
        job_cluster_init_script_files,
    ]
), "One or more required parameters for notebook functioning are missing"

logger.info(f"{workspace_host=}")
logger.info(f"{desired_runtime_version=}")
logger.info(f"{init_scripts_dir=}")
logger.info(f"{cluster_init_script_files=}")
logger.info(f"{job_cluster_init_script_files=}")

In [0]:
ws = WorkspaceClient(host=workspace_host, token=workspace_token)
logger.info(f"{ws.get_workspace_id()=}")

In [0]:
valid_workspace_versions: list[str] = sorted(
    list(
        set(
            [
                version_tuple.name.split(" ")[0]
                for version_tuple in ws.clusters.spark_versions().versions
            ]
        )
    )
)

logger.info(f"{len(valid_workspace_versions)=:,}")

assert (
    desired_runtime_version in valid_workspace_versions
), f"Invalid {desired_runtime_version=}"

## Init Scripts

In [0]:
def make_init_scripts(init_script_files: list[str]):
    return [
        InitScriptInfo.from_dict(
            {
                "volumes": {
                    "destination": str(Path(init_scripts_dir) / file_name),
                }
            }
        )
        for file_name in init_script_files
    ]


cluster_init_scripts = make_init_scripts(cluster_init_script_files)
job_cluster_init_scripts = make_init_scripts(job_cluster_init_script_files)

logger.info(f"{cluster_init_scripts=}")
logger.info(f"{job_cluster_init_scripts=}")

## Clusters

According to the SDK and REST API documentation:

- Clusters created as a result of a job cannot be updated via this endpoint. Only those created either via the `UI` or `API` can be changed.
- Those clusters that are `RUNNING` will be `TERMINATED` at the time of update and restart with the new configuration.

In [0]:
clusters = list(
    ws.clusters.list(
        filter_by=ListClustersFilterBy(
            cluster_sources=[ClusterSource.API, ClusterSource.UI]
        )
    )
)

logger.info(f"Found {len(clusters)} clusters")

In [0]:
pd.DataFrame([cluster.as_dict() for cluster in clusters])

In [0]:
# A dictionary which maps each cluster ID to parameters for the cluster update method
cluster_updates = {}

### Updating the Databricks Runtime Version

The runtime version is the `cluster.spark_version` field.

In [0]:
valid_versions = set(
    pd.DataFrame(
        [version.as_dict() for version in ws.clusters.spark_versions().versions]
    )["key"].tolist()
)


def get_updated_spark_version_key(
    spark_version_key: str, desired_runtime_version: str
) -> str:
    new_spark_version = re.sub(
        r"^\d{2}\.\d", desired_runtime_version, spark_version_key
    )

    if new_spark_version not in valid_versions:
        raise ValueError(f"Could not validate version '{new_spark_version}'")

    return new_spark_version


assert (
    get_updated_spark_version_key("11.3.x-photon-scala2.12", "15.4")
    == "15.4.x-photon-scala2.12"
)

In [0]:
def update_cluster_spark_version(cluster: ClusterDetails):
    cluster_updates[cluster.cluster_id] = {
        **(cluster_updates.get(cluster.cluster_id) or {}),
        "spark_version": get_updated_spark_version_key(
            cluster.spark_version, desired_runtime_version
        ),
    }


for cluster in clusters:
    update_cluster_spark_version(cluster)

### Update the Init Scripts

In [0]:
def update_cluster_init_scripts(
    cluster: ClusterDetails, init_scripts: list[InitScriptInfo]
):
    cluster_updates[cluster.cluster_id] = {
        **(cluster_updates.get(cluster.cluster_id) or {}),
        "init_scripts": init_scripts,
    }


for cluster in clusters:
    update_cluster_init_scripts(cluster, cluster_init_scripts)

### Execute the updates

In [0]:
clusters_to_update = clusters
names_for_clusters_that_failed_update = []

for cluster in clusters_to_update:
    cluster_id = cluster.cluster_id

    # Do not update the cluster which is running this notebook
    # because it will force a restart
    if cluster_id == spark.conf.get("spark.databricks.clusterUsageTags.clusterId"):
        logger.info(
            f"Skipping cluster: '{cluster.cluster_name}', because it is running this notebook"
        )
        continue

    updates = cluster_updates.get(cluster_id)

    if updates is None:
        continue

    update_mask = ",".join(updates.keys())

    try:
        ws.clusters.update(
            cluster_id=cluster_id,
            update_mask=update_mask,
            cluster=UpdateClusterResource(**updates),
        )
        logger.info(f"Updated cluster: '{cluster.cluster_name}'")
    except Exception as e:
        logger.error(f"Failed to update cluster: '{cluster.cluster_name}'")
        logger.error(e)
        names_for_clusters_that_failed_update.append(cluster.cluster_name)


cluster_update_failures = len(names_for_clusters_that_failed_update)
cluster_count = len(clusters_to_update)

if cluster_update_failures > 0:
    cluster_update_failure_message = (
        f"Failed to update {cluster_update_failures} of {cluster_count} cluster(s)"
    )
    if cluster_update_failures / len(clusters) >= 0.25:
        raise Exception(cluster_update_failure_message)

    logger.warning(cluster_update_failure_message)
else:
    logger.info(f"Updated all {cluster_count} cluster(s)")

## Jobs

In [0]:
jobs = list(ws.jobs.list(expand_tasks=True))
logger.info(f"Found {len(jobs)} jobs")

In [0]:
pd.DataFrame([job.as_dict() for job in jobs])

In [0]:
def update_job_clusters_spark_version(job: Job | BaseJob) -> Job | BaseJob:
    job_clusters = []
    for jc in job.settings.job_clusters:
        njc = jc.__class__.from_dict(jc.as_dict())
        njc.new_cluster.spark_version = get_updated_spark_version_key(
            njc.new_cluster.spark_version, desired_runtime_version
        )
        job_clusters.append(njc)

    new_job = job.__class__.from_dict(job.as_dict())
    new_job.settings.job_clusters = job_clusters
    return new_job

In [0]:
def update_job_clusters_init_scripts(
    job: Job | BaseJob, init_scripts: list[InitScriptInfo]
) -> Job | BaseJob:
    job_clusters = []
    for jc in job.settings.job_clusters:
        njc = jc.__class__.from_dict(jc.as_dict())
        njc.new_cluster.init_scripts = init_scripts
        job_clusters.append(njc)

    new_job = job.__class__.from_dict(job.as_dict())
    new_job.settings.job_clusters = job_clusters
    return new_job

In [0]:
names_for_jobs_that_failed_update = []

jobs_to_update = jobs

for job in jobs_to_update:
    njob = update_job_clusters_spark_version(job)
    njob = update_job_clusters_init_scripts(njob, job_cluster_init_scripts)

    new_settings = njob.settings.as_dict()
    new_settings = {
        k: v for k, v in new_settings.items() if k in ("job_clusters", "init_scripts")
    }
    new_settings = JobSettings.from_dict(new_settings)

    try:
        ws.jobs.update(job_id=job.job_id, new_settings=njob.settings)
        logger.info(f"Updated job: '{job.settings.name}'")
    except Exception as e:
        logger.error(f"Failed to update job: '{job.settings.name}'")
        logger.error(e)
        names_for_jobs_that_failed_update.append(job.settings.name)

job_update_failures = len(names_for_jobs_that_failed_update)
job_count = len(jobs_to_update)

if job_update_failures > 0:
    job_update_failure_message = (
        f"Failed to update {job_update_failures} of {job_count} job(s)"
    )
    if job_update_failures / len(jobs) >= 0.25:
        raise Exception(job_update_failure_message)

    logger.warning(job_update_failure_message)
else:
    logger.info(f"Updated all {job_count} job(s)")