Skip to content

Conversation

@unverbraucht
Copy link

@unverbraucht unverbraucht commented Nov 25, 2025

Based on the work by @zhang-hui-yulo for RDNA4 I attempted to backport the WMMA MMF support to RDNA3. Also ports the RDNA4 WMMA-MMQ improvements by @jiachengjason from PR #17156 to RDNA3.

The differences to RDNA4 are:

  • RDNA3 has no FP8 support for in WMMA (INT8 is supported by hardware)
  • RDNA3 has a different tile size

The results for granite 1b 400m look great:

GPU Model Microbatch size Test t/s master t/s ba25661 Speedup
RX 7900 XT granitemoe ?B F16 1 pp512 283.42 286.08 1.01
RX 7900 XT granitemoe ?B F16 2 pp512 124.99 668.81 5.35
RX 7900 XT granitemoe ?B F16 4 pp512 205.77 1224.45 5.95
RX 7900 XT granitemoe ?B F16 8 pp512 377.29 1881.51 4.99
RX 7900 XT granitemoe ?B F16 16 pp512 640.67 3181.89 4.97
RX 7900 XT granitemoe ?B F16 32 pp512 1024.92 5654.28 5.52
RX 7900 XT granitemoe ?B F16 64 pp512 2052.33 9817.10 4.78
RX 7900 XT granitemoe ?B F16 128 pp512 3622.50 15972.81 4.41
RX 7900 XT granitemoe ?B F16 256 pp512 6007.40 22525.58 3.75
RX 7900 XT granitemoe ?B F16 512 pp512 9174.28 27815.62 3.03

EDIT: the performance regression for GPT OSS 20b has been fixed, now we have moderate speed-up:

GPU Model Microbatch size Test t/s 55ab25c t/s 8aed111 Speedup
RX 7900 XT gpt-oss 20B Q8_0 1 pp512 184.01 181.51 0.99
RX 7900 XT gpt-oss 20B Q8_0 2 pp512 194.39 216.68 1.11
RX 7900 XT gpt-oss 20B Q8_0 4 pp512 331.38 386.97 1.17
RX 7900 XT gpt-oss 20B Q8_0 8 pp512 535.49 656.94 1.23
RX 7900 XT gpt-oss 20B Q8_0 16 pp512 683.39 772.00 1.13
RX 7900 XT gpt-oss 20B Q8_0 32 pp512 898.96 1049.12 1.17
RX 7900 XT gpt-oss 20B Q8_0 64 pp512 1089.26 1358.59 1.25
RX 7900 XT gpt-oss 20B Q8_0 128 pp512 1712.74 1935.43 1.13
RX 7900 XT gpt-oss 20B Q8_0 256 pp512 2552.29 2828.21 1.11
RX 7900 XT gpt-oss 20B Q8_0 512 pp512 3298.97 3594.53 1.09

CC @jiachengjason

zhang hui and others added 20 commits November 7, 2025 21:22
  Key Changes Made:

  1. ggml/src/ggml-cuda/common.cuh:
    - Extended AMD_WMMA_AVAILABLE macro to include both RDNA3 and RDNA4
    - Updated amd_wmma_available() to return true for both architectures
  2. ggml/src/ggml-cuda/mma.cuh:
    - Tile structures: Added RDNA3-specific tile sizes:
        - RDNA4: 4 half2 = 8 FP16 elements (compact layout)
      - RDNA3: 8 half2 = 16 FP16 elements (duplicate layout required by hardware)
    - MMA operations: Added RDNA3 intrinsics:
        - FP16: __builtin_amdgcn_wmma_f32_16x16x16_f16_w32 (no gfx12 suffix)
      - BF16: __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32
      - Uses halfx16_t/bf16x16_t for RDNA3 vs halfx8_t/bf16x8_t for RDNA4
    - Load operations: Added conditional handling for 32-byte RDNA3 tiles using two 16-byte copies
  3. ggml/src/ggml-cuda/mmf.cu:
    - Updated to use amd_wmma_available() for both RDNA3 and RDNA4
@am17an
Copy link
Collaborator

am17an commented Nov 25, 2025

gpt-oss would not be using the MMF path (it uses MMQ), you might have some variation in your measurements

@unverbraucht
Copy link
Author

@am17an you're right, since we don't have integer WMMA on RDNA3 this should not be using this code path. I might have other commits in my PR that I don't have in my master build, or maybe my changes mess with the FP16 code path.

I'll look into using the same master build, and also check with other FP16 models

…use MMQ with integer WMMA operations (hardware-accelerated)
@hjc4869
Copy link
Contributor

hjc4869 commented Nov 25, 2025

RDNA3 does have __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32 intrinsic (V_WMMA_I32_16X16X16_IU8) which is a little different from RDNA4's _gfx12 variant but has the same functionality. Though it's the same ops/cycle as F16/BF16 so probably only gonna save some registers / bandwidth here and there.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Nov 25, 2025
@jiachengjason
Copy link
Contributor

