diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 362ad2c..8dc9347 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -1,4 +1,4 @@ -name: QA & Tests +name: CI on: workflow_dispatch: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f937fa8..591924a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,6 +31,12 @@ repos: entry: ruff format language: system pass_filenames: false + - id: format-docs + name: format docs + entry: blacken-docs + language: system + pass_filenames: true + files: "\\.(md)$" - id: pyright name: pyright entry: pyright . diff --git a/README.md b/README.md index e384121..a56aa91 100644 --- a/README.md +++ b/README.md @@ -4,40 +4,214 @@ [![image](https://img.shields.io/pypi/v/dagster-ray.svg)](https://pypi.python.org/pypi/dagster-ray) [![image](https://img.shields.io/pypi/l/dagster-ray.svg)](https://pypi.python.org/pypi/dagster-ray) [![image](https://img.shields.io/pypi/pyversions/dagster-ray.svg)](https://pypi.python.org/pypi/dagster-ray) -[![CI](https://github.com/danielgafni/dagster-ray/actions/workflows/ci.yml/badge.svg)](https://github.com/danielgafni/dagster-ray/actions/workflows/ci.yml) +[![CI](https://github.com/danielgafni/dagster-ray/actions/workflows/ci.yml/badge.svg)](https://github.com/danielgafni/dagster-ray/actions/workflows/CI.yml) [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit) [![Checked with pyright](https://microsoft.github.io/pyright/img/pyright_badge.svg)](https://microsoft.github.io/pyright/) [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) [Ray](https://github.com/ray-project/ray) integration library for [Dagster](https://github.com/dagster-io/dagster). -`dagster-ray` allows you to run Ray computations in your Dagster pipelines. The following backends are implemented: +`dagster-ray` allows running Ray computations in Dagster pipelines. It provides various Dagster abstractions, the most important being `Resource`, and helper `@op`s and `@schedule`s, for multiple backends. + +The following backends are implemented: +- local - `KubeRay` (kubernetes) `dagster-ray` is tested across multiple version combinations of components such as `ray`, `dagster`, `KubeRay Operator`, and `Python`. -# Features +`dagster-ray` integrates with [Dagster+](https://dagster.io/plus) out of the box. + +Documentation can be found below. + +> [!NOTE] +> This project is in early development. Contributions are very welcome! See the [Development](#development) section below. + +# Backends + +`dagster-ray` provides a `RayResource` class, which does not implement any specific backend. +It defines the common interface for all `Ray` resources. +It can be used for type annotations in your `@op` and `@asset` definitions. + +Examples: + +```python +from dagster import asset +from dagster_ray import RayResource +import ray -## Resources -### `LocalRay` +@asset +def my_asset( + ray_cluster: RayResource, # RayResource is only used as a type annotation +): + return ray.get(ray.put(42)) +``` + +The other resources below are the actual backends that implement the `RayResource` interface. + +## Local + +These resources can be used for development and testing purposes. +They provide the same interface as the other `*Ray` resources, but don't require any external infrastructure. + +The public objects can be imported from `dagster_ray.local` module. + +### Resources + +#### `LocalRay` A dummy resource which is useful for testing and development. It doesn't do anything, but provides the same interface as the other `*Ray` resources. -### `KubeRayCluster` +Examples: + + +Using the `LocalRay` resource + +```python +from dagster import asset, Definitions +from dagster_ray import RayResource +from dagster_ray.local import LocalRay +import ray + + +@asset +def my_asset( + ray_cluster: RayResource, # RayResource is only used as a type annotation +): # this type annotation only defines the interface + return ray.get(ray.put(42)) + + +definitions = Definitions(resources={"ray_cluster": LocalRay()}, assets=[my_asset]) +``` + +Conditionally using the `LocalRay` resource in development and `KubeRayCluster` in production: + +```python +from dagster import asset, Definitions +from dagster_ray import RayResource +from dagster_ray.local import LocalRay +from dagster_ray.kuberay import KubeRayCluster +import ray + + +@asset +def my_asset( + ray_cluster: RayResource, # RayResource is only used as a type annotation +): # this type annotation only defines the interface + return ray.get(ray.put(42)) + + +IN_K8s = ... + + +definitions = Definitions( + resources={"ray_cluster": KubeRayCluster() if IN_K8s else LocalRay()}, + assets=[my_asset], +) +``` + +## KubeRay + +This backend requires a Kubernetes cluster with the `KubeRay Operator` installed. + +Integrates with [Dagster+](https://dagster.io/plus) by injecting environment variables such as `DAGSTER_CLOUD_DEPLOYMENT_NAME` and tags such as `dagster/user` into default configuration values and `RayCluster` labels. -`KubeRayCluster` can be used for running Ray computations on Kubernetes. Requires `KubeRay Operator` to be installed the Kubernetes cluster. +The public objects can be imported from `dagster_ray.kuberay` module. + +### Resources + +#### `KubeRayCluster` + +`KubeRayCluster` can be used for running Ray computations on Kubernetes. When added as resource dependency to an `@op/@asset`, the `KubeRayCluster`: - Starts a dedicated `RayCluster` for it - - connects `ray.init()` to the cluster (if `ray` is installed) - - tears down the cluster after the step is executed + - Connects to the cluster in client mode with `ray.init()` (unless `skip_init` is set to `True`) + - Tears down the cluster after the step is executed (unless `skip_cleanup` is set to `True`) + +`RayCluster` comes with minimal default configuration, matching `KubeRay` defaults. + +Examples: + +Basic usage (will create a single-node, non-scaling `RayCluster`): + +```python +from dagster import asset, Definitions +from dagster_ray import RayResource +from dagster_ray.kuberay import KubeRayCluster +import ray + + +@asset +def my_asset( + ray_cluster: RayResource, # RayResource is only used as a type annotation +): # this type annotation only defines the interface + return ray.get(ray.put(42)) + + +definitions = Definitions( + resources={"ray_cluster": KubeRayCluster()}, assets=[my_asset] +) +``` + +Larger cluster with auto-scaling enabled: + +```python +from dagster_ray.kuberay import KubeRayCluster, RayClusterConfig + +ray_cluster = KubeRayCluster( + ray_cluster=RayClusterConfig( + enable_in_tree_autoscaling=True, + worker_group_specs=[ + { + "groupName": "workers", + "replicas": 2, + "minReplicas": 1, + "maxReplicas": 10, + # ... + } + ], + ) +) +``` +#### `KubeRayAPI` + +This resource can be used to interact with the Kubernetes API Server. + +Examples: + +Listing currently running `RayClusters`: + +```python +from dagster import op, Definitions +from dagster_ray.kuberay import KubeRayAPI + + +@op +def list_ray_clusters( + kube_ray_api: KubeRayAPI, +): + return kube_ray_api.kuberay.list_ray_clusters(k8s_namespace="kuberay") +``` + +### Jobs + +#### `delete_kuberay_clusters` + +This `job` can be used to delete `RayClusters` from a given list of names. + +#### `cleanup_old_ray_clusters` + +This `job` can be used to delete old `RayClusters` which no longer correspond to any active Dagster Runs. +They may be left behind if the automatic cluster cleanup was disabled or failed. + +### Schedules -## Schedules +Cleanup schedules can be trivially created using the `cleanup_old_ray_clusters` or `delete_kuberay_clusters` jobs. -`dagster-ray` provides a schedule for automatic cleanup of old `RayClusters` in the cluster. -They may be left behind if the automatic cleanup was disabled or failed. +#### `cleanup_old_ray_clusters` +`dagster-ray` provides an example daily cleanup schedule. ## Executor WIP @@ -62,13 +236,13 @@ Required tools: - `minikube` Running `pytest` will **automatically**: - - build an image with the local `dagster-ray` code, using the current Python's interpreter version + - build an image with the local `dagster-ray` code - start a `minikube` Kubernetes cluster - load the built `dagster-ray` and loaded `kuberay-operator` images into the cluster - install the `KubeRay Operator` in the cluster with `helm` - run the tests -Thus, no manual setup is required, just the presence of the tools listed above. +Thus, no manual setup is required, just the presence of the tools listed above. This makes testing a breeze! > [!NOTE] > Specifying a comma-separated list of `KubeRay Operator` versions in the `KUBE_RAY_OPERATOR_VERSIONS` environment variable will spawn a new test for each version. diff --git a/dagster_ray/__init__.py b/dagster_ray/__init__.py index e69de29..80d642a 100644 --- a/dagster_ray/__init__.py +++ b/dagster_ray/__init__.py @@ -0,0 +1,6 @@ +from dagster_ray._base.resources import BaseRayResource + +RayResource = BaseRayResource + + +__all__ = ["RayResource"] diff --git a/dagster_ray/_base/resources.py b/dagster_ray/_base/resources.py index cf2f694..55f228a 100644 --- a/dagster_ray/_base/resources.py +++ b/dagster_ray/_base/resources.py @@ -1,9 +1,9 @@ import sys import uuid from abc import ABC, abstractmethod -from typing import Optional, cast +from typing import TYPE_CHECKING, Optional, Union, cast -from dagster import ConfigurableResource, InitResourceContext +from dagster import ConfigurableResource, InitResourceContext, OpExecutionContext from pydantic import Field, PrivateAttr # yes, `python-client` is actually the KubeRay package name @@ -18,8 +18,9 @@ else: pass -import ray -from ray._private.worker import BaseContext as RayBaseContext # noqa + +if TYPE_CHECKING: + from ray._private.worker import BaseContext as RayBaseContext # noqa class BaseRayResource(ConfigurableResource, ABC): @@ -36,7 +37,13 @@ class BaseRayResource(ConfigurableResource, ABC): default=8265, description="Dashboard port for connection. Make sure to match with the actual available port." ) - _context: Optional[RayBaseContext] = PrivateAttr() + _context: Optional["RayBaseContext"] = PrivateAttr() + + def setup_for_execution(self, context: InitResourceContext) -> None: + raise NotImplementedError( + "This is an abstract resource, it's not meant to be provided directly. " + "Use a backend-specific resource instead." + ) @property def context(self) -> "RayBaseContext": @@ -62,15 +69,22 @@ def runtime_job_id(self) -> str: Returns the Ray Job ID for the current job which was created with `ray.init()`. :return: """ + import ray + return ray.get_runtime_context().get_job_id() @retry(stop=stop_after_delay(120), retry=retry_if_exception_type(ConnectionError), reraise=True) - def init_ray(self) -> "RayBaseContext": + def init_ray(self, context: Union[OpExecutionContext, InitResourceContext]) -> "RayBaseContext": + assert context.log is not None + + import ray + self.data_execution_options.apply() self._context = ray.init(address=self.ray_address, ignore_reinit_error=True) self.data_execution_options.apply() self.data_execution_options.apply_remote() - return cast(RayBaseContext, self._context) + context.log.info("Initialized Ray!") + return cast("RayBaseContext", self._context) def _get_step_key(self, context: InitResourceContext) -> str: # just return a random string diff --git a/dagster_ray/kuberay/__init__.py b/dagster_ray/kuberay/__init__.py index 05a07ad..8580d21 100644 --- a/dagster_ray/kuberay/__init__.py +++ b/dagster_ray/kuberay/__init__.py @@ -1,3 +1,16 @@ -from dagster_ray.kuberay.resources import KubeRayCluster +from dagster_ray.kuberay.configs import RayClusterConfig +from dagster_ray.kuberay.jobs import cleanup_kuberay_clusters, delete_kuberay_clusters +from dagster_ray.kuberay.ops import cleanup_kuberay_clusters_op, delete_kuberay_clusters_op +from dagster_ray.kuberay.resources import KubeRayAPI, KubeRayCluster +from dagster_ray.kuberay.schedules import cleanup_kuberay_clusters_daily -__all__ = ["KubeRayCluster"] +__all__ = [ + "KubeRayCluster", + "RayClusterConfig", + "KubeRayAPI", + "cleanup_kuberay_clusters", + "delete_kuberay_clusters", + "cleanup_kuberay_clusters_op", + "delete_kuberay_clusters_op", + "cleanup_kuberay_clusters_daily", +] diff --git a/dagster_ray/kuberay/configs.py b/dagster_ray/kuberay/configs.py new file mode 100644 index 0000000..e7a2e81 --- /dev/null +++ b/dagster_ray/kuberay/configs.py @@ -0,0 +1,91 @@ +import os +from typing import Any, Dict, List, Optional + +from dagster import Config + +in_k8s = os.environ.get("KUBERNETES_SERVICE_HOST") is not None +IS_PROD = os.getenv("DAGSTER_CLOUD_DEPLOYMENT_NAME") == "prod" +DEFAULT_AUTOSCALER_OPTIONS = { + "upscalingMode": "Default", + "idleTimeoutSeconds": 60, + "env": [], + "envFrom": [], + "resources": { + "limits": {"cpu": "1000m", "memory": "1Gi"}, + "requests": {"cpu": "1000m", "memory": "1Gi"}, + }, +} +DEFAULT_HEAD_GROUP_SPEC = { + "serviceType": "ClusterIP", + "rayStartParams": {}, + "metadata": { + "labels": {}, + "annotations": {}, + }, + "template": { + "spec": { + "imagePullSecrets": [], + "containers": [ + { + "volumeMounts": [ + # {"mountPath": "/tmp/ray", "name": "log-volume"}, + ], + "name": "head", + "imagePullPolicy": "Always", + }, + ], + "volumes": [ + {"name": "log-volume", "emptyDir": {}}, + ], + "affinity": {}, + "tolerations": [], + "nodeSelector": {}, + }, + }, +} +DEFAULT_WORKER_GROUP_SPECS = [ + { + "groupName": "workers", + "replicas": 0, + "minReplicas": 0, + "maxReplicas": 1, + "rayStartParams": {}, + "template": { + "metadata": {"labels": {}, "annotations": {}}, + "spec": { + "imagePullSecrets": [], + "containers": [ + { + "volumeMounts": [ + # {"mountPath": "/tmp/ray", "name": "log-volume"} + ], + "name": "worker", + "imagePullPolicy": "Always", + } + ], + "volumes": [ + {"name": "log-volume", "emptyDir": {}}, + ], + "affinity": {}, + "tolerations": [], + "nodeSelector": {}, + }, + }, + } +] + + +class RayClusterConfig(Config): + image: Optional[str] = None + namespace: str = "kuberay" + enable_in_tree_autoscaling: bool = False + autoscaler_options: Dict[str, Any] = DEFAULT_AUTOSCALER_OPTIONS # TODO: add a dedicated Config type + head_group_spec: Dict[str, Any] = DEFAULT_HEAD_GROUP_SPEC # TODO: add a dedicated Config type + worker_group_specs: List[Dict[str, Any]] = DEFAULT_WORKER_GROUP_SPECS # TODO: add a dedicated Config type + + +DEFAULT_DEPLOYMENT_NAME = ( + os.getenv("DAGSTER_CLOUD_DEPLOYMENT_NAME") + if os.getenv("DAGSTER_CLOUD_IS_BRANCH_DEPLOYMENT") == "0" + else os.getenv("DAGSTER_CLOUD_GIT_BRANCH") +) or "dev" diff --git a/dagster_ray/kuberay/jobs.py b/dagster_ray/kuberay/jobs.py index e69de29..59e2f77 100644 --- a/dagster_ray/kuberay/jobs.py +++ b/dagster_ray/kuberay/jobs.py @@ -0,0 +1,16 @@ +from dagster import job + +from dagster_ray.kuberay.ops import cleanup_kuberay_clusters_op, delete_kuberay_clusters_op + + +@job(description="Deletes KubeRay clusters from Kubernetes", name="delete_kuberay_clusters") +def delete_kuberay_clusters(): + delete_kuberay_clusters_op() + + +@job( + description="Deletes KubeRay clusters which do not correspond to any active Dagster Runs in this deployment", + name="cleanup_kuberay_clusters", +) +def cleanup_kuberay_clusters(): + cleanup_kuberay_clusters_op() diff --git a/dagster_ray/kuberay/ops.py b/dagster_ray/kuberay/ops.py index 3c32745..cc5ac6b 100644 --- a/dagster_ray/kuberay/ops.py +++ b/dagster_ray/kuberay/ops.py @@ -1,23 +1,76 @@ -# from typing import List -# -# from dagster import Config, OpExecutionContext, op -# -# -# -# class DeleteRayClustersConfig(Config): -# cluster_names: List[str] -# namespace: str = "ray" -# -# -# @op(description="Deletes RayClusters from Kubernetes") -# def delete_ray_clusters_op(context: OpExecutionContext, config: DeleteRayClustersConfig) -> None: -# for cluster_name in config.cluster_names: -# try: -# if check_exists(cluster_name, namespace=config.namespace): -# stop(cluster_name, namespace=config.namespace) -# context.log.info(f"RayCluster {cluster_name} deleted") -# else: -# context.log.debug(f"RayCluster {cluster_name} doesn't exist") -# except Exception as e: # noqa -# context.log.error(f"Couldn't delete RayCluster {cluster_name}") -# context.log.exception(str(e)) +from typing import List + +from dagster import Config, DagsterRunStatus, OpExecutionContext, RunsFilter, op +from pydantic import Field + +from dagster_ray.kuberay.configs import DEFAULT_DEPLOYMENT_NAME +from dagster_ray.kuberay.resources import KubeRayAPI + + +class DeleteKubeRayClustersConfig(Config): + namespace: str = "kuberay" + cluster_names: List[str] = Field(default_factory=list, description="Specific RayCluster names to delete") + + +@op(description="Deletes KubeRay clusters from Kubernetes", name="delete_kuberay_clusters") +def delete_kuberay_clusters_op( + context: OpExecutionContext, + config: DeleteKubeRayClustersConfig, + kuberay_api: KubeRayAPI, +) -> None: + for cluster_name in config.cluster_names: + try: + if kuberay_api.kuberay.get_ray_cluster(name=cluster_name, k8s_namespace=config.namespace).get("items"): + kuberay_api.kuberay.delete_ray_cluster(name=cluster_name, k8s_namespace=config.namespace) + context.log.info(f"RayCluster {config.namespace}/{cluster_name} deleted!") + else: + context.log.warning(f"RayCluster {config.namespace}/{cluster_name} doesn't exist") + except Exception as e: # noqa + context.log.exception(f"Couldn't delete RayCluster {config.namespace}/{cluster_name}") + + +class CleanupKuberayClustersConfig(Config): + namespace: str = "kuberay" + label_selector: str = Field( + default=f"dagster.io/deployment={DEFAULT_DEPLOYMENT_NAME}", description="Label selector to filter RayClusters" + ) + + +@op( + description="Deletes KubeRay clusters which do not correspond to any active Dagster Runs in this deployment", + name="cleanup_kuberay_clusters", +) +def cleanup_kuberay_clusters_op( + context: OpExecutionContext, + config: CleanupKuberayClustersConfig, + kuberay_api: KubeRayAPI, +) -> None: + current_runs = context.instance.get_runs( + filters=RunsFilter( + statuses=[ + DagsterRunStatus.STARTED, + DagsterRunStatus.QUEUED, + DagsterRunStatus.CANCELING, + ] + ) + ) + + clusters = kuberay_api.kuberay.list_ray_clusters( + k8s_namespace=config.namespace, + label_selector=config.label_selector, + )["items"] + + # filter clusters by current runs using dagster.io/run_id label + + old_cluster_names = [ + cluster["metadata"]["name"] + for cluster in clusters + if not any(run.run_id == cluster["metadata"]["labels"]["dagster.io/run_id"] for run in current_runs) + ] + + for cluster_name in old_cluster_names: + try: + kuberay_api.kuberay.delete_ray_cluster(name=cluster_name, k8s_namespace=config.namespace) + context.log.info(f"RayCluster {config.namespace}/{cluster_name} deleted!") + except: # noqa + context.log.exception(f"Couldn't delete RayCluster {config.namespace}/{cluster_name}") diff --git a/dagster_ray/kuberay/resources.py b/dagster_ray/kuberay/resources.py index 090536e..65e9863 100644 --- a/dagster_ray/kuberay/resources.py +++ b/dagster_ray/kuberay/resources.py @@ -5,10 +5,10 @@ import re import string import sys -from typing import Any, Dict, Generator, List, Optional, cast +from typing import Any, Dict, Generator, Optional, cast import dagster._check as check -from dagster import InitResourceContext +from dagster import ConfigurableResource, InitResourceContext from kubernetes import client, config, watch from pydantic import Field, PrivateAttr @@ -16,6 +16,7 @@ # https://github.com/ray-project/kuberay/issues/2078 from python_client import kuberay_cluster_api +from dagster_ray.kuberay.configs import DEFAULT_DEPLOYMENT_NAME, RayClusterConfig from dagster_ray.kuberay.ray_cluster_api import PatchedRayClusterApi if sys.version_info >= (3, 11): @@ -23,127 +24,50 @@ else: from typing_extensions import Self -from dagster import Config, DagsterRun, DagsterRunStatus +from dagster import DagsterRun, DagsterRunStatus from ray._private.worker import BaseContext as RayBaseContext # noqa from dagster_ray._base.resources import BaseRayResource -in_k8s = os.environ.get("KUBERNETES_SERVICE_HOST") is not None -IS_PROD = os.getenv("DAGSTER_CLOUD_DEPLOYMENT_NAME") == "prod" - -DEFAULT_AUTOSCALER_OPTIONS = { - "upscalingMode": "Default", - "idleTimeoutSeconds": 60, - "env": [], - "envFrom": [], - "resources": { - "limits": {"cpu": "1000m", "memory": "1Gi"}, - "requests": {"cpu": "1000m", "memory": "1Gi"}, - }, -} - -DEFAULT_HEAD_GROUP_SPEC = { - "serviceType": "ClusterIP", - "rayStartParams": {"dashboard-host": "0.0.0.0"}, - "metadata": { - "labels": {}, - "annotations": {}, - }, - "template": { - "spec": { - "imagePullSecrets": [], - "containers": [ - { - "volumeMounts": [ - {"mountPath": "/tmp/ray", "name": "log-volume"}, - ], - "name": "head", - "imagePullPolicy": "Always", - }, - ], - "volumes": [ - {"name": "log-volume", "emptyDir": {}}, - ], - "affinity": {}, - "tolerations": [], - "nodeSelector": {}, - }, - }, -} -# -# -# class WorkerGroupSpecConfig(Config): -# imagePullSecrets: List[Dict[str, Any]] = [] -# containers: List[Dict[str, Any]] = [{ -# "volumeMounts": [ -# {"mountPath": "/tmp/ray", "name": "log-volume"}, -# ], -# "name": "worker", -# }] -# volumes: List[Dict[str, Any]] = [] -# affinity: Dict[str, Any] = {} -# tolerations: List[Dict[str, Any]] = [] -# nodeSelector: Dict[str, Any] = {} -# -# -# -# class WorkerGroupTemplateConfig(Config): -# spec: WorkerGroupSpecConfig = WorkerGroupSpecConfig() -# -# -# class WorkerGroupConfig(Config): -# template: WorkerGroupTemplateConfig = WorkerGroupTemplateConfig() -# -# class HeadGroupConfig(Config): -# serviceType: str = "ClusterIP" -# rayStartParams: Dict[str, Any] = {"dashboard-host": "0.0.0.0"} -# template: GroupTemplateConfig = GroupTemplateConfig() - - -DEFAULT_WORKER_GROUP_SPECS = [ - { - "groupName": "workers", - "rayStartParams": {}, - "template": { - "metadata": {"labels": {}, "annotations": {}}, - "spec": { - "imagePullSecrets": [], - "containers": [ - { - "volumeMounts": [ - {"mountPath": "/tmp/ray", "name": "log-volume"}, - ], - "name": "worker", - "imagePullPolicy": "Always", - } - ], - "volumes": [ - {"name": "log-volume", "emptyDir": {}}, - ], - "affinity": {}, - "tolerations": [], - "nodeSelector": {}, - }, - }, - } -] +class KubeRayAPI(ConfigurableResource): + kubeconfig_file: Optional[str] = None + + _kuberay_api: PatchedRayClusterApi = PrivateAttr() + _k8s_api: client.CustomObjectsApi = PrivateAttr() + _k8s_core_api: client.CoreV1Api = PrivateAttr() + + @property + def kuberay(self) -> kuberay_cluster_api.RayClusterApi: + if self._kuberay_api is None: + raise ValueError("KubeRayAPI not initialized") + return self._kuberay_api + + @property + def k8s(self) -> client.CustomObjectsApi: + if self._k8s_api is None: + raise ValueError("KubeRayAPI not initialized") + return self._k8s_api + + @property + def k8s_core(self) -> client.CoreV1Api: + if self._k8s_core_api is None: + raise ValueError("KubeRayAPI not initialized") + return self._k8s_core_api -class RayClusterConfig(Config): - image: Optional[str] = None - namespace: str = "kuberay" - enable_in_tree_autoscaling: bool = False - autoscaler_options: Dict[str, Any] = DEFAULT_AUTOSCALER_OPTIONS # TODO: add a dedicated Config type - head_group_spec: Dict[str, Any] = DEFAULT_HEAD_GROUP_SPEC # TODO: add a dedicated Config type - worker_group_specs: List[Dict[str, Any]] = DEFAULT_WORKER_GROUP_SPECS # TODO: add a dedicated Config type + def setup_for_execution(self, context: InitResourceContext) -> None: + self._load_kubeconfig(self.kubeconfig_file) + self._kuberay_api = PatchedRayClusterApi(config_file=self.kubeconfig_file) + self._k8s_api = client.CustomObjectsApi() + self._k8s_core_api = client.CoreV1Api() -DEFAULT_CLUSTER_NAME_PREFIX = ( - os.getenv("DAGSTER_CLOUD_DEPLOYMENT_NAME") - if os.getenv("DAGSTER_CLOUD_IS_BRANCH_DEPLOYMENT") == "0" - else os.getenv("DAGSTER_CLOUD_GIT_BRANCH") -) -DEFAULT_CLUSTER_NAME_PREFIX = DEFAULT_CLUSTER_NAME_PREFIX or "dev" + @staticmethod + def _load_kubeconfig(kubeconfig_file: Optional[str] = None): + try: + config.load_kube_config(config_file=kubeconfig_file) + except config.config_exception.ConfigException: + config.load_incluster_config() class KubeRayCluster(BaseRayResource): @@ -152,21 +76,19 @@ class KubeRayCluster(BaseRayResource): The cluster is automatically deleted after steps execution """ - cluster_name_prefix: str = Field( - default=DEFAULT_CLUSTER_NAME_PREFIX, - description="Prefix for the RayCluster name. It's recommended to match it with the Dagster deployment name. Dagster Cloud variables are used for the default value.", + deployment_name: str = Field( + default=DEFAULT_DEPLOYMENT_NAME, + description="Prefix for the RayCluster name. It's recommended to match it with the Dagster deployment name. " + "Dagster Cloud variables are used to determine the default value.", ) ray_cluster: RayClusterConfig = Field(default_factory=RayClusterConfig) - disable_cluster_cleanup: bool = False + skip_cleanup: bool = False skip_init: bool = False - kubeconfig_file: Optional[str] = None + api: KubeRayAPI = Field(default_factory=KubeRayAPI) _cluster_name: str = PrivateAttr() _host: str = PrivateAttr() - _kuberay_api: PatchedRayClusterApi = PrivateAttr() - _k8s_api: client.CustomObjectsApi = PrivateAttr() - _k8s_core_api: client.CoreV1Api = PrivateAttr() @property def host(self) -> str: @@ -184,28 +106,12 @@ def cluster_name(self) -> str: raise ValueError("RayClusterResource not initialized") return self._cluster_name - @property - def kuberay_api(self) -> kuberay_cluster_api.RayClusterApi: - if self._kuberay_api is None: - raise ValueError("RayClusterResource not initialized") - return self._kuberay_api - - @property - def k8s_api(self) -> client.CustomObjectsApi: - if self._k8s_api is None: - raise ValueError("RayClusterResource not initialized") - return self._k8s_api - @contextlib.contextmanager def yield_for_execution(self, context: InitResourceContext) -> Generator[Self, None, None]: assert context.log is not None assert context.dagster_run is not None - self._load_kubeconfig(self.kubeconfig_file) - - self._kuberay_api = PatchedRayClusterApi(config_file=self.kubeconfig_file) - self._k8s_api = client.CustomObjectsApi() - self._k8s_core_api = client.CoreV1Api() + self.api.setup_for_execution(context) self._cluster_name = self._get_ray_cluster_step_name(context) @@ -213,20 +119,21 @@ def yield_for_execution(self, context: InitResourceContext) -> Generator[Self, N try: # just a safety measure, no need to recreate the cluster for step retries or smth - if not self.kuberay_api.list_ray_clusters( + if not self.api.kuberay.list_ray_clusters( k8s_namespace=self.namespace, label_selector=f"dagster.io/cluster={self.cluster_name}", )["items"]: cluster_body = self._build_raycluster( image=(self.ray_cluster.image or context.dagster_run.tags["dagster/image"]), - labels={ - "dagster.io/run_id": cast(str, context.run_id), - "dagster.io/cluster": self.cluster_name, - # TODO: add more labels - }, + labels=self._get_labels(context), ) - self.kuberay_api.create_ray_cluster(body=cluster_body, k8s_namespace=self.namespace) + resource = self.api.kuberay.create_ray_cluster(body=cluster_body, k8s_namespace=self.namespace) + if resource is None: + raise Exception( + f"Couldn't create RayCluster {self.namespace}/{self.cluster_name}! Reason logged above." + ) + context.log.info( f"Created RayCluster {self.namespace}/{self.cluster_name}. Waiting for it to become ready..." ) @@ -236,7 +143,7 @@ def yield_for_execution(self, context: InitResourceContext) -> Generator[Self, N # TODO: currently this will only work from withing the cluster # find a way to make it work from outside # probably would need a LoadBalancer/Ingress - self._host = self.kuberay_api.get_ray_cluster(name=self.cluster_name, k8s_namespace=self.namespace)[ + self._host = self.api.kuberay.get_ray_cluster(name=self.cluster_name, k8s_namespace=self.namespace)[ "status" ]["head"]["serviceIP"] @@ -249,14 +156,13 @@ def yield_for_execution(self, context: InitResourceContext) -> Generator[Self, N context.log.debug(f"Ray host: {self.host}") if not self.skip_init: - self.init_ray() - context.log.info("Initialized Ray!") + self.init_ray(context) else: self._context = None yield self except Exception as e: - context.log.critical("Couldn't create or connect to RayCluster!") + context.log.critical(f"Couldn't create or connect to RayCluster {self.namespace}/{self.cluster_name}!") self._maybe_cleanup_raycluster(context) raise e @@ -264,6 +170,27 @@ def yield_for_execution(self, context: InitResourceContext) -> Generator[Self, N if self._context is not None: self._context.disconnect() + def _get_labels(self, context: InitResourceContext) -> Dict[str, str]: + assert context.dagster_run is not None + + labels = { + "dagster.io/run_id": cast(str, context.run_id), + "dagster.io/cluster": self.cluster_name, + "dagster.io/deployment": self.deployment_name, + # TODO: add more labels + } + + if context.dagster_run.tags.get("user"): + labels["dagster.io/user"] = context.dagster_run.tags["user"] + + if os.getenv("DAGSTER_CLOUD_GIT_BRANCH"): + labels["dagster.io/git-branch"] = os.environ["DAGSTER_CLOUD_GIT_BRANCH"] + + if os.getenv("DAGSTER_CLOUD_GIT_SHA"): + labels["dagster.io/git-sha"] = os.environ["DAGSTER_CLOUD_GIT_SHA"] + + return labels + def _build_raycluster( self, image: str, @@ -273,6 +200,10 @@ def _build_raycluster( Builds a RayCluster from the provided configuration, while injecting custom image and labels (only known during resource setup) """ # TODO: inject self.redis_port and self.dashboard_port into the RayCluster configuration + # TODO: autoa-apply some tags from dagster-k8s/config + + labels = labels or {} + assert isinstance(labels, dict) image = self.ray_cluster.image or image head_group_spec = self.ray_cluster.head_group_spec.copy() @@ -284,12 +215,12 @@ def update_group_spec(group_spec: Dict[str, Any]): group_spec["template"]["spec"]["containers"][0]["image"] = image if group_spec.get("metadata") is None: - group_spec["metadata"] = {"labels": labels or {}} + group_spec["metadata"] = {"labels": labels} else: if group_spec["metadata"].get("labels") is None: - group_spec["metadata"]["labels"] = labels or {} + group_spec["metadata"]["labels"] = labels else: - group_spec["metadata"]["labels"].update(labels or {}) + group_spec["metadata"]["labels"].update(labels) update_group_spec(head_group_spec) for worker_group_spec in worker_group_specs: @@ -311,8 +242,8 @@ def update_group_spec(group_spec: Dict[str, Any]): } def _wait_raycluster_ready(self): - if not self.kuberay_api.wait_until_ray_cluster_running(self.cluster_name, k8s_namespace=self.namespace): - status = self.kuberay_api.get_ray_cluster_status(self.cluster_name, k8s_namespace=self.namespace) + if not self.api.kuberay.wait_until_ray_cluster_running(self.cluster_name, k8s_namespace=self.namespace): + status = self.api.kuberay.get_ray_cluster_status(self.cluster_name, k8s_namespace=self.namespace) raise Exception(f"RayCluster {self.namespace}/{self.cluster_name} failed to start: {status}") # the above code only checks for RayCluster creation @@ -320,7 +251,7 @@ def _wait_raycluster_ready(self): w = watch.Watch() for event in w.stream( - func=self._k8s_core_api.list_namespaced_pod, + func=self.api.k8s_core.list_namespaced_pod, namespace=self.namespace, label_selector=f"ray.io/cluster={self.cluster_name},ray.io/group=headgroup", timeout_seconds=60, @@ -337,11 +268,8 @@ def _wait_raycluster_ready(self): def _maybe_cleanup_raycluster(self, context: InitResourceContext): assert context.log is not None - if ( - not self.disable_cluster_cleanup - and cast(DagsterRun, context.dagster_run).status != DagsterRunStatus.FAILURE - ): - self.kuberay_api.delete_ray_cluster(self.cluster_name, k8s_namespace=self.namespace) + if not self.skip_cleanup and cast(DagsterRun, context.dagster_run).status != DagsterRunStatus.FAILURE: + self.api.kuberay.delete_ray_cluster(self.cluster_name, k8s_namespace=self.namespace) context.log.info(f"Deleted RayCluster {self.namespace}/{self.cluster_name}") else: context.log.warning( @@ -356,7 +284,7 @@ def _get_ray_cluster_step_name(self, context: InitResourceContext) -> str: # try to make the name as short as possible - cluster_name_prefix = f"dr-{self.cluster_name_prefix.replace('-', '')[:8]}-{context.run_id[:8]}" + cluster_name_prefix = f"dr-{self.deployment_name.replace('-', '')[:8]}-{context.run_id[:8]}" dagster_user_email = context.dagster_run.tags.get("user") if dagster_user_email is not None: @@ -374,13 +302,6 @@ def _get_ray_cluster_step_name(self, context: InitResourceContext) -> str: return step_name - @staticmethod - def _load_kubeconfig(kubeconfig_file: Optional[str] = None): - try: - config.load_kube_config(config_file=kubeconfig_file) - except config.config_exception.ConfigException: - config.load_incluster_config() - def get_k8s_object_name(run_id: str, step_key: Optional[str] = None): """Creates a unique (short!) identifier to name k8s objects based on run ID and step key(s). diff --git a/dagster_ray/kuberay/schedules.py b/dagster_ray/kuberay/schedules.py index 0180b84..4d64dd5 100644 --- a/dagster_ray/kuberay/schedules.py +++ b/dagster_ray/kuberay/schedules.py @@ -1,51 +1,9 @@ -# from dagster import ( -# DagsterRunStatus, -# RunConfig, -# RunRequest, -# RunsFilter, -# ScheduleEvaluationContext, -# SkipReason, -# schedule, -# ) -# -# -# @schedule( -# job=delete_ray_clusters, -# cron_schedule="0 * * * *", -# description="Deletes old KubeRay cluster created by Dagster which don't correspond to any current Runs", -# ) -# def cleanup_old_kuberay_clusters(context: ScheduleEvaluationContext): -# current_runs = context.instance.get_runs( -# filters=RunsFilter( -# statuses=[ -# DagsterRunStatus.STARTED, -# DagsterRunStatus.QUEUED, -# DagsterRunStatus.CANCELING, -# ] -# ) -# ) -# deployed_names = list_deployed_names() -# -# dangling_cluster_names = [] -# -# for cluster_name in deployed_names: -# if cluster_name.startswith("dagster-run-"): -# dagster_short_run_id = cluster_name.split("-")[2] -# -# is_dangling = True -# for run in current_runs: -# if run.run_id.startswith(dagster_short_run_id): -# is_dangling = False -# -# if is_dangling: -# dangling_cluster_names.append(cluster_name) -# -# if len(dangling_cluster_names) > 0: -# return RunRequest( -# run_key=None, -# run_config=RunConfig( -# ops={"delete_ray_clusters_op": DeleteRayClustersConfig(cluster_names=dangling_cluster_names)} -# ), -# ) -# else: -# return SkipReason(skip_message="No dangling RayClusters were found") +from dagster import ScheduleDefinition + +from dagster_ray.kuberay.jobs import cleanup_kuberay_clusters + +cleanup_kuberay_clusters_daily = ScheduleDefinition( + job=cleanup_kuberay_clusters, + cron_schedule="0 0 * * *", + name="cleanup_kuberay_clusters_schedule_daily", +) diff --git a/dagster_ray/local/__init__.py b/dagster_ray/local/__init__.py index e69de29..d63c8a7 100644 --- a/dagster_ray/local/__init__.py +++ b/dagster_ray/local/__init__.py @@ -0,0 +1,3 @@ +from dagster_ray.local.resources import LocalRay + +__all__ = ["LocalRay"] diff --git a/dagster_ray/local/resources.py b/dagster_ray/local/resources.py index a997f43..d80955d 100644 --- a/dagster_ray/local/resources.py +++ b/dagster_ray/local/resources.py @@ -36,8 +36,7 @@ def yield_for_execution(self, context: InitResourceContext) -> Generator[Self, N context.log.debug(f"Ray host: {self.host}") - self.init_ray() - context.log.info("Initialized Ray!") + self.init_ray(context) yield self diff --git a/poetry.lock b/poetry.lock index a35a05d..65fc9e2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -220,6 +220,66 @@ files = [ {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, ] +[[package]] +name = "black" +version = "24.4.0" +description = "The uncompromising code formatter." +optional = false +python-versions = ">=3.8" +files = [ + {file = "black-24.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6ad001a9ddd9b8dfd1b434d566be39b1cd502802c8d38bbb1ba612afda2ef436"}, + {file = "black-24.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e3a3a092b8b756c643fe45f4624dbd5a389f770a4ac294cf4d0fce6af86addaf"}, + {file = "black-24.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dae79397f367ac8d7adb6c779813328f6d690943f64b32983e896bcccd18cbad"}, + {file = "black-24.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:71d998b73c957444fb7c52096c3843875f4b6b47a54972598741fe9a7f737fcb"}, + {file = "black-24.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8e5537f456a22cf5cfcb2707803431d2feeb82ab3748ade280d6ccd0b40ed2e8"}, + {file = "black-24.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:64e60a7edd71fd542a10a9643bf369bfd2644de95ec71e86790b063aa02ff745"}, + {file = "black-24.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5cd5b4f76056cecce3e69b0d4c228326d2595f506797f40b9233424e2524c070"}, + {file = "black-24.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:64578cf99b6b46a6301bc28bdb89f9d6f9b592b1c5837818a177c98525dbe397"}, + {file = "black-24.4.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f95cece33329dc4aa3b0e1a771c41075812e46cf3d6e3f1dfe3d91ff09826ed2"}, + {file = "black-24.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4396ca365a4310beef84d446ca5016f671b10f07abdba3e4e4304218d2c71d33"}, + {file = "black-24.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44d99dfdf37a2a00a6f7a8dcbd19edf361d056ee51093b2445de7ca09adac965"}, + {file = "black-24.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:21f9407063ec71c5580b8ad975653c66508d6a9f57bd008bb8691d273705adcd"}, + {file = "black-24.4.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:652e55bb722ca026299eb74e53880ee2315b181dfdd44dca98e43448620ddec1"}, + {file = "black-24.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7f2966b9b2b3b7104fca9d75b2ee856fe3fdd7ed9e47c753a4bb1a675f2caab8"}, + {file = "black-24.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1bb9ca06e556a09f7f7177bc7cb604e5ed2d2df1e9119e4f7d2f1f7071c32e5d"}, + {file = "black-24.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:d4e71cdebdc8efeb6deaf5f2deb28325f8614d48426bed118ecc2dcaefb9ebf3"}, + {file = "black-24.4.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6644f97a7ef6f401a150cca551a1ff97e03c25d8519ee0bbc9b0058772882665"}, + {file = "black-24.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:75a2d0b4f5eb81f7eebc31f788f9830a6ce10a68c91fbe0fade34fff7a2836e6"}, + {file = "black-24.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eb949f56a63c5e134dfdca12091e98ffb5fd446293ebae123d10fc1abad00b9e"}, + {file = "black-24.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:7852b05d02b5b9a8c893ab95863ef8986e4dda29af80bbbda94d7aee1abf8702"}, + {file = "black-24.4.0-py3-none-any.whl", hash = "sha256:74eb9b5420e26b42c00a3ff470dc0cd144b80a766128b1771d07643165e08d0e"}, + {file = "black-24.4.0.tar.gz", hash = "sha256:f07b69fda20578367eaebbd670ff8fc653ab181e1ff95d84497f9fa20e7d0641"}, +] + +[package.dependencies] +click = ">=8.0.0" +mypy-extensions = ">=0.4.3" +packaging = ">=22.0" +pathspec = ">=0.9.0" +platformdirs = ">=2" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +uvloop = ["uvloop (>=0.15.2)"] + +[[package]] +name = "blacken-docs" +version = "1.16.0" +description = "Run Black on Python code blocks in documentation files." +optional = false +python-versions = ">=3.8" +files = [ + {file = "blacken_docs-1.16.0-py3-none-any.whl", hash = "sha256:b0dcb84b28ebfb352a2539202d396f50e15a54211e204a8005798f1d1edb7df8"}, + {file = "blacken_docs-1.16.0.tar.gz", hash = "sha256:b4bdc3f3d73898dfbf0166f292c6ccfe343e65fc22ddef5319c95d1a8dcc6c1c"}, +] + +[package.dependencies] +black = ">=22.1.0" + [[package]] name = "cachetools" version = "5.3.1" @@ -1898,6 +1958,17 @@ files = [ {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"}, ] +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + [[package]] name = "networkx" version = "3.1" @@ -2216,6 +2287,17 @@ sql-other = ["SQLAlchemy (>=1.4.16)"] test = ["hypothesis (>=6.34.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.6.3)"] +[[package]] +name = "pathspec" +version = "0.12.1" +description = "Utility library for gitignore style pattern matching of file paths." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, + {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, +] + [[package]] name = "pendulum" version = "2.1.2" @@ -4259,4 +4341,4 @@ kuberay = ["kubernetes", "python-client", "pyyaml"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "44e4c5ae1cac883342a554d23175f758b8b60ea7a635c03bd2c9c10b86c365f7" +content-hash = "7d063237e1ebef30fd2ebdd01cbe16c73ddafb006d879e7b75ba4d6b5e95dc6d" diff --git a/pyproject.toml b/pyproject.toml index 026630c..625fc14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ pre-commit = "^3.3.2" dagit = "^1.3.9" pytest-cases = "^3.6.14" pytest-kubernetes = "^0.3.1" +blacken-docs = "^1.16.0" [build-system] requires = ["poetry-core"] diff --git a/tests/test_kuberay.py b/tests/test_kuberay.py index 6231a96..74c188c 100644 --- a/tests/test_kuberay.py +++ b/tests/test_kuberay.py @@ -9,12 +9,15 @@ import pytest import pytest_cases import ray -from dagster import asset, materialize_to_memory +from dagster import AssetExecutionContext, RunConfig, asset, materialize_to_memory from pytest_kubernetes.options import ClusterOptions from pytest_kubernetes.providers import AClusterManager, select_provider_manager -from dagster_ray.kuberay import KubeRayCluster -from dagster_ray.kuberay.resources import DEFAULT_HEAD_GROUP_SPEC, DEFAULT_WORKER_GROUP_SPECS, RayClusterConfig +from dagster_ray import RayResource +from dagster_ray.kuberay import KubeRayAPI, KubeRayCluster, RayClusterConfig, cleanup_kuberay_clusters +from dagster_ray.kuberay.configs import DEFAULT_HEAD_GROUP_SPEC, DEFAULT_WORKER_GROUP_SPECS +from dagster_ray.kuberay.ops import CleanupKuberayClustersConfig +from dagster_ray.kuberay.ray_cluster_api import PatchedRayClusterApi from tests import ROOT_DIR @@ -149,7 +152,31 @@ def ray_cluster_resource( # have have to first run port-forwarding with minikube # we can only init ray after that skip_init=True, - kubeconfig_file=str(k8s_with_raycluster.kubeconfig), + api=KubeRayAPI(kubeconfig_file=str(k8s_with_raycluster.kubeconfig)), + ray_cluster=RayClusterConfig( + image=dagster_ray_image, + head_group_spec=head_group_spec, + worker_group_specs=worker_group_specs, + ), + redis_port=redis_port, + ) + + +@pytest.fixture(scope="session") +def ray_cluster_resource_skip_cleanup( + k8s_with_raycluster: AClusterManager, + dagster_ray_image: str, + head_group_spec: Dict[str, Any], + worker_group_specs: List[Dict[str, Any]], +) -> KubeRayCluster: + redis_port = get_random_free_port() + + return KubeRayCluster( + # have have to first run port-forwarding with minikube + # we can only init ray after that + skip_init=True, + skip_cleanup=True, + api=KubeRayAPI(kubeconfig_file=str(k8s_with_raycluster.kubeconfig)), ray_cluster=RayClusterConfig( image=dagster_ray_image, head_group_spec=head_group_spec, @@ -169,10 +196,13 @@ def test_kuberay_cluster_resource( k8s_with_raycluster: AClusterManager, ): @asset - def my_asset(ray_cluster: KubeRayCluster) -> None: + # testing RayResource type annotation too! + def my_asset(context: AssetExecutionContext, ray_cluster: RayResource) -> None: # port-forward to the head node # because it's not possible to access it otherwise + assert isinstance(ray_cluster, KubeRayCluster) + with k8s_with_raycluster.port_forwarding( target=f"svc/{ray_cluster.cluster_name}-head-svc", source_port=cast(int, ray_cluster.redis_port), @@ -182,14 +212,76 @@ def my_asset(ray_cluster: KubeRayCluster) -> None: # now we can access the head node # hack the _host attribute to point to the port-forwarded address ray_cluster._host = "127.0.0.1" - ray_cluster.init_ray() # normally this would happen automatically during resource setup + ray_cluster.init_ray(context) # normally this would happen automatically during resource setup assert ray_cluster.context is not None # make sure a @remote function runs inside the cluster # not in localhost assert ray_cluster.cluster_name in ray.get(get_hostname.remote()) - materialize_to_memory( + ray_cluster_description = ray_cluster.api.kuberay.get_ray_cluster( + ray_cluster.cluster_name, k8s_namespace=ray_cluster.namespace + ) + assert ray_cluster_description["metadata"]["labels"]["dagster.io/run_id"] == context.run_id + assert ray_cluster_description["metadata"]["labels"]["dagster.io/cluster"] == ray_cluster.cluster_name + + result = materialize_to_memory( [my_asset], resources={"ray_cluster": ray_cluster_resource}, ) + + kuberay_api = PatchedRayClusterApi(config_file=str(k8s_with_raycluster.kubeconfig)) + + # make sure the RayCluster is cleaned up + + assert ( + len( + kuberay_api.list_ray_clusters( + k8s_namespace=ray_cluster_resource.namespace, label_selector=f"dagster.io/run_id={result.run_id}" + )["items"] + ) + == 0 + ) + + +def test_kuberay_cleanup_job( + ray_cluster_resource_skip_cleanup: KubeRayCluster, + k8s_with_raycluster: AClusterManager, +): + @asset + def my_asset(ray_cluster: RayResource) -> None: + assert isinstance(ray_cluster, KubeRayCluster) + + result = materialize_to_memory( + [my_asset], + resources={"ray_cluster": ray_cluster_resource_skip_cleanup}, + ) + + kuberay_api = PatchedRayClusterApi(config_file=str(k8s_with_raycluster.kubeconfig)) + + assert ( + len( + kuberay_api.list_ray_clusters( + k8s_namespace=ray_cluster_resource_skip_cleanup.namespace, + label_selector=f"dagster.io/run_id={result.run_id}", + )["items"] + ) + > 0 + ) + + cleanup_kuberay_clusters.execute_in_process( + resources={ + "kuberay_api": KubeRayAPI(kubeconfig_file=str(k8s_with_raycluster.kubeconfig)), + }, + run_config=RunConfig( + ops={ + "cleanup_kuberay_clusters": CleanupKuberayClustersConfig( + namespace=ray_cluster_resource_skip_cleanup.namespace, + ) + } + ), + ) + + assert not kuberay_api.list_ray_clusters( + k8s_namespace=ray_cluster_resource_skip_cleanup.namespace, label_selector=f"dagster.io/run_id={result.run_id}" + )["items"]