Skip to content

Commit

Permalink
[Data] Allow to specify application-level error to retry for actor ta…
Browse files Browse the repository at this point in the history
…sk (ray-project#42492)

User reported issue that they cannot specify application-level exception retry for actor task (`retry_exceptions`). This is due to our actor pool operator does not allow specify ray remote arguments for actor task. This PR adds a config as `DataContext.actor_task_retry_on_errors`, so users can control application-level exceptions retry.

Signed-off-by: Cheng Su <scnju13@gmail.com>
Signed-off-by: khluu <khluu000@gmail.com>
  • Loading branch information
c21 authored and khluu committed Jan 24, 2024
1 parent e3fccfd commit 11d880f
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def __init__(
ray_remote_args,
)
self._ray_remote_args = self._apply_default_remote_args(self._ray_remote_args)
self._ray_actor_task_remote_args = {}
actor_task_errors = DataContext.get_current().actor_task_retry_on_errors
if actor_task_errors:
self._ray_actor_task_remote_args["retry_exceptions"] = actor_task_errors
self._min_rows_per_bundle = min_rows_per_bundle

# Create autoscaling policy from compute strategy.
Expand Down Expand Up @@ -194,9 +198,11 @@ def _dispatch_tasks(self):
task_idx=self._next_data_task_idx,
target_max_block_size=self.actual_target_max_block_size,
)
gen = actor.submit.options(num_returns="streaming", name=self.name).remote(
DataContext.get_current(), ctx, *input_blocks
)
gen = actor.submit.options(
num_returns="streaming",
name=self.name,
**self._ray_actor_task_remote_args,
).remote(DataContext.get_current(), ctx, *input_blocks)

def _task_done_callback(actor_to_return):
# Return the actor that was running the task to the pool.
Expand Down
11 changes: 10 additions & 1 deletion python/ray/data/context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import threading
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import ray
from ray._private.ray_constants import env_integer
Expand Down Expand Up @@ -159,6 +159,12 @@
"AWS Error SLOW_DOWN",
]

# The application-level errors that actor task would retry.
# Default to `False` to not retry on any errors.
# Set to `True` to retry all errors, or set to a list of errors to retry.
# This follows same format as `retry_exceptions` in Ray Core.
DEFAULT_ACTOR_TASK_RETRY_ON_ERRORS = False


@DeveloperAPI
class DataContext:
Expand Down Expand Up @@ -201,6 +207,7 @@ def __init__(
enable_get_object_locations_for_metrics: bool,
use_runtime_metrics_scheduling: bool,
write_file_retry_on_errors: List[str],
actor_task_retry_on_errors: Union[bool, List[BaseException]],
):
"""Private constructor (use get_current() instead)."""
self.target_max_block_size = target_max_block_size
Expand Down Expand Up @@ -239,6 +246,7 @@ def __init__(
)
self.use_runtime_metrics_scheduling = use_runtime_metrics_scheduling
self.write_file_retry_on_errors = write_file_retry_on_errors
self.actor_task_retry_on_errors = actor_task_retry_on_errors
# The additonal ray remote args that should be added to
# the task-pool-based data tasks.
self._task_pool_data_task_remote_args: Dict[str, Any] = {}
Expand Down Expand Up @@ -309,6 +317,7 @@ def get_current() -> "DataContext":
enable_get_object_locations_for_metrics=DEFAULT_ENABLE_GET_OBJECT_LOCATIONS_FOR_METRICS, # noqa E501
use_runtime_metrics_scheduling=DEFAULT_USE_RUNTIME_METRICS_SCHEDULING, # noqa: E501
write_file_retry_on_errors=DEFAULT_WRITE_FILE_RETRY_ON_ERRORS,
actor_task_retry_on_errors=DEFAULT_ACTOR_TASK_RETRY_ON_ERRORS,
)

return _default_context
Expand Down
21 changes: 21 additions & 0 deletions python/ray/data/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,27 @@ def mapper(x):
ds.map(mapper).materialize()


def test_actor_task_failure(shutdown_only, restore_data_context):
ray.init(num_cpus=2)

ctx = DataContext.get_current()
ctx.actor_task_retry_on_errors = [ValueError]

ds = ray.data.from_items([0, 10], parallelism=2)

class Mapper:
def __init__(self):
self._counter = 0

def __call__(self, x):
if self._counter < 2:
self._counter += 1
raise ValueError("oops")
return x

ds.map_batches(Mapper, concurrency=1).materialize()


def test_concurrency(shutdown_only):
ray.init(num_cpus=6)
ds = ray.data.range(10, parallelism=10)
Expand Down

0 comments on commit 11d880f

Please sign in to comment.