Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ray] Fix ray worker failover #3080

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/platform-ci.yml
Expand Up @@ -144,7 +144,7 @@ jobs:
coverage combine build/ && coverage report
fi
if [ -n "$WITH_RAY" ]; then
pytest $PYTEST_CONFIG --durations=0 --timeout=600 -v -s -m ray
pytest $PYTEST_CONFIG --durations=0 --timeout=200 -v -s -m ray
coverage report
fi
if [ -n "$WITH_RAY_DAG" ]; then
Expand Down
7 changes: 5 additions & 2 deletions mars/deploy/oscar/ray.py
Expand Up @@ -37,7 +37,7 @@
AbstractClusterBackend,
)
from ...services import NodeRole
from ...utils import lazy_import
from ...utils import lazy_import, retry_callable
from ..utils import (
load_config,
get_third_party_modules_from_config,
Expand Down Expand Up @@ -274,7 +274,10 @@ async def reconstruct_worker(self, address: str):
async def _reconstruct_worker():
logger.info("Reconstruct worker %s", address)
actor = ray.get_actor(address)
state = await actor.state.remote()
# ray call will error when actor is restarting
state = await retry_callable(
actor.state.remote, ex_type=ray.exceptions.RayActorError, sync=False
)()
if state == RayPoolState.SERVICE_READY:
logger.info("Worker %s is service ready.")
return
Expand Down
5 changes: 3 additions & 2 deletions mars/deploy/oscar/tests/test_ray.py
Expand Up @@ -578,7 +578,7 @@ async def remote(self):
class FakeActor:
state = FakeActorMethod()

def _get_actor(*args):
def _get_actor(*args, **kwargs):
zhongchun marked this conversation as resolved.
Show resolved Hide resolved
return FakeActor

async def _stop_worker(*args):
Expand Down Expand Up @@ -677,7 +677,8 @@ async def test_auto_scale_in(ray_large_cluster):
assert await autoscaler_ref.get_dynamic_worker_nums() == 2


@pytest.mark.timeout(timeout=1000)
@pytest.mark.skip("Enable it when ray ownership bug is fixed")
@pytest.mark.timeout(timeout=200)
@pytest.mark.parametrize("ray_large_cluster", [{"num_nodes": 4}], indirect=True)
@require_ray
@pytest.mark.asyncio
Expand Down
31 changes: 24 additions & 7 deletions mars/oscar/backends/ray/pool.py
Expand Up @@ -28,7 +28,7 @@

from ... import ServerClosed
from ....serialization.ray import register_ray_serializers
from ....utils import lazy_import, ensure_coverage
from ....utils import lazy_import, ensure_coverage, retry_callable
from ..config import ActorPoolConfig
from ..message import CreateActorMessage
from ..pool import (
Expand Down Expand Up @@ -130,14 +130,27 @@ async def start_sub_pool(
f"process_index {process_index} is not consistent with index {_process_index} "
f"in external_address {external_address}"
)
actor_handle = config["kwargs"]["sub_pool_handles"][external_address]
state = await retry_callable(
actor_handle.state.remote, ex_type=ray.exceptions.RayActorError, sync=False
)()
if state is RayPoolState.SERVICE_READY: # pragma: no cover
logger.info("Ray sub pool %s is alive, kill it first.", external_address)
await kill_and_wait(actor_handle, no_restart=False)
# Wait sub pool process restarted.
await retry_callable(
actor_handle.state.remote,
ex_type=ray.exceptions.RayActorError,
sync=False,
)()
logger.info("Start to start ray sub pool %s.", external_address)
create_sub_pool_timeout = 120
actor_handle = config["kwargs"]["sub_pool_handles"][external_address]
done, _ = await asyncio.wait(
[actor_handle.set_actor_pool_config.remote(actor_pool_config)],
timeout=create_sub_pool_timeout,
)
if not done: # pragma: no cover
try:
await asyncio.wait_for(
actor_handle.set_actor_pool_config.remote(actor_pool_config),
timeout=create_sub_pool_timeout,
)
except asyncio.TimeoutError: # pragma: no cover
msg = (
f"Can not start ray sub pool {external_address} in {create_sub_pool_timeout} seconds.",
)
Expand All @@ -153,6 +166,10 @@ async def wait_sub_pools_ready(cls, create_pool_tasks: List[asyncio.Task]):

async def recover_sub_pool(self, address: str):
process = self.sub_processes[address]
# ray call will error when actor is restarting
await retry_callable(
process.state.remote, ex_type=ray.exceptions.RayActorError, sync=False
)()
await process.start.remote()

if self._auto_recover == "actor":
Expand Down
34 changes: 34 additions & 0 deletions mars/tests/test_utils.py
Expand Up @@ -616,3 +616,37 @@ def __call__(self, *args, **kwargs):
def test_gen_random_id(id_length):
rnd_id = utils.new_random_id(id_length)
assert len(rnd_id) == id_length


@pytest.mark.asyncio
async def test_retry_callable():
assert utils.retry_callable(lambda x: x)(1) == 1
assert utils.retry_callable(lambda x: 0)(1) == 0

class CustomException(BaseException):
pass

def f1(x):
nonlocal num_retried
num_retried += 1
if num_retried == 3:
return x
raise CustomException

num_retried = 0
with pytest.raises(CustomException):
utils.retry_callable(f1)(1)
assert utils.retry_callable(f1, ex_type=CustomException)(1) == 1
num_retried = 0
with pytest.raises(CustomException):
utils.retry_callable(f1, max_retries=2, ex_type=CustomException)(1)
num_retried = 0
assert utils.retry_callable(f1, max_retries=3, ex_type=CustomException)(1) == 1

async def f2(x):
return f1(x)

num_retried = 0
with pytest.raises(CustomException):
await utils.retry_callable(f2)(1)
assert await utils.retry_callable(f2, ex_type=CustomException)(1) == 1
38 changes: 38 additions & 0 deletions mars/utils.py
Expand Up @@ -1698,3 +1698,41 @@ def ensure_coverage():
pass
else:
cleanup_on_sigterm()


def retry_callable(
callable_,
ex_type: type = Exception,
wait_interval=1,
max_retries=-1,
sync: bool = None,
):
if inspect.iscoroutinefunction(callable_) or sync is False:

@functools.wraps(callable)
async def retry_call(*args, **kwargs):
num_retried = 0
while max_retries < 0 or num_retried < max_retries:
num_retried += 1
try:
return await callable_(*args, **kwargs)
except ex_type:
await asyncio.sleep(wait_interval)

else:

@functools.wraps(callable)
def retry_call(*args, **kwargs):
num_retried = 0
ex = None
while max_retries < 0 or num_retried < max_retries:
num_retried += 1
try:
return callable_(*args, **kwargs)
except ex_type as e:
ex = e
time.sleep(wait_interval)
assert ex is not None
raise ex # pylint: disable-msg=E0702

return retry_call