diff --git a/README.md b/README.md index da274ce..a56aa91 100644 --- a/README.md +++ b/README.md @@ -19,11 +19,10 @@ The following backends are implemented: `dagster-ray` is tested across multiple version combinations of components such as `ray`, `dagster`, `KubeRay Operator`, and `Python`. -`dagster-ray` integrates with `Dagster Plus` out of the box by using environment variables such as `DAGSTER_CLOUD_DEPLOYMENT_NAME` and tags such as `dagster/user`. +`dagster-ray` integrates with [Dagster+](https://dagster.io/plus) out of the box. Documentation can be found below. - > [!NOTE] > This project is in early development. Contributions are very welcome! See the [Development](#development) section below. @@ -116,6 +115,8 @@ definitions = Definitions( This backend requires a Kubernetes cluster with the `KubeRay Operator` installed. +Integrates with [Dagster+](https://dagster.io/plus) by injecting environment variables such as `DAGSTER_CLOUD_DEPLOYMENT_NAME` and tags such as `dagster/user` into default configuration values and `RayCluster` labels. + The public objects can be imported from `dagster_ray.kuberay` module. ### Resources @@ -168,6 +169,7 @@ ray_cluster = KubeRayCluster( "replicas": 2, "minReplicas": 1, "maxReplicas": 10, + # ... } ], ) diff --git a/dagster_ray/kuberay/configs.py b/dagster_ray/kuberay/configs.py index f170504..e7a2e81 100644 --- a/dagster_ray/kuberay/configs.py +++ b/dagster_ray/kuberay/configs.py @@ -17,6 +17,7 @@ } DEFAULT_HEAD_GROUP_SPEC = { "serviceType": "ClusterIP", + "rayStartParams": {}, "metadata": { "labels": {}, "annotations": {}, diff --git a/dagster_ray/kuberay/resources.py b/dagster_ray/kuberay/resources.py index fdb7a85..18573d6 100644 --- a/dagster_ray/kuberay/resources.py +++ b/dagster_ray/kuberay/resources.py @@ -1,5 +1,6 @@ import contextlib import hashlib +import os import random import re import string @@ -110,6 +111,8 @@ def yield_for_execution(self, context: InitResourceContext) -> Generator[Self, N assert context.log is not None assert context.dagster_run is not None + self.api.setup_for_execution(context) + self._cluster_name = self._get_ray_cluster_step_name(context) # self._host = f"{self.cluster_name}-head-svc.{self.namespace}.svc.cluster.local" @@ -125,7 +128,12 @@ def yield_for_execution(self, context: InitResourceContext) -> Generator[Self, N labels=self._get_labels(context), ) - self.api.kuberay.create_ray_cluster(body=cluster_body, k8s_namespace=self.namespace) + resource = self.api.kuberay.create_ray_cluster(body=cluster_body, k8s_namespace=self.namespace) + if resource is None: + raise Exception( + f"Couldn't create RayCluster {self.namespace}/{self.cluster_name}! Reason logged above." + ) + context.log.info( f"Created RayCluster {self.namespace}/{self.cluster_name}. Waiting for it to become ready..." ) @@ -154,7 +162,7 @@ def yield_for_execution(self, context: InitResourceContext) -> Generator[Self, N yield self except Exception as e: - context.log.critical("Couldn't create or connect to RayCluster!") + context.log.critical(f"Couldn't create or connect to RayCluster {self.namespace}/{self.cluster_name}!") self._maybe_cleanup_raycluster(context) raise e @@ -175,6 +183,12 @@ def _get_labels(self, context: InitResourceContext) -> Dict[str, str]: if context.dagster_run.tags.get("user"): labels["dagster.io/user"] = context.dagster_run.tags["user"] + if os.getenv("DAGSTER_CLOUD_GIT_BRANCH"): + labels["dagster.io/git-branch"] = os.environ["DAGSTER_CLOUD_GIT_BRANCH"] + + if os.getenv("DAGSTER_CLOUD_GIT_SHA"): + labels["dagster.io/git-sha"] = os.environ["DAGSTER_CLOUD_GIT_SHA"] + return labels def _build_raycluster( @@ -187,6 +201,9 @@ def _build_raycluster( """ # TODO: inject self.redis_port and self.dashboard_port into the RayCluster configuration + labels = labels or {} + assert isinstance(labels, dict) + image = self.ray_cluster.image or image head_group_spec = self.ray_cluster.head_group_spec.copy() worker_group_specs = self.ray_cluster.worker_group_specs.copy() @@ -197,12 +214,12 @@ def update_group_spec(group_spec: Dict[str, Any]): group_spec["template"]["spec"]["containers"][0]["image"] = image if group_spec.get("metadata") is None: - group_spec["metadata"] = {"labels": labels or {}} + group_spec["metadata"] = {"labels": labels} else: if group_spec["metadata"].get("labels") is None: - group_spec["metadata"]["labels"] = labels or {} + group_spec["metadata"]["labels"] = labels else: - group_spec["metadata"]["labels"].update(labels or {}) + group_spec["metadata"]["labels"].update(labels) update_group_spec(head_group_spec) for worker_group_spec in worker_group_specs: diff --git a/tests/test_kuberay.py b/tests/test_kuberay.py index 4e7093c..f98c512 100644 --- a/tests/test_kuberay.py +++ b/tests/test_kuberay.py @@ -17,6 +17,7 @@ from dagster_ray.kuberay import KubeRayAPI, KubeRayCluster, cleanup_kuberay_clusters from dagster_ray.kuberay.configs import DEFAULT_HEAD_GROUP_SPEC, DEFAULT_WORKER_GROUP_SPECS, RayClusterConfig from dagster_ray.kuberay.ops import CleanupKuberayClustersConfig +from dagster_ray.kuberay.ray_cluster_api import PatchedRayClusterApi from tests import ROOT_DIR @@ -229,11 +230,18 @@ def my_asset(context: AssetExecutionContext, ray_cluster: RayResource) -> None: resources={"ray_cluster": ray_cluster_resource}, ) + kuberay_api = PatchedRayClusterApi(config_file=str(k8s_with_raycluster.kubeconfig)) + # make sure the RayCluster is cleaned up - assert not ray_cluster_resource.api.kuberay.list_ray_clusters( - k8s_namespace=ray_cluster_resource.namespace, label_selector=f"dagster.io/run_id={result.run_id}" - )["items"] + assert ( + len( + kuberay_api.list_ray_clusters( + k8s_namespace=ray_cluster_resource.namespace, label_selector=f"dagster.io/run_id={result.run_id}" + )["items"] + ) + == 0 + ) def test_kuberay_cleanup_job( @@ -249,9 +257,11 @@ def my_asset(ray_cluster: RayResource) -> None: resources={"ray_cluster": ray_cluster_resource_skip_cleanup}, ) + kuberay_api = PatchedRayClusterApi(config_file=str(k8s_with_raycluster.kubeconfig)) + assert ( len( - ray_cluster_resource_skip_cleanup.api.kuberay.list_ray_clusters( + kuberay_api.list_ray_clusters( k8s_namespace=ray_cluster_resource_skip_cleanup.namespace, label_selector=f"dagster.io/run_id={result.run_id}", )["items"] @@ -269,6 +279,6 @@ def my_asset(ray_cluster: RayResource) -> None: ) ) - assert not ray_cluster_resource_skip_cleanup.api.kuberay.list_ray_clusters( + assert not kuberay_api.list_ray_clusters( k8s_namespace=ray_cluster_resource_skip_cleanup.namespace, label_selector=f"dagster.io/run_id={result.run_id}" )["items"]