Skip to content

Commit

Permalink
Re-interpret 'fusable' to mean with predecessor operations. (#460)
Browse files Browse the repository at this point in the history
An operation like 'map_direct' which can read data across
block boundaries from its inputs is marked with `fusable=False`,
since it cannot be fused with predecessor operations. However,
it is permitted for successor operations to fuse with `map_direct`,
since they will operate on whole blocks.
  • Loading branch information
tomwhite committed May 16, 2024
1 parent ff2f23a commit 2e3a935
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 10 deletions.
2 changes: 1 addition & 1 deletion cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def wrap(*a, block_id=None, **kw):
chunks=chunks,
extra_source_arrays=args,
extra_projected_mem=extra_projected_mem,
fusable=False, # don't allow fusion since side inputs are not accounted for
fusable=False, # don't allow fusion with predecessors since side inputs are not accounted for
**kwargs,
)

Expand Down
23 changes: 15 additions & 8 deletions cubed/core/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,14 @@ def predecessor_ops(dag, name):
yield pre_list[0]


def is_primitive_op(node_dict):
"""Return True if a node is a primitive op"""
return "primitive_op" in node_dict


def is_fusable(node_dict):
"Return True if a node can be fused."
return "primitive_op" in node_dict and node_dict["primitive_op"].fusable
"""Return True if a node is a primitive op and can be fused with its predecessors."""
return is_primitive_op(node_dict) and node_dict["primitive_op"].fusable


def can_fuse_predecessors(
Expand All @@ -124,7 +129,7 @@ def can_fuse_predecessors(
return False

# if no predecessor ops can be fused then there is nothing to fuse
if all(not is_fusable(nodes[pre]) for pre in predecessor_ops(dag, name)):
if all(not is_primitive_op(nodes[pre]) for pre in predecessor_ops(dag, name)):
logger.debug("can't fuse %s since no predecessor ops can be fused", name)
return False

Expand All @@ -140,7 +145,9 @@ def can_fuse_predecessors(
# the fused function would be more than an allowed maximum, then don't fuse
if len(list(predecessor_ops(dag, name))) > 1:
total_source_arrays = sum(
len(list(predecessors_unordered(dag, pre))) if is_fusable(nodes[pre]) else 1
len(list(predecessors_unordered(dag, pre)))
if is_primitive_op(nodes[pre])
else 1
for pre in predecessor_ops(dag, name)
)
if total_source_arrays > max_total_source_arrays:
Expand All @@ -155,7 +162,7 @@ def can_fuse_predecessors(
predecessor_primitive_ops = [
nodes[pre]["primitive_op"]
for pre in predecessor_ops(dag, name)
if is_fusable(nodes[pre])
if is_primitive_op(nodes[pre])
]
return can_fuse_multiple_primitive_ops(
name,
Expand Down Expand Up @@ -193,7 +200,7 @@ def fuse_predecessors(

# if a predecessor has no primitive op then just use None
predecessor_primitive_ops = [
nodes[pre]["primitive_op"] if is_fusable(nodes[pre]) else None
nodes[pre]["primitive_op"] if is_primitive_op(nodes[pre]) else None
for pre in predecessor_ops(dag, name)
]

Expand All @@ -210,12 +217,12 @@ def fuse_predecessors(
# 1. update edges to change inputs
for input in predecessors_unordered(dag, name):
pre = next(predecessors_unordered(dag, input))
if not is_fusable(fused_nodes[pre]):
if not is_primitive_op(fused_nodes[pre]):
# if a predecessor is not fusable then don't change the edge
continue
fused_dag.remove_edge(input, name)
for pre in predecessor_ops(dag, name):
if not is_fusable(fused_nodes[pre]):
if not is_primitive_op(fused_nodes[pre]):
# if a predecessor is not fusable then don't change the edge
continue
for input in predecessors_unordered(dag, pre):
Expand Down
2 changes: 1 addition & 1 deletion cubed/primitive/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class PrimitiveOperation:
"""The number of tasks needed to run this operation."""

fusable: bool = True
"""Whether this operation should be considered for fusion."""
"""Whether this operation can be fused with predecessor operations."""

write_chunks: Optional[T_RegularChunks] = None
"""The chunk size used by this operation."""
Expand Down
22 changes: 22 additions & 0 deletions cubed/tests/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,28 @@ def test_fusion_transpose(spec):
)


def test_fusion_map_direct(spec):
# test that operations after a map_direct operation (indexing) can be fused
# with the map_direct operation
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = a[1:, :]
c = xp.negative(b) # should be fused with b

num_created_arrays = 2 # b, c
assert c.plan.num_tasks(optimize_graph=False) == num_created_arrays + 4
num_created_arrays = 1 # c
assert c.plan.num_tasks(optimize_graph=True) == num_created_arrays + 2

task_counter = TaskCounter()
result = c.compute(callbacks=[task_counter])
assert task_counter.value == num_created_arrays + 2

assert_array_equal(
result,
np.array([[-4, -5, -6], [-7, -8, -9]]),
)


def test_no_fusion(spec):
# b can't be fused with c because d also depends on b
a = xp.ones((2, 2), chunks=(2, 2), spec=spec)
Expand Down

0 comments on commit 2e3a935

Please sign in to comment.