From b67f87bb3aba62236f07108359a895e7cf92cce9 Mon Sep 17 00:00:00 2001 From: ruisizhang123 Date: Thu, 4 Sep 2025 11:11:02 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- autoparallel/auto_bucketing.py | 20 +- .../autobucketing_util/bucket_func.py | 632 ++++++++++++++++++ .../autobucketing_util/bucket_plan.py | 2 +- .../autobucketing_util/bucket_utils.py | 412 +++++++++++- 4 files changed, 1061 insertions(+), 5 deletions(-) create mode 100644 autoparallel/autobucketing_util/bucket_func.py diff --git a/autoparallel/auto_bucketing.py b/autoparallel/auto_bucketing.py index 85e1573a..e868f0cd 100644 --- a/autoparallel/auto_bucketing.py +++ b/autoparallel/auto_bucketing.py @@ -5,7 +5,7 @@ import torch -from .autobucketing_util import bucket_plan, bucket_utils +from .autobucketing_util import bucket_func, bucket_plan, bucket_utils class simplefsdp_autobucketing_config: @@ -53,4 +53,22 @@ def simple_fsdp_autobucketing_reordering_pass( bucketable_nodes, configs, ) + + snodes = bucket_func.bucket_fsdp_all_gather_with_plan( + scheduler, + snodes, + scheduler.name_to_buf, + scheduler.name_to_fused_node, + all_gather_plan, + bucketable_nodes, + ) + if len(reduce_scatter_plan) > 0: + snodes = bucket_func.bucket_fsdp_reduce_scatter_with_plan( + scheduler, + snodes, + scheduler.name_to_buf, + scheduler.name_to_fused_node, + reduce_scatter_plan, + bucketable_nodes, + ) return snodes diff --git a/autoparallel/autobucketing_util/bucket_func.py b/autoparallel/autobucketing_util/bucket_func.py new file mode 100644 index 00000000..7f9f18f7 --- /dev/null +++ b/autoparallel/autobucketing_util/bucket_func.py @@ -0,0 +1,632 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# mypy: ignore-errors +import functools +from typing import Any, Dict + +import torch +from torch._inductor import ir, scheduler +from torch._inductor.comms import bucket_all_gathers, bucket_reduce_scatters, get_op_idx +from torch._inductor.dependencies import StarDep, WeakDep +from torch._inductor.utils import is_collective, is_wait +from torch._inductor.virtualized import V +from torch.utils._ordered_set import OrderedSet + +from .bucket_utils import ( + _find_recursive_deps_of_snode, + _find_recursive_users_of_snode, + _get_fx_node, + _remove_operation, + _replace_scheduler_buffer, + _schedule_fallback_operation, + _schedule_snode, + check_ir_node_bucketable, +) + + +def bucket_fsdp_all_gather_with_plan( + sched: "scheduler.Scheduler", + snodes: list["scheduler.BaseSchedulerNode"], + name_to_buf: Dict[str, "scheduler.SchedulerBuffer"], + name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"], + all_gather_bucket_plan: list[ + Dict[tuple[Any, ...], list["scheduler.BaseSchedulerNode"]] + ], + bucketable_nodes: set[str], +) -> list["scheduler.BaseSchedulerNode"]: + # Given a list of scheduler nodes `snodes`, pick out all_gather nodes and bucket them according to `all_gather_bucket_plan`. + # It will return a new list of scheduler nodes, which is the same as `snodes` except that all_gather nodes are bucketed. + new_order: list[scheduler.BaseSchedulerNode] = [] + scheduled = OrderedSet() + ag_exists = False + ag_snode_to_cast_snode: Dict[ + scheduler.BaseSchedulerNode, scheduler.BaseSchedulerNode + ] = {} + ag_snode_to_wait_snode: Dict[ + scheduler.BaseSchedulerNode, scheduler.BaseSchedulerNode + ] = {} + new_operation_name_to_snode = {} + + schedule_snode = functools.partial( + _schedule_snode, new_order=new_order, scheduled=scheduled + ) + replace_scheduler_buffer = functools.partial( + _replace_scheduler_buffer, name_to_buf=name_to_buf + ) + remove_operation = functools.partial( + _remove_operation, name_to_fused_node=name_to_fused_node + ) + schedule_fallback_operation = functools.partial( + _schedule_fallback_operation, + scheduler=sched, + name_to_buf=name_to_buf, + name_to_fused_node=name_to_fused_node, + schedule_snode_fn=schedule_snode, + new_operation_name_to_snode=new_operation_name_to_snode, + ) + + # Step 1: Find all all_gather nodes + for snode in snodes: + if is_collective( + snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor.default + ) and check_ir_node_bucketable(snode.node, bucketable_nodes): + ag_exists = True + ag_snode = snode + ag_related_snode_set: OrderedSet[scheduler.BaseSchedulerNode] = OrderedSet() + + # Find the "cast + all_gather" code block + _find_recursive_deps_of_snode( + ag_snode, + ag_related_snode_set, + name_to_buf, + name_to_fused_node, + allow_weak_dep=False, + ) # sort nodes by original operation order + ag_related_snodes = sorted( + ag_related_snode_set, key=lambda x: get_op_idx(x) + ) + if len(ag_related_snodes) >= 2: + cast_node = ag_related_snodes[-2] + for node in ag_related_snodes: + if not is_wait(node.node) and not is_collective(node.node): + cast_node = node + ag_snode = snode + ag_snode_to_cast_snode[ag_snode] = cast_node + else: + ag_snode = ag_related_snodes[0] + + # Find the "all_gather + wait_tensor" code block + assert len(ag_snode.outputs) == 1 + assert len(ag_snode.outputs[0].users) == 1 + wait_snode = ag_snode.outputs[0].users[0].node + ag_snode_to_wait_snode[ag_snode] = wait_snode + + if ag_exists: + assert len(ag_snode_to_wait_snode) > 0 + else: + return snodes + + # Step 2: Put all_gather nodes into buckets + ag_snode_to_bucket_id = {} + ag_snode_to_bucket_id_coarsen = {} + cur_bucket_id: int = 0 + + for all_gather_bucket in all_gather_bucket_plan: + for all_gather_info, all_gather_list in all_gather_bucket.items(): + ag_snode_to_bucket_id.update( + dict.fromkeys(all_gather_list, all_gather_info + (cur_bucket_id,)) + ) + ag_snode_to_bucket_id_coarsen.update( + dict.fromkeys(all_gather_list, cur_bucket_id) + ) + cur_bucket_id += 1 + assert len(ag_snode_to_bucket_id) == len(ag_snode_to_wait_snode) + + # Step 3: Create new (bucketed) all_gather nodes + # TODO(yf225): horizontally fuse all cast ops into one op + bucket_id_to_bucketed_op_info = {} + bucket_id_is_scheduled = {} + + for bucket_id, ag_bucket in enumerate(all_gather_bucket_plan): + all_ag_snodes = [] + all_wait_snodes = [] + all_ag_input_ir_nodes = [] + group_sizes = [] + group_names = [] + for ag_info, ag_snodes in ag_bucket.items(): + if len(ag_snodes) == 0: + continue + example_ag_fx_node = _get_fx_node( + ag_snodes[0], + expected_op=torch.ops._c10d_functional.all_gather_into_tensor.default, + ) + _, group_size, group_name = example_ag_fx_node.args + ag_input_ir_nodes: list[ir.IRNode] = [] + wait_snodes = [] + for ag_snode in ag_snodes: + ag_fx_node = _get_fx_node( + ag_snode, + expected_op=torch.ops._c10d_functional.all_gather_into_tensor.default, + ) + assert ( + ag_fx_node.args[1] == group_size + and ag_fx_node.args[2] == group_name + ), f"Expected group_size {group_size} and group_name {group_name}, but got {ag_fx_node.args[1:]}" + # TODO(yf225): this needs to consider the "no cast op" case, in which case we should directly take graph input as input + # storage = V.graph.graph_inputs[name].data + # assert isinstance(storage, ir.StorageBox) and storage.is_input_buffer() + if cast_snode := ag_snode_to_cast_snode.get(ag_snode, None): + assert len(cast_snode.get_outputs()) == 1 + ag_input_ir_node = list(cast_snode.get_outputs())[0].node + else: + met_deps = ag_snode.read_writes.reads - ag_snode.unmet_dependencies + assert ( + len(met_deps) == 1 + ), f"ag_snode: {ag_snode}, ag_snode.debug_str(): {ag_snode.debug_str()}, met_deps: {met_deps}" + ag_input_name = list(met_deps)[0].name + ag_input_ir_node = V.graph.graph_inputs[ag_input_name].data + assert ( + isinstance(ag_input_ir_node, ir.StorageBox) + and ag_input_ir_node.is_input_buffer() + ) + ag_input_ir_nodes.append(ag_input_ir_node) + wait_snodes.append(ag_snode_to_wait_snode[ag_snode]) + group_sizes.append(group_size) + group_names.append(group_name) + all_ag_snodes.append(ag_snodes) + all_wait_snodes.append(wait_snodes) + all_ag_input_ir_nodes.append(ag_input_ir_nodes) + bucket_id_to_bucketed_op_info[bucket_id] = ( + all_ag_input_ir_nodes, + group_sizes, + group_names, + all_ag_snodes, + all_wait_snodes, + ) + + ag_snodes = OrderedSet(ag_snode_to_wait_snode.keys()) + ag_and_wait_snodes = OrderedSet() + ag_and_wait_snodes |= ag_snodes # all_gather + ag_and_wait_snodes |= OrderedSet(ag_snode_to_wait_snode.values()) # wait_tensor + + for snode in snodes: + if ( + snode not in ag_and_wait_snodes + ): # and snode not in list(ag_snode_to_cast_snode.values()): + # not all_gather or its wait_tensor - schedule it normally + schedule_snode(snode) + elif snode in ag_snodes: + assert ( + snode in ag_snode_to_bucket_id + ), f"{snode} not in {ag_snode_to_bucket_id}" + # bucket_id is the smaller one with (group_info, bucket) + bucket_id = ag_snode_to_bucket_id[snode] + # coarsen_bucket_id is the bigger one with bucket info + coarsen_bucket_id = ag_snode_to_bucket_id_coarsen[snode] + + if coarsen_bucket_id not in bucket_id_is_scheduled: + ( + all_ag_input_ir_nodes, + group_sizes, + group_names, + all_orig_ag_snodes, + all_orig_wait_snodes, + ) = bucket_id_to_bucketed_op_info[coarsen_bucket_id] + + AG_Group_node_list = [] + Wait_Group_node_list = [] + for idx, ( + ag_input_ir_nodes, + orig_ag_snodes, + orig_wait_snodes, + group_size, + group_name, + ) in enumerate( + zip( + all_ag_input_ir_nodes, + all_orig_ag_snodes, + all_orig_wait_snodes, + group_sizes, + group_names, + ) + ): + if len(orig_ag_snodes) == 1: + # If there is only one all_gather in the bucket, schedule it normally. + if orig_ag_snodes[0] in ag_snode_to_cast_snode: + AG_Group_node_list.append( + ag_snode_to_cast_snode[orig_ag_snodes[0]] + ) + AG_Group_node_list.append(orig_ag_snodes[0]) + Wait_Group_node_list.append(orig_wait_snodes[0]) + else: + original_length = len(new_order) + outs = bucket_all_gathers( + schedule_fallback_operation, + group_size, + group_name, + ag_input_ir_nodes, + orig_ag_snodes, + name_to_buf, + orig_wait_snodes, + schedule_snode, + ) + # Swap out the original wait output buffer with the new buffer, + # so that downstream user nodes can read from the new buffer just by using the original dep buffer name. + for out_operation, orig_ag_snode, orig_wait_snode in zip( + outs, orig_ag_snodes, orig_wait_snodes + ): + out_snode = new_operation_name_to_snode[ + out_operation.get_operation_name() + ] + assert len(orig_ag_snode.outputs) == 1 + orig_ag_snode_output = orig_ag_snode.outputs[-1] + orig_wait_snode_output = orig_wait_snode.outputs[-1] + out_snode_output = out_snode.outputs[-1] + replace_scheduler_buffer( + orig_sched_buf=orig_ag_snode_output, + new_sched_buf=out_snode_output, + ) + # wait_tensor node output is modeled as a mutation on the all_gather node output. + # We need to preserve this property even after swapping. + assert ( + isinstance( + orig_wait_snode_output.node, ir.MutationOutput + ) + and len(orig_wait_snode_output.get_mutations()) == 1 + and orig_wait_snode_output.get_mutations()[0] + == orig_ag_snode_output.get_name() + ) + out_snode.outputs.append(orig_wait_snode_output) + out_snode.read_writes.writes.add( + StarDep( + name=orig_wait_snode_output.get_name(), mode=None + ) + ) + # Remove original all_gather and wait_tensor operations + remove_operation(orig_ag_snode.node) + remove_operation(orig_wait_snode.node) + new_length = len(new_order) + current_Wait_Group_node = [] + current_AG_Group_node = [] + wait_node = True + for node in range(new_length - original_length): + node = new_order.pop() + node.min_order = 0 + node.max_order = 0 + if wait_node: + current_Wait_Group_node.append(node) + else: + current_AG_Group_node.append(node) + + if ( + isinstance(node.node, ir.FallbackKernel) + and node.node.python_kernel_name + == "torch.ops._c10d_functional.wait_tensor.default" + ): + wait_node = False + current_AG_Group_node.reverse() + current_Wait_Group_node.reverse() + AG_Group_node_list.extend(current_AG_Group_node) + Wait_Group_node_list.extend(current_Wait_Group_node) + AG_Group_node = scheduler.GroupedSchedulerNode.create( + AG_Group_node_list + ) + Wait_Group_node = scheduler.GroupedSchedulerNode.create( + Wait_Group_node_list + ) + AG_Group_node.temp_grouping = True + Wait_Group_node.temp_grouping = True + new_order.append(AG_Group_node) + new_order.append(Wait_Group_node) + bucket_id_is_scheduled[coarsen_bucket_id] = True + + return new_order + + +def bucket_fsdp_reduce_scatter_with_plan( + sched: "scheduler.Scheduler", + snodes: list["scheduler.BaseSchedulerNode"], + name_to_buf: Dict[str, "scheduler.SchedulerBuffer"], + name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"], + reduce_scatter_bucket_plan: list[ + Dict[tuple[Any, ...], list["scheduler.BaseSchedulerNode"]] + ], + bucketable_nodes: set[str], +) -> list["scheduler.BaseSchedulerNode"]: + # Given a list of scheduler nodes `snodes`, pick out reduce_scatter nodes and bucket them according to `reduce_scatter_bucket_plan`. + # It will return a new list of scheduler nodes, which is the same as `snodes` except that reduce_scatter nodes are bucketed. + + new_order: list[scheduler.BaseSchedulerNode] = [] + scheduled = OrderedSet() + rs_exists = False + rs_snode_to_wait_snode = {} + new_operation_name_to_snode = {} + + schedule_snode = functools.partial( + _schedule_snode, new_order=new_order, scheduled=scheduled + ) + replace_scheduler_buffer = functools.partial( + _replace_scheduler_buffer, name_to_buf=name_to_buf + ) + remove_operation = functools.partial( + _remove_operation, name_to_fused_node=name_to_fused_node + ) + schedule_fallback_operation = functools.partial( + _schedule_fallback_operation, + scheduler=sched, + name_to_buf=name_to_buf, + name_to_fused_node=name_to_fused_node, + schedule_snode_fn=schedule_snode, + new_operation_name_to_snode=new_operation_name_to_snode, + ) + + # Step 1: Find all reduce_scatter nodes + for snode in snodes: + if is_collective( + snode.node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default + ) and check_ir_node_bucketable(snode.node, bucketable_nodes): + rs_exists = True + rs_snode = snode + + # Find the "reduce_scatter + wait_tensor" code block + assert len(rs_snode.outputs) == 1 + assert ( + len(rs_snode.outputs[0].users) == 1 + ), f"rs_snode.outputs[0].users: {rs_snode.outputs[0].users}" + wait_snode = rs_snode.outputs[0].users[0].node + rs_snode_to_wait_snode[rs_snode] = wait_snode + + if rs_exists: + assert len(rs_snode_to_wait_snode) > 0 + else: + return snodes + + # Step 2: Put reduce_scatter nodes into buckets + rs_snode_to_bucket_id = {} + rs_snode_to_bucket_id_coarsen = {} + cur_bucket_id: int = 0 + + for reduce_scatter_bucket in reduce_scatter_bucket_plan: + for reduce_scatter_info, reduce_scatter_list in reduce_scatter_bucket.items(): + rs_snode_to_bucket_id.update( + dict.fromkeys( + reduce_scatter_list, reduce_scatter_info + (cur_bucket_id,) + ) + ) + rs_snode_to_bucket_id_coarsen.update( + dict.fromkeys(reduce_scatter_list, cur_bucket_id) + ) + cur_bucket_id += 1 + + assert len(rs_snode_to_bucket_id) == len(rs_snode_to_wait_snode) + + # Step 3: Create new (bucketed) reduce_scatter nodes + order = {x: i for i, x in enumerate(snodes)} + rs_snodes = OrderedSet(rs_snode_to_wait_snode.keys()) + rs_and_its_recursive_users = OrderedSet() + rs_and_its_recursive_users |= rs_snodes # all_gather + rs_and_its_recursive_users |= OrderedSet( + rs_snode_to_wait_snode.values() + ) # wait_tensor + + bucket_id_to_bucketed_op_info = {} + bucket_id_is_scheduled = {} + for bucket_id, rs_bucket in enumerate(reduce_scatter_bucket_plan): + all_rs_input_ir_nodes = [] + all_wait_snodes = [] + all_wait_snode_recursive_users = [] + all_rs_snodes = [] + group_sizes = [] + group_names = [] + for rs_info, rs_snodes in rs_bucket.items(): + if len(rs_snodes) == 0: + continue + example_rs_fx_node = _get_fx_node( + rs_snodes[0], + expected_op=torch.ops._c10d_functional.reduce_scatter_tensor.default, + ) + _, reduce_op, group_size, group_name = example_rs_fx_node.args + rs_input_ir_nodes: list[ir.IRNode] = [] + wait_snodes = [] + wait_snode_recursive_users = OrderedSet() + for rs_snode in rs_snodes: + rs_fx_node = _get_fx_node( + rs_snode, + expected_op=torch.ops._c10d_functional.reduce_scatter_tensor.default, + ) + assert ( + rs_fx_node.args[1] == reduce_op + and rs_fx_node.args[2] == group_size + and rs_fx_node.args[3] == group_name + ), f"Expected reduce_op {reduce_op} and group_size {group_size} and group_name {group_name}, but got {rs_fx_node.args[1:]}" + unmet_real_deps = [ + dep + for dep in rs_snode.unmet_dependencies + if not isinstance(dep, WeakDep) + ] + assert len(unmet_real_deps) == 1 + # rs_input_ir_nodes.append(name_to_buf[unmet_real_deps[0].name].node) + rs_input_ir_nodes.append(rs_snode.node.inputs[0]) + wait_snode = rs_snode_to_wait_snode[rs_snode] + wait_snodes.append(wait_snode) + _find_recursive_users_of_snode( + wait_snode, + wait_snode_recursive_users, + name_to_buf, + name_to_fused_node, + ) + # _find_recursive_users_of_snode() is inclusive - need to manually remove wait_snode from set + wait_snode_recursive_users.remove(wait_snode) + rs_and_its_recursive_users |= wait_snode_recursive_users + + all_rs_input_ir_nodes.append(rs_input_ir_nodes) + all_wait_snodes.append(wait_snodes) + all_wait_snode_recursive_users.append(wait_snode_recursive_users) + all_rs_snodes.append(rs_snodes) + group_sizes.append(group_size) + group_names.append(group_name) + bucket_id_to_bucketed_op_info[bucket_id] = ( + all_rs_input_ir_nodes, + reduce_op, + group_sizes, + group_names, + all_rs_snodes, + all_wait_snodes, + all_wait_snode_recursive_users, + ) + + RS_Group_node_list = [] + Wait_Group_node_list = [] + for snode in snodes: + if snode not in rs_and_its_recursive_users: + # not reduce_scatter or its wait_tensor - schedule it normally + schedule_snode(snode) + elif snode in rs_snode_to_wait_snode: + assert ( + snode in rs_snode_to_bucket_id + ), f"{snode} not in {rs_snode_to_bucket_id}" + bucket_id = rs_snode_to_bucket_id[snode] + coarsen_bucket_id = bucket_id[-1] + + if ( + coarsen_bucket_id not in bucket_id_is_scheduled + and snode + == bucket_id_to_bucketed_op_info[coarsen_bucket_id][-3][-1][-1] + ): + # If we are at the last node in the bucket, we can start to schedule the bucketed reduce_scatter node + ( + all_rs_input_ir_nodes, + reduce_op, + group_sizes, + group_names, + all_orig_rs_snodes, + all_orig_wait_snodes, + all_orig_wait_snode_recursive_users, + ) = bucket_id_to_bucketed_op_info[coarsen_bucket_id] + + RS_Group_node_list = [] + Wait_Group_node_list = [] + for idx, ( + rs_input_ir_nodes, + orig_rs_snodes, + orig_wait_snodes, + orig_wait_snode_recursive_users, + group_size, + group_name, + ) in enumerate( + zip( + all_rs_input_ir_nodes, + all_orig_rs_snodes, + all_orig_wait_snodes, + all_orig_wait_snode_recursive_users, + group_sizes, + group_names, + ) + ): + if len(rs_input_ir_nodes) == 1: + # If there is only one input, we can directly use the input as the output + RS_Group_node_list.append(orig_rs_snodes[0]) + Wait_Group_node_list.append(orig_wait_snodes[0]) + Wait_Group_node_list.extend(orig_wait_snode_recursive_users) + else: + original_length = len(new_order) + new_sharded_grads = bucket_reduce_scatters( + schedule_fallback_operation, + group_size, + group_name, + reduce_op, + rs_input_ir_nodes, + orig_rs_snodes, + name_to_buf, + orig_wait_snodes, + ) + for out_operation, orig_rs_snode, orig_wait_snode in zip( + new_sharded_grads, orig_rs_snodes, orig_wait_snodes + ): + out_snode = new_operation_name_to_snode[ + out_operation.get_operation_name() + ] + assert len(orig_rs_snode.outputs) == 1 + orig_rs_snode_output = orig_rs_snode.outputs[-1] + orig_wait_snode_output = orig_wait_snode.outputs[-1] + out_snode_output = out_snode.outputs[-1] + replace_scheduler_buffer( + orig_sched_buf=orig_rs_snode_output, + new_sched_buf=out_snode_output, + ) + # wait_tensor node output is modeled as a mutation on the reduce_scatter node output. + # We need to preserve this property even after swapping. + assert ( + isinstance( + orig_wait_snode_output.node, ir.MutationOutput + ) + and len(orig_wait_snode_output.get_mutations()) == 1 + and orig_wait_snode_output.get_mutations()[0] + == orig_rs_snode_output.get_name() + ) + out_snode.outputs.append(orig_wait_snode_output) + out_snode.read_writes.writes.add( + StarDep( + name=orig_wait_snode_output.get_name(), mode=None + ) + ) + # Remove original reduce_scatter and wait_tensor operations + remove_operation(orig_rs_snode.node) + remove_operation(orig_wait_snode.node) + + if len(rs_input_ir_nodes) != 1: + new_length = len(new_order) + current_RS_Group_node = [] + current_Wait_Group_node = [] + wait_node = True + for node in range(new_length - original_length): + node = new_order.pop() + node.min_order = 0 + node.max_order = 0 + if wait_node: + current_Wait_Group_node.append(node) + else: + current_RS_Group_node.append(node) + if ( + isinstance(node.node, ir.FallbackKernel) + and node.node.python_kernel_name + == "torch.ops._c10d_functional.wait_tensor.default" + ): + wait_node = False + current_RS_Group_node.reverse() + current_Wait_Group_node.reverse() + RS_Group_node_list.extend(current_RS_Group_node) + Wait_Group_node_list.extend(current_Wait_Group_node) + + RS_Group_node = scheduler.GroupedSchedulerNode.create( + RS_Group_node_list + ) + for ( + orig_wait_snode_recursive_users + ) in all_orig_wait_snode_recursive_users: + for user in sorted( + orig_wait_snode_recursive_users, key=lambda x: order[x] + ): + Wait_Group_node_list.append(user) + Wait_Group_node = scheduler.GroupedSchedulerNode.create( + Wait_Group_node_list + ) + RS_Group_node.temp_grouping = True + Wait_Group_node.temp_grouping = True + new_order.append(RS_Group_node) + new_order.append(Wait_Group_node) + bucket_id_is_scheduled[coarsen_bucket_id] = True + else: + continue + + if len(RS_Group_node_list) > 0: + RS_Group_node = scheduler.GroupedSchedulerNode.create(RS_Group_node_list) + Wait_Group_node = scheduler.GroupedSchedulerNode.create(Wait_Group_node_list) + RS_Group_node.temp_grouping = True + Wait_Group_node.temp_grouping = True + new_order.append(RS_Group_node) + new_order.append(Wait_Group_node) + return new_order diff --git a/autoparallel/autobucketing_util/bucket_plan.py b/autoparallel/autobucketing_util/bucket_plan.py index e08cb1ae..95ab1386 100644 --- a/autoparallel/autobucketing_util/bucket_plan.py +++ b/autoparallel/autobucketing_util/bucket_plan.py @@ -12,10 +12,10 @@ import torch from torch._C._distributed_c10d import ReduceOp from torch._inductor import scheduler -from torch._inductor.comm import _schedule_fallback_operation from torch._inductor.utils import is_collective from .bucket_utils import ( + _schedule_fallback_operation, check_ir_node_bucketable, estimate_bucketed_snode_runtime, get_data_size, diff --git a/autoparallel/autobucketing_util/bucket_utils.py b/autoparallel/autobucketing_util/bucket_utils.py index 075c1ce4..2bb153a3 100644 --- a/autoparallel/autobucketing_util/bucket_utils.py +++ b/autoparallel/autobucketing_util/bucket_utils.py @@ -4,15 +4,17 @@ # LICENSE file in the root directory of this source tree. # mypy: ignore-errors +import math from functools import reduce -from typing import Any, Callable, Dict, Union +from typing import Any, Callable, Dict, Union, cast import torch +import torch.utils._pytree as pytree from torch._inductor import ir, scheduler -from torch._inductor.comm import bucket_all_gathers, bucket_reduce_scatters -from torch._inductor.dependencies import WeakDep +from torch._inductor.dependencies import StarDep, WeakDep from torch._inductor.ir import NoneLayout from torch._inductor.utils import buf_name_to_fused_snode, is_collective, is_wait +from torch._inductor.virtualized import V from torch.distributed import ProcessGroup from torch.distributed.distributed_c10d import _resolve_process_group from torch.utils._ordered_set import OrderedSet @@ -340,3 +342,407 @@ def estimate_bucketed_snode_runtime( comm_size_inp += get_data_size(local_comm_size_inp) comm_size_out += get_data_size(local_comm_size_out) return estimated_comm, comm_size_inp, comm_size_out + + +def _schedule_snode( + snode: "scheduler.BaseSchedulerNode", + new_order: list["scheduler.BaseSchedulerNode"], + scheduled: list["scheduler.BaseSchedulerNode"], +): + if snode in scheduled: + return + + new_order.append(snode) + scheduled.add(snode) + + +def _remove_operation( + operation: ir.Operation, + name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"], +): + assert isinstance( + operation, ir.Operation + ), f"Expected ir.Operation, but got {type(ir.Operation)}. Offending value: {ir.Operation}" + idx = V.graph.operations.index(operation) + del V.graph.operations[idx] + del V.graph.name_to_op[operation.get_operation_name()] + del name_to_fused_node[operation.get_operation_name()] + + +def _replace_scheduler_buffer( + orig_sched_buf: "scheduler.SchedulerBuffer", + new_sched_buf: "scheduler.SchedulerBuffer", + name_to_buf: Dict[str, "scheduler.SchedulerBuffer"], +): + new_buf = new_sched_buf.node + new_buf_name = new_buf.get_name() + orig_buf = orig_sched_buf.node + orig_buf_name = orig_buf.get_name() + V.graph.buffers[V.graph.buffers.index(orig_buf)] = new_buf + V.graph.name_to_buffer[orig_buf_name] = new_buf + name_to_buf[orig_buf_name] = new_sched_buf + new_buf.name = orig_buf_name + new_sched_buf.defining_op.set_read_writes( + new_sched_buf.defining_op.read_writes.rename({new_buf_name: orig_buf_name}) + ) + new_sched_buf.users = orig_sched_buf.users + + +def _schedule_fallback_operation( + target: Any, + args: list[Any], + kwargs: dict[str, Any], + scheduler: "scheduler.Scheduler", + name_to_buf: Dict[str, "scheduler.SchedulerBuffer"], + name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"], + schedule_snode_fn: Union[Callable[..., Any], Any] = None, + new_operation_name_to_snode: Dict[str, "scheduler.BaseSchedulerNode"] = {}, + dep_operations: Union[ir.Operation, list[ir.Operation], None] = None, +) -> Union[ir.Operation, list[ir.Operation]]: + # NOTE: `dep_operations` enforces strong ordering between ops, helpful if the dependency chain is not clear + # from direct input-output relationship + # (i.e. if OP1 mutates a view of buffer X and then OP2 reads from X, and OP1 is expected to run before OP2 -> OP2 + # must have `dep_operations` pointing to OP1 to ensure reordering pass would not mess up the order). + + def wrap_tensors(x): + if isinstance(x, ir.MutationOutput): + mutated_buf_names = x.get_mutation_names() + assert ( + isinstance(mutated_buf_names, list) and len(mutated_buf_names) == 1 + ), "Expect only one mutated buffer in MutationOutput" + return wrap_tensors(name_to_buf[mutated_buf_names[0]].node) + elif isinstance(x, ir.IRNode): + if isinstance(x, ir.StorageBox): + return x + else: + return ir.TensorBox.create(x) + else: + return x + + operations_prev_watermark = len(V.graph.operations) + # this will append newly created operations to V.graph.operations + ir.FallbackKernel.create( + target, + *pytree.tree_map(wrap_tensors, args), + **pytree.tree_map(wrap_tensors, kwargs), + ) + new_operations = V.graph.operations[operations_prev_watermark:] + new_snodes = [] + if isinstance(dep_operations, ir.Operation): + dep_operations = [dep_operations] + for new_operation in new_operations: + new_snode = scheduler.create_scheduler_node(new_operation) + if dep_operations is not None: + # make the new snode depend on all output buffers of all the dep operations, + # to ensure that the new snode will always be executed after all the dep operations. + for dep_operation in dep_operations: + dep_snode = name_to_fused_node[dep_operation.get_operation_name()] + for buf_name in dep_snode.get_buffer_names(): + new_snode.set_read_writes( + new_snode.read_writes.with_read( + StarDep(name=buf_name, mode=None) + ) + ) + if schedule_snode_fn is not None: + schedule_snode_fn(new_snode) + new_snodes.append(new_snode) + new_operation_name_to_snode[new_operation.get_operation_name()] = new_snode + for o in new_snode.get_outputs(): + name_to_buf[o.get_name()] = o + name_to_fused_node[new_snode.get_name()] = new_snode + multi_output_operations = [] + # identify the trailing MultiOutput operations, if any + for operation in reversed(new_operations): + if isinstance(operation, ir.MultiOutput): + multi_output_operations.insert(0, operation) + else: + break + if len(multi_output_operations) == 0: + # if no MultiOutput operations, it means this fallback kernel has no output - + # in this case, just return the FallbackKernel operation. + assert len(new_operations) == 1 + return new_operations[0] + elif len(multi_output_operations) == 1: + return multi_output_operations[0] + else: + return multi_output_operations + + +def bucket_all_gathers( + schedule_fallback_operation: Callable, + group_size: int, + group_name: str, + ag_input_ir_nodes: list["ir.IRNode"], + orig_ag_snodes: list["scheduler.BaseSchedulerNode"], + name_to_buf: Dict[str, "scheduler.SchedulerBuffer"], + orig_wait_snodes: list["scheduler.BaseSchedulerNode"] = None, + schedule_snode_fn: Union[Callable[..., Any], Any] = None, + return_ag_only: bool = False, +): + """ + bucket multiple all_gather nodes into one all_gather node + return_ag_only set to True: only return the bucketed all_gather node + return_ag_only set to False: return the bucketed all_gather node and the bucketed wait node (in GroupedSchedulerNode) + """ + orig_ag_fx_nodes = [ + _get_fx_node( + sn, expected_op=torch.ops._c10d_functional.all_gather_into_tensor.default + ) + for sn in orig_ag_snodes + ] + ag_input_fx_nodes = [n.args[0] for n in orig_ag_fx_nodes] + assert all( + n.meta["val"].dtype == orig_ag_fx_nodes[0].meta["val"].dtype + for n in orig_ag_fx_nodes + ), "All all_gather inputs in the same bucket must have the same dtype" + + # must schedule all the all_gather input nodes first, before the bucketed all_gather node + param_all_gather_inputs_orig: list[Union[ir.IRNode, scheduler.SchedulerBuffer]] = [] + for ag_input_ir_node in ag_input_ir_nodes: + if ag_input_ir_node.is_input_buffer() or isinstance( + ag_input_ir_node, ir.ReinterpretView + ): + param_all_gather_inputs_orig.append(ag_input_ir_node) + elif ag_input_sched_buf := name_to_buf.get(ag_input_ir_node.get_name()): + if not return_ag_only: + schedule_snode_fn(ag_input_sched_buf.defining_op) + param_all_gather_inputs_orig.append(ag_input_sched_buf.node) + else: + raise ValueError("Unexpected node type") + # assert ag_input_ir_node.is_input_buffer() + # param_all_gather_inputs_orig.append(ag_input_ir_node) + + # schedule the bucketed all_gather node + param_all_gather_inputs_flattened = [ + schedule_fallback_operation(torch.ops.aten.reshape.default, (n, [-1]), {}) + for n in param_all_gather_inputs_orig + ] + + inp_split_sizes = [n.meta["val"].numel() for n in ag_input_fx_nodes] + param_all_gather_outputs = [ + schedule_fallback_operation( + torch.ops.aten.empty.memory_format, + ([n.meta["val"].numel() * group_size],), + { + "dtype": n.meta["val"].dtype, + "device": n.meta["val"].device, + "pin_memory": False, + }, + ) + for n in ag_input_fx_nodes + ] + # TODO(yf225): This assumes dim-0 sharding. + # If we need to support sharding on another dim, we should look at how FSDP2 does it (e.g. search for `shard_dim` in FSDP2 codebase) + param_all_gather_outputs_shape_orig = [ + (n.meta["val"].shape[0] * group_size,) + n.meta["val"].shape[1:] + for n in ag_input_fx_nodes + ] + all_gather_input_numel = sum(inp_split_sizes) + param_all_gather_outputs_flattened = schedule_fallback_operation( + torch.ops.aten.empty.memory_format, + ([all_gather_input_numel * group_size],), + { + "dtype": ag_input_fx_nodes[0].meta["val"].dtype, + "device": ag_input_fx_nodes[0].meta["val"].device, + "pin_memory": False, + }, + ) + + example_ag_input_tensor = ag_input_fx_nodes[0].meta["val"] + all_gather_input, all_gather_output = schedule_fallback_operation( + torch.ops.fsdp.all_gather_copy_in.default, + ( + param_all_gather_inputs_flattened, + param_all_gather_outputs_flattened, + inp_split_sizes, + all_gather_input_numel, + example_ag_input_tensor.device.index % group_size, + ), + {}, + ) + all_gather_into_tensor_out = schedule_fallback_operation( + torch.ops._c10d_functional.all_gather_into_tensor_out.default, + (all_gather_input, group_size, group_name), + {"out": all_gather_output}, + ) + if return_ag_only: + assert len(all_gather_into_tensor_out.inputs) == 1 + return all_gather_input, all_gather_output, all_gather_into_tensor_out + + wait_tensor = schedule_fallback_operation( + torch.ops._c10d_functional.wait_tensor.default, + (all_gather_into_tensor_out,), + {}, + ) + all_gather_output_reshaped = schedule_fallback_operation( + torch.ops.aten.reshape.default, + (wait_tensor, [group_size, -1]), + {}, + ) + outs_flattened = [ + schedule_fallback_operation( + torch.ops.aten.reshape.default, + (n, [group_size, -1]), + {}, + ) + for n in param_all_gather_outputs + ] + split_with_sizes_copy = schedule_fallback_operation( + torch.ops.fsdp.split_with_sizes_copy.default, + (all_gather_output_reshaped, inp_split_sizes), + {"dim": 1, "out": outs_flattened}, + ) + outs = [ + schedule_fallback_operation( + torch.ops.aten.reshape.default, + (n, orig_shape), + {}, + dep_operations=split_with_sizes_copy, + ) + for n, orig_shape in zip(outs_flattened, param_all_gather_outputs_shape_orig) + ] + # Make sure downstream users of original wait nodes are now dependent on the new `outs` nodes + assert len(outs) == len(orig_wait_snodes) + assert len(outs) == len(orig_ag_snodes) + return outs + + +def bucket_reduce_scatters( + schedule_fallback_operation: Callable, + group_size: int, + group_name: str, + reduce_op: Any, + rs_input_ir_nodes: list["ir.IRNode"], + orig_rs_snodes: list["scheduler.BaseSchedulerNode"], + name_to_buf: Dict[str, "scheduler.SchedulerBuffer"], + orig_wait_snodes: list["scheduler.BaseSchedulerNode"] = None, + return_rs_only: bool = False, +): + orig_rs_fx_nodes = [ + _get_fx_node( + sn, expected_op=torch.ops._c10d_functional.reduce_scatter_tensor.default + ) + for sn in orig_rs_snodes + ] + # must schedule all the reduce_scatter input nodes first, before the bucketed reduce_scatter node + unsharded_grads = [] + unsharded_grads_fx_nodes = [n.args[0] for n in orig_rs_fx_nodes] + for rs_input_ir_node in rs_input_ir_nodes: + if rs_input_ir_node.is_input_buffer() or isinstance( + rs_input_ir_node, ir.ReinterpretView + ): + unsharded_grads.append(rs_input_ir_node) + elif rs_input_sched_buf := name_to_buf.get(rs_input_ir_node.get_name()): + unsharded_grads.append(rs_input_sched_buf.node) + else: + raise ValueError("Unexpected node type") + reduce_dtype = unsharded_grads_fx_nodes[0].meta["val"].dtype + # Only float32 and bfloat16 are supported for now. + # To support fp16, please see FSDP2 `_get_gradient_divide_factors`. + assert reduce_dtype in ( + torch.float32, + torch.bfloat16, + ), f"reduce_dtype {reduce_dtype} is not supported" + assert all(n.meta["val"].dtype == reduce_dtype for n in unsharded_grads_fx_nodes) + device = unsharded_grads_fx_nodes[0].meta["val"].device + rank = device.index % group_size + # TODO(yf225): need more work if we want to support non-dim-0 sharding (e.g. search for `shard_dim` in FSDP2 codebase) + shard_dim = 0 + + def _get_dim0_padded_size(tensor_size: torch.Size, dim0_factor: int) -> torch.Size: + padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor + return cast(torch.Size, torch.Size([padded_dim0]) + tensor_size[1:]) + + padded_unsharded_sizes = tuple( + _get_dim0_padded_size(n.meta["val"].size(), group_size) + for n in unsharded_grads_fx_nodes + ) + + reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes) + reduce_scatter_input = schedule_fallback_operation( + torch.ops.aten.empty.memory_format, + ([reduce_scatter_input_numel],), + { + "dtype": reduce_dtype, + "device": device, + "pin_memory": False, + }, + ) + reduce_scatter_input_reshaped = schedule_fallback_operation( + torch.ops.aten.reshape.default, + (reduce_scatter_input, [group_size, -1]), + {}, + ) + # NOTE(yf225): have to turn off Inductor config shape_padding and comprehensive_padding, + # otherwise we get "torch.Size([4096, 80096]) and strides (80128, 1) cannot be viewed as shape (2, 164036608)" error. + chunk_cat = schedule_fallback_operation( + torch.ops.fsdp.chunk_cat.default, + (unsharded_grads,), + { + "dim": 0, + "num_chunks": group_size, + "out": reduce_scatter_input_reshaped, + }, + ) + + reduce_scatter_tensor = schedule_fallback_operation( + torch.ops._c10d_functional.reduce_scatter_tensor.default, + (reduce_scatter_input, reduce_op, group_size, group_name), + {}, + dep_operations=chunk_cat, + ) + + if return_rs_only: + assert len(reduce_scatter_tensor.inputs) == 1 + return reduce_scatter_tensor.inputs[0].inputs[0], reduce_scatter_tensor + + wait_tensor = schedule_fallback_operation( + torch.ops._c10d_functional.wait_tensor.default, + (reduce_scatter_tensor,), + {}, + ) + + def _chunk_with_empty( + tensor: torch.Tensor, num_chunks: int, dim: int + ) -> list[torch.Tensor]: + chunks = list(torch.chunk(tensor, num_chunks, dim=dim)) + while len(chunks) < num_chunks: + chunks.append(chunks[0].new_empty(0)) + return chunks + + reduce_output = wait_tensor + # View out and accumulate sharded gradients + new_sharded_grads = [] + + flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1] + for padded_unsharded_size, unsharded_grad_fx_node in zip( + padded_unsharded_sizes, unsharded_grads_fx_nodes + ): + # NOTE: we only care about the shape of tensors in `chunks`, so using meta tensor here + chunks = _chunk_with_empty( + torch.empty_like(unsharded_grad_fx_node.meta["val"], device="meta"), + group_size, + dim=shard_dim, + ) + sharded_param = chunks[rank] + sharded_size = sharded_param.size() + contiguous_sharded_stride = torch._prims_common.make_contiguous_strides_for( + sharded_size + ) + # Assume even sharding for Shard(i), i > 0; otherwise would require + # copy-out for contiguous strides + new_sharded_grad = schedule_fallback_operation( + torch.ops.aten.as_strided.default, + (reduce_output,), + { + "size": sharded_size, + "stride": contiguous_sharded_stride, + "storage_offset": flat_grad_offset, + }, + ) + new_sharded_grads.append(new_sharded_grad) + padded_sharded_numel = padded_unsharded_size.numel() // group_size + flat_grad_offset += padded_sharded_numel + assert len(orig_wait_snodes) == len(new_sharded_grads) + assert len(orig_wait_snodes) == len(orig_rs_snodes) + return new_sharded_grads