Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU] Improve memory bandwidth utilization of column reduction #11018

Closed

Conversation

lingzhi98
Copy link
Contributor

@lingzhi98 lingzhi98 commented Mar 28, 2024

Column reduction support vectorization previously, removed after this PR. The reason of disable vectorization is that find no performance gain and the vectorization heuristic is fairly complex and different from the one for
row reductions. But I find vectorization in column reduction is really helpful in gemma model. To re-support it, I modify previous vectorization heuristic to make its logic simple. Now column reduction vectorization only check 2 conditions:
(1) if vectorization will introduce large overhead, checked by MayPreventVectorization like row reduction.
(2) if the last dimension is divisible by vectorization factor.
Though vectorization can get better memory bandwidth, it will decrease active sm core number which will cause worse memory bandwidth. The previous column reduction vectorization heuristic don't enable vectorization if active sm core is smaller than max core number. But if we split reduced dimension into multi part, it can increase active sm core number so that improve memory bandwidth utilization. Though atomic and initialized thunk will be needed, i still find it has performance improvement compared with no vectorization.
And for small minor kept dimension, current column reduction codegen will use small sm core. Splitting reduced dimension can also get better performance in this situation.
After enable above optimizations, gemma 2b (bs1, bfloat16, greedy search, 1024 input tokens, 128 output tokens) e2e performance has 1.1x speedup on A100 40GB.

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Mar 28, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Mar 28, 2024
@kamaljeeti kamaljeeti requested a review from reedwm March 28, 2024 09:13
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Mar 28, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Mar 28, 2024
@reedwm reedwm requested review from jreiffers and removed request for reedwm March 28, 2024 21:38
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Mar 30, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Mar 30, 2024
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Mar 30, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Mar 30, 2024
@penpornk penpornk requested a review from akuegel April 4, 2024 19:09
@akuegel
Copy link
Member

akuegel commented Apr 8, 2024

@jreiffers given that you worked on the reduction emitter quite a bit and know the latest state, can you please review this?

@akuegel akuegel removed their request for review April 8, 2024 13:29
@kamaljeeti
Copy link
Contributor

Hi @jreiffers , can you look into this once? Thanks.

@jreiffers
Copy link
Member

