gfx1151 nwarps, tile sizing to curb VGPR pressure#21344
gfx1151 nwarps, tile sizing to curb VGPR pressure#21344pedapudi wants to merge 3 commits intoggml-org:masterfrom
Conversation
|
Strangely, there's been no improvement on my machine. Is there any other information required? |
|
@IIIIIllllIIIIIlllll thanks for looking. I'm also on Ubuntu 25.10 using ROCm 7.2.1. I'm surprised that you aren't seeing any improvement. I don't know how you applied the patch, built llama.cpp, or how you tested. I'll update the PR with my cmake flags. |
|
Here is my command: The model: https://huggingface.co/Ex0bit/Qwen3.5-122B-A10B-PRISM-LITE-GGUF My testing method: I started the model using
|
|
I tested this PR with more models (arch linux 6.19, rocm 7.2.1, vulkan 1.4.341): Mistral-Small-4-119B-2603-UD-Q4_K_XL, Qwen3.5-122B-A10B-UD-Q4_K_XL, NVIDIA-Nemotron-3-Super-120B-A12B-UD-Q4_K_XL. (I added Vulkan, since I was also curious about the comparison ROCm/Vulkan). For ROCm, this PR performs consistently better for the 3 tested models Mistral, Qwen3.5, NVIDIA-Nemotron (with the excetion of -1% for Qwen, which also may be noise). In summary (ROCm before/after): Mistral-Small-4-119B (ROCm)
Qwen3.5-122B-A10B (ROCm)
NVIDIA-Nemotron-3-Super-120B-A12B (ROCm)
Master: (Before) This PR: (After) And just for curiosity, the comparison Vulkan/ROCm (after this PR): ROCm vs Vulkan highlights/lowlights (after PR)
|
|
@tbocek Thank you for doing the cross model testing! I'm glad the PR shows uplift across models. I have some questions about your empirical numbers which look lower than what I see with llama-bench for the Qwen3.5 122B model. I re-ran llama-bench with Unsloth's UD Q4_K_XL to match your run, and I see: BEFORE $ ./bin/llama-bench --model /home/sunil/models/unsloth/qwen35-122b-ud-q4_k_xl/unsloth_Qwen3.5-122B-A10B-GGUF_UD-Q4_K_XL_Qwen3.5-122B-A10B-UD-Q4_K_XL-00001-of-00003.gguf -p 128,256,512,1024,2048,4096 -n 0 --n-gpu-layers 99 --flash-attn 1 --mmap 0 --direct-io 1 --ubatch-size 2048 --batch-size 2048 -r 5
ggml_cuda_init: found 1 ROCm devices (Total VRAM: 126976 MiB):
Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32, VRAM: 126976 MiB
| model | size | params | backend | ngl | n_ubatch | fa | mmap | dio | test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --: | --------------: | -------------------: |
| qwen35moe 122B.A10B Q4_K - Medium | 71.73 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp128 | 181.31 ± 4.92 |
| qwen35moe 122B.A10B Q4_K - Medium | 71.73 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp256 | 269.33 ± 3.56 |
| qwen35moe 122B.A10B Q4_K - Medium | 71.73 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp512 | 355.28 ± 1.51 |
| qwen35moe 122B.A10B Q4_K - Medium | 71.73 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp1024 | 418.42 ± 1.23 |
| qwen35moe 122B.A10B Q4_K - Medium | 71.73 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp2048 | 429.09 ± 5.65 |
| qwen35moe 122B.A10B Q4_K - Medium | 71.73 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp4096 | 405.33 ± 2.56 |
build: 7c7d6ce5c (8642)AFTER ./bin/llama-bench --model /home/sunil/models/unsloth/qwen35-122b-ud-q4_k_xl/unsloth_Qwen3.5-122B-A10B-GGUF_UD-Q4_K_XL_Qwen3.5-122B-A10B-UD-Q4_K_XL-00001-of-00003.gguf -p 128,256,512,1024,2048,4096 -n 0 --n-gpu-layers 99 --flash-attn 1 --mmap 0 --direct-io 1 --ubatch-size 2048 --batch-size 2048 -r 5
ggml_cuda_init: found 1 ROCm devices (Total VRAM: 126976 MiB):
Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32, VRAM: 126976 MiB
| model | size | params | backend | ngl | n_ubatch | fa | mmap | dio | test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --: | --------------: | -------------------: |
| qwen35moe 122B.A10B Q4_K - Medium | 71.73 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp128 | 314.28 ± 5.58 |
| qwen35moe 122B.A10B Q4_K - Medium | 71.73 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp256 | 411.32 ± 3.45 |
| qwen35moe 122B.A10B Q4_K - Medium | 71.73 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp512 | 488.98 ± 2.14 |
| qwen35moe 122B.A10B Q4_K - Medium | 71.73 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp1024 | 442.81 ± 1.63 |
| qwen35moe 122B.A10B Q4_K - Medium | 71.73 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp2048 | 553.57 ± 5.64 |
| qwen35moe 122B.A10B Q4_K - Medium | 71.73 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp4096 | 494.13 ± 3.09 |
build: e0e3c3fc6 (8643)(note the differing build IDs but they are not material here.) Both baseline and with this PR, the prefill t/s is higher. It's possibly worth normalizing for whatever differences exist between the setups, but it seems your findings corroborate some of the relative improvement with the PR. The falloff you are noticing at higher lengths seems to have a different slope though. |
|
@pedapudi I just realized, I had different build flags, I was using ROCWMMA_FATTN=ON. Now with Qwen3.5-122B-A10B-UD-Q4_K_XL with ROCWMMA_FATTN=OFF. Now the number are closer to yours. Summary:
before PR: and after PR: |
|
@tbocek So glad that you were able to identify the discrepancy! Thanks again for testing. |
|
@IMbackK @JohannesGaessler does this look okay? I don't have the hardware to test but a lot of people seem to have confirmed it looks good. |
|
I have yet to find the time for a proper review, but on a surface level im not convinced with the current state of this. For starters this is trying to solve register pressure by tuning the values for rdna3.5 using gfx1151 as a model, but other rdna3.5 gpus have different size register files. |
|
Thank you, @am17an and @IMbackK.
You may be right, and it's a reasonable point to raise here, but I would like to offer a few things for you to consider before holding back on incremental improvements (within the current structure of the code, at least):
As I mentioned in my original issue, there is no static sizing that's going to work across architectures. Sidebar: One idea I'd love for llama.cpp is to have a |
There was a problem hiding this comment.
What are your intentions with the changes to mmvq.cu? They look wrong.
There was a problem hiding this comment.
I wanted to update the structure to support RDNA 3.5 more natively and not reuse preexisting configurations. I've reverted the changes for now to not distract from the other changes. Thanks!
Looking at this table, the in my opinion correct way to do it is to use the value for discrete RDNA3 GPUs for those APUs that only have 512 kiB of registers. |
Well right now everything is tuned for RDNA2 which is even worse. Incremental improvements always make sense especially if they focus on the more popular cases (Strix Halo) over the less popular ones (Strix Point) |
You mean RDNA2 DGPUs presumably, the large register file RDNA3 dgpus (gfx1100 and gfx1101) also have >512 kiB of registers. |
|
Regarding register pressure we really only have 3 cases: gfx1100, gfx1101, gfx1151, and gfx12 with 768 32 wide vector registers Btw the table is wrong, the unit is not kiB, its the number of vector registers. A single register being 1024 bits for rdna and 2048 bits for gcn/cnda |
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
|
Thanks for the discussion so far. I'm a little confused as to where we are. Is the suggestion to abandon this change? If so, would someone else pick up realizing the performance opportunity? If not, what are the next steps? Thank you! (PS. That table indeed looks misleading :)) |
I strongly disagree with this notion
If the resulting kernels are indeed below 512 registers (see GGML_HIP_EXPORT_METRICS) i think this is conceptually fine and likely to be positive for the other gfx115x gpus. Ideally someone would benchmark this fact, however. Further it makes sense to try these values on the other RDNA devices, as they would also spill and widen the filter further, i will check this on gfx1100 soon. This dosent have to be part of this pr. |
|
Thank you, @IMbackK
Yes, reasonable people can disagree :) I respect your position, especially from a maintainer, that no hardware will be alienated. After all, that is the crux of this PR as well. For a later time: I think there is a structural opportunity in the implementation to support different GPU architectures with less maintenance burden, eg., organizing different config files, or an adaptive approach like my prior comment was conceptualizing using a "probe the host hardware" step.
Thank you, this is all very reasonable. AFAIK we're targeting 256 VGPR for gfx1151 (ideally slightly below to leave room for system registers). I attached the remarks from GGML_HIP_EXPORT_METRICS (which has been a useful utility in validating some additional changing outside this PR I'm trying as well!). There is still some VGPR spill (especially with IQ quantization). At least with mmq_x, if I lower it further (even for just Q8), there does not seem to have any improvement in performance and high potential for regressions. |
Performance changes
In my testing the performance changes from this PR are very inconsistent across batch sizes and data types. It cannot be merged like this. |
|
@JohannesGaessler thank you for the sweep! I understand your position that the variance is not desirable. Asking naively, is the sweep impacted by the issue described in this PR: #21282 Let me know if you have suggestions on how to reduce the variance. Thanks. I'd also appreciate your eyes on the sweep @tbocek did showing material benefit on more modern and larger models than llama 8B. |
|
Perhaps corroborating your findings: this PR appears to be a benefit for MoE models, but a wildcard for dense model (eg., Qwen 3.5 9B). There might be a path forward if this change is categorically not beneficial for dense models simply by reverting to the old values if the model is dense in the mmq code path? I don't quite know why this isn't beneficial (have not had a chance to look more closely). What's your opinion? |
|
👋 In all these benchmarks, main is First, here are some other benchmarks results with 3 MoE models at various context sizes: Gemma4 26B-A4B, Nemotron3-Super and GLM4.7-Flash. They are looking great! Then, here are some results with Qwen 3.5 9B, I've tried to use kind of the same params you were using in the llama 8B benchmarks you shared earlier @JohannesGaessler, varying the quants and the ubatch size. I'm wondering the reason you're interested in the different ubatch size, if you have some pointers? HTH |
|
Tested this on a Framework Desktop (Ryzen AI MAX+ 395, Radeon 8060S / gfx1151, 128 GB RAM, Fedora 43, ROCm 7.2.1) with: Qwen3.5-122B-A10B-REAP-20 Q6_K. I used Codex to build a patched ROCm container from the PR diff and compared it against the stock ROCm
So at least on gfx1151 + a large MoE model, this looks very real and very useful. |
|
I have added a toolbox with this PR and run the benchmark: https://kyuz0.github.io/amd-strix-halo-toolboxes/ The benefits seem to be mostly for short context, but if you switch to the 30k context tests, results do not look great, sometimes better, sometimes worse, but not by that much. |
Applies the six-edit ggml-cuda/mmq.cuh change from upstream PR ggml-org#21344 (pedapudi/llama.cpp@gfx1151-opt) that gives RDNA 3.5 its own MMQ tile and warp sizing — mmq_x_max=48, mmq_y=64, nwarps=4 — instead of inheriting the discrete RDNA3 values tuned for 7900 XTX-class hardware. Hypothesis, expected numbers (from kyuz0's independent A/B logs), and bench plan in strix-halo/mmq-rdna3_5.md. Also includes the previously uncommitted docs (codex-insights, rocm-config) and updates to NOTES, README, uma-integrated reflecting the UMA-deprioritization decision, plus a .gitignore entry for useful-repos/. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Applies the six-edit ggml-cuda/mmq.cuh change from upstream PR ggml-org#21344 (pedapudi/llama.cpp@gfx1151-opt) that gives RDNA 3.5 its own MMQ tile and warp sizing — mmq_x_max=48, mmq_y=64, nwarps=4 — instead of inheriting the discrete RDNA3 values tuned for 7900 XTX-class hardware. Hypothesis, expected numbers (from kyuz0's independent A/B logs), and bench plan in strix-halo/mmq-rdna3_5.md. Also includes the previously uncommitted docs (codex-insights, rocm-config) and updates to NOTES, README, uma-integrated reflecting the UMA-deprioritization decision, plus a .gitignore entry for useful-repos/. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

Follow up on issue #21284
a. MMQ:
mmq_x_max=48,mmq_y=64,nwarps=4for RDNA3_5 to balance VGPR usage and occupancyb. Note: I took the opportunity for a minor refactor replacing nested ternary operators to improve readability and reduce opportunity for errors (especially after I made a mistake while piling on the ternary operations).
mmvq_parameter_table_idinstead of falling back to RDNA2a. Results in nwraps calculation falling to 1.
1 is more important than 2, but 2 is still helpful on the mmvq paths. And it sets up for future per-quant tuning.
Benchmarks
Built with cmake flags
Before (build 7c7d6ce / 8642)
After (build 955df3551 / 8643)
Speedup
Requirements