Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions autoparallel/auto_bucketing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import torch

from .autobucketing_util import bucket_utils


class simplefsdp_autobucketing_config:
"""
Config for simplefsdp's autobucketing pass, which by default would give good performance.
To make the results tunable, we expose the following parameters:
- relax_ratio: relax comp time to include more comm in one bucket
with this config, comp is updated as comp * (1 + relax_ratio)
- peak_memory_offset: relax peak_memory to include more comm in one bucket
with this config, peak_memory is updated as (peak_memory + peak_memory_offset)
- load_cache: set to True to load cache from save_estimation_path
- enable_bucket_ir: set to True to bucket all_gather/reduce_scatter
- enable_reorder_ir: set to True to reorder all_gather/reduce_satter
"""

relax_ratio = 0
peak_memory_offset = 0
load_cache = False
save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast.pkl"
enable_bucket_ir = True
enable_reorder_ir = True


def simple_fsdp_autobucketing_reordering_pass(
snodes: list["torch._inductor.scheduler.BaseSchedulerNode"],
configs: "simplefsdp_autobucketing_config",
) -> list["torch._inductor.scheduler.BaseSchedulerNode"]:
scheduler = snodes[0].scheduler
bucket_utils.get_bucketable_ir_nodes(
snodes, scheduler.name_to_fused_node, scheduler.name_to_buf
)
return snodes
119 changes: 119 additions & 0 deletions autoparallel/autobucketing_util/bucket_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# mypy: ignore-errors
from typing import Any, Callable, Dict

import torch
from torch._inductor import scheduler
from torch._inductor.dependencies import WeakDep
from torch._inductor.utils import buf_name_to_fused_snode, is_collective
from torch.utils._ordered_set import OrderedSet


def _find_recursive_deps_of_snode(
snode: "scheduler.BaseSchedulerNode",
collected_node_set: OrderedSet["scheduler.BaseSchedulerNode"],
name_to_buf: Dict[str, "scheduler.SchedulerBuffer"],
name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"],
criteria_cb: Callable[[Any], bool] = lambda snode: False,
allow_weak_dep: bool = True,
):
if criteria_cb(snode):
return
collected_node_set.add(snode)
for dep in snode.unmet_dependencies:
if isinstance(dep, WeakDep) and not allow_weak_dep:
continue
defining_op_for_dep = buf_name_to_fused_snode(
dep.name, name_to_buf, name_to_fused_node
)
if defining_op_for_dep in collected_node_set:
continue
_find_recursive_deps_of_snode(
defining_op_for_dep,
collected_node_set,
name_to_buf,
name_to_fused_node,
criteria_cb=criteria_cb,
)


def _find_recursive_users_of_snode(
snode: "scheduler.BaseSchedulerNode",
collected_node_set: OrderedSet["scheduler.BaseSchedulerNode"],
name_to_buf: Dict[str, "scheduler.SchedulerBuffer"],
name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"],
criteria_cb: Callable[[Any], bool] = lambda snode: False,
):
if criteria_cb(snode):
return
collected_node_set.add(snode)
for o in snode.get_outputs():
for user in o.users:
assert user.node is not None
if user.node.get_name() == "OUTPUT":
continue
if user.node.get_name() not in name_to_fused_node:
continue
user_op = name_to_fused_node[user.node.get_name()]
if user_op in collected_node_set:
continue
_find_recursive_users_of_snode(
user_op,
collected_node_set,
name_to_buf,
name_to_fused_node,
criteria_cb=criteria_cb,
)


def get_bucketable_ir_nodes(
snodes: list["torch._inductor.scheduler.BaseSchedulerNode"],
name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"],
name_to_buf: Dict[str, "scheduler.SchedulerBuffer"],
) -> set[str]:
"""
This function selects the ir nodes' names that are bucketable
From first principle, only all-gathers that gather parameters and reduce-scatters
that update model gradients could be bucketed together.
Thus, bucketable all-gathers's deps are (1) computed buffer for dtype conversion (optional)
(2) all-gather itself
bucketable reduce-scatter wait's users are (1) reduce-scatter wait itself
"""
bucketable_ir_nodes = set()
for snode in snodes:
if is_collective(
snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor.default
):
ag_related_snode_set: OrderedSet[
"torch._inductor.scheduler.BaseSchedulerNode"
] = OrderedSet()
_find_recursive_deps_of_snode(
snode,
ag_related_snode_set,
name_to_buf,
name_to_fused_node,
allow_weak_dep=False,
)
if len(ag_related_snode_set) <= 2:
bucketable_ir_nodes.add(snode.node.get_name())
elif is_collective(
snode.node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default
):
wait_snode = snode.get_outputs()[0].users[0].node
wait_snode_recursive_users: OrderedSet[
"torch._inductor.scheduler.BaseSchedulerNode"
] = OrderedSet()
_find_recursive_users_of_snode(
wait_snode,
wait_snode_recursive_users,
name_to_buf,
name_to_fused_node,
)
if len(wait_snode_recursive_users) <= 1:
bucketable_ir_nodes.add(snode.node.get_name())

return bucketable_ir_nodes