Thanks for your PR and sorry for the delay, I didn't see this. I just took a quick look and have a few comments:

  • We need to make sure this is modeled properly in the cost model. For this, the ReductionInfo::ComputeThreadIdToInputIndexing function in reduction_base.cc may need to be updated. We'll definitely need a new test for it.
  • The reduction emitter is in maintenance mode, the plan is to replace it with the MLIR version. Therefore, if this change is beneficial, we should update the latter as well (which should be significantly easier, it shouldn't be much more than updating the thread -> input indexing).
  • Please add integration tests for the new logic (service/gpu/tests/reduction_vectorization_test.cc).

I'll take a deeper look later.

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request May 22, 2024
Special thanks to github user lingzhi98 who experimented with this in
openxla/xla#11018.

I tried to make the logic as similar for vectorized and non-vectorized
reductions as I could. The vectorized logic looks like this:

- produce N reduced elements per thread, store the intermediate results in
  a vector V
- loop over the N elements of V, writing each one to shmem
- loop over N elements, reading them from shmem and writing the result to
  global memory

PiperOrigin-RevId: 636130464
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request May 22, 2024
Special thanks to github user lingzhi98 who experimented with this in
openxla/xla#11018.

I tried to make the logic as similar for vectorized and non-vectorized
reductions as I could. The vectorized logic looks like this:

- produce N reduced elements per thread, store the intermediate results in
  a vector V
- loop over the N elements of V, writing each one to shmem
- loop over N elements, reading them from shmem and writing the result to
  global memory

PiperOrigin-RevId: 636130464
copybara-service bot pushed a commit that referenced this pull request May 22, 2024
Special thanks to github user lingzhi98 who experimented with this in
#11018.

I tried to make the logic as similar for vectorized and non-vectorized
reductions as I could. The vectorized logic looks like this:

- produce N reduced elements per thread, store the intermediate results in
  a vector V
- loop over the N elements of V, writing each one to shmem
- loop over N elements, reading them from shmem and writing the result to
  global memory

PiperOrigin-RevId: 636243118
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request May 22, 2024
Special thanks to github user lingzhi98 who experimented with this in
openxla/xla#11018.

I tried to make the logic as similar for vectorized and non-vectorized
reductions as I could. The vectorized logic looks like this:

- produce N reduced elements per thread, store the intermediate results in
  a vector V
- loop over the N elements of V, writing each one to shmem
- loop over N elements, reading them from shmem and writing the result to
  global memory

PiperOrigin-RevId: 636243118
@lingzhi98
Copy link
Contributor Author

  1. This Line is weird. It will lead to tiled_size [x, y, z, 2, 1] if row reduction enable vectorization. Maybe it is not as your expected.
  2. Vector dialect will convert to llvm vector type after conversion. I am not sure whether it is well supported by SPIRV.
  3. I am looking for new vectorization heuristic due to I find vectorization will lead to performance regression for some real world model workload. You can keep current vectorization heuristic now.
  4. So this PR will only remain the atomic part and wait for updates of cost model PR, right?

@jreiffers
Copy link
Member

  1. Yes, that's working as intended. It's explained in the comment above: the last dimension is for the number of independent reduction results computed per thread. In a row reduction, that's 1.
  2. Good to know, but I only have access to Nvidia hardware, and that's all I can test with, unfortunately.
  3. That's great, I'm aware the heuristic I put in is not good, so improvements would be very much appreciated.
  4. I didn't even notice you implemented atomics, sorry. Currently we're rewriting all reductions to not require atomics in tree_reduction_rewriter. If you know of cases where that leads to bad performance, that would be great to know.

I'm currently finishing the last indexing optimizations, so I think we should be in launchable state soon.

@jreiffers
Copy link
Member

@olegshyshkov Can you give an update about the cost model changes?

@lingzhi98
Copy link
Contributor Author

This atomic is not same as tree_reduction_rewriter. You can see the below hlo:
fused_reduce {
param_1.15 = bf16[1,2048]{1,0} parameter(1)
bitcast.86.8 = bf16[2048]{0} bitcast(param_1.15)
convert.90.5 = f32[2048]{0} convert(bitcast.86.8)
broadcast.6.6 = f32[2048,256]{1,0} broadcast(convert.90.5), dimensions={0}, metadata={op_name="jit(func)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=bfloat16]"}
param_0.29 = bf16[2048,256]{1,0} parameter(0)
convert.83.3 = f32[2048,256]{1,0} convert(param_0.29)
multiply.8.3 = f32[2048,256]{1,0} multiply(broadcast.6.6, convert.83.3)
constant_9 = f32[] constant(0)
reduce.5 = f32[256]{0} reduce(multiply.8.3, constant_9), dimensions={0}, to_apply=scalar_add_computation
param_2.12 = bf16[2048,256]{1,0} parameter(2)
convert.87.3.clone.1 = f32[2048,256]{1,0} convert(param_2.12)
multiply.9.3.clone.1 = f32[2048,256]{1,0} multiply(broadcast.6.6, convert.87.3.clone.1)
reduce.1.1.clone.1 = f32[256]{0} reduce(multiply.9.3.clone.1, constant_9), dimensions={0}, to_apply=scalar_add_computation
ROOT tuple = (f32[256]{0}, f32[256]{0}) tuple(reduce.5, reduce.1.1.clone.1)
} // fused_reduce
Current implementation will only launch 256 / (32 or 64) blocks, which can not occupy full device. Using atomic can launch more blocks to reach better performance. Tested on A100, has 2x improvement. As I know, tree_reduction_rewriter split 1 reduction to 2 reduction if reduced dim is large to avoid atomic overhead, maybe not for the same purpose.

@jreiffers
Copy link
Member

Ah, gotcha. This optimization doesn't currently exist in XLA, as far as I know. Is the implementation already pushed? I don't see it in the diff for this PR.

@lingzhi98
Copy link
Contributor Author

The code is here. We only need to change tile_y and do nothing else.

@jreiffers
Copy link
Member

But this won't automatically emit atomics. What am I missing?

@lingzhi98
Copy link
Contributor Author

If reduction is not race free, emitter will emit atomic.

@jreiffers
Copy link
Member

Ah, you're talking about the legacy emitter. I don't think we want to make any more changes to that one.

@lingzhi98
Copy link
Contributor Author

lingzhi98 commented May 23, 2024

As I know, reduction mlir emitter will support atomic also in the future. So this change will also useful. And I don't make any change on legacy llvm emitter, just reuse previous implementation.

@jreiffers
Copy link
Member

The legacy emitter is very brittle. Changing a racefree reduction to an atomic one is not generally safe, since there might be an epilogue fusion. I asked @pifon2a to look into whether this transformation can be done as part of tree_reduction_rewriter.

@lingzhi98
Copy link
Contributor Author

lingzhi98 commented May 23, 2024

I have checked if has epilogue (https://github.com/openxla/xla/pull/11018/files#diff-78597727ea860fa53ae0dfcaa961c8b1fd1f2a4f94297def7676586ed5db158eR163). I don't think there is anything that can be done in the tree reduction rewriter, can wait for his conclusion.

@olegshyshkov
Copy link
Contributor

@olegshyshkov Can you give an update about the cost model changes?

I've prepared a Cost Model PR: #13044 based on the idea from #12208.

@lingzhi98
Copy link
Contributor Author

Has any updated about the atomic part? Maybe we can do some things as the below shown if don't want to involve atomic:
(1) add hlo pass to split reduce:
before hlo pass:
param = [2048, 256]
reduce = [256] reduce(param)
after hlo pass:
param = [2048, 256]
slice0 = [1024, 256] slice(param)
slice1 = [1024, 256] slice(param)
reduce0 = [256] reduce(slice0)
reduce1 = [256] reduce(slice1)
concat = [2, 256] concat(reduce0, reduce1)
reduce2 = [256] reduce(concat)
(2) wish to get below fusion after fusion pipeline:
fusion.0 {
param = [2048, 256]
slice0 = [1024, 256] slice(param)
slice1 = [1024, 256] slice(param)
reduce0 = [256] reduce(slice0)
reduce1 = [256] reduce(slice1)
}
fusion.1 {
param0 = [256]
param1 = [256]
concat = [2, 256] concat(param0, param1)
reduce = [256] reduce(concat)
}
(3) make small change on GroupDisjointReductions to put each slice in different group, so that we can also launch more blocks.

@jreiffers
Copy link
Member

I thought about it a bit more and I see the problem now. We'll probably have to implement atomics then.

@cheshire
Copy link
Member

Sorry I'm a bit confused, where are the atomics coming from? We strongly need to ensure determinism, so atomics for addition aren't great.

@jreiffers
Copy link
Member

Right, then this particular optimization simply can't be done (at least with the default flags). For min/max we could do it.

@jreiffers
Copy link
Member

Uh, actually, the tree rewriter idea should work. E.g.:

2048x256 -> 256x32 (up to 64x8 blocks) -> 256

We‘ll still want to adjust the launch grid for the first reduction. The second reduction would be a row reduction.

@lingzhi98
Copy link
Contributor Author

lingzhi98 commented May 31, 2024

Sorry I'm a bit confused, where are the atomics coming from? We strongly need to ensure determinism, so atomics for addition aren't great.

Atomic comes from this idea (#11018 (comment)). But if we want to ensure determinism, split reduction to multiple small reduction may be better.

@lingzhi98
Copy link
Contributor Author

Uh, actually, the tree rewriter idea should work. E.g.:

2048x256 -> 256x32 (up to 64x8 blocks) -> 256

We‘ll still want to adjust the launch grid for the first reduction. The second reduction would be a row reduction.

So openxla has plan to support this optimization pattern, right? If yes, maybe I can close this PR and wait for openxla update.

@jreiffers
Copy link
Member

@pifon2a is still looking into this, AFAIU.

@lingzhi98
Copy link
Contributor Author

Finally, I want to understand one thing: Do you and I have the same idea (#11018 (comment))?
How to ensure that fusion is what we want, as shown in the second point?It doesn't matter if you don't know it clearly now

@jreiffers
Copy link
Member

I think the ideas are different. My idea is essentially this:

2048,256 -> bitcast -> 32,64,256 -> reduce -> 256,32 -> reduce -> 256

So no slicing or concat.

@lingzhi98
Copy link
Contributor Author

Thanks, got it.

@lingzhi98 lingzhi98 closed this May 31, 2024
@lingzhi98 lingzhi98 deleted the lingzhi/optimize_column_reduction branch June 19, 2024 13:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kokoro:force-run Forces CI to rerun
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants