From 17f191e9d7c090e387a60ea2de869e2932476d1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 10 Nov 2025 23:17:47 +0100 Subject: [PATCH 1/6] CUDA: generalized (mma) FA, add Volta support --- ggml/include/ggml.h | 2 +- ggml/src/ggml-cuda/fattn-common.cuh | 22 +- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 1236 +++++++++++++------------ ggml/src/ggml-cuda/fattn-tile.cuh | 36 +- ggml/src/ggml-cuda/fattn-vec.cuh | 18 +- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 17 +- ggml/src/ggml-cuda/fattn-wmma-f16.cuh | 4 +- ggml/src/ggml-cuda/fattn.cu | 21 +- ggml/src/ggml-cuda/mma.cuh | 257 +++-- ggml/src/ggml-cuda/mmf.cuh | 59 +- 10 files changed, 936 insertions(+), 736 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 4dbca868bc7..c18dd75b4a5 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2278,7 +2278,7 @@ extern "C" { float stop, float step); -#define GGML_KQ_MASK_PAD 64 +#define GGML_KQ_MASK_PAD 1 // q: [n_embd_k, n_batch, n_head, ne3 ] // k: [n_embd_k, n_kv, n_head_kv, ne3 ] diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 5cdd4bb2114..02443b8c638 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -25,7 +25,7 @@ typedef void (* fattn_kernel_t)( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, @@ -621,7 +621,8 @@ static __global__ void flash_attn_mask_to_KV_max( template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( - float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) { + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11, + const int nbatch_fa) { constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; @@ -632,8 +633,8 @@ static __global__ void flash_attn_stream_k_fixup( const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); - const int iter_k = ne11 / FATTN_KQ_STRIDE; - const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; @@ -765,7 +766,7 @@ static __global__ void flash_attn_combine_results( template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, - const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE + const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE ) { constexpr int ncols = ncols1 * ncols2; @@ -790,8 +791,6 @@ void launch_fattn( GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V)); GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); - GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && - "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); @@ -915,7 +914,7 @@ void launch_fattn( dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float)); } else { - const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. + const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size. // parallel_blocks must not be larger than what the tensor size allows: parallel_blocks = std::min(parallel_blocks, ntiles_KQ); @@ -970,6 +969,9 @@ void launch_fattn( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + // TODO other tensor dimensions after removal of WMMA kernel: + const uint3 ne01 = init_fastdiv_values(Q->ne[1]); + GGML_ASSERT(block_dim.x % warp_size == 0); fattn_kernel<<>>( (const char *) Q->data, @@ -980,7 +982,7 @@ void launch_fattn( KV_max.ptr, !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], + Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13, nb21, nb22, nb23, mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, @@ -995,7 +997,7 @@ void launch_fattn( flash_attn_stream_k_fixup <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]); + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa); } } else if (parallel_blocks > 1) { const dim3 block_dim_combine(DV, 1, 1); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 57defb0c629..62e3d0a673b 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -5,14 +5,6 @@ using namespace ggml_cuda_mma; -typedef tile<16, 8, half2> tile_A; -typedef tile< 8, 8, half2> tile_B; -typedef tile<16, 8, half2> tile_B_16; -typedef tile<16, 8, float> tile_C_KQ; -typedef tile<16, 16, float> tile_C_KQ_16; -typedef tile<16, 4, half2> tile_C_VKQ; -typedef tile<16, 8, half2> tile_C_VKQ_16; - // Config options for specific head sizes. // Should not affect results, only speed/register pressure/shared memory use. // @@ -24,265 +16,202 @@ typedef tile<16, 8, half2> tile_C_VKQ_16; // nbatch_V2: number of V half2 values in direction of DV to load in parallel. // nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel. -template -struct fattn_mma_f16_config; - -template <> -struct fattn_mma_f16_config< 64, 64> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; - - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 32; - } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 32; - } - - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 32; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 32; - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 32; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 32; - } -}; - -template <> -struct fattn_mma_f16_config< 80, 80> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; - - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 40; - } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 40; - } - - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 40; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 40; - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 40; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 40; - } -}; - -template <> -struct fattn_mma_f16_config< 96, 96> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; - - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 48; - } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 48; - } +// The ROCm compiler cannot handle templating in __launch_bounds__. +// As a workaround, define a macro to package the kernel parameters as uint32_t: +#define GGML_CUDA_FATTN_MMA_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads, occupancy, nbatch_fa, nbatch_K2, nbatch_V2, nbatch_combine, nstages_target, Q_in_reg) \ + if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \ + static_assert((nthreads) % 32 == 0 && (nthreads) <= 512, "bad nthreads"); \ + static_assert( (occupancy) <= 8, "bad occupancy"); \ + static_assert((nbatch_fa) % 32 == 0 && (nbatch_fa) <= 256, "bad nbatch_fa"); \ + static_assert((nbatch_K2) % 4 == 0 && (nbatch_K2) <= 512, "bad nbatch_K2"); \ + static_assert((nbatch_V2) % 4 == 0 && (nbatch_V2) <= 256, "bad nbatch_V2"); \ + static_assert((nbatch_combine) % 4 == 0 && (nbatch_combine) <= 128, "bad nbatch_combine"); \ + static_assert((nstages_target) >= 1 && (nstages_target) <= 2, "bad nstages_target"); \ + return ((((nthreads) / 32) - 1) << 0) | \ + ((((occupancy) / 1) - 1) << 4) | \ + ((((nbatch_fa) / 32) - 1) << 7) | \ + ((((nbatch_K2) / 8) - 1) << 10) | \ + ((((nbatch_V2) / 8) - 1) << 17) | \ + ((((nbatch_combine) / 8) - 1) << 23) | \ + ((((nstages_target) / 1) - 1) << 28) | \ + (((Q_in_reg) ? 1 : 0) << 29); \ + } \ + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_mma_get_config_ampere(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 128, 2, 64, 32, 32, 32, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 128, 2, 64, 40, 40, 40, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 128, 2, 64, 48, 48, 48, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2, 64, 56, 56, 56, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + + return 0; +} - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 48; - } +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_mma_get_config_turing(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 128, 2, 64, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 48; - } + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 48; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 48; - } -}; - -template <> -struct fattn_mma_f16_config<112, 112> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); +} - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 56; - } +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false); - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 56; - } + // TODO tune specifically for Volta + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); +} - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 56; +static __host__ uint32_t ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) { + if (ampere_mma_available(cc)) { + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 56; - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 56; + if (ampere_mma_available(cc)) { + return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); } + GGML_ASSERT(volta_mma_available(cc)); + return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols); +} - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 56; - } -}; +static constexpr __device__ uint32_t ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols) { +#if defined(AMPERE_MMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); +#elif defined(TURING_MMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); +#elif defined(VOLTA_MMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols); +#else + GGML_UNUSED_VARS(DKQ, DV, ncols); + return 0; +#endif // defined(AMPERE_MMA_AVAILABLE) +} -template <> -struct fattn_mma_f16_config<128, 128> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; +static __host__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) { + return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc) >> 0) & ((1 << 4) - 1)) + 1) * 32; +} - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 64; - } +static constexpr __device__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols) { + return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols) >> 0) & ((1 << 4) - 1)) + 1) * 32; +} - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 64; - } +static __host__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) { + return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc) >> 4) & ((1 << 3) - 1)) + 1) * 1; +} - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 64; - } +static constexpr __device__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols) { + return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols) >> 4) & ((1 << 3) - 1)) + 1) * 1; +} - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 64; - } +static __host__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) { + return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc) >> 7) & ((1 << 3) - 1)) + 1) * 32; +} - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 64; - } +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols) { + return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols) >> 7) & ((1 << 3) - 1)) + 1) * 32; +} - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 64; - } -}; +static __host__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols, const int cc) { + return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc) >> 10) & ((1 << 7) - 1)) + 1) * 8; +} -template <> -struct fattn_mma_f16_config<256, 256> { - static constexpr int nbatch_fa = 32; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols) { + return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols) >> 10) & ((1 << 7) - 1)) + 1) * 8; +} - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 128; - } +static __host__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols, const int cc) { + return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc) >> 17) & ((1 << 6) - 1)) + 1) * 8; +} - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 128; - } +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols) { + return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols) >> 17) & ((1 << 6) - 1)) + 1) * 8; +} - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 128; - } +static __host__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols, const int cc) { + return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc) >> 23) & ((1 << 5) - 1)) + 1) * 8; +} - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 128; - } +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols) { + return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols) >> 23) & ((1 << 5) - 1)) + 1) * 8; +} - static int get_nbatch_combine_host(const int cc, const int ncols) { - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { - return ncols <= 16 ? 128 : 64; - } - return 64; - } +static __host__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols, const int cc) { + return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc) >> 28) & ((1 << 2) - 1)) + 1) * 1; +} - static constexpr __device__ int get_nbatch_combine_device(int ncols) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING - return ncols <= 16 ? 128 : 64; -#else - GGML_UNUSED(ncols); - return 128; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING - } -}; +static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols) { + return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols) >> 28) & ((1 << 2) - 1)) + 1) * 1; +} -template <> -struct fattn_mma_f16_config<576, 512> { - static constexpr int nbatch_fa = 32; - static constexpr int nwarps_max = 8; - static constexpr bool Q_in_reg = false; - static constexpr int nstages_target = 1; +static __host__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc) >> 29) & 1; +} - static int get_nbatch_K2_host(const int cc, const int ncols) { - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { - return ncols <= 16 ? 96 : 160; - } - return ncols <= 16 ? 288 : 160; - } +static constexpr __device__ int ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols) >> 29) & 1; +} - static constexpr __device__ int get_nbatch_K2_device(int ncols) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING - return ncols <= 16 ? 96 : 160; -#else - return ncols <= 16 ? 288 : 160; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING - } +// ------------------------------------------------------------------------------------------------------------------ - static int get_nbatch_V2_host(const int cc, const int ncols) { - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { - return ncols <= 16 ? 64 : 128; - } - return ncols <= 16 ? 256 : 128; - } +static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) { + return cp_async_available(cc) && ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2, cc) : 0; +} - static constexpr __device__ int get_nbatch_V2_device(int ncols) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING - return ncols <= 16 ? 64 : 128; +static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2) { +#ifdef CP_ASYNC_AVAILABLE + return ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2) : 0; #else - return ncols <= 16 ? 256 : 128; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 128; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 128; - } -}; + GGML_UNUSED_VARS(DKQ, DV, ncols1, ncols2); + return 0; +#endif // CP_ASYNC_AVAILABLE +} // ------------------------------------------------------------------------------------------------------------------ -template +template static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( - const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) { - + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) { // K/V data is loaded with decreasing granularity for D for better memory bandwidth. // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. - - if (use_cp_async) { + if constexpr (use_cp_async) { + static_assert(!oob_check, "OOB check not compatible with cp_async"); constexpr int preload = 64; constexpr int h2_per_chunk = 16/sizeof(half2); const int chunks_per_row = D2 / h2_per_chunk; @@ -315,9 +244,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( } } }; - ggml_cuda_unroll<5>{}(load); + // 1: max 32*16=512 bytes, 256 half + // 2: max 16*16=256 bytes, 128 half + // 3: max 8*16=128 bytes, 64 half + // 4: max 4*16= 64 bytes, 32 half + // 5: max 2*16= 32 bytes, 16 half + // 6: max 1*16= 16 bytes, 8 half + ggml_cuda_unroll<6>{}(load); } else { - static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds"); + // TODO use ggml_cuda_memcpy_1 auto load = [&] __device__ (const int n) { const int stride_k = WARP_SIZE >> n; const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k); @@ -340,20 +275,25 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - tile_KV[i*stride_tile + k] = KV[i*stride_KV + k]; + tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f); } } }; - ggml_cuda_unroll<3>{}(load); + // 1: max 32* 4=128 bytes, 64 half + // 2: max 16* 4= 64 bytes, 32 half + // 3: max 8* 4= 32 bytes, 16 half + // 4: max 4* 4= 16 bytes, 8 half + ggml_cuda_unroll<4>{}(load); } } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( - const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) { - static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter"); - - if (use_cp_async) { + const half * const __restrict__ mask_h, half * const __restrict__ tile_mask, + const int stride_mask, const int i_sup, const int j0, const uint3 ne01) { + if constexpr (use_cp_async) { + static_assert(nbatch_fa <= 8*WARP_SIZE && nbatch_fa % 8 == 0, "bad nbatch_fa"); + static_assert(!oob_check, "OOB check incompatible with cp_async"); constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64; constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa; constexpr int stride_j = nwarps * cols_per_warp; @@ -361,50 +301,85 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask); #pragma unroll - for (int j0 = 0; j0 < ncols1; j0 += stride_j) { - const int j = j0 + threadIdx.y*cols_per_warp + - (nbatch_fa == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/cols_per_warp)); + for (int j1 = 0; j1 < ncols1; j1 += stride_j) { + const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp); + const int j_vram = fastmodulo(j0 + j_sram, ne01); - if (j0 + stride_j > ncols1 && j >= ncols1) { + if (j1 + stride_j > ncols1 && j_sram >= ncols1) { break; } - const int i = 4 * (threadIdx.x % (nbatch_fa/8)); + const int i = 8 * (threadIdx.x % (nbatch_fa/8)); - cp_async_cg_16(tile_mask_32 + j*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i); + cp_async_cg_16(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i); } - return; - } + } else if constexpr (oob_check) { +#pragma unroll + for (int j1 = 0; j1 < ncols1; j1 += nwarps) { + const int j_sram = j1 + threadIdx.y; + const int j_vram = fastmodulo(j0 + j_sram, ne01); + + if (j1 + nwarps > ncols1 && j_sram >= ncols1) { + break; + } - constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa; - constexpr int stride_j = nwarps * cols_per_warp; #pragma unroll - for (int j0 = 0; j0 < ncols1; j0 += stride_j) { - const int j = j0 + threadIdx.y*cols_per_warp + (nbatch_fa == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/cols_per_warp)); + for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; - if (j0 + stride_j > ncols1 && j >= ncols1) { - break; + tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f); + } } + } else if constexpr (nbatch_fa < 2*WARP_SIZE) { + constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa; + constexpr int stride_j = nwarps * cols_per_warp; +#pragma unroll + for (int j1 = 0; j1 < ncols1; j1 += stride_j) { + const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp); + const int j_vram = fastmodulo(j0 + j_sram, ne01); - const int i = nbatch_fa == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/cols_per_warp); + if (j1 + stride_j > ncols1 && j_sram >= ncols1) { + break; + } + + const int i = threadIdx.x % (WARP_SIZE/cols_per_warp); - tile_mask[j*(nbatch_fa/2 + 4) + i] = mask_h2[j*stride_mask + i]; + ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i); + } + } else { +#pragma unroll + for (int j1 = 0; j1 < ncols1; j1 += nwarps) { + const int j_sram = j1 + threadIdx.y; + const int j_vram = fastmodulo(j0 + j_sram, ne01); + + if (j1 + nwarps > ncols1 && j_sram >= ncols1) { + break; + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) { + const int i = i0 + 2*threadIdx.x; + + ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i); + } + } } } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, - const half2 * const __restrict__ mask_h2, + const half * const __restrict__ mask_h, float2 * const __restrict__ dstk, float2 * const __restrict__ dstk_fixup, const float scale, const float slope, const float logit_softcap, - const int ne01, + const uint3 ne01, const int ne02, const int stride_K, const int stride_V, @@ -412,27 +387,24 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( half2 * const __restrict__ tile_Q, half2 * const __restrict__ tile_K, half2 * const __restrict__ tile_V, - half2 * const __restrict__ tile_mask, - const tile_B * const __restrict__ Q_B, - tile_C_VKQ * const __restrict__ VKQ_C, + half * const __restrict__ tile_mask, + T_B_KQ * const __restrict__ Q_B, + T_C_VKQ * const __restrict__ VKQ_C, float * const __restrict__ KQ_max, float * const __restrict__ KQ_rowsum, - const int kb0) { -#ifdef TURING_MMA_AVAILABLE - typedef fattn_mma_f16_config c; - -#ifdef CP_ASYNC_AVAILABLE - constexpr int nstages = c::nstages_target; -#else - constexpr int nstages = 0; -#endif // CP_ASYNC_AVAILABLE - - constexpr int cols_per_warp = ntiles * tile_B::I; - constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; - constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. - constexpr int ncols = ncols1 * ncols2; - constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); - constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); + const int jt, + const int kb0, + const int k_VKQ_sup) { +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + constexpr int ncols = ncols1 * ncols2; + constexpr int cols_per_warp = T_B_KQ::I; + constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column. + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); + constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols); + constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols); + constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols); + constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2); constexpr int stride_tile_Q = DKQ/2 + 4; constexpr int stride_tile_K = nbatch_K2 + 4; @@ -440,26 +412,27 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; - const int k_VKQ_0 = kb0 * c::nbatch_fa; - tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles]; - - // Use wide variants of tiles if ntiles >= 2. - tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; - tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; - tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; + const int k_VKQ_0 = kb0 * nbatch_fa; +#if defined(TURING_MMA_AVAILABLE) + T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))]; +#else // Volta + T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)]; +#endif // defined(TURING_MMA_AVAILABLE) if constexpr (nstages > 1) { + static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline"); static_assert(!mla, "multi-stage loading not implemented for MLA"); static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); constexpr bool use_cp_async = true; cp_async_wait_all(); __syncthreads(); - flash_attn_ext_f16_load_tile - (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V); + flash_attn_ext_f16_load_tile + (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V, k_VKQ_sup); } else { constexpr bool use_cp_async = nstages == 1; - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); + if (ncols2 > 1 || mask_h) { + flash_attn_ext_f16_load_mask + (mask_h + k_VKQ_0, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01); } } @@ -468,10 +441,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; const int k0_diff = k0_stop - k0_start; - if (nstages <= 1) { + if constexpr (nstages <= 1) { constexpr bool use_cp_async = nstages == 1; - flash_attn_ext_f16_load_tile - (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K); + flash_attn_ext_f16_load_tile + (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K, k_VKQ_sup); if (use_cp_async) { cp_async_wait_all(); } @@ -479,55 +452,53 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } // Calculate tile of KQ: - if constexpr (c::Q_in_reg) { + if constexpr (Q_in_reg) { #pragma unroll - for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { - const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; + for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I; #pragma unroll - for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { - tile_A K_A; + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) { + T_A_KQ K_A; load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); - if (ntiles == 1) { - mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); + if constexpr (cols_per_warp == 8) { + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]); } else { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - // Wide version of KQ_C is column-major => swap A and B. - mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A); - } + // Wide version of KQ_C is column-major => swap A and B. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A); } } } } else { - static_assert(ntiles == 2, "ntiles != 2 not implemented"); + static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented"); #pragma unroll - for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { - load_ldmatrix(Q_B_16[0], tile_Q + (threadIdx.y / np)*(tile_B_16::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) { + load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); #pragma unroll - for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { - const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; + for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I; - tile_A K_A; + T_A_KQ K_A; load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); // Wide version of KQ_C is column-major => swap A and B. - mma(KQ_C_16[i_KQ_00/(np*tile_A::I)], Q_B_16[0], K_A); + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); } } } - if (nstages <= 1) { + if constexpr (nstages <= 1) { __syncthreads(); // Only needed if tile_K == tile_V. } } if (use_logit_softcap) { - static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); + constexpr int stride = cols_per_warp == 8 ? np*T_C_KQ::I : np*T_C_KQ::J; + static_assert(nbatch_fa % stride == 0, "bad loop size"); #pragma unroll - for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) { + for (int i = 0; i < nbatch_fa/stride; ++i) { #pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { + for (int l = 0; l < T_C_KQ::ne; ++l) { KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); } } @@ -540,34 +511,35 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } float KQ_rowsum_add[cols_per_thread] = {0.0f}; - if (ntiles == 1) { - if (ncols2 > 1 || mask_h2) { + if constexpr (cols_per_warp == 8) { + if (ncols2 > 1 || mask_h) { #pragma unroll - for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) { - const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; + for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::I) { + const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::I; #pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { - const int i = i0 + tile_C_KQ::get_i(l); - const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2; + for (int l = 0; l < T_C_KQ::ne; ++l) { + const int i = i0 + T_C_KQ::get_i(l); + const int j = ((threadIdx.y / np)*T_C_KQ::J + T_C_KQ::get_j(l)) / ncols2; - KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope * - __half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]); + KQ_C[i00/(np*T_C_KQ::I)].x[l] += slope * __half2float(tile_mask[j*(nbatch_fa + 8) + i]); } } } // Calculate softmax for each KQ column using the current max. value. // The divisor is stored in KQ_rowsum and will be applied at the end. - static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) { #pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { - KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]); + for (int l = 0; l < T_C_KQ::ne; ++l) { + if (!oob_check || k0 + T_C_KQ::get_i(l) < k_VKQ_sup) { + KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l]); + } } } - // Values per KQ column are spread across 8 threads, does not need full warp reduce: + // Values per KQ column are spread across 8 threads: #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll @@ -576,73 +548,78 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } - static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) { #pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { - KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]); - - KQ_rowsum_add[l % 2] += KQ_C[k].x[l]; + for (int l = 0; l < T_C_KQ::ne; ++l) { + if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) { + KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[l % 2]); + KQ_rowsum_add[l % 2] += KQ_C[k0/(np*T_C_KQ::I)].x[l]; + } else { + KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f; + } } } - } else { // ntiles > 1 - if (ncols2 > 1 || mask_h2) { -#pragma unroll - for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ_16::J) { - const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J; + } else { // not Turing mma or T_B_KQ::I > 8 + if (ncols2 > 1 || mask_h) { #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { + for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) { + const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J; #pragma unroll - for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) { - const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2; - const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2; + for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) { + const int i = (i0 + T_C_KQ::get_j(l0)) / 2; + const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l0)) / ncols2; - const float2 tmp = __half22float2(tile_mask[j*(c::nbatch_fa/2 + 4) + i]); - const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t; - KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x; - KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y; - } + const float2 tmp = __half22float2(((const half2 *)tile_mask)[j*(nbatch_fa/2 + 4) + i]); + KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x; + KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y; } } } // Calculate softmax for each KQ column using the current max. value. // The divisor is stored in KQ_rowsum and will be applied at the end. - static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); -#pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { + static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size"); #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) { #pragma unroll - for (int l = 0; l < tile_C_KQ_16::ne; ++l) { - const int KQ_index = 2*t + (l/2) % 2; - KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]); + for (int l = 0; l < T_C_KQ::ne; ++l) { + if (!oob_check || k0 + T_C_KQ::get_j(l) < k_VKQ_sup) { + // Turing + Volta: + KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l]); } } } - // Values per KQ column are spread across 4 threads, does not need full warp reduce: #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { +#if defined(TURING_MMA_AVAILABLE) + // Values per KQ column are spread across 4 threads: + constexpr int offset_first = 2; + constexpr int offset_last = 1; +#else + // Values per KQ column are spread across 2 threads: + constexpr int offset_first = 2; + constexpr int offset_last = 2; +#endif // defined(TURING_MMA_AVAILABLE) #pragma unroll - for (int offset = 2; offset >= 1; offset >>= 1) { + for (int offset = offset_first; offset >= offset_last; offset >>= 1) { KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); } } - static_assert(c::nbatch_fa % (np*tile_C_KQ_16::J) == 0, "bad loop size"); + static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) { #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { -#pragma unroll - for (int l = 0; l < tile_C_KQ_16::ne; ++l) { - const int KQ_index = 2*t + (l/2) % 2; - - KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]); - - KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l]; + for (int l = 0; l < T_C_KQ::ne; ++l) { + // Turing + Volta: + if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) { + KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[(l/2) % 2]); + KQ_rowsum_add[(l/2) % 2] += KQ_C[(k0/(np*T_C_KQ::J))].x[l]; + } else { + KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f; } } } @@ -662,12 +639,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col]; } - if (ntiles == 1) { +#if defined(TURING_MMA_AVAILABLE) + if constexpr (cols_per_warp == 8) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); #pragma unroll - for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { + for (int i = 0; i < DV/T_C_VKQ::I; ++i) { #pragma unroll - for (int l = 0; l < tile_C_VKQ::ne; ++l) { + for (int l = 0; l < T_C_VKQ::ne; ++l) { VKQ_C[i].x[l] *= KQ_max_scale_h2; } } @@ -676,46 +654,53 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int col = 0; col < cols_per_thread; ++col) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); #pragma unroll - for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { #pragma unroll - for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { - VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) { + VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2; } } } } +#else // Volta + const half2 KQ_max_scale_h2 = make_half2( + KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]); +#pragma unroll + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } +#endif // defined(TURING_MMA_AVAILABLE) } // Convert KQ C tiles into B tiles for VKQ calculation: - tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles]; - tile_B_16 * B_16 = (tile_B_16 *) B; - static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size"); - if (ntiles == 1) { + T_B_VKQ B[nbatch_fa/(np*2*T_B_VKQ::J)]; + static_assert(nbatch_fa % (np*2*T_B_VKQ::J) == 0, "bad loop size"); + if constexpr (cols_per_warp == 8) { #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) { + for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) { B[k] = get_transposed(get_half2(KQ_C[k])); } } else { - for (int k = 0; k < c::nbatch_fa/(np*2*tile_B_16::J); ++k) { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]); - } + for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) { + B[k] = get_half2(KQ_C[k]); } } - if (nstages > 1) { + if constexpr (nstages > 1) { // Preload K tile for next iteration: constexpr bool use_cp_async = true; cp_async_wait_all(); __syncthreads(); if (!last_iter) { - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask - (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); + if (ncols2 > 1 || mask_h) { + flash_attn_ext_f16_load_mask + (mask_h + k_VKQ_0 + nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01); } - flash_attn_ext_f16_load_tile - (K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K); + flash_attn_ext_f16_load_tile + (K_h2 + int64_t(k_VKQ_0 + nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup); } } @@ -724,72 +709,119 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // Therefore, iterate over V in reverse and re-use the data if possible. static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV; + + // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V: #pragma unroll for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) { const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0; const int i0_diff = i0_stop - i0_start; - if (nstages <= 1 && i0_start < reusable_cutoff) { - constexpr bool use_cp_async = nstages == 1; - flash_attn_ext_f16_load_tile - (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); - if (use_cp_async) { - cp_async_wait_all(); + if constexpr (nstages <= 1) { + if (i0_start < reusable_cutoff) { + constexpr bool use_cp_async = nstages == 1; + flash_attn_ext_f16_load_tile + (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup); + if (use_cp_async) { + cp_async_wait_all(); + } + __syncthreads(); } - __syncthreads(); } const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2; - // Calculate VKQ tile: +#if defined(TURING_MMA_AVAILABLE) + constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J; #pragma unroll - for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += tile_C_VKQ::I) { - static_assert((c::nbatch_fa/2) % (np*tile_A::J) == 0, "bad loop size"); + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) { + static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, "bad loop size"); #pragma unroll - for (int k00 = 0; k00 < c::nbatch_fa/2; k00 += np*tile_A::J) { - const int k0 = k00 + (threadIdx.y % np)*tile_A::J; + for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) { + const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J; - tile_A A; + T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); - if (ntiles == 1) { - mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); + if constexpr (T_B_KQ::I == 8) { + mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); } else { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - // Wide version of VKQ_C is column-major => swap A and B. - mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A); - } + // Wide version of VKQ_C is column-major => swap A and B. + mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A); } } } +#else // Volta + constexpr int i0_stride = 2*T_C_VKQ::J; +#pragma unroll + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) { + static_assert(nbatch_fa % (np*T_A_VKQ::I) == 0, "bad loop size"); + static_assert(2*T_B_VKQ::J == T_A_VKQ::I, "bad tile sizes"); +#pragma unroll + for (int k00 = 0; k00 < nbatch_fa; k00 += np*T_A_VKQ::I) { + const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::I; - if (nstages <= 1) { + T_A_VKQ A; // Transposed in both SRAM and registers, load normally. + load_ldmatrix(A, tile_V_i + k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); + mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A); + } + } +#endif // defined(TURING_MMA_AVAILABLE) + + if constexpr (nstages <= 1) { __syncthreads(); // Only needed if tile_K == tile_V. } } #else - GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, + GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); NO_DEVICE_CODE; -#endif // TURING_MMA_AVAILABLE +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } -template +#if defined(TURING_MMA_AVAILABLE) +template struct mma_tile_sizes { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile<16, 8, half2>; // column-major + using T_C_KQ = tile<16, 16, float>; // column-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile<16, 8, half2>; // column-major + using T_C_VKQ = tile<16, 8, half2>; // column-major +}; +template<> struct mma_tile_sizes<8> { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile< 8, 8, half2>; // column-major + using T_C_KQ = tile<16, 8, float>; // row-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile< 8, 8, half2>; // column-major + using T_C_VKQ = tile<16, 4, half2>; // row-major +}; +#else // Volta +template struct mma_tile_sizes { + using T_A_KQ = tile< 8, 4, half2, DATA_SPLIT_MIRRORED, false>; // row-major + using T_B_KQ = tile<32, 4, half2, DATA_SPLIT_NONE, false>; // column-major + using T_C_KQ = tile<32, 8, float, DATA_SPLIT_NONE, false>; // column-major + using T_A_VKQ = tile< 8, 4, half2, DATA_SPLIT_MIRRORED, true>; // column-major + using T_B_VKQ = tile<32, 4, half2, DATA_SPLIT_NONE, false>; // column-major + using T_C_VKQ = tile<32, 4, half2, DATA_SPLIT_NONE, false>; // column-major +}; +#endif // defined(TURING_MMA_AVAILABLE) + +template static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, - const half2 * const __restrict__ mask_h2, + const half * const __restrict__ mask_h, const float * const __restrict__ sinks_f, float2 * const __restrict__ dstk, float2 * const __restrict__ dstk_fixup, const float scale, const float slope, const float logit_softcap, - const int ne01, + const uint3 ne01, const int ne02, + const int ne11, const int stride_Q1, const int stride_Q2, const int stride_K, @@ -798,23 +830,31 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int jt, const int kb0_start, const int kb0_stop) { -#ifdef TURING_MMA_AVAILABLE +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - typedef fattn_mma_f16_config c; - -#ifdef CP_ASYNC_AVAILABLE - constexpr int nstages = c::nstages_target; -#else - constexpr int nstages = 0; -#endif // CP_ASYNC_AVAILABLE - - constexpr int ncols = ncols1 * ncols2; - constexpr int cols_per_warp = ntiles * tile_B::I; - constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; - constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. - constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); - constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); + constexpr int ncols = ncols1 * ncols2; + using T_A_KQ = typename mma_tile_sizes::T_A_KQ; + using T_B_KQ = typename mma_tile_sizes::T_B_KQ; + using T_C_KQ = typename mma_tile_sizes::T_C_KQ; + using T_A_VKQ = typename mma_tile_sizes::T_A_VKQ; + using T_B_VKQ = typename mma_tile_sizes::T_B_VKQ; + using T_C_VKQ = typename mma_tile_sizes::T_C_VKQ; + + constexpr int cols_per_warp = T_B_KQ::I; + constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column. + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols); + constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols); + constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols); + constexpr int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols); + constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols); + constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2); + + if (cols_per_warp > ncols) { + NO_DEVICE_CODE; + return; + } static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); @@ -826,15 +866,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; extern __shared__ half2 tile_Q[]; - half2 * tile_K = c::Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q; - half2 * tile_V = nstages > 1 ? tile_K + c::nbatch_fa * stride_tile_K : tile_K; - half2 * tile_mask = nstages > 1 ? tile_V + c::nbatch_fa * stride_tile_V : tile_V + c::nbatch_fa * stride_tile_KV_max; + half2 * tile_K = Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q; + half2 * tile_V = nstages > 1 ? tile_K + nbatch_fa * stride_tile_K : tile_K; + half * tile_mask = (half *) (nstages > 1 ? tile_V + nbatch_fa * stride_tile_V : tile_V + nbatch_fa * stride_tile_KV_max); - tile_B Q_B[(c::Q_in_reg ? DKQ/(2*tile_B::J) : 1) * ntiles]; - tile_C_VKQ VKQ_C[DV/tile_C_VKQ::I * ntiles]; - - tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; - tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; + T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)]; +#if defined(TURING_MMA_AVAILABLE) + T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)]; +#else // Volta + T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; +#endif // defined(TURING_MMA_AVAILABLE) float KQ_rowsum[cols_per_thread] = {0.0f}; float KQ_max[cols_per_thread]; @@ -868,7 +909,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j = jc / ncols2; const int c = jc % ncols2; - if (jt*ncols1 + j < ne01) { + if (jt*ncols1 + j < int(ne01.z)) { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); @@ -889,63 +930,96 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); - if (c::Q_in_reg) { + if (Q_in_reg) { const int j0 = (threadIdx.y / np) * cols_per_warp; #pragma unroll - for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) { - if (ntiles == 1) { - load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q); - } else { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t], - tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q); - } - } + for (int k0 = 0; k0 < DKQ/2; k0 += T_B_KQ::J) { + load_ldmatrix(Q_B[k0/T_B_KQ::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q); } } __syncthreads(); + int kb0 = kb0_start; + // Preload mask and K data for first iteration when using cp_async with multiple stages: if constexpr (nstages > 1) { static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); constexpr bool use_cp_async = true; - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask - (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); + constexpr bool oob_check = false; + constexpr int k_VKQ_sup = nbatch_fa; + if (ncols2 > 1 || mask_h) { + flash_attn_ext_f16_load_mask + (mask_h + kb0*nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01); } - flash_attn_ext_f16_load_tile - (K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K); + flash_attn_ext_f16_load_tile + (K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup); } - // Iterate over ne11 == previous tokens: - int kb0 = kb0_start; for (; kb0 < kb0_stop-1; ++kb0) { constexpr bool last_iter = false; - flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); - } - { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. + constexpr bool oob_check = false; + constexpr int k_VKQ_sup = nbatch_fa; + flash_attn_ext_f16_iter + + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); + } + // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. + if constexpr (ncols2 == 1) { + if (ne11 % nbatch_fa == 0) { + constexpr bool last_iter = true; + constexpr bool oob_check = false; + constexpr int k_VKQ_sup = nbatch_fa; + flash_attn_ext_f16_iter + + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); + } else { + constexpr bool last_iter = true; + constexpr bool oob_check = true; + const int k_VKQ_sup = ne11 - kb0*nbatch_fa; + flash_attn_ext_f16_iter + + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); + } + } else { constexpr bool last_iter = true; - flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + constexpr bool oob_check = false; + constexpr int k_VKQ_sup = nbatch_fa; + flash_attn_ext_f16_iter + + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); } // With multi-stage loading there is no __syncthreads at the end of the iter, // there can be a race condition on shared memory access for combining/writing back results. - if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) { + if constexpr (nstages > 1 && nwarps*cols_per_warp > nbatch_fa) { __syncthreads(); } // Finally, sum up partial KQ rowsums. - // The partial sums are spread across 8/4 threads each, does not need full reduce. { - constexpr int offset_first = ntiles == 1 ? 16 : 2; - constexpr int offset_last = ntiles == 1 ? 4 : 1; +#if defined(TURING_MMA_AVAILABLE) + // The partial sums are spread across 8/4 threads. + constexpr int offset_first = cols_per_warp == 8 ? 16 : 2; + constexpr int offset_last = cols_per_warp == 8 ? 4 : 1; +#else // Volta + // The partial sums are spread across 2 threads. + constexpr int offset_first = 2; + constexpr int offset_last = 2; +#endif // defined(TURING_MMA_AVAILABLE) #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll @@ -962,8 +1036,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float KQ_max_scale[cols_per_thread]; #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { - static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented"); - const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col); + const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col); const float sink = sinks_f[jc % ncols2]; const float KQ_max_new = fmaxf(KQ_max[col], sink); @@ -977,12 +1050,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add; } - if (ntiles == 1) { +#if defined(TURING_MMA_AVAILABLE) + if constexpr (cols_per_warp == 8) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); #pragma unroll - for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { + for (int i = 0; i < DV/T_C_VKQ::I; ++i) { #pragma unroll - for (int l = 0; l < tile_C_VKQ::ne; ++l) { + for (int l = 0; l < T_C_VKQ::ne; ++l) { VKQ_C[i].x[l] *= KQ_max_scale_h2; } } @@ -991,30 +1065,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( for (int col = 0; col < cols_per_thread; ++col) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); #pragma unroll - for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { #pragma unroll - for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { - VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) { + VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2; } } } } +#else // Volta + const int col = (threadIdx.x / 2) % 2; + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); +#pragma unroll + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } +#endif // defined(TURING_MMA_AVAILABLE) } // Combine VKQ accumulator values if np > 1. // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // So also write VKQ accumulators to shared memory in column-major format if np == 1. - constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols); - constexpr int tile_stride = nbatch_combine + 4; + constexpr int tile_stride = nbatch_combine + 4; static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine"); - if constexpr (ntiles == 1) { - const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset - const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta + if constexpr (cols_per_warp == 8) { + const int jc_cwmo = (threadIdx.x % (2*T_C_VKQ::J)) / T_C_VKQ::J; // jc combine write meta offset + const int jc_cwm = threadIdx.y*(2*T_C_VKQ::J) + 2*T_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum - if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { + if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*T_C_VKQ::J) { // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; } @@ -1023,24 +1107,30 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if (np == 1) { // No combination is needed, the meta data can be directly written from registers to VRAM. - if (needs_fixup && threadIdx.x < tile_B::I) { + if (needs_fixup && threadIdx.x < T_B_KQ::I) { float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; dstk_fixup_meta[jc_cwm] = KQ_cmr; } - if (is_fixup && threadIdx.x < tile_B::I) { + if (is_fixup && threadIdx.x < T_B_KQ::I) { float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; dstk_fixup_meta[jc_cwm] = KQ_cmr; } } } else { - static_assert(ntiles == 2 || ntiles == 4, "bad ntiles"); - const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta - + (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0) - + tile_C_VKQ_16::get_i(threadIdx.x % 4); - const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum - - if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) { - // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + // jc_cwm = jc combine write meta + // KQ_cmr = KQ combine max rowsum + // Use the 16 bytes of padding in each Q column to store the meta data: KQ max, KQ rowsum, KQ max scale. +#if defined(TURING_MMA_AVAILABLE) + const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4); + const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); + const bool thread_should_write = threadIdx.x % 4 < cols_per_thread; +#else // Volta + const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2); + const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]); + const bool thread_should_write = T_C_KQ::J == 8 || T_C_KQ::get_j(threadIdx.x & 2) < 8; +#endif // defined(TURING_MMA_AVAILABLE) + + if (((!needs_fixup && !is_fixup) || np > 1) && thread_should_write) { ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; } @@ -1048,18 +1138,17 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if (np == 1) { // No combination is needed, the meta data can be directly written from registers to VRAM. - if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + if (needs_fixup && thread_should_write) { float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; dstk_fixup_meta[jc_cwm] = KQ_cmr; } - if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + if (is_fixup && thread_should_write) { float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; dstk_fixup_meta[jc_cwm] = KQ_cmr; } } } - static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles"); if (np > 1 && threadIdx.y % np == 0) { // Combine the meta data for parallel warps via shared memory. // Warps with threadIdx.y % np != 0 must NOT return early. @@ -1135,32 +1224,29 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) { - if (ntiles == 1) { - const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data + if constexpr (cols_per_warp == 8) { + const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data #pragma unroll - for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) { - const tile_B B = get_transposed(VKQ_C[(k00 + k0)/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. + for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) { + const T_B_KQ B = get_transposed(VKQ_C[(k00 + k1)/T_B_KQ::J]); // Conversion of C to B matrix puts it in column-major format. #pragma unroll - for (int l = 0; l < tile_B::ne; ++l) { - const int k = k0 + tile_B::get_j(l); + for (int l = 0; l < T_B_KQ::ne; ++l) { + const int k = k1 + T_B_KQ::get_j(l); tile_Q[jc_cwd*tile_stride + k] = B.x[l]; } } } else { + const int j0 = threadIdx.y*cols_per_warp; #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I; + for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) { #pragma unroll - for (int k0 = 0; k0 < nbatch_combine; k0 += tile_C_VKQ_16::J) { -#pragma unroll - for (int l = 0; l < tile_C_VKQ_16::ne; ++l) { - const int j = j0 + tile_C_VKQ_16::get_i(l); - const int k = k0 + tile_C_VKQ_16::get_j(l); + for (int l = 0; l < T_C_VKQ::ne; ++l) { + const int j = j0 + T_C_VKQ::get_i(l); + const int k = k1 + T_C_VKQ::get_j(l); - tile_Q[j*tile_stride + k] = VKQ_C_16[(k00 + k0)/tile_C_VKQ_16::J * ntiles/2 + t].x[l]; - } + tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l]; } } } @@ -1195,7 +1281,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j_dst = jc_dst / ncols2; const int c_dst = jc_dst % ncols2; - if (!is_fixup && jt*ncols1 + j_dst >= ne01) { + if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) { continue; } @@ -1233,16 +1319,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } #else - GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dstk_fixup, + GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); NO_DEVICE_CODE; -#endif // TURING_MMA_AVAILABLE +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } -template -__launch_bounds__(nwarps*WARP_SIZE, 1) +template +__launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2)) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -1258,14 +1344,14 @@ static __global__ void flash_attn_ext_f16( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE) +#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { @@ -1281,23 +1367,22 @@ static __global__ void flash_attn_ext_f16( static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV"); - typedef fattn_mma_f16_config c; - - static_assert(FATTN_KQ_STRIDE % fattn_mma_f16_config::nbatch_fa == 0, "bad nbatch_fa"); + constexpr int ncols = ncols1 * ncols2; + constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); + constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols); + constexpr int nwarps = nthreads / WARP_SIZE; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int stride_Q1 = nb01 / sizeof(float2); const int stride_Q2 = nb02 / sizeof(float2); const int stride_K = nb11 / sizeof(half2); - const int stride_mask = nb31 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half); const int stride_V = mla ? stride_K : nb21 / sizeof(half2); - const int iter_k = ne11 / FATTN_KQ_STRIDE; - const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - - constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice. + const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; + const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1; // kbc == k block continuous, current index in continuous ijk space. int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; @@ -1318,35 +1403,31 @@ static __global__ void flash_attn_ext_f16( const int head0 = zt * ncols2; - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); - const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : - (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const half * mask_h = ncols2 == 1 && !mask ? nullptr : + (const half *) (mask + nb33*(sequence % ne33)); + float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; - const int kb0_start_kernel = kb0_start * kb_niter; - int kb0_stop_kernel = kb0_stop * kb_niter; - if (KV_max) { - kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa); + kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa); } - constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. - flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); } else { - constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. - flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile. + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); } kbc += iter_k; @@ -1366,29 +1447,26 @@ static __global__ void flash_attn_ext_f16( const int head0 = zt * ncols2; - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); - const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : - (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const half * mask_h = ncols2 == 1 && !mask ? nullptr : + (const half *) (mask + nb33*(sequence % ne33)); + float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; - const int kb0_start_kernel = kb0_start * kb_niter; - int kb0_stop_kernel = kb0_stop * kb_niter; - if (KV_max) { - kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa); + kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa); } constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; - flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); #else GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, max_bias, m0, m1, n_head_log2, logit_softcap, @@ -1400,7 +1478,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE) +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) } template @@ -1409,36 +1487,30 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml const int id = ggml_cuda_get_device(); const int cc = ggml_cuda_info().devices[id].cc; - typedef fattn_mma_f16_config c; + constexpr int ncols = ncols1 * ncols2; - const int nstages = cp_async_available(cc) ? c::nstages_target : 0; + const int nthreads = ggml_cuda_fattn_mma_get_nthreads (DKQ, DV, ncols, cc); + const int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols, cc); + const int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols, cc); + const int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols, cc); + const int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols, cc); + const bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols, cc); + const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc); - constexpr int ncols = ncols1 * ncols2; - constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp. - constexpr int cols_per_warp = ntiles * tile_B::I; - constexpr int nwarps_max_x = ncols / cols_per_warp; - constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I; - constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max; + const int cols_per_warp = std::min(ncols, turing_mma_available(cc) ? 16 : 32); + const int nwarps = nthreads / WARP_SIZE; constexpr bool mla = DKQ == 576; - const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols); - const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols); - const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols); - - static_assert(DKQ % tile_B::J == 0, "bad DKQ"); - static_assert(DV % tile_A::J == 0, "bad DV"); - static_assert(ncols % cols_per_warp == 0, "bad ncols"); - - const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2); - const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); - const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); + const size_t nbytes_shared_mask = ncols1 * (nbatch_fa/2 + 4) * sizeof(half2); const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage; - const size_t nbytes_shared_total = std::max(nbytes_shared_combine, c::Q_in_reg ? + const size_t nbytes_shared_total = std::max(nbytes_shared_combine, Q_in_reg ? std::max(nbytes_shared_Q, nbytes_shared_KV + nbytes_shared_mask) : nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask); @@ -1448,7 +1520,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1459,7 +1531,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } else { constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1471,7 +1543,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml } launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true); + (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true); } diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 3e58d64ff9d..63b235674eb 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -501,6 +501,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, const half * const __restrict__ mask, + const uint3 ne01, const float logit_softcap, const float slope, T_KQ * const KQ, @@ -512,7 +513,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter( float * const KQ_sum, T_acc * const VKQ, const int k_VKQ_0, - const int k_VKQ_max) { + const int k_VKQ_max, + const int col_Q_0) { constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); constexpr int cpy_ne = cpy_nb / 4; @@ -556,7 +558,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( // Apply logit softcap + mask, update KQ_max: #pragma unroll for (int jc0 = 0; jc0 < cpw; ++jc0) { - const int j = (jc0 + (threadIdx.y / np)*cpw)/ncols2; + const int j = fastmodulo(col_Q_0 + (jc0 + (threadIdx.y / np)*cpw)/ncols2, ne01); #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { @@ -736,7 +738,7 @@ static __global__ void flash_attn_tile( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, @@ -781,11 +783,11 @@ static __global__ void flash_attn_tile( const int sequence = blockIdx.z / (ne02/ncols2); const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2) const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0 + nb01*col_Q_0); + const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0); const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape - const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33) + nb31*col_Q_0) : nullptr; + const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33)) : nullptr; const int stride_K2 = nb11 / sizeof(half2); const int stride_V2 = nb21 / sizeof(half2); @@ -842,11 +844,9 @@ static __global__ void flash_attn_tile( for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) { if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) { float tmp_f[cpy_ne_D] = {0.0f}; - if (ncols1 == 1 || col_Q_0 + j < ne01) { - ggml_cuda_memcpy_1 - (tmp_f, &Q_f[c*(nb02/sizeof(float)) + j*(nb01/sizeof(float)) - + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]); - } + ggml_cuda_memcpy_1 + (tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float)) + + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]); #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; ++i1) { @@ -881,23 +881,23 @@ static __global__ void flash_attn_tile( while (k_VKQ_0 < k_VKQ_max - nbatch_fa) { constexpr bool oob_check = false; flash_attn_tile_iter - (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, - stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); k_VKQ_0 += gridDim.y*nbatch_fa; } if (k_VKQ_0 < k_VKQ_max) { constexpr bool oob_check = true; flash_attn_tile_iter - (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, - stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); } } else { // Branch without out-of-bounds checks. for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) { constexpr bool oob_check = false; flash_attn_tile_iter - (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, - stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); } } @@ -1010,13 +1010,13 @@ static __global__ void flash_attn_tile( const int j = jc / ncols2; const int c = jc % ncols2; - if (ncols1 > 1 && col_Q_0 + j >= ne01) { + if (ncols1 > 1 && col_Q_0 + j >= int(ne01.z)) { return; } const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f; - const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y; + const int j_dst_unrolled = ((sequence*int(ne01.z) + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y; #ifdef FAST_FP16_AVAILABLE constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 67aa67ecb94..0bae9849a96 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -33,7 +33,7 @@ static __global__ void flash_attn_ext_vec( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, @@ -150,7 +150,7 @@ static __global__ void flash_attn_ext_vec( float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); // Set memory to zero if out of bounds: - if (ncols > 1 && ic0 + j >= ne01) { + if (ncols > 1 && ic0 + j >= int(ne01.z)) { #pragma unroll for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; @@ -201,7 +201,7 @@ static __global__ void flash_attn_ext_vec( const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; float2 tmp[cpy_ne] = {{0.0f, 0.0f}}; - if (ncols == 1 || ic0 + j < ne01) { + if (ncols == 1 || ic0 + j < int(ne01.z)) { ggml_cuda_memcpy_1(tmp, &Q_j[i]); ggml_cuda_memcpy_1(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]); } @@ -222,7 +222,7 @@ static __global__ void flash_attn_ext_vec( #pragma unroll for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; - if (ncols == 1 || ic0 + j < ne01) { + if (ncols == 1 || ic0 + j < int(ne01.z)) { ggml_cuda_memcpy_1(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]); ggml_cuda_memcpy_1(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]); } @@ -266,7 +266,7 @@ static __global__ void flash_attn_ext_vec( sum = logit_softcap*tanhf(sum); } - if (mask) { + if (mask && (ncols == 1 || ic0 + j < int(ne01.z))) { sum += slope*__half2float(maskh[j*ne11 + i_KQ]); } @@ -412,7 +412,7 @@ static __global__ void flash_attn_ext_vec( #pragma unroll for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { - if (ncols > 1 && ic0 + j_VKQ >= ne01) { + if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z)) { break; } @@ -479,7 +479,7 @@ static __global__ void flash_attn_ext_vec( if (gridDim.y == 1) { dst_val /= KQ_sum[j_VKQ]; } - dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val; + dst[(((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val; } } @@ -489,8 +489,8 @@ static __global__ void flash_attn_ext_vec( } - if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < ne01)) { - dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]); + if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z))) { + dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]); } #else GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 6c90d6d52b3..0d81f0aae0a 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -38,14 +38,14 @@ static __global__ void flash_attn_ext_f16( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))) +#if defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -149,7 +149,7 @@ static __global__ void flash_attn_ext_f16( if (i0 + warp_size > D && i >= D) { break; } - KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + KQ[j*D_padded + i] = ic0 + j < int(ne01.z) ? Q_f[j*stride_Q + i] * scale : 0.0f; } } @@ -218,7 +218,8 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { const int k = k0 + threadIdx.x; - KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? + __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]); } KQ_max_new = warp_reduce_max(KQ_max_new); @@ -270,7 +271,7 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { const int k = k0 + threadIdx.x; - KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ2_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]); } KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); @@ -431,7 +432,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j_VKQ = j0 + threadIdx.y; - if (ic0 + j_VKQ >= ne01) { + if (ic0 + j_VKQ >= int(ne01.z)) { return; } @@ -442,7 +443,7 @@ static __global__ void flash_attn_ext_f16( KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); } - const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; + const int j_dst_unrolled = ((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; #pragma unroll for (int i0 = 0; i0 < D; i0 += warp_size) { @@ -481,7 +482,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))) +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)) } constexpr int get_max_power_of_2(int x) { diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index 7235f1b77ae..cd3bfd4051a 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -2,9 +2,9 @@ #include "common.cuh" -#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) +#if defined(GGML_USE_MUSA) #define GGML_USE_WMMA_FATTN -#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) +#endif // defined(GGML_USE_MUSA) #if defined(GGML_HIP_ROCWMMA_FATTN) #if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 82405991cea..1358faf6837 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -12,13 +12,13 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con const ggml_tensor * Q = dst->src[0]; if constexpr (ncols2 <= 8) { - if (Q->ne[1] <= 8/ncols2) { + if (turing_mma_available(cc) && Q->ne[1] <= 8/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } } - if (Q->ne[1] <= 16/ncols2) { + if (turing_mma_available(cc) && Q->ne[1] <= 16/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } @@ -41,7 +41,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con float max_bias = 0.0f; memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); - const bool use_gqa_opt = mask && max_bias == 0.0f; + const bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; @@ -275,8 +275,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; - // If Turing tensor cores available, use them: - if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72) { + // If Turing tensor cores are available, use them: + if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { if (can_use_vector_kernel) { if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) { @@ -297,7 +297,18 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_VEC; } } + return BEST_FATTN_KERNEL_MMA_F16; + } + if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { + int gqa_ratio_eff = 1; + const int ncols2_max = Q->ne[0] == 576 ? 16 : 8; + while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { + gqa_ratio_eff *= 2; + } + if (Q->ne[1] * gqa_ratio_eff <= 16) { + return BEST_FATTN_KERNEL_TILE; // On Volta tensor cores are only faster for sufficiently large matrices. + } return BEST_FATTN_KERNEL_MMA_F16; } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 0ed42e87d3d..0007506ec9d 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -68,10 +68,25 @@ static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) { namespace ggml_cuda_mma { + // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel, + // effectively the warp is being split into subgroups of threads that each perform a single mma instruction. + // In those cases the data can be split in different ways across the warp. + enum data_split { + DATA_SPLIT_NONE = 0, // Each data value is held exactly once per warp (always applies to Turing, Ampere, Ada Lovelace, consumer Blackwell). + DATA_SPLIT_MIRRORED = 10, // Each data value is held exactly once per subgroup. + }; + // Implemented mma combinations are: + // - (NONE, NONE) -> NONE + // - (NONE, MIRRORED) -> NONE + + template + struct tile {}; + template - struct tile { - static constexpr int I = I_; - static constexpr int J = J_; + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_split ds = DATA_SPLIT_NONE; #if defined(AMD_MFMA_AVAILABLE) static constexpr int ne = I * J / 64; @@ -131,9 +146,9 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_i(const int l) { if constexpr (I == 32 && J == 8) { #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM - return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (l & 2) | (threadIdx.x % 2); + return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2); #else - return (l & 2) | (threadIdx.x & ~2); + return (l & 2) + (threadIdx.x & ~2); #endif // GGML_CUDA_MMA_NO_VOLTA_PERM } else { NO_DEVICE_CODE; @@ -143,7 +158,7 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 32 && J == 8) { - return (threadIdx.x & 2) | (l & (4 + 1)); + return (threadIdx.x & 2) + (l & (4 + 1)); } else { NO_DEVICE_CODE; return -1; @@ -196,9 +211,9 @@ namespace ggml_cuda_mma { } else if constexpr (I == 8 && J == 8) { return threadIdx.x / 4; } else if constexpr (I == 16 && J == 8) { - return ((l / 2) * 8) | (threadIdx.x / 4); + return ((l / 2) * 8) + (threadIdx.x / 4); } else if constexpr (I == 16 && J == 16) { - return (((l / 2) % 2) * 8) | (threadIdx.x / 4); + return (((l / 2) % 2) * 8) + (threadIdx.x / 4); } else if constexpr (I == 32 && J == 8) { return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction. } else { @@ -211,11 +226,11 @@ namespace ggml_cuda_mma { if constexpr (I == 8 && J == 4) { return threadIdx.x % 4; } else if constexpr (I == 8 && J == 8) { - return (l * 4) | (threadIdx.x % 4); + return (l * 4) + (threadIdx.x % 4); } else if constexpr (I == 16 && J == 8) { - return ((threadIdx.x % 4) * 2) | (l % 2); + return ((threadIdx.x % 4) * 2) + (l % 2); } else if constexpr (I == 16 && J == 16) { - return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2); + return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2); } else if constexpr (I == 32 && J == 8) { return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction. } else { @@ -227,26 +242,24 @@ namespace ggml_cuda_mma { }; template - struct tile { - static constexpr int I = I_; - static constexpr int J = J_; + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_split ds = DATA_SPLIT_NONE; #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA - static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE; + static constexpr int ne = I * J / WARP_SIZE; half2 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { - if (I == 8 && J == 8) return true; - if (I == 32 && J == 8) return true; + if (I == 32 && J == 4) return true; return false; } static __device__ __forceinline__ int get_i(const int l) { - if constexpr (I == 8 && J == 8) { - return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4); - } else if constexpr (I == 32 && J == 8) { + if constexpr (I == 32 && J == 4) { #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM - return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (threadIdx.x % 4); + return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4); #else return threadIdx.x; #endif // GGML_CUDA_MMA_NO_VOLTA_PERM @@ -257,7 +270,7 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_j(const int l) { - if constexpr ((I == 8 || I == 32) && J == 8) { + if constexpr (I == 32 && J == 4) { return l; } else { NO_DEVICE_CODE; @@ -307,11 +320,11 @@ namespace ggml_cuda_mma { if constexpr (I == 8 && J == 8) { return threadIdx.x / 4; } else if constexpr (I == 16 && J == 4) { - return (l * 8) | (threadIdx.x / 4); + return (l * 8) + (threadIdx.x / 4); } else if constexpr (I == 16 && J == 8) { - return ((l % 2) * 8) | (threadIdx.x / 4); + return ((l % 2) * 8) + (threadIdx.x / 4); } else if constexpr (I == 32 && J == 8) { - return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4); + return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4); } else { NO_DEVICE_CODE; return -1; @@ -320,13 +333,13 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 8 && J == 8) { - return (l * 4) | (threadIdx.x % 4); + return (l * 4) + (threadIdx.x % 4); } else if constexpr (I == 16 && J == 4) { return threadIdx.x % 4; } else if constexpr (I == 16 && J == 8) { - return ((l / 2) * 4) | (threadIdx.x % 4); + return ((l / 2) * 4) + (threadIdx.x % 4); } else if constexpr (I == 32 && J == 8) { - return ((l & 2) * 2) | (threadIdx.x % 4); + return ((l & 2) * 2) + (threadIdx.x % 4); } else { NO_DEVICE_CODE; return -1; @@ -336,14 +349,15 @@ namespace ggml_cuda_mma { }; template - struct tile { - static constexpr int I = I_; - static constexpr int J = J_; + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_split ds = DATA_SPLIT_NONE; + static constexpr int ne = I * J / WARP_SIZE; -#if defined(AMD_WMMA_AVAILABLE) - static constexpr int ne = I * J / 32; nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; +#if defined(AMD_WMMA_AVAILABLE) static constexpr __device__ bool supported() { if (I == 16 && J == 8) return true; return false; @@ -367,9 +381,6 @@ namespace ggml_cuda_mma { } } #else - static constexpr int ne = I * J / WARP_SIZE; - nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; - static constexpr __device__ bool supported() { if (I == 8 && J == 8) return true; if (I == 16 && J == 4) return true; @@ -381,9 +392,9 @@ namespace ggml_cuda_mma { if constexpr (I == 8 && J == 8) { return threadIdx.x / 4; } else if constexpr (I == 16 && J == 4) { - return (l * 8) | (threadIdx.x / 4); + return (l * 8) + (threadIdx.x / 4); } else if constexpr (I == 16 && J == 8) { - return ((l % 2) * 8) | (threadIdx.x / 4); + return ((l % 2) * 8) + (threadIdx.x / 4); } else { NO_DEVICE_CODE; return -1; @@ -392,11 +403,11 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 8 && J == 8) { - return (l * 4) | (threadIdx.x % 4); + return (l * 4) + (threadIdx.x % 4); } else if constexpr (I == 16 && J == 4) { return threadIdx.x % 4; } else if constexpr (I == 16 && J == 8) { - return ((l / 2) * 4) | (threadIdx.x % 4); + return ((l / 2) * 4) + (threadIdx.x % 4); } else { NO_DEVICE_CODE; return -1; @@ -405,6 +416,73 @@ namespace ggml_cuda_mma { #endif // defined(AMD_WMMA_AVAILABLE) }; + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_split ds = DATA_SPLIT_MIRRORED; + static constexpr int ne = I * J / (WARP_SIZE/4); + + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 8 && J == 4) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int /*l*/) { + if constexpr (I == 8 && J == 4) { + return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4); + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 4) { + return l; + } else { + NO_DEVICE_CODE; + return -1; + } + } + }; + + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_split ds = DATA_SPLIT_MIRRORED; + static constexpr int ne = I * J / (WARP_SIZE/4); + + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 8 && J == 4) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && J == 4) { + return ((l / 2) * 4) + (threadIdx.x % 4); + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 4) { + return ((threadIdx.x / 16) * 2) + (l % 2); + } else { + NO_DEVICE_CODE; + return -1; + } + } + }; + +#if defined(TURING_MMA_AVAILABLE) template static __device__ __forceinline__ tile get_half2(const tile & tile_float) { tile ret; @@ -422,9 +500,26 @@ namespace ggml_cuda_mma { return ret; } +#else // Volta + template + static __device__ __forceinline__ tile get_half2(const tile & tile_float) { + tile ret; +#pragma unroll + for (int l0 = 0; l0 < tile_float.ne; l0 += 4) { + ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]); + ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]); + + // On Volta FP16 and FP32 tiles have a different memory layout, + // for the conversion threads with an offset of 2 need to exchange half their values: + ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync( + 0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE); + } + return ret; + } +#endif // defined(TURING_MMA_AVAILABLE) - template - static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { + template + static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { #if defined(AMD_MFMA_AVAILABLE) if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> #pragma unroll @@ -511,18 +606,6 @@ namespace ggml_cuda_mma { : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3]) : "l"(xs)); #else -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA - GGML_UNUSED_VARS(t, xs0, stride); - NO_DEVICE_CODE; -#else - load_generic(t, xs0, stride); -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA -#endif // TURING_MMA_AVAILABLE - } - - template - static __device__ __forceinline__ void load_ldmatrix( - tile<32, 8, T> & t, const T * __restrict__ xs0, const int stride) { #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #if 1 // TODO: more generic handling @@ -533,9 +616,31 @@ namespace ggml_cuda_mma { load_generic(t, xs0, stride); #endif // 1 #else - tile<16, 8, T> * t16 = (tile<16, 8, T> *) &t; - load_ldmatrix(t16[0], xs0 + 0*stride, stride); - load_ldmatrix(t16[1], xs0 + 16*stride, stride); + load_generic(t, xs0, stride); +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // TURING_MMA_AVAILABLE + } + + static __device__ __forceinline__ void load_ldmatrix( + tile<8, 4, half2, DATA_SPLIT_MIRRORED, false> & t, const half2 * __restrict__ xs0, const int stride) { + ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride); + } + + static __device__ __forceinline__ void load_ldmatrix( + tile<8, 4, half2, DATA_SPLIT_MIRRORED, true> & t, const half2 * __restrict__ xs0, const int stride) { +#pragma unroll + for (int l0 = 0; l0 < t.ne; l0 += 2) { + ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0)); + } + } + + static __device__ __forceinline__ void load_ldmatrix( + tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride); +#else + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA } @@ -867,7 +972,7 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ void mma( - tile<32, 8, float> & D, const tile<32, 8, half2> & A, const tile<8, 8, half2> & B) { + tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_SPLIT_MIRRORED, false> & B) { #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; @@ -880,20 +985,30 @@ namespace ggml_cuda_mma { "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3])); - asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" - : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) - : "r"(Axi[4]), "r"(Axi[5]), "r"(Bxi[4]), "r"(Bxi[5])); - asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" - : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) - : "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7])); #else - tile <16, 8, float> * D16 = reinterpret_cast *>(&D); - const tile<16, 8, half2> * A16 = reinterpret_cast *>(&A); - mma(D16[0], A16[0], B); - mma(D16[1], A16[1], B); -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + } + + static __device__ __forceinline__ void mma( + tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_SPLIT_MIRRORED, true> & B) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; + asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 " + "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1])); + asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 " + "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3])); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA } static __device__ __forceinline__ void mma( diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index c2a0a2e42fe..507022ff710 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -37,23 +37,19 @@ static __global__ void mul_mat_f( typedef tile<16, 8, T> tile_A; typedef tile tile_B; typedef tile<16, tile_C_J, float> tile_C; - - constexpr bool a_supported = tile_A::supported(); - constexpr bool b_supported = tile_B::supported(); - constexpr bool c_supported = tile_C::supported(); - constexpr bool supported = a_supported && b_supported && c_supported; #else - constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); - constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); - constexpr bool supported = I_16_supported || I_32_supported; - - constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. - - typedef tile tile_A; - typedef tile<8, 8, T> tile_B; - typedef tile tile_C; +#ifdef VOLTA_MMA_AVAILABLE + if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else { + typedef tile<32, 4, T, DATA_SPLIT_NONE, false> tile_A; + typedef tile< 8, 4, T, DATA_SPLIT_MIRRORED, false> tile_B; + typedef tile<32, 8, float, DATA_SPLIT_NONE, false> tile_C; +#else + typedef tile<16, 8, T> tile_A; + typedef tile<8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; +#endif // VOLTA_MMA_AVAILABLE #endif // defined(AMD_WMMA_AVAILABLE) - if constexpr (!supported) { + if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) { NO_DEVICE_CODE; return; } @@ -248,6 +244,9 @@ static __global__ void mul_mat_f( } } } +#ifdef VOLTA_MMA_AVAILABLE + } +#endif //VOLTA_MMA_AVAILABLE #else GGML_UNUSED_VARS(x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, @@ -278,27 +277,24 @@ static __global__ void mul_mat_f_ids( typedef tile<16, 8, T> tile_A; typedef tile tile_B; typedef tile<16, tile_C_J, float> tile_C; - - constexpr bool a_supported = tile_A::supported(); - constexpr bool b_supported = tile_B::supported(); - constexpr bool c_supported = tile_C::supported(); - constexpr bool supported = a_supported && b_supported && c_supported; #else - constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); - constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); - constexpr bool supported = I_16_supported || I_32_supported; - - constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. - - typedef tile tile_A; - typedef tile<8, 8, T> tile_B; - typedef tile tile_C; +#ifdef VOLTA_MMA_AVAILABLE + if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else { + typedef tile<32, 4, T, DATA_SPLIT_NONE, false> tile_A; + typedef tile< 8, 4, T, DATA_SPLIT_MIRRORED, false> tile_B; + typedef tile<32, 8, float, DATA_SPLIT_NONE, false> tile_C; +#else + typedef tile<16, 8, T> tile_A; + typedef tile<8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; +#endif // VOLTA_MMA_AVAILABLE #endif // defined(AMD_WMMA_AVAILABLE) - if constexpr (!supported) { + if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) { NO_DEVICE_CODE; return; } + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int tile_k_padded = warp_size + 4; constexpr int ntA = rows_per_block / tile_A::I; @@ -517,6 +513,9 @@ static __global__ void mul_mat_f_ids( } } } +#ifdef VOLTA_MMA_AVAILABLE + } +#endif // VOLTA_MMA_AVAILABLE #else GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, From e2c50b1db2f2f23fb51b51fb8d012d658f445956 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 27 Nov 2025 15:55:31 +0100 Subject: [PATCH 2/6] fix const correctness --- ggml/src/ggml-cuda/mma.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 0007506ec9d..628a256b699 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -965,8 +965,8 @@ namespace ggml_cuda_mma { template static __device__ __forceinline__ void mma( tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile & B) { - tile<16, J, T1> * D16 = (tile<16, J, T1> *) &D; - tile<16, K, T2> * A16 = (tile<16, K, T2> *) &A; + tile <16, J, T1> * D16 = reinterpret_cast< tile<16, J, T1> *>(&D); + const tile<16, K, T2> * A16 = reinterpret_cast *>(&A); mma(D16[0], A16[0], B); mma(D16[1], A16[1], B); } From 301ae300571099c33e0c70a2869f0a54cf67e469 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 28 Nov 2025 16:48:35 +0100 Subject: [PATCH 3/6] fix turing config lookup --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 62e3d0a673b..be1e1b5ae74 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -104,7 +104,7 @@ static __host__ uint32_t ggml_cuda_fattn_mma_get_config(const int DKQ, const int if (ampere_mma_available(cc)) { return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); } - if (ampere_mma_available(cc)) { + if (turing_mma_available(cc)) { return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); } GGML_ASSERT(volta_mma_available(cc)); From 13500e8cca032138277cba9b4a63c003e92701d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 29 Nov 2025 14:34:59 +0100 Subject: [PATCH 4/6] refactor template parameters --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 12 ++--- ggml/src/ggml-cuda/mma.cuh | 76 +++++++++++++++------------- ggml/src/ggml-cuda/mmf.cuh | 12 ++--- 3 files changed, 53 insertions(+), 47 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index be1e1b5ae74..cb5443a9ac6 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -798,12 +798,12 @@ template<> struct mma_tile_sizes<8> { }; #else // Volta template struct mma_tile_sizes { - using T_A_KQ = tile< 8, 4, half2, DATA_SPLIT_MIRRORED, false>; // row-major - using T_B_KQ = tile<32, 4, half2, DATA_SPLIT_NONE, false>; // column-major - using T_C_KQ = tile<32, 8, float, DATA_SPLIT_NONE, false>; // column-major - using T_A_VKQ = tile< 8, 4, half2, DATA_SPLIT_MIRRORED, true>; // column-major - using T_B_VKQ = tile<32, 4, half2, DATA_SPLIT_NONE, false>; // column-major - using T_C_VKQ = tile<32, 4, half2, DATA_SPLIT_NONE, false>; // column-major + using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_KQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major + using T_C_KQ = tile<32, 8, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile< 8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED>; // column-major + using T_B_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major + using T_C_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major }; #endif // defined(TURING_MMA_AVAILABLE) diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 628a256b699..08991644fe5 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -71,22 +71,28 @@ namespace ggml_cuda_mma { // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel, // effectively the warp is being split into subgroups of threads that each perform a single mma instruction. // In those cases the data can be split in different ways across the warp. - enum data_split { - DATA_SPLIT_NONE = 0, // Each data value is held exactly once per warp (always applies to Turing, Ampere, Ada Lovelace, consumer Blackwell). - DATA_SPLIT_MIRRORED = 10, // Each data value is held exactly once per subgroup. + enum data_layout { + // By default the data uses the I direction as its major dimension and the J direction as its minor dimension. + // For the A/C matrices this means I major == row major, J major == column major. + // For the B matrix this means I major == column major, J major == row major. + // MIRRORED == Each data value is held exactly once per thread subgroup. + DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell. + DATA_LAYOUT_I_MAJOR_MIRRORED = 10, + DATA_LAYOUT_J_MAJOR_MIRRORED = 20, }; // Implemented mma combinations are: - // - (NONE, NONE) -> NONE - // - (NONE, MIRRORED) -> NONE + // - (I_MAJOR, I_MAJOR) -> I_MAJOR + // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR + // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR - template + template struct tile {}; template - struct tile { - static constexpr int I = I_; - static constexpr int J = J_; - static constexpr data_split ds = DATA_SPLIT_NONE; + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; #if defined(AMD_MFMA_AVAILABLE) static constexpr int ne = I * J / 64; @@ -242,10 +248,10 @@ namespace ggml_cuda_mma { }; template - struct tile { - static constexpr int I = I_; - static constexpr int J = J_; - static constexpr data_split ds = DATA_SPLIT_NONE; + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA static constexpr int ne = I * J / WARP_SIZE; @@ -349,11 +355,11 @@ namespace ggml_cuda_mma { }; template - struct tile { - static constexpr int I = I_; - static constexpr int J = J_; - static constexpr data_split ds = DATA_SPLIT_NONE; - static constexpr int ne = I * J / WARP_SIZE; + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; + static constexpr int ne = I * J / WARP_SIZE; nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; @@ -417,11 +423,11 @@ namespace ggml_cuda_mma { }; template - struct tile { - static constexpr int I = I_; - static constexpr int J = J_; - static constexpr data_split ds = DATA_SPLIT_MIRRORED; - static constexpr int ne = I * J / (WARP_SIZE/4); + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED; + static constexpr int ne = I * J / (WARP_SIZE/4); half2 x[ne] = {{0.0f, 0.0f}}; @@ -450,11 +456,11 @@ namespace ggml_cuda_mma { }; template - struct tile { - static constexpr int I = I_; - static constexpr int J = J_; - static constexpr data_split ds = DATA_SPLIT_MIRRORED; - static constexpr int ne = I * J / (WARP_SIZE/4); + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED; + static constexpr int ne = I * J / (WARP_SIZE/4); half2 x[ne] = {{0.0f, 0.0f}}; @@ -518,8 +524,8 @@ namespace ggml_cuda_mma { } #endif // defined(TURING_MMA_AVAILABLE) - template - static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { + template + static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { #if defined(AMD_MFMA_AVAILABLE) if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> #pragma unroll @@ -622,12 +628,12 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ void load_ldmatrix( - tile<8, 4, half2, DATA_SPLIT_MIRRORED, false> & t, const half2 * __restrict__ xs0, const int stride) { + tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) { ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride); } static __device__ __forceinline__ void load_ldmatrix( - tile<8, 4, half2, DATA_SPLIT_MIRRORED, true> & t, const half2 * __restrict__ xs0, const int stride) { + tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) { #pragma unroll for (int l0 = 0; l0 < t.ne; l0 += 2) { ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0)); @@ -972,7 +978,7 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ void mma( - tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_SPLIT_MIRRORED, false> & B) { + tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) { #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; @@ -992,7 +998,7 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ void mma( - tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_SPLIT_MIRRORED, true> & B) { + tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) { #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index 507022ff710..e1c695c5c0f 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -40,9 +40,9 @@ static __global__ void mul_mat_f( #else #ifdef VOLTA_MMA_AVAILABLE if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else { - typedef tile<32, 4, T, DATA_SPLIT_NONE, false> tile_A; - typedef tile< 8, 4, T, DATA_SPLIT_MIRRORED, false> tile_B; - typedef tile<32, 8, float, DATA_SPLIT_NONE, false> tile_C; + typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; + typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; #else typedef tile<16, 8, T> tile_A; typedef tile<8, 8, T> tile_B; @@ -280,9 +280,9 @@ static __global__ void mul_mat_f_ids( #else #ifdef VOLTA_MMA_AVAILABLE if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else { - typedef tile<32, 4, T, DATA_SPLIT_NONE, false> tile_A; - typedef tile< 8, 4, T, DATA_SPLIT_MIRRORED, false> tile_B; - typedef tile<32, 8, float, DATA_SPLIT_NONE, false> tile_C; + typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; + typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; #else typedef tile<16, 8, T> tile_A; typedef tile<8, 8, T> tile_B; From 394ced5ff8e427325464701ae79233a4a157042e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 29 Nov 2025 18:39:58 +0100 Subject: [PATCH 5/6] adjust kernel selection logic --- ggml/src/ggml-cuda/fattn.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 1358faf6837..de5e76ff023 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -306,6 +306,9 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { gqa_ratio_eff *= 2; } + if (Q->ne[1] * gqa_ratio_eff <= 2) { + return BEST_FATTN_KERNEL_VEC; + } if (Q->ne[1] * gqa_ratio_eff <= 16) { return BEST_FATTN_KERNEL_TILE; // On Volta tensor cores are only faster for sufficiently large matrices. } From 3e1ca0c62f301e5d3ba96c4f47b1371c243dbde2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 29 Nov 2025 19:06:57 +0100 Subject: [PATCH 6/6] fix trailing whitespace --- ggml/src/ggml-cuda/mma.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 08991644fe5..6ea7a809a47 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -77,7 +77,7 @@ namespace ggml_cuda_mma { // For the B matrix this means I major == column major, J major == row major. // MIRRORED == Each data value is held exactly once per thread subgroup. DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell. - DATA_LAYOUT_I_MAJOR_MIRRORED = 10, + DATA_LAYOUT_I_MAJOR_MIRRORED = 10, DATA_LAYOUT_J_MAJOR_MIRRORED = 20, }; // Implemented mma combinations are: