Skip to content

Commit

Permalink
[inductor] post_grad batched linear fusion (pytorch#112504)
Browse files Browse the repository at this point in the history
Summary: Fusing independent nn.Linear() functions with aten.bmm and aten.cat.

Test Plan:
Without the BMM fusion:
```
buck2 run @mode/opt //pytorch/benchmark:run -- test_module -d cuda --module test_linear_module --torchdynamo inductor --torchinductor_cudagraph 0 --torchinductor_batch_fusion 0
```
https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/test/torchbench_test_module_20231030_072536_6535183793.json.gz&bucket=pyper_traces

100 aten::mm operators

With the BMM fusion:
```
buck2 run @mode/opt //pytorch/benchmark:run -- test_module -d cuda --module test_linear_module --torchdynamo inductor --torchinductor_cudagraph 0 --torchinductor_batch_fusion 1
```

20 aten::bmm operators

https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/test/torchbench_test_module_20231030_072157_6535183793.json.gz&bucket=pyper_traces

Passes accuracy test:
```
$ buck2 run @mode/opt //pytorch/benchmark:run -- test_module -d cuda --module test_linear_module --torchdynamo inductor --torchinductor_cudagraph 0 --torchinductor_batch_fusion 1 --accuracy
Running eval method from test_module on cuda in dynamo inductor mode with input batch size 4 and precision tf32.
Accuracy:                            pass
```
Looks like the bmm and input cat has been fused successfully.

Checking the triton codegen:

```
TORCH_LOGS=+dynamo,+aot,+inductor buck2 run @mode/opt //pytorch/benchmark:run -- test_module -d cuda --module test_linear_module --torchdynamo inductor --torchinductor_cudagraph 0 --torchinductor_batch_fusion 1 --dump_triton 1
```

Triton code dump: https://www.internalfb.com/intern/everpaste/?handle=GHp1ABaqYuTjYCUBALiTWmteaI1PbsIXAAAB

Pull Request resolved: pytorch#112504
Approved by: https://github.com/yanboliang
  • Loading branch information
xuzhao9 authored and dmenig committed Dec 21, 2023
1 parent c02d735 commit d02a5ea
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 10 deletions.
41 changes: 40 additions & 1 deletion test/inductor/test_group_batch_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def forward(self, x):


@requires_cuda()
@torch._inductor.config.patch(post_grad_fusion_options={"group_linear": {}})
@torch._inductor.config.patch(
post_grad_fusion_options={"group_linear": {"require_fbgemm": True}}
)
class TestGroupBatchFusion(TestCase):
def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
Expand Down Expand Up @@ -413,5 +415,42 @@ def test_pointwise_op_pre_grad_fusion(self):
counters.clear()


class TestBMMFusionModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.my_modules = torch.nn.ModuleList()
for _ in range(10):
self.my_modules.append(torch.nn.Linear(10, 10))

def forward(self, inputs):
output = None
for linear, input in zip(self.my_modules, inputs):
if output is None:
output = linear(input)
else:
output += linear(input)
return output


@requires_cuda()
@torch._inductor.config.patch(
post_grad_fusion_options={"batch_linear": {"require_fbgemm": False}}
)
class TestPostGradBatchLinearFusion(TestCase):
def test_batch_linear_post_grad_fusion(self):
pt1_module = TestBMMFusionModule().cuda()
inputs = []
for _ in range(10):
inputs.append(torch.randn(10, 10).cuda())
eager_output = pt1_module(inputs)
pt2_module = torch.compile(pt1_module)
pt2_output = pt2_module(inputs)
self.assertTrue(torch.allclose(eager_output, pt2_output))
self.assertEqual(
counters["inductor"]["batch_fusion"],
2,
)


if __name__ == "__main__":
run_tests()
123 changes: 114 additions & 9 deletions torch/_inductor/fx_passes/group_batch_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,18 @@ def list_group_batch_fusions(pre_grad=True) -> List[str]:
return list(POST_GRAD_FUSIONS.keys())


def decompose_stack(graph: torch.fx.GraphModule, input_tensors: List[Any]) -> Any:
unsqueezed_inputs = []
for input_tensor in input_tensors:
unsqueezed_input = graph.call_function(aten.unsqueeze, args=(input_tensor, 0))
unsqueezed_inputs.append(unsqueezed_input)
stacked_inputs = graph.call_function(
aten.cat,
args=(unsqueezed_inputs, 0),
)
return stacked_inputs


class GroupFusion(GroupBatchFusionBase):
"""
Fuse ops in a group way, e.g, fuse mm/addmm of arbitrary input shapes with fbgemm.gmm.
Expand All @@ -105,6 +117,82 @@ def __init__(self, op, **kwargs):
self.op = op


@register_fusion("batch_linear", pre_grad=False)
class PostGradBatchLinearFusion(BatchFusion):
"""
Fuse ops in a batch way in post grad (aten level).
"""

def _addmm_node_can_be_fused(self, node: torch.fx.Node) -> bool:
return (
node.kwargs.get("beta", 1.0) == 1.0 and node.kwargs.get("alpha", 1.0) == 1.0
)

def _is_input_2d(self, input: torch.fx.Node) -> bool:
return len(input.meta["tensor_meta"].shape) == 2

def match(self, node: torch.fx.Node) -> Optional[Tuple[str, int, int, int, bool]]:
if CallFunctionVarArgs(aten.mm).match(node):
input_m, weight_m = node.args
bias_m = None

elif CallFunctionVarArgs(aten.addmm.default).match(
node
) and self._addmm_node_can_be_fused(node):
bias_m, input_m, weight_m = node.args
else:
return None

# only handle the cases where inputs are 2D tensors
if not self._is_input_2d(input_m) or not self._is_input_2d(weight_m):
return None
m, k = input_m.meta["tensor_meta"].shape
n = weight_m.meta["tensor_meta"].shape[1]
batch_key = ("batch_linear", m, k, n, bias_m is not None)
return batch_key

def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
batch_inputs = []
batch_weights = []
batch_biases = []
batch_nodes = []

for node in subset:
if CallFunctionVarArgs(aten.addmm.default).match(node):
bias, input, weight = node.args
elif CallFunctionVarArgs(aten.mm.default).match(node):
input, weight = node.args
bias = None
batch_nodes.append(node)
batch_inputs.append(input)
batch_weights.append(weight)
batch_biases.append(bias)

with graph.inserting_before(subset[-1]):
fused_inputs = decompose_stack(graph, batch_inputs)
fused_weights = decompose_stack(graph, batch_weights)
fused_bmm = graph.call_function(
torch.ops.aten.bmm,
args=(fused_inputs, fused_weights),
)

for i, original_mm in enumerate(batch_nodes):
has_bias = False
with graph.inserting_after(fused_bmm):
new_mm = graph.call_function(
torch.ops.aten.select, args=((fused_bmm, 0, i))
)
if batch_biases[i]:
has_bias = True
new_bias_add = graph.call_function(
torch.ops.aten.add, args=((batch_biases[i], new_mm))
)
new_mm_cont = new_bias_add if has_bias else new_mm
original_mm.replace_all_uses_with(new_mm_cont)
new_mm_cont.meta.update(original_mm.meta)
graph.erase_node(original_mm)


@register_fusion("group_linear", pre_grad=False)
class GroupLinearFusion(GroupFusion):
def _addmm_node_can_be_fused(self, node: torch.fx.Node):
Expand Down Expand Up @@ -282,7 +370,7 @@ def is_linear_node_can_be_fused(node: torch.fx.Node):


@register_fusion("batch_linear")
class BatchLinearFusion(BatchFusion):
class PreGradBatchLinearFusion(BatchFusion):
"""
Batch linear fusion in pre grad pass.
Fuse linear with same size with torch.baddmm
Expand Down Expand Up @@ -619,7 +707,8 @@ def get_fusion_candidates(
continue

key = rule.match(node)
if key is not None:
# SymInt is not hashable, so we need to skip it
if key is not None and not isinstance(key, torch.SymInt):
candidate_nodes = candidate_dict[key]
if node not in candidate_nodes:
candidate_nodes.append(node)
Expand Down Expand Up @@ -651,8 +740,10 @@ def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusion
fused_set.update(subset)
if isinstance(rule, GroupFusion):
counters["inductor"]["group_fusion"] += 1
else:
elif isinstance(rule, BatchFusion):
counters["inductor"]["batch_fusion"] += 1
else:
counters["inductor"]["unknown_group_batch_fusion"] += 1

log.info(
f"{rule.__class__.__name__}: key = {key}; subset size = {len(subset)}" # noqa: G004
Expand All @@ -670,17 +761,31 @@ def generate_fusion_from_config(config_options: Dict[str, Any], pre_grad=True):


def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True):
print_graph(graph, "Before group_batch fusion in post grads pass.")
print_graph(graph, "Before group_batch fusion in pre grad pass.")
fusions: List[GroupBatchFusionBase] = []

if pre_grad:
fusions = generate_fusion_from_config(
fusions += generate_fusion_from_config(
config.pre_grad_fusion_options, pre_grad=True
)
elif has_fbgemm: # Only group fusion (which needs fbgemm) in post grad.
fusions = generate_fusion_from_config(
config.post_grad_fusion_options, pre_grad=False
)
else:
fbgemm_fusion_keys = [
x
for x in config.post_grad_fusion_options
if config.post_grad_fusion_options[x].get('require_fbgemm', False)
]
fbgemm_fusions = {
fusion: config.post_grad_fusion_options[fusion]
for fusion in fbgemm_fusion_keys
}
non_fbgemm_fusions = {
fusion: config.post_grad_fusion_options[fusion]
for fusion in config.post_grad_fusion_options.keys()
if fusion not in fbgemm_fusion_keys
}
fusions += generate_fusion_from_config(non_fbgemm_fusions, pre_grad=False)
if has_fbgemm:
fusions += generate_fusion_from_config(fbgemm_fusions, pre_grad=False)

for rule in fusions:
apply_group_batch_fusion(graph, rule)
Expand Down

0 comments on commit d02a5ea

Please sign in to comment.