Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKPORT][Ray] Support Ray client mode (#2773) #2796

Merged
merged 2 commits into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/platform-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ jobs:
if [ -n "$WITH_RAY" ]; then
pip install ray[default]==1.9.2
pip install xgboost_ray==0.1.5
pip install --upgrade numpy
fi
if [ -n "$RUN_DASK" ]; then
pip install dask[complete] mimesis sklearn
Expand Down Expand Up @@ -131,7 +130,7 @@ jobs:
coverage combine build/ && coverage report
fi
if [ -n "$WITH_RAY" ]; then
pytest $PYTEST_CONFIG --timeout=300 -m ray
pytest $PYTEST_CONFIG --durations=0 --timeout=300 -v -s -m ray
coverage report
fi
if [ -n "$RUN_DASK" ]; then
Expand Down
2 changes: 1 addition & 1 deletion mars/deploy/oscar/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ async def request_worker(
# TODO rescale ray placement group instead of creating new placement group
pg_name = f"{self._pg_name}_{next(self._pg_counter)}"
pg = ray.util.placement_group(name=pg_name, bundles=[bundle], strategy="SPREAD")
create_pg_timeout = timeout or 60
create_pg_timeout = timeout or 120
try:
await asyncio.wait_for(pg.ready(), timeout=create_pg_timeout)
except asyncio.TimeoutError:
Expand Down
21 changes: 21 additions & 0 deletions mars/deploy/oscar/tests/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,25 @@ def test_new_cluster_in_ray(stop_ray):

@require_ray
def test_new_ray_session(stop_ray):
new_ray_session_test()


def new_ray_session_test():
session = new_ray_session(session_id="abc", worker_num=2)
mt.random.RandomState(0).rand(100, 5).sum().execute()
session.execute(mt.random.RandomState(0).rand(100, 5).sum())
mars.execute(mt.random.RandomState(0).rand(100, 5).sum())
session = new_ray_session(session_id="abcd", worker_num=2, default=True)
session.execute(mt.random.RandomState(0).rand(100, 5).sum())
mars.execute(mt.random.RandomState(0).rand(100, 5).sum())
df = md.DataFrame(mt.random.rand(100, 4), columns=list("abcd"))
# Convert mars dataframe to ray dataset
ds = md.to_ray_dataset(df)
print(ds.schema(), ds.count())
ds.filter(lambda row: row["a"] > 0.5).show(5)
# Convert ray dataset to mars dataframe
df2 = md.read_ray_dataset(ds)
print(df2.head(5).execute())
# Test ray cluster exists after session got gc.
del session
import gc
Expand All @@ -168,6 +180,15 @@ def test_new_ray_session(stop_ray):
mars.execute(mt.random.RandomState(0).rand(100, 5).sum())


@require_ray
def test_ray_client(ray_large_cluster):
from ray.util.client.ray_client_helpers import ray_start_client_server
from ray._private.client_mode_hook import enable_client_mode

with ray_start_client_server(), enable_client_mode():
new_ray_session_test()


@require_ray
@pytest.mark.parametrize(
"test_option",
Expand Down
46 changes: 38 additions & 8 deletions mars/oscar/backends/ray/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
from abc import ABC
from collections import namedtuple
from dataclasses import dataclass
from typing import Any, Callable, Coroutine, Dict, Type
from urllib.parse import urlparse

Expand All @@ -38,6 +39,35 @@
)


def _argwrapper_unpickler(serialized_message):
return _ArgWrapper(deserialize(*serialized_message))


@dataclass
class _ArgWrapper:
message: Any = None

def __init__(self, message):
self.message = message

def __reduce__(self):
return _argwrapper_unpickler, (serialize(self.message),)


if ray:
_ray_serialize = ray.serialization.SerializationContext.serialize
_ray_deserialize_object = ray.serialization.SerializationContext._deserialize_object

def _serialize(self, value):
return _ray_serialize(self, value)

def _deserialize_object(self, data, metadata, object_ref):
return _ray_deserialize_object(self, data, metadata, object_ref)

ray.serialization.SerializationContext.serialize = _serialize
ray.serialization.SerializationContext._deserialize_object = _deserialize_object


class RayChannelException(Exception):
def __init__(self, exc_type, exc_value: BaseException, exc_traceback):
self.exc_type = exc_type
Expand Down Expand Up @@ -121,7 +151,7 @@ async def send(self, message: Any):
(
message,
self._peer_actor.__on_ray_recv__.remote(
self.channel_id, serialize(message)
self.channel_id, _ArgWrapper(message)
),
)
)
Expand All @@ -139,7 +169,7 @@ async def recv(self):
result = await object_ref
if isinstance(result, RayChannelException):
raise result.exc_value.with_traceback(result.exc_traceback)
return deserialize(*result)
return result.message
except ray.exceptions.RayActorError:
if not self._closed.is_set():
# raise a EOFError as the SocketChannel does
Expand Down Expand Up @@ -178,7 +208,7 @@ async def send(self, message: Any):
# Current process is ray actor, we use ray call reply to send message to ray driver/actor.
# Not that we can only send once for every read message in channel, otherwise
# it will be taken as other message's reply.
await self._out_queue.put(serialize(message))
await self._out_queue.put(message)
self._msg_sent_counter += 1
assert (
self._msg_sent_counter <= self._msg_recv_counter
Expand All @@ -189,19 +219,19 @@ async def recv(self):
if self._closed.is_set(): # pragma: no cover
raise ChannelClosed("Channel already closed, cannot write message")
try:
return deserialize(*(await self._in_queue.get()))
return await self._in_queue.get()
except RuntimeError: # pragma: no cover
if not self._closed.is_set():
raise

async def __on_ray_recv__(self, message):
async def __on_ray_recv__(self, message_wrapper):
"""This method will be invoked when current process is a ray actor rather than a ray driver"""
self._msg_recv_counter += 1
await self._in_queue.put(message)
await self._in_queue.put(message_wrapper.message)
result_message = await self._out_queue.get()
if self._closed.is_set(): # pragma: no cover
raise ChannelClosed("Channel already closed")
return result_message
return _ArgWrapper(result_message)

@implements(Channel.close)
async def close(self):
Expand Down Expand Up @@ -319,7 +349,7 @@ def stopped(self) -> bool:
async def __on_ray_recv__(self, channel_id: ChannelID, message):
if self.stopped:
raise ServerClosed(
f"Remote server {self.address} closed, but got message {deserialize(*message)} "
f"Remote server {self.address} closed, but got message {message} "
f"from channel {channel_id}"
)
channel = self._channels.get(channel_id)
Expand Down
2 changes: 1 addition & 1 deletion mars/oscar/backends/ray/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def setup_cluster(cls, address_to_resources: Dict[str, Dict[str, Number]]):
pg_name, bundles = addresses_to_placement_group_info(address_to_resources)
logger.info("Creating placement group %s with bundles %s.", pg_name, bundles)
pg = ray.util.placement_group(name=pg_name, bundles=bundles, strategy="SPREAD")
create_pg_timeout = 60
create_pg_timeout = 120
done, _ = ray.wait([pg.ready()], timeout=create_pg_timeout)
if not done: # pragma: no cover
raise Exception(
Expand Down
14 changes: 13 additions & 1 deletion mars/oscar/backends/ray/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import sys
import time
import threading
import types
from abc import ABC, abstractmethod
from enum import Enum
Expand Down Expand Up @@ -205,7 +206,10 @@ class RayPoolBase(ABC):
def __new__(cls, *args, **kwargs):
if not _is_windows:
try:
if "COV_CORE_SOURCE" in os.environ: # pragma: no branch
if (
"COV_CORE_SOURCE" in os.environ
and threading.current_thread() is threading.main_thread()
): # pragma: no branch
# register coverage hooks on SIGTERM
from pytest_cov.embed import cleanup_on_sigterm

Expand Down Expand Up @@ -375,3 +379,11 @@ async def check_main_pool_alive(self, main_pool):
"Main pool %s has exited, exit current sub pool now.", main_pool
)
os._exit(0)


if ray and ray.is_initialized():
# When using ray client to connect to a ray cluster, ray server will act as mars driver. All mars call from mars
# client will go to ray server first, then the ray server will ray call to other actors. So the ray server need to
# register ray serializers.
# TODO Need a way to check whether current process is a ray server.
register_ray_serializers()