I am currently working on enabling __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32 for RDNA3 with similar fashion as #17156

@unverbraucht
Copy link
Author

@jiachengjason great, looking forward to it :)

Maybe you can also help me with this, since it touches on MMQ: I am trying to find the source of the regression of GPT OSS 20b regression. It seems to me that RDNA3 no longer uses MMQ with DP4A instructions for batches < 512, which is the fast path for RDNA3. I'm trying to debug this right in my last commits.

@unverbraucht unverbraucht marked this pull request as draft November 25, 2025 16:27
Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • According to the AMD ISA documentation RDNA3 supports integer tensor cores. Please change the comments to say that it's not implemented rather than not supported.
  • Please always add a comment for an #endif to indicate which #if/#ifdef it is closing.
  • The get_i and get_j methods are not going to work correctly if you mirror the data for RDNA3. Please either implement them correctly or replace them with NO_DEVICE_CODE for RDNA3.
  • The code in mma.cuh is currently in a bad state in terms of maintainability and is in dire need of a refactor. However, I consider this to be a job for me, the maintainer, rather than contributors. So no action from your side is necessary, for now it's fine to pile on hacky solutions. I just want to give you a heads-up that the code is subject to change once RDNA3, RDNA4, and CDNA are properly supported and I know what the requirements are.

@unverbraucht
Copy link
Author

@JohannesGaessler thanks for the feedback.

RDNA3 indeed supports INT8 in WMMA, and I'll investigate that. It doesn't support FP8 and the sparse WMMA is also missing. Looking into get_i and get_j.

Regarding your new code in #17505 - does it even make sense to investigate this code here more or should I wait for that PR to be merged and then attempt to add this to the new MMA kernel?

@JohannesGaessler
Copy link
Collaborator

RDNA3 support should be largely independent of the changes I'm making to mma.cuh as long as you're only working on the kernel in mmf.cuh. For the kernel in fattn-mma-f16.cuh my PR should very much be merged first and then correct implementations for get_i and get_j will be 100% required.

Kevin Read added 2 commits November 27, 2025 11:35
Details
1. Separated RDNA3 and RDNA4 integer WMMA implementations in mma.cuh:
  - RDNA4: Uses __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12 with int32x2_t (original path preserved)
  - RDNA3: Uses __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32 with int32x4_t (new path added)

2. Both architectures now share:                                                                                                                                                       - AMD_WMMA_INT_AVAILABLE macro (enables shared optimizations)
  - Memory layout settings (mmq_x_max=128, granularity=16, nwarps=8)
  - The tile<16, 4, int> optimizations in mmq.cuh
3. RDNA4-exclusive features remain untouched:
  - FP8/BF8 WMMA operations
  - Specific RDNA4 optimizations behind #if defined(RDNA4) guards
@unverbraucht
Copy link
Author

@JohannesGaessler I've updated the code to make use of int8 wmma. get_i and get_j are working now. endifs commented.

@jiachengjason as far as I can tell my changes cover the uses of __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32 that I can see. This also fixes the GPT-OSS 20b regression. Please have a look at this draft.

I will wait with getting the conflicts merged since we want to merge #17505 first.

@unverbraucht unverbraucht marked this pull request as ready for review November 27, 2025 11:54
@jiachengjason
Copy link
Contributor

Hi @unverbraucht, I don't believe the implementation for int8 is correct. (by running HIP_VISIBLE_DEVICES=0 ./build/bin/test-backend-ops test -o MUL_MAT > output.txt) The the quantization cases are failing for n>8. I think the mapping for loading the data into the register are incorrect for RDNA3. (load_generic)

I have attached the output of backend-ops test
output.txt

@zhang-hui-yulo
Copy link
Contributor

zhang-hui-yulo commented Nov 28, 2025

Hi @unverbraucht, I don't believe the implementation for int8 is correct. (by running HIP_VISIBLE_DEVICES=0 ./build/bin/test-backend-ops test -o MUL_MAT > output.txt) The the quantization cases are failing for n>8. I think the mapping for loading the data into the register are incorrect for RDNA3. (load_generic)

I have attached the output of backend-ops test output.txt

I think @unverbraucht uses the original tile<I, J, T> for RDNA3 int8 wmma and doesn't deal with data loading well in load_generic, I wonder if it's possible to move matrix A, B out of tile<I, J, T> as it covers both row and col major matrix for AMD.

Also, I will suggest to clean up get_i and get_j in tile<I, J, half2> and tile<I, J, nv_bfloat162> as the iteration based on ne will be not correct.

My suggestion will be tile<I, J, T, transposed = true> for matrix A, B for RDNA int8 and tile<I, J, T> for matrix C, then load_generic can use ggml_cuda_memcpy_1 with position get_i(0) + get_j(0) for all RDNA wmma layout as they all have continues data, just remove the ugly int64_t copy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants