Skip to content

Commit

Permalink
[data] Implement zero-copy fusion for Read op (ray-project#38789)
Browse files Browse the repository at this point in the history
Optimize `Read -> Map/Write` fusion. In this case, we can drop the unnecessary `BuildOutputBlocks` transform_fn.

Also change `MapTransformFn` to an abstract class and enforce implementations to use subclasses. This is for optimization rules to better detecting the pattern.
---------

Signed-off-by: Hao Chen <chenh1024@gmail.com>
Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
raulchen authored and arvind-chandra committed Aug 31, 2023
1 parent d276c49 commit 672187d
Show file tree
Hide file tree
Showing 8 changed files with 254 additions and 45 deletions.
103 changes: 85 additions & 18 deletions python/ray/data/_internal/execution/operators/map_transformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
from abc import abstractmethod
from enum import Enum
from typing import Any, Callable, Dict, Iterable, List, Optional, TypeVar, Union

Expand Down Expand Up @@ -31,7 +32,6 @@ class MapTransformFn:

def __init__(
self,
callable: MapTransformCallable[MapTransformFnData, MapTransformFnData],
input_type: MapTransformFnDataType,
output_type: MapTransformFnDataType,
):
Expand All @@ -45,10 +45,11 @@ def __init__(
self._input_type = input_type
self._output_type = output_type

@abstractmethod
def __call__(
self, input: Iterable[MapTransformFnData], ctx: TaskContext
) -> Iterable[MapTransformFnData]:
return self._callable(input, ctx)
...

@property
def input_type(self) -> MapTransformFnDataType:
Expand Down Expand Up @@ -80,6 +81,11 @@ def __init__(
init_fn: A function that will be called before transforming data.
Used for the actor-based map operator.
"""
self.set_transform_fns(transform_fns)
self._init_fn = init_fn if init_fn is not None else lambda: None

def set_transform_fns(self, transform_fns: List[MapTransformFn]) -> None:
"""Set the transform functions."""
assert len(transform_fns) > 0
assert (
transform_fns[0].input_type == MapTransformFnDataType.Block
Expand All @@ -93,9 +99,11 @@ def __init__(
"The output type of the previous transform function must match "
"the input type of the next transform function."
)

self._transform_fns = transform_fns
self._init_fn = init_fn if init_fn is not None else lambda: None

def get_transform_fns(self) -> List[MapTransformFn]:
"""Get the transform functions."""
return self._transform_fns

def init(self) -> None:
"""Initialize the transformer.
Expand Down Expand Up @@ -140,32 +148,78 @@ def create_map_transformer_from_block_fn(
"""
return MapTransformer(
[
MapTransformFn(
block_fn,
MapTransformFnDataType.Block,
MapTransformFnDataType.Block,
)
BlockMapTransformFn(block_fn),
],
init_fn,
)


# Below are util `MapTransformFn`s for converting input/output data.
# Below are subclasses of MapTransformFn.


class RowMapTransformFn(MapTransformFn):
"""A rows-to-rows MapTransformFn."""

def __init__(self, row_fn: MapTransformCallable[Row, Row]):
self._row_fn = row_fn
super().__init__(
MapTransformFnDataType.Row,
MapTransformFnDataType.Row,
)

def __call__(self, input: Iterable[Row], ctx: TaskContext) -> Iterable[Row]:
yield from self._row_fn(input, ctx)

def __repr__(self) -> str:
return f"RowMapTransformFn({self._row_fn})"


class BatchMapTransformFn(MapTransformFn):
"""A batch-to-batch MapTransformFn."""

def __init__(self, batch_fn: MapTransformCallable[DataBatch, DataBatch]):
self._batch_fn = batch_fn
super().__init__(
MapTransformFnDataType.Batch,
MapTransformFnDataType.Batch,
)

def __call__(
self, input: Iterable[DataBatch], ctx: TaskContext
) -> Iterable[DataBatch]:
yield from self._batch_fn(input, ctx)

def __repr__(self) -> str:
return f"BatchMapTransformFn({self._batch_fn})"


class BlockMapTransformFn(MapTransformFn):
"""A block-to-block MapTransformFn."""

def __init__(self, block_fn: MapTransformCallable[Block, Block]):
self._block_fn = block_fn
super().__init__(
MapTransformFnDataType.Block,
MapTransformFnDataType.Block,
)

def __call__(self, input: Iterable[Block], ctx: TaskContext) -> Iterable[Block]:
yield from self._block_fn(input, ctx)

def __repr__(self) -> str:
return f"BlockMapTransformFn({self._block_fn})"


class BlocksToRowsMapTransformFn(MapTransformFn):
"""A MapTransformFn that converts input blocks to rows."""

def __init__(self):
super().__init__(
self._input_blocks_to_rows,
MapTransformFnDataType.Block,
MapTransformFnDataType.Row,
)

def _input_blocks_to_rows(
self, blocks: Iterable[Block], _: TaskContext
) -> Iterable[Row]:
def __call__(self, blocks: Iterable[Block], _: TaskContext) -> Iterable[Row]:
for block in blocks:
block = BlockAccessor.for_block(block)
for row in block.iter_rows(public_row_format=True):
Expand All @@ -178,6 +232,9 @@ def instance(cls) -> "BlocksToRowsMapTransformFn":
cls._instance = cls()
return cls._instance

def __repr__(self) -> str:
return "BlocksToRowsMapTransformFn()"


class BlocksToBatchesMapTransformFn(MapTransformFn):
"""A MapTransformFn that converts input blocks to batches."""
Expand All @@ -192,12 +249,11 @@ def __init__(
self._batch_format = batch_format
self._ensure_copy = not zero_copy_batch and batch_size is not None
super().__init__(
self._input_blocks_to_batches,
MapTransformFnDataType.Block,
MapTransformFnDataType.Batch,
)

def _input_blocks_to_batches(
def __call__(
self,
blocks: Iterable[Block],
_: TaskContext,
Expand Down Expand Up @@ -241,6 +297,15 @@ def batch_format(self) -> str:
def zero_copy_batch(self) -> bool:
return not self._ensure_copy

def __repr__(self) -> str:
return (
f"BlocksToBatchesMapTransformFn("
f"batch_size={self._batch_size}, "
f"batch_format={self._batch_format}, "
f"zero_copy_batch={self.zero_copy_batch}"
f")"
)


class BuildOutputBlocksMapTransformFn(MapTransformFn):
"""A MapTransformFn that converts UDF-returned data to output blocks."""
Expand All @@ -252,12 +317,11 @@ def __init__(self, input_type: MapTransformFnDataType):
"""
self._input_type = input_type
super().__init__(
self._to_output_blocks,
input_type,
MapTransformFnDataType.Block,
)

def _to_output_blocks(
def __call__(
self,
iter: Iterable[MapTransformFnData],
_: TaskContext,
Expand Down Expand Up @@ -306,3 +370,6 @@ def for_blocks(cls) -> "BuildOutputBlocksMapTransformFn":
if getattr(cls, "_instance_for_blocks", None) is None:
cls._instance_for_blocks = cls(MapTransformFnDataType.Block)
return cls._instance_for_blocks

def __repr__(self) -> str:
return f"BuildOutputBlocksMapTransformFn(input_type={self._input_type})"
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from ray.data._internal.logical.rules.operator_fusion import OperatorFusionRule
from ray.data._internal.logical.rules.randomize_blocks import ReorderRandomizeBlocksRule
from ray.data._internal.logical.rules.zero_copy_map_fusion import (
EliminateBuildOutputBlocks,
)


def get_logical_optimizer_rules():
Expand All @@ -8,5 +11,7 @@ def get_logical_optimizer_rules():


def get_physical_optimizer_rules():
rules = [OperatorFusionRule]
# Subclasses of ZeroCopyMapFusionRule (e.g., EliminateBuildOutputBlocks) should
# be run after OperatorFusionRule.
rules = [OperatorFusionRule, EliminateBuildOutputBlocks]
return rules
88 changes: 88 additions & 0 deletions python/ray/data/_internal/logical/rules/zero_copy_map_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from abc import abstractmethod
from typing import List

from ray.data._internal.execution.operators.map_operator import MapOperator
from ray.data._internal.execution.operators.map_transformer import (
BuildOutputBlocksMapTransformFn,
MapTransformFn,
MapTransformFnDataType,
)
from ray.data._internal.logical.interfaces.optimizer import Rule
from ray.data._internal.logical.interfaces.physical_plan import PhysicalPlan


class ZeroCopyMapFusionRule(Rule):
"""Base abstract class for all zero-copy map fusion rules.
A zero-copy map fusion rule is a rule that optimizes the transform_fn chain of
a fused MapOperator. The optimization is usually done by removing unnecessary
data conversions.
This base abstract class defines the common util functions. And subclasses
should implement the `_optimize` method for the concrete optimization
strategy.
"""

def apply(self, plan: PhysicalPlan) -> PhysicalPlan:
self._traverse(plan.dag)
return plan

def _traverse(self, op):
"""Traverse the DAG and apply the optimization to each MapOperator."""
if isinstance(op, MapOperator):
map_transformer = op.get_map_transformer()
transform_fns = map_transformer.get_transform_fns()
new_transform_fns = self._optimize(transform_fns)
# Physical operators won't be shared,
# so it's safe to modify the transform_fns in place.
map_transformer.set_transform_fns(new_transform_fns)

for input_op in op.input_dependencies:
self._traverse(input_op)

@abstractmethod
def _optimize(self, transform_fns: List[MapTransformFn]) -> List[MapTransformFn]:
"""Optimize the transform_fns chain of a MapOperator.
Args:
transform_fns: The old transform_fns chain.
Returns:
The optimized transform_fns chain.
"""
...


class EliminateBuildOutputBlocks(ZeroCopyMapFusionRule):
"""This rule eliminates unnecessary BuildOutputBlocksMapTransformFn,
if the previous fn already outputs blocks.
This happens for the "Read -> Map/Write" fusion.
"""

def _optimize(self, transform_fns: List[MapTransformFn]) -> List[MapTransformFn]:
# For the following subsquence,
# 1. Any MapTransformFn with block output.
# 2. BuildOutputBlocksMapTransformFn
# 3. Any MapTransformFn with block input.
# We drop the BuildOutputBlocksMapTransformFn in the middle.
new_transform_fns = []

for i in range(len(transform_fns)):
cur_fn = transform_fns[i]
drop = False
if (
i > 0
and i < len(transform_fns) - 1
and isinstance(cur_fn, BuildOutputBlocksMapTransformFn)
):
prev_fn = transform_fns[i - 1]
next_fn = transform_fns[i + 1]
if (
prev_fn.output_type == MapTransformFnDataType.Block
and next_fn.input_type == MapTransformFnDataType.Block
):
drop = True
if not drop:
new_transform_fns.append(cur_fn)

return new_transform_fns
23 changes: 9 additions & 14 deletions python/ray/data/_internal/planner/plan_read_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
from ray.data._internal.execution.operators.input_data_buffer import InputDataBuffer
from ray.data._internal.execution.operators.map_operator import MapOperator
from ray.data._internal.execution.operators.map_transformer import (
BlockMapTransformFn,
BuildOutputBlocksMapTransformFn,
MapTransformer,
MapTransformFn,
MapTransformFnDataType,
)
from ray.data._internal.logical.operators.read_operator import Read
from ray.data.block import Block, BlockAccessor
Expand Down Expand Up @@ -110,23 +109,19 @@ def do_read(blocks: Iterable[ReadTask], _: TaskContext) -> Iterable[Block]:
# Create a MapTransformer for a read operator
transform_fns = [
# First, execute the read tasks.
MapTransformFn(
do_read, MapTransformFnDataType.Block, MapTransformFnDataType.Block
),
BlockMapTransformFn(do_read),
# Then build the output blocks.
BuildOutputBlocksMapTransformFn.for_blocks(),
]

if op._additional_split_factor is not None:
# If addtional split is needed, do it in the last.
transform_fns.append(
MapTransformFn(
BlockMapTransformFn(
functools.partial(
_do_additional_splits,
additional_output_splits=op._additional_split_factor,
),
MapTransformFnDataType.Block,
MapTransformFnDataType.Block,
)
),
)

Expand All @@ -148,17 +143,17 @@ def apply_output_blocks_handling_to_read_task(
This function is only used for compability with the legacy LazyBlockList code path.
"""
transform_fns: List[MapTransformFn] = [BuildOutputBlocksMapTransformFn.for_blocks()]
transform_fns: List[BlockMapTransformFn] = [
BuildOutputBlocksMapTransformFn.for_blocks()
]

if additional_split_factor is not None:
transform_fns.append(
MapTransformFn(
BlockMapTransformFn(
functools.partial(
_do_additional_splits,
additional_output_splits=additional_split_factor,
),
MapTransformFnDataType.Block,
MapTransformFnDataType.Block,
)
),
)
map_transformer = MapTransformer(transform_fns)
Expand Down
Loading

0 comments on commit 672187d

Please sign in to comment.