Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve wait_actor_pool_recovered #2328

Merged
merged 3 commits into from
Aug 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions mars/oscar/backends/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,15 +864,11 @@ async def handle_control_command(self,
self.sub_processes[message.address],
timeout=timeout,
force=force)
if self._auto_recover:
self._recover_events[message.address] = asyncio.Event()
processor.result = ResultMessage(message.message_id, True,
protocol=message.protocol)
elif message.control_message_type == ControlMessageType.wait_pool_recovered:
# check the aliveness of sub pool first, in case monitor task haven't found it.
if not await self.is_sub_pool_alive(self.sub_processes[message.address]):
if self._auto_recover and message.address not in self._recover_events:
self._recover_events[message.address] = asyncio.Event()
if self._auto_recover and message.address not in self._recover_events:
self._recover_events[message.address] = asyncio.Event()
Copy link
Member

@wjsi wjsi Aug 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one single event object plus a lock is enough. Recover event is created during pool initialization and replaced with a new event object every time when a monitor loop finishes.


event = self._recover_events.get(message.address, None)
if event is not None:
Expand Down Expand Up @@ -1019,6 +1015,10 @@ async def is_sub_pool_alive(self, process: SubProcessHandle):
bool
"""

@abstractmethod
def recover_sub_pool(self, address):
"""Recover a sub actor pool"""

def process_sub_pool_lost(self, address: str):
if self._auto_recover in (False, 'process'):
# process down, when not auto_recover
Expand All @@ -1030,18 +1030,18 @@ async def monitor_sub_pools(self):
while not self._stopped.is_set():
for address in self.sub_processes:
process = self.sub_processes[address]
recover_events_discovered = (address in self._recover_events)
if not await self.is_sub_pool_alive(process): # pragma: no cover
if self._on_process_down is not None:
self._on_process_down(self, address)
self.process_sub_pool_lost(address)
if self._auto_recover:
if address not in self._recover_events:
self._recover_events[address] = asyncio.Event()
await self.recover_sub_pool(address)
if self._on_process_recover is not None:
self._on_process_recover(self, address)
event = self._recover_events.pop(address)
event.set()
if recover_events_discovered:
event = self._recover_events.pop(address)
event.set()

# check every half second
await asyncio.sleep(.5)
Expand Down
3 changes: 3 additions & 0 deletions mars/oscar/backends/ray/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ..message import CreateActorMessage
from ..pool import AbstractActorPool, MainActorPoolBase, SubActorPoolBase, create_actor_pool, _register_message_handler
from ..router import Router
from ... import ServerClosed
from ....serialization.ray import register_ray_serializers
from ....utils import lazy_import

Expand Down Expand Up @@ -174,6 +175,8 @@ def _set_ray_server(self, actor_pool: AbstractActorPool):
async def __on_ray_recv__(self, channel_id: ChannelID, message):
"""Method for communication based on ray actors"""
try:
if self._ray_server is None:
raise ServerClosed(f'Remote server {channel_id.dest_address} closed')
return await self._ray_server.__on_ray_recv__(channel_id, message)
except Exception: # pragma: no cover
return RayChannelException(*sys.exc_info())
Expand Down
115 changes: 112 additions & 3 deletions mars/oscar/backends/ray/tests/test_ray_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os

import pytest

import mars.oscar as mo
from mars.oscar.errors import ServerClosed
from mars.oscar.backends.allocate_strategy import ProcessIndex, MainPool
from mars.oscar.backends.ray.pool import RayMainPool, RayMainActorPool, create_actor_pool, PoolStatus
from mars.oscar.backends.ray.utils import process_placement_to_address
from mars.oscar.context import get_context
from mars.tests.core import require_ray
from .....utils import lazy_import
from ..pool import RayMainPool, RayMainActorPool, create_actor_pool
from ..utils import process_placement_to_address
from mars.utils import lazy_import

ray = lazy_import('ray')


class TestActor(mo.Actor):
async def kill(self, address, uid):
actor_ref = await mo.actor_ref(address, uid)
task = asyncio.create_task(actor_ref.crash())
return await task

async def crash(self):
os._exit(0)


@require_ray
@pytest.mark.asyncio
async def test_main_pool(ray_start_regular):
Expand Down Expand Up @@ -71,3 +87,96 @@ async def test_shutdown_sub_pool(ray_start_regular):
await sub_pool_handle1.actor_pool.remote('health_check')
with pytest.raises(AttributeError, match='NoneType'):
await sub_pool_handle2.actor_pool.remote('health_check')


@require_ray
@pytest.mark.asyncio
async def test_server_closed(ray_start_regular):
pg_name, n_process = 'ray_cluster', 1
pg = ray.util.placement_group(name=pg_name, bundles=[{'CPU': n_process}])
ray.get(pg.ready())
address = process_placement_to_address(pg_name, 0, process_index=0)
# start the actor pool
actor_handle = await mo.create_actor_pool(address, n_process=n_process)
await actor_handle.actor_pool.remote('start')

ctx = get_context()
actor_main = await ctx.create_actor(
TestActor, address=address, uid='Test-main',
allocate_strategy=ProcessIndex(0))

actor_sub = await ctx.create_actor(
TestActor, address=address, uid='Test-sub',
allocate_strategy=ProcessIndex(1))

# test calling from ray driver to ray actor
task = asyncio.create_task(actor_sub.crash())

with pytest.raises(ServerClosed):
# process already died,
# ServerClosed will be raised
await task

# wait for recover of sub pool
await ctx.wait_actor_pool_recovered(actor_sub.address, address)

# test calling from ray actor to ray actor
task = asyncio.create_task(actor_main.kill(actor_sub.address, 'Test-sub'))

with pytest.raises(ServerClosed):
await task


@require_ray
@pytest.mark.asyncio
@pytest.mark.parametrize(
'auto_recover',
[False, True, 'actor', 'process']
)
async def test_auto_recover(ray_start_regular, auto_recover):
pg_name, n_process = 'ray_cluster', 1
pg = ray.util.placement_group(name=pg_name, bundles=[{'CPU': n_process}])
assert pg.wait(timeout_seconds=20)
address = process_placement_to_address(pg_name, 0, process_index=0)
actor_handle = await mo.create_actor_pool(address, n_process=n_process, auto_recover=auto_recover)
await actor_handle.actor_pool.remote('start')

ctx = get_context()

# wait for recover of main pool always returned immediately
await ctx.wait_actor_pool_recovered(address, address)

# create actor on main
actor_ref = await ctx.create_actor(
TestActor, address=address,
allocate_strategy=MainPool())

with pytest.raises(ValueError):
# cannot kill actors on main pool
await mo.kill_actor(actor_ref)

# create actor
actor_ref = await ctx.create_actor(
TestActor, address=address,
allocate_strategy=ProcessIndex(1))
# kill_actor will cause kill corresponding process
await ctx.kill_actor(actor_ref)

if auto_recover:
await ctx.wait_actor_pool_recovered(actor_ref.address, address)
sub_pool_address = process_placement_to_address(pg_name, 0, process_index=1)
sub_pool_handle = ray.get_actor(sub_pool_address)
assert await sub_pool_handle.actor_pool.remote('health_check') == PoolStatus.HEALTHY

expect_has_actor = True if auto_recover in ['actor', True] else False
assert await ctx.has_actor(actor_ref) is expect_has_actor
else:
with pytest.raises((ServerClosed, ConnectionError)):
await ctx.has_actor(actor_ref)

if 'COV_CORE_SOURCE' in os.environ:
for addr in [process_placement_to_address(pg_name, 0, process_index=i) for i in range(2)]:
# must save the local reference until this is fixed:
# https://github.com/ray-project/ray/issues/7815
ray_actor = ray.get_actor(addr)
ray.get(ray_actor.cleanup.remote())