Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 124 additions & 23 deletions autoparallel/_passes/split_fsdp_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,58 +4,159 @@
# LICENSE file in the root directory of this source tree.

import dataclasses
from contextlib import contextmanager
from functools import partial
from typing import Any

import torch
import torch.fx.node
import torch.utils._pytree as pytree
from torch._functorch._aot_autograd.descriptors import AOTOutput
from torch._functorch.partitioners import _extract_graph_with_inputs_outputs
from torch._inductor.fx_passes.bucketing import (
is_all_gather_into_tensor,
is_reduce_scatter_tensor,
)

# Switch to once https://github.com/pytorch/pytorch/pull/166725 is landed
# from torch._functorch.partitioners import _extract_graph_with_inputs_outputs
from autoparallel._passes.utils import _extract_graph_with_inputs_outputs


@contextmanager
def exclude_from_fx_side_effectful(exclude_vals: set[Any]):
original_val = torch.fx.node._side_effectful_functions.copy()
try:
torch.fx.node._side_effectful_functions -= exclude_vals
yield
finally:
torch.fx.node._side_effectful_functions.clear()
torch.fx.node._side_effectful_functions.update(original_val)


exclude_wait_from_fx_side_effectful = partial(
exclude_from_fx_side_effectful,
{
torch.ops._c10d_functional.wait_tensor,
torch.ops._c10d_functional.wait_tensor.default,
},
)


def _clear_partitioner_tag(g: torch.fx.Graph):
# TODO: Remove this once torch._functorch.partitioners supports ignore_must_be_in_fw_bw
# https://github.com/pytorch/pytorch/pull/166725
for n in g.nodes:
n.meta.pop("partitioner_tag", None)


@dataclasses.dataclass(frozen=True)
class PrefetchOutput(AOTOutput):
pass


def split_fsdp_prefetch(
gm: torch.fx.GraphModule,
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
g = gm.graph
@dataclasses.dataclass(frozen=True)
class EpilogueInput(AOTOutput):
pass


def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Graph]:
g_ins = g.find_nodes(op="placeholder")
prefetch_g_outs_map = {}
prefetch_g_outs_map = []

for g_in in g_ins:
n = g_in
last_ag = None
while True:
if len(n.users) != 1:
break
user = next(iter(n.users))
if len(user.all_input_nodes) > 1:
break
n = user
prefetch_g_outs_map[g_in] = n
if is_all_gather_into_tensor(n):
last_ag = n
if last_ag is None:
prefetch_g_outs_map.append(g_in)
else:
w_n = next(iter(last_ag.users))
prefetch_g_outs_map.append(w_n)

prefetch_g_outs = list(prefetch_g_outs_map.values())
prefetch_g_outs = prefetch_g_outs_map
prefetch_g_outs_descs: list[AOTOutput] = [
PrefetchOutput() for _ in range(len(prefetch_g_outs))
]

prefetch_g = _extract_graph_with_inputs_outputs(
g,
g_ins,
prefetch_g_outs,
prefetch_g_outs_descs,
g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output")))
g_outs_descs = pytree.arg_tree_leaves(
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs))
)
with exclude_wait_from_fx_side_effectful():
_clear_partitioner_tag(g)
prefetch_g = _extract_graph_with_inputs_outputs(
g,
g_ins,
prefetch_g_outs,
prefetch_g_outs_descs,
ignore_must_be_in_fw_bw=True,
)

main_g = _extract_graph_with_inputs_outputs(
g,
prefetch_g_outs,
g_outs,
g_outs_descs,
ignore_must_be_in_fw_bw=True,
)
return prefetch_g, main_g


def split_fsdp_reduce_scatters_epilogue(
g: torch.fx.Graph,
) -> tuple[torch.fx.Graph, torch.fx.Graph]:
g_ins = g.find_nodes(op="placeholder")
g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output")))
g_outs_descs = pytree.arg_tree_leaves(
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs))
)
main_g = _extract_graph_with_inputs_outputs(
g,
prefetch_g_outs,
g_outs,
g_outs_descs,
)
main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g)
prefetch_gm = torch.fx._lazy_graph_module._make_graph_module(gm, prefetch_g)
return prefetch_gm, main_gm

