fix getitem handling in existing SAC tag pass, add attention back to example SAC run #123
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This fixes the "infinite flow" error in the partitioner when we run
example_llama3.pywith SAC turned on + marking attention as must_save.With my changes, here's the tlparse from
example_llama3.py: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmprbE9ts/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 (you can see us saving a bunch of primals, matmuls, and attention for bw, nothing else)I left more details in the comments, but the root cause is that our pass is tagging nodes in the graph slightly differently compared to how compile normally handles tagging of nodes for SAC:
flash_attention_fwasMUST_SAVE, but thegetitemoutput node from it (which is used in the backward compute asPREFER_RECOMPUTE.getitemnever gets a tag. This allows the partitioner to not need to handle getitem, and have the decision about whether to recompute getitem be determined solely by its (multi-output) source nodeOne question I have around SAC handling in autoparallel is: it's still not exactly clear to me why we need our graph pass for marking recompute tags in the long run. With a better graph capture frontend (Simon's changes), can we just re-use the compiler to do the tagging and kill our pass?