Skip to content

Conversation

@jiachengjason
Copy link
Contributor

Enabled WMMA-MMQ kernels for RDNA 4 architecture on AMD GPUs

Following similar approach to #14624

Using ./build/bin/llama-bench to collect the following performance results

Performance results with ggml/llama.cpp master commit up to/includes 5b180c3

Build command for the following performance results:
HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" cmake -S . -B build -DGGML_HIP=ON -DGGML_CUDA_FORCE_MMQ=OFF -DGGML_HIP_UMA=OFF -DGGML_HIP_ROCWMMA_FATTN=OFF -DGPU_TARGETS="gfx1201" -DGGML_HIP_GRAPHS=OFF -DLLAMA_CURL=OFF -DGGML_CUDA_FORCE_CUBLAS=OFF -DCMAKE_BUILD_TYPE=Release && cmake --build build --config Release -- -j 32

image

Build command for the following performance results:
HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" cmake -S . -B build -DGGML_HIP=ON -DGGML_HIP_UMA=OFF -DGGML_HIP_ROCWMMA_FATTN=ON -DGPU_TARGETS=gfx1201 -DGGML_HIP_GRAPHS=OFF -DLLAMA_CURL=OFF -DGGML_CUDA_FORCE_CUBLAS=OFF -DCMAKE_BUILD_TYPE=Release && cmake --build build --config Release -- -j 32

image

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Nov 11, 2025
@JohannesGaessler
Copy link
Collaborator

Can you give me a quick summary of what you would consider to still be missing from this PR for it to be ready for review?

@jiachengjason jiachengjason marked this pull request as ready for review November 11, 2025 18:15
@jiachengjason
Copy link
Contributor Author

Can you give me a quick summary of what you would consider to still be missing from this PR for it to be ready for review?

Hi I opened up for review now, thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 is not used anymore

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its not

Copy link
Collaborator

@IMbackK IMbackK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jiachengjason As mentioned by @slojosic-amd, there is a accidental change in the cmake file.
You are also changing the permissions of build-xcframework.sh by accident with this pr. Please revert these changes.

Since you improved the performance mmq on RDNA4 you should also change

return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
to increase the use of these kernels. As the performance diff between this and blas seams quite uneven i would recommend doing so on a per-datatype and shape basis like is done for mfma. For CDNA i just selected whatever is faster but for RDNA4 i would err on the side using mmq unless the performance difference is large as consumer RDNA4 devices have only 16 GiB of vram makeing the size of the dequant buffers more relevant.

@IMbackK
Copy link
Collaborator

IMbackK commented Nov 12, 2025

Oh and using scripts/compare-llama-bench.py with llama-bench [...] -o sql|sqlite3 llama-bench.sqlite and test-backend-ops perf provide more readable ways to compare performances changes - useful for finding the right shapes to enable.

@IMbackK IMbackK self-assigned this Nov 12, 2025
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON)
option(GGML_HIP_MMQ_WMMA "ggml: enable WMMA MMA for RDNA4 in MMQ" ON)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now this is fine but long-term, after the kernels have been fully optimized and tested per datatype, it would be preferable to re-use the FORCE_CUBLAS and FORCE_MMQ options.

Copy link
Collaborator

@IMbackK IMbackK Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This dose not replace those, but makes it use the dp4a mmq kernels instead. I added this for cdna/mfma because it allows testing for GCN performance regressions on CDNA. Similarly this allows testing for RDNA1/2 performance regressions on RDNA4.

I would prefer this to be kept.

EDIT: i gues testing for RDNA1/2 performance on RDNA4 is less useful than testing for GCN performance on CDNA as RDNA4 has more VGPRS and some new VALU instructions compared to RDNA1/2 unlike CDNA/GCN which have fewer differences

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this for cdna/mfma because it allows testing for GCN performance regressions on CDNA.

My experience so far has been that the portability of performance across GPUs is so poor that something like this is of little utility. In the rare cases where emulating old hardware is needed one should just edit the code temporarily. If options like this are exposed to users they are going to use them and that increases the amount of work that needs to be put into maintenance. So long-term I still intend to remove those options. My current AMD lineup consists of RDNA2, RDNA3.5, RDNA4, GCN5.1, and CDNA1, and in the next months I intend to add RDNA3 and CDNA2. I would just test the performance using those GPUs directly.

