diff --git a/autoparallel/auto_bucketing.py b/autoparallel/auto_bucketing.py index d78fd9cf..85e1573a 100644 --- a/autoparallel/auto_bucketing.py +++ b/autoparallel/auto_bucketing.py @@ -5,7 +5,7 @@ import torch -from .autobucketing_util import bucket_utils +from .autobucketing_util import bucket_plan, bucket_utils class simplefsdp_autobucketing_config: @@ -36,7 +36,21 @@ def simple_fsdp_autobucketing_reordering_pass( configs: "simplefsdp_autobucketing_config", ) -> list["torch._inductor.scheduler.BaseSchedulerNode"]: scheduler = snodes[0].scheduler - bucket_utils.get_bucketable_ir_nodes( + bucketable_nodes = bucket_utils.get_bucketable_ir_nodes( snodes, scheduler.name_to_fused_node, scheduler.name_to_buf ) + + assert ( + not torch._inductor.config.allow_buffer_reuse + ), "bucketing algorithm requires torch._inductor.config.allow_buffer_reuse to be False" + + if configs.enable_bucket_ir: + all_gather_plan, reduce_scatter_plan = bucket_plan.get_simplefsdp_auto_plan( + scheduler, + snodes, + scheduler.name_to_buf, + scheduler.name_to_fused_node, + bucketable_nodes, + configs, + ) return snodes diff --git a/autoparallel/autobucketing_util/bucket_plan.py b/autoparallel/autobucketing_util/bucket_plan.py new file mode 100644 index 00000000..d32d875a --- /dev/null +++ b/autoparallel/autobucketing_util/bucket_plan.py @@ -0,0 +1,345 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import functools + +# mypy: ignore-errors +from collections import defaultdict +from typing import Any, Dict + +import torch +from torch._C._distributed_c10d import ReduceOp +from torch._inductor import scheduler +from torch._inductor.comm import _schedule_fallback_operation +from torch._inductor.utils import is_collective + +from .bucket_utils import ( + check_ir_node_bucketable, + estimate_bucketed_snode_runtime, + get_data_size, + get_snode_process_group_info, + get_snode_tensor_info, +) +from .estimation import benchmark_and_sync_runtime + + +def get_dynamic_memory_threshold( + peak_memory, + peak_memory_per_step_dict, + current_step, +) -> int: + """ + this function calculates the memory gap from the current step's memory to the peak memory + """ + left_peak_memory = 0 + right_peak_memory = 0 + for idx, memory in peak_memory_per_step_dict.items(): + if idx <= current_step: + left_peak_memory = max(memory, left_peak_memory) + if idx >= current_step: + right_peak_memory = max(memory, right_peak_memory) + current_peak_memory = min(left_peak_memory, right_peak_memory) + return peak_memory - current_peak_memory + + +def get_simplefsdp_auto_plan( + sched: "scheduler.Scheduler", + snodes: list["scheduler.BaseSchedulerNode"], + name_to_buf: Dict[str, "scheduler.SchedulerBuffer"], + name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"], + bucketable_nodes: set[str], + configs: Any, + verbose: bool = True, +) -> tuple[ + list[Dict[tuple[Any, ...], list["scheduler.BaseSchedulerNode"]]], + list[Dict[tuple[Any, ...], list["scheduler.BaseSchedulerNode"]]], +]: + """ + This function implements a greedy algorithm, which decides if the node could be bucketed + with the previous one based on several criteria below: + FWD Pass: + (1) the bucketed AG communication could be overlapped by the previous computation; + (2) the bucketed AG memory doesn't exceed peak memory; + (3) bucketed AG communication size doesn't exceed 0.2*sum(fwd_ag_tensor_list), such + that the estimated AG communication time is always in the calibration bound. + BWD Pass: + (1) the bucketed AG + RS communication could be overlapped by the previous computation; + (2) the bucketed AG+RS memory doesn't exceed peak memory; + (3) RS always have future compute to overlap it, such that its final exposed communication is small; + (4) bucketed AG/RS communication size doesn't exceed 0.2* sum(fwd_ag_tensor_list) & 0.2* sum(bwd_rs_tensor_list), + such that the estimated AG/RS communication time is always in the calibration bound. + """ + all_gather_plan = [] + reduce_scatter_plan = [] + current_ag_bucket: Dict[ + tuple[Any, ...], list["scheduler.BaseSchedulerNode"] + ] = defaultdict(list) + current_rs_bucket: Dict[ + tuple[Any, ...], list["scheduler.BaseSchedulerNode"] + ] = defaultdict(list) + schedule_fallback_operation = functools.partial( + _schedule_fallback_operation, + scheduler=sched, + name_to_buf=name_to_buf, + name_to_fused_node=name_to_fused_node, + ) + + heuristic_info = { + # time info + "last_step_rs_comm_time": 0.0, + "this_step_comp_time": 0.0, + "this_step_rs_comm_time": 0.0, + "next_step_comp_time": 0.0, + "next_step_nonfsdp_comm_time": 0.0, + # memory info + "accumulated_gradient_size": 0, + "last_step_rs_comm_size": 0, + "this_step_rs_comm_out_size": 0, + "this_step_rs_comm_inp_size": 0, + "this_step_memory": 0, + "next_step_memory": 0, + } + + # sync runtime info across ranks + ( + comm_cache, + comp_time_dict, + memory_dict, + peak_memory_per_step_dict, + ) = benchmark_and_sync_runtime( + sched, snodes, name_to_buf, name_to_fused_node, bucketable_nodes, configs + ) + future_comp_time = sum(comp_time_dict.values()) + peak_memory = max(peak_memory_per_step_dict.values()) + configs.peak_memory_offset + + # autobucket algorithm + bucketable_ag_idx = -1 + seen_new_bucketable_ag = True + for _, snode in enumerate(snodes): + if is_collective( + snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor.default + ) and check_ir_node_bucketable(snode.node, bucketable_nodes): + bucketable_ag_idx += 1 + seen_new_bucketable_ag = True + future_comp_time -= comp_time_dict[bucketable_ag_idx] + + ag_node_info = get_snode_tensor_info( + snode, return_data_size=False + ) + get_snode_process_group_info( + snode, + expected_op=torch.ops._c10d_functional.all_gather_into_tensor.default, + resolve_pg=False, + ) + current_ag_bucket[ag_node_info].append(snode) + ( + estimated_comm, + comm_size_inp, + comm_size_out, + ) = estimate_bucketed_snode_runtime( + current_ag_bucket, + schedule_fallback_operation, + name_to_buf, + "torch.ops._c10d_functional.all_gather_into_tensor.default", + comm_cache, + ) + + # Check if current bucketing breaks the greedy criteria + # (1) Overlappping criteria + comp_time = heuristic_info["this_step_comp_time"] * ( + 1 + configs.relax_ratio + ) + comm_time = estimated_comm + heuristic_info["last_step_rs_comm_time"] + break_overlap_criteria = comp_time < comm_time + + # (2) Memory criteria + memory_threshold = get_dynamic_memory_threshold( + peak_memory, + peak_memory_per_step_dict, + bucketable_ag_idx, + ) + # the buckted AG/RS are created on-the-fly, whose memory was not captured by the + # estimate_peak_memory function. The bucketed_comm_memory consists of: + # in FWD pass: + # (1) all-gather copy-in (comm_size_inp): smaller buffers for dtype_conversion + bigger buffer to copy_in smaller buffers + # thus, we have comm_size_inp*2 + # (2) all-gather copy-out (comm_size_out): bigger buffer to copy_out from ag_wait + split out smaller buffers for compute + # thus, we have comm_size_out*2 + # in BWD pass: + # TODO (ruisizhang123): we need to double check this. From memory trace, we can clearly see + # these three regions stack together at a certain moment + # due to reordering, the peak memory occurs at the end of current step's all-gather when last step & this step's reduce-scatter + # are not cleared in time + # (1) all-gather copy-in/copy-out (like FWD pass) + # (2) last step's reduce-scatter: bigger buffer containts gradient + # (3) next step's reduce-scatter: smaller buffers for dtype_conversion + bigger buffer to copy_in gradient + bucketed_comm_memory = ( + 2 * comm_size_inp + + 2 * comm_size_out + + heuristic_info["this_step_rs_comm_inp_size"] * 2 + + heuristic_info["last_step_rs_comm_size"] + ) + break_memory_criteria = ( + memory_threshold + < heuristic_info["next_step_memory"] + bucketed_comm_memory + ) + + # (3) Communication size criteria + break_comm_size_criteria = comm_cache.ag_max_inp_size < comm_size_inp + if comm_cache.rs_max_out_size > 0: + break_comm_size_criteria = ( + break_comm_size_criteria + or comm_cache.rs_max_out_size + < heuristic_info["this_step_rs_comm_out_size"] + ) + + if ( + break_overlap_criteria + or break_memory_criteria + or break_comm_size_criteria + ): + if heuristic_info["this_step_comp_time"] > 0: + # if bucketing breaks the greedy criteria, pop the last node out + overflow_ag = current_ag_bucket[ag_node_info].pop() + all_gather_plan.append(current_ag_bucket) + current_ag_bucket: Dict[ + tuple[Any, ...], list["scheduler.BaseSchedulerNode"] + ] = defaultdict(list) + current_ag_bucket[ag_node_info].append(overflow_ag) + else: + # if there is no compute, we have to keep the all_gather to avoid deadlock + all_gather_plan.append(current_ag_bucket) + current_ag_bucket: Dict[ + tuple[Any, ...], list["scheduler.BaseSchedulerNode"] + ] = defaultdict(list) + + if verbose: + print( + "break_overlap_criteria", + break_overlap_criteria, + ) + print("Current comm time", comm_time, "comp time", comp_time) + print( + "break_memory_criteria", + break_memory_criteria, + ) + print( + "memory_threshold", + memory_threshold, + "total memory", + heuristic_info["next_step_memory"] + bucketed_comm_memory, + ) + print( + "break_comm_size_criteria", + break_comm_size_criteria, + ) + print("current_ag_bucket", all_gather_plan[-1]) + + # bucket reduce scatters if there are any + if len(current_rs_bucket) > 0: + ( + current_estimated_rs, + rs_comm_size_inp, + rs_comm_size_out, + ) = estimate_bucketed_snode_runtime( + current_rs_bucket, + schedule_fallback_operation, + name_to_buf, + "torch.ops._c10d_functional.reduce_scatter_tensor.default", + comm_cache, + ReduceOp.AVG, + ) + heuristic_info["last_step_rs_comm_time"] = current_estimated_rs + reduce_scatter_plan.append(current_rs_bucket) + heuristic_info["last_step_rs_comm_size"] = rs_comm_size_out + current_rs_bucket: Dict[ + tuple[Any, ...], list["scheduler.BaseSchedulerNode"] + ] = defaultdict(list) + + # update heuristic info for the next step + ( + heuristic_info["this_step_comp_time"], + heuristic_info["this_step_memory"], + ) = ( + heuristic_info["next_step_comp_time"] + + heuristic_info["next_step_nonfsdp_comm_time"], + heuristic_info["next_step_memory"], + ) + ( + heuristic_info["next_step_comp_time"], + heuristic_info["next_step_memory"], + ) = ( + 0, + 0, + ) + elif is_collective( + snode.node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default + ) and check_ir_node_bucketable(snode.node, bucketable_nodes): + node_info = get_snode_tensor_info( + snode, return_data_size=False + ) + get_snode_process_group_info( + snode, + expected_op=torch.ops._c10d_functional.reduce_scatter_tensor.default, + resolve_pg=False, + ) + current_rs_bucket[node_info].append(snode) + + ( + heuristic_info["this_step_rs_comm_time"], + rs_comm_size_inp, + rs_comm_size_out, + ) = estimate_bucketed_snode_runtime( + current_rs_bucket, + schedule_fallback_operation, + name_to_buf, + "torch.ops._c10d_functional.reduce_scatter_tensor.default", + comm_cache, + ReduceOp.AVG, + ) + heuristic_info["this_step_rs_comm_out_size"] = rs_comm_size_out + heuristic_info["this_step_rs_comm_inp_size"] = rs_comm_size_inp + heuristic_info["accumulated_gradient_size"] += get_data_size( + snode.node.layout.size + ) + + # Check if current bucketing breaks the greedy criteria + # (4) future compute to overlap RS criteria + break_rs_overlap_criteria = ( + future_comp_time < heuristic_info["this_step_rs_comm_time"] * 5 + ) + if break_rs_overlap_criteria: + reduce_scatter_plan.append(current_rs_bucket) + heuristic_info["last_step_rs_comm_time"] = heuristic_info[ + "this_step_rs_comm_time" + ] + heuristic_info["this_step_rs_comm_time"] = 0 + current_rs_bucket: Dict[ + tuple[Any, ...], list["scheduler.BaseSchedulerNode"] + ] = defaultdict(list) + else: + # TODO (ruisizhang123): for now, we only consider FSDP + (TP & CP), whose comms are AG & RS & All_Reduce + # For TP and CP, we consider the node as a "COMP" node with exposed communication as Comp time + if is_collective(snode.node): + current_comm = comm_cache.get_comm_time( + snode.node.inputs[0].layout.size, + snode.node.layout.size, + getattr(snode.node, "python_kernel_name", ""), + calibrated=True, + ) + heuristic_info["next_step_nonfsdp_comm_time"] += current_comm + else: + if seen_new_bucketable_ag: + heuristic_info["next_step_memory"] += memory_dict[bucketable_ag_idx] + heuristic_info["next_step_comp_time"] += comp_time_dict[ + bucketable_ag_idx + ] + seen_new_bucketable_ag = False + + if len(current_ag_bucket) > 0: + all_gather_plan.append(current_ag_bucket) + + if len(current_rs_bucket) > 0: + reduce_scatter_plan.append(current_rs_bucket) + + return all_gather_plan, reduce_scatter_plan diff --git a/autoparallel/autobucketing_util/bucket_utils.py b/autoparallel/autobucketing_util/bucket_utils.py index 1d6082ad..52041642 100644 --- a/autoparallel/autobucketing_util/bucket_utils.py +++ b/autoparallel/autobucketing_util/bucket_utils.py @@ -9,6 +9,7 @@ import torch from torch._inductor import ir, scheduler +from torch._inductor.comm import bucket_all_gathers, bucket_reduce_scatters from torch._inductor.dependencies import WeakDep from torch._inductor.ir import NoneLayout from torch._inductor.utils import buf_name_to_fused_snode, is_collective, is_wait @@ -241,3 +242,101 @@ def get_snode_tensor_info( if return_data_size: result += (input_size, output_size) return result + + +def _estimate_bucketed_node_list( + current_node_list: list["scheduler.BaseSchedulerNode"], + schedule_fallback_operation: Callable[[Any], Any], + group_size: int, + group_name: str, + name_to_buf: Dict[str, "scheduler.SchedulerBuffer"], + comm_func: Callable[[Any], Any], + comm_cache: Dict[Any, Any], + reduce_op: Any = None, +): + input_ir_nodes = [n.node.inputs[0] for n in current_node_list] + if len(input_ir_nodes) == 1: + # standalone node, no need to bucket + comm_node = current_node_list[0].node + comm_size_inp, comm_size_out = ( + comm_node.inputs[0].layout.size, + comm_node.layout.size, + ) + estimated_comm = comm_cache.get_comm_time( + comm_size_inp, + comm_size_out, + comm_func, + calibrated=True, + ) + return estimated_comm, comm_size_inp, comm_size_out + + if comm_func == "torch.ops._c10d_functional.all_gather_into_tensor.default": + bucked_node = bucket_all_gathers( + schedule_fallback_operation, + group_size, + group_name, + input_ir_nodes, + current_node_list, + name_to_buf, + return_ag_only=True, + ) + comm_size_inp = bucked_node[0].layout.size + comm_size_out = bucked_node[1].layout.size + elif comm_func == "torch.ops._c10d_functional.reduce_scatter_tensor.default": + bucked_node = bucket_reduce_scatters( + schedule_fallback_operation, + group_size, + group_name, + reduce_op, + input_ir_nodes, + current_node_list, + name_to_buf, + return_rs_only=True, + ) + comm_size_inp = bucked_node[0].layout.size + comm_size_out = bucked_node[1].layout.size + else: + raise ValueError(f"Unsupported comm_func {comm_func} for bucketing") + + estimated_comm = comm_cache.get_comm_time( + bucked_node[0].layout.size, + bucked_node[1].layout.size, + comm_func, + calibrated=True, + ) + return estimated_comm, comm_size_inp, comm_size_out + + +def estimate_bucketed_snode_runtime( + node_bucket_dict: Dict[tuple[Any, ...], list["scheduler.BaseSchedulerNode"]], + schedule_fallback_operation: Callable[[Any], Any], + name_to_buf: Dict[str, "scheduler.SchedulerBuffer"], + comm_func: Callable[[Any], Any], + comm_cache: Dict[Any, Any], + reduce_op: Any = None, +) -> tuple[float, int, int]: + """ + We can have AG/RS nodes from different process groups & different dtypes in a bucket + This function estimates the accumulated time & comm size of a bucketed AG/RS node + """ + estimated_comm, comm_size_inp, comm_size_out = 0, 0, 0 + for node_info, node_list in node_bucket_dict.items(): + group_size, group_name = node_info[-2], node_info[-1] + ( + local_comm, + local_comm_size_inp, + local_comm_size_out, + ) = _estimate_bucketed_node_list( + node_list, + schedule_fallback_operation, + group_size, + group_name, + name_to_buf, + comm_func, + comm_cache, + reduce_op, + ) + estimated_comm += local_comm + comm_size_inp += get_data_size(local_comm_size_inp) + comm_size_out += get_data_size(local_comm_size_out) + return estimated_comm, comm_size_inp, comm_size_out