-
Notifications
You must be signed in to change notification settings - Fork 353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[XLA:GPU] Improve memory bandwidth utilization of column reduction #11018
[XLA:GPU] Improve memory bandwidth utilization of column reduction #11018
Conversation
@jreiffers given that you worked on the reduction emitter quite a bit and know the latest state, can you please review this? |
Hi @jreiffers , can you look into this once? Thanks. |
Thanks for your PR and sorry for the delay, I didn't see this. I just took a quick look and have a few comments:
I'll take a deeper look later. |
Special thanks to github user lingzhi98 who experimented with this in openxla/xla#11018. I tried to make the logic as similar for vectorized and non-vectorized reductions as I could. The vectorized logic looks like this: - produce N reduced elements per thread, store the intermediate results in a vector V - loop over the N elements of V, writing each one to shmem - loop over N elements, reading them from shmem and writing the result to global memory PiperOrigin-RevId: 636130464
Special thanks to github user lingzhi98 who experimented with this in openxla/xla#11018. I tried to make the logic as similar for vectorized and non-vectorized reductions as I could. The vectorized logic looks like this: - produce N reduced elements per thread, store the intermediate results in a vector V - loop over the N elements of V, writing each one to shmem - loop over N elements, reading them from shmem and writing the result to global memory PiperOrigin-RevId: 636130464
Special thanks to github user lingzhi98 who experimented with this in #11018. I tried to make the logic as similar for vectorized and non-vectorized reductions as I could. The vectorized logic looks like this: - produce N reduced elements per thread, store the intermediate results in a vector V - loop over the N elements of V, writing each one to shmem - loop over N elements, reading them from shmem and writing the result to global memory PiperOrigin-RevId: 636243118
Special thanks to github user lingzhi98 who experimented with this in openxla/xla#11018. I tried to make the logic as similar for vectorized and non-vectorized reductions as I could. The vectorized logic looks like this: - produce N reduced elements per thread, store the intermediate results in a vector V - loop over the N elements of V, writing each one to shmem - loop over N elements, reading them from shmem and writing the result to global memory PiperOrigin-RevId: 636243118
|
I'm currently finishing the last indexing optimizations, so I think we should be in launchable state soon. |
@olegshyshkov Can you give an update about the cost model changes? |
This atomic is not same as tree_reduction_rewriter. You can see the below hlo: |
Ah, gotcha. This optimization doesn't currently exist in XLA, as far as I know. Is the implementation already pushed? I don't see it in the diff for this PR. |
The code is here. We only need to change tile_y and do nothing else. |
But this won't automatically emit atomics. What am I missing? |
If reduction is not race free, emitter will emit atomic. |
Ah, you're talking about the legacy emitter. I don't think we want to make any more changes to that one. |
As I know, reduction mlir emitter will support atomic also in the future. So this change will also useful. And I don't make any change on legacy llvm emitter, just reuse previous implementation. |
The legacy emitter is very brittle. Changing a racefree reduction to an atomic one is not generally safe, since there might be an epilogue fusion. I asked @pifon2a to look into whether this transformation can be done as part of tree_reduction_rewriter. |
I have checked if has epilogue (https://github.com/openxla/xla/pull/11018/files#diff-78597727ea860fa53ae0dfcaa961c8b1fd1f2a4f94297def7676586ed5db158eR163). I don't think there is anything that can be done in the tree reduction rewriter, can wait for his conclusion. |
I've prepared a Cost Model PR: #13044 based on the idea from #12208. |
Has any updated about the atomic part? Maybe we can do some things as the below shown if don't want to involve atomic: |
I thought about it a bit more and I see the problem now. We'll probably have to implement atomics then. |
Sorry I'm a bit confused, where are the atomics coming from? We strongly need to ensure determinism, so atomics for addition aren't great. |
Right, then this particular optimization simply can't be done (at least with the default flags). For min/max we could do it. |
Uh, actually, the tree rewriter idea should work. E.g.: 2048x256 -> 256x32 (up to 64x8 blocks) -> 256 We‘ll still want to adjust the launch grid for the first reduction. The second reduction would be a row reduction. |
Atomic comes from this idea (#11018 (comment)). But if we want to ensure determinism, split reduction to multiple small reduction may be better. |
So openxla has plan to support this optimization pattern, right? If yes, maybe I can close this PR and wait for openxla update. |
@pifon2a is still looking into this, AFAIU. |
Finally, I want to understand one thing: Do you and I have the same idea (#11018 (comment))? |
I think the ideas are different. My idea is essentially this:
So no slicing or concat. |
Thanks, got it. |
Column reduction support vectorization previously, removed after this PR. The reason of disable vectorization is that find no performance gain and the vectorization heuristic is fairly complex and different from the one for
row reductions. But I find vectorization in column reduction is really helpful in gemma model. To re-support it, I modify previous vectorization heuristic to make its logic simple. Now column reduction vectorization only check 2 conditions:
(1) if vectorization will introduce large overhead, checked by MayPreventVectorization like row reduction.
(2) if the last dimension is divisible by vectorization factor.
Though vectorization can get better memory bandwidth, it will decrease active sm core number which will cause worse memory bandwidth. The previous column reduction vectorization heuristic don't enable vectorization if active sm core is smaller than max core number. But if we split reduced dimension into multi part, it can increase active sm core number so that improve memory bandwidth utilization. Though atomic and initialized thunk will be needed, i still find it has performance improvement compared with no vectorization.
And for small minor kept dimension, current column reduction codegen will use small sm core. Splitting reduced dimension can also get better performance in this situation.
After enable above optimizations, gemma 2b (bs1, bfloat16, greedy search, 1024 input tokens, 128 output tokens) e2e performance has 1.1x speedup on A100 40GB.