Skip to content

Commit

Permalink
[Ray] Implement gc for ray task executor context (#3061)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongchun committed May 30, 2022
1 parent 0263954 commit 505361b
Show file tree
Hide file tree
Showing 5 changed files with 324 additions and 62 deletions.
40 changes: 39 additions & 1 deletion mars/deploy/oscar/tests/test_ray_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@

import copy
import os
import time

import pytest

from .... import get_context
from .... import tensor as mt
from ....tests.core import DICT_NOT_EMPTY, require_ray
from ....utils import lazy_import
from ..local import new_cluster
from ..session import new_session
from ..session import new_session, get_default_async_session
from ..tests import test_local
from ..tests.session import new_test_session
from ..tests.test_local import _cancel_when_tile, _cancel_when_execute
Expand Down Expand Up @@ -129,3 +132,38 @@ async def test_session_get_progress(ray_start_regular_shared2, create_cluster):
@pytest.mark.parametrize("test_func", [_cancel_when_execute, _cancel_when_tile])
def test_cancel(ray_start_regular_shared2, create_cluster, test_func):
test_local.test_cancel(create_cluster, test_func)


@require_ray
@pytest.mark.parametrize("config", [{"backend": "ray"}])
def test_executor_context_gc(config):
session = new_session(
backend=config["backend"],
n_cpu=2,
web=False,
use_uvloop=False,
config={"task.execution_config.ray.subtask_monitor_interval": 0},
)

assert session._session.client.web_address is None
assert session.get_web_endpoint() is None

def f1(c):
time.sleep(0.5)
return c

with session:
t1 = mt.random.randint(10, size=(100, 10), chunk_size=100)
t2 = mt.random.randint(10, size=(100, 10), chunk_size=50)
t3 = t2 + t1
t4 = t3.sum(0)
t5 = t4.map_chunk(f1)
r = t5.execute()
result = r.fetch()
assert result is not None
assert len(result) == 10
context = get_context()
assert len(context._task_context) < 5

session.stop_server()
assert get_default_async_session() is None
2 changes: 1 addition & 1 deletion mars/lib/ordered_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

# SetLike[T] is either a set of elements of type T, or a sequence, which
# we will convert to an OrderedSet by adding its elements in order.
SetLike = Union[AbstractSet[T], Sequence[T]]
SetLike = Union[AbstractSet[T], Sequence[T], Iterable[T]]
OrderedSetInitializer = Union[AbstractSet[T], Sequence[T], Iterable[T]]


Expand Down
12 changes: 12 additions & 0 deletions mars/services/task/execution/ray/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from ..api import ExecutionConfig, register_config_cls
from ..utils import get_band_resources_from_config

# The default interval seconds to update progress and collect garbage.
DEFAULT_SUBTASK_MONITOR_INTERVAL = 1


@register_config_cls
class RayExecutionConfig(ExecutionConfig):
Expand Down Expand Up @@ -55,3 +58,12 @@ def create_task_state_actor_as_needed(self):
# - False:
# Create RayTaskState actor in advance when the RayTaskExecutor is created.
return self._ray_execution_config.get("create_task_state_actor_as_needed", True)

def get_subtask_monitor_interval(self):
"""
The interval seconds for the monitor task to update progress and
collect garbage.
"""
return self._ray_execution_config.get(
"subtask_monitor_interval", DEFAULT_SUBTASK_MONITOR_INTERVAL
)
188 changes: 134 additions & 54 deletions mars/services/task/execution/ray/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
execute,
)
from .....lib.aio import alru_cache
from .....lib.ordered_set import OrderedSet
from .....resource import Resource
from .....serialization import serialize, deserialize
from .....typing import BandType
Expand Down Expand Up @@ -96,7 +97,7 @@ def _optimize_subtask_graph(subtask_graph):


async def _cancel_ray_task(obj_ref, kill_timeout: int = 3):
ray.cancel(obj_ref, force=False)
await asyncio.to_thread(ray.cancel, obj_ref, force=False)
try:
await asyncio.to_thread(ray.get, obj_ref, timeout=kill_timeout)
except ray.exceptions.TaskCancelledError: # pragma: no cover
Expand All @@ -108,7 +109,7 @@ async def _cancel_ray_task(obj_ref, kill_timeout: int = 3):
e,
obj_ref,
)
ray.cancel(obj_ref, force=True)
await asyncio.to_thread(ray.cancel, obj_ref, force=True)


