Skip to content

Commit e44ebb0

Browse files
revise OOB check in j direction
1 parent 6ec87ab commit e44ebb0

File tree

5 files changed

+80
-87
lines changed

5 files changed

+80
-87
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ typedef void (* fattn_kernel_t)(
2525
const float m1,
2626
const uint32_t n_head_log2,
2727
const float logit_softcap,
28-
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
28+
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
2929
const int32_t nb01, const int32_t nb02, const int32_t nb03,
3030
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
3131
const int32_t nb11, const int32_t nb12, const int64_t nb13,
@@ -969,6 +969,9 @@ void launch_fattn(
969969
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
970970
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
971971

972+
// TODO other tensor dimensions after removal of WMMA kernel:
973+
const uint3 ne01 = init_fastdiv_values(Q->ne[1]);
974+
972975
GGML_ASSERT(block_dim.x % warp_size == 0);
973976
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
974977
(const char *) Q->data,
@@ -979,7 +982,7 @@ void launch_fattn(
979982
KV_max.ptr,
980983
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
981984
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
982-
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
985+
Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
983986
K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
984987
nb21, nb22, nb23,
985988
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 42 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
285285
template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
286286
static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
287287
const half * const __restrict__ mask_h, half * const __restrict__ tile_mask,
288-
const int stride_mask, const int i_sup, const int j_sup) {
288+
const int stride_mask, const int i_sup, const int j0, const uint3 ne01) {
289289
if constexpr (use_cp_async) {
290290
static_assert(nbatch_fa <= 8*WARP_SIZE && nbatch_fa % 8 == 0, "bad nbatch_fa");
291291
static_assert(!oob_check, "OOB check incompatible with cp_async");
@@ -296,73 +296,66 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
296296
const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
297297

298298
#pragma unroll
299-
for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
300-
const int j = j0 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
299+
for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
300+
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
301+
const int j_vram = fastmodulo(j0 + j_sram, ne01);
301302

302-
if (j0 + stride_j > ncols1 && j >= ncols1) {
303+
if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
303304
break;
304305
}
305306

306307
const int i = 8 * (threadIdx.x % (nbatch_fa/8));
307308

308-
if (ncols1 <= 2 || j < j_sup) {
309-
cp_async_cg_16<preload>(tile_mask_32 + j*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j*stride_mask + i);
310-
} else {
311-
const half zero[8] = {0.0f};
312-
ggml_cuda_memcpy_1<16>(tile_mask + j*(nbatch_fa + 8) + i, zero);
313-
}
309+
cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i);
314310
}
315311
} else if constexpr (oob_check) {
316312
#pragma unroll
317-
for (int j0 = 0; j0 < ncols1; j0 += nwarps) {
318-
const int j = j0 + threadIdx.y;
313+
for (int j1 = 0; j1 < ncols1; j1 += nwarps) {
314+
const int j_sram = j1 + threadIdx.y;
315+
const int j_vram = fastmodulo(j0 + j_sram, ne01);
319316

320-
if (j0 + nwarps > ncols1 && j >= ncols1) {
317+
if (j1 + nwarps > ncols1 && j_sram >= ncols1) {
321318
break;
322319
}
323320

324321
#pragma unroll
325322
for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) {
326323
const int i = i0 + threadIdx.x;
327324

328-
tile_mask[j*(nbatch_fa + 8) + i] = i < i_sup && (ncols1 <= 2 || j < j_sup) ?
329-
mask_h[j*stride_mask + i] : half(0.0f);
325+
tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
330326
}
331327
}
332328
} else if constexpr (nbatch_fa < 2*WARP_SIZE) {
333329
constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
334330
constexpr int stride_j = nwarps * cols_per_warp;
335331
#pragma unroll
336-
for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
337-
const int j = j0 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
332+
for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
333+
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
334+
const int j_vram = fastmodulo(j0 + j_sram, ne01);
338335

339-
if (j0 + stride_j > ncols1 && j >= ncols1) {
336+
if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
340337
break;
341338
}
342339

343340
const int i = threadIdx.x % (WARP_SIZE/cols_per_warp);
344341

345-
// TODO bigger chunks
346-
const half zero[2] = {0.0f};
347-
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j*(nbatch_fa + 8) + 2*i,
348-
ncols1 <= 2 || j < j_sup ? mask_h + j*stride_mask + 2*i : zero);
342+
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
349343
}
350344
} else {
351345
#pragma unroll
352-
for (int j0 = 0; j0 < ncols1; j0 += nwarps) {
353-
const int j = j0 + threadIdx.y;
346+
for (int j1 = 0; j1 < ncols1; j1 += nwarps) {
347+
const int j_sram = j1 + threadIdx.y;
348+
const int j_vram = fastmodulo(j0 + j_sram, ne01);
354349

355-
if (j0 + nwarps > ncols1 && j >= ncols1) {
350+
if (j1 + nwarps > ncols1 && j_sram >= ncols1) {
356351
break;
357352
}
358353

359354
#pragma unroll
360355
for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) {
361356
const int i = i0 + 2*threadIdx.x;
362357

363-
const half zero[2] = {0.0f};
364-
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j*(nbatch_fa + 8) + i,
365-
ncols1 <= 2 || j < j_sup ? mask_h + j*stride_mask + i : zero);
358+
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
366359
}
367360
}
368361
}
@@ -381,7 +374,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
381374
const float scale,
382375
const float slope,
383376
const float logit_softcap,
384-
const int ne01,
377+
const uint3 ne01,
385378
const int ne02,
386379
const int stride_K,
387380
const int stride_V,
@@ -394,9 +387,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
394387
T_C_VKQ * const __restrict__ VKQ_C,
395388
float * const __restrict__ KQ_max,
396389
float * const __restrict__ KQ_rowsum,
390+
const int jt,
397391
const int kb0,
398-
const int k_VKQ_sup,
399-
const int j_VKQ_sup) {
392+
const int k_VKQ_sup) {
400393
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
401394
constexpr int ncols = ncols1 * ncols2;
402395
constexpr int cols_per_warp = T_B_KQ::I;
@@ -434,7 +427,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
434427
constexpr bool use_cp_async = nstages == 1;
435428
if (ncols2 > 1 || mask_h) {
436429
flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
437-
(mask_h + k_VKQ_0, tile_mask, stride_mask, k_VKQ_sup, j_VKQ_sup);
430+
(mask_h + k_VKQ_0, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
438431
}
439432
}
440433

@@ -699,7 +692,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
699692
if (!last_iter) {
700693
if (ncols2 > 1 || mask_h) {
701694
flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
702-
(mask_h + k_VKQ_0 + nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, j_VKQ_sup);
695+
(mask_h + k_VKQ_0 + nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
703696
}
704697
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
705698
(K_h2 + int64_t(k_VKQ_0 + nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
@@ -821,7 +814,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
821814
const float scale,
822815
const float slope,
823816
const float logit_softcap,
824-
const int ne01,
817+
const uint3 ne01,
825818
const int ne02,
826819
const int ne11,
827820
const int stride_Q1,
@@ -911,7 +904,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
911904
const int j = jc / ncols2;
912905
const int c = jc % ncols2;
913906

914-
if (jt*ncols1 + j < ne01) {
907+
if (jt*ncols1 + j < int(ne01.z)) {
915908
#pragma unroll
916909
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
917910
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
@@ -943,8 +936,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
943936

944937
__syncthreads();
945938

946-
int kb0 = kb0_start;
947-
const int j_VKQ_sup = ne01 - jt*ncols1;
939+
int kb0 = kb0_start;
948940

949941
// Preload mask and K data for first iteration when using cp_async with multiple stages:
950942
if constexpr (nstages > 1) {
@@ -954,7 +946,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
954946
constexpr int k_VKQ_sup = nbatch_fa;
955947
if (ncols2 > 1 || mask_h) {
956948
flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
957-
(mask_h + kb0*nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, j_VKQ_sup);
949+
(mask_h + kb0*nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
958950
}
959951
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
960952
(K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
@@ -969,7 +961,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
969961
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
970962
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
971963
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
972-
KQ_max, KQ_rowsum, kb0, k_VKQ_sup, j_VKQ_sup);
964+
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
973965
}
974966
// kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
975967
if constexpr (ncols2 == 1) {
@@ -982,7 +974,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
982974
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
983975
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
984976
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
985-
KQ_max, KQ_rowsum, kb0, k_VKQ_sup, j_VKQ_sup);
977+
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
986978
} else {
987979
constexpr bool last_iter = true;
988980
constexpr bool oob_check = true;
@@ -992,7 +984,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
992984
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
993985
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
994986
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
995-
KQ_max, KQ_rowsum, kb0, k_VKQ_sup, j_VKQ_sup);
987+
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
996988
}
997989
} else {
998990
constexpr bool last_iter = true;
@@ -1003,7 +995,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1003995
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
1004996
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
1005997
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
1006-
KQ_max, KQ_rowsum, kb0, k_VKQ_sup, j_VKQ_sup);
998+
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
1007999
}
10081000

10091001
// With multi-stage loading there is no __syncthreads at the end of the iter,
@@ -1284,7 +1276,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
12841276
const int j_dst = jc_dst / ncols2;
12851277
const int c_dst = jc_dst % ncols2;
12861278

1287-
if (!is_fixup && jt*ncols1 + j_dst >= ne01) {
1279+
if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) {
12881280
continue;
12891281
}
12901282

@@ -1347,7 +1339,7 @@ static __global__ void flash_attn_ext_f16(
13471339
const float m1,
13481340
const uint32_t n_head_log2,
13491341
const float logit_softcap,
1350-
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
1342+
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
13511343
const int32_t nb01, const int32_t nb02, const int32_t nb03,
13521344
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
13531345
const int32_t nb11, const int32_t nb12, const int64_t nb13,
@@ -1384,8 +1376,8 @@ static __global__ void flash_attn_ext_f16(
13841376

13851377
const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
13861378

1387-
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
1388-
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
1379+
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
1380+
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
13891381

13901382
// kbc == k block continuous, current index in continuous ijk space.
13911383
int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
@@ -1409,8 +1401,8 @@ static __global__ void flash_attn_ext_f16(
14091401
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
14101402
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
14111403
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
1412-
(const half *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1413-
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
1404+
(const half *) (mask + nb33*(sequence % ne33));
1405+
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
14141406

14151407
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
14161408
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
@@ -1453,8 +1445,8 @@ static __global__ void flash_attn_ext_f16(
14531445
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
14541446
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
14551447
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
1456-
(const half *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1457-
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
1448+
(const half *) (mask + nb33*(sequence % ne33));
1449+
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
14581450

14591451
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
14601452
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;

0 commit comments

Comments
 (0)