Skip to content

AutoSP: fix torch 2.9 fake propagation issues#2

Merged
spikerheado1234 merged 4 commits intoneeldani:autospfrom
tohtana:pr7860-autosp-faketensorprop-test
Mar 19, 2026
Merged

AutoSP: fix torch 2.9 fake propagation issues#2
spikerheado1234 merged 4 commits intoneeldani:autospfrom
tohtana:pr7860-autosp-faketensorprop-test

Conversation

@tohtana
Copy link

@tohtana tohtana commented Mar 17, 2026

This PR updates AutoSP to work with PyTorch v2.9 by fixing several fake-tensor and symbolic-shape issues in the graph rewrite and shape-propagation path.

  • Reuse the existing FakeTensorMode / ShapeEnv during pass_propagate_shapes(). The FX graph may already carry fake tensor metadata from tracing, and PyTorch 2.9 is sensitive to mixing that symbolic state with a newly created fake mode. Reusing the graph-owned mode keeps the fake inputs and graph metadata in the same symbolic shape environment.

  • Run FakeTensorProp(...).propagate_dont_convert_inputs(...) with fake inputs that were already created in the selected fake mode. This avoids a second conversion step and keeps propagation aligned with the same FakeTensorMode. In practice this fixes the original .item() / aten._local_scalar_dense.default fake propagation failure seen on torch 2.9.

  • Canonicalize the fake autosp::all_to_all output shape back to the original sequence symbol when the local dimension is FloorDiv(s, P). On torch 2.9, P * (s // P) may be treated as symbolically distinct from s even when they are numerically equal. Restoring the original symbol keeps downstream fake shape reasoning consistent.

  • Temporarily clear attn_mask only during fake shape propagation for SDPA nodes, and restore it immediately afterward. This works around a separate torch 2.9 failure in the masked fake-CUDA SDPA path while preserving the actual runtime graph. It is acceptable here because this pass only needs output metadata, and SDPA output shapes are determined by Q/K/V shapes rather than mask values.

  • Preserve topological ordering when inserting the sharded getitem node in shard_tensor_node(). Torch 2.9 bf16 traces can place the symbolic sequence placeholder after the tensor placeholder, which previously allowed getitem to be inserted before its symbolic slice dependencies. Anchoring insertion after the later dependency keeps the rewritten FX graph lint-clean.

  • Add regression coverage for the reordered-placeholder case. The new unit test rebuilds a graph where the SymInt placeholder follows input_ids and verifies that shard_tensor_node() still produces a valid graph. This gives us a deterministic repro for the ordering bug without depending on the full distributed torch 2.9 bf16 path.

tohtana added 4 commits March 16, 2026 14:29
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Copy link
Owner

@neeldani neeldani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the fixes Masahiro

@@ -184,15 +184,52 @@ def pass_canonicalize(gm: GraphModule, real_inputs):


def pass_propagate_shapes(gm: torch.fx.GraphModule, real_inputs):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion:

Should we use detect_fake_mode defined here. It could be cleaner

Something like this:

all_metadata = [node.meta.get("val") or node.meta.get("example_value") 
                    for node in gm.graph.nodes]

fake_mode = detect_fake_mode(all_metadata)

if fake_mode is None:
   fake_mode = FakeTensorMode(shape_env=ShapeEnv())

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

detect_fake_mode() didn't work with PyTorch 2.8.0 and PyTorch 2.9.1.
It inspects Dynamo’s tracing-context fake mode and asserted that it did not match the fake mode attached to the graph metadata. In this backend path, we specifically need to reuse the graph-owned fake mode.

# Torch 2.9 can fail fake propagation through SDPA's masked fake-CUDA path,
# even though this pass only needs output metadata. Temporarily clear
# attn_mask so shape propagation can proceed, then restore it immediately;
# SDPA output shapes are still determined by Q/K/V shapes, not mask values.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my knowledge: why does shape prop fail for masked SDPA? I am curious because it worked in torch 2.7

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a clear idea, but I feel they have been introducing stricter check and don't care much about the backward compabitibility.

@spikerheado1234
Copy link
Collaborator

LGTM!

@spikerheado1234 spikerheado1234 merged commit b417a72 into neeldani:autosp Mar 19, 2026
spikerheado1234 pushed a commit that referenced this pull request Mar 19, 2026
* Fix AutoSP shape propagation fake mode reuse

* Fix AutoSP torch 2.9 fake propagation

* Fix AutoSP shard slice ordering

* Add comments for AutoSP torch 2.9 fixes

---------

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
spikerheado1234 pushed a commit that referenced this pull request Mar 19, 2026
* Fix AutoSP shape propagation fake mode reuse

* Fix AutoSP torch 2.9 fake propagation

* Fix AutoSP shard slice ordering

* Add comments for AutoSP torch 2.9 fixes

---------

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Ahan Gupta <ahangupta.96@gmail.com>
spikerheado1234 pushed a commit that referenced this pull request Mar 19, 2026
* Fix AutoSP shape propagation fake mode reuse

* Fix AutoSP torch 2.9 fake propagation

* Fix AutoSP shard slice ordering

* Add comments for AutoSP torch 2.9 fixes

* Change AutoSP PyTorch requirement to 2.9+

---------

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Ahan Gupta <ahangupta.96@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants