diff --git a/autoparallel/_passes/graph_multiplex.py b/autoparallel/_passes/graph_multiplex.py index 387ad9f5..bd61ad6f 100644 --- a/autoparallel/_passes/graph_multiplex.py +++ b/autoparallel/_passes/graph_multiplex.py @@ -4,13 +4,44 @@ # LICENSE file in the root directory of this source tree. import copy +from itertools import dropwhile import torch import torch.fx as fx +from torch._inductor.fx_passes.bucketing import is_wait_tensor +from torch._logging import trace_structured + + +def _add_compute_annotations(gm: fx.GraphModule, tag: str): + """Add compute_region annotations to nodes without custom metadata.""" + for n in gm.graph.nodes: + if n.op == "placeholder": + continue + if n.meta.get("custom", None) is None: + n.meta["custom"] = {"compute_region": tag} + else: + assert "comm_region" in n.meta["custom"] + val = n.meta["custom"]["comm_region"] + n.meta["custom"]["comm_region"] = tag + " " + val + + +def _move_wait_tensors_to_compute_region(gm: fx.GraphModule, tag: str): + """Move wait_tensor nodes from comm_region to compute_region of their users.""" + for n in gm.graph.nodes: + if n.op == "placeholder": + continue + if "comm_region" in n.meta["custom"] and is_wait_tensor(n): + assert len(n.users) >= 1, "wait tensor must have at least one user" + user: fx.Node = next(iter(n.users)) + if "compute_region" in user.meta["custom"]: + n.meta["custom"].pop("comm_region") + n.meta["custom"].update({"compute_region": tag + " " + "wait"}) + if n.next is not user: + user.prepend(n) def multiplex_fw_bw_graph( - fw_gm: fx.GraphModule, bw_gm: fx.GraphModule + fw_gm: fx.GraphModule, bw_gm: fx.GraphModule, overlap_with_annotations: bool = True ) -> fx.GraphModule: """ Multiplexes forward and backward graphs into a single unified graph module. @@ -32,62 +63,101 @@ def multiplex_fw_bw_graph( Note: The function preserves node metadata during the merging process. """ - # Mapping to track correspondence between backward graph nodes and new nodes + if overlap_with_annotations: + _add_compute_annotations(fw_gm, "forward") + _add_compute_annotations(bw_gm, "backward") + _move_wait_tensors_to_compute_region(fw_gm, "forward") + _move_wait_tensors_to_compute_region(bw_gm, "backward") + + # Mapping to track correspondence between forward graph nodes and new nodes old_node_to_new_node: dict[torch.fx.Node, torch.fx.Node] = {} - # Start with a deep copy of the forward graph as the base - multiplexed_gm = copy.deepcopy(fw_gm) + # Start with a deep copy of the backward graph as the base + multiplexed_gm = copy.deepcopy(bw_gm) - # Collect all placeholder nodes from the backward graph + # Collect all placeholder nodes from all the graphs bw_placeholders = bw_gm.graph.find_nodes(op="placeholder") fw_placeholders = fw_gm.graph.find_nodes(op="placeholder") + insert_point = multiplexed_gm.graph.find_nodes(op="placeholder")[-1] - # Insert backward placeholders at the beginning of the multiplexed graph - # Reversed order ensures correct execution sequence - with multiplexed_gm.graph.inserting_before(): - for n in reversed(bw_placeholders): + # Insert forward placeholders after the backward placeholders of the multiplexed graph + for n in fw_placeholders: + with multiplexed_gm.graph.inserting_after(insert_point): new_placeholder = multiplexed_gm.graph.placeholder(n.name) - new_placeholder.meta = n.meta + new_placeholder.meta = copy.copy(n.meta) new_placeholder.target = new_placeholder.name old_node_to_new_node[n] = new_placeholder + insert_point = new_placeholder - # Find the last placeholder and the output node in the multiplexed graph - multiplxed_gm_placeholders = multiplexed_gm.graph.find_nodes(op="placeholder") - assert len(multiplxed_gm_placeholders) == ( - len(fw_placeholders) + len(bw_placeholders) + multiplexed_gm_placeholders = multiplexed_gm.graph.find_nodes(op="placeholder") + assert len(multiplexed_gm_placeholders) == len(fw_placeholders) + len( + bw_placeholders ) - insert_point = multiplxed_gm_placeholders[-1] - - # Copy all computation nodes from backward graph into multiplexed graph - fw_outputs = fw_gm.graph.find_nodes(op="output") - bw_outputs = bw_gm.graph.find_nodes(op="output") - assert len(bw_outputs) == 1 and len(fw_outputs) == 1 - bw_graph_op_node = bw_outputs[0] - for n in bw_gm.graph.nodes: - if n.op == "placeholder": - continue - if n.op == "output": - continue - with multiplexed_gm.graph.inserting_after(insert_point): + fw_nodes_iter = iter(fw_gm.graph.nodes) + fw_nodes_iter = dropwhile(lambda n: n.op == "placeholder", fw_nodes_iter) + # Initialize the forward node to be the first non-placeholder node + fn = next(fw_nodes_iter) + if overlap_with_annotations: + # Interleave forward and backward nodes to create overlap pattern: + # bw_compute (if any) -> bw_comm -> fw_compute (if any) -> fw_comm -> [repeat] + # This allows bw_comm to overlap with fw_compute, and fw_comm to overlap with bw_compute + bw_in_comm = False + for bn in multiplexed_gm.graph.nodes: + if bn.op == "placeholder" or bn.op == "output": + continue + # Track when we enter a backward comm region + if "comm_region" in bn.meta["custom"] and not bw_in_comm: + bw_in_comm = True + # When we transition from bw_comm to bw_compute, insert forward nodes + elif "compute_region" in bn.meta["custom"] and bw_in_comm: + bw_in_comm = False + fw_in_comm = False + insert_point = bn + # Insert forward nodes before this bw_compute node + # Note: We cannot reorder nodes within a graph, only their relative order between graphs + while fn.op != "output": + if "comm_region" in fn.meta["custom"] and not fw_in_comm: + fw_in_comm = True + elif "compute_region" in fn.meta["custom"] and fw_in_comm: + # Stop when we reach the next fw_compute after fw_comm + # This ensures we insert one fw_compute + fw_comm cycle per bw_comm -> bw_compute transition + # If fw starts with comm (no compute before it), we still insert it to overlap with future bw_compute + fw_in_comm = False + break + with multiplexed_gm.graph.inserting_before(insert_point): + # Copy node and remap its arguments using the node mapping + new_node = multiplexed_gm.graph.node_copy( + fn, lambda x: old_node_to_new_node[x] + ) + new_node.meta = copy.copy(fn.meta) + old_node_to_new_node[fn] = new_node + fn = next(fw_nodes_iter) + # Insert any remaining forward nodes at the end + # If overlap_with_annotations is False, this concatenates all fw nodes after bw nodes + insert_point = multiplexed_gm.graph.find_nodes(op="output")[-1] + while fn.op != "output": + with multiplexed_gm.graph.inserting_before(insert_point): # Copy node and remap its arguments using the node mapping new_node = multiplexed_gm.graph.node_copy( - n, lambda x: old_node_to_new_node[x] + fn, lambda x: old_node_to_new_node[x] ) - new_node.meta = n.meta - old_node_to_new_node[n] = new_node - insert_point = new_node + new_node.meta = copy.copy(fn.meta) + old_node_to_new_node[fn] = new_node + fn = next(fw_nodes_iter) - # Collect output arguments from backward graph, remapping to new nodes - bw_op_node_args = [ + # Collect output arguments from forward graph, remapping to new nodes + fw_outputs = fw_gm.graph.find_nodes(op="output") + multiplexed_graph_outputs = multiplexed_gm.graph.find_nodes(op="output") + assert len(multiplexed_graph_outputs) == 1 and len(fw_outputs) == 1 + fw_graph_op_node = fw_outputs[0] + fw_op_node_args = [ old_node_to_new_node[n] if n is not None else None - for n in bw_graph_op_node.args[0] + for n in fw_graph_op_node.args[0] ] - # Collect output arguments from multiplexed graph (will contain only fwd_outs) - multiplexed_graph_outputs = multiplexed_gm.graph.find_nodes(op="output") - assert len(multiplexed_graph_outputs) == 1 + # Collect output arguments from multiplexed graph (will contain only bwd_outs) multiplexed_graph_op_node = multiplexed_graph_outputs[0] - fw_op_node_args = list(multiplexed_graph_op_node.args[0]) + bw_op_node_args = list(multiplexed_graph_op_node.args[0]) # Update output node args to prepend backward outputs before forward outputs multiplexed_graph_op_node.args = (tuple(bw_op_node_args + fw_op_node_args),) @@ -95,4 +165,14 @@ def multiplex_fw_bw_graph( multiplexed_gm.graph.eliminate_dead_code() multiplexed_gm.graph.lint() multiplexed_gm.recompile() + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "autoparallel_multiplexed_graph", + "encoding": "string", + }, + payload_fn=lambda: multiplexed_gm.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) return multiplexed_gm diff --git a/autoparallel/_testing/models/dsv3.py b/autoparallel/_testing/models/dsv3.py index a88bc2b1..efb43525 100644 --- a/autoparallel/_testing/models/dsv3.py +++ b/autoparallel/_testing/models/dsv3.py @@ -8,6 +8,7 @@ from typing import Callable, ClassVar, Literal, Optional, Tuple, Union import torch +import torch.fx.traceback as fx_traceback import torch.nn.functional as F import triton import triton.language as tl @@ -631,61 +632,63 @@ def forward( def _token_dispatch(routed_input, num_tokens_per_expert, axis_name): - # annotate module input placements/sharding with input_layouts - # ep_size = device_mesh.shape[0] - ep_size = axis_size(axis_name) - - # generate the input splits and output splits for all-to-all - with torch.no_grad(): - num_tokens_per_expert_group = all_to_all( - num_tokens_per_expert, - None, - None, + with fx_traceback.annotate({"comm_region": "token_dispatch"}): + # annotate module input placements/sharding with input_layouts + # ep_size = device_mesh.shape[0] + ep_size = axis_size(axis_name) + + # generate the input splits and output splits for all-to-all + with torch.no_grad(): + num_tokens_per_expert_group = all_to_all( + num_tokens_per_expert, + None, + None, + axis_name, + ) + input_splits = ( + num_tokens_per_expert.view(ep_size, -1) + .sum(dim=1) + .to(torch.device("cpu"), non_blocking=True) + ) + # NOTE: this would incur a device-to-host sync + output_splits = ( + num_tokens_per_expert_group.view(ep_size, -1) + .sum(dim=1) + .to(torch.device("cpu"), non_blocking=False) + ) + input_splits = input_splits.tolist() + output_splits = output_splits.tolist() + + # perform all-to-all + routed_input = all_to_all( + routed_input, + output_splits, + input_splits, axis_name, ) - input_splits = ( - num_tokens_per_expert.view(ep_size, -1) - .sum(dim=1) - .to(torch.device("cpu"), non_blocking=True) - ) - # NOTE: this would incur a device-to-host sync - output_splits = ( - num_tokens_per_expert_group.view(ep_size, -1) - .sum(dim=1) - .to(torch.device("cpu"), non_blocking=False) - ) - input_splits = input_splits.tolist() - output_splits = output_splits.tolist() - - # perform all-to-all - routed_input = all_to_all( - routed_input, - output_splits, - input_splits, - axis_name, - ) - # NOTE: After this all-to-all, the routed input is put on proper EP rank. - # However, the num_tokens_per_expert_group is not of the final target format - # [#tokens for local expert 0, #tokens for local expert 1, ...] - # Rather, it is of the format - # [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ..., - # #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...] - # We need to perform another shuffle to get the correct format -- this is done via the function - # generate_permute_indices in moe.py, which also does padding to make sure the number of tokens - # each expert gets locally is a multiple of ALIGN_SIZE_M. + # NOTE: After this all-to-all, the routed input is put on proper EP rank. + # However, the num_tokens_per_expert_group is not of the final target format + # [#tokens for local expert 0, #tokens for local expert 1, ...] + # Rather, it is of the format + # [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ..., + # #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...] + # We need to perform another shuffle to get the correct format -- this is done via the function + # generate_permute_indices in moe.py, which also does padding to make sure the number of tokens + # each expert gets locally is a multiple of ALIGN_SIZE_M. - return routed_input, num_tokens_per_expert_group, input_splits, output_splits + return routed_input, num_tokens_per_expert_group, input_splits, output_splits def _token_combine(routed_output, input_splits, output_splits, axis_name): - routed_output = all_to_all( - routed_output, - input_splits, - output_splits, - axis_name, - ) - return routed_output + with fx_traceback.annotate({"comm_region": "token_combine"}): + routed_output = all_to_all( + routed_output, + input_splits, + output_splits, + axis_name, + ) + return routed_output # @torch.library.custom_op("autoparallel::local_mapped_region", mutates_args=()) diff --git a/autoparallel/dtensor_util/utils.py b/autoparallel/dtensor_util/utils.py index 3341e2e9..09f3c11c 100644 --- a/autoparallel/dtensor_util/utils.py +++ b/autoparallel/dtensor_util/utils.py @@ -17,10 +17,10 @@ OpStrategy, StrategyType, ) +from torch.distributed.tensor._ops.registration import register_op_strategy from torch.distributed.tensor._ops.utils import ( generate_redistribute_costs, is_tensor_shardable, - register_op_strategy, ) from torch.distributed.tensor.placement_types import Placement, Replicate, Shard diff --git a/autoparallel/graph_pp_runner.py b/autoparallel/graph_pp_runner.py index ed3a7820..506e9c9d 100644 --- a/autoparallel/graph_pp_runner.py +++ b/autoparallel/graph_pp_runner.py @@ -5,7 +5,7 @@ import logging from dataclasses import dataclass -from typing import Any, Callable, Optional, Union, cast +from typing import Any, Callable, Optional, Protocol, Union, cast import torch import torch.fx as fx @@ -48,29 +48,77 @@ class GraphMeta: num_input_grads: int +class MultiplexFwBwGraphPass(Protocol): + """Protocol defining the contract for forward-backward graph multiplexing passes. + + Implementations must accept two GraphModules (forward and backward) and return a fused + GraphModule that multiplexes their execution. + + Contract Requirements: + 1. Input placeholders ordering: The returned GraphModule's placeholders must be ordered + as ``bw_placeholders + fw_placeholders`` (backward placeholders concatenated with + forward placeholders, each maintaining their original order from the input graphs). + + 2. Output node args ordering: The returned GraphModule's output node args must contain + outputs ordered as ``bw_outputs + fw_outputs`` (backward outputs concatenated with + forward outputs, each maintaining their original order from the input graphs). + + Example:: + + def my_multiplex_pass( + fw_graph: fx.GraphModule, + bw_graph: fx.GraphModule + ) -> fx.GraphModule: + # Implementation that satisfies the contract + ... + return multiplexed_graph + """ + + def __call__( + self, + fw_graph: fx.GraphModule, + bw_graph: fx.GraphModule, + ) -> fx.GraphModule: + """Multiplex forward and backward graphs into a single fused graph. + + Args: + fw_graph (fx.GraphModule): Forward graph module. + bw_graph (fx.GraphModule): Backward graph module. + + Returns: + fx.GraphModule: Fused graph module satisfying the contract requirements. + """ + ... + + def get_multiplexed_graph_callables( - stage_graphs: dict[int, GraphCallables] + stage_graphs: dict[int, GraphCallables], + multiplex_fw_bw_graph_pass: MultiplexFwBwGraphPass, ) -> dict[tuple[int, int], fx.GraphModule]: """Generate multiplexed graph modules that fuse forward and backward passes from different stages. Creates fused modules for all stage pairs where fw_stage_idx != bw_stage_idx. This enables - pipeline schedules (e.g., ZeroBubble) to overlap computation and reduce bubbles. + pipeline schedules (e.g., DualPipe) to overlap communication with computation. Args: - stage_graphs: Mapping from stage index to GraphCallables containing forward/backward modules. + stage_graphs (dict[int, GraphCallables]): Mapping from stage index to GraphCallables + containing forward/backward modules. + multiplex_fw_bw_graph_pass (MultiplexFwBwGraphPass): A callable that takes two + GraphModules (forward and backward) and returns a fused GraphModule that multiplexes + their execution. Must satisfy the contract defined in + :class:`MultiplexFwBwGraphPass`. Returns: - Mapping from (fw_stage_idx, bw_stage_idx) to fused GraphModule that executes - forward from fw_stage_idx and backward from bw_stage_idx. + dict[tuple[int, int], fx.GraphModule]: Mapping from (fw_stage_idx, bw_stage_idx) to fused + GraphModule that executes forward from fw_stage_idx and backward from bw_stage_idx. """ - from autoparallel._passes.graph_multiplex import multiplex_fw_bw_graph - multiplexed_graph_callables: dict[tuple[int, int], torch.fx.GraphModule] = {} for bw_stage_idx, bw_stage_graph_callables in stage_graphs.items(): for fw_stage_idx, fw_stage_graph_callables in stage_graphs.items(): if bw_stage_idx != fw_stage_idx: - fw_bw_module = multiplex_fw_bw_graph( - fw_stage_graph_callables.fw, bw_stage_graph_callables.full_bw + fw_bw_module = multiplex_fw_bw_graph_pass( + fw_stage_graph_callables.fw, + bw_stage_graph_callables.full_bw, ) multiplexed_graph_callables[(fw_stage_idx, bw_stage_idx)] = fw_bw_module return multiplexed_graph_callables diff --git a/examples/example_ds3_pp.py b/examples/example_ds3_pp.py index 1f682fd5..fa220644 100644 --- a/examples/example_ds3_pp.py +++ b/examples/example_ds3_pp.py @@ -3,10 +3,10 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. -import functools import logging import os from contextlib import nullcontext +from functools import partial from typing import Callable, Optional import torch @@ -384,6 +384,7 @@ def last_stage_inp_with_loss_fn(): for stages in pp_rank_to_stage_indices.values(): assert len(stages) * pp_degree == len(virtual_pp_stages) stage_indices_current_pp_rank = pp_rank_to_stage_indices[pp_rank] + should_log_weights = should_log_fw_outs = False if rng_seed: # Compute the ranks to log from # 1. for fw_outs, log from coord [pp_rank_containing_last_stage, 0, 0] @@ -591,9 +592,14 @@ def last_stage_inp_with_loss_fn(): schedule.register_custom_function(BACKWARD_INPUT, stage_backward_input) schedule.register_custom_function(BACKWARD_WEIGHT, stage_backward_weight) if schedule_name == "DualPipeV": - multiplexed_graph_callables = get_multiplexed_graph_callables(stage_graphs) + from autoparallel._passes.graph_multiplex import multiplex_fw_bw_graph + + multiplexed_graph_callables = get_multiplexed_graph_callables( + stage_graphs, + partial(multiplex_fw_bw_graph, overlap_with_annotations=True), + ) schedule.register_custom_function( - OVERLAP_F_B, functools.partial(overlap_fw_bw, multiplexed_graph_callables) + OVERLAP_F_B, partial(overlap_fw_bw, multiplexed_graph_callables) ) # Step 7. Register the schedule with the graph runner @@ -618,7 +624,7 @@ def last_stage_inp_with_loss_fn(): ) if pp_rank == 0: x = runtime_input_fn_first_stage() - if rng_seed: + if numerics_logger is not None: numerics_logger.log_diff( x.to(torch.float32), prefix="full batch input" ) @@ -635,10 +641,10 @@ def last_stage_inp_with_loss_fn(): }, payload_fn=lambda: f"losses: {losses}", ) - - numerics_logger.log_pp_grads( - model, stage_mods, num_world_stages, should_log=should_log_weights - ) + if numerics_logger is not None: + numerics_logger.log_pp_grads( + model, stage_mods, num_world_stages, should_log=should_log_weights + ) print("All good!") diff --git a/examples/example_pp_graph_passes.py b/examples/example_pp_graph_passes.py index e6d8e98c..bed8c37f 100644 --- a/examples/example_pp_graph_passes.py +++ b/examples/example_pp_graph_passes.py @@ -183,9 +183,8 @@ def _run_graph_test( """Execute forward and backward passes with specified graph options.""" if use_multiplexed_graph: multiplexed_fw_bw_module = multiplex_fw_bw_graph( - graph_modules.fw, graph_modules.full_bw + graph_modules.fw, graph_modules.full_bw, overlap_with_annotations=True ) - with ( FakeTensorMode( allow_non_fake_inputs=True,