g_outs_map = []
for g_out in g_outs:
n = g_out
last_rs = None
while n is not None:
if len(n.all_input_nodes) != 1:
break
n_in = n.all_input_nodes[0]
if len(n_in.users) > 1:
break
prev_n = n
n = n_in
if is_reduce_scatter_tensor(prev_n):
# In AP for mesh dim > 1
# The reduction of gradients happen in multiple steps
last_rs = n
if last_rs is not None:
g_outs_map.append(last_rs)
else:
g_outs_map.append(g_out)

epi_g_ins = [n for n in g_outs_map if n is not None]
epi_g_ins_descs: list[AOTOutput] = [EpilogueInput() for _ in range(len(epi_g_ins))]

with exclude_wait_from_fx_side_effectful():
_clear_partitioner_tag(g)
main_g = _extract_graph_with_inputs_outputs(
g,
g_ins,
epi_g_ins,
epi_g_ins_descs,
ignore_must_be_in_fw_bw=True,
)
epi_g = _extract_graph_with_inputs_outputs(
g,
epi_g_ins,
g_outs,
g_outs_descs,
ignore_must_be_in_fw_bw=True,
)

return main_g, epi_g
93 changes: 93 additions & 0 deletions autoparallel/_passes/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch.fx as fx
import torch.utils._pytree as pytree
from torch._functorch._aot_autograd.descriptors import AOTOutput

# TODO(ivankobzarev): Remove parititoner function fork once https://github.com/pytorch/pytorch/pull/166725 is landed


class InvalidNodeBase:
def __repr__(self):
return "Invalid Node"


InvalidNode = InvalidNodeBase()


def _extract_graph_with_inputs_outputs(
joint_graph: fx.Graph,
inputs: list[fx.Node],
outputs: list[fx.Node],
outputs_descs: list[AOTOutput],
subgraph: Optional[str] = None,
ignore_must_be_in_fw_bw: bool = False,
) -> fx.Graph:
"""
Given a graph, extracts out a subgraph that takes the specified nodes as
inputs and returns the specified outputs.

This includes specifying non-placeholder nodes as inputs.

The general strategy is to initialize all inputs with proxies as we
encounter them, and trace through the graph, only keeping values which take
in valid proxies. Then, all dead code is eliminated.
"""
new_graph = fx.Graph()
env = {}

# Add new placeholder nodes in the order specified by the inputs
for node in inputs:
new_node = new_graph.placeholder(node.name)
# Can't use node_copy here as we may be turning previous call_function into placeholders
new_node.meta = node.meta
# pyrefly: ignore [unsupported-operation]
env[node] = new_node

for node in joint_graph.nodes:
if node in env:
# Node must be one of our inputs. (Any member of env which wasn't an
# input to start must have been created by this loop and won't be in
# joint_graph.nodes).
continue
elif node.op == "placeholder":
env[node] = InvalidNode # type: ignore[assignment]
elif node.op == "call_function":
all_args = pytree.arg_tree_leaves(*node.args, **node.kwargs)
all_args = [
isinstance(env[x], InvalidNodeBase)
for x in all_args
if isinstance(x, fx.Node)
]
if any(all_args):
env[node] = InvalidNode # type: ignore[assignment]
continue
# pyrefly: ignore [unsupported-operation, bad-argument-type]
env[node] = new_graph.node_copy(node, lambda x: env[x])
elif node.op == "get_attr":
# pyrefly: ignore [unsupported-operation, bad-argument-type]
env[node] = new_graph.node_copy(node, lambda x: env[x])
elif node.op == "output":
pass
output_values = []
for x in outputs:
if isinstance(x, fx.Node):
if x not in env:
raise RuntimeError(f"Node {x} couldn't be found in env")
assert not isinstance(
env[x], InvalidNodeBase
), f"Node {x} was invalid, but is output"
output_values.append(env[x])
else:
output_values.append(x)
out = new_graph.output(tuple(output_values))
out.meta["desc"] = outputs_descs

new_graph.eliminate_dead_code()
new_graph.lint()
return new_graph
Loading
Loading