Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion autoparallel/auto_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from .autobucketing_util import bucket_func, bucket_plan, bucket_utils
from .autobucketing_util import bucket_func, bucket_plan, bucket_utils, reorder


class simplefsdp_autobucketing_config:
Expand Down Expand Up @@ -71,4 +71,19 @@ def simple_fsdp_autobucketing_reordering_pass(
reduce_scatter_plan,
bucketable_nodes,
)

if configs.enable_reorder_ir:
print("Reorder scheduler nodes with autobucketing algroithm")
node_length = len(snodes)
snodes = reorder.reorder_all_gather(
snodes, bucketable_nodes, all_gather_before_last_wait=False
)
assert node_length == len(
snodes
), f"Missed nodes in reordering all gather: expected {node_length}, but got {len(snodes)}"
snodes = reorder.reorder_reduce_scatter(snodes, bucketable_nodes)
assert node_length == len(
snodes
), f"Missed nodes in reordering reduce scatter: expected {node_length}, but got {len(snodes)}"

return snodes
264 changes: 264 additions & 0 deletions autoparallel/autobucketing_util/reorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# mypy: ignore-errors
from collections import defaultdict
from enum import IntEnum
from typing import Dict, List, Optional, Tuple

import torch
from torch._inductor import ir, scheduler
from torch._inductor.utils import contains_collective, contains_wait, is_collective
from torch.utils._ordered_set import OrderedSet

from .bucket_utils import check_ir_node_bucketable


class NodeType(IntEnum):
ALL_GATHER = 0
COMPUTE = 1
REDUCE_SCATTER = 2
AG_WAIT = 3
RS_WAIT = 4


def compute_node_users(
snodes: List["scheduler.BaseSchedulerNode"],
) -> Tuple[
Dict["scheduler.BaseSchedulerNode", OrderedSet["scheduler.BaseSchedulerNode"]],
Dict["scheduler.BaseSchedulerNode", OrderedSet["scheduler.BaseSchedulerNode"]],
]:
"""
Compute the inverse users and users of each node
"""
buf_to_snode: Dict[str, scheduler.BaseSchedulerNode] = {}
for node in snodes:
if isinstance(node, scheduler.FusedSchedulerNode):
for x in node.snodes:
for buf in x.get_outputs():
buf_to_snode[buf.get_name()] = node

for buf in node.get_outputs():
buf_to_snode[buf.get_name()] = node

inverse_users = {}
keys = list(buf_to_snode.keys())
for node in snodes:
dep_list = []
for dep in node.unmet_dependencies:
if dep.name in keys:
dep_list.append(buf_to_snode[dep.name])
inverse_users.update({node: OrderedSet(dep_list)})

node_users: Dict[
scheduler.BaseSchedulerNode, OrderedSet[scheduler.BaseSchedulerNode]
] = defaultdict(OrderedSet)
for node, node_inverse_users in inverse_users.items():
for inverse_user in node_inverse_users:
node_users[inverse_user].add(node)

return inverse_users, node_users


def _get_ir_node_type(ir_node: "ir.Operation", bucketable_ir_nodes) -> NodeType:
"""
Determine the type of a ir node
"""
if isinstance(ir_node, ir._WaitKernel):
# Determine if the wait node is waiting for ALL_GATHER or REDUCE_SCATTER
ir_op_overload = getattr(ir_node.inputs[0], "op_overload", None)
if (
ir_op_overload == torch.ops._c10d_functional.all_gather_into_tensor.default
and check_ir_node_bucketable(ir_node.inputs[0], bucketable_ir_nodes)
):
return NodeType.AG_WAIT
elif (
ir_op_overload == torch.ops._c10d_functional.reduce_scatter_tensor.default
and check_ir_node_bucketable(ir_node.inputs[0], bucketable_ir_nodes)
):
return NodeType.RS_WAIT
if isinstance(ir_node, ir._CollectiveKernel):
# Determine if the collective kernel is for ALL_GATHER or REDUCE_SCATTER
ir_op_overload = getattr(ir_node, "op_overload", None)
if is_collective(
ir_node, op=torch.ops._c10d_functional.all_gather_into_tensor.default
) and check_ir_node_bucketable(ir_node, bucketable_ir_nodes):
return NodeType.ALL_GATHER
elif is_collective(
ir_node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default
) and check_ir_node_bucketable(ir_node, bucketable_ir_nodes):
return NodeType.REDUCE_SCATTER

