diff --git a/autoparallel/auto_bucketing.py b/autoparallel/auto_bucketing.py new file mode 100644 index 00000000..4bbd2770 --- /dev/null +++ b/autoparallel/auto_bucketing.py @@ -0,0 +1,40 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from .autobucketing_util import bucket_utils + + +class simplefsdp_autobucketing_config: + """ + Config for simplefsdp's autobucketing pass, which by default would give good performance. + To make the results tunable, we expose the following parameters: + - relax_ratio: relax comp time to include more comm in one bucket + with this config, comp is updated as comp * (1 + relax_ratio) + - peak_memory_offset: relax peak_memory to include more comm in one bucket + with this config, peak_memory is updated as (peak_memory + peak_memory_offset) + - load_cache: set to True to load cache from save_estimation_path + - enable_bucket_ir: set to True to bucket all_gather/reduce_scatter + - enable_reorder_ir: set to True to reorder all_gather/reduce_satter + """ + + relax_ratio = 0 + peak_memory_offset = 0 + load_cache = False + save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast.pkl" + enable_bucket_ir = True + enable_reorder_ir = True + + +def simple_fsdp_autobucketing_reordering_pass( + snodes: list["torch._inductor.scheduler.BaseSchedulerNode"], + configs: "simplefsdp_autobucketing_config", +) -> list["torch._inductor.scheduler.BaseSchedulerNode"]: + scheduler = snodes[0].scheduler + bucket_utils.get_bucketable_ir_nodes( + snodes, scheduler.name_to_fused_node, scheduler.name_to_buf + ) + return snodes diff --git a/autoparallel/autobucketing_util/bucket_utils.py b/autoparallel/autobucketing_util/bucket_utils.py new file mode 100644 index 00000000..09f27440 --- /dev/null +++ b/autoparallel/autobucketing_util/bucket_utils.py @@ -0,0 +1,119 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# mypy: ignore-errors +from typing import Any, Callable, Dict + +import torch +from torch._inductor import scheduler +from torch._inductor.dependencies import WeakDep +from torch._inductor.utils import buf_name_to_fused_snode, is_collective +from torch.utils._ordered_set import OrderedSet + + +def _find_recursive_deps_of_snode( + snode: "scheduler.BaseSchedulerNode", + collected_node_set: OrderedSet["scheduler.BaseSchedulerNode"], + name_to_buf: Dict[str, "scheduler.SchedulerBuffer"], + name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"], + criteria_cb: Callable[[Any], bool] = lambda snode: False, + allow_weak_dep: bool = True, +): + if criteria_cb(snode): + return + collected_node_set.add(snode) + for dep in snode.unmet_dependencies: + if isinstance(dep, WeakDep) and not allow_weak_dep: + continue + defining_op_for_dep = buf_name_to_fused_snode( + dep.name, name_to_buf, name_to_fused_node + ) + if defining_op_for_dep in collected_node_set: + continue + _find_recursive_deps_of_snode( + defining_op_for_dep, + collected_node_set, + name_to_buf, + name_to_fused_node, + criteria_cb=criteria_cb, + ) + + +def _find_recursive_users_of_snode( + snode: "scheduler.BaseSchedulerNode", + collected_node_set: OrderedSet["scheduler.BaseSchedulerNode"], + name_to_buf: Dict[str, "scheduler.SchedulerBuffer"], + name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"], + criteria_cb: Callable[[Any], bool] = lambda snode: False, +): + if criteria_cb(snode): + return + collected_node_set.add(snode) + for o in snode.get_outputs(): + for user in o.users: + assert user.node is not None + if user.node.get_name() == "OUTPUT": + continue + if user.node.get_name() not in name_to_fused_node: + continue + user_op = name_to_fused_node[user.node.get_name()] + if user_op in collected_node_set: + continue + _find_recursive_users_of_snode( + user_op, + collected_node_set, + name_to_buf, + name_to_fused_node, + criteria_cb=criteria_cb, + ) + + +def get_bucketable_ir_nodes( + snodes: list["torch._inductor.scheduler.BaseSchedulerNode"], + name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"], + name_to_buf: Dict[str, "scheduler.SchedulerBuffer"], +) -> set[str]: + """ + This function selects the ir nodes' names that are bucketable + From first principle, only all-gathers that gather parameters and reduce-scatters + that update model gradients could be bucketed together. + Thus, bucketable all-gathers's deps are (1) computed buffer for dtype conversion (optional) + (2) all-gather itself + bucketable reduce-scatter wait's users are (1) reduce-scatter wait itself + """ + bucketable_ir_nodes = set() + for snode in snodes: + if is_collective( + snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor.default + ): + ag_related_snode_set: OrderedSet[ + "torch._inductor.scheduler.BaseSchedulerNode" + ] = OrderedSet() + _find_recursive_deps_of_snode( + snode, + ag_related_snode_set, + name_to_buf, + name_to_fused_node, + allow_weak_dep=False, + ) + if len(ag_related_snode_set) <= 2: + bucketable_ir_nodes.add(snode.node.get_name()) + elif is_collective( + snode.node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default + ): + wait_snode = snode.get_outputs()[0].users[0].node + wait_snode_recursive_users: OrderedSet[ + "torch._inductor.scheduler.BaseSchedulerNode" + ] = OrderedSet() + _find_recursive_users_of_snode( + wait_snode, + wait_snode_recursive_users, + name_to_buf, + name_to_fused_node, + ) + if len(wait_snode_recursive_users) <= 1: + bucketable_ir_nodes.add(snode.node.get_name()) + + return bucketable_ir_nodes