Skip to content

Commit

Permalink
[BACKPORT] Improve wait_actor_pool_recovered (#2328) (#2350)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xuye (Chris) Qin authored Aug 18, 2021
1 parent 97443d7 commit d0282e4
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 158 deletions.
17 changes: 9 additions & 8 deletions mars/oscar/backends/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,15 @@ async def cancel(self,

async def wait_actor_pool_recovered(self, address: str,
main_address: str = None):
# get main_pool_address
control_message = ControlMessage(
new_message_id(), main_address,
ControlMessageType.get_config,
'main_pool_address',
protocol=DEFAULT_PROTOCOL)
main_address = self._process_result_message(
await self._call(main_address, control_message))
if main_address is None:
# get main_pool_address
control_message = ControlMessage(
new_message_id(), address,
ControlMessageType.get_config,
'main_pool_address',
protocol=DEFAULT_PROTOCOL)
main_address = self._process_result_message(
await self._call(address, control_message))

# if address is main pool, it is never recovered
if address == main_address:
Expand Down
14 changes: 14 additions & 0 deletions mars/oscar/backends/mars/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from ....utils import get_next_port
from ..config import ActorPoolConfig
from ..message import CreateActorMessage
from ..pool import MainActorPoolBase, SubActorPoolBase, _register_message_handler


Expand Down Expand Up @@ -198,6 +199,19 @@ async def kill_sub_pool(self, process: multiprocessing.Process,
async def is_sub_pool_alive(self, process: multiprocessing.Process):
return process.is_alive()

async def recover_sub_pool(self, address: str):
process_index = self._config.get_process_index(address)
# process dead, restart it
# remember always use spawn to recover sub pool
self.sub_processes[address] = await self.__class__.start_sub_pool(
self._config, process_index, 'spawn')

if self._auto_recover == 'actor':
# need to recover all created actors
for _, message in self._allocated_actors[address].values():
create_actor_message: CreateActorMessage = message
await self.call(address, create_actor_message)


@_register_message_handler
class SubActorPool(SubActorPoolBase):
Expand Down
30 changes: 11 additions & 19 deletions mars/oscar/backends/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,11 +864,12 @@ async def handle_control_command(self,
self.sub_processes[message.address],
timeout=timeout,
force=force)
if self._auto_recover:
self._recover_events[message.address] = asyncio.Event()
processor.result = ResultMessage(message.message_id, True,
protocol=message.protocol)
elif message.control_message_type == ControlMessageType.wait_pool_recovered:
if self._auto_recover and message.address not in self._recover_events:
self._recover_events[message.address] = asyncio.Event()

event = self._recover_events.get(message.address, None)
if event is not None:
await event.wait()
Expand Down Expand Up @@ -1014,42 +1015,33 @@ async def is_sub_pool_alive(self, process: SubProcessHandle):
bool
"""

@abstractmethod
def recover_sub_pool(self, address):
"""Recover a sub actor pool"""

def process_sub_pool_lost(self, address: str):
if self._auto_recover in (False, 'process'):
# process down, when not auto_recover
# or only recover process, remove all created actors
self._allocated_actors[address] = dict()

async def recover_sub_pool(self, address: str):
process_index = self._config.get_process_index(address)
# process dead, restart it
# remember always use spawn to recover sub pool
self.sub_processes[address] = await self.__class__.start_sub_pool(
self._config, process_index, 'spawn')

if self._auto_recover == 'actor':
# need to recover all created actors
for _, message in self._allocated_actors[address].values():
create_actor_message: CreateActorMessage = message
await self.call(address, create_actor_message)

async def monitor_sub_pools(self):
try:
while not self._stopped.is_set():
for address in self.sub_processes:
process = self.sub_processes[address]
recover_events_discovered = (address in self._recover_events)
if not await self.is_sub_pool_alive(process): # pragma: no cover
if self._on_process_down is not None:
self._on_process_down(self, address)
self.process_sub_pool_lost(address)
if self._auto_recover:
if address not in self._recover_events:
self._recover_events[address] = asyncio.Event()
await self.recover_sub_pool(address)
if self._on_process_recover is not None:
self._on_process_recover(self, address)
event = self._recover_events.pop(address)
event.set()
if recover_events_discovered:
event = self._recover_events.pop(address)
event.set()

# check every half second
await asyncio.sleep(.5)
Expand Down
108 changes: 17 additions & 91 deletions mars/oscar/backends/ray/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging
from abc import ABC
from collections import namedtuple
from typing import Any, Callable, Coroutine, Dict, Type, Union
from typing import Any, Callable, Coroutine, Dict, Type
from urllib.parse import urlparse

from ....utils import implements, classproperty
Expand Down Expand Up @@ -88,7 +88,8 @@ def closed(self) -> bool:


class RayClientChannel(RayChannelBase):
"""A channel from ray driver to ray actor. Use ray call reply for client channel recv.
"""
A channel from ray driver/actor to ray actor. Use ray call reply for client channel recv.
"""
__slots__ = '_peer_actor',

Expand Down Expand Up @@ -121,16 +122,21 @@ async def recv(self):
if isinstance(result, RayChannelException):
raise result.exc_value.with_traceback(result.exc_traceback)
return deserialize(*result)
except ray.exceptions.RayActorError:
if not self._closed.is_set():
# raise a EOFError as the SocketChannel does
raise EOFError('Server may be closed')
except (RuntimeError, ServerClosed) as e: # pragma: no cover
if not self._closed.is_set():
raise e


class RayServerChannel(RayChannelBase):
"""A channel from ray actor to ray driver. Since ray actor can't call ray driver,
we use ray call reply for server channel send. Note that there can't be multiple
channel message sends for one received message, or else it will be taken as next
message's reply.
"""
A channel from ray actor to ray driver/actor. Since ray actor can't call ray driver,
we use ray call reply for server channel send. Note that there can't be multiple
channel message sends for one received message, or else it will be taken as next
message's reply.
"""
__slots__ = '_out_queue', '_msg_recv_counter', '_msg_sent_counter'

Expand Down Expand Up @@ -179,58 +185,6 @@ async def __on_ray_recv__(self, message):
return await done.pop()


class RayTwoWayChannel(RayChannelBase):
"""
Channel for communications between ray actors.
"""

__slots__ = '_peer_actor',

def __init__(self,
local_address: str = None,
dest_address: str = None,
channel_index: int = None,
channel_id: ChannelID = None,
compression=None):
super().__init__(local_address, dest_address, channel_index, channel_id, compression)
# ray actor should be created with the address as the name.
self._peer_actor: 'ray.actor.ActorHandle' = ray.get_actor(dest_address)

@implements(Channel.send)
async def send(self, message: Any):
if self._closed.is_set(): # pragma: no cover
raise ChannelClosed('Channel already closed, cannot send message')
object_ref = self._peer_actor.__on_ray_recv__.remote(self.channel_id, serialize(message))
with debug_async_timeout('ray_object_retrieval_timeout',
'Message that the server sent to actor %s is %s', self.dest_address, message):
result = await object_ref
if isinstance(result, RayChannelException): # pragma: no cover
# Peer create channel may fail
raise result.exc_value.with_traceback(result.exc_traceback)
elif isinstance(result, Exception):
raise result
else:
assert result is None

@implements(Channel.recv)
async def recv(self):
if self._closed.is_set(): # pragma: no cover
raise ChannelClosed('Channel already closed, cannot write message')
try:
result = await self._in_queue.get()
if isinstance(result, RayChannelException): # pragma: no cover
raise result.exc_value.with_traceback(result.exc_traceback)
return deserialize(*result)
except (RuntimeError, ServerClosed) as e: # pragma: no cover
if not self._closed.is_set():
raise e

async def __on_ray_recv__(self, message):
if self._closed.is_set(): # pragma: no cover
raise ChannelClosed('Channel already closed')
await self._in_queue.put(message)


@register_server
class RayServer(Server):
__slots__ = '_closed', '_channels', '_tasks'
Expand All @@ -242,7 +196,7 @@ class RayServer(Server):
def __init__(self, address, channel_handler: Callable[[Channel], Coroutine] = None):
super().__init__(address, channel_handler)
self._closed = asyncio.Event()
self._channels: Dict[ChannelID, Union[RayServerChannel, RayTwoWayChannel]] = dict()
self._channels: Dict[ChannelID, RayServerChannel] = dict()
self._tasks: Dict[ChannelID, asyncio.Task] = dict()

@classproperty
Expand Down Expand Up @@ -309,7 +263,7 @@ async def join(self, timeout=None):
@implements(Server.on_connected)
async def on_connected(self, *args, **kwargs):
channel = args[0]
assert isinstance(channel, (RayServerChannel, RayTwoWayChannel))
assert isinstance(channel, RayServerChannel)
if kwargs: # pragma: no cover
raise TypeError(f'{type(self).__name__} got unexpected '
f'arguments: {",".join(kwargs)}')
Expand Down Expand Up @@ -337,21 +291,12 @@ async def __on_ray_recv__(self, channel_id: ChannelID, message):
f'from channel {channel_id}')
channel = self._channels.get(channel_id)
if not channel:
peer_local_address, _, peer_channel_index, peer_dest_address = channel_id
if not peer_local_address:
# Peer is a ray driver.
channel = RayServerChannel(peer_dest_address, peer_channel_index, channel_id)
else:
# Peer is a ray actor too.
channel = RayTwoWayChannel(
peer_dest_address, peer_local_address, peer_channel_index, channel_id)
_, _, peer_channel_index, peer_dest_address = channel_id
channel = RayServerChannel(peer_dest_address, peer_channel_index, channel_id)
self._channels[channel_id] = channel
self._tasks[channel_id] = asyncio.create_task(self.on_connected(channel))
return await channel.__on_ray_recv__(message)

def register_channel(self, channel_id: ChannelID, channel: RayTwoWayChannel):
self._channels[channel_id] = channel


@register_client
class RayClient(Client):
Expand All @@ -375,26 +320,7 @@ async def connect(dest_address: str,
if urlparse(dest_address).scheme != RayServer.scheme: # pragma: no cover
raise ValueError(f'Destination address should start with "ray://" '
f'for RayClient, got {dest_address}')
if local_address:
if not RayServer.is_ray_actor_started(): # pragma: no cover
logger.info(f'Current process needs to be a ray actor for using {local_address} '
f'as address to receive messages from other clients. '
f'Use RayClientChannel instead to receive messages from {dest_address}.')
local_address = None # make peer use RayClientChannel too.
client_channel = RayClientChannel(dest_address)
else:
server = RayServer.get_instance()
if server is None:
raise RuntimeError(f'RayServer needs to be created first before RayClient '
f'local_address {local_address}, dest_address {dest_address}')
# Current process ia a ray actor, is connecting to another ray actor.
client_channel = RayTwoWayChannel(local_address, dest_address)
# The RayServer will push message to this channel's queue after it received
# the message from the `dest_address` actor.
server.register_channel(client_channel.channel_id, client_channel)
else:
# Current process ia a ray driver
client_channel = RayClientChannel(dest_address)
client_channel = RayClientChannel(dest_address)
client = RayClient(local_address, dest_address, client_channel)
return client

Expand Down
6 changes: 5 additions & 1 deletion mars/oscar/backends/ray/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import logging
import os
from numbers import Number
from typing import Dict

Expand Down Expand Up @@ -56,10 +57,13 @@ def stop_cluster(cls):
pg_name = cls._cluster_info['pg_name']
pg = cls._cluster_info['pg_group']
for index, bundle_spec in enumerate(pg.bundle_specs):
n_process = int(bundle_spec["CPU"])
n_process = int(bundle_spec["CPU"]) + 1
for process_index in range(n_process):
address = process_placement_to_address(pg_name, index, process_index=process_index)
try:
if 'COV_CORE_SOURCE' in os.environ: # pragma: no cover
# must clean up first, or coverage info lost
ray.get(ray.get_actor(address).cleanup.remote())
ray.kill(ray.get_actor(address))
except: # noqa: E722 # nosec # pylint: disable=bare-except
pass
Expand Down
Loading

0 comments on commit d0282e4

Please sign in to comment.