if isinstance(ir_node, ir.FallbackKernel):
python_kernel_name = ir_node.python_kernel_name
if (
python_kernel_name == "torch.ops._c10d_functional.wait_tensor.default"
and check_ir_node_bucketable(ir_node, bucketable_ir_nodes)
):
inputs_rs_kernel_name1 = (
getattr(ir_node.inputs[0], "python_kernel_name", "")
== "torch.ops._c10d_functional.reduce_scatter_tensor.default"
)
inputs_rs_kernel_name2 = (
hasattr(ir_node.inputs[0], "inputs")
and getattr(ir_node.inputs[0].inputs[0], "python_kernel_name", "")
== "torch.ops._c10d_functional.reduce_scatter_tensor.default"
)
if inputs_rs_kernel_name1 or inputs_rs_kernel_name2:
return NodeType.RS_WAIT

inputs_ag_kernel_name1 = (
getattr(ir_node.inputs[0], "python_kernel_name", "")
== "torch.ops._c10d_functional.all_gather_into_tensor_out.default"
)
inputs_ag_kernel_name2 = (
hasattr(ir_node.inputs[0], "inputs")
and getattr(ir_node.inputs[0].inputs[0], "python_kernel_name", "")
== "torch.ops._c10d_functional.all_gather_into_tensor_out.default"
)
if inputs_ag_kernel_name1 or inputs_ag_kernel_name2:
return NodeType.AG_WAIT
elif (
python_kernel_name
== "torch.ops._c10d_functional.reduce_scatter_tensor.default"
) and check_ir_node_bucketable(ir_node, bucketable_ir_nodes):
return NodeType.REDUCE_SCATTER
elif (
python_kernel_name
== "torch.ops._c10d_functional.all_gather_into_tensor_out.default"
) and check_ir_node_bucketable(ir_node, bucketable_ir_nodes):
return NodeType.ALL_GATHER
return NodeType.COMPUTE


def get_node_type(node: "scheduler.BaseSchedulerNode", bucketable_ir_nodes) -> NodeType:
"""
Determine the NodeType of a node
"""
if isinstance(node, scheduler.FusedSchedulerNode):
# Only compute nodes are fused
return NodeType.COMPUTE

if isinstance(node, scheduler.GroupedSchedulerNode):
# [Only for bucketing]: newly created AG and RS are grouped as GroupedSchedulerNode
child_nodes_type = [
_get_ir_node_type(n.node, bucketable_ir_nodes) for n in node.snodes
]
if NodeType.AG_WAIT in child_nodes_type:
return NodeType.AG_WAIT
elif NodeType.RS_WAIT in child_nodes_type:
return NodeType.RS_WAIT
elif NodeType.ALL_GATHER in child_nodes_type:
return NodeType.ALL_GATHER
elif NodeType.REDUCE_SCATTER in child_nodes_type:
return NodeType.REDUCE_SCATTER
else:
return NodeType.COMPUTE

return _get_ir_node_type(node.node, bucketable_ir_nodes)


