CUDA: add fast walsh-hadamard transform#23615
Conversation
|
|
||
| cudaStream_t stream = ctx.stream(); | ||
| dim3 grid_dims(num_blocks, 1, 1); | ||
| dim3 block_dims(WARP_SIZE, rows_per_block, 1); |
There was a problem hiding this comment.
| dim3 block_dims(WARP_SIZE, rows_per_block, 1); | |
| dim3 block_dims(WARP_SIZE, rows_per_block, 1); // TODO support for warp size 64 |
Unless you want to implement it in this PR. It would need a bit of extra logic for warp size selection due to potential out-of-bounds memory accesses for e.g. head size 96.
There was a problem hiding this comment.
I think the code would only pass pow-of-2 N here
There was a problem hiding this comment.
Oh right, in that case it should be unproblematic to use a warp size of 64. It should be a simple change so it would make sense to include it from the get-go - I'll push a quick commit.
|
Sorry, I accidentally pushed to the wrong branch. I'm currently still prototyping because the performance impact on CDNA seems to be negative both with the original kernel and with the warp size 64 patch I made. |
22c360f to
6ee12a2
Compare
|
That maybe due to register spilling I guess |
|
Sorry, the previous report was wrong. I had accidentally swapped the commits when I compared the performance so I incorrectly thought the code had gotten slower. This is the correct performance: Performance
LLaMA 3 1b has a head size of 64, LLaMA 3 8b a head size of 128, Gemma 2b 256, Gemma 4 26b 512. All tests are done with |
|
I forgot: on the MI100 I implemented a warp size of 64 but this only provided a speedup of like 1%, the bulk of the speedup comes from the work of @am17an . |
This reverts commit c1f1e28.
|
Repro / narrow down: ./build/bin/llama-completion -m ../gemma-4-E2B-it-UD-Q8_K_XL.gguf -ngl 999 -ctk q8_0 -ctv q8_0 -fa on --jinja -p "salut" -n 32 --temp 0 -s 1 ./build/bin/llama-completion -m ../gemma-4-E2B-it-UD-Q8_K_XL.gguf -ngl 999 -fa on --jinja -p "salut" -n 32 --temp 0 -s 1 So it's the KV cache quantization that triggers the bug! |
Overview
Implement FWHT for CUDA, speed-up for cases when we quantize the kv-cache.
Performance on a 5090 with
-ctk q8_0 -ctv q8_0Additional information
Requirements