-
Notifications
You must be signed in to change notification settings - Fork 375
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[XLA:GPU] Fine-grained remat policy makes async/pipelined collectives execute in the main stream #14397
Comments
corresponding JAX issue: |
@golechwierowicz could you take a look? |
In general if the latency hiding scheduler is not able to schedule any computation within the async collective then this collective will be executed on the compute stream. Can you provide an HLO based, single host - 2 gpu repro? You can do this via |
First of all, let me clarify my training setup: Last week I've performed tons of various experiments and noticed that reason behind this behavior might not be fine-grained checkpoint activations but rather some weird XLA behavior. I've tested following setups: For 7B, 30B models, as I mentioned in the issue, asynchronous communications are working correctly without activation checkpointing and are executing in compute stream when I try to checkpoint anything. But for 70B model all communications are actually being executed in the compute stream even without selective activation checkpointing. On the other hand, for the single-node repro I've tried following setups: xla_dump_selective_saveable.tgz xla_dump_nothing_saveable.tgz In single-node selective checkpointing repro I've noticed that during forward pass weights' all-gather performed asynchronously while all communication during backward pass are executed in compute stream, which is even weirder. |
8 GPUs is fine. I don't see Alternatively I can also work with a small single node JAX-based repro. |
Sorry, I see, looks like with JAX compilation cache enabled, cached modules are not being dumped. |
Unfortunately I am missing some custom kernels to run it purely at XLA level.
I was however able to look at intermediate passes and have some idea. Can you try the following flags in addition the flags you provide already?
and then
and see if it leads to an improvement? If not can you get rid of custom kernels so I can debug this further? I want to run it via |
Thank you, I've tried Here is XLA dump with this flag enabled in case you need it. Meanwhile, I'll try to reproduce this behavior without custom triton kernels. |
Unfortunately I can't reproduce this behavior without FlashAttention triton kernel, communications are performed asynchronously with and without activation checkpointing. Maybe reason behind this is that I had to set batch size 2 time smaller because otherwise it wouldn't fit in memory. |
Also, can you please elaborate on how exactly |
It prevents too aggressive copy elision which happens at a cost of insertion of control predecessors/successors. These additional control flow constructs break the flow of latency hiding scheduler (LHS). So disabling region analysis creates extra copies but this is still beneficial because it unlocks compute-overlap opportunities for LHS in the compiler. Of course, the XLA documentation in this regard is lacking but I have some code pointers. Copy insertion xla/xla/service/copy_insertion.h Lines 27 to 45 in 6deb462
Latency Hiding Scheduler xla/xla/service/latency_hiding_scheduler.h Line 737 in 6deb462
|
Also, for GPUs, this behavior should be a default one in the next JAX release. |
Description
Hi, I have following setup:
I'm using following flags:
To speedup backward by fine-grained reduction of activations recomputation, I marked each dense layer's output in transformer block with specific name:
So, for example, in attention layer I have "dot_attention_query", "dot_attention_key", "dot_attention_value", "dot_attention_out".
And then I apply checkpoint policy on scanned function which accepts list of activation names to checkpoint:
and then scan It over embeddings:
If I set self.config.save_names_for_bwd to empty list (which is basically equivalent to "nothing_saveable" policy), then communications works correctly - all-gather/reduce-scatters/all-reduces are overlapped with computations, as can be seen on this perfetto trace:
![Screenshot 2024-07-03 at 14 33 27](https://private-user-images.githubusercontent.com/48059208/345421947-e89aaa29-bead-4aaf-988b-7303537c58c1.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjIyNTMzNzgsIm5iZiI6MTcyMjI1MzA3OCwicGF0aCI6Ii80ODA1OTIwOC8zNDU0MjE5NDctZTg5YWFhMjktYmVhZC00YWFmLTk4OGItNzMwMzUzN2M1OGMxLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA3MjklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwNzI5VDExMzc1OFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWY2M2UyZTc3ZDk3ZGQyNGZlMDgxNTZiYTllZTdmZTQxMGFmYjUyMDZlNGE1NjY0YWZiMjVjNGZkNDNkOGFlZTQmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.Fa65eOk8Vro41uJLoyxdMDOyJvEMEUfSBisu0yPRPow)
nothing_saveable.tgz
But as soon as I start to specify some names in self.config.save_names_for_bwd, for example,
While these activations is indeed not recomputed during backward pass, all communications are executed in main stream without any overlapping with computations:
![Screenshot 2024-07-03 at 14 35 06](https://private-user-images.githubusercontent.com/48059208/345422469-8dffb1ab-780b-4dc2-a647-96fb1616a805.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjIyNTMzNzgsIm5iZiI6MTcyMjI1MzA3OCwicGF0aCI6Ii80ODA1OTIwOC8zNDU0MjI0NjktOGRmZmIxYWItNzgwYi00ZGMyLWE2NDctOTZmYjE2MTZhODA1LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA3MjklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwNzI5VDExMzc1OFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWFhMTg1ZDEwMzUyNWExODA4ZWRiOGViMzg3NzhlNGRiMDk0ZjI5NmY0NzY4MTRjMjk4ZmUyODZkODY0YzY5NjEmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.qt6DKaNbJ75Y_4MwOfAur36pSCq5NpUnKzZrB2n8BNA)
save_only_these_names_trace.tgz
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: