Skip to content

Commit

Permalink
[Ray] Fix Ray context GC (#3118)
Browse files Browse the repository at this point in the history
  • Loading branch information
fyrestone committed Jun 6, 2022
1 parent 78628b6 commit 6ffc7b9
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 26 deletions.
8 changes: 7 additions & 1 deletion mars/services/task/execution/ray/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import logging
from typing import Dict, List
from .....resource import Resource
from ..api import ExecutionConfig, register_config_cls
from ..utils import get_band_resources_from_config


logger = logging.getLogger(__name__)

IN_RAY_CI = os.environ.get("MARS_CI_BACKEND", "mars") == "ray"
# The default interval seconds to update progress and collect garbage.
DEFAULT_SUBTASK_MONITOR_INTERVAL = 1
DEFAULT_SUBTASK_MONITOR_INTERVAL = 0 if IN_RAY_CI else 1


@register_config_cls
Expand Down
65 changes: 46 additions & 19 deletions mars/services/task/execution/ray/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import functools
import logging
import operator
import sys
from dataclasses import dataclass
from typing import List, Dict, Any, Set, Callable
from .....core import ChunkGraph, Chunk, TileContext
Expand Down Expand Up @@ -49,7 +50,7 @@
ExecutionChunkResult,
register_executor_cls,
)
from .config import RayExecutionConfig
from .config import RayExecutionConfig, IN_RAY_CI
from .context import (
RayExecutionContext,
RayExecutionWorkerContext,
Expand Down Expand Up @@ -314,35 +315,54 @@ async def execute_subtask_graph(
) -> Dict[Chunk, ExecutionChunkResult]:
if self._cancelled is True: # pragma: no cover
raise asyncio.CancelledError()
logger.info("Stage %s start.", stage_id)
# Make sure each stage use a clean dict.
self._cur_stage_first_output_object_ref_to_subtask = dict()

def _on_monitor_task_done(fut):
def _on_monitor_aiotask_done(fut):
# Print the error of monitor task.
try:
fut.result()
except asyncio.CancelledError:
pass
except Exception: # pragma: no cover
logger.exception(
"The monitor task of stage %s is done with exception.", stage_id
)
if IN_RAY_CI: # pragma: no cover
logger.warning(
"The process will be exit due to the monitor task exception "
"when MARS_CI_BACKEND=ray."
)
sys.exit(-1)

result_meta_keys = {
chunk.key
for chunk in chunk_graph.result_chunks
if not isinstance(chunk.op, Fetch)
}
# Create a monitor task to update progress and collect garbage.
monitor_task = asyncio.create_task(
monitor_aiotask = asyncio.create_task(
self._update_progress_and_collect_garbage(
subtask_graph, self._config.get_subtask_monitor_interval()
stage_id,
subtask_graph,
result_meta_keys,
self._config.get_subtask_monitor_interval(),
)
)
monitor_task.add_done_callback(_on_monitor_task_done)
monitor_aiotask.add_done_callback(_on_monitor_aiotask_done)

def _on_execute_task_done(fut):
def _on_execute_aiotask_done(_):
# Make sure the monitor task is cancelled.
monitor_task.cancel()
monitor_aiotask.cancel()
# Just use `self._cur_stage_tile_progress` as current stage progress
# because current stage is completed, its progress is 1.0.
self._cur_stage_progress = 1.0
self._pre_all_stages_progress += self._cur_stage_tile_progress
self._cur_stage_first_output_object_ref_to_subtask.clear()

self._execute_subtask_graph_aiotask = asyncio.current_task()
self._execute_subtask_graph_aiotask.add_done_callback(_on_execute_task_done)
self._execute_subtask_graph_aiotask.add_done_callback(_on_execute_aiotask_done)

logger.info("Stage %s start.", stage_id)
task_context = self._task_context
output_meta_object_refs = []
self._pre_all_stages_tile_progress = (
Expand All @@ -352,11 +372,6 @@ def _on_execute_task_done(fut):
self._tile_context.get_all_progress() - self._pre_all_stages_tile_progress
)
logger.info("Submitting %s subtasks of stage %s.", len(subtask_graph), stage_id)
result_meta_keys = {
chunk.key
for chunk in chunk_graph.result_chunks
if not isinstance(chunk.op, Fetch)
}
subtask_max_retries = self._config.get_subtask_max_retries()
for subtask in subtask_graph.topological_iter():
subtask_chunk_graph = subtask.chunk_graph
Expand Down Expand Up @@ -555,7 +570,11 @@ def _get_subtask_output_keys(chunk_graph: ChunkGraph):
return output_keys.keys()

async def _update_progress_and_collect_garbage(
self, subtask_graph: SubtaskGraph, interval_seconds: float
self,
stage_id: str,
subtask_graph: SubtaskGraph,
result_meta_keys: Set[str],
interval_seconds: float,
):
object_ref_to_subtask = self._cur_stage_first_output_object_ref_to_subtask
total = len(subtask_graph)
Expand All @@ -579,7 +598,7 @@ def gc():
# Iterate the completed subtasks once.
subtask = completed_subtasks[i]
i += 1
logger.debug("GC: %s", subtask)
logger.debug("GC[stage=%s]: %s", stage_id, subtask)

# Note: There may be a scenario in which delayed gc occurs.
# When a subtask has more than one predecessor, like A, B,
Expand All @@ -595,15 +614,23 @@ def gc():
):
yield
for chunk in pred.chunk_graph.results:
self._task_context.pop(chunk.key, None)
chunk_key = chunk.key
# We need to check the GC chunk key is not in the
# result meta keys, because there are some special
# cases that the result meta keys are not the leaves.
#
# example: test_cut_execution
if chunk_key not in result_meta_keys:
logger.debug("GC[stage=%s]: %s", stage_id, chunk)
self._task_context.pop(chunk_key, None)
gc_subtasks.add(pred)

# TODO(fyrestone): Check the remaining self._task_context.keys()
# in the result subtasks

collect_garbage = gc()

while len(completed_subtasks) != total:
while len(completed_subtasks) < total:
if len(object_ref_to_subtask) <= 0: # pragma: no cover
await asyncio.sleep(interval_seconds)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,21 @@ async def submit_subtask_graph(
subtask_graph: SubtaskGraph,
chunk_graph: ChunkGraph,
):
monitor_task = asyncio.create_task(
self._update_progress_and_collect_garbage(
subtask_graph, self._config.get_subtask_monitor_interval()
)
)

result_meta_keys = {
chunk.key
for chunk in chunk_graph.result_chunks
if not isinstance(chunk.op, Fetch)
}

monitor_task = asyncio.create_task(
self._update_progress_and_collect_garbage(
stage_id,
subtask_graph,
result_meta_keys,
self._config.get_subtask_monitor_interval(),
)
)

for subtask in subtask_graph.topological_iter():
subtask_chunk_graph = subtask.chunk_graph
task_context = self._task_context
Expand Down

0 comments on commit 6ffc7b9

Please sign in to comment.