Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgafni committed Apr 19, 2024
1 parent 495bd81 commit 43f4038
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 12 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ The following backends are implemented:

`dagster-ray` is tested across multiple version combinations of components such as `ray`, `dagster`, `KubeRay Operator`, and `Python`.

`dagster-ray` integrates with `Dagster Plus` out of the box by using environment variables such as `DAGSTER_CLOUD_DEPLOYMENT_NAME` and tags such as `dagster/user`.
`dagster-ray` integrates with [Dagster+](https://dagster.io/plus) out of the box.

Documentation can be found below.


> [!NOTE]
> This project is in early development. Contributions are very welcome! See the [Development](#development) section below.
Expand Down Expand Up @@ -116,6 +115,8 @@ definitions = Definitions(

This backend requires a Kubernetes cluster with the `KubeRay Operator` installed.

Integrates with [Dagster+](https://dagster.io/plus) by injecting environment variables such as `DAGSTER_CLOUD_DEPLOYMENT_NAME` and tags such as `dagster/user` into default configuration values and `RayCluster` labels.

The public objects can be imported from `dagster_ray.kuberay` module.

### Resources
Expand Down Expand Up @@ -168,6 +169,7 @@ ray_cluster = KubeRayCluster(
"replicas": 2,
"minReplicas": 1,
"maxReplicas": 10,
# ...
}
],
)
Expand Down
1 change: 1 addition & 0 deletions dagster_ray/kuberay/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
}
DEFAULT_HEAD_GROUP_SPEC = {
"serviceType": "ClusterIP",
"rayStartParams": {},
"metadata": {
"labels": {},
"annotations": {},
Expand Down
27 changes: 22 additions & 5 deletions dagster_ray/kuberay/resources.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import hashlib
import os
import random
import re
import string
Expand Down Expand Up @@ -110,6 +111,8 @@ def yield_for_execution(self, context: InitResourceContext) -> Generator[Self, N
assert context.log is not None
assert context.dagster_run is not None

self.api.setup_for_execution(context)

self._cluster_name = self._get_ray_cluster_step_name(context)

# self._host = f"{self.cluster_name}-head-svc.{self.namespace}.svc.cluster.local"
Expand All @@ -125,7 +128,12 @@ def yield_for_execution(self, context: InitResourceContext) -> Generator[Self, N
labels=self._get_labels(context),
)

self.api.kuberay.create_ray_cluster(body=cluster_body, k8s_namespace=self.namespace)
resource = self.api.kuberay.create_ray_cluster(body=cluster_body, k8s_namespace=self.namespace)
if resource is None:
raise Exception(
f"Couldn't create RayCluster {self.namespace}/{self.cluster_name}! Reason logged above."
)

context.log.info(
f"Created RayCluster {self.namespace}/{self.cluster_name}. Waiting for it to become ready..."
)
Expand Down Expand Up @@ -154,7 +162,7 @@ def yield_for_execution(self, context: InitResourceContext) -> Generator[Self, N

yield self
except Exception as e:
context.log.critical("Couldn't create or connect to RayCluster!")
context.log.critical(f"Couldn't create or connect to RayCluster {self.namespace}/{self.cluster_name}!")
self._maybe_cleanup_raycluster(context)
raise e

Expand All @@ -175,6 +183,12 @@ def _get_labels(self, context: InitResourceContext) -> Dict[str, str]:
if context.dagster_run.tags.get("user"):
labels["dagster.io/user"] = context.dagster_run.tags["user"]

if os.getenv("DAGSTER_CLOUD_GIT_BRANCH"):
labels["dagster.io/git-branch"] = os.environ["DAGSTER_CLOUD_GIT_BRANCH"]

if os.getenv("DAGSTER_CLOUD_GIT_SHA"):
labels["dagster.io/git-sha"] = os.environ["DAGSTER_CLOUD_GIT_SHA"]

return labels

def _build_raycluster(
Expand All @@ -187,6 +201,9 @@ def _build_raycluster(
"""
# TODO: inject self.redis_port and self.dashboard_port into the RayCluster configuration

labels = labels or {}
assert isinstance(labels, dict)

image = self.ray_cluster.image or image
head_group_spec = self.ray_cluster.head_group_spec.copy()
worker_group_specs = self.ray_cluster.worker_group_specs.copy()
Expand All @@ -197,12 +214,12 @@ def update_group_spec(group_spec: Dict[str, Any]):
group_spec["template"]["spec"]["containers"][0]["image"] = image

if group_spec.get("metadata") is None:
group_spec["metadata"] = {"labels": labels or {}}
group_spec["metadata"] = {"labels": labels}
else:
if group_spec["metadata"].get("labels") is None:
group_spec["metadata"]["labels"] = labels or {}
group_spec["metadata"]["labels"] = labels
else:
group_spec["metadata"]["labels"].update(labels or {})
group_spec["metadata"]["labels"].update(labels)

update_group_spec(head_group_spec)
for worker_group_spec in worker_group_specs:
Expand Down
20 changes: 15 additions & 5 deletions tests/test_kuberay.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from dagster_ray.kuberay import KubeRayAPI, KubeRayCluster, cleanup_kuberay_clusters
from dagster_ray.kuberay.configs import DEFAULT_HEAD_GROUP_SPEC, DEFAULT_WORKER_GROUP_SPECS, RayClusterConfig
from dagster_ray.kuberay.ops import CleanupKuberayClustersConfig
from dagster_ray.kuberay.ray_cluster_api import PatchedRayClusterApi
from tests import ROOT_DIR


Expand Down Expand Up @@ -229,11 +230,18 @@ def my_asset(context: AssetExecutionContext, ray_cluster: RayResource) -> None:
resources={"ray_cluster": ray_cluster_resource},
)

kuberay_api = PatchedRayClusterApi(config_file=str(k8s_with_raycluster.kubeconfig))

# make sure the RayCluster is cleaned up

assert not ray_cluster_resource.api.kuberay.list_ray_clusters(
k8s_namespace=ray_cluster_resource.namespace, label_selector=f"dagster.io/run_id={result.run_id}"
)["items"]
assert (
len(
kuberay_api.list_ray_clusters(
k8s_namespace=ray_cluster_resource.namespace, label_selector=f"dagster.io/run_id={result.run_id}"
)["items"]
)
== 0
)


def test_kuberay_cleanup_job(
Expand All @@ -249,9 +257,11 @@ def my_asset(ray_cluster: RayResource) -> None:
resources={"ray_cluster": ray_cluster_resource_skip_cleanup},
)

kuberay_api = PatchedRayClusterApi(config_file=str(k8s_with_raycluster.kubeconfig))

assert (
len(
ray_cluster_resource_skip_cleanup.api.kuberay.list_ray_clusters(
kuberay_api.list_ray_clusters(
k8s_namespace=ray_cluster_resource_skip_cleanup.namespace,
label_selector=f"dagster.io/run_id={result.run_id}",
)["items"]
Expand All @@ -269,6 +279,6 @@ def my_asset(ray_cluster: RayResource) -> None:
)
)

assert not ray_cluster_resource_skip_cleanup.api.kuberay.list_ray_clusters(
assert not kuberay_api.list_ray_clusters(
k8s_namespace=ray_cluster_resource_skip_cleanup.namespace, label_selector=f"dagster.io/run_id={result.run_id}"
)["items"]

0 comments on commit 43f4038

Please sign in to comment.