From 032f69d46583da942e41ac6780d31211c3f2c0eb Mon Sep 17 00:00:00 2001 From: jiachengjason Date: Thu, 27 Nov 2025 10:48:17 -0500 Subject: [PATCH 1/5] enabled wmma instructions for most quantizations other than q2k --- ggml/src/ggml-cuda/common.cuh | 4 +- ggml/src/ggml-cuda/mma.cuh | 117 +++++++++++++++++++++++++++------- ggml/src/ggml-cuda/mmq.cu | 5 +- 3 files changed, 100 insertions(+), 26 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 992ec0495fe..aa08bbb6c52 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -226,7 +226,7 @@ static const char * cu_get_error_str(CUresult err) { #define AMD_MFMA_AVAILABLE #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) -#if defined(GGML_USE_HIP) && defined(RDNA4) +#if defined(GGML_USE_HIP) && (defined(RDNA4) || defined(RDNA3)) #define AMD_WMMA_AVAILABLE #endif // defined(GGML_USE_HIP) && defined(RDNA4) @@ -294,7 +294,7 @@ static bool amd_mfma_available(const int cc) { } static bool amd_wmma_available(const int cc) { - return GGML_CUDA_CC_IS_RDNA4(cc); + return (GGML_CUDA_CC_IS_RDNA4(cc) || GGML_CUDA_CC_IS_RDNA3(cc)); } static bool volta_mma_available(const int cc) { diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 6ea7a809a47..a9e04df4b07 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -173,6 +173,9 @@ namespace ggml_cuda_mma { #elif defined(AMD_WMMA_AVAILABLE) #if defined(RDNA4) static constexpr int ne = I * J / 32; +#elif defined(RDNA3) + static constexpr int ne = (I == 16 && J == 16) ? I * J / 32 : I * J / 16; +#endif T x[ne] = {0}; static constexpr __device__ bool supported() { @@ -182,7 +185,11 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_i(const int l) { if constexpr (I == 16 && J == 16) { +#if defined(RDNA4) return 8 * (threadIdx.x / 16) + l; +#elif defined(RDNA3) + return 2 * l + (threadIdx.x / 16); +#endif } else { NO_DEVICE_CODE; return -1; @@ -197,7 +204,6 @@ namespace ggml_cuda_mma { return -1; } } -#endif #else static constexpr int ne = I * J / 32; T x[ne] = {0}; @@ -284,7 +290,11 @@ namespace ggml_cuda_mma { } } #elif defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA3) + static constexpr int ne = (I == 16 && J == 16) ? I * J / 32 : I * J / 16; +#else static constexpr int ne = I * J / 32; +#endif half2 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { @@ -361,6 +371,12 @@ namespace ggml_cuda_mma { static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; static constexpr int ne = I * J / WARP_SIZE; +#if defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA3) + static constexpr int ne = (I == 16 && J == 16) ? I * J / 32 : I * J / 16; +#else + static constexpr int ne = I * J / 32; +#endif nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; #if defined(AMD_WMMA_AVAILABLE) @@ -544,16 +560,32 @@ namespace ggml_cuda_mma { } else if constexpr (std::is_same_v) { if constexpr (I == 16 && J == 4) { int64_t * xi = (int64_t *) t.x; +#if defined(RDNA4) const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); xi[0] = xs[0]; - +#elif defined(RDNA3) + static_assert(tile::ne >= 4, "fragment too small"); + const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride); + xi[0] = xs[0]; + xi[1] = xs[1]; +#endif // defined(RDNA4) }else if constexpr (I == 16 && J == 8) { int64_t * xi = (int64_t *) t.x; +#if defined(RDNA4) const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I)); xi[0] = xs[0]; const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2); xi[1] = xs1[0]; +#elif defined(RDNA3) + static_assert(tile::ne >= 8, "fragment too small"); + const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride); + // contiguous four 64-bit chunks per lane for the wider RDNA3 fragment + xi[0] = xs[0]; + xi[1] = xs[1]; + const int64_t * xs1 = xs + 2; + xi[2] = xs1[0]; + xi[3] = xs1[1]; }else{ NO_DEVICE_CODE; @@ -561,6 +593,7 @@ namespace ggml_cuda_mma { } else { NO_DEVICE_CODE; } +#endif #else #pragma unroll for (int l = 0; l < t.ne; ++l) { @@ -858,12 +891,14 @@ namespace ggml_cuda_mma { : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #elif defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA4) using halfx8_t = __attribute__((ext_vector_type(8))) _Float16; using floatx8_t = __attribute__((ext_vector_type(8))) float; floatx8_t& acc_frag = reinterpret_cast(D.x[0]); const halfx8_t& a_frag = reinterpret_cast(A.x[0]); const halfx8_t& b_frag = reinterpret_cast(B.x[0]); acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag); +#endif // RDNA4 #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -873,12 +908,14 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ void mma( tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) { #if defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA4) using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16; using floatx8_t = __attribute__((ext_vector_type(8))) float; floatx8_t& acc_frag = reinterpret_cast(D.x[0]); const bf16x8_t& a_frag = reinterpret_cast(A.x[0]); const bf16x8_t& b_frag = reinterpret_cast(B.x[0]); acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag); +#endif // RDNA4 #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -907,14 +944,14 @@ namespace ggml_cuda_mma { #endif // defined(CDNA3) #elif defined(AMD_WMMA_AVAILABLE) - using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; - int32x2_t * a_vec = (int32x2_t *) A.x; - int32x2_t * b_vec = (int32x2_t *) B.x; using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; int32x8_t * acc = (int32x8_t *) D.x; #if defined(RDNA4) + using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; + int32x2_t * a_vec = (int32x2_t *) A.x; + int32x2_t * b_vec = (int32x2_t *) B.x; acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( true, @@ -933,7 +970,30 @@ namespace ggml_cuda_mma { acc[0], true ); -#endif // defined(RDNA4) + +#elif defined(RDNA3) + using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + int32x4_t * a_vec = (int32x4_t *) A.x; + int32x4_t * b_vec = (int32x4_t *) B.x; + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( + true, + a_vec[0], + true, + b_vec[0], + acc[0], + true + ); + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( + true, + a_vec[1], + true, + b_vec[1], + acc[0], + true + ); +#endif #else GGML_UNUSED_VARS(D, A, B); @@ -1020,21 +1080,35 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ void mma( tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) { #if defined(AMD_WMMA_AVAILABLE) - using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; - int32x2_t * a_vec = (int32x2_t *) A.x; - int32x2_t * b_vec = (int32x2_t *) B.x; - - using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; - int32x8_t * acc = (int32x8_t *) D.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - false - ); + using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; + int32x8_t * acc = (int32x8_t *) D.x; +#if defined(RDNA4) + using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; + int32x2_t * a_vec = (int32x2_t *) A.x; + int32x2_t * b_vec = (int32x2_t *) B.x; + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, + a_vec[0], + true, + b_vec[0], + acc[0], + false + ); +#elif defined(RDNA3) + using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + int32x4_t * a_vec = (int32x4_t *) A.x; + int32x4_t * b_vec = (int32x4_t *) B.x; + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( + true, + a_vec[0], + true, + b_vec[0], + acc[0], + false + ); +#endif #else GGML_UNUSED(D); GGML_UNUSED(A); @@ -1043,4 +1117,3 @@ static __device__ __forceinline__ void mma( #endif } } - diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 03ceba874d8..deefb0a0943 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -307,10 +307,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { } if (amd_wmma_available(cc)) { - if (GGML_CUDA_CC_IS_RDNA4(cc)) { + if (GGML_CUDA_CC_IS_RDNA4(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) { return true; } } - return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + } From 888b788e285356a5c813c271a31661b4d56d4259 Mon Sep 17 00:00:00 2001 From: jiachengjason Date: Fri, 28 Nov 2025 10:23:02 -0500 Subject: [PATCH 2/5] fixed the last q2_k test case failure --- ggml/src/ggml-cuda/mmq.cuh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 82468b384e2..75c69210fe4 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -1544,6 +1544,8 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( tile_A A1; A1.x[0] = 0x01010101; A1.x[1] = 0x01010101; + A1.x[2] = 0x01010101; + A1.x[3] = 0x01010101; mma(Cm, A1, B); } From 40e435c7b57bbd7ed44442dee64729bba87a3116 Mon Sep 17 00:00:00 2001 From: jiachengjason Date: Wed, 3 Dec 2025 15:50:38 -0500 Subject: [PATCH 3/5] address comments: fix out of bound write for RDNA4, add comments after #endif --- ggml/src/ggml-cuda/mma.cuh | 16 ++++++++-------- ggml/src/ggml-cuda/mmq.cu | 4 +--- ggml/src/ggml-cuda/mmq.cuh | 8 ++++---- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index a9e04df4b07..57f9eebda21 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -175,7 +175,7 @@ namespace ggml_cuda_mma { static constexpr int ne = I * J / 32; #elif defined(RDNA3) static constexpr int ne = (I == 16 && J == 16) ? I * J / 32 : I * J / 16; -#endif +#endif // defined(RDNA4) T x[ne] = {0}; static constexpr __device__ bool supported() { @@ -189,7 +189,7 @@ namespace ggml_cuda_mma { return 8 * (threadIdx.x / 16) + l; #elif defined(RDNA3) return 2 * l + (threadIdx.x / 16); -#endif +#endif // defined(RDNA4) } else { NO_DEVICE_CODE; return -1; @@ -294,7 +294,7 @@ namespace ggml_cuda_mma { static constexpr int ne = (I == 16 && J == 16) ? I * J / 32 : I * J / 16; #else static constexpr int ne = I * J / 32; -#endif +#endif // defined(RDNA3) half2 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { @@ -376,7 +376,7 @@ namespace ggml_cuda_mma { static constexpr int ne = (I == 16 && J == 16) ? I * J / 32 : I * J / 16; #else static constexpr int ne = I * J / 32; -#endif +#endif // defined(RDNA3) nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; #if defined(AMD_WMMA_AVAILABLE) @@ -593,7 +593,7 @@ namespace ggml_cuda_mma { } else { NO_DEVICE_CODE; } -#endif +#endif // defined(RDNA4) #else #pragma unroll for (int l = 0; l < t.ne; ++l) { @@ -993,7 +993,7 @@ namespace ggml_cuda_mma { acc[0], true ); -#endif +#endif // RDNA4 #else GGML_UNUSED_VARS(D, A, B); @@ -1108,12 +1108,12 @@ static __device__ __forceinline__ void mma( acc[0], false ); -#endif +#endif // RDNA4 #else GGML_UNUSED(D); GGML_UNUSED(A); GGML_UNUSED(B); NO_DEVICE_CODE; -#endif +#endif // AMD_WMMA_AVAILABLE } } diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index deefb0a0943..f7a2cbca90f 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -307,9 +307,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { } if (amd_wmma_available(cc)) { - if (GGML_CUDA_CC_IS_RDNA4(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) { - return true; - } + return true; } return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 75c69210fe4..1298f99fff6 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -1542,10 +1542,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( tile_C Cm; if (k01 >= MMQ_TILE_NE_K * 3/4) { tile_A A1; - A1.x[0] = 0x01010101; - A1.x[1] = 0x01010101; - A1.x[2] = 0x01010101; - A1.x[3] = 0x01010101; +#pragma unroll + for (int l = 0; l < tile_A::ne; ++l) { + A1.x[l] = 0x01010101; + } mma(Cm, A1, B); } From e4fecbca51bcf5df38957f9052a27ab89af43724 Mon Sep 17 00:00:00 2001 From: jiachengjason Date: Thu, 4 Dec 2025 00:40:30 -0500 Subject: [PATCH 4/5] clean up rebase: fix ne error in half2 --- ggml/src/ggml-cuda/mma.cuh | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 57f9eebda21..5560c47a900 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -290,11 +290,8 @@ namespace ggml_cuda_mma { } } #elif defined(AMD_WMMA_AVAILABLE) -#if defined(RDNA3) - static constexpr int ne = (I == 16 && J == 16) ? I * J / 32 : I * J / 16; -#else + static constexpr int ne = I * J / 32; -#endif // defined(RDNA3) half2 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { @@ -371,12 +368,6 @@ namespace ggml_cuda_mma { static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; static constexpr int ne = I * J / WARP_SIZE; -#if defined(AMD_WMMA_AVAILABLE) -#if defined(RDNA3) - static constexpr int ne = (I == 16 && J == 16) ? I * J / 32 : I * J / 16; -#else - static constexpr int ne = I * J / 32; -#endif // defined(RDNA3) nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; #if defined(AMD_WMMA_AVAILABLE) From 685be0e1e756875d2ab8c5e103fadfb23936b510 Mon Sep 17 00:00:00 2001 From: jiachengjason Date: Thu, 4 Dec 2025 10:05:12 -0500 Subject: [PATCH 5/5] fix the EditorConfig CI --- ggml/src/ggml-cuda/mma.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 5560c47a900..625a367a5b2 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -962,7 +962,7 @@ namespace ggml_cuda_mma { true ); -#elif defined(RDNA3) +#elif defined(RDNA3) using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * a_vec = (int32x4_t *) A.x; int32x4_t * b_vec = (int32x4_t *) B.x;