diff --git a/autoparallel/_passes/split_fsdp_collectives.py b/autoparallel/_passes/split_fsdp_collectives.py index f2983c0..a411fb3 100644 --- a/autoparallel/_passes/split_fsdp_collectives.py +++ b/autoparallel/_passes/split_fsdp_collectives.py @@ -12,15 +12,12 @@ import torch.fx.node import torch.utils._pytree as pytree from torch._functorch._aot_autograd.descriptors import AOTOutput +from torch._functorch.partitioners import _extract_graph_with_inputs_outputs from torch._inductor.fx_passes.bucketing import ( is_all_gather_into_tensor, is_reduce_scatter_tensor, ) -# Switch to once https://github.com/pytorch/pytorch/pull/166725 is landed -# from torch._functorch.partitioners import _extract_graph_with_inputs_outputs -from autoparallel._passes.utils import _extract_graph_with_inputs_outputs - @contextmanager def exclude_from_fx_side_effectful(exclude_vals: set[Any]): @@ -42,13 +39,6 @@ def exclude_from_fx_side_effectful(exclude_vals: set[Any]): ) -def _clear_partitioner_tag(g: torch.fx.Graph): - # TODO: Remove this once torch._functorch.partitioners supports ignore_must_be_in_fw_bw - # https://github.com/pytorch/pytorch/pull/166725 - for n in g.nodes: - n.meta.pop("partitioner_tag", None) - - @dataclasses.dataclass(frozen=True) class PrefetchOutput(AOTOutput): pass @@ -90,7 +80,6 @@ def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Gra next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs)) ) with exclude_wait_from_fx_side_effectful(): - _clear_partitioner_tag(g) prefetch_g = _extract_graph_with_inputs_outputs( g, g_ins, @@ -143,7 +132,6 @@ def split_fsdp_reduce_scatters_epilogue( epi_g_ins_descs: list[AOTOutput] = [EpilogueInput() for _ in range(len(epi_g_ins))] with exclude_wait_from_fx_side_effectful(): - _clear_partitioner_tag(g) main_g = _extract_graph_with_inputs_outputs( g, g_ins, diff --git a/autoparallel/_passes/utils.py b/autoparallel/_passes/utils.py deleted file mode 100644 index a5ea1d9..0000000 --- a/autoparallel/_passes/utils.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Optional - -import torch.fx as fx -import torch.utils._pytree as pytree -from torch._functorch._aot_autograd.descriptors import AOTOutput - -# TODO(ivankobzarev): Remove parititoner function fork once https://github.com/pytorch/pytorch/pull/166725 is landed - - -class InvalidNodeBase: - def __repr__(self): - return "Invalid Node" - - -InvalidNode = InvalidNodeBase() - - -def _extract_graph_with_inputs_outputs( - joint_graph: fx.Graph, - inputs: list[fx.Node], - outputs: list[fx.Node], - outputs_descs: list[AOTOutput], - subgraph: Optional[str] = None, - ignore_must_be_in_fw_bw: bool = False, -) -> fx.Graph: - """ - Given a graph, extracts out a subgraph that takes the specified nodes as - inputs and returns the specified outputs. - - This includes specifying non-placeholder nodes as inputs. - - The general strategy is to initialize all inputs with proxies as we - encounter them, and trace through the graph, only keeping values which take - in valid proxies. Then, all dead code is eliminated. - """ - new_graph = fx.Graph() - env = {} - - # Add new placeholder nodes in the order specified by the inputs - for node in inputs: - new_node = new_graph.placeholder(node.name) - # Can't use node_copy here as we may be turning previous call_function into placeholders - new_node.meta = node.meta - # pyrefly: ignore [unsupported-operation] - env[node] = new_node - - for node in joint_graph.nodes: - if node in env: - # Node must be one of our inputs. (Any member of env which wasn't an - # input to start must have been created by this loop and won't be in - # joint_graph.nodes). - continue - elif node.op == "placeholder": - env[node] = InvalidNode # type: ignore[assignment] - elif node.op == "call_function": - all_args = pytree.arg_tree_leaves(*node.args, **node.kwargs) - all_args = [ - isinstance(env[x], InvalidNodeBase) - for x in all_args - if isinstance(x, fx.Node) - ] - if any(all_args): - env[node] = InvalidNode # type: ignore[assignment] - continue - # pyrefly: ignore [unsupported-operation, bad-argument-type] - env[node] = new_graph.node_copy(node, lambda x: env[x]) - elif node.op == "get_attr": - # pyrefly: ignore [unsupported-operation, bad-argument-type] - env[node] = new_graph.node_copy(node, lambda x: env[x]) - elif node.op == "output": - pass - output_values = [] - for x in outputs: - if isinstance(x, fx.Node): - if x not in env: - raise RuntimeError(f"Node {x} couldn't be found in env") - assert not isinstance( - env[x], InvalidNodeBase - ), f"Node {x} was invalid, but is output" - output_values.append(env[x]) - else: - output_values.append(x) - out = new_graph.output(tuple(output_values)) - out.meta["desc"] = outputs_descs - - new_graph.eliminate_dead_code() - new_graph.lint() - return new_graph