Copy link
Collaborator

@IMbackK IMbackK Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not everyone has a huge selection of hardware to choose from. Across GCN5.1/gfx906 and CDNA in my experience the performance portability is extremely close, this is no surprise as the changes made to CDNA that are relevant to ggml are very slight:

  1. MFMA was added, with a special register 256 wide file usable by just these instructions and loads and stores.
  2. an instruction was added to load from global memory directly into lds, but the compiler do sent generate it.

The only practical difference in the generated assembly is that under register pressure the compiler will spill to MFMAs register space instead of scratch memory, which very slightly reduces the cost of spills under register pressure.
The cus themselves are also extremely similar and cache local memory and global memory latency are essentially unchanged.

The picture changes only slightly with CDNA2 where the physical (but not logical) register space between the valu and mfma instructions is now shared, meaning the minimum occupancy for a valu kernel allocating all 256 registers is 2 and packed 32bit instructions where added, but again in my expirance the performance on cdna2 predicts extremely closely the performance on GCN.

I dont have much expirance with RDNA and its true that the changes between RDNA generations are larger.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any case, we to my knowledge don't have anyone who would be using the GGML_HIP_MMQ_WMMA option with the intent you laid out so it should be removed. I fundamentally don't want to add extra compilation options unless there is a good reason for them because that is just one extra variable that one potentially needs to account for with bug reports.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you changing this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was a mistake, it is reverted now thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are still changes to this file.

Comment on lines 231 to 233
#if defined(GGML_USE_HIP) && defined(RDNA4) && !defined(GGML_HIP_NO_MMQ_WMMA)
#define AMD_WMMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(RDNA4) && !defined(GGML_HIP_NO_MMQ_WMMA)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This going to be in conflict with #17077 . Instead of making the availability of AMD WMMA contingent on GGML_HIP_NO_MMQ_WMMA, check that macro in mmq.cuh to decide whether or not to use AMD WMMA.

}

if (amd_mfma_available(cc)) {
if (amd_mfma_available(cc)||amd_wmma_available(cc)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make a separate branch for AMD WMMA instead. Since as of right now it's an explicit opt-in via a compilation option, simply return true if and only if the cc is RDNA4 and the compilation option has been enabled.

Copy link
Collaborator

@IMbackK IMbackK Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JohannesGaessler I think you misunderstood what the compilation option is supped to achieve.

Copy link
Collaborator

@IMbackK IMbackK Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jiachengjason formating, should be: if (amd_mfma_available(cc) || amd_wmma_available(cc)) (spaces)

static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#if defined(AMD_MFMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE)

Comment on lines 157 to 177
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else if constexpr (I == 32 && J == 4) {
return threadIdx.x % 32;
} else if constexpr (I == 16 && J == 16) {
return 8 * (threadIdx.x / 16) + l;
} else if constexpr (I == 32 && J == 32) {
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return 2 * (threadIdx.x / 16) + l;
} else if constexpr (I == 32 && J == 4) {
return 2 * (threadIdx.x / 32) + l;
} else if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
} else if constexpr (I == 32 && J == 32) {
return threadIdx.x % 32;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please define only the actually used shapes of 16x4 and 16x16 and use NO_DEVICE_CODE instead of a static assert as was recently changed in the surrounding code.

Comment on lines 714 to 817

#elif defined(AMD_WMMA_AVAILABLE)
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
int32x2_t * a_vec = (int32x2_t *) A.x;
int32x2_t * b_vec = (int32x2_t *) B.x;

using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x;

#if defined(RDNA4)

acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[0],
true,
b_vec[0],
acc[0],
true
);

acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[1],
true,
b_vec[1],
acc[0],
true
);
#endif // defined(RDNA4)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to my understanding currently unused, so please remove it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi I believe this is used in vec_dot_q8_0_q8_1_mma function which are called in Q4_0, Q5_0, Q8_0, MXFP4 etc

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that function I'm only seeing 16x4 and 16x16 tiles, not 16x8.

Copy link
Contributor Author

