# Ray on Spark - Cluster Setup Recommender

### Overview
* This script will provide recommendations for configurations for the `ray.util.spark.setup_ray_cluster()` command to launch a "Ray on Spark" cluster. 
* Attach it to any Classic All-Purpose Cluster during development; you can then take the recommendations after testing to an automated Job Cluster. 
* *As of August 2025, Ray on Spark will not work on Serverless clusters*
* The setup script will also confirm baseline cluster settings (such as runtime version, security mode).

### Steps
* Attach this script to a Classic All-Purpose cluster
* Click "Run all"
* Copy/Paste the [Ray on Spark setup command](https://docs.databricks.com/aws/en/machine-learning/ray/ray-create) at the end into a different notebook, then run (on the same cluster).
* Continue to modify this baseline setup as your workload evolves

In [0]:
%pip install -qqq --upgrade databricks-sdk 
%restart_python

In [0]:
from databricks.sdk import WorkspaceClient
from pprint import pprint
import re

def get_workspace_client() -> WorkspaceClient: 
    """
    Returns an authenticated WorkspaceClient using the current Databricks notebook context.
    """
    ctx = dbutils.notebook.entry_point.getDbutils().notebook().getContext()
    DATABRICKS_TOKEN = ctx.apiToken().getOrElse(None)
    DATABRICKS_URL = ctx.apiUrl().getOrElse(None)
    return WorkspaceClient(host=DATABRICKS_URL, token=DATABRICKS_TOKEN)
  
def get_cluster_id() -> str:
    """
    Returns the cluster ID of the current Databricks notebook context.
    """
    return dbutils.notebook.entry_point.getDbutils().notebook().getContext().clusterId().get()
  
w = get_workspace_client()

current_cluster = w.clusters.get(get_cluster_id())
# pprint(current_cluster)

In [0]:
def check_security_mode(current_cluster):
    """
    Checks if the current cluster's data_security_mode is allowed for running Ray on Spark.
    Raises a ValueError if the mode is not permitted.
    """
    allowed_modes = ["USER_ISOLATION", "SINGLE_USER", "DEDICATED", "NONE"]
    pattern = re.compile(rf"({'|'.join(allowed_modes)})")

    data_security_mode_str = str(current_cluster.data_security_mode)

    if not pattern.search(data_security_mode_str):
        raise ValueError(
            f"data_security_mode '{data_security_mode_str}' is not allowed for Ray on Spark clusters. Allowed: {allowed_modes}. See docs: https://docs.databricks.com/aws/en/machine-learning/ray/#limitations"
        )
    else:
        print(f"Clusters with data_security_mode '{data_security_mode_str}' can be used to create Ray on Spark clusters.")

check_security_mode(current_cluster=current_cluster)

In [0]:
def check_runtime_version(current_cluster):
    """
    Checks if the current cluster's spark_version meets the minimum required version for running Ray on Spark.
    Raises a ValueError if the version is not permitted. Also recommends using the ML Runtime for best compatibility.
    """
    allowed_starting_version = ["12.2"]
    # Extract numeric parts from spark_version (e.g., "12.2.x-cpu-ml-scala2.12" -> "12.2")
    version_match = re.match(r"(\d+\.\d+)", str(current_cluster.spark_version))
    if not version_match:
        raise ValueError(f"Could not parse spark_version: {current_cluster.spark_version}")
    spark_version_num = version_match.group(1)

    # Compare as floats
    if float(spark_version_num) < float(allowed_starting_version[0]):
        raise ValueError(
            f"Cluster spark_version '{current_cluster.spark_version}' is less than the required version {allowed_starting_version[0]}"
        )
    else:
        print(f"Cluster Runtime '{current_cluster.spark_version}' can be used to run Ray on Spark. Consider upgrading to latest stable LTS ML Runtime for the best performance. ")
    
    if not getattr(current_cluster, "use_ml_runtime", False):
        print("Your cluster is not using ML Runtime. Recommend upgrading to most recent LTS Machine Learning Runtime where Ray and other dependencies are pre-installed")
    
check_runtime_version(current_cluster=current_cluster)

In [0]:
def check_single_node(current_cluster):
  """
  Checks if the current cluster is a single-node cluster.
  Raises a ValueError if the cluster is single-node, as Ray on Spark setup is intended for multi-node clusters.
  """
  if getattr(current_cluster, "is_single_node", False):
    raise ValueError("This script is intended to determine setup for a multi-node cluster to use Ray on Spark. This is a single-node cluster. To use ray, just run ray.init(). See Ray docs for more info.")

check_single_node(current_cluster=current_cluster)

In [0]:
def check_gpu(current_cluster):
    """
    Checks if the current cluster is using GPU-enabled instance types based on the cloud provider and node type.
    Returns True if GPUs are detected, otherwise False.
    """
    gpus = {
        "gpu_instance_types": {
            "aws": ["p5", "p4", "g6e", "g6", "g5", "g4dn", "p3"],
            "azure": [
                "Standard_NC40ads_H100_v5",
                "Standard_NC80adis_H100_v5",
                "Standard_NC24ads_A100_v4",
                "Standard_NC48ads_A100_v4",
                "Standard_NC96ads_A100_v4",
                "Standard_ND96asr_v4",
                "Standard_NV36ads_A10_v5",
                "Standard_NV36adms_A10_v5",
                "Standard_NV72ads_A10_v5",
                "Standard_NC4as_T4_v3",
                "Standard_NC8as_T4_v3",
                "Standard_NC16as_T4_v3",
                "Standard_NC64as_T4_v3",
                "Standard_NC6s_v3",
                "Standard_NC12s_v3",
                "Standard_NC24s_v3",
                "Standard_NC24rs_v3",
            ],
            "gcp": [
                "a2-ultragpu-8g",
                "a2-highgpu-1g",
                "a2-highgpu-2g",
                "a2-highgpu-4g",
                "a2-megagpu-16g",
                "g2-standard-8",
            ],
        }
    }

    if hasattr(current_cluster, "aws_attributes"):
        current_cloud = "aws"
    elif hasattr(current_cluster, "azure_attributes"):
        current_cloud = "azure"
    elif hasattr(current_cluster, "gcp_attributes"):
        current_cloud = "gcp"
    else:
        current_cloud = None

    has_gpu = False
    if current_cloud and hasattr(current_cluster, "node_type_id"):
        node_type = str(current_cluster.node_type_id).lower()
        gpu_types = [t.lower() for t in gpus["gpu_instance_types"].get(current_cloud, [])]
        for gpu_type in gpu_types:
            if gpu_type in node_type:
                has_gpu = True
                # print("Using instances with GPUs, additional setup required.")
                break

    # else:
      # print("CPU-only instances.")
    return has_gpu

check_gpu(current_cluster=current_cluster)

In [0]:
# TODO: Change to function input
spark_share = 0.0

setup_cmd = """
>>> Use setup command >>>
setup_ray_cluster(
"""

print("Observations for setup script recommendation:")
# STEP 1: Determine min and max worker nodes
## Autoscaling = FALSE
if current_cluster.autoscale:
  print(" - Autoscaling cluster")
  min_workers = worker_nodes = current_cluster.autoscale.min_workers
  max_workers = current_cluster.autoscale.max_workers

  setup_cmd += f"""  min_worker_nodes={min_workers},
  max_worker_nodes={max_workers},
  """
## Autoscaling = TRUE
else:
  print(" - Non-Autoscaling cluster")
  worker_nodes = current_cluster.num_workers

  setup_cmd += f"""  min_worker_nodes={worker_nodes},
  max_worker_nodes={worker_nodes},
  """


# STEP 2: Determine if Driver and Worker nodes match (homogenous cluster)
worker_driver_match = current_cluster.driver_node_type_id == current_cluster.node_type_id
## Worker Driver Match = FALSE
if not worker_driver_match:
  print(" - Heterogenous cluster, Driver and Workers are different instance types")
  driver_cores = int(current_cluster.cluster_cores - spark.sparkContext.defaultParallelism)
  worker_cores = int(spark.sparkContext.defaultParallelism / worker_nodes)

  setup_cmd += f"""num_cpus_worker_node={worker_cores},
  num_cpus_head_node={driver_cores},
  """
## Homogenouse cluster
else:
  print(" - Homogenous cluster, Driver and Workers are same instance type:")
  worker_nodes = current_cluster.num_workers
  cores_per_node = int(current_cluster.cluster_cores/(worker_nodes+1))
  
  setup_cmd += f"""num_cpus_worker_node={cores_per_node}
  num_cpus_head_node={cores_per_node},
  """


# STEP 3: Determine if GPUs onboard
# TODO: Update and test this
if check_gpu(current_cluster):
  print(" - GPU cluster; please manually configure GPUs per worker node")

  setup_cmd += f"""num_gpus_worker_node=1,
  num_gpus_head_node=0,
  """


setup_cmd += """head_node_options={
      'dashboard_port': 9999,
      'include_dashboard':True,
    }
)
"""
print(setup_cmd)

# STEP 4: Determine if Spark Share is enabled
if spark_share > 0.0:
  print(" ^ Determine how many resources to give to Spark, then decrease the values of num_cpus_worker_node and num_cputs_head_node to reserve resources for Spark")



## End 
Copy the `setup_ray_cluster()` command printed after running the previous cell.