Skip to content

Commit

Permalink
[Data] Account internal inqueue to previous operator's memory usage (r…
Browse files Browse the repository at this point in the history
…ay-project#42930)

To make informed scheduling decisions, Ray Data estimates the object store memory usage from each operator. Currently, an operator's inputs count towards its memory usage; however, this can lead to suboptimal scheduling decisions since we primarily care about outputs for backpressuring.

This PR updates the implementation so that input counts toward the previous operator's object store usage (i.e., in counts towards the operator that created the data).

---------

Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
  • Loading branch information
bveeramani authored and kevin85421 committed Feb 17, 2024
1 parent d0ce4b9 commit 8c920a2
Show file tree
Hide file tree
Showing 16 changed files with 396 additions and 174 deletions.
2 changes: 1 addition & 1 deletion doc/source/data/inspecting-data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ To view stats about your :class:`Datasets <ray.data.Dataset>`, call :meth:`Datas
* Output num rows: 150 min, 150 max, 150 mean, 150 total
* Output size bytes: 6000 min, 6000 max, 6000 mean, 6000 total
* Tasks per node: 1 min, 1 max, 1 mean; 1 nodes used
* Extra metrics: {'obj_store_mem_freed': 5761, 'obj_store_mem_peak': 6000}
* Extra metrics: {'obj_store_mem_freed': 5761}

Dataset iterator time breakdown:
* Total time user code is blocked: 5.68ms
Expand Down
2 changes: 1 addition & 1 deletion doc/source/data/monitoring-your-workload.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,6 @@ When an operator completes, the metrics for that operator are also logged.
.. code-block:: text
Operator InputDataBuffer[Input] -> TaskPoolMapOperator[ReadRange->MapBatches(<lambda>)] completed. Operator Metrics:
{'num_inputs_received': 20, 'bytes_inputs_received': 46440, 'num_task_inputs_processed': 20, 'bytes_task_inputs_processed': 46440, 'num_task_outputs_generated': 20, 'bytes_task_outputs_generated': 800, 'rows_task_outputs_generated': 100, 'num_outputs_taken': 20, 'bytes_outputs_taken': 800, 'num_outputs_of_finished_tasks': 20, 'bytes_outputs_of_finished_tasks': 800, 'num_tasks_submitted': 20, 'num_tasks_running': 0, 'num_tasks_have_outputs': 20, 'num_tasks_finished': 20, 'obj_store_mem_freed': 46440, 'obj_store_mem_cur': 0, 'obj_store_mem_peak': 23260, 'obj_store_mem_spilled': 0, 'block_generation_time': 1.191296085, 'cpu_usage': 0, 'gpu_usage': 0, 'ray_remote_args': {'num_cpus': 1, 'scheduling_strategy': 'SPREAD'}}
{'num_inputs_received': 20, 'bytes_inputs_received': 46440, 'num_task_inputs_processed': 20, 'bytes_task_inputs_processed': 46440, 'num_task_outputs_generated': 20, 'bytes_task_outputs_generated': 800, 'rows_task_outputs_generated': 100, 'num_outputs_taken': 20, 'bytes_outputs_taken': 800, 'num_outputs_of_finished_tasks': 20, 'bytes_outputs_of_finished_tasks': 800, 'num_tasks_submitted': 20, 'num_tasks_running': 0, 'num_tasks_have_outputs': 20, 'num_tasks_finished': 20, 'obj_store_mem_freed': 46440, 'obj_store_mem_spilled': 0, 'block_generation_time': 1.191296085, 'cpu_usage': 0, 'gpu_usage': 0, 'ray_remote_args': {'num_cpus': 1, 'scheduling_strategy': 'SPREAD'}}
This log file can be found locally at `/tmp/ray/{SESSION_NAME}/logs/ray-data.log`. It can also be found on the Ray Dashboard under the head node's logs in the :ref:`Logs view <dash-logs-view>`.
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,23 @@ class OpRuntimeMetrics:

# === Object store memory metrics ===

# Size in bytes of input blocks in the operator's internal input queue.
obj_store_mem_internal_inqueue: int = field(
default=0, metadata={"export_metric": True}
)
# Size in bytes of output blocks in the operator's internal output queue.
obj_store_mem_internal_outqueue: int = field(
default=0, metadata={"export_metric": True}
)
# Size in bytes of input blocks used by pending tasks.
obj_store_mem_pending_task_inputs: int = field(
default=0, metadata={"map_only": True}
)

# Freed memory size in the object store.
obj_store_mem_freed: int = field(
default=0, metadata={"map_only": True, "export_metric": True}
)
# Current memory size in the object store.
obj_store_mem_cur: int = field(
default=0, metadata={"map_only": True, "export_metric": True}
)
# Peak memory size in the object store.
obj_store_mem_peak: int = field(default=0, metadata={"map_only": True})
# Spilled memory size in the object store.
obj_store_mem_spilled: int = field(
default=0, metadata={"map_only": True, "export_metric": True}
Expand Down Expand Up @@ -131,8 +138,8 @@ def as_dict(self, metrics_only: bool = False):
result.append((f.name, value))

# TODO: record resource usage in OpRuntimeMetrics,
# avoid calling self._op.current_resource_usage()
resource_usage = self._op.current_resource_usage()
# avoid calling self._op.current_processor_usage()
resource_usage = self._op.current_processor_usage()
result.extend(
[
("cpu_usage", resource_usage.cpu or 0),
Expand Down Expand Up @@ -166,7 +173,7 @@ def average_bytes_per_output(self) -> Optional[float]:
return self.bytes_task_outputs_generated / self.num_task_outputs_generated

@property
def obj_store_mem_pending_tasks(self) -> Optional[float]:
def obj_store_mem_pending_task_outputs(self) -> Optional[float]:
"""Estimated size in bytes of output blocks in Ray generator buffers.
If an estimate isn't available, this property returns ``None``.
Expand Down Expand Up @@ -214,41 +221,41 @@ def average_bytes_change_per_task(self) -> Optional[float]:

return self.average_bytes_outputs_per_task - self.average_bytes_inputs_per_task

@property
def input_buffer_bytes(self) -> int:
"""Size in bytes of input blocks that are not processed yet."""
return self.bytes_inputs_received - self.bytes_task_inputs_processed

@property
def output_buffer_bytes(self) -> int:
"""Size in bytes of output blocks that are not taken by the downstream yet."""
return self.bytes_task_outputs_generated - self.bytes_outputs_taken

def on_input_received(self, input: RefBundle):
"""Callback when the operator receives a new input."""
self.num_inputs_received += 1
input_size = input.size_bytes()
self.bytes_inputs_received += input_size
# Update object store metrics.
self.obj_store_mem_cur += input_size
if self.obj_store_mem_cur > self.obj_store_mem_peak:
self.obj_store_mem_peak = self.obj_store_mem_cur
self.bytes_inputs_received += input.size_bytes()

def on_input_queued(self, input: RefBundle):
"""Callback when the operator queues an input."""
self.obj_store_mem_internal_inqueue += input.size_bytes()

def on_input_dequeued(self, input: RefBundle):
"""Callback when the operator dequeues an input."""
self.obj_store_mem_internal_inqueue -= input.size_bytes()

def on_output_queued(self, output: RefBundle):
"""Callback when an output is queued by the operator."""
self.obj_store_mem_internal_outqueue += output.size_bytes()

def on_output_dequeued(self, output: RefBundle):
"""Callback when an output is dequeued by the operator."""
self.obj_store_mem_internal_outqueue -= output.size_bytes()

def on_output_taken(self, output: RefBundle):
"""Callback when an output is taken from the operator."""
output_bytes = output.size_bytes()
self.num_outputs_taken += 1
self.bytes_outputs_taken += output_bytes
self.obj_store_mem_cur -= output_bytes
self.bytes_outputs_taken += output.size_bytes()

def on_task_submitted(self, task_index: int, inputs: RefBundle):
"""Callback when the operator submits a task."""
self.num_tasks_submitted += 1
self.num_tasks_running += 1
self.bytes_inputs_of_submitted_tasks += inputs.size_bytes()
self.obj_store_mem_pending_task_inputs += inputs.size_bytes()
self._running_tasks[task_index] = RunningTaskInfo(inputs, 0, 0)

def on_output_generated(self, task_index: int, output: RefBundle):
def on_task_output_generated(self, task_index: int, output: RefBundle):
"""Callback when a new task generates an output."""
num_outputs = len(output)
output_bytes = output.size_bytes()
Expand All @@ -262,11 +269,6 @@ def on_output_generated(self, task_index: int, output: RefBundle):
task_info.num_outputs += num_outputs
task_info.bytes_outputs += output_bytes

# Update object store metrics.
self.obj_store_mem_cur += output_bytes
if self.obj_store_mem_cur > self.obj_store_mem_peak:
self.obj_store_mem_peak = self.obj_store_mem_cur

for block_ref, meta in output.blocks:
assert meta.exec_stats and meta.exec_stats.wall_time_s
self.block_generation_time += meta.exec_stats.wall_time_s
Expand All @@ -289,6 +291,7 @@ def on_task_finished(self, task_index: int, exception: Optional[Exception]):
self.num_task_inputs_processed += len(inputs)
total_input_size = inputs.size_bytes()
self.bytes_task_inputs_processed += total_input_size
self.obj_store_mem_pending_task_inputs -= inputs.size_bytes()

blocks = [input[0] for input in inputs.blocks]
metadata = [input[1] for input in inputs.blocks]
Expand All @@ -301,7 +304,6 @@ def on_task_finished(self, task_index: int, exception: Optional[Exception]):
self.obj_store_mem_spilled += meta.size_bytes

self.obj_store_mem_freed += total_input_size
self.obj_store_mem_cur -= total_input_size

inputs.destroy_if_owned()
del self._running_tasks[task_index]
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,11 @@ def shutdown(self) -> None:
if not self._started:
raise ValueError("Operator must be started before being shutdown.")

def current_resource_usage(self) -> ExecutionResources:
"""Returns the current estimated resource usage of this operator.
def current_processor_usage(self) -> ExecutionResources:
"""Returns the current estimated CPU and GPU usage of this operator, excluding
object store memory.
This method is called by the executor to decide how to allocate resources
This method is called by the executor to decide how to allocate processors
between different operators.
"""
return ExecutionResources(0, 0, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def _task_done_callback(res_ref):

def _add_bundled_input(self, bundle: RefBundle):
self._bundle_queue.append(bundle)
self._metrics.on_input_queued(bundle)
# Try to dispatch all bundles in the queue, including this new bundle.
self._dispatch_tasks()

Expand All @@ -193,6 +194,7 @@ def _dispatch_tasks(self):
break
# Submit the map task.
bundle = self._bundle_queue.popleft()
self._metrics.on_input_dequeued(bundle)
input_blocks = [block for block, _ in bundle.blocks]
ctx = TaskContext(
task_idx=self._next_data_task_idx,
Expand Down Expand Up @@ -308,16 +310,12 @@ def base_resource_usage(self) -> ExecutionResources:
gpu=self._ray_remote_args.get("num_gpus", 0) * min_workers,
)

def current_resource_usage(self) -> ExecutionResources:
def current_processor_usage(self) -> ExecutionResources:
# Both pending and running actors count towards our current resource usage.
num_active_workers = self._actor_pool.num_total_actors()
object_store_memory = self.metrics.obj_store_mem_cur
if self.metrics.obj_store_mem_pending_tasks is not None:
object_store_memory += self.metrics.obj_store_mem_pending_tasks
return ExecutionResources(
cpu=self._ray_remote_args.get("num_cpus", 0) * num_active_workers,
gpu=self._ray_remote_args.get("num_gpus", 0) * num_active_workers,
object_store_memory=object_store_memory,
)

def incremental_resource_usage(self) -> ExecutionResources:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def slice_fn(block, metadata, num_rows) -> Tuple[Block, BlockMetadata]:
owns_blocks=refs.owns_blocks,
)
self._buffer.append(out_refs)
self._metrics.on_output_queued(out_refs)
if self._limit_reached():
self.mark_execution_completed()

Expand All @@ -98,7 +99,9 @@ def has_next(self) -> bool:
return len(self._buffer) > 0

def _get_next_inner(self) -> RefBundle:
return self._buffer.popleft()
output = self._buffer.popleft()
self._metrics.on_output_dequeued(output)
return output

def get_stats(self) -> StatsDict:
return {self._name: self._output_metadata}
Expand Down
8 changes: 6 additions & 2 deletions python/ray/data/_internal/execution/operators/map_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,12 @@ def _add_input_inner(self, refs: RefBundle, input_index: int):
assert input_index == 0, input_index
# Add RefBundle to the bundler.
self._block_ref_bundler.add_bundle(refs)
self._metrics.on_input_queued(refs)
if self._block_ref_bundler.has_bundle():
# If the bundler has a full bundle, add it to the operator's task submission
# queue.
bundle = self._block_ref_bundler.get_next_bundle()
self._metrics.on_input_dequeued(bundle)
self._add_bundled_input(bundle)

def _get_runtime_ray_remote_args(
Expand Down Expand Up @@ -293,10 +295,11 @@ def _submit_data_task(
def _output_ready_callback(task_index, output: RefBundle):
# Since output is streamed, it should only contain one block.
assert len(output) == 1
self._metrics.on_output_generated(task_index, output)
self._metrics.on_task_output_generated(task_index, output)

# Notify output queue that the task has produced an new output.
self._output_queue.notify_task_output_ready(task_index, output)
self._metrics.on_output_queued(output)

def _task_done_callback(task_index: int, exception: Optional[Exception]):
self._metrics.on_task_finished(task_index, exception)
Expand Down Expand Up @@ -361,6 +364,7 @@ def has_next(self) -> bool:
def _get_next_inner(self) -> RefBundle:
assert self._started
bundle = self._output_queue.get_next()
self._metrics.on_output_dequeued(bundle)
for _, meta in bundle.blocks:
self._output_metadata.append(meta)
return bundle
Expand All @@ -384,7 +388,7 @@ def shutdown(self):
self._finished_streaming_gens.clear()

@abstractmethod
def current_resource_usage(self) -> ExecutionResources:
def current_processor_usage(self) -> ExecutionResources:
raise NotImplementedError

@abstractmethod
Expand Down
20 changes: 11 additions & 9 deletions python/ray/data/_internal/execution/operators/output_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from ray.data._internal.execution.interfaces import (
ExecutionOptions,
ExecutionResources,
NodeIdStr,
PhysicalOperator,
RefBundle,
Expand Down Expand Up @@ -87,7 +86,9 @@ def has_next(self) -> bool:
return len(self._output_queue) > 0

def _get_next_inner(self) -> RefBundle:
return self._output_queue.popleft()
output = self._output_queue.popleft()
self._metrics.on_output_dequeued(output)
return output

def get_stats(self) -> StatsDict:
return {"split": []} # TODO(ekl) add split metrics?
Expand All @@ -102,6 +103,7 @@ def _add_input_inner(self, bundle, input_index) -> None:
if bundle.num_rows() is None:
raise ValueError("OutputSplitter requires bundles with known row count")
self._buffer.append(bundle)
self._metrics.on_input_queued(bundle)
self._dispatch_bundles()

def all_inputs_done(self) -> None:
Expand Down Expand Up @@ -132,17 +134,12 @@ def all_inputs_done(self) -> None:
for b in bundles:
b.output_split_idx = i
self._output_queue.append(b)
self._metrics.on_output_queued(b)
self._buffer = []

def internal_queue_size(self) -> int:
return len(self._buffer)

def current_resource_usage(self) -> ExecutionResources:
return ExecutionResources(
object_store_memory=sum(b.size_bytes() for b in self._buffer)
+ sum(b.size_bytes() for b in self._output_queue)
)

def progress_str(self) -> str:
if self._locality_hints:
return locality_string(self._locality_hits, self._locality_misses)
Expand All @@ -161,6 +158,7 @@ def _dispatch_bundles(self, dispatch_all: bool = False) -> None:
target_bundle.output_split_idx = target_index
self._num_output[target_index] += target_bundle.num_rows()
self._output_queue.append(target_bundle)
self._metrics.on_output_queued(target_bundle)
if self._locality_hints:
preferred_loc = self._locality_hints[target_index]
if self._get_location(target_bundle) == preferred_loc:
Expand All @@ -183,8 +181,12 @@ def _pop_bundle_to_dispatch(self, target_index: int) -> RefBundle:
for bundle in self._buffer:
if self._get_location(bundle) == preferred_loc:
self._buffer.remove(bundle)
self._metrics.on_input_dequeued(bundle)
return bundle
return self._buffer.pop(0)

bundle = self._buffer.pop(0)
self._metrics.on_input_dequeued(bundle)
return bundle

def _can_safely_dispatch(self, target_index: int, nrow: int) -> bool:
if not self._equal:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,11 @@ def progress_str(self) -> str:
def base_resource_usage(self) -> ExecutionResources:
return ExecutionResources()

def current_resource_usage(self) -> ExecutionResources:
def current_processor_usage(self) -> ExecutionResources:
num_active_workers = self.num_active_tasks()
object_store_memory = self.metrics.obj_store_mem_cur
if self.metrics.obj_store_mem_pending_tasks is not None:
object_store_memory += self.metrics.obj_store_mem_pending_tasks
return ExecutionResources(
cpu=self._ray_remote_args.get("num_cpus", 0) * num_active_workers,
gpu=self._ray_remote_args.get("num_gpus", 0) * num_active_workers,
object_store_memory=object_store_memory,
)

def incremental_resource_usage(self) -> ExecutionResources:
Expand Down
28 changes: 21 additions & 7 deletions python/ray/data/_internal/execution/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,9 @@ def update_usages(self):
num_ops_total = len(self._topology)
for op, state in reversed(self._topology.items()):
# Update `self._op_usages`.
op_usage = op.current_resource_usage()
# Don't count input refs towards dynamic memory usage, as they have been
# pre-created already outside this execution.
if not isinstance(op, InputDataBuffer):
op_usage.object_store_memory = (
op_usage.object_store_memory or 0
) + state.outqueue_memory_usage()
op_usage = op.current_processor_usage()
assert not op_usage.object_store_memory
op_usage.object_store_memory = _estimate_object_store_memory(op, state)
self._op_usages[op] = op_usage
# Update `self._global_usage`.
self._global_usage = self._global_usage.add(op_usage)
Expand Down Expand Up @@ -121,3 +117,21 @@ def get_downstream_fraction(self, op: PhysicalOperator) -> float:
def get_downstream_object_store_memory(self, op: PhysicalOperator) -> int:
"""Return the downstream object store memory usage of the given operator."""
return self._downstream_object_store_memory[op]


def _estimate_object_store_memory(op, state) -> int:
# Don't count input refs towards dynamic memory usage, as they have been
# pre-created already outside this execution.
if isinstance(op, InputDataBuffer):
return 0

object_store_memory = op.metrics.obj_store_mem_internal_outqueue
if op.metrics.obj_store_mem_pending_task_outputs is not None:
object_store_memory += op.metrics.obj_store_mem_pending_task_outputs
object_store_memory += state.outqueue_memory_usage()
for next_op in op.output_dependencies:
object_store_memory += (
next_op.metrics.obj_store_mem_internal_inqueue
+ next_op.metrics.obj_store_mem_pending_task_inputs
)
return object_store_memory
1 change: 0 additions & 1 deletion python/ray/data/_internal/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ def update_execution_metrics(
tags = self._create_tags(dataset_tag, operator_tag)
self.bytes_spilled.set(stats.get("obj_store_mem_spilled", 0), tags)
self.bytes_freed.set(stats.get("obj_store_mem_freed", 0), tags)
self.bytes_current.set(stats.get("obj_store_mem_cur", 0), tags)
self.bytes_outputted.set(stats.get("bytes_task_outputs_generated", 0), tags)
self.rows_outputted.set(stats.get("rows_task_outputs_generated", 0), tags)
self.cpu_usage.set(stats.get("cpu_usage", 0), tags)
Expand Down

0 comments on commit 8c920a2

Please sign in to comment.