@jiachengjason jiachengjason Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, the Q4_0 case, execution flows into vec_dot_q8_0_q8_1_mma in mmq.cuh. Inside vec_dot_q8_0_q8_1_mma function, tile A and tile B are shaped as 16×8 blocks. These tiles are forwarded to the mma function (the one shown here), where they are processed by the WMMA instructions.

// performs better but is currently suffering from a crash on this architecture.
// TODO: Revisit when hipblaslt is fixed on CDNA3
if (GGML_CUDA_CC_IS_CDNA3(cc)) {
if (GGML_CUDA_CC_IS_CDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we have test-backend-ops perf -o MUL_MAT for this pr and the master to better see if this always enabling this is the way to go?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are the results from test-backend-ops perf -o MUL_MAT when FA is not enabled for n > 8

image image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about other batch sizes, could we have just the full raw output for both cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think we can enable -fa 1 when running test-backend-ops perf -o MUL_MAT ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no but your pr dosent change the fa kernels so the only change to the relative performance between fa on and off is that the non-fa path dose include matmuls.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a thing to check would be batch sizes greater than 512, but you cant do that with test-backend-ops, something like the following can be used for that:

$ git checkout master
$ ...build...
$ llama-bench [...] -ub 128,256,512,1024,2048 -n 0 -p 2048 -o sql|sqlite3 llama-bench.sqlite
$ git checkout pr
$ ...build...
$ llama-bench [...] -ub 128,256,512,1024,2048 -n 0 -p 2048 -o sql|sqlite3 llama-bench.sqlite
$ python scripts/compare-llama-bench.py -b master -c pr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are the results I gathered
image
image

return GGML_CUDA_CC_IS_RDNA4(cc);
}

// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.

int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
} else {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guard this branch with an explicit check for the expected shape and put NO_DEVICE_CODE into the else branch at the end.

Comment on lines 714 to 817

#elif defined(AMD_WMMA_AVAILABLE)
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
int32x2_t * a_vec = (int32x2_t *) A.x;
int32x2_t * b_vec = (int32x2_t *) B.x;

using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x;

#if defined(RDNA4)

acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[0],
true,
b_vec[0],
acc[0],
true
);

acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[1],
true,
b_vec[1],
acc[0],
true
);
#endif // defined(RDNA4)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that function I'm only seeing 16x4 and 16x16 tiles, not 16x8.

}

