Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion ggml/src/ggml-cuda/fattn-tile.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -721,9 +721,12 @@ static __global__ void flash_attn_tile(

// Skip unused kernel variants for faster compilation:

// Optionally disable pruning to keep all TILE variants for testing.
#if !defined(GGML_USE_HIP)
if (
#ifdef GGML_USE_WMMA_FATTN
(ncols2 != 1 && DV != 40 && DV != 512) ||
// On CUDA WMMA builds, prune some TILE variants to reduce compile time/binary size.
(ncols2 != 1 && DV != 40 && DV != 64 && DV != 128 && DV != 256 && DV != 512) ||
#endif // GGML_USE_WMMA_FATTN
(use_logit_softcap && !(DV == 128 || DV == 256))
) {
Expand All @@ -739,6 +742,7 @@ static __global__ void flash_attn_tile(
NO_DEVICE_CODE;
return;
}
#endif // !defined(GGML_USE_HIP)

static_assert(ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, "kernel config not defined");

Expand Down
56 changes: 39 additions & 17 deletions ggml/src/ggml-cuda/fattn-wmma-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,24 @@ namespace wmma = rocwmma;
#endif // !defined(GGML_USE_HIP)
#endif // GGML_USE_WMMA_FATTN

#if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
static constexpr int GGML_ROCWMMA_FATTN_MIN_BLOCKS_PER_SM = 2;
#else
static constexpr int GGML_ROCWMMA_FATTN_MIN_BLOCKS_PER_SM = 1;
#endif

template <int D>
constexpr int ggml_wmma_fattn_kq_stride() {
#if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
return D <= 128 ? 128 : FATTN_KQ_STRIDE;
#else
return FATTN_KQ_STRIDE;
#endif
}

// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), GGML_ROCWMMA_FATTN_MIN_BLOCKS_PER_SM)
static __global__ void flash_attn_ext_f16(
const char * __restrict__ Q,
const char * __restrict__ K,
Expand Down Expand Up @@ -55,10 +70,12 @@ static __global__ void flash_attn_ext_f16(
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.

constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int fattn_kq_stride = ggml_wmma_fattn_kq_stride<D>();

static_assert(D <= fattn_kq_stride, "D must be <= fattn_kq_stride.");

const int ic0 = ncols*blockIdx.x; // Index of the first Q/QKV column to work on.

static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
constexpr int frag_m = ncols == 8 ? 32 : 16;
constexpr int frag_n = ncols == 8 ? 8 : 16;
Expand All @@ -75,7 +92,7 @@ static __global__ void flash_attn_ext_f16(

// Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
constexpr int D_padded = D + 8;
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
constexpr int kqs_padded = fattn_kq_stride + 8;
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);

const int sequence = blockIdx.z / ne02;
Expand Down Expand Up @@ -168,10 +185,10 @@ static __global__ void flash_attn_ext_f16(

// Iterate over ne11 == previous tokens:
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
for (int k_VKQ_0 = blockIdx.y*fattn_kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*fattn_kq_stride) {
// Calculate tile of KQ:
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
for (int i_KQ_0 = 0; i_KQ_0 < fattn_kq_stride; i_KQ_0 += KQ_stride_tc) {
frag_c_KQ KQ_c[ncols/frag_n];
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
Expand Down Expand Up @@ -201,9 +218,9 @@ static __global__ void flash_attn_ext_f16(
const int j = j0 + threadIdx.y;

if (std::is_same<KQ_acc_t, float>::value) {
float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size];
float KQ_f_tmp[fattn_kq_stride / warp_size];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
for (int k0 = 0; k0 < fattn_kq_stride; k0 += warp_size) {
const int k = k0 + threadIdx.x;

KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k];
Expand All @@ -215,7 +232,7 @@ static __global__ void flash_attn_ext_f16(

float KQ_max_new = KQ_max_f[j0/nwarps];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
for (int k0 = 0; k0 < fattn_kq_stride; k0 += warp_size) {
const int k = k0 + threadIdx.x;

KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
Expand All @@ -232,7 +249,7 @@ static __global__ void flash_attn_ext_f16(

float KQ_rowsum_add = 0.0f;
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
for (int k0 = 0; k0 < fattn_kq_stride; k0 += warp_size) {
const int k = k0 + threadIdx.x;

const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps];
Expand All @@ -248,9 +265,9 @@ static __global__ void flash_attn_ext_f16(
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
} else {
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*warp_size)];
half2 KQ2_tmp[fattn_kq_stride/(2*warp_size)];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
for (int k0 = 0; k0 < fattn_kq_stride/2; k0 += warp_size) {
const int k = k0 + threadIdx.x;

KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2) + k];
Expand All @@ -267,7 +284,7 @@ static __global__ void flash_attn_ext_f16(

half2 KQ_max_new = KQ_max_h2[j0/nwarps];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
for (int k0 = 0; k0 < fattn_kq_stride/2; k0 += warp_size) {
const int k = k0 + threadIdx.x;

KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
Expand All @@ -282,7 +299,7 @@ static __global__ void flash_attn_ext_f16(

half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
for (int k0 = 0; k0 < fattn_kq_stride/2; k0 += warp_size) {
const int k = k0 + threadIdx.x;

const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps];
Expand All @@ -301,11 +318,11 @@ static __global__ void flash_attn_ext_f16(

__syncthreads();

frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
frag_b KQ_b[fattn_kq_stride/(VKQ_ratio*16)][ncols/frag_n];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
for (int k0 = 0; k0 < fattn_kq_stride; k0 += VKQ_ratio*16) {
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
wmma::load_matrix_sync(
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
Expand All @@ -323,7 +340,7 @@ static __global__ void flash_attn_ext_f16(
}

#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
for (int k0 = 0; k0 < fattn_kq_stride; k0 += VKQ_ratio*16) {
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;

frag_a_V v_a;
Expand Down Expand Up @@ -512,7 +529,12 @@ template <int D, int cols_per_block, typename KQ_acc_t>
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;

#if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
constexpr int nwarps = D <= 96 ? 8 : 4;
#else
constexpr int nwarps = 4;
#endif
constexpr int fattn_kq_stride = ggml_wmma_fattn_kq_stride<D>();

constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
Expand All @@ -530,7 +552,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, 0, fattn_kq_stride, true, true, false, warp_size);
}

void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
Expand Down
55 changes: 54 additions & 1 deletion ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -301,13 +301,66 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
}

// Use the WMMA kernel if possible:
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) {
#if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
const bool hip_wmma_decode = Q->ne[1] == 1;
#else
const bool hip_wmma_decode = false;
#endif
if (!hip_wmma_decode && ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) {
if (can_use_vector_kernel && Q->ne[1] <= 2) {
return BEST_FATTN_KERNEL_VEC;
}
return BEST_FATTN_KERNEL_WMMA_F16;
}

// HIP decode path (Q->ne[1] == 1): fall through to generic HIP selection below (VEC/TILE),
// with a guard to avoid selecting a TILE shape that has no config.
if (hip_wmma_decode) {
#if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
// Mirror the ncols2 selection from launch_fattn_tile_switch_ncols2 to predict if
// a multi-column TILE kernel (ncols2 != 1) would be chosen.
const bool nvidia_arch = GGML_CUDA_CC_IS_NVIDIA(cc);
const int gqa_limit = (nvidia_arch && gqa_ratio <= 4) ? 16 : INT_MAX;
const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;

int predicted_ncols2 = 1;
if (V->ne[0] == 512) {
if (use_gqa_opt && gqa_ratio % 16 == 0) predicted_ncols2 = 16;
} else if (V->ne[0] <= 256) {
if (use_gqa_opt && gqa_ratio % 8 == 0) predicted_ncols2 = 8;
else if (use_gqa_opt && gqa_ratio % 4 == 0) predicted_ncols2 = 4;
else if (use_gqa_opt && gqa_ratio % 2 == 0) predicted_ncols2 = 2;
}

// Predict cols_per_block like launch_fattn_tile_switch_ncols1 does (HIP path):
int predicted_cols_per_block = 2;
if (predicted_ncols2 <= 2) {
predicted_cols_per_block = 2;
}
if (predicted_ncols2 <= 4 && Q->ne[1] > 2/predicted_ncols2) {
predicted_cols_per_block = 4;
}
if (predicted_ncols2 <= 8 && Q->ne[1] > 4/predicted_ncols2) {
predicted_cols_per_block = 8;
}
if (Q->ne[1] > 8/predicted_ncols2) {
predicted_cols_per_block = 16;
}
if (Q->ne[1] > 16/predicted_ncols2) {
predicted_cols_per_block = 32;
}
if (V->ne[0] <= 128 && Q->ne[1] > 32/predicted_ncols2) {
predicted_cols_per_block = 64;
}

const uint32_t cfg = ggml_cuda_fattn_tile_get_config((int)Q->ne[0], (int)V->ne[0], predicted_cols_per_block, cc);
if (predicted_ncols2 != 1 && cfg == 0) {
return BEST_FATTN_KERNEL_VEC;
}
#endif // defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
// Otherwise, fall through.
}

// If there are no tensor cores available, use the generic tile kernel:
if (can_use_vector_kernel) {
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
Expand Down