From 3c2652844312e36015f2ec236ec14d1d812d7f3f Mon Sep 17 00:00:00 2001 From: keyile <32282736+keyile@users.noreply.github.com> Date: Sat, 14 Aug 2021 15:13:10 +0800 Subject: [PATCH 1/2] Improve wait_actor_pool_recovered (#2328) (cherry picked from commit 2fb00d8fbbaf1464dbf5e208c16d3237b62ede53) --- mars/oscar/backends/pool.py | 17 ++- mars/oscar/backends/ray/pool.py | 3 + .../oscar/backends/ray/tests/test_ray_pool.py | 115 +++++++++++++++++- 3 files changed, 126 insertions(+), 9 deletions(-) diff --git a/mars/oscar/backends/pool.py b/mars/oscar/backends/pool.py index c9e47bdba5..b8b1af3434 100644 --- a/mars/oscar/backends/pool.py +++ b/mars/oscar/backends/pool.py @@ -864,11 +864,12 @@ async def handle_control_command(self, self.sub_processes[message.address], timeout=timeout, force=force) - if self._auto_recover: - self._recover_events[message.address] = asyncio.Event() processor.result = ResultMessage(message.message_id, True, protocol=message.protocol) elif message.control_message_type == ControlMessageType.wait_pool_recovered: + if self._auto_recover and message.address not in self._recover_events: + self._recover_events[message.address] = asyncio.Event() + event = self._recover_events.get(message.address, None) if event is not None: await event.wait() @@ -1014,6 +1015,10 @@ async def is_sub_pool_alive(self, process: SubProcessHandle): bool """ + @abstractmethod + def recover_sub_pool(self, address): + """Recover a sub actor pool""" + def process_sub_pool_lost(self, address: str): if self._auto_recover in (False, 'process'): # process down, when not auto_recover @@ -1038,18 +1043,18 @@ async def monitor_sub_pools(self): while not self._stopped.is_set(): for address in self.sub_processes: process = self.sub_processes[address] + recover_events_discovered = (address in self._recover_events) if not await self.is_sub_pool_alive(process): # pragma: no cover if self._on_process_down is not None: self._on_process_down(self, address) self.process_sub_pool_lost(address) if self._auto_recover: - if address not in self._recover_events: - self._recover_events[address] = asyncio.Event() await self.recover_sub_pool(address) if self._on_process_recover is not None: self._on_process_recover(self, address) - event = self._recover_events.pop(address) - event.set() + if recover_events_discovered: + event = self._recover_events.pop(address) + event.set() # check every half second await asyncio.sleep(.5) diff --git a/mars/oscar/backends/ray/pool.py b/mars/oscar/backends/ray/pool.py index b50697efc6..23e22964e2 100644 --- a/mars/oscar/backends/ray/pool.py +++ b/mars/oscar/backends/ray/pool.py @@ -28,6 +28,7 @@ from ..config import ActorPoolConfig from ..pool import AbstractActorPool, MainActorPoolBase, SubActorPoolBase, create_actor_pool, _register_message_handler from ..router import Router +from ... import ServerClosed from .communication import ChannelID, RayServer, RayChannelException from .utils import process_address_to_placement, process_placement_to_address, get_placement_group @@ -166,6 +167,8 @@ def _set_ray_server(self, actor_pool: AbstractActorPool): async def __on_ray_recv__(self, channel_id: ChannelID, message): """Method for communication based on ray actors""" try: + if self._ray_server is None: + raise ServerClosed(f'Remote server {channel_id.dest_address} closed') return await self._ray_server.__on_ray_recv__(channel_id, message) except Exception: # pragma: no cover return RayChannelException(*sys.exc_info()) diff --git a/mars/oscar/backends/ray/tests/test_ray_pool.py b/mars/oscar/backends/ray/tests/test_ray_pool.py index 256c7a73bf..19e2841ed9 100644 --- a/mars/oscar/backends/ray/tests/test_ray_pool.py +++ b/mars/oscar/backends/ray/tests/test_ray_pool.py @@ -11,17 +11,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +import os import pytest +import mars.oscar as mo +from mars.oscar.errors import ServerClosed +from mars.oscar.backends.allocate_strategy import ProcessIndex, MainPool +from mars.oscar.backends.ray.pool import RayMainPool, RayMainActorPool, create_actor_pool, PoolStatus +from mars.oscar.backends.ray.utils import process_placement_to_address +from mars.oscar.context import get_context from mars.tests.core import require_ray -from .....utils import lazy_import -from ..pool import RayMainPool, RayMainActorPool, create_actor_pool -from ..utils import process_placement_to_address +from mars.utils import lazy_import ray = lazy_import('ray') +class TestActor(mo.Actor): + async def kill(self, address, uid): + actor_ref = await mo.actor_ref(address, uid) + task = asyncio.create_task(actor_ref.crash()) + return await task + + async def crash(self): + os._exit(0) + + @require_ray @pytest.mark.asyncio async def test_main_pool(ray_start_regular): @@ -71,3 +87,96 @@ async def test_shutdown_sub_pool(ray_start_regular): with pytest.raises(ray.exceptions.RayActorError): await sub_pool_handle1.health_check.remote() await sub_pool_handle2.health_check.remote() + + +@require_ray +@pytest.mark.asyncio +async def test_server_closed(ray_start_regular): + pg_name, n_process = 'ray_cluster', 1 + pg = ray.util.placement_group(name=pg_name, bundles=[{'CPU': n_process}]) + ray.get(pg.ready()) + address = process_placement_to_address(pg_name, 0, process_index=0) + # start the actor pool + actor_handle = await mo.create_actor_pool(address, n_process=n_process) + await actor_handle.actor_pool.remote('start') + + ctx = get_context() + actor_main = await ctx.create_actor( + TestActor, address=address, uid='Test-main', + allocate_strategy=ProcessIndex(0)) + + actor_sub = await ctx.create_actor( + TestActor, address=address, uid='Test-sub', + allocate_strategy=ProcessIndex(1)) + + # test calling from ray driver to ray actor + task = asyncio.create_task(actor_sub.crash()) + + with pytest.raises(ServerClosed): + # process already died, + # ServerClosed will be raised + await task + + # wait for recover of sub pool + await ctx.wait_actor_pool_recovered(actor_sub.address, address) + + # test calling from ray actor to ray actor + task = asyncio.create_task(actor_main.kill(actor_sub.address, 'Test-sub')) + + with pytest.raises(ServerClosed): + await task + + +@require_ray +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'auto_recover', + [False, True, 'actor', 'process'] +) +async def test_auto_recover(ray_start_regular, auto_recover): + pg_name, n_process = 'ray_cluster', 1 + pg = ray.util.placement_group(name=pg_name, bundles=[{'CPU': n_process}]) + assert pg.wait(timeout_seconds=20) + address = process_placement_to_address(pg_name, 0, process_index=0) + actor_handle = await mo.create_actor_pool(address, n_process=n_process, auto_recover=auto_recover) + await actor_handle.actor_pool.remote('start') + + ctx = get_context() + + # wait for recover of main pool always returned immediately + await ctx.wait_actor_pool_recovered(address, address) + + # create actor on main + actor_ref = await ctx.create_actor( + TestActor, address=address, + allocate_strategy=MainPool()) + + with pytest.raises(ValueError): + # cannot kill actors on main pool + await mo.kill_actor(actor_ref) + + # create actor + actor_ref = await ctx.create_actor( + TestActor, address=address, + allocate_strategy=ProcessIndex(1)) + # kill_actor will cause kill corresponding process + await ctx.kill_actor(actor_ref) + + if auto_recover: + await ctx.wait_actor_pool_recovered(actor_ref.address, address) + sub_pool_address = process_placement_to_address(pg_name, 0, process_index=1) + sub_pool_handle = ray.get_actor(sub_pool_address) + assert await sub_pool_handle.actor_pool.remote('health_check') == PoolStatus.HEALTHY + + expect_has_actor = True if auto_recover in ['actor', True] else False + assert await ctx.has_actor(actor_ref) is expect_has_actor + else: + with pytest.raises((ServerClosed, ConnectionError)): + await ctx.has_actor(actor_ref) + + if 'COV_CORE_SOURCE' in os.environ: + for addr in [process_placement_to_address(pg_name, 0, process_index=i) for i in range(2)]: + # must save the local reference until this is fixed: + # https://github.com/ray-project/ray/issues/7815 + ray_actor = ray.get_actor(addr) + ray.get(ray_actor.cleanup.remote()) From 62e83da63a87004af2560474dfbb1de285690cd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BB=A7=E7=9B=9B?= Date: Wed, 18 Aug 2021 11:18:40 +0800 Subject: [PATCH 2/2] Backport part code from #2288 --- mars/oscar/backends/context.py | 17 +-- mars/oscar/backends/mars/pool.py | 14 +++ mars/oscar/backends/pool.py | 13 --- mars/oscar/backends/ray/communication.py | 108 +++--------------- mars/oscar/backends/ray/driver.py | 6 +- mars/oscar/backends/ray/pool.py | 68 ++++++----- .../oscar/backends/ray/tests/test_ray_pool.py | 10 +- mars/oscar/context.pyx | 1 - 8 files changed, 88 insertions(+), 149 deletions(-) diff --git a/mars/oscar/backends/context.py b/mars/oscar/backends/context.py index 91a86c876d..b37b33b76a 100644 --- a/mars/oscar/backends/context.py +++ b/mars/oscar/backends/context.py @@ -166,14 +166,15 @@ async def cancel(self, async def wait_actor_pool_recovered(self, address: str, main_address: str = None): - # get main_pool_address - control_message = ControlMessage( - new_message_id(), main_address, - ControlMessageType.get_config, - 'main_pool_address', - protocol=DEFAULT_PROTOCOL) - main_address = self._process_result_message( - await self._call(main_address, control_message)) + if main_address is None: + # get main_pool_address + control_message = ControlMessage( + new_message_id(), address, + ControlMessageType.get_config, + 'main_pool_address', + protocol=DEFAULT_PROTOCOL) + main_address = self._process_result_message( + await self._call(address, control_message)) # if address is main pool, it is never recovered if address == main_address: diff --git a/mars/oscar/backends/mars/pool.py b/mars/oscar/backends/mars/pool.py index 2c385da93c..d18dc34c8b 100644 --- a/mars/oscar/backends/mars/pool.py +++ b/mars/oscar/backends/mars/pool.py @@ -23,6 +23,7 @@ from ....utils import get_next_port from ..config import ActorPoolConfig +from ..message import CreateActorMessage from ..pool import MainActorPoolBase, SubActorPoolBase, _register_message_handler @@ -198,6 +199,19 @@ async def kill_sub_pool(self, process: multiprocessing.Process, async def is_sub_pool_alive(self, process: multiprocessing.Process): return process.is_alive() + async def recover_sub_pool(self, address: str): + process_index = self._config.get_process_index(address) + # process dead, restart it + # remember always use spawn to recover sub pool + self.sub_processes[address] = await self.__class__.start_sub_pool( + self._config, process_index, 'spawn') + + if self._auto_recover == 'actor': + # need to recover all created actors + for _, message in self._allocated_actors[address].values(): + create_actor_message: CreateActorMessage = message + await self.call(address, create_actor_message) + @_register_message_handler class SubActorPool(SubActorPoolBase): diff --git a/mars/oscar/backends/pool.py b/mars/oscar/backends/pool.py index b8b1af3434..3d32669221 100644 --- a/mars/oscar/backends/pool.py +++ b/mars/oscar/backends/pool.py @@ -1025,19 +1025,6 @@ def process_sub_pool_lost(self, address: str): # or only recover process, remove all created actors self._allocated_actors[address] = dict() - async def recover_sub_pool(self, address: str): - process_index = self._config.get_process_index(address) - # process dead, restart it - # remember always use spawn to recover sub pool - self.sub_processes[address] = await self.__class__.start_sub_pool( - self._config, process_index, 'spawn') - - if self._auto_recover == 'actor': - # need to recover all created actors - for _, message in self._allocated_actors[address].values(): - create_actor_message: CreateActorMessage = message - await self.call(address, create_actor_message) - async def monitor_sub_pools(self): try: while not self._stopped.is_set(): diff --git a/mars/oscar/backends/ray/communication.py b/mars/oscar/backends/ray/communication.py index b4bb471a27..e4c8a6a74b 100644 --- a/mars/oscar/backends/ray/communication.py +++ b/mars/oscar/backends/ray/communication.py @@ -18,7 +18,7 @@ import logging from abc import ABC from collections import namedtuple -from typing import Any, Callable, Coroutine, Dict, Type, Union +from typing import Any, Callable, Coroutine, Dict, Type from urllib.parse import urlparse from ....utils import implements, classproperty @@ -88,7 +88,8 @@ def closed(self) -> bool: class RayClientChannel(RayChannelBase): - """A channel from ray driver to ray actor. Use ray call reply for client channel recv. + """ + A channel from ray driver/actor to ray actor. Use ray call reply for client channel recv. """ __slots__ = '_peer_actor', @@ -121,16 +122,21 @@ async def recv(self): if isinstance(result, RayChannelException): raise result.exc_value.with_traceback(result.exc_traceback) return deserialize(*result) + except ray.exceptions.RayActorError: + if not self._closed.is_set(): + # raise a EOFError as the SocketChannel does + raise EOFError('Server may be closed') except (RuntimeError, ServerClosed) as e: # pragma: no cover if not self._closed.is_set(): raise e class RayServerChannel(RayChannelBase): - """A channel from ray actor to ray driver. Since ray actor can't call ray driver, - we use ray call reply for server channel send. Note that there can't be multiple - channel message sends for one received message, or else it will be taken as next - message's reply. + """ + A channel from ray actor to ray driver/actor. Since ray actor can't call ray driver, + we use ray call reply for server channel send. Note that there can't be multiple + channel message sends for one received message, or else it will be taken as next + message's reply. """ __slots__ = '_out_queue', '_msg_recv_counter', '_msg_sent_counter' @@ -179,58 +185,6 @@ async def __on_ray_recv__(self, message): return await done.pop() -class RayTwoWayChannel(RayChannelBase): - """ - Channel for communications between ray actors. - """ - - __slots__ = '_peer_actor', - - def __init__(self, - local_address: str = None, - dest_address: str = None, - channel_index: int = None, - channel_id: ChannelID = None, - compression=None): - super().__init__(local_address, dest_address, channel_index, channel_id, compression) - # ray actor should be created with the address as the name. - self._peer_actor: 'ray.actor.ActorHandle' = ray.get_actor(dest_address) - - @implements(Channel.send) - async def send(self, message: Any): - if self._closed.is_set(): # pragma: no cover - raise ChannelClosed('Channel already closed, cannot send message') - object_ref = self._peer_actor.__on_ray_recv__.remote(self.channel_id, serialize(message)) - with debug_async_timeout('ray_object_retrieval_timeout', - 'Message that the server sent to actor %s is %s', self.dest_address, message): - result = await object_ref - if isinstance(result, RayChannelException): # pragma: no cover - # Peer create channel may fail - raise result.exc_value.with_traceback(result.exc_traceback) - elif isinstance(result, Exception): - raise result - else: - assert result is None - - @implements(Channel.recv) - async def recv(self): - if self._closed.is_set(): # pragma: no cover - raise ChannelClosed('Channel already closed, cannot write message') - try: - result = await self._in_queue.get() - if isinstance(result, RayChannelException): # pragma: no cover - raise result.exc_value.with_traceback(result.exc_traceback) - return deserialize(*result) - except (RuntimeError, ServerClosed) as e: # pragma: no cover - if not self._closed.is_set(): - raise e - - async def __on_ray_recv__(self, message): - if self._closed.is_set(): # pragma: no cover - raise ChannelClosed('Channel already closed') - await self._in_queue.put(message) - - @register_server class RayServer(Server): __slots__ = '_closed', '_channels', '_tasks' @@ -242,7 +196,7 @@ class RayServer(Server): def __init__(self, address, channel_handler: Callable[[Channel], Coroutine] = None): super().__init__(address, channel_handler) self._closed = asyncio.Event() - self._channels: Dict[ChannelID, Union[RayServerChannel, RayTwoWayChannel]] = dict() + self._channels: Dict[ChannelID, RayServerChannel] = dict() self._tasks: Dict[ChannelID, asyncio.Task] = dict() @classproperty @@ -309,7 +263,7 @@ async def join(self, timeout=None): @implements(Server.on_connected) async def on_connected(self, *args, **kwargs): channel = args[0] - assert isinstance(channel, (RayServerChannel, RayTwoWayChannel)) + assert isinstance(channel, RayServerChannel) if kwargs: # pragma: no cover raise TypeError(f'{type(self).__name__} got unexpected ' f'arguments: {",".join(kwargs)}') @@ -337,21 +291,12 @@ async def __on_ray_recv__(self, channel_id: ChannelID, message): f'from channel {channel_id}') channel = self._channels.get(channel_id) if not channel: - peer_local_address, _, peer_channel_index, peer_dest_address = channel_id - if not peer_local_address: - # Peer is a ray driver. - channel = RayServerChannel(peer_dest_address, peer_channel_index, channel_id) - else: - # Peer is a ray actor too. - channel = RayTwoWayChannel( - peer_dest_address, peer_local_address, peer_channel_index, channel_id) + _, _, peer_channel_index, peer_dest_address = channel_id + channel = RayServerChannel(peer_dest_address, peer_channel_index, channel_id) self._channels[channel_id] = channel self._tasks[channel_id] = asyncio.create_task(self.on_connected(channel)) return await channel.__on_ray_recv__(message) - def register_channel(self, channel_id: ChannelID, channel: RayTwoWayChannel): - self._channels[channel_id] = channel - @register_client class RayClient(Client): @@ -375,26 +320,7 @@ async def connect(dest_address: str, if urlparse(dest_address).scheme != RayServer.scheme: # pragma: no cover raise ValueError(f'Destination address should start with "ray://" ' f'for RayClient, got {dest_address}') - if local_address: - if not RayServer.is_ray_actor_started(): # pragma: no cover - logger.info(f'Current process needs to be a ray actor for using {local_address} ' - f'as address to receive messages from other clients. ' - f'Use RayClientChannel instead to receive messages from {dest_address}.') - local_address = None # make peer use RayClientChannel too. - client_channel = RayClientChannel(dest_address) - else: - server = RayServer.get_instance() - if server is None: - raise RuntimeError(f'RayServer needs to be created first before RayClient ' - f'local_address {local_address}, dest_address {dest_address}') - # Current process ia a ray actor, is connecting to another ray actor. - client_channel = RayTwoWayChannel(local_address, dest_address) - # The RayServer will push message to this channel's queue after it received - # the message from the `dest_address` actor. - server.register_channel(client_channel.channel_id, client_channel) - else: - # Current process ia a ray driver - client_channel = RayClientChannel(dest_address) + client_channel = RayClientChannel(dest_address) client = RayClient(local_address, dest_address, client_channel) return client diff --git a/mars/oscar/backends/ray/driver.py b/mars/oscar/backends/ray/driver.py index 0eb3f1b65b..0234185f61 100644 --- a/mars/oscar/backends/ray/driver.py +++ b/mars/oscar/backends/ray/driver.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +import os from numbers import Number from typing import Dict @@ -56,10 +57,13 @@ def stop_cluster(cls): pg_name = cls._cluster_info['pg_name'] pg = cls._cluster_info['pg_group'] for index, bundle_spec in enumerate(pg.bundle_specs): - n_process = int(bundle_spec["CPU"]) + n_process = int(bundle_spec["CPU"]) + 1 for process_index in range(n_process): address = process_placement_to_address(pg_name, index, process_index=process_index) try: + if 'COV_CORE_SOURCE' in os.environ: # pragma: no cover + # must clean up first, or coverage info lost + ray.get(ray.get_actor(address).cleanup.remote()) ray.kill(ray.get_actor(address)) except: # noqa: E722 # nosec # pylint: disable=bare-except pass diff --git a/mars/oscar/backends/ray/pool.py b/mars/oscar/backends/ray/pool.py index 23e22964e2..c823c744bf 100644 --- a/mars/oscar/backends/ray/pool.py +++ b/mars/oscar/backends/ray/pool.py @@ -23,14 +23,15 @@ from enum import Enum from typing import List, Optional -from ....serialization.ray import register_ray_serializers -from ....utils import lazy_import +from .communication import ChannelID, RayServer, RayChannelException +from .utils import process_address_to_placement, process_placement_to_address, get_placement_group from ..config import ActorPoolConfig +from ..message import CreateActorMessage from ..pool import AbstractActorPool, MainActorPoolBase, SubActorPoolBase, create_actor_pool, _register_message_handler from ..router import Router from ... import ServerClosed -from .communication import ChannelID, RayServer, RayChannelException -from .utils import process_address_to_placement, process_placement_to_address, get_placement_group +from ....serialization.ray import register_ray_serializers +from ....utils import lazy_import ray = lazy_import('ray') logger = logging.getLogger(__name__) @@ -78,27 +79,30 @@ async def start_sub_pool( actor_handle = ray.remote(RaySubPool).options( num_cpus=num_cpus, name=external_address, max_concurrency=10000, # By default, 1000 tasks can be running concurrently. + max_restarts=-1, # Auto restarts by ray placement_group=pg, placement_group_bundle_index=bundle_index).remote() await actor_handle.start.remote(actor_pool_config, process_index) return actor_handle + async def recover_sub_pool(self, address: str): + process = self.sub_processes[address] + process_index = self._config.get_process_index(address) + await process.start.remote(self._config, process_index) + + if self._auto_recover == 'actor': + # need to recover all created actors + for _, message in self._allocated_actors[address].values(): + create_actor_message: CreateActorMessage = message + await self.call(address, create_actor_message) + async def kill_sub_pool(self, process: 'ray.actor.ActorHandle', force: bool = False): if 'COV_CORE_SOURCE' in os.environ and not force: # pragma: no cover - # must shutdown gracefully, or coverage info lost - process.exit_actor.remote() - wait_time, waited_time = 10, 0 - while await self.is_sub_pool_alive(process): # pragma: no cover - if waited_time > wait_time: - logger.info('''Can't stop %s in %s, kill sub_pool forcibly''', process, wait_time) - await self._kill_actor_forcibly(process) - return - await asyncio.sleep(0.1) - wait_time += 0.1 - else: - await self._kill_actor_forcibly(process) + # must clean up first, or coverage info lost + await process.cleanup.remote() + await self._kill_actor_forcibly(process) async def _kill_actor_forcibly(self, process: 'ray.actor.ActorHandle'): - ray.kill(process) + ray.kill(process, no_restart=False) wait_time, waited_time = 30, 0 while await self.is_sub_pool_alive(process): # pragma: no cover if waited_time > wait_time: @@ -108,13 +112,19 @@ async def _kill_actor_forcibly(self, process: 'ray.actor.ActorHandle'): async def is_sub_pool_alive(self, process: 'ray.actor.ActorHandle'): try: - await process.health_check.remote() + # try to call the method of sup pool, if success, it's alive. + await process.actor_pool.remote('health_check') return True except Exception: logger.info("Detected RaySubPool %s died", process) return False +class PoolStatus(Enum): + HEALTHY = 0 + UNHEALTHY = 1 + + @_register_message_handler class RaySubActorPool(SubActorPoolBase): @@ -127,10 +137,8 @@ async def stop(self): finally: self._stopped.set() - -class PoolStatus(Enum): - HEALTHY = 0 - UNHEALTHY = 1 + def health_check(self): # noqa: R0201 # pylint: disable=no-self-use + return PoolStatus.HEALTHY class RayPoolBase(ABC): @@ -170,12 +178,9 @@ async def __on_ray_recv__(self, channel_id: ChannelID, message): if self._ray_server is None: raise ServerClosed(f'Remote server {channel_id.dest_address} closed') return await self._ray_server.__on_ray_recv__(channel_id, message) - except Exception: # pragma: no cover + except Exception: # pragma: no cover return RayChannelException(*sys.exc_info()) - def health_check(self): # noqa: R0201 # pylint: disable=no-self-use - return PoolStatus.HEALTHY - async def actor_pool(self, attribute, *args, **kwargs): attr = getattr(self._actor_pool, attribute) if isinstance(attr, types.MethodType): @@ -185,10 +190,13 @@ async def actor_pool(self, attribute, *args, **kwargs): else: return attr - def exit_actor(self): - """Exiting current process gracefully.""" - logger.info('Exiting %s of process %s now', self, os.getpid()) - ray.actor.exit_actor() + def cleanup(self): + logger.info('Cleaning up %s of process %s now', self, os.getpid()) + try: + from pytest_cov.embed import cleanup + cleanup() + except ImportError: # pragma: no cover + pass class RayMainPool(RayPoolBase): diff --git a/mars/oscar/backends/ray/tests/test_ray_pool.py b/mars/oscar/backends/ray/tests/test_ray_pool.py index 19e2841ed9..ce45e1e593 100644 --- a/mars/oscar/backends/ray/tests/test_ray_pool.py +++ b/mars/oscar/backends/ray/tests/test_ray_pool.py @@ -76,17 +76,17 @@ async def test_shutdown_sub_pool(ray_start_regular): address = process_placement_to_address(pg_name, 0, process_index=0) actor_handle = ray.remote(RayMainPool).options( name=address, placement_group=pg, placement_group_bundle_index=bundle_index).remote() - await actor_handle.start.remote(address, n_process) + await actor_handle.start.remote(address, n_process, auto_recover=False) sub_pool_address1 = process_placement_to_address(pg_name, 0, process_index=1) sub_pool_handle1 = ray.get_actor(sub_pool_address1) sub_pool_address2 = process_placement_to_address(pg_name, 0, process_index=2) sub_pool_handle2 = ray.get_actor(sub_pool_address2) await actor_handle.actor_pool.remote('stop_sub_pool', sub_pool_address1, sub_pool_handle1, force=True) await actor_handle.actor_pool.remote('stop_sub_pool', sub_pool_address2, sub_pool_handle2, force=False) - import ray.exceptions - with pytest.raises(ray.exceptions.RayActorError): - await sub_pool_handle1.health_check.remote() - await sub_pool_handle2.health_check.remote() + with pytest.raises(AttributeError, match='NoneType'): + await sub_pool_handle1.actor_pool.remote('health_check') + with pytest.raises(AttributeError, match='NoneType'): + await sub_pool_handle2.actor_pool.remote('health_check') @require_ray diff --git a/mars/oscar/context.pyx b/mars/oscar/context.pyx index b90d340b79..d89756057a 100644 --- a/mars/oscar/context.pyx +++ b/mars/oscar/context.pyx @@ -196,7 +196,6 @@ cdef class ClientActorContext(BaseActorContext): return context.send(actor_ref, message, wait_response=wait_response) def wait_actor_pool_recovered(self, str address, str main_address = None): - main_address = main_address or address context = self._get_backend_context(address) return context.wait_actor_pool_recovered(address, main_address)