if (amd_mfma_available(cc)) {
if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make a separate branch for the AMD WMMA selection logic that for now simply returns true. Prior to merging this PR we should then test the performance as a function of data type and tensor shape and decide how exactly to do the selection logic.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are still changes to this file.

@JohannesGaessler
Copy link
Collaborator

I'm currently traveling and I'm unfortunately unable to start my machine with an RDNA4 GPU remotely via wake-on-LAN. I'll be back on Saturday and I'll test the performance then.

@IMbackK
Copy link
Collaborator

IMbackK commented Nov 19, 2025

Ok so over all i would say that we should make should use mmq:

  1. return false for q2_k, q3_k, q5_k, iq2_xs, 1q2_s
  2. return true for mxfp4, Q8_0
  3. return true for batch sizes <=1024 for all except those in 1.
  4. return false for batch sizes > 1024 except those in 2.

@JohannesGaessler
Copy link
Collaborator

Performance
GPU Model Microbatch size Test t/s b6968 t/s 9cdf36e Speedup
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 1 pp2048 69.25 69.23 1.00
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 2 pp2048 121.22 120.71 1.00
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 4 pp2048 187.76 187.85 1.00
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 8 pp2048 207.71 207.80 1.00
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 16 pp2048 484.95 563.50 1.16
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 32 pp2048 596.67 756.66 1.27
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 64 pp2048 179.19 1274.58 7.11
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 128 pp2048 315.04 1704.01 5.41
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 256 pp2048 335.42 1838.65 5.48
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 512 pp2048 346.57 1887.70 5.45
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 1024 pp2048 353.25 1920.43 5.44
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 2048 pp2048 354.21 1901.36 5.37
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 1 pp2048 56.67 56.55 1.00
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 2 pp2048 97.84 97.50 1.00
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 4 pp2048 155.61 154.45 0.99
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 8 pp2048 208.24 206.86 0.99
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 16 pp2048 395.11 349.52 0.88
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 32 pp2048 512.88 645.99 1.26
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 64 pp2048 178.76 1161.13 6.50
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 128 pp2048 314.98 1496.31 4.75
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 256 pp2048 336.68 1616.53 4.80
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 512 pp2048 348.16 1662.30 4.77
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 1024 pp2048 355.25 1680.76 4.73
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 2048 pp2048 356.36 1657.48 4.65
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 1 pp2048 58.51 58.41 1.00
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 2 pp2048 100.19 99.59 0.99
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 4 pp2048 153.89 153.78 1.00
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 8 pp2048 208.55 208.14 1.00
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 16 pp2048 399.64 350.13 0.88
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 32 pp2048 513.08 651.42 1.27
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 64 pp2048 178.86 1141.63 6.38
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 128 pp2048 314.72 1442.75 4.58
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 256 pp2048 336.23 1551.07 4.61
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 512 pp2048 348.22 1594.66 4.58
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 1024 pp2048 355.48 1625.77 4.57
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 2048 pp2048 357.08 1609.75 4.51
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 1 pp2048 51.14 51.08 1.00
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 2 pp2048 92.62 92.61 1.00
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 4 pp2048 154.33 154.02 1.00
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 8 pp2048 180.80 180.57 1.00
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 16 pp2048 390.86 443.82 1.14
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 32 pp2048 523.94 576.05 1.10
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 64 pp2048 179.31 1158.43 6.46
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 128 pp2048 315.85 1707.58 5.41
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 256 pp2048 336.98 1846.91 5.48
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 512 pp2048 348.59 1894.87 5.44
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 1024 pp2048 356.03 1933.77 5.43
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 2048 pp2048 357.59 1913.28 5.35
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 1 pp2048 47.20 47.15 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 2 pp2048 86.78 86.73 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 4 pp2048 158.41 158.34 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 8 pp2048 179.87 179.83 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 16 pp2048 382.16 413.24 1.08
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 32 pp2048 519.29 589.77 1.14
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 64 pp2048 176.43 1169.05 6.63
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 128 pp2048 309.44 1735.75 5.61
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 256 pp2048 331.73 1868.42 5.63
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 512 pp2048 343.66 1921.32 5.59
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 1024 pp2048 351.05 1946.41 5.54
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 2048 pp2048 352.43 1910.72 5.42
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 1 pp2048 46.34 46.42 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 2 pp2048 86.81 87.13 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 4 pp2048 154.33 155.43 1.01
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 8 pp2048 176.71 177.10 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 16 pp2048 387.57 424.99 1.10
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 32 pp2048 522.41 610.31 1.17
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 64 pp2048 176.34 1163.72 6.60
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 128 pp2048 310.27 1749.13 5.64
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 256 pp2048 332.76 1884.11 5.66
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 512 pp2048 344.66 1932.79 5.61
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 1024 pp2048 352.08 1968.20 5.59
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 2048 pp2048 353.48 1928.23 5.45
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 1 pp2048 51.77 51.95 1.00
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 2 pp2048 92.84 93.07 1.00
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 4 pp2048 152.81 153.36 1.00
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 8 pp2048 181.30 181.77 1.00
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 16 pp2048 397.46 449.30 1.13
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 32 pp2048 529.20 600.71 1.14
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 64 pp2048 175.25 1211.05 6.91
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 128 pp2048 308.34 1786.84 5.80
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 256 pp2048 330.27 1930.81 5.85
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 512 pp2048 341.93 1988.71 5.82
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 1024 pp2048 349.24 2022.71 5.79
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 2048 pp2048 350.71 1982.03 5.65
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 1 pp2048 54.53 54.57 1.00
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 2 pp2048 94.68 94.67 1.00
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 4 pp2048 151.38 151.17 1.00
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 8 pp2048 184.31 184.52 1.00
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 16 pp2048 406.25 454.51 1.12
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 32 pp2048 542.71 614.05 1.13
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 64 pp2048 177.39 1224.91 6.91
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 128 pp2048 312.31 1788.60 5.73
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 256 pp2048 334.30 1941.26 5.81
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 512 pp2048 346.31 2003.03 5.78
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 1024 pp2048 353.27 2033.53 5.76
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 2048 pp2048 354.57 1993.13 5.62
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 1 pp2048 47.66 47.74 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 2 pp2048 90.98 91.13 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 4 pp2048 171.17 171.13 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 8 pp2048 197.40 196.92 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 16 pp2048 463.53 555.07 1.20
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 32 pp2048 587.90 784.33 1.33
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 64 pp2048 174.53 1331.95 7.63
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 128 pp2048 307.83 1955.61 6.35
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 256 pp2048 330.75 2113.76 6.39
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 512 pp2048 343.49 2178.24 6.34
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 1024 pp2048 351.17 2215.71 6.31
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 2048 pp2048 353.36 2174.19 6.15
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 1 pp2048 50.33 50.22 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 2 pp2048 95.61 95.55 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 4 pp2048 171.58 171.78 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 8 pp2048 210.43 211.05 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 16 pp2048 482.81 578.85 1.20
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 32 pp2048 600.36 776.78 1.29
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 64 pp2048 174.94 1378.66 7.88
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 128 pp2048 308.07 1973.03 6.40
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 256 pp2048 331.07 2133.50 6.44
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 512 pp2048 343.73 2196.40 6.39
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 1024 pp2048 351.35 2238.96 6.37
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 2048 pp2048 353.12 2204.32 6.24
RX 9060 XT llama 8B Q2_K_S 1 pp2048 66.15 66.29 1.00
RX 9060 XT llama 8B Q2_K_S 2 pp2048 102.84 102.55 1.00
RX 9060 XT llama 8B Q2_K_S 4 pp2048 128.41 128.04 1.00
RX 9060 XT llama 8B Q2_K_S 8 pp2048 144.22 143.92 1.00
RX 9060 XT llama 8B Q2_K_S 16 pp2048 414.83 348.74 0.84
RX 9060 XT llama 8B Q2_K_S 32 pp2048 414.07 488.63 1.18
RX 9060 XT llama 8B Q2_K_S 64 pp2048 176.21 762.09 4.32
RX 9060 XT llama 8B Q2_K_S 128 pp2048 311.92 914.99 2.93
RX 9060 XT llama 8B Q2_K_S 256 pp2048 334.82 875.72 2.62
RX 9060 XT llama 8B Q2_K_S 512 pp2048 347.59 896.05 2.58
RX 9060 XT llama 8B Q2_K_S 1024 pp2048 355.38 970.91 2.73
RX 9060 XT llama 8B Q2_K_S 2048 pp2048 357.46 968.77 2.71
RX 9060 XT llama 8B Q3_K_S 1 pp2048 49.42 49.80 1.01
RX 9060 XT llama 8B Q3_K_S 2 pp2048 78.13 77.52 0.99
RX 9060 XT llama 8B Q3_K_S 4 pp2048 114.46 113.37 0.99
RX 9060 XT llama 8B Q3_K_S 8 pp2048 142.05 140.69 0.99
RX 9060 XT llama 8B Q3_K_S 16 pp2048 404.35 482.63 1.19
RX 9060 XT llama 8B Q3_K_S 32 pp2048 507.08 700.31 1.38
RX 9060 XT llama 8B Q3_K_S 64 pp2048 171.80 1223.35 7.12
RX 9060 XT llama 8B Q3_K_S 128 pp2048 302.84 1643.20 5.43
RX 9060 XT llama 8B Q3_K_S 256 pp2048 329.55 1767.48 5.36
RX 9060 XT llama 8B Q3_K_S 512 pp2048 344.41 1822.14 5.29
RX 9060 XT llama 8B Q3_K_S 1024 pp2048 353.87 1853.60 5.24
RX 9060 XT llama 8B Q3_K_S 2048 pp2048 356.54 1836.60 5.15
RX 9060 XT llama 8B Q4_0 1 pp2048 48.02 48.16 1.00
RX 9060 XT llama 8B Q4_0 2 pp2048 92.18 92.70 1.01
RX 9060 XT llama 8B Q4_0 4 pp2048 171.45 172.24 1.00
RX 9060 XT llama 8B Q4_0 8 pp2048 202.46 202.86 1.00
RX 9060 XT llama 8B Q4_0 16 pp2048 480.42 561.16 1.17
RX 9060 XT llama 8B Q4_0 32 pp2048 613.28 766.81 1.25
RX 9060 XT llama 8B Q4_0 64 pp2048 176.38 1443.81 8.19
RX 9060 XT llama 8B Q4_0 128 pp2048 315.53 2062.85 6.54
RX 9060 XT llama 8B Q4_0 256 pp2048 339.53 2270.94 6.69
RX 9060 XT llama 8B Q4_0 512 pp2048 353.68 2349.09 6.64
RX 9060 XT llama 8B Q4_0 1024 pp2048 361.36 2413.84 6.68
RX 9060 XT llama 8B Q4_0 2048 pp2048 363.04 2389.10 6.58
RX 9060 XT llama 8B Q4_1 1 pp2048 45.80 45.79 1.00
RX 9060 XT llama 8B Q4_1 2 pp2048 86.96 86.54 1.00
RX 9060 XT llama 8B Q4_1 4 pp2048 163.59 163.04 1.00
RX 9060 XT llama 8B Q4_1 8 pp2048 215.88 215.60 1.00
RX 9060 XT llama 8B Q4_1 16 pp2048 480.80 564.78 1.17
RX 9060 XT llama 8B Q4_1 32 pp2048 624.21 792.30 1.27
RX 9060 XT llama 8B Q4_1 64 pp2048 174.30 1396.35 8.01
RX 9060 XT llama 8B Q4_1 128 pp2048 309.66 1710.60 5.52
RX 9060 XT llama 8B Q4_1 256 pp2048 333.48 1860.02 5.58
RX 9060 XT llama 8B Q4_1 512 pp2048 346.67 1917.93 5.53
RX 9060 XT llama 8B Q4_1 1024 pp2048 354.79 1954.68 5.51
RX 9060 XT llama 8B Q4_1 2048 pp2048 355.83 1937.92 5.45
RX 9060 XT llama 8B Q4_K_S 1 pp2048 47.93 47.93 1.00
RX 9060 XT llama 8B Q4_K_S 2 pp2048 92.85 92.92 1.00
RX 9060 XT llama 8B Q4_K_S 4 pp2048 134.28 134.28 1.00
RX 9060 XT llama 8B Q4_K_S 8 pp2048 150.11 150.20 1.00
RX 9060 XT llama 8B Q4_K_S 16 pp2048 452.07 553.63 1.22
RX 9060 XT llama 8B Q4_K_S 32 pp2048 540.08 786.03 1.46
RX 9060 XT llama 8B Q4_K_S 64 pp2048 174.64 1411.17 8.08
RX 9060 XT llama 8B Q4_K_S 128 pp2048 307.82 1795.19 5.83
RX 9060 XT llama 8B Q4_K_S 256 pp2048 331.24 1925.56 5.81
RX 9060 XT llama 8B Q4_K_S 512 pp2048 344.64 1972.11 5.72
RX 9060 XT llama 8B Q4_K_S 1024 pp2048 352.44 2004.93 5.69
RX 9060 XT llama 8B Q4_K_S 2048 pp2048 353.93 1989.67 5.62
RX 9060 XT llama 8B Q5_0 1 pp2048 43.00 43.00 1.00
RX 9060 XT llama 8B Q5_0 2 pp2048 81.76 81.73 1.00
RX 9060 XT llama 8B Q5_0 4 pp2048 153.78 153.60 1.00
RX 9060 XT llama 8B Q5_0 8 pp2048 191.68 192.22 1.00
RX 9060 XT llama 8B Q5_0 16 pp2048 402.32 475.15 1.18
RX 9060 XT llama 8B Q5_0 32 pp2048 538.42 679.14 1.26
RX 9060 XT llama 8B Q5_0 64 pp2048 161.50 1189.80 7.37
RX 9060 XT llama 8B Q5_0 128 pp2048 288.71 1820.03 6.30
RX 9060 XT llama 8B Q5_0 256 pp2048 321.02 1953.89 6.09
RX 9060 XT llama 8B Q5_0 512 pp2048 340.14 2000.58 5.88
RX 9060 XT llama 8B Q5_0 1024 pp2048 351.81 2034.38 5.78
RX 9060 XT llama 8B Q5_0 2048 pp2048 354.38 2006.29 5.66
RX 9060 XT llama 8B Q5_1 1 pp2048 41.39 41.44 1.00
RX 9060 XT llama 8B Q5_1 2 pp2048 80.46 80.60 1.00
RX 9060 XT llama 8B Q5_1 4 pp2048 146.47 146.96 1.00
RX 9060 XT llama 8B Q5_1 8 pp2048 234.44 234.93 1.00
RX 9060 XT llama 8B Q5_1 16 pp2048 378.65 415.17 1.10
RX 9060 XT llama 8B Q5_1 32 pp2048 533.34 633.32 1.19
RX 9060 XT llama 8B Q5_1 64 pp2048 162.63 1062.68 6.53
RX 9060 XT llama 8B Q5_1 128 pp2048 288.97 1505.34 5.21
RX 9060 XT llama 8B Q5_1 256 pp2048 320.66 1624.56 5.07
RX 9060 XT llama 8B Q5_1 512 pp2048 339.16 1675.88 4.94
RX 9060 XT llama 8B Q5_1 1024 pp2048 349.82 1698.82 4.86
RX 9060 XT llama 8B Q5_1 2048 pp2048 353.40 1685.44 4.77
RX 9060 XT llama 8B Q5_K_S 1 pp2048 43.24 43.21 1.00
RX 9060 XT llama 8B Q5_K_S 2 pp2048 84.66 84.54 1.00
RX 9060 XT llama 8B Q5_K_S 4 pp2048 130.74 130.01 0.99
RX 9060 XT llama 8B Q5_K_S 8 pp2048 148.09 147.17 0.99
RX 9060 XT llama 8B Q5_K_S 16 pp2048 416.39 528.53 1.27
RX 9060 XT llama 8B Q5_K_S 32 pp2048 481.95 751.73 1.56
RX 9060 XT llama 8B Q5_K_S 64 pp2048 171.79 1250.53 7.28
RX 9060 XT llama 8B Q5_K_S 128 pp2048 303.48 1662.49 5.48
RX 9060 XT llama 8B Q5_K_S 256 pp2048 330.01 1777.73 5.39
RX 9060 XT llama 8B Q5_K_S 512 pp2048 344.64 1821.15 5.28
RX 9060 XT llama 8B Q5_K_S 1024 pp2048 353.77 1847.80 5.22
RX 9060 XT llama 8B Q5_K_S 2048 pp2048 356.12 1825.22 5.13
RX 9060 XT llama 8B Q6_K 1 pp2048 38.87 38.80 1.00
RX 9060 XT llama 8B Q6_K 2 pp2048 74.82 74.51 1.00
RX 9060 XT llama 8B Q6_K 4 pp2048 128.35 128.59 1.00
RX 9060 XT llama 8B Q6_K 8 pp2048 158.64 159.04 1.00
RX 9060 XT llama 8B Q6_K 16 pp2048 380.99 417.45 1.10
RX 9060 XT llama 8B Q6_K 32 pp2048 434.03 574.62 1.32
RX 9060 XT llama 8B Q6_K 64 pp2048 171.56 892.70 5.20
RX 9060 XT llama 8B Q6_K 128 pp2048 304.12 1114.89 3.67
RX 9060 XT llama 8B Q6_K 256 pp2048 330.16 1201.63 3.64
RX 9060 XT llama 8B Q6_K 512 pp2048 344.77 1222.57 3.55
RX 9060 XT llama 8B Q6_K 1024 pp2048 353.79 1240.98 3.51
RX 9060 XT llama 8B Q6_K 2048 pp2048 356.72 1231.18 3.45
RX 9060 XT llama 8B Q8_0 1 pp2048 32.66 32.68 1.00
RX 9060 XT llama 8B Q8_0 2 pp2048 60.88 60.93 1.00
RX 9060 XT llama 8B Q8_0 4 pp2048 116.43 116.60 1.00
RX 9060 XT llama 8B Q8_0 8 pp2048 181.42 181.43 1.00
RX 9060 XT llama 8B Q8_0 16 pp2048 394.56 432.62 1.10
RX 9060 XT llama 8B Q8_0 32 pp2048 548.80 662.79 1.21
RX 9060 XT llama 8B Q8_0 64 pp2048 169.56 1205.64 7.11
RX 9060 XT llama 8B Q8_0 128 pp2048 302.51 1876.41 6.20
RX 9060 XT llama 8B Q8_0 256 pp2048 329.80 2043.97 6.20
RX 9060 XT llama 8B Q8_0 512 pp2048 345.44 2110.12 6.11
RX 9060 XT llama 8B Q8_0 1024 pp2048 354.72 2170.41 6.12
RX 9060 XT llama 8B Q8_0 2048 pp2048 357.50 2145.75 6.00

Performance seems to be universally faster except for q2_K at a batch size of 16. If you compare the scaling of batch size 8 (MMVQ) vs. batch size 16 (MMQ) there are some cases where it would probably make sense to use MMQ even if the utilization of the WMMA instructions is < 50%. So overall the kernel selection logic should probably still be adjusted a bit in terms of ne11. To be clear, I don't mean that this needs be done in this PR, I would prefer to merge it as-is for now since the medium batch sizes are not that important by comparison. If you want to do testing like above as well, I used commands like:

for q in q4_0 q4_1 q5_0 q5_1 q8_0 q2_k_s q3_k_s q4_k_s q5_k_s q6_k iq1_s iq2_xxs iq2_xs iq2_s iq3_xxs iq3_xs iq3_s iq3_m iq4_nl iq4_xs; do echo $q; ./bench --model models/opt/${model_name}-${q}.gguf -r 1 -fa 1 -n 0 -p 2048 -ub "1-2048*2" --progress -o sql|sqlite3 llama-bench.sqlite; sleep 10; done

python3 scripts/compare-llama-bench.py -s gpu_info,model_type,n_ubatch -i llama-bench.sqlite -b b6968 | tee bench.txt

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming the CI passes, I will approve after you revert the changes to the CMake files and rebase on top of master.

if (GGML_CUDA_CC_IS_RDNA4(cc)) {
return true;
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert the changes to this file.

@jiachengjason jiachengjason force-pushed the feat/jiachengjason/enable_mmq_kernels_for_RDNA4 branch from c9a81fb to 9075f54 Compare November 23, 2025 18:37
@JohannesGaessler JohannesGaessler merged commit 0543f92 into ggml-org:master Nov 24, 2025
73 of 74 checks passed
@jiachengjason
Copy link
Contributor Author

Hi @JohannesGaessler, found 1 test case failing
MUL_MAT(type_a=q4_0,type_b=f32,m=576,n=512,k=576,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1)

likely due to conflicts with this merged PR #17077

will make a PR ASAP to patch this

@meven3000
Copy link

Does not appear working correctly on ROCM 7.1 post update with Radeon AI Pro 9700 RDNA 4.

Build 7146 worked as expected.

Will keep investigating, seems to be specific to Qwen type models so far as gpt-oss-20b, seed, Deepseek, Phi3, Phi4, llama3.3, work without issue.

No longer working:

Kwaipilot_KAT-Dev-Q5_K_L.gguf
Qwen3-30B-A3B-Thinking-2507-UD-Q4_K_XL.gguf
Qwen2.5-Coder-32B-Instruct-Q8_0.gguf
aquif-3.5-Max-42B-A3B-UD-Q4_K_XL.gguf

Today at 10:15 AM

 4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4

QWEN.5 coder

Hello how are you

如何是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是是

hi

Yes!
Is it helpful?
No!
Is it helpful?
Yes!
Is it helpful?
No!
Is it helpful?
Yes!
Is it helpful?
No!
Is it helpful?
Yes!
Is it helpful?
No!
Is it helpful?
Yes!
Is

@jiachengjason
Copy link
Contributor Author

Hi @JohannesGaessler, found 1 test case failing MUL_MAT(type_a=q4_0,type_b=f32,m=576,n=512,k=576,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1)

likely due to conflicts with this merged PR #17077

will make a PR ASAP to patch this

Patched this testcase failing in #17502

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants