AutoSP: fix torch 2.9 fake propagation issues#2
AutoSP: fix torch 2.9 fake propagation issues#2spikerheado1234 merged 4 commits intoneeldani:autospfrom
Conversation
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
neeldani
left a comment
There was a problem hiding this comment.
Thank you for the fixes Masahiro
| @@ -184,15 +184,52 @@ def pass_canonicalize(gm: GraphModule, real_inputs): | |||
|
|
|||
|
|
|||
| def pass_propagate_shapes(gm: torch.fx.GraphModule, real_inputs): | |||
There was a problem hiding this comment.
Suggestion:
Should we use detect_fake_mode defined here. It could be cleaner
Something like this:
all_metadata = [node.meta.get("val") or node.meta.get("example_value")
for node in gm.graph.nodes]
fake_mode = detect_fake_mode(all_metadata)
if fake_mode is None:
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
There was a problem hiding this comment.
detect_fake_mode() didn't work with PyTorch 2.8.0 and PyTorch 2.9.1.
It inspects Dynamo’s tracing-context fake mode and asserted that it did not match the fake mode attached to the graph metadata. In this backend path, we specifically need to reuse the graph-owned fake mode.
| # Torch 2.9 can fail fake propagation through SDPA's masked fake-CUDA path, | ||
| # even though this pass only needs output metadata. Temporarily clear | ||
| # attn_mask so shape propagation can proceed, then restore it immediately; | ||
| # SDPA output shapes are still determined by Q/K/V shapes, not mask values. |
There was a problem hiding this comment.
For my knowledge: why does shape prop fail for masked SDPA? I am curious because it worked in torch 2.7
There was a problem hiding this comment.
I don't have a clear idea, but I feel they have been introducing stricter check and don't care much about the backward compabitibility.
|
LGTM! |
* Fix AutoSP shape propagation fake mode reuse * Fix AutoSP torch 2.9 fake propagation * Fix AutoSP shard slice ordering * Add comments for AutoSP torch 2.9 fixes --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
* Fix AutoSP shape propagation fake mode reuse * Fix AutoSP torch 2.9 fake propagation * Fix AutoSP shard slice ordering * Add comments for AutoSP torch 2.9 fixes --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Signed-off-by: Ahan Gupta <ahangupta.96@gmail.com>
* Fix AutoSP shape propagation fake mode reuse * Fix AutoSP torch 2.9 fake propagation * Fix AutoSP shard slice ordering * Add comments for AutoSP torch 2.9 fixes * Change AutoSP PyTorch requirement to 2.9+ --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Signed-off-by: Ahan Gupta <ahangupta.96@gmail.com>
This PR updates AutoSP to work with PyTorch v2.9 by fixing several fake-tensor and symbolic-shape issues in the graph rewrite and shape-propagation path.
Reuse the existing
FakeTensorMode/ShapeEnvduringpass_propagate_shapes(). The FX graph may already carry fake tensor metadata from tracing, and PyTorch 2.9 is sensitive to mixing that symbolic state with a newly created fake mode. Reusing the graph-owned mode keeps the fake inputs and graph metadata in the same symbolic shape environment.Run
FakeTensorProp(...).propagate_dont_convert_inputs(...)with fake inputs that were already created in the selected fake mode. This avoids a second conversion step and keeps propagation aligned with the sameFakeTensorMode. In practice this fixes the original.item()/aten._local_scalar_dense.defaultfake propagation failure seen on torch 2.9.Canonicalize the fake
autosp::all_to_alloutput shape back to the original sequence symbol when the local dimension isFloorDiv(s, P). On torch 2.9,P * (s // P)may be treated as symbolically distinct fromseven when they are numerically equal. Restoring the original symbol keeps downstream fake shape reasoning consistent.Temporarily clear
attn_maskonly during fake shape propagation for SDPA nodes, and restore it immediately afterward. This works around a separate torch 2.9 failure in the masked fake-CUDA SDPA path while preserving the actual runtime graph. It is acceptable here because this pass only needs output metadata, and SDPA output shapes are determined by Q/K/V shapes rather than mask values.Preserve topological ordering when inserting the sharded
getitemnode inshard_tensor_node(). Torch 2.9 bf16 traces can place the symbolic sequence placeholder after the tensor placeholder, which previously allowedgetitemto be inserted before its symbolic slice dependencies. Anchoring insertion after the later dependency keeps the rewritten FX graph lint-clean.Add regression coverage for the reordered-placeholder case. The new unit test rebuilds a graph where the
SymIntplaceholder followsinput_idsand verifies thatshard_tensor_node()still produces a valid graph. This gives us a deterministic repro for the ordering bug without depending on the full distributed torch 2.9 bf16 path.