Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 118 additions & 38 deletions autoparallel/_passes/graph_multiplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,44 @@
# LICENSE file in the root directory of this source tree.

import copy
from itertools import dropwhile

import torch
import torch.fx as fx
from torch._inductor.fx_passes.bucketing import is_wait_tensor
from torch._logging import trace_structured


def _add_compute_annotations(gm: fx.GraphModule, tag: str):
"""Add compute_region annotations to nodes without custom metadata."""
for n in gm.graph.nodes:
if n.op == "placeholder":
continue
if n.meta.get("custom", None) is None:
n.meta["custom"] = {"compute_region": tag}
else:
assert "comm_region" in n.meta["custom"]
val = n.meta["custom"]["comm_region"]
n.meta["custom"]["comm_region"] = tag + " " + val


def _move_wait_tensors_to_compute_region(gm: fx.GraphModule, tag: str):
"""Move wait_tensor nodes from comm_region to compute_region of their users."""
for n in gm.graph.nodes:
if n.op == "placeholder":
continue
if "comm_region" in n.meta["custom"] and is_wait_tensor(n):
assert len(n.users) >= 1, "wait tensor must have at least one user"
user: fx.Node = next(iter(n.users))
if "compute_region" in user.meta["custom"]:
n.meta["custom"].pop("comm_region")
n.meta["custom"].update({"compute_region": tag + " " + "wait"})
if n.next is not user:
user.prepend(n)


def multiplex_fw_bw_graph(
fw_gm: fx.GraphModule, bw_gm: fx.GraphModule
fw_gm: fx.GraphModule, bw_gm: fx.GraphModule, overlap_with_annotations: bool = True
) -> fx.GraphModule:
"""
Multiplexes forward and backward graphs into a single unified graph module.
Expand All @@ -32,67 +63,116 @@ def multiplex_fw_bw_graph(
Note:
The function preserves node metadata during the merging process.
"""
# Mapping to track correspondence between backward graph nodes and new nodes
if overlap_with_annotations:
_add_compute_annotations(fw_gm, "forward")
_add_compute_annotations(bw_gm, "backward")
_move_wait_tensors_to_compute_region(fw_gm, "forward")
_move_wait_tensors_to_compute_region(bw_gm, "backward")

# Mapping to track correspondence between forward graph nodes and new nodes
old_node_to_new_node: dict[torch.fx.Node, torch.fx.Node] = {}

# Start with a deep copy of the forward graph as the base
multiplexed_gm = copy.deepcopy(fw_gm)
# Start with a deep copy of the backward graph as the base
multiplexed_gm = copy.deepcopy(bw_gm)

# Collect all placeholder nodes from the backward graph
# Collect all placeholder nodes from all the graphs
bw_placeholders = bw_gm.graph.find_nodes(op="placeholder")
fw_placeholders = fw_gm.graph.find_nodes(op="placeholder")
insert_point = multiplexed_gm.graph.find_nodes(op="placeholder")[-1]

# Insert backward placeholders at the beginning of the multiplexed graph
# Reversed order ensures correct execution sequence
with multiplexed_gm.graph.inserting_before():
for n in reversed(bw_placeholders):
# Insert forward placeholders after the backward placeholders of the multiplexed graph
for n in fw_placeholders:
with multiplexed_gm.graph.inserting_after(insert_point):
new_placeholder = multiplexed_gm.graph.placeholder(n.name)
new_placeholder.meta = n.meta
new_placeholder.meta = copy.copy(n.meta)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably want to deepcopy it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get this warning: pytorch/torch/_tensor.py:167: UserWarning: Accessing the data pointer of FakeTensor is deprecated and will error in PyTorch 2.5. This is almost definitely a bug in your code and will cause undefined behavior with subsystems like torch.compile. Please wrap calls to tensor.data_ptr() in an opaque custom op; If all else fails, you can guard accesses to tensor.data_ptr() on isinstance(tensor, FakeTensor). (Triggered internally at pytorch/c10/core/StorageImpl.cpp:34.) or (type(self) is not Tensor and self.data_ptr() == 0)

new_placeholder.target = new_placeholder.name
old_node_to_new_node[n] = new_placeholder
insert_point = new_placeholder

# Find the last placeholder and the output node in the multiplexed graph
multiplxed_gm_placeholders = multiplexed_gm.graph.find_nodes(op="placeholder")
assert len(multiplxed_gm_placeholders) == (
len(fw_placeholders) + len(bw_placeholders)
multiplexed_gm_placeholders = multiplexed_gm.graph.find_nodes(op="placeholder")
assert len(multiplexed_gm_placeholders) == len(fw_placeholders) + len(
bw_placeholders
)
insert_point = multiplxed_gm_placeholders[-1]

# Copy all computation nodes from backward graph into multiplexed graph
fw_outputs = fw_gm.graph.find_nodes(op="output")
bw_outputs = bw_gm.graph.find_nodes(op="output")
assert len(bw_outputs) == 1 and len(fw_outputs) == 1
bw_graph_op_node = bw_outputs[0]
for n in bw_gm.graph.nodes:
if n.op == "placeholder":
continue
if n.op == "output":
continue
with multiplexed_gm.graph.inserting_after(insert_point):
fw_nodes_iter = iter(fw_gm.graph.nodes)
fw_nodes_iter = dropwhile(lambda n: n.op == "placeholder", fw_nodes_iter)
# Initialize the forward node to be the first non-placeholder node
fn = next(fw_nodes_iter)
if overlap_with_annotations:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo a simpler way to write this block is to first get a list of contiguous compute/comm/compute/comm/compute segments

then do your overlapping using those groups of nodes without needing much state tracking

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For that algorithm to work, it would strictly require pairs of compute and comm blocks, but we don't have a guarantee that they would even exist. More formally, regular expression wise this is what the current algorithm is implementing: [bw_compute]* ([bw_comm]+ [fw_compute]* [fw_comm]+)* [remaining_bw]* [remaining_fw]*

# Interleave forward and backward nodes to create overlap pattern:
# bw_compute (if any) -> bw_comm -> fw_compute (if any) -> fw_comm -> [repeat]
# This allows bw_comm to overlap with fw_compute, and fw_comm to overlap with bw_compute
bw_in_comm = False
for bn in multiplexed_gm.graph.nodes:
if bn.op == "placeholder" or bn.op == "output":
continue
# Track when we enter a backward comm region
if "comm_region" in bn.meta["custom"] and not bw_in_comm:
bw_in_comm = True
# When we transition from bw_comm to bw_compute, insert forward nodes
elif "compute_region" in bn.meta["custom"] and bw_in_comm:
bw_in_comm = False
fw_in_comm = False
insert_point = bn
# Insert forward nodes before this bw_compute node
# Note: We cannot reorder nodes within a graph, only their relative order between graphs
while fn.op != "output":
if "comm_region" in fn.meta["custom"] and not fw_in_comm:
fw_in_comm = True
elif "compute_region" in fn.meta["custom"] and fw_in_comm:
# Stop when we reach the next fw_compute after fw_comm
# This ensures we insert one fw_compute + fw_comm cycle per bw_comm -> bw_compute transition
# If fw starts with comm (no compute before it), we still insert it to overlap with future bw_compute
fw_in_comm = False
break
with multiplexed_gm.graph.inserting_before(insert_point):
# Copy node and remap its arguments using the node mapping
new_node = multiplexed_gm.graph.node_copy(
fn, lambda x: old_node_to_new_node[x]
)
new_node.meta = copy.copy(fn.meta)
old_node_to_new_node[fn] = new_node
fn = next(fw_nodes_iter)
# Insert any remaining forward nodes at the end
# If overlap_with_annotations is False, this concatenates all fw nodes after bw nodes
insert_point = multiplexed_gm.graph.find_nodes(op="output")[-1]
while fn.op != "output":
with multiplexed_gm.graph.inserting_before(insert_point):
# Copy node and remap its arguments using the node mapping
new_node = multiplexed_gm.graph.node_copy(
n, lambda x: old_node_to_new_node[x]
fn, lambda x: old_node_to_new_node[x]
)
new_node.meta = n.meta
old_node_to_new_node[n] = new_node
insert_point = new_node
new_node.meta = copy.copy(fn.meta)
old_node_to_new_node[fn] = new_node
fn = next(fw_nodes_iter)

# Collect output arguments from backward graph, remapping to new nodes
bw_op_node_args = [
# Collect output arguments from forward graph, remapping to new nodes
fw_outputs = fw_gm.graph.find_nodes(op="output")
multiplexed_graph_outputs = multiplexed_gm.graph.find_nodes(op="output")
assert len(multiplexed_graph_outputs) == 1 and len(fw_outputs) == 1
fw_graph_op_node = fw_outputs[0]
fw_op_node_args = [
old_node_to_new_node[n] if n is not None else None
for n in bw_graph_op_node.args[0]
for n in fw_graph_op_node.args[0]
]

# Collect output arguments from multiplexed graph (will contain only fwd_outs)
multiplexed_graph_outputs = multiplexed_gm.graph.find_nodes(op="output")
assert len(multiplexed_graph_outputs) == 1
# Collect output arguments from multiplexed graph (will contain only bwd_outs)
multiplexed_graph_op_node = multiplexed_graph_outputs[0]
fw_op_node_args = list(multiplexed_graph_op_node.args[0])
bw_op_node_args = list(multiplexed_graph_op_node.args[0])

# Update output node args to prepend backward outputs before forward outputs
multiplexed_graph_op_node.args = (tuple(bw_op_node_args + fw_op_node_args),)

multiplexed_gm.graph.eliminate_dead_code()
multiplexed_gm.graph.lint()
multiplexed_gm.recompile()
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "autoparallel_multiplexed_graph",
"encoding": "string",
},
payload_fn=lambda: multiplexed_gm.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
return multiplexed_gm
99 changes: 51 additions & 48 deletions autoparallel/_testing/models/dsv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Callable, ClassVar, Literal, Optional, Tuple, Union

import torch
import torch.fx.traceback as fx_traceback
import torch.nn.functional as F
import triton
import triton.language as tl
Expand Down Expand Up @@ -631,61 +632,63 @@ def forward(


def _token_dispatch(routed_input, num_tokens_per_expert, axis_name):
# annotate module input placements/sharding with input_layouts
# ep_size = device_mesh.shape[0]
ep_size = axis_size(axis_name)

# generate the input splits and output splits for all-to-all
with torch.no_grad():
num_tokens_per_expert_group = all_to_all(
num_tokens_per_expert,
None,
None,
with fx_traceback.annotate({"comm_region": "token_dispatch"}):
# annotate module input placements/sharding with input_layouts
# ep_size = device_mesh.shape[0]
ep_size = axis_size(axis_name)

# generate the input splits and output splits for all-to-all
with torch.no_grad():
num_tokens_per_expert_group = all_to_all(
num_tokens_per_expert,
None,
None,
axis_name,
)
input_splits = (
num_tokens_per_expert.view(ep_size, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=True)
)
# NOTE: this would incur a device-to-host sync
output_splits = (
num_tokens_per_expert_group.view(ep_size, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=False)
)
input_splits = input_splits.tolist()
output_splits = output_splits.tolist()

# perform all-to-all
routed_input = all_to_all(
routed_input,
output_splits,
input_splits,
axis_name,
)
input_splits = (
num_tokens_per_expert.view(ep_size, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=True)
)
# NOTE: this would incur a device-to-host sync
output_splits = (
num_tokens_per_expert_group.view(ep_size, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=False)
)
input_splits = input_splits.tolist()
output_splits = output_splits.tolist()

# perform all-to-all
routed_input = all_to_all(
routed_input,
output_splits,
input_splits,
axis_name,
)

# NOTE: After this all-to-all, the routed input is put on proper EP rank.
# However, the num_tokens_per_expert_group is not of the final target format
# [#tokens for local expert 0, #tokens for local expert 1, ...]
# Rather, it is of the format
# [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ...,
# #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...]
# We need to perform another shuffle to get the correct format -- this is done via the function
# generate_permute_indices in moe.py, which also does padding to make sure the number of tokens
# each expert gets locally is a multiple of ALIGN_SIZE_M.
# NOTE: After this all-to-all, the routed input is put on proper EP rank.
# However, the num_tokens_per_expert_group is not of the final target format
# [#tokens for local expert 0, #tokens for local expert 1, ...]
# Rather, it is of the format
# [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ...,
# #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...]
# We need to perform another shuffle to get the correct format -- this is done via the function
# generate_permute_indices in moe.py, which also does padding to make sure the number of tokens
# each expert gets locally is a multiple of ALIGN_SIZE_M.

return routed_input, num_tokens_per_expert_group, input_splits, output_splits
return routed_input, num_tokens_per_expert_group, input_splits, output_splits


def _token_combine(routed_output, input_splits, output_splits, axis_name):
routed_output = all_to_all(
routed_output,
input_splits,
output_splits,
axis_name,
)
return routed_output
with fx_traceback.annotate({"comm_region": "token_combine"}):
routed_output = all_to_all(
routed_output,
input_splits,
output_splits,
axis_name,
)
return routed_output


# @torch.library.custom_op("autoparallel::local_mapped_region", mutates_args=())
Expand Down
2 changes: 1 addition & 1 deletion autoparallel/dtensor_util/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
OpStrategy,
StrategyType,
)
from torch.distributed.tensor._ops.registration import register_op_strategy
from torch.distributed.tensor._ops.utils import (
generate_redistribute_costs,
is_tensor_shardable,
register_op_strategy,
)
from torch.distributed.tensor.placement_types import Placement, Replicate, Shard

Expand Down
Loading