Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 1 addition & 13 deletions autoparallel/_passes/split_fsdp_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,12 @@
import torch.fx.node
import torch.utils._pytree as pytree
from torch._functorch._aot_autograd.descriptors import AOTOutput
from torch._functorch.partitioners import _extract_graph_with_inputs_outputs
from torch._inductor.fx_passes.bucketing import (
is_all_gather_into_tensor,
is_reduce_scatter_tensor,
)

# Switch to once https://github.com/pytorch/pytorch/pull/166725 is landed
# from torch._functorch.partitioners import _extract_graph_with_inputs_outputs
from autoparallel._passes.utils import _extract_graph_with_inputs_outputs


@contextmanager
def exclude_from_fx_side_effectful(exclude_vals: set[Any]):
Expand All @@ -42,13 +39,6 @@ def exclude_from_fx_side_effectful(exclude_vals: set[Any]):
)


def _clear_partitioner_tag(g: torch.fx.Graph):
# TODO: Remove this once torch._functorch.partitioners supports ignore_must_be_in_fw_bw
# https://github.com/pytorch/pytorch/pull/166725
for n in g.nodes:
n.meta.pop("partitioner_tag", None)


@dataclasses.dataclass(frozen=True)
class PrefetchOutput(AOTOutput):
pass
Expand Down Expand Up @@ -90,7 +80,6 @@ def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Gra
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs))
)
with exclude_wait_from_fx_side_effectful():
_clear_partitioner_tag(g)
prefetch_g = _extract_graph_with_inputs_outputs(
g,
g_ins,
Expand Down Expand Up @@ -143,7 +132,6 @@ def split_fsdp_reduce_scatters_epilogue(
epi_g_ins_descs: list[AOTOutput] = [EpilogueInput() for _ in range(len(epi_g_ins))]

with exclude_wait_from_fx_side_effectful():
_clear_partitioner_tag(g)
main_g = _extract_graph_with_inputs_outputs(
g,
g_ins,
Expand Down
93 changes: 0 additions & 93 deletions autoparallel/_passes/utils.py

This file was deleted.

Loading