Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ray Task Support #1093

Merged
merged 31 commits into from Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/pythonbuild.yml
Expand Up @@ -46,6 +46,7 @@ jobs:
- name: Install dependencies
run: |
make setup${{ matrix.spark-version-suffix }}
pip install --no-deps -U git+https://github.com/flyteorg/flyteidl@ray
pip freeze
- name: Test with coverage
run: |
Expand Down Expand Up @@ -80,6 +81,7 @@ jobs:
- flytekit-pandera
- flytekit-papermill
- flytekit-polars
- flytekit-ray
- flytekit-snowflake
- flytekit-spark
- flytekit-sqlalchemy
Expand Down Expand Up @@ -112,6 +114,7 @@ jobs:
pip install -r requirements.txt
if [ -f dev-requirements.txt ]; then pip install -r dev-requirements.txt; fi
pip install --no-deps -U https://github.com/flyteorg/flytekit/archive/${{ github.sha }}.zip#egg=flytekit
pip install --no-deps -U git+https://github.com/flyteorg/flyteidl@ray
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
pip freeze
- name: Test with coverage
run: |
Expand All @@ -137,6 +140,7 @@ jobs:
run: |
python -m pip install --upgrade pip==21.2.4
pip install -r dev-requirements.txt
pip install --no-deps -U git+https://github.com/flyteorg/flyteidl@ray
- name: Lint
run: |
make lint
Expand All @@ -158,5 +162,6 @@ jobs:
run: |
python -m pip install --upgrade pip==21.2.4 setuptools wheel
pip install -r doc-requirements.txt
pip install --no-deps -U git+https://github.com/flyteorg/flyteidl@ray
- name: Build the documentation
run: make -C docs html
9 changes: 9 additions & 0 deletions plugins/flytekit-ray/README.md
@@ -0,0 +1,9 @@
# Flytekit Ray Plugin

Flyte backend can be connected with Ray. Once enabled, it allows you to run flyte task on Ray cluster

To install the plugin, run the following command:

```bash
pip install flytekitplugins-ray
```
13 changes: 13 additions & 0 deletions plugins/flytekit-ray/flytekitplugins/ray/__init__.py
@@ -0,0 +1,13 @@
"""
.. currentmodule:: flytekitplugins.ray

This package contains things that are useful when extending Flytekit.

.. autosummary::
:template: custom.rst
:toctree: generated/

RayConfig
"""

from .task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig
251 changes: 251 additions & 0 deletions plugins/flytekit-ray/flytekitplugins/ray/models.py
@@ -0,0 +1,251 @@
import typing

from flyteidl.plugins import ray_pb2 as _ray_pb2

from flytekit.models import common as _common


class WorkerGroupSpec(_common.FlyteIdlEntity):
def __init__(
self,
group_name: str,
replicas: int,
min_replicas: typing.Optional[int] = 0,
max_replicas: typing.Optional[int] = None,
ray_start_params: typing.Optional[typing.Dict[str, str]] = None,
):
self._group_name = group_name
self._replicas = replicas
self._min_replicas = min_replicas
self._max_replicas = max_replicas if max_replicas else replicas
self._ray_start_params = ray_start_params

@property
def group_name(self):
"""
Group name of the current worker group.
:rtype: str
"""
return self._group_name

@property
def replicas(self):
"""
Desired replicas of the worker group.
:rtype: int
"""
return self._replicas

@property
def min_replicas(self):
"""
Min replicas of the worker group.
:rtype: int
"""
return self._min_replicas

@property
def max_replicas(self):
"""
Max replicas of the worker group.
:rtype: int
"""
return self._max_replicas

@property
def ray_start_params(self):
"""
The ray start params of worker node group.
:rtype: typing.Dict[str, str]
"""
return self._ray_start_params

def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins._ray_pb2.WorkerGroupSpec
"""
return _ray_pb2.WorkerGroupSpec(
group_name=self.group_name,
replicas=self.replicas,
min_replicas=self.min_replicas,
max_replicas=self.max_replicas,
ray_start_params=self.ray_start_params,
)

@classmethod
def from_flyte_idl(cls, proto):
"""
:param flyteidl.plugins._ray_pb2.WorkerGroupSpec proto:
:rtype: WorkerGroupSpec
"""
return cls(
group_name=proto.group_name,
replicas=proto.replicas,
min_replicas=proto.min_replicas,
max_replicas=proto.max_replicas,
ray_start_params=proto.ray_start_params,
)


class HeadGroupSpec(_common.FlyteIdlEntity):
def __init__(
self,
ray_start_params: typing.Optional[typing.Dict[str, str]] = None,
):
self._ray_start_params = ray_start_params

@property
def ray_start_params(self):
"""
The ray start params of worker node group.
:rtype: typing.Dict[str, str]
"""
return self._ray_start_params

def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins._ray_pb2.HeadGroupSpec
"""
return _ray_pb2.HeadGroupSpec(
ray_start_params=self.ray_start_params if self.ray_start_params else {},
)

