Skip to content

Commit

Permalink
[Inductor][fx pass] Fuse pointwise operators in the post grad (pytorc…
Browse files Browse the repository at this point in the history
…h#114778)

Summary:

We construct a unified API that can be easily add pointwise ops to be batched in the post grad

Test Plan:
# unit test
```
buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:group_batch_fusion
```
Buck UI: https://www.internalfb.com/buck2/6c5d1d31-e4d1-4865-bf79-1e7ac3b6e051
Test UI: https://www.internalfb.com/intern/testinfra/testrun/1970325050015770
Network: Up: 72KiB  Down: 22KiB  (reSessionID-44adc8b2-54e9-453a-bd20-710cefefaed1)
Jobs completed: 20. Time elapsed: 1:44.6s.
Cache hits: 0%. Commands: 2 (cached: 0, remote: 0, local: 2)
Tests finished: Pass 7. Fail 0. Fatal 0. Skip 0. Build failure 0
# local reproduce
### cmf
P887605070
### igctr
P892987433
### mai
P893109069
### icvr
P893075846
### oc
P893109069

Reviewed By: xuzhao9

Differential Revision: D51332067
  • Loading branch information
mengluy authored and facebook-github-bot committed Dec 4, 2023
1 parent bfa2c84 commit 33eccc7
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 6 deletions.
1 change: 1 addition & 0 deletions test/inductor/test_group_batch_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ def forward(self, inputs):


@requires_cuda()
@torch._inductor.config.patch(group_fusion=False, batch_fusion=False)
@torch._inductor.config.patch(post_grad_fusion_options={"batch_linear_post_grad": {}})
class TestPostGradBatchLinearFusion(TestCase):
def test_batch_linear_post_grad_fusion(self):
Expand Down
90 changes: 84 additions & 6 deletions torch/_inductor/fx_passes/group_batch_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,12 @@ def list_group_batch_fusions(pre_grad=True) -> List[str]:
def decompose_stack(graph: torch.fx.GraphModule, input_tensors: List[Any]) -> Any:
unsqueezed_inputs = []
for input_tensor in input_tensors:
unsqueezed_input = graph.call_function(aten.unsqueeze, args=(input_tensor, 0))
unsqueezed_input = graph.call_function(
aten.unsqueeze, args=(input_tensor,), kwargs={"dim": 0}
)
unsqueezed_inputs.append(unsqueezed_input)
stacked_inputs = graph.call_function(
aten.cat,
args=(unsqueezed_inputs, 0),
aten.cat, args=(unsqueezed_inputs,), kwargs={"dim": 0}
)
return stacked_inputs

Expand Down Expand Up @@ -271,6 +272,69 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
graph.erase_node(original_mm)


class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
"""
Batch pointwise operator (e.g., add, mul) in post grad pass.
"""

def __init__(self, op, **kwargs):
super().__init__(op, **kwargs)
self.op = op

def _pointwise_node_can_be_fused(self, node: torch.fx.Node):
# note: we only consider the case where the inputs are tensors
input, other = node.args
return (
input.meta["tensor_meta"].shape == other.meta["tensor_meta"].shape
if hasattr(input, "meta") and hasattr(other, "meta")
and "tensor_meta" in input.meta and "tensor_meta" in other.meta
else False
)

def match(self, node: torch.fx.Node):
if CallFunctionVarArgs(self.op).match(
node
) and self._pointwise_node_can_be_fused(node):
alpha = node.kwargs.get("alpha", 1.0)
input, other = node.args
shape = list(input.meta["tensor_meta"].shape)
group_key = (
"batch_" + self.op.__name__.lower() + "_post_grad",
str(shape),
str(alpha),
)
else:
group_key = None
return group_key

def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
batch_inputs, batch_others = [], []
alpha = subset[0].kwargs.get("alpha", 1.0)

for node in subset:
input, other = node.args
batch_inputs.append(input)
batch_others.append(other)

with graph.inserting_before(subset[0]):
stack_inputs = decompose_stack(graph, batch_inputs)
stack_others = decompose_stack(graph, batch_others)

batch_op = graph.call_function(
self.op,
args=(stack_inputs, stack_others),
kwargs={"alpha": alpha} if self.op == aten.add.Tensor else {},
)
for i, original_add in enumerate(subset):
with graph.inserting_after(batch_op):
new_add = graph.call_function(
torch.ops.aten.select, args=((batch_op, 0, i))
)
original_add.replace_all_uses_with(new_add)
new_add.meta.update(original_add.meta)
graph.erase_node(original_add)


@register_fusion("batch_linear_lhs")
class BatchLinearLHSFusion(BatchFusion):
"""
Expand Down Expand Up @@ -633,6 +697,18 @@ def __init__(self, **kwargs):
super().__init__(torch.nn.functional.relu, **kwargs)


@register_fusion("batch_aten_add", pre_grad=False)
class BatchAddPostGradFusion(BatchPointwiseOpsPostGradFusion):
def __init__(self, **kwargs):
super().__init__(aten.add.Tensor, **kwargs)


@register_fusion("batch_aten_mul", pre_grad=False)
class BatchMulPostGradFusion(BatchPointwiseOpsPostGradFusion):
def __init__(self, **kwargs):
super().__init__(aten.mul.Tensor, **kwargs)


def find_independent_subset_greedy(
node_list: List[torch.fx.Node],
graph_search_options: Dict[str, Any],
Expand Down Expand Up @@ -773,9 +849,11 @@ def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True):
"batch_relu": {},
"batch_sigmoid": {},
}
# config.post_grad_fusion_options = {
# "batch_linear_post_grad": {},
# }
config.post_grad_fusion_options = {
# "batch_linear_post_grad": {},
"batch_aten_add": {},
"batch_aten_mul": {},
}
if config.group_fusion:
config.post_grad_fusion_options = {
"group_linear": {"require_fbgemm": True},
Expand Down

0 comments on commit 33eccc7

Please sign in to comment.