def reorder_all_gather(
snodes: List["scheduler.BaseSchedulerNode"],
bucketable_ir_nodes: set[str],
all_gather_before_last_wait: Optional[bool] = True,
) -> List["scheduler.BaseSchedulerNode"]:
"""
Reorder All Gather and Wait in the forward/backward pass;
1. all_gather_before_last_wait set to True: all_gather_i is reordered before wait_i-1
2. all_gather_before_last_wait set to False: all_gather_i is reordered after wait_i-1
"""
result_list: List[scheduler.BaseSchedulerNode] = []
all_gather_list: List[scheduler.BaseSchedulerNode] = []
node_to_type: Dict[scheduler.BaseSchedulerNode, int] = {}
inverse_users, node_users = compute_node_users(snodes)

for node in snodes:
node_to_type[node] = get_node_type(node, bucketable_ir_nodes)
snodes.reverse()
for idx, node in enumerate(snodes):
node_type = node_to_type[node]
if node_type in [NodeType.REDUCE_SCATTER, NodeType.COMPUTE, NodeType.RS_WAIT]:
# we do not reorder reduce scatter and compute node
if node not in result_list and node not in all_gather_list:
result_list.append(node)
elif node_type == NodeType.ALL_GATHER:
# gather i-th all gather node and its dependencies
all_gather_list.append(node)
inverse_user = list(inverse_users[node])
inverse_user = [
n
for n in inverse_user
if node_to_type[n] == NodeType.COMPUTE
and not contains_collective(n)
and not contains_wait(n)
]
if len(inverse_user) > 0:
all_gather_list.extend(inverse_user)
elif node_type == NodeType.AG_WAIT:
if not all_gather_before_last_wait and len(all_gather_list) > 0:
assert node_to_type[snodes[idx + 1]] == NodeType.ALL_GATHER
# move i-th all gather node and its dependencies after (i-1)-th wait node (bc this is a reverse list)
result_list.extend(all_gather_list)
all_gather_list = []

result_list.append(node)

if all_gather_before_last_wait and len(all_gather_list) > 0:
assert node_to_type[snodes[idx + 1]] == NodeType.ALL_GATHER
# move i-th all gather node and its dependencies before (i-1)-th wait node (bc this is a reverse list)
result_list.extend(all_gather_list)
all_gather_list = []
if len(all_gather_list) > 0:
result_list.extend(all_gather_list)
result_list.reverse()

return result_list


def reorder_reduce_scatter(
snodes: List["scheduler.BaseSchedulerNode"],
bucketable_ir_nodes: set[str],
) -> List["scheduler.BaseSchedulerNode"]:
"""
Reorder Reduce Scatter and Wait in the backward pass
reorder wait_i_rs before reduce_scatter_i+1
"""
result_list: List[scheduler.BaseSchedulerNode] = []
wait_list: List[scheduler.BaseSchedulerNode] = []
node_to_type: Dict[scheduler.BaseSchedulerNode, int] = {}
inverse_users, node_users = compute_node_users(snodes)
types = []
for node in snodes:
node_to_type[node] = get_node_type(node, bucketable_ir_nodes)
types.append(get_node_type(node, bucketable_ir_nodes))

if NodeType.REDUCE_SCATTER not in types:
return snodes

for idx, node in enumerate(snodes):
node_type = node_to_type[node]
if node_type in [NodeType.ALL_GATHER, NodeType.COMPUTE, NodeType.AG_WAIT]:
if node not in result_list and node not in wait_list:
result_list.append(node)
elif node_type == NodeType.RS_WAIT:
# there will sometimes be a memory checker node between rs and rs wait
assert node_to_type[snodes[idx - 1]] == NodeType.REDUCE_SCATTER
# gather wait node after reduce scatter
wait_list.append(node)
node_user = node_users[node]
node_user = [n for n in node_user if node_to_type[n] == NodeType.COMPUTE]
# wait_list.extend(node_user)
elif node_type == NodeType.REDUCE_SCATTER:
if len(wait_list) > 0:
# move the i-th wait node before (i+1)-th reduce scatter node
result_list.extend(wait_list)
wait_list = []
# add reduce scatter node
result_list.append(node)

if len(wait_list) > 0:
result_list.extend(wait_list)
return result_list