def execute_subtask(
Expand Down Expand Up @@ -183,17 +184,12 @@ def __init__(

self._available_band_resources = None

# For progress
# For progress and task cancel
self._pre_all_stages_progress = 0.0
self._pre_all_stages_tile_progress = 0
self._cur_stage_tile_progress = 0
self._cur_stage_output_object_refs = []
# This list records the output object ref number of subtasks, so with
# `self._cur_stage_output_object_refs` we can just call `ray.cancel`
# with one object ref to cancel a subtask instead of cancel all object
# refs. In this way we can reduce a lot of unnecessary calls of ray.
self._output_object_refs_nums = []
# For meta and data gc
self._pre_all_stages_tile_progress = 0.0
self._cur_stage_progress = 0.0
self._cur_stage_tile_progress = 0.0
self._cur_stage_first_output_object_ref_to_subtask = dict()
self._execute_subtask_graph_aiotask = None
self._cancelled = False

Expand Down Expand Up @@ -258,12 +254,12 @@ def destroy(self):

self._available_band_resources = None

# For progress
self._pre_all_stages_progress = 1
self._pre_all_stages_tile_progress = 1
self._cur_stage_tile_progress = 1
self._cur_stage_output_object_refs = []
self._output_object_refs_nums = []
# For progress and task cancel
self._pre_all_stages_progress = 1.0
self._pre_all_stages_tile_progress = 1.0
self._cur_stage_progress = 1.0
self._cur_stage_tile_progress = 1.0
self._cur_stage_first_output_object_ref_to_subtask = dict()
self._execute_subtask_graph_aiotask = None
self._cancelled = None

Expand Down Expand Up @@ -318,7 +314,33 @@ async def execute_subtask_graph(
) -> Dict[Chunk, ExecutionChunkResult]:
if self._cancelled is True: # pragma: no cover
raise asyncio.CancelledError()

def _on_monitor_task_done(fut):
# Print the error of monitor task.
try:
fut.result()
except asyncio.CancelledError:
pass

# Create a monitor task to update progress and collect garbage.
monitor_task = asyncio.create_task(
self._update_progress_and_collect_garbage(
subtask_graph, self._config.get_subtask_monitor_interval()
)
)
monitor_task.add_done_callback(_on_monitor_task_done)

def _on_execute_task_done(fut):
# Make sure the monitor task is cancelled.
monitor_task.cancel()
# Just use `self._cur_stage_tile_progress` as current stage progress
# because current stage is completed, its progress is 1.0.
self._cur_stage_progress = 1.0
self._pre_all_stages_progress += self._cur_stage_tile_progress
self._cur_stage_first_output_object_ref_to_subtask.clear()

self._execute_subtask_graph_aiotask = asyncio.current_task()
self._execute_subtask_graph_aiotask.add_done_callback(_on_execute_task_done)

logger.info("Stage %s start.", stage_id)
task_context = self._task_context
Expand Down Expand Up @@ -359,8 +381,9 @@ async def execute_subtask_graph(
continue
elif output_count == 1:
output_object_refs = [output_object_refs]
self._cur_stage_output_object_refs.extend(output_object_refs)
self._output_object_refs_nums.append(len(output_object_refs))
self._cur_stage_first_output_object_ref_to_subtask[
output_object_refs[0]
] = subtask
if output_meta_keys:
meta_object_ref, *output_object_refs = output_object_refs
# TODO(fyrestone): Fetch(not get) meta object here.
Expand Down Expand Up @@ -395,16 +418,16 @@ async def execute_subtask_graph(
logger.info("Waiting for stage %s complete.", stage_id)
# Patched the asyncio.to_thread for Python < 3.9 at mars/lib/aio/__init__.py
await asyncio.to_thread(ray.wait, list(output_object_refs), fetch_local=False)
# Just use `self._cur_stage_tile_progress` as current stage progress
# because current stage is finished, its progress is 1.
self._pre_all_stages_progress += self._cur_stage_tile_progress
self._cur_stage_output_object_refs.clear()
self._output_object_refs_nums.clear()

logger.info("Stage %s is complete.", stage_id)
return chunk_to_meta

async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
try:
await self.cancel()
except BaseException: # noqa: E722 # nosec # pylint: disable=bare-except
pass
return

# Update info if no exception occurs.
Expand Down Expand Up @@ -458,19 +481,7 @@ async def get_available_band_resources(self) -> Dict[BandType, Resource]:

async def get_progress(self) -> float:
"""Get the execution progress."""
stage_progress = 0.0
total = len(self._cur_stage_output_object_refs)
if total > 0:
finished_objects, _ = ray.wait(
self._cur_stage_output_object_refs,
num_returns=total,
timeout=0, # Avoid blocking the asyncio loop.
fetch_local=False,
)
stage_progress = (
len(finished_objects) / total * self._cur_stage_tile_progress
)
return self._pre_all_stages_progress + stage_progress
return self._cur_stage_progress

async def cancel(self):
"""
Expand All @@ -480,26 +491,17 @@ async def cancel(self):
2. Try to cancel the submitted subtasks by `ray.cancel`
"""
logger.info("Start to cancel task %s.", self._task)
if self._task is None:
if self._task is None or self._cancelled is True:
return
self._cancelled = True
if (
self._execute_subtask_graph_aiotask is not None
and not self._execute_subtask_graph_aiotask.cancelled()
):
if self._execute_subtask_graph_aiotask is not None:
self._execute_subtask_graph_aiotask.cancel()
timeout = self._config.get_subtask_cancel_timeout()
subtask_num = len(self._output_object_refs_nums)
if subtask_num > 0:
pos = 0
obj_refs_to_be_cancelled_ = []
for i in range(0, subtask_num):
if i > 0:
pos += self._output_object_refs_nums[i - 1]
obj_refs_to_be_cancelled_.append(
_cancel_ray_task(self._cur_stage_output_object_refs[pos], timeout)
)
await asyncio.gather(*obj_refs_to_be_cancelled_)
to_be_cancelled_coros = [
_cancel_ray_task(object_ref, timeout)
for object_ref in self._cur_stage_first_output_object_ref_to_subtask.keys()
]
await asyncio.gather(*to_be_cancelled_coros)

async def _load_subtask_inputs(
self, stage_id: str, subtask: Subtask, chunk_graph: ChunkGraph, context: Dict
Expand Down Expand Up @@ -551,3 +553,81 @@ def _get_subtask_output_keys(chunk_graph: ChunkGraph):
else:
output_keys[chunk.key] = 1
return output_keys.keys()

async def _update_progress_and_collect_garbage(
self, subtask_graph: SubtaskGraph, interval_seconds: float
):
object_ref_to_subtask = self._cur_stage_first_output_object_ref_to_subtask
total = len(subtask_graph)
completed_subtasks = OrderedSet()

def gc():
"""
Consume the completed subtasks and collect garbage.
GC the output object refs of the subtask which successors are submitted
(not completed as above) can reduce the memory peaks, but we can't cancel
and rerun slow subtasks because the input object refs of running subtasks
may be deleted.
"""
i = 0
gc_subtasks = set()

while i < total:
while i >= len(completed_subtasks):
yield
# Iterate the completed subtasks once.
subtask = completed_subtasks[i]
i += 1
logger.debug("GC: %s", subtask)

# Note: There may be a scenario in which delayed gc occurs.
# When a subtask has more than one predecessor, like A, B,
# and in the `for ... in ...` loop we get A firstly while
# B's successors are completed, A's not. Then we cannot remove
# B's results chunks before A's.
for pred in subtask_graph.iter_predecessors(subtask):
if pred in gc_subtasks:
continue
while not all(
succ in completed_subtasks
for succ in subtask_graph.iter_successors(pred)
):
yield
for chunk in pred.chunk_graph.results:
self._task_context.pop(chunk.key, None)
gc_subtasks.add(pred)

# TODO(fyrestone): Check the remaining self._task_context.keys()
# in the result subtasks

collect_garbage = gc()

while len(completed_subtasks) != total:
if len(object_ref_to_subtask) <= 0: # pragma: no cover
await asyncio.sleep(interval_seconds)

# Only wait for unready subtask object refs.
ready_objects, _ = await asyncio.to_thread(
ray.wait,
list(object_ref_to_subtask.keys()),
num_returns=len(object_ref_to_subtask),
timeout=0,
fetch_local=False,
)
if len(ready_objects) == 0:
await asyncio.sleep(interval_seconds)
continue

# Pop the completed subtasks from object_ref_to_subtask.
completed_subtasks.update(map(object_ref_to_subtask.pop, ready_objects))
# Update progress.
stage_progress = (
len(completed_subtasks) / total * self._cur_stage_tile_progress
)
self._cur_stage_progress = self._pre_all_stages_progress + stage_progress
# Collect garbage, use `for ... in ...` to avoid raising StopIteration.
for _ in collect_garbage:
break
# Fast to next loop and give it a chance to update object_ref_to_subtask.
await asyncio.sleep(0)

0 comments on commit 505361b

Please sign in to comment.