@classmethod
def from_flyte_idl(cls, proto):
"""
:param flyteidl.plugins._ray_pb2.HeadGroupSpec proto:
:rtype: HeadGroupSpec
"""
return cls(
ray_start_params=proto.ray_start_params,
)


class ClusterSpec(_common.FlyteIdlEntity):
def __init__(
self,
worker_group_spec: typing.List[WorkerGroupSpec],
head_group_spec: typing.Optional[HeadGroupSpec] = HeadGroupSpec(),
):
self._head_group_spec = head_group_spec
self._worker_group_spec = worker_group_spec

@property
def head_group_spec(self) -> HeadGroupSpec:
"""
The head group configuration.
:rtype: HeadGroupSpec
"""
return self._head_group_spec

@property
def worker_group_spec(self) -> typing.List[WorkerGroupSpec]:
"""
The worker group configurations.
:rtype: typing.List[WorkerGroupSpec]
"""
return self._worker_group_spec

def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins._ray_pb2.ClusterSpec
"""
return _ray_pb2.ClusterSpec(
head_group_spec=self.head_group_spec.to_flyte_idl(),
worker_group_spec=[wg.to_flyte_idl() for wg in self.worker_group_spec],
)

@classmethod
def from_flyte_idl(cls, proto):
"""
:param flyteidl.plugins._ray_pb2.ClusterSpec proto:
:rtype: ClusterSpec
"""
return cls(
head_group_spec=HeadGroupSpec.from_flyte_idl(proto.head_group_spec) if proto.head_group_spec else None,
worker_group_spec=[WorkerGroupSpec.from_flyte_idl(wg) for wg in proto.worker_group_spec]
if proto.worker_group_spec
else None,
)


class RayCluster(_common.FlyteIdlEntity):
"""
Define RayCluster spec that will be used by KubeRay to launch the cluster.
"""

def __init__(self, cluster_spec: ClusterSpec):
self._cluster_spec = cluster_spec

@property
def cluster_spec(self):
"""
This field indicates ray cluster configuration
:rtype: ClusterSpec
"""
return self._cluster_spec

def to_flyte_idl(self) -> _ray_pb2.RayCluster:
"""
:rtype: flyteidl.plugins._ray_pb2.RayCluster
"""
return _ray_pb2.RayCluster(cluster_spec=self.cluster_spec.to_flyte_idl())

@classmethod
def from_flyte_idl(cls, proto):
"""
:param flyteidl.plugins._ray_pb2.RayCluster proto:
:rtype: RayCluster
"""
return cls(cluster_spec=ClusterSpec.from_flyte_idl(proto.cluster_spec))


class RayJob(_common.FlyteIdlEntity):
"""
Models _ray_pb2.RayJob
"""

def __init__(
self,
ray_cluster: RayCluster,
runtime_env: typing.Optional[str],
shutdown_after_job_finishes: typing.Optional[bool] = True,
ttl_seconds_after_finished: typing.Optional[bool] = 3600,
):
self._ray_cluster = ray_cluster
self._runtime_env = runtime_env
self._shutdown_after_job_finishes = shutdown_after_job_finishes
self._ttl_seconds_after_finished = ttl_seconds_after_finished

@property
def ray_cluster(self) -> RayCluster:
return self._ray_cluster

@property
def runtime_env(self) -> typing.Optional[str]:
return self._runtime_env

@property
def shutdown_after_job_finishes(self) -> bool:
return self._shutdown_after_job_finishes

@property
def ttl_seconds_after_finished(self) -> int:
return self._ttl_seconds_after_finished

def to_flyte_idl(self) -> _ray_pb2.RayJob:
return _ray_pb2.RayJob(
ray_cluster=self.ray_cluster.to_flyte_idl(),
runtime_env=self.runtime_env,
shutdown_after_job_finishes=self.shutdown_after_job_finishes,
ttl_seconds_after_finished=self.ttl_seconds_after_finished,
)

@classmethod
def from_flyte_idl(cls, proto: _ray_pb2.RayJob):
return cls(
ray_cluster=RayCluster.from_flyte_idl(proto.ray_cluster) if proto.ray_cluster else None,
runtime_env=proto.runtime_env,
shutdown_after_job_finishes=proto.shutdown_after_job_finishes,
ttl_seconds_after_finished=proto.ttl_seconds_after_finished,
)