-
Notifications
You must be signed in to change notification settings - Fork 9
DualPipeV Fw-Bw Overlapping pass with User Annotations #261
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,13 +4,44 @@ | |
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import copy | ||
| from itertools import dropwhile | ||
|
|
||
| import torch | ||
| import torch.fx as fx | ||
| from torch._inductor.fx_passes.bucketing import is_wait_tensor | ||
| from torch._logging import trace_structured | ||
|
|
||
|
|
||
| def _add_compute_annotations(gm: fx.GraphModule, tag: str): | ||
| """Add compute_region annotations to nodes without custom metadata.""" | ||
| for n in gm.graph.nodes: | ||
| if n.op == "placeholder": | ||
| continue | ||
| if n.meta.get("custom", None) is None: | ||
| n.meta["custom"] = {"compute_region": tag} | ||
| else: | ||
| assert "comm_region" in n.meta["custom"] | ||
| val = n.meta["custom"]["comm_region"] | ||
| n.meta["custom"]["comm_region"] = tag + " " + val | ||
|
|
||
|
|
||
| def _move_wait_tensors_to_compute_region(gm: fx.GraphModule, tag: str): | ||
| """Move wait_tensor nodes from comm_region to compute_region of their users.""" | ||
| for n in gm.graph.nodes: | ||
| if n.op == "placeholder": | ||
| continue | ||
| if "comm_region" in n.meta["custom"] and is_wait_tensor(n): | ||
| assert len(n.users) >= 1, "wait tensor must have at least one user" | ||
| user: fx.Node = next(iter(n.users)) | ||
| if "compute_region" in user.meta["custom"]: | ||
| n.meta["custom"].pop("comm_region") | ||
| n.meta["custom"].update({"compute_region": tag + " " + "wait"}) | ||
| if n.next is not user: | ||
| user.prepend(n) | ||
|
|
||
|
|
||
| def multiplex_fw_bw_graph( | ||
| fw_gm: fx.GraphModule, bw_gm: fx.GraphModule | ||
| fw_gm: fx.GraphModule, bw_gm: fx.GraphModule, overlap_with_annotations: bool = True | ||
| ) -> fx.GraphModule: | ||
| """ | ||
| Multiplexes forward and backward graphs into a single unified graph module. | ||
|
|
@@ -32,67 +63,116 @@ def multiplex_fw_bw_graph( | |
| Note: | ||
| The function preserves node metadata during the merging process. | ||
| """ | ||
| # Mapping to track correspondence between backward graph nodes and new nodes | ||
| if overlap_with_annotations: | ||
| _add_compute_annotations(fw_gm, "forward") | ||
| _add_compute_annotations(bw_gm, "backward") | ||
| _move_wait_tensors_to_compute_region(fw_gm, "forward") | ||
| _move_wait_tensors_to_compute_region(bw_gm, "backward") | ||
|
|
||
| # Mapping to track correspondence between forward graph nodes and new nodes | ||
| old_node_to_new_node: dict[torch.fx.Node, torch.fx.Node] = {} | ||
|
|
||
| # Start with a deep copy of the forward graph as the base | ||
| multiplexed_gm = copy.deepcopy(fw_gm) | ||
| # Start with a deep copy of the backward graph as the base | ||
| multiplexed_gm = copy.deepcopy(bw_gm) | ||
|
|
||
| # Collect all placeholder nodes from the backward graph | ||
| # Collect all placeholder nodes from all the graphs | ||
| bw_placeholders = bw_gm.graph.find_nodes(op="placeholder") | ||
| fw_placeholders = fw_gm.graph.find_nodes(op="placeholder") | ||
| insert_point = multiplexed_gm.graph.find_nodes(op="placeholder")[-1] | ||
|
|
||
| # Insert backward placeholders at the beginning of the multiplexed graph | ||
| # Reversed order ensures correct execution sequence | ||
| with multiplexed_gm.graph.inserting_before(): | ||
| for n in reversed(bw_placeholders): | ||
| # Insert forward placeholders after the backward placeholders of the multiplexed graph | ||
| for n in fw_placeholders: | ||
| with multiplexed_gm.graph.inserting_after(insert_point): | ||
| new_placeholder = multiplexed_gm.graph.placeholder(n.name) | ||
| new_placeholder.meta = n.meta | ||
| new_placeholder.meta = copy.copy(n.meta) | ||
| new_placeholder.target = new_placeholder.name | ||
| old_node_to_new_node[n] = new_placeholder | ||
| insert_point = new_placeholder | ||
|
|
||
| # Find the last placeholder and the output node in the multiplexed graph | ||
| multiplxed_gm_placeholders = multiplexed_gm.graph.find_nodes(op="placeholder") | ||
| assert len(multiplxed_gm_placeholders) == ( | ||
| len(fw_placeholders) + len(bw_placeholders) | ||
| multiplexed_gm_placeholders = multiplexed_gm.graph.find_nodes(op="placeholder") | ||
| assert len(multiplexed_gm_placeholders) == len(fw_placeholders) + len( | ||
| bw_placeholders | ||
| ) | ||
| insert_point = multiplxed_gm_placeholders[-1] | ||
|
|
||
| # Copy all computation nodes from backward graph into multiplexed graph | ||
| fw_outputs = fw_gm.graph.find_nodes(op="output") | ||
| bw_outputs = bw_gm.graph.find_nodes(op="output") | ||
| assert len(bw_outputs) == 1 and len(fw_outputs) == 1 | ||
| bw_graph_op_node = bw_outputs[0] | ||
| for n in bw_gm.graph.nodes: | ||
| if n.op == "placeholder": | ||
| continue | ||
| if n.op == "output": | ||
| continue | ||
| with multiplexed_gm.graph.inserting_after(insert_point): | ||
| fw_nodes_iter = iter(fw_gm.graph.nodes) | ||
| fw_nodes_iter = dropwhile(lambda n: n.op == "placeholder", fw_nodes_iter) | ||
| # Initialize the forward node to be the first non-placeholder node | ||
| fn = next(fw_nodes_iter) | ||
| if overlap_with_annotations: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. imo a simpler way to write this block is to first get a list of contiguous compute/comm/compute/comm/compute segments then do your overlapping using those groups of nodes without needing much state tracking
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For that algorithm to work, it would strictly require pairs of compute and comm blocks, but we don't have a guarantee that they would even exist. More formally, regular expression wise this is what the current algorithm is implementing: |
||
| # Interleave forward and backward nodes to create overlap pattern: | ||
| # bw_compute (if any) -> bw_comm -> fw_compute (if any) -> fw_comm -> [repeat] | ||
| # This allows bw_comm to overlap with fw_compute, and fw_comm to overlap with bw_compute | ||
| bw_in_comm = False | ||
| for bn in multiplexed_gm.graph.nodes: | ||
| if bn.op == "placeholder" or bn.op == "output": | ||
| continue | ||
| # Track when we enter a backward comm region | ||
| if "comm_region" in bn.meta["custom"] and not bw_in_comm: | ||
| bw_in_comm = True | ||
| # When we transition from bw_comm to bw_compute, insert forward nodes | ||
| elif "compute_region" in bn.meta["custom"] and bw_in_comm: | ||
| bw_in_comm = False | ||
| fw_in_comm = False | ||
| insert_point = bn | ||
| # Insert forward nodes before this bw_compute node | ||
| # Note: We cannot reorder nodes within a graph, only their relative order between graphs | ||
| while fn.op != "output": | ||
| if "comm_region" in fn.meta["custom"] and not fw_in_comm: | ||
| fw_in_comm = True | ||
| elif "compute_region" in fn.meta["custom"] and fw_in_comm: | ||
| # Stop when we reach the next fw_compute after fw_comm | ||
| # This ensures we insert one fw_compute + fw_comm cycle per bw_comm -> bw_compute transition | ||
| # If fw starts with comm (no compute before it), we still insert it to overlap with future bw_compute | ||
| fw_in_comm = False | ||
| break | ||
| with multiplexed_gm.graph.inserting_before(insert_point): | ||
| # Copy node and remap its arguments using the node mapping | ||
| new_node = multiplexed_gm.graph.node_copy( | ||
| fn, lambda x: old_node_to_new_node[x] | ||
| ) | ||
| new_node.meta = copy.copy(fn.meta) | ||
| old_node_to_new_node[fn] = new_node | ||
| fn = next(fw_nodes_iter) | ||
| # Insert any remaining forward nodes at the end | ||
| # If overlap_with_annotations is False, this concatenates all fw nodes after bw nodes | ||
| insert_point = multiplexed_gm.graph.find_nodes(op="output")[-1] | ||
| while fn.op != "output": | ||
| with multiplexed_gm.graph.inserting_before(insert_point): | ||
| # Copy node and remap its arguments using the node mapping | ||
| new_node = multiplexed_gm.graph.node_copy( | ||
| n, lambda x: old_node_to_new_node[x] | ||
| fn, lambda x: old_node_to_new_node[x] | ||
| ) | ||
| new_node.meta = n.meta | ||
| old_node_to_new_node[n] = new_node | ||
| insert_point = new_node | ||
| new_node.meta = copy.copy(fn.meta) | ||
| old_node_to_new_node[fn] = new_node | ||
| fn = next(fw_nodes_iter) | ||
|
|
||
| # Collect output arguments from backward graph, remapping to new nodes | ||
| bw_op_node_args = [ | ||
| # Collect output arguments from forward graph, remapping to new nodes | ||
| fw_outputs = fw_gm.graph.find_nodes(op="output") | ||
| multiplexed_graph_outputs = multiplexed_gm.graph.find_nodes(op="output") | ||
| assert len(multiplexed_graph_outputs) == 1 and len(fw_outputs) == 1 | ||
| fw_graph_op_node = fw_outputs[0] | ||
| fw_op_node_args = [ | ||
| old_node_to_new_node[n] if n is not None else None | ||
| for n in bw_graph_op_node.args[0] | ||
| for n in fw_graph_op_node.args[0] | ||
| ] | ||
|
|
||
| # Collect output arguments from multiplexed graph (will contain only fwd_outs) | ||
| multiplexed_graph_outputs = multiplexed_gm.graph.find_nodes(op="output") | ||
| assert len(multiplexed_graph_outputs) == 1 | ||
| # Collect output arguments from multiplexed graph (will contain only bwd_outs) | ||
| multiplexed_graph_op_node = multiplexed_graph_outputs[0] | ||
| fw_op_node_args = list(multiplexed_graph_op_node.args[0]) | ||
| bw_op_node_args = list(multiplexed_graph_op_node.args[0]) | ||
|
|
||
| # Update output node args to prepend backward outputs before forward outputs | ||
| multiplexed_graph_op_node.args = (tuple(bw_op_node_args + fw_op_node_args),) | ||
|
|
||
| multiplexed_gm.graph.eliminate_dead_code() | ||
| multiplexed_gm.graph.lint() | ||
| multiplexed_gm.recompile() | ||
| trace_structured( | ||
| "artifact", | ||
| metadata_fn=lambda: { | ||
| "name": "autoparallel_multiplexed_graph", | ||
| "encoding": "string", | ||
| }, | ||
| payload_fn=lambda: multiplexed_gm.print_readable( | ||
| print_output=False, include_stride=True, include_device=True | ||
| ), | ||
| ) | ||
| return multiplexed_gm | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably want to deepcopy it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get this warning:
pytorch/torch/_tensor.py:167: UserWarning: Accessing the data pointer of FakeTensor is deprecated and will error in PyTorch 2.5. This is almost definitely a bug in your code and will cause undefined behavior with subsystems like torch.compile. Please wrap calls to tensor.data_ptr() in an opaque custom op; If all else fails, you can guard accesses to tensor.data_ptr() on isinstance(tensor, FakeTensor). (Triggered internally at pytorch/c10/core/StorageImpl.cpp:34.) or (type(self) is not Tensor and self.data_ptr() == 0)