Skip to content

Commit

Permalink
Backport part code from #2288
Browse files Browse the repository at this point in the history
  • Loading branch information
继盛 committed Aug 18, 2021
1 parent 3c26528 commit 62e83da
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 149 deletions.
17 changes: 9 additions & 8 deletions mars/oscar/backends/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,15 @@ async def cancel(self,

async def wait_actor_pool_recovered(self, address: str,
main_address: str = None):
# get main_pool_address
control_message = ControlMessage(
new_message_id(), main_address,
ControlMessageType.get_config,
'main_pool_address',
protocol=DEFAULT_PROTOCOL)
main_address = self._process_result_message(
await self._call(main_address, control_message))
if main_address is None:
# get main_pool_address
control_message = ControlMessage(
new_message_id(), address,
ControlMessageType.get_config,
'main_pool_address',
protocol=DEFAULT_PROTOCOL)
main_address = self._process_result_message(
await self._call(address, control_message))

# if address is main pool, it is never recovered
if address == main_address:
Expand Down
14 changes: 14 additions & 0 deletions mars/oscar/backends/mars/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from ....utils import get_next_port
from ..config import ActorPoolConfig
from ..message import CreateActorMessage
from ..pool import MainActorPoolBase, SubActorPoolBase, _register_message_handler


Expand Down Expand Up @@ -198,6 +199,19 @@ async def kill_sub_pool(self, process: multiprocessing.Process,
async def is_sub_pool_alive(self, process: multiprocessing.Process):
return process.is_alive()

async def recover_sub_pool(self, address: str):
process_index = self._config.get_process_index(address)
# process dead, restart it
# remember always use spawn to recover sub pool
self.sub_processes[address] = await self.__class__.start_sub_pool(
self._config, process_index, 'spawn')

if self._auto_recover == 'actor':
# need to recover all created actors
for _, message in self._allocated_actors[address].values():
create_actor_message: CreateActorMessage = message
await self.call(address, create_actor_message)


@_register_message_handler
class SubActorPool(SubActorPoolBase):
Expand Down
13 changes: 0 additions & 13 deletions mars/oscar/backends/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,19 +1025,6 @@ def process_sub_pool_lost(self, address: str):
# or only recover process, remove all created actors
self._allocated_actors[address] = dict()

async def recover_sub_pool(self, address: str):
process_index = self._config.get_process_index(address)
# process dead, restart it
# remember always use spawn to recover sub pool
self.sub_processes[address] = await self.__class__.start_sub_pool(
self._config, process_index, 'spawn')

if self._auto_recover == 'actor':
# need to recover all created actors
for _, message in self._allocated_actors[address].values():
create_actor_message: CreateActorMessage = message
await self.call(address, create_actor_message)

async def monitor_sub_pools(self):
try:
while not self._stopped.is_set():
Expand Down
108 changes: 17 additions & 91 deletions mars/oscar/backends/ray/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging
from abc import ABC
from collections import namedtuple
from typing import Any, Callable, Coroutine, Dict, Type, Union
from typing import Any, Callable, Coroutine, Dict, Type
from urllib.parse import urlparse

from ....utils import implements, classproperty
Expand Down Expand Up @@ -88,7 +88,8 @@ def closed(self) -> bool:


class RayClientChannel(RayChannelBase):
"""A channel from ray driver to ray actor. Use ray call reply for client channel recv.
"""
A channel from ray driver/actor to ray actor. Use ray call reply for client channel recv.
"""
__slots__ = '_peer_actor',

Expand Down Expand Up @@ -121,16 +122,21 @@ async def recv(self):
if isinstance(result, RayChannelException):
raise result.exc_value.with_traceback(result.exc_traceback)
return deserialize(*result)
except ray.exceptions.RayActorError:
if not self._closed.is_set():
# raise a EOFError as the SocketChannel does
raise EOFError('Server may be closed')
except (RuntimeError, ServerClosed) as e: # pragma: no cover
if not self._closed.is_set():
raise e


class RayServerChannel(RayChannelBase):
"""A channel from ray actor to ray driver. Since ray actor can't call ray driver,
we use ray call reply for server channel send. Note that there can't be multiple
channel message sends for one received message, or else it will be taken as next
message's reply.
"""
A channel from ray actor to ray driver/actor. Since ray actor can't call ray driver,
we use ray call reply for server channel send. Note that there can't be multiple
channel message sends for one received message, or else it will be taken as next
message's reply.
"""
__slots__ = '_out_queue', '_msg_recv_counter', '_msg_sent_counter'

Expand Down Expand Up @@ -179,58 +185,6 @@ async def __on_ray_recv__(self, message):
return await done.pop()


class RayTwoWayChannel(RayChannelBase):
"""
Channel for communications between ray actors.
"""

__slots__ = '_peer_actor',

def __init__(self,
local_address: str = None,
dest_address: str = None,
channel_index: int = None,
channel_id: ChannelID = None,
compression=None):
super().__init__(local_address, dest_address, channel_index, channel_id, compression)
# ray actor should be created with the address as the name.
self._peer_actor: 'ray.actor.ActorHandle' = ray.get_actor(dest_address)

@implements(Channel.send)
async def send(self, message: Any):
if self._closed.is_set(): # pragma: no cover
raise ChannelClosed('Channel already closed, cannot send message')
object_ref = self._peer_actor.__on_ray_recv__.remote(self.channel_id, serialize(message))
with debug_async_timeout('ray_object_retrieval_timeout',
'Message that the server sent to actor %s is %s', self.dest_address, message):
result = await object_ref
if isinstance(result, RayChannelException): # pragma: no cover
# Peer create channel may fail
raise result.exc_value.with_traceback(result.exc_traceback)
elif isinstance(result, Exception):
raise result
else:
assert result is None

@implements(Channel.recv)
async def recv(self):
if self._closed.is_set(): # pragma: no cover
raise ChannelClosed('Channel already closed, cannot write message')
try:
result = await self._in_queue.get()
if isinstance(result, RayChannelException): # pragma: no cover
raise result.exc_value.with_traceback(result.exc_traceback)
return deserialize(*result)
except (RuntimeError, ServerClosed) as e: # pragma: no cover
if not self._closed.is_set():
raise e

async def __on_ray_recv__(self, message):
if self._closed.is_set(): # pragma: no cover
raise ChannelClosed('Channel already closed')
await self._in_queue.put(message)


@register_server
class RayServer(Server):
__slots__ = '_closed', '_channels', '_tasks'
Expand All @@ -242,7 +196,7 @@ class RayServer(Server):
def __init__(self, address, channel_handler: Callable[[Channel], Coroutine] = None):
super().__init__(address, channel_handler)
self._closed = asyncio.Event()
self._channels: Dict[ChannelID, Union[RayServerChannel, RayTwoWayChannel]] = dict()
self._channels: Dict[ChannelID, RayServerChannel] = dict()
self._tasks: Dict[ChannelID, asyncio.Task] = dict()

@classproperty
Expand Down Expand Up @@ -309,7 +263,7 @@ async def join(self, timeout=None):
@implements(Server.on_connected)
async def on_connected(self, *args, **kwargs):
channel = args[0]
assert isinstance(channel, (RayServerChannel, RayTwoWayChannel))
assert isinstance(channel, RayServerChannel)
if kwargs: # pragma: no cover
raise TypeError(f'{type(self).__name__} got unexpected '
f'arguments: {",".join(kwargs)}')
Expand Down Expand Up @@ -337,21 +291,12 @@ async def __on_ray_recv__(self, channel_id: ChannelID, message):
f'from channel {channel_id}')
channel = self._channels.get(channel_id)
if not channel:
peer_local_address, _, peer_channel_index, peer_dest_address = channel_id
if not peer_local_address:
# Peer is a ray driver.
channel = RayServerChannel(peer_dest_address, peer_channel_index, channel_id)
else:
# Peer is a ray actor too.
channel = RayTwoWayChannel(
peer_dest_address, peer_local_address, peer_channel_index, channel_id)
_, _, peer_channel_index, peer_dest_address = channel_id
channel = RayServerChannel(peer_dest_address, peer_channel_index, channel_id)
self._channels[channel_id] = channel
self._tasks[channel_id] = asyncio.create_task(self.on_connected(channel))
return await channel.__on_ray_recv__(message)

def register_channel(self, channel_id: ChannelID, channel: RayTwoWayChannel):
self._channels[channel_id] = channel


@register_client
class RayClient(Client):
Expand All @@ -375,26 +320,7 @@ async def connect(dest_address: str,
if urlparse(dest_address).scheme != RayServer.scheme: # pragma: no cover
raise ValueError(f'Destination address should start with "ray://" '
f'for RayClient, got {dest_address}')
if local_address:
if not RayServer.is_ray_actor_started(): # pragma: no cover
logger.info(f'Current process needs to be a ray actor for using {local_address} '
f'as address to receive messages from other clients. '
f'Use RayClientChannel instead to receive messages from {dest_address}.')
local_address = None # make peer use RayClientChannel too.
client_channel = RayClientChannel(dest_address)
else:
server = RayServer.get_instance()
if server is None:
raise RuntimeError(f'RayServer needs to be created first before RayClient '
f'local_address {local_address}, dest_address {dest_address}')
# Current process ia a ray actor, is connecting to another ray actor.
client_channel = RayTwoWayChannel(local_address, dest_address)
# The RayServer will push message to this channel's queue after it received
# the message from the `dest_address` actor.
server.register_channel(client_channel.channel_id, client_channel)
else:
# Current process ia a ray driver
client_channel = RayClientChannel(dest_address)
client_channel = RayClientChannel(dest_address)
client = RayClient(local_address, dest_address, client_channel)
return client

Expand Down
6 changes: 5 additions & 1 deletion mars/oscar/backends/ray/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import logging
import os
from numbers import Number
from typing import Dict

Expand Down Expand Up @@ -56,10 +57,13 @@ def stop_cluster(cls):
pg_name = cls._cluster_info['pg_name']
pg = cls._cluster_info['pg_group']
for index, bundle_spec in enumerate(pg.bundle_specs):
n_process = int(bundle_spec["CPU"])
n_process = int(bundle_spec["CPU"]) + 1
for process_index in range(n_process):
address = process_placement_to_address(pg_name, index, process_index=process_index)
try:
if 'COV_CORE_SOURCE' in os.environ: # pragma: no cover
# must clean up first, or coverage info lost
ray.get(ray.get_actor(address).cleanup.remote())
ray.kill(ray.get_actor(address))
except: # noqa: E722 # nosec # pylint: disable=bare-except
pass
Expand Down
68 changes: 38 additions & 30 deletions mars/oscar/backends/ray/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@
from enum import Enum
from typing import List, Optional

from ....serialization.ray import register_ray_serializers
from ....utils import lazy_import
from .communication import ChannelID, RayServer, RayChannelException
from .utils import process_address_to_placement, process_placement_to_address, get_placement_group
from ..config import ActorPoolConfig
from ..message import CreateActorMessage
from ..pool import AbstractActorPool, MainActorPoolBase, SubActorPoolBase, create_actor_pool, _register_message_handler
from ..router import Router
from ... import ServerClosed
from .communication import ChannelID, RayServer, RayChannelException
from .utils import process_address_to_placement, process_placement_to_address, get_placement_group
from ....serialization.ray import register_ray_serializers
from ....utils import lazy_import

ray = lazy_import('ray')
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -78,27 +79,30 @@ async def start_sub_pool(
actor_handle = ray.remote(RaySubPool).options(
num_cpus=num_cpus, name=external_address,
max_concurrency=10000, # By default, 1000 tasks can be running concurrently.
max_restarts=-1, # Auto restarts by ray
placement_group=pg, placement_group_bundle_index=bundle_index).remote()
await actor_handle.start.remote(actor_pool_config, process_index)
return actor_handle

async def recover_sub_pool(self, address: str):
process = self.sub_processes[address]
process_index = self._config.get_process_index(address)
await process.start.remote(self._config, process_index)

if self._auto_recover == 'actor':
# need to recover all created actors
for _, message in self._allocated_actors[address].values():
create_actor_message: CreateActorMessage = message
await self.call(address, create_actor_message)

async def kill_sub_pool(self, process: 'ray.actor.ActorHandle', force: bool = False):
if 'COV_CORE_SOURCE' in os.environ and not force: # pragma: no cover
# must shutdown gracefully, or coverage info lost
process.exit_actor.remote()
wait_time, waited_time = 10, 0
while await self.is_sub_pool_alive(process): # pragma: no cover
if waited_time > wait_time:
logger.info('''Can't stop %s in %s, kill sub_pool forcibly''', process, wait_time)
await self._kill_actor_forcibly(process)
return
await asyncio.sleep(0.1)
wait_time += 0.1
else:
await self._kill_actor_forcibly(process)
# must clean up first, or coverage info lost
await process.cleanup.remote()
await self._kill_actor_forcibly(process)

async def _kill_actor_forcibly(self, process: 'ray.actor.ActorHandle'):
ray.kill(process)
ray.kill(process, no_restart=False)
wait_time, waited_time = 30, 0
while await self.is_sub_pool_alive(process): # pragma: no cover
if waited_time > wait_time:
Expand All @@ -108,13 +112,19 @@ async def _kill_actor_forcibly(self, process: 'ray.actor.ActorHandle'):

async def is_sub_pool_alive(self, process: 'ray.actor.ActorHandle'):
try:
await process.health_check.remote()
# try to call the method of sup pool, if success, it's alive.
await process.actor_pool.remote('health_check')
return True
except Exception:
logger.info("Detected RaySubPool %s died", process)
return False


class PoolStatus(Enum):
HEALTHY = 0
UNHEALTHY = 1


@_register_message_handler
class RaySubActorPool(SubActorPoolBase):

Expand All @@ -127,10 +137,8 @@ async def stop(self):
finally:
self._stopped.set()


class PoolStatus(Enum):
HEALTHY = 0
UNHEALTHY = 1
def health_check(self): # noqa: R0201 # pylint: disable=no-self-use
return PoolStatus.HEALTHY


class RayPoolBase(ABC):
Expand Down Expand Up @@ -170,12 +178,9 @@ async def __on_ray_recv__(self, channel_id: ChannelID, message):
if self._ray_server is None:
raise ServerClosed(f'Remote server {channel_id.dest_address} closed')
return await self._ray_server.__on_ray_recv__(channel_id, message)
except Exception: # pragma: no cover
except Exception: # pragma: no cover
return RayChannelException(*sys.exc_info())

def health_check(self): # noqa: R0201 # pylint: disable=no-self-use
return PoolStatus.HEALTHY

async def actor_pool(self, attribute, *args, **kwargs):
attr = getattr(self._actor_pool, attribute)
if isinstance(attr, types.MethodType):
Expand All @@ -185,10 +190,13 @@ async def actor_pool(self, attribute, *args, **kwargs):
else:
return attr

def exit_actor(self):
"""Exiting current process gracefully."""
logger.info('Exiting %s of process %s now', self, os.getpid())
ray.actor.exit_actor()
def cleanup(self):
logger.info('Cleaning up %s of process %s now', self, os.getpid())
try:
from pytest_cov.embed import cleanup
cleanup()
except ImportError: # pragma: no cover
pass


class RayMainPool(RayPoolBase):
Expand Down
Loading

0 comments on commit 62e83da

Please sign in to comment.