From a3c9d1d6990cb633d26f0a94391aebfa00c8a2f2 Mon Sep 17 00:00:00 2001 From: lhl Date: Tue, 28 Oct 2025 17:33:47 +0000 Subject: [PATCH 1/2] HIP/WMMA: retune WMMA FlashAttention on RDNA3\n\n- Increase block residency on HIP via __launch_bounds__ (min 2 blocks/SM)\n- Adaptive KQ stride on HIP: 128 for D<=128 to reduce LDS footprint\n- Update loops and launch to use the adaptive stride; bump nwarps for small D\n- No behavior change on CUDA; improves prefill perf on RDNA3 --- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 56 +++++++++++++++++++--------- 1 file changed, 39 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 6c90d6d52b335..67e5993de0f77 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -20,9 +20,24 @@ namespace wmma = rocwmma; #endif // !defined(GGML_USE_HIP) #endif // GGML_USE_WMMA_FATTN +#if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN) +static constexpr int GGML_ROCWMMA_FATTN_MIN_BLOCKS_PER_SM = 2; +#else +static constexpr int GGML_ROCWMMA_FATTN_MIN_BLOCKS_PER_SM = 1; +#endif + +template +constexpr int ggml_wmma_fattn_kq_stride() { +#if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN) + return D <= 128 ? 128 : FATTN_KQ_STRIDE; +#else + return FATTN_KQ_STRIDE; +#endif +} + // D == head size, VKQ_stride == num VKQ rows calculated in parallel: template -__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1) +__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), GGML_ROCWMMA_FATTN_MIN_BLOCKS_PER_SM) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -55,10 +70,12 @@ static __global__ void flash_attn_ext_f16( //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int fattn_kq_stride = ggml_wmma_fattn_kq_stride(); + + static_assert(D <= fattn_kq_stride, "D must be <= fattn_kq_stride."); const int ic0 = ncols*blockIdx.x; // Index of the first Q/QKV column to work on. - static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; @@ -75,7 +92,7 @@ static __global__ void flash_attn_ext_f16( // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: constexpr int D_padded = D + 8; - constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; + constexpr int kqs_padded = fattn_kq_stride + 8; constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); const int sequence = blockIdx.z / ne02; @@ -168,10 +185,10 @@ static __global__ void flash_attn_ext_f16( // Iterate over ne11 == previous tokens: const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; - for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) { + for (int k_VKQ_0 = blockIdx.y*fattn_kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*fattn_kq_stride) { // Calculate tile of KQ: #pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { + for (int i_KQ_0 = 0; i_KQ_0 < fattn_kq_stride; i_KQ_0 += KQ_stride_tc) { frag_c_KQ KQ_c[ncols/frag_n]; #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { @@ -201,9 +218,9 @@ static __global__ void flash_attn_ext_f16( const int j = j0 + threadIdx.y; if (std::is_same::value) { - float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size]; + float KQ_f_tmp[fattn_kq_stride / warp_size]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { + for (int k0 = 0; k0 < fattn_kq_stride; k0 += warp_size) { const int k = k0 + threadIdx.x; KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k]; @@ -215,7 +232,7 @@ static __global__ void flash_attn_ext_f16( float KQ_max_new = KQ_max_f[j0/nwarps]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { + for (int k0 = 0; k0 < fattn_kq_stride; k0 += warp_size) { const int k = k0 + threadIdx.x; KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; @@ -232,7 +249,7 @@ static __global__ void flash_attn_ext_f16( float KQ_rowsum_add = 0.0f; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { + for (int k0 = 0; k0 < fattn_kq_stride; k0 += warp_size) { const int k = k0 + threadIdx.x; const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps]; @@ -248,9 +265,9 @@ static __global__ void flash_attn_ext_f16( // Scale previous KQ_rowsum to account for a potential increase in KQ_max: KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; } else { - half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*warp_size)]; + half2 KQ2_tmp[fattn_kq_stride/(2*warp_size)]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { + for (int k0 = 0; k0 < fattn_kq_stride/2; k0 += warp_size) { const int k = k0 + threadIdx.x; KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2) + k]; @@ -267,7 +284,7 @@ static __global__ void flash_attn_ext_f16( half2 KQ_max_new = KQ_max_h2[j0/nwarps]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { + for (int k0 = 0; k0 < fattn_kq_stride/2; k0 += warp_size) { const int k = k0 + threadIdx.x; KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); @@ -282,7 +299,7 @@ static __global__ void flash_attn_ext_f16( half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { + for (int k0 = 0; k0 < fattn_kq_stride/2; k0 += warp_size) { const int k = k0 + threadIdx.x; const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps]; @@ -301,11 +318,11 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); - frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n]; + frag_b KQ_b[fattn_kq_stride/(VKQ_ratio*16)][ncols/frag_n]; #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + for (int k0 = 0; k0 < fattn_kq_stride; k0 += VKQ_ratio*16) { const int k = k0 + (threadIdx.y % VKQ_ratio)*16; wmma::load_matrix_sync( KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], @@ -323,7 +340,7 @@ static __global__ void flash_attn_ext_f16( } #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + for (int k0 = 0; k0 < fattn_kq_stride; k0 += VKQ_ratio*16) { const int k = k0 + (threadIdx.y % VKQ_ratio)*16; frag_a_V v_a; @@ -512,7 +529,12 @@ template void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; +#if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN) + constexpr int nwarps = D <= 96 ? 8 : 4; +#else constexpr int nwarps = 4; +#endif + constexpr int fattn_kq_stride = ggml_wmma_fattn_kq_stride(); constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16; const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; @@ -530,7 +552,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm fattn_kernel = flash_attn_ext_f16< D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>; } - launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size); + launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, fattn_kq_stride, true, true, false, warp_size); } void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { From a45e1cd6e9f306a4708cb98912b2bd37e8b70fff Mon Sep 17 00:00:00 2001 From: lhl Date: Tue, 28 Oct 2025 17:33:47 +0000 Subject: [PATCH 2/2] HIP: use WMMA for prefill only; fix decode regression by enabling TILE and adding a safe fallback\n\n- Do not select WMMA for decode on HIP; fall through to VEC/TILE\n- Remove WMMA TILE pruning on HIP to avoid device traps; keep for CUDA WMMA\n- Add decode-time guard: if predicted TILE split has no config, select VEC\n- Remove ad-hoc env overrides and debug prints --- ggml/src/ggml-cuda/fattn-tile.cuh | 6 +++- ggml/src/ggml-cuda/fattn.cu | 55 ++++++++++++++++++++++++++++++- 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 2b60b3bb13563..830f111daf872 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -721,9 +721,12 @@ static __global__ void flash_attn_tile( // Skip unused kernel variants for faster compilation: + // Optionally disable pruning to keep all TILE variants for testing. +#if !defined(GGML_USE_HIP) if ( #ifdef GGML_USE_WMMA_FATTN - (ncols2 != 1 && DV != 40 && DV != 512) || + // On CUDA WMMA builds, prune some TILE variants to reduce compile time/binary size. + (ncols2 != 1 && DV != 40 && DV != 64 && DV != 128 && DV != 256 && DV != 512) || #endif // GGML_USE_WMMA_FATTN (use_logit_softcap && !(DV == 128 || DV == 256)) ) { @@ -739,6 +742,7 @@ static __global__ void flash_attn_tile( NO_DEVICE_CODE; return; } +#endif // !defined(GGML_USE_HIP) static_assert(ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, "kernel config not defined"); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 7dee032c29137..04acf2772841f 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -301,13 +301,66 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // Use the WMMA kernel if possible: - if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) { +#if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN) + const bool hip_wmma_decode = Q->ne[1] == 1; +#else + const bool hip_wmma_decode = false; +#endif + if (!hip_wmma_decode && ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) { if (can_use_vector_kernel && Q->ne[1] <= 2) { return BEST_FATTN_KERNEL_VEC; } return BEST_FATTN_KERNEL_WMMA_F16; } + // HIP decode path (Q->ne[1] == 1): fall through to generic HIP selection below (VEC/TILE), + // with a guard to avoid selecting a TILE shape that has no config. + if (hip_wmma_decode) { +#if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN) + // Mirror the ncols2 selection from launch_fattn_tile_switch_ncols2 to predict if + // a multi-column TILE kernel (ncols2 != 1) would be chosen. + const bool nvidia_arch = GGML_CUDA_CC_IS_NVIDIA(cc); + const int gqa_limit = (nvidia_arch && gqa_ratio <= 4) ? 16 : INT_MAX; + const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; + + int predicted_ncols2 = 1; + if (V->ne[0] == 512) { + if (use_gqa_opt && gqa_ratio % 16 == 0) predicted_ncols2 = 16; + } else if (V->ne[0] <= 256) { + if (use_gqa_opt && gqa_ratio % 8 == 0) predicted_ncols2 = 8; + else if (use_gqa_opt && gqa_ratio % 4 == 0) predicted_ncols2 = 4; + else if (use_gqa_opt && gqa_ratio % 2 == 0) predicted_ncols2 = 2; + } + + // Predict cols_per_block like launch_fattn_tile_switch_ncols1 does (HIP path): + int predicted_cols_per_block = 2; + if (predicted_ncols2 <= 2) { + predicted_cols_per_block = 2; + } + if (predicted_ncols2 <= 4 && Q->ne[1] > 2/predicted_ncols2) { + predicted_cols_per_block = 4; + } + if (predicted_ncols2 <= 8 && Q->ne[1] > 4/predicted_ncols2) { + predicted_cols_per_block = 8; + } + if (Q->ne[1] > 8/predicted_ncols2) { + predicted_cols_per_block = 16; + } + if (Q->ne[1] > 16/predicted_ncols2) { + predicted_cols_per_block = 32; + } + if (V->ne[0] <= 128 && Q->ne[1] > 32/predicted_ncols2) { + predicted_cols_per_block = 64; + } + + const uint32_t cfg = ggml_cuda_fattn_tile_get_config((int)Q->ne[0], (int)V->ne[0], predicted_cols_per_block, cc); + if (predicted_ncols2 != 1 && cfg == 0) { + return BEST_FATTN_KERNEL_VEC; + } +#endif // defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN) + // Otherwise, fall through. + } + // If there are no tensor cores available, use the generic tile kernel: if (can_use_vector_kernel) { if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {