From d02a5eab8f8cb92c38661ccddbc6cbcb2fcb54a4 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Fri, 1 Dec 2023 19:26:29 +0000 Subject: [PATCH] [inductor] post_grad batched linear fusion (#112504) Summary: Fusing independent nn.Linear() functions with aten.bmm and aten.cat. Test Plan: Without the BMM fusion: ``` buck2 run @mode/opt //pytorch/benchmark:run -- test_module -d cuda --module test_linear_module --torchdynamo inductor --torchinductor_cudagraph 0 --torchinductor_batch_fusion 0 ``` https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/test/torchbench_test_module_20231030_072536_6535183793.json.gz&bucket=pyper_traces 100 aten::mm operators With the BMM fusion: ``` buck2 run @mode/opt //pytorch/benchmark:run -- test_module -d cuda --module test_linear_module --torchdynamo inductor --torchinductor_cudagraph 0 --torchinductor_batch_fusion 1 ``` 20 aten::bmm operators https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/test/torchbench_test_module_20231030_072157_6535183793.json.gz&bucket=pyper_traces Passes accuracy test: ``` $ buck2 run @mode/opt //pytorch/benchmark:run -- test_module -d cuda --module test_linear_module --torchdynamo inductor --torchinductor_cudagraph 0 --torchinductor_batch_fusion 1 --accuracy Running eval method from test_module on cuda in dynamo inductor mode with input batch size 4 and precision tf32. Accuracy: pass ``` Looks like the bmm and input cat has been fused successfully. Checking the triton codegen: ``` TORCH_LOGS=+dynamo,+aot,+inductor buck2 run @mode/opt //pytorch/benchmark:run -- test_module -d cuda --module test_linear_module --torchdynamo inductor --torchinductor_cudagraph 0 --torchinductor_batch_fusion 1 --dump_triton 1 ``` Triton code dump: https://www.internalfb.com/intern/everpaste/?handle=GHp1ABaqYuTjYCUBALiTWmteaI1PbsIXAAAB Pull Request resolved: https://github.com/pytorch/pytorch/pull/112504 Approved by: https://github.com/yanboliang --- test/inductor/test_group_batch_fusion.py | 41 +++++- .../_inductor/fx_passes/group_batch_fusion.py | 123 ++++++++++++++++-- 2 files changed, 154 insertions(+), 10 deletions(-) diff --git a/test/inductor/test_group_batch_fusion.py b/test/inductor/test_group_batch_fusion.py index 3dcc25b70f786..b7e69f8240a6f 100644 --- a/test/inductor/test_group_batch_fusion.py +++ b/test/inductor/test_group_batch_fusion.py @@ -217,7 +217,9 @@ def forward(self, x): @requires_cuda() -@torch._inductor.config.patch(post_grad_fusion_options={"group_linear": {}}) +@torch._inductor.config.patch( + post_grad_fusion_options={"group_linear": {"require_fbgemm": True}} +) class TestGroupBatchFusion(TestCase): def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3): if len(set(ref_dict.keys())) != len(set(res_dict.keys())): @@ -413,5 +415,42 @@ def test_pointwise_op_pre_grad_fusion(self): counters.clear() +class TestBMMFusionModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.my_modules = torch.nn.ModuleList() + for _ in range(10): + self.my_modules.append(torch.nn.Linear(10, 10)) + + def forward(self, inputs): + output = None + for linear, input in zip(self.my_modules, inputs): + if output is None: + output = linear(input) + else: + output += linear(input) + return output + + +@requires_cuda() +@torch._inductor.config.patch( + post_grad_fusion_options={"batch_linear": {"require_fbgemm": False}} +) +class TestPostGradBatchLinearFusion(TestCase): + def test_batch_linear_post_grad_fusion(self): + pt1_module = TestBMMFusionModule().cuda() + inputs = [] + for _ in range(10): + inputs.append(torch.randn(10, 10).cuda()) + eager_output = pt1_module(inputs) + pt2_module = torch.compile(pt1_module) + pt2_output = pt2_module(inputs) + self.assertTrue(torch.allclose(eager_output, pt2_output)) + self.assertEqual( + counters["inductor"]["batch_fusion"], + 2, + ) + + if __name__ == "__main__": run_tests() diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py index 9cbd3271d5e2c..f863f4a3755e4 100644 --- a/torch/_inductor/fx_passes/group_batch_fusion.py +++ b/torch/_inductor/fx_passes/group_batch_fusion.py @@ -83,6 +83,18 @@ def list_group_batch_fusions(pre_grad=True) -> List[str]: return list(POST_GRAD_FUSIONS.keys()) +def decompose_stack(graph: torch.fx.GraphModule, input_tensors: List[Any]) -> Any: + unsqueezed_inputs = [] + for input_tensor in input_tensors: + unsqueezed_input = graph.call_function(aten.unsqueeze, args=(input_tensor, 0)) + unsqueezed_inputs.append(unsqueezed_input) + stacked_inputs = graph.call_function( + aten.cat, + args=(unsqueezed_inputs, 0), + ) + return stacked_inputs + + class GroupFusion(GroupBatchFusionBase): """ Fuse ops in a group way, e.g, fuse mm/addmm of arbitrary input shapes with fbgemm.gmm. @@ -105,6 +117,82 @@ def __init__(self, op, **kwargs): self.op = op +@register_fusion("batch_linear", pre_grad=False) +class PostGradBatchLinearFusion(BatchFusion): + """ + Fuse ops in a batch way in post grad (aten level). + """ + + def _addmm_node_can_be_fused(self, node: torch.fx.Node) -> bool: + return ( + node.kwargs.get("beta", 1.0) == 1.0 and node.kwargs.get("alpha", 1.0) == 1.0 + ) + + def _is_input_2d(self, input: torch.fx.Node) -> bool: + return len(input.meta["tensor_meta"].shape) == 2 + + def match(self, node: torch.fx.Node) -> Optional[Tuple[str, int, int, int, bool]]: + if CallFunctionVarArgs(aten.mm).match(node): + input_m, weight_m = node.args + bias_m = None + + elif CallFunctionVarArgs(aten.addmm.default).match( + node + ) and self._addmm_node_can_be_fused(node): + bias_m, input_m, weight_m = node.args + else: + return None + + # only handle the cases where inputs are 2D tensors + if not self._is_input_2d(input_m) or not self._is_input_2d(weight_m): + return None + m, k = input_m.meta["tensor_meta"].shape + n = weight_m.meta["tensor_meta"].shape[1] + batch_key = ("batch_linear", m, k, n, bias_m is not None) + return batch_key + + def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): + batch_inputs = [] + batch_weights = [] + batch_biases = [] + batch_nodes = [] + + for node in subset: + if CallFunctionVarArgs(aten.addmm.default).match(node): + bias, input, weight = node.args + elif CallFunctionVarArgs(aten.mm.default).match(node): + input, weight = node.args + bias = None + batch_nodes.append(node) + batch_inputs.append(input) + batch_weights.append(weight) + batch_biases.append(bias) + + with graph.inserting_before(subset[-1]): + fused_inputs = decompose_stack(graph, batch_inputs) + fused_weights = decompose_stack(graph, batch_weights) + fused_bmm = graph.call_function( + torch.ops.aten.bmm, + args=(fused_inputs, fused_weights), + ) + + for i, original_mm in enumerate(batch_nodes): + has_bias = False + with graph.inserting_after(fused_bmm): + new_mm = graph.call_function( + torch.ops.aten.select, args=((fused_bmm, 0, i)) + ) + if batch_biases[i]: + has_bias = True + new_bias_add = graph.call_function( + torch.ops.aten.add, args=((batch_biases[i], new_mm)) + ) + new_mm_cont = new_bias_add if has_bias else new_mm + original_mm.replace_all_uses_with(new_mm_cont) + new_mm_cont.meta.update(original_mm.meta) + graph.erase_node(original_mm) + + @register_fusion("group_linear", pre_grad=False) class GroupLinearFusion(GroupFusion): def _addmm_node_can_be_fused(self, node: torch.fx.Node): @@ -282,7 +370,7 @@ def is_linear_node_can_be_fused(node: torch.fx.Node): @register_fusion("batch_linear") -class BatchLinearFusion(BatchFusion): +class PreGradBatchLinearFusion(BatchFusion): """ Batch linear fusion in pre grad pass. Fuse linear with same size with torch.baddmm @@ -619,7 +707,8 @@ def get_fusion_candidates( continue key = rule.match(node) - if key is not None: + # SymInt is not hashable, so we need to skip it + if key is not None and not isinstance(key, torch.SymInt): candidate_nodes = candidate_dict[key] if node not in candidate_nodes: candidate_nodes.append(node) @@ -651,8 +740,10 @@ def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusion fused_set.update(subset) if isinstance(rule, GroupFusion): counters["inductor"]["group_fusion"] += 1 - else: + elif isinstance(rule, BatchFusion): counters["inductor"]["batch_fusion"] += 1 + else: + counters["inductor"]["unknown_group_batch_fusion"] += 1 log.info( f"{rule.__class__.__name__}: key = {key}; subset size = {len(subset)}" # noqa: G004 @@ -670,17 +761,31 @@ def generate_fusion_from_config(config_options: Dict[str, Any], pre_grad=True): def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True): - print_graph(graph, "Before group_batch fusion in post grads pass.") + print_graph(graph, "Before group_batch fusion in pre grad pass.") fusions: List[GroupBatchFusionBase] = [] if pre_grad: - fusions = generate_fusion_from_config( + fusions += generate_fusion_from_config( config.pre_grad_fusion_options, pre_grad=True ) - elif has_fbgemm: # Only group fusion (which needs fbgemm) in post grad. - fusions = generate_fusion_from_config( - config.post_grad_fusion_options, pre_grad=False - ) + else: + fbgemm_fusion_keys = [ + x + for x in config.post_grad_fusion_options + if config.post_grad_fusion_options[x].get('require_fbgemm', False) + ] + fbgemm_fusions = { + fusion: config.post_grad_fusion_options[fusion] + for fusion in fbgemm_fusion_keys + } + non_fbgemm_fusions = { + fusion: config.post_grad_fusion_options[fusion] + for fusion in config.post_grad_fusion_options.keys() + if fusion not in fbgemm_fusion_keys + } + fusions += generate_fusion_from_config(non_fbgemm_fusions, pre_grad=False) + if has_fbgemm: + fusions += generate_fusion_from_config(fbgemm_fusions, pre_grad=False) for rule in fusions: apply_group_batch_fusion(graph, rule)