Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fuse_elementwise] Fix external outputs collection #848

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 34 additions & 35 deletions python/aitemplate/compiler/transform/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""
import collections
import logging
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set

from aitemplate.compiler.base import Operator, Tensor
Expand Down Expand Up @@ -75,15 +75,15 @@ def get_node_groups(self) -> List[Set[Any]]:
return node_groups


def _find_fusable_elementwise_ops(op: Operator) -> Set[Operator]:
def _find_fusable_elementwise_ops(src_op: Operator) -> Set[Operator]:
"""
Given an elementwise op, returns a list of parent elementwise ops
which can be fused with this elementwise op.
"""

# Get parent ops.
dependent_ops = set()
for input_tensor in op._attrs["inputs"]:
for input_tensor in src_op._attrs["inputs"]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to avoid a naming collision to make debug easier.

dependent_ops.update(input_tensor._attrs["src_ops"])
original_ops = set(dependent_ops)

Expand Down Expand Up @@ -147,33 +147,43 @@ class FusedElementwiseInfo:
external_outputs: Set[Tensor]


def _partition_subgraphs(ops: Set[Operator]) -> Dict[str, Set[Operator]]:
@dataclass
class SubgraphInfo:
partitioned_ops: Set[Operator] = field(default_factory=set)
external_outputs: Set[Tensor] = field(default_factory=set)


def _partition_subgraphs(ops: Set[Operator]) -> Dict[str, SubgraphInfo]:
"""
Given ops of candidate graph of fused_elementwise op graph and partition
into subgraph based on output shape, returns dict of
{output shape: ops to form subgraph based on the shape}
{output shape: ops to form subgraph based on the shape and external outputs of the subgraph}
"""
# Partition graph of elementwise into subgraph based on output shape.
output_op_map = collections.defaultdict(set)
subgraph_info_map = collections.defaultdict(SubgraphInfo)
for op in ops:
shapes = []
external_outputs = []
# Find output nodes
for output_tensor in op._attrs["outputs"]:
if (
output_tensor._attrs["is_output"]
or len(output_tensor._attrs["dst_ops"] - ops) > 0
):
shapes.append("_".join(map(str, output_tensor._attrs["shape"])))
external_outputs.append(output_tensor)
# Find anscestor of output node.
# Outputs with the same shape should form the same graph
if shapes:
key = "|".join(shapes)
op_set = output_op_map[key]
subgraph_info = subgraph_info_map[key]
subgraph_info.external_outputs.update(external_outputs)
op_set = subgraph_info.partitioned_ops
for anc_op in ops:
if transform_utils.is_ancestor(anc_op, op):
op_set.add(anc_op)
op_set.add(op)
return output_op_map
return subgraph_info_map


