Skip to content

Conversation

@bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Aug 28, 2025

This fixes the "infinite flow" error in the partitioner when we run example_llama3.py with SAC turned on + marking attention as must_save.

With my changes, here's the tlparse from example_llama3.py: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmprbE9ts/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 (you can see us saving a bunch of primals, matmuls, and attention for bw, nothing else)

I left more details in the comments, but the root cause is that our pass is tagging nodes in the graph slightly differently compared to how compile normally handles tagging of nodes for SAC:

  • our pass takes every op in the forward that was not marked explicitly as "save" and marks it as "recompute". This means that we end up tagging flash_attention_fw as MUST_SAVE, but the getitem output node from it (which is used in the backward compute as PREFER_RECOMPUTE.
  • in vanilla compile, we use TorchDispatchModes to do the tagging. These modes only every intercept OpOverloads, and getitem never gets a tag. This allows the partitioner to not need to handle getitem, and have the decision about whether to recompute getitem be determined solely by its (multi-output) source node

One question I have around SAC handling in autoparallel is: it's still not exactly clear to me why we need our graph pass for marking recompute tags in the long run. With a better graph capture frontend (Simon's changes), can we just re-use the compiler to do the tagging and kill our pass?

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 28, 2025
@bdhirsh bdhirsh requested review from fmassa and wconstab August 28, 2025 14:14
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome, thanks for the fix!

cc @xuanzhang816 as we might want to fix this in the upstream version as well

@fmassa
Copy link
Contributor

fmassa commented Aug 28, 2025

One question I have around SAC handling in autoparallel is: it's still not exactly clear to me why we need our graph pass for marking recompute tags in the long run. With a better graph capture frontend (Simon's changes), can we just re-use the compiler to do the tagging and kill our pass?

We won't need this once Simon's changes lands. We added it as a workaround to get good memory without having the tags being propagated.

That being said, I still think it would be useful to have the mark_nodes_as_must_save_to_stage_recomputation upstreamed in PyTorch somehow, as it is complementary to AutoAC and can be used in conjunction with it

@fmassa fmassa merged commit 2c573ce into main Aug 28, 2025
6 checks passed
@fmassa fmassa deleted the ac_attention_fix branch August 28, 2025 14:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants