diff --git a/autoparallel/_passes/split_fsdp_collectives.py b/autoparallel/_passes/split_fsdp_collectives.py index 387b9e4..f2983c0 100644 --- a/autoparallel/_passes/split_fsdp_collectives.py +++ b/autoparallel/_passes/split_fsdp_collectives.py @@ -4,11 +4,49 @@ # LICENSE file in the root directory of this source tree. import dataclasses +from contextlib import contextmanager +from functools import partial +from typing import Any import torch +import torch.fx.node import torch.utils._pytree as pytree from torch._functorch._aot_autograd.descriptors import AOTOutput -from torch._functorch.partitioners import _extract_graph_with_inputs_outputs +from torch._inductor.fx_passes.bucketing import ( + is_all_gather_into_tensor, + is_reduce_scatter_tensor, +) + +# Switch to once https://github.com/pytorch/pytorch/pull/166725 is landed +# from torch._functorch.partitioners import _extract_graph_with_inputs_outputs +from autoparallel._passes.utils import _extract_graph_with_inputs_outputs + + +@contextmanager +def exclude_from_fx_side_effectful(exclude_vals: set[Any]): + original_val = torch.fx.node._side_effectful_functions.copy() + try: + torch.fx.node._side_effectful_functions -= exclude_vals + yield + finally: + torch.fx.node._side_effectful_functions.clear() + torch.fx.node._side_effectful_functions.update(original_val) + + +exclude_wait_from_fx_side_effectful = partial( + exclude_from_fx_side_effectful, + { + torch.ops._c10d_functional.wait_tensor, + torch.ops._c10d_functional.wait_tensor.default, + }, +) + + +def _clear_partitioner_tag(g: torch.fx.Graph): + # TODO: Remove this once torch._functorch.partitioners supports ignore_must_be_in_fw_bw + # https://github.com/pytorch/pytorch/pull/166725 + for n in g.nodes: + n.meta.pop("partitioner_tag", None) @dataclasses.dataclass(frozen=True) @@ -16,15 +54,18 @@ class PrefetchOutput(AOTOutput): pass -def split_fsdp_prefetch( - gm: torch.fx.GraphModule, -) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: - g = gm.graph +@dataclasses.dataclass(frozen=True) +class EpilogueInput(AOTOutput): + pass + + +def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Graph]: g_ins = g.find_nodes(op="placeholder") - prefetch_g_outs_map = {} + prefetch_g_outs_map = [] for g_in in g_ins: n = g_in + last_ag = None while True: if len(n.users) != 1: break @@ -32,30 +73,90 @@ def split_fsdp_prefetch( if len(user.all_input_nodes) > 1: break n = user - prefetch_g_outs_map[g_in] = n + if is_all_gather_into_tensor(n): + last_ag = n + if last_ag is None: + prefetch_g_outs_map.append(g_in) + else: + w_n = next(iter(last_ag.users)) + prefetch_g_outs_map.append(w_n) - prefetch_g_outs = list(prefetch_g_outs_map.values()) + prefetch_g_outs = prefetch_g_outs_map prefetch_g_outs_descs: list[AOTOutput] = [ PrefetchOutput() for _ in range(len(prefetch_g_outs)) ] - - prefetch_g = _extract_graph_with_inputs_outputs( - g, - g_ins, - prefetch_g_outs, - prefetch_g_outs_descs, + g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) + g_outs_descs = pytree.arg_tree_leaves( + next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs)) ) + with exclude_wait_from_fx_side_effectful(): + _clear_partitioner_tag(g) + prefetch_g = _extract_graph_with_inputs_outputs( + g, + g_ins, + prefetch_g_outs, + prefetch_g_outs_descs, + ignore_must_be_in_fw_bw=True, + ) + + main_g = _extract_graph_with_inputs_outputs( + g, + prefetch_g_outs, + g_outs, + g_outs_descs, + ignore_must_be_in_fw_bw=True, + ) + return prefetch_g, main_g + +def split_fsdp_reduce_scatters_epilogue( + g: torch.fx.Graph, +) -> tuple[torch.fx.Graph, torch.fx.Graph]: + g_ins = g.find_nodes(op="placeholder") g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) g_outs_descs = pytree.arg_tree_leaves( next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs)) ) - main_g = _extract_graph_with_inputs_outputs( - g, - prefetch_g_outs, - g_outs, - g_outs_descs, - ) - main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g) - prefetch_gm = torch.fx._lazy_graph_module._make_graph_module(gm, prefetch_g) - return prefetch_gm, main_gm + + g_outs_map = [] + for g_out in g_outs: + n = g_out + last_rs = None + while n is not None: + if len(n.all_input_nodes) != 1: + break + n_in = n.all_input_nodes[0] + if len(n_in.users) > 1: + break + prev_n = n + n = n_in + if is_reduce_scatter_tensor(prev_n): + # In AP for mesh dim > 1 + # The reduction of gradients happen in multiple steps + last_rs = n + if last_rs is not None: + g_outs_map.append(last_rs) + else: + g_outs_map.append(g_out) + + epi_g_ins = [n for n in g_outs_map if n is not None] + epi_g_ins_descs: list[AOTOutput] = [EpilogueInput() for _ in range(len(epi_g_ins))] + + with exclude_wait_from_fx_side_effectful(): + _clear_partitioner_tag(g) + main_g = _extract_graph_with_inputs_outputs( + g, + g_ins, + epi_g_ins, + epi_g_ins_descs, + ignore_must_be_in_fw_bw=True, + ) + epi_g = _extract_graph_with_inputs_outputs( + g, + epi_g_ins, + g_outs, + g_outs_descs, + ignore_must_be_in_fw_bw=True, + ) + + return main_g, epi_g diff --git a/autoparallel/_passes/utils.py b/autoparallel/_passes/utils.py new file mode 100644 index 0000000..a5ea1d9 --- /dev/null +++ b/autoparallel/_passes/utils.py @@ -0,0 +1,93 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch.fx as fx +import torch.utils._pytree as pytree +from torch._functorch._aot_autograd.descriptors import AOTOutput + +# TODO(ivankobzarev): Remove parititoner function fork once https://github.com/pytorch/pytorch/pull/166725 is landed + + +class InvalidNodeBase: + def __repr__(self): + return "Invalid Node" + + +InvalidNode = InvalidNodeBase() + + +def _extract_graph_with_inputs_outputs( + joint_graph: fx.Graph, + inputs: list[fx.Node], + outputs: list[fx.Node], + outputs_descs: list[AOTOutput], + subgraph: Optional[str] = None, + ignore_must_be_in_fw_bw: bool = False, +) -> fx.Graph: + """ + Given a graph, extracts out a subgraph that takes the specified nodes as + inputs and returns the specified outputs. + + This includes specifying non-placeholder nodes as inputs. + + The general strategy is to initialize all inputs with proxies as we + encounter them, and trace through the graph, only keeping values which take + in valid proxies. Then, all dead code is eliminated. + """ + new_graph = fx.Graph() + env = {} + + # Add new placeholder nodes in the order specified by the inputs + for node in inputs: + new_node = new_graph.placeholder(node.name) + # Can't use node_copy here as we may be turning previous call_function into placeholders + new_node.meta = node.meta + # pyrefly: ignore [unsupported-operation] + env[node] = new_node + + for node in joint_graph.nodes: + if node in env: + # Node must be one of our inputs. (Any member of env which wasn't an + # input to start must have been created by this loop and won't be in + # joint_graph.nodes). + continue + elif node.op == "placeholder": + env[node] = InvalidNode # type: ignore[assignment] + elif node.op == "call_function": + all_args = pytree.arg_tree_leaves(*node.args, **node.kwargs) + all_args = [ + isinstance(env[x], InvalidNodeBase) + for x in all_args + if isinstance(x, fx.Node) + ] + if any(all_args): + env[node] = InvalidNode # type: ignore[assignment] + continue + # pyrefly: ignore [unsupported-operation, bad-argument-type] + env[node] = new_graph.node_copy(node, lambda x: env[x]) + elif node.op == "get_attr": + # pyrefly: ignore [unsupported-operation, bad-argument-type] + env[node] = new_graph.node_copy(node, lambda x: env[x]) + elif node.op == "output": + pass + output_values = [] + for x in outputs: + if isinstance(x, fx.Node): + if x not in env: + raise RuntimeError(f"Node {x} couldn't be found in env") + assert not isinstance( + env[x], InvalidNodeBase + ), f"Node {x} was invalid, but is output" + output_values.append(env[x]) + else: + output_values.append(x) + out = new_graph.output(tuple(output_values)) + out.meta["desc"] = outputs_descs + + new_graph.eliminate_dead_code() + new_graph.lint() + return new_graph diff --git a/autoparallel/pipeline/passes.py b/autoparallel/pipeline/passes.py deleted file mode 100644 index f219199..0000000 --- a/autoparallel/pipeline/passes.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import dataclasses -from contextlib import contextmanager -from functools import partial -from typing import Any - -import torch -import torch.fx.node -import torch.utils._pytree as pytree -from torch._functorch._aot_autograd.descriptors import AOTOutput -from torch._functorch.partitioners import _extract_graph_with_inputs_outputs -from torch._inductor.fx_passes.bucketing import ( - is_all_gather_into_tensor, - is_reduce_scatter_tensor, -) - - -@contextmanager -def exclude_from_fx_side_effectful(exclude_vals: set[Any]): - original_val = torch.fx.node._side_effectful_functions.copy() - try: - torch.fx.node._side_effectful_functions -= exclude_vals - yield - finally: - torch.fx.node._side_effectful_functions.clear() - torch.fx.node._side_effectful_functions.update(original_val) - - -exclude_wait_from_fx_side_effectful = partial( - exclude_from_fx_side_effectful, - { - torch.ops._c10d_functional.wait_tensor, - torch.ops._c10d_functional.wait_tensor.default, - }, -) - - -@dataclasses.dataclass(frozen=True) -class PrefetchOutput(AOTOutput): - pass - - -@dataclasses.dataclass(frozen=True) -class EpilogueInput(AOTOutput): - pass - - -def split_fsdp_prefetch( - g: torch.fx.Graph, stop_at_all_gather: bool = True -) -> tuple[torch.fx.Graph, torch.fx.Graph]: - g_ins = g.find_nodes(op="placeholder") - prefetch_g_outs_map = [] - - for g_in in g_ins: - n = g_in - has_ag = False - while True: - if len(n.users) != 1: - break - user = next(iter(n.users)) - if len(user.all_input_nodes) > 1: - break - n = user - if stop_at_all_gather and is_all_gather_into_tensor(n): - has_ag = True - w_n = next(iter(n.users)) - n = w_n - break - if stop_at_all_gather and not has_ag: - prefetch_g_outs_map.append(g_in) - else: - prefetch_g_outs_map.append(n) - - prefetch_g_outs = prefetch_g_outs_map - prefetch_g_outs_descs: list[AOTOutput] = [ - PrefetchOutput() for _ in range(len(prefetch_g_outs)) - ] - g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) - g_outs_descs = pytree.arg_tree_leaves( - next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs)) - ) - with exclude_wait_from_fx_side_effectful(): - prefetch_g = _extract_graph_with_inputs_outputs( - g, - g_ins, - prefetch_g_outs, - prefetch_g_outs_descs, - ) - - main_g = _extract_graph_with_inputs_outputs( - g, - prefetch_g_outs, - g_outs, - g_outs_descs, - ) - return prefetch_g, main_g - - -def split_fsdp_reduce_scatters_epilogue( - g: torch.fx.Graph, -) -> tuple[torch.fx.Graph, torch.fx.Graph]: - g_ins = g.find_nodes(op="placeholder") - g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) - g_outs_descs = pytree.arg_tree_leaves( - next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs)) - ) - - g_outs_map = [] - for g_out in g_outs: - n = g_out - has_rs = False - while n is not None: - if len(n.all_input_nodes) != 1: - break - n_in = n.all_input_nodes[0] - if len(n_in.users) > 1: - break - prev_n = n - n = n_in - if is_reduce_scatter_tensor(prev_n): - has_rs = True - break - if has_rs: - g_outs_map.append(n) - else: - g_outs_map.append(g_out) - - epi_g_ins = [n for n in g_outs_map if n is not None] - epi_g_ins_descs: list[AOTOutput] = [EpilogueInput() for _ in range(len(epi_g_ins))] - main_g = _extract_graph_with_inputs_outputs( - g, - g_ins, - epi_g_ins, - epi_g_ins_descs, - ) - epi_g = _extract_graph_with_inputs_outputs( - g, - epi_g_ins, - g_outs, - g_outs_descs, - ) - - return main_g, epi_g diff --git a/tests/test_graph_partition.py b/tests/test_graph_partition.py index f811886..cb0a4d3 100644 --- a/tests/test_graph_partition.py +++ b/tests/test_graph_partition.py @@ -11,6 +11,10 @@ from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.testing._internal.distributed.fake_pg import FakeStore +from autoparallel._passes.split_fsdp_collectives import ( + split_fsdp_prefetch, + split_fsdp_reduce_scatters_epilogue, +) from autoparallel._testing.models.dsv3 import ( DeepSeekV3Model, DeepSeekV3ModelArgs, @@ -97,13 +101,21 @@ def input_fn(): sharding_placement = autop.optimize_placement() pp_mod = autop.apply_placement_pp(sharding_placement) -pp_mod.to_empty(device="cuda") +fw_module, bw_module, graph_meta, shared_param_dict, shared_buffer_dict = pp_mod +# pp_mod.to_empty(device="cuda") # run weight init on our sharded DTensor params # TODO: plumb init_std through # pp_mod.init_weights( # init_std=0.02, buffer_device="cuda" # ) # maybe not correct value -pp_mod.init_weights(buffer_device="cuda") +# pp_mod.init_weights(buffer_device="cuda") + +fw_g = fw_module.graph +bw_g = bw_module.graph + +fw_unshard_g, fw_main_g = split_fsdp_prefetch(fw_g) +bw_main_g, bw_reduce_grad_g = split_fsdp_reduce_scatters_epilogue(bw_g) + x = ( torch.randint( 0,