def _get_inputs_outputs(
Expand All @@ -182,11 +192,9 @@ def _get_inputs_outputs(
"""
Given ops of a partitioned subgraph based on output shape, and ops of full graph
to form a complete graph with fused_elementwise op, returns all inputs/outputs of
the ops and the external input/output of the subgraph, which will serve as input/output
of fused_elementwise op.
the ops and the external input of the subgraph, which will serve as input of fused_elementwise op.
"""
external_inputs = set()
external_outputs = set()
tmp_inputs = set()
tmp_outputs = set()

Expand All @@ -201,9 +209,6 @@ def _get_inputs_outputs(
assert op in input_tensor._attrs["dst_ops"]
for output_tensor in op._attrs["outputs"]:
tmp_outputs.add(output_tensor)
dst_ops = set(output_tensor._attrs["dst_ops"])
if output_tensor._attrs["is_output"] or len(dst_ops - all_ops) > 0:
external_outputs.add(output_tensor)
assert len(output_tensor._attrs["src_ops"]) == 1
assert list(output_tensor._attrs["src_ops"])[0] == op

Expand All @@ -212,22 +217,11 @@ def _get_inputs_outputs(
), "external_inputs: {} is not equal to tmp_inputs: {} - tmp_outputs: {}.".format(
external_inputs, tmp_inputs, tmp_outputs
)
assert (
len(tmp_outputs - tmp_inputs - external_outputs) == 0
), "tmp_outputs: {} - tmp_inputs: {} - external_outputs: {} is not empty.".format(
tmp_outputs, tmp_inputs, external_outputs
)
assert (
len(external_outputs - tmp_outputs) == 0
), "external_outputs: {} - tmp_outputs: {} is not empty.".format(
external_outputs, tmp_outputs
)

return [tmp_inputs, tmp_outputs, external_inputs, external_outputs]
return [tmp_inputs, tmp_outputs, external_inputs]


def _collect_info(
output_op_map: Dict[str, Set[Operator]],
subgraph_info_map: Dict[str, SubgraphInfo],
all_ops: Set[Operator],
sorted_graph: List[Tensor],
) -> List[FusedElementwiseInfo]:
Expand All @@ -241,9 +235,10 @@ def _collect_info(
their external input/output, serving as input/output of fused_elementwise op.
"""
info_list = []
for op_set in output_op_map.values():
for subgraph_info in subgraph_info_map.values():
# Toposort the op_set into op_list
# because fuse_elementwise stores elementwise ops in topological order
op_set = subgraph_info.partitioned_ops
topo_set = set()
op_list = []
for tensor in sorted_graph:
Expand All @@ -259,8 +254,13 @@ def _collect_info(
), "Unable to find topological order of op list for fused_elementwise!"
# Get all inputs/outputs of elementwise ops and their external input/output,
# which will serve as input/output of fused_elementwise op.
inputs_outputs = _get_inputs_outputs(op_list, all_ops)
fused_op_info = FusedElementwiseInfo(op_list, *inputs_outputs)
tmp_inputs, tmp_outputs, external_inputs = _get_inputs_outputs(op_list, all_ops)
# Note external outputs were generated earlier because we need to group
# them by their shapes.
external_outputs = subgraph_info.external_outputs
fused_op_info = FusedElementwiseInfo(
op_list, tmp_inputs, tmp_outputs, external_inputs, external_outputs
)
info_list.append(fused_op_info)
return info_list

Expand Down Expand Up @@ -321,9 +321,9 @@ def fuse_elementwise(sorted_graph: List[Tensor], workdir: str = None) -> List[Te

for ops in to_be_fused_op_groups:
# Partition subgraph based on output shape.
output_op_map = _partition_subgraphs(ops)
subgraph_info_map = _partition_subgraphs(ops)
# Collect information to create fuse ops.
info_list = _collect_info(output_op_map, ops, sorted_graph)
info_list = _collect_info(subgraph_info_map, ops, sorted_graph)
# Create fuse ops.
_create_fuse_ops(info_list)

Expand Down Expand Up @@ -353,10 +353,9 @@ def process_singleton_elementwise(

for ops in to_be_fused_op_groups:
# Partition subgraph based on output shape.
# output_op_map = {op._attrs["op"]: set(op) for op in ops}
output_op_map = _partition_subgraphs(ops)
subgraph_info_map = _partition_subgraphs(ops)
# Collect information to create fuse ops.
info_list = _collect_info(output_op_map, set(ops), sorted_graph)
info_list = _collect_info(subgraph_info_map, set(ops), sorted_graph)
# Create fuse ops.
_create_fuse_ops(info_list)

Expand Down
107 changes: 106 additions & 1 deletion tests/unittest/ops/test_fused_elementwise_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@
from aitemplate.compiler.ops.common.epilogue import FuncEnum
from aitemplate.frontend import Tensor
from aitemplate.testing import detect_target
from aitemplate.testing.test_utils import get_random_torch_tensor
from aitemplate.testing.test_utils import (
get_random_torch_tensor,
get_torch_empty_tensor,
)
from aitemplate.utils import graph_utils, shape_utils

from parameterized import parameterized


class FusedElementwiseBroadcastTestCase(unittest.TestCase):
@classmethod
Expand Down Expand Up @@ -957,6 +962,106 @@ def test_vectorization_fp32(self):
dtype="float",
)

@parameterized.expand([("float16"), ("float")])
def test_fused_elementwise_broadcast_with_skip_connection(self, dtype):
r"""
X0 X1 (8) X2 (1) X3
\ / \ /
Div_0 (R0) Sub_1 (R1)
\ | X4 (-1)
\ | /
\ Mul_2 (R2)
\ / \
\ / \
Add_3 (R3) \
| \
Softmax_4 (R4) /
\ /
\ /
\ /
Add_5 (R5) (output)

X0 ([1,12,512,512]) and X3 ([1,1,1,512]) have different but broadcastable shapes.
"""
target = detect_target()
if dtype == "float" and target.name == "rocm":
self.skipTest("float tensors not supported by rocm")
shape0 = [1, 12, 512, 512]
shape1 = [1, 1, 1, 512]
X0 = Tensor(
shape=shape0,
dtype=dtype,
name="X0",
is_input=True,
)
X1 = Tensor(
shape=[],
dtype=dtype,
name="X1",
value=8.0,
)
X2 = Tensor(
shape=[],
dtype=dtype,
name="X2",
value=1.0,
)
X3 = Tensor(
shape=shape1,
dtype=dtype,
name="X3",
is_input=True,
)
X4 = Tensor(
shape=[],
dtype=dtype,
name="X4",
value=-1.0,
)

R0 = ops.elementwise(FuncEnum.DIV)(X0, X1) # Div_0
R1 = ops.elementwise(FuncEnum.SUB)(X2, X3) # Sub_1
R2 = ops.elementwise(FuncEnum.MUL)(R1, X4) # Mul_2
R3 = ops.elementwise(FuncEnum.ADD)(R0, R2) # Add_3
R4 = ops.softmax()(R3, -1) # Softmax_4
R5 = ops.elementwise(FuncEnum.ADD)(R4, R2) # Add_5
R5._attrs["name"] = "R5"
R5._attrs["is_output"] = True

module = compile_model(
[R5],
target,
"./tmp",
f"test_fused_elementwise_broadcast_with_skip_connection_{dtype}",
)
debug_sorted_graph = module.debug_sorted_graph
sorted_ops = graph_utils.get_sorted_ops(debug_sorted_graph)
self.assertEqual(len(sorted_ops), 4)

x0_pt = get_random_torch_tensor(shape0, dtype)
x3_pt = get_random_torch_tensor(shape1, dtype)

r0_pt = x0_pt / 8.0
r1_pt = 1.0 - x3_pt
r2_pt = r1_pt * (-1.0)
r3_pt = r0_pt + r2_pt
r4_pt = torch.nn.functional.softmax(r3_pt, -1)
r5_pt = r4_pt + r2_pt

r5 = get_torch_empty_tensor(x0_pt.shape, dtype)

input_name_to_idx_mapping = module.get_input_name_to_index_map()
inputs = [None] * len(input_name_to_idx_mapping)
input_name_to_pt_mapping = {
"X0": x0_pt,
"X3": x3_pt,
}
for input_name, pt in input_name_to_pt_mapping.items():
inputs[input_name_to_idx_mapping[input_name]] = pt
module.run_with_tensors(inputs, {"R5": r5})

self.assertTrue(torch.allclose(r5, r5_pt, atol=1e-2, rtol=1e-2))


if __name__ == "__main__":
unittest.main()
Loading