@@ -285,7 +285,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
285285template <int ncols1, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
286286static __device__ __forceinline__ void flash_attn_ext_f16_load_mask (
287287 const half * const __restrict__ mask_h, half * const __restrict__ tile_mask,
288- const int stride_mask, const int i_sup, const int j_sup ) {
288+ const int stride_mask, const int i_sup, const int j0, const uint3 ne01 ) {
289289 if constexpr (use_cp_async) {
290290 static_assert (nbatch_fa <= 8 *WARP_SIZE && nbatch_fa % 8 == 0 , " bad nbatch_fa" );
291291 static_assert (!oob_check, " OOB check incompatible with cp_async" );
@@ -296,73 +296,66 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
296296 const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared (tile_mask);
297297
298298#pragma unroll
299- for (int j0 = 0 ; j0 < ncols1; j0 += stride_j) {
300- const int j = j0 + threadIdx .y *cols_per_warp + threadIdx .x / (WARP_SIZE/cols_per_warp);
299+ for (int j1 = 0 ; j1 < ncols1; j1 += stride_j) {
300+ const int j_sram = j1 + threadIdx .y *cols_per_warp + threadIdx .x / (WARP_SIZE/cols_per_warp);
301+ const int j_vram = fastmodulo (j0 + j_sram, ne01);
301302
302- if (j0 + stride_j > ncols1 && j >= ncols1) {
303+ if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
303304 break ;
304305 }
305306
306307 const int i = 8 * (threadIdx .x % (nbatch_fa/8 ));
307308
308- if (ncols1 <= 2 || j < j_sup) {
309- cp_async_cg_16<preload>(tile_mask_32 + j*(nbatch_fa*sizeof (half) + 16 ) + i*sizeof (half), mask_h + j*stride_mask + i);
310- } else {
311- const half zero[8 ] = {0 .0f };
312- ggml_cuda_memcpy_1<16 >(tile_mask + j*(nbatch_fa + 8 ) + i, zero);
313- }
309+ cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof (half) + 16 ) + i*sizeof (half), mask_h + j_vram*stride_mask + i);
314310 }
315311 } else if constexpr (oob_check) {
316312#pragma unroll
317- for (int j0 = 0 ; j0 < ncols1; j0 += nwarps) {
318- const int j = j0 + threadIdx .y ;
313+ for (int j1 = 0 ; j1 < ncols1; j1 += nwarps) {
314+ const int j_sram = j1 + threadIdx .y ;
315+ const int j_vram = fastmodulo (j0 + j_sram, ne01);
319316
320- if (j0 + nwarps > ncols1 && j >= ncols1) {
317+ if (j1 + nwarps > ncols1 && j_sram >= ncols1) {
321318 break ;
322319 }
323320
324321#pragma unroll
325322 for (int i0 = 0 ; i0 < nbatch_fa; i0 += WARP_SIZE) {
326323 const int i = i0 + threadIdx .x ;
327324
328- tile_mask[j*(nbatch_fa + 8 ) + i] = i < i_sup && (ncols1 <= 2 || j < j_sup) ?
329- mask_h[j*stride_mask + i] : half (0 .0f );
325+ tile_mask[j_sram*(nbatch_fa + 8 ) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half (0 .0f );
330326 }
331327 }
332328 } else if constexpr (nbatch_fa < 2 *WARP_SIZE) {
333329 constexpr int cols_per_warp = 2 *WARP_SIZE/nbatch_fa;
334330 constexpr int stride_j = nwarps * cols_per_warp;
335331#pragma unroll
336- for (int j0 = 0 ; j0 < ncols1; j0 += stride_j) {
337- const int j = j0 + threadIdx .y *cols_per_warp + threadIdx .x / (WARP_SIZE/cols_per_warp);
332+ for (int j1 = 0 ; j1 < ncols1; j1 += stride_j) {
333+ const int j_sram = j1 + threadIdx .y *cols_per_warp + threadIdx .x / (WARP_SIZE/cols_per_warp);
334+ const int j_vram = fastmodulo (j0 + j_sram, ne01);
338335
339- if (j0 + stride_j > ncols1 && j >= ncols1) {
336+ if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
340337 break ;
341338 }
342339
343340 const int i = threadIdx .x % (WARP_SIZE/cols_per_warp);
344341
345- // TODO bigger chunks
346- const half zero[2 ] = {0 .0f };
347- ggml_cuda_memcpy_1<sizeof (half2)>(tile_mask + j*(nbatch_fa + 8 ) + 2 *i,
348- ncols1 <= 2 || j < j_sup ? mask_h + j*stride_mask + 2 *i : zero);
342+ ggml_cuda_memcpy_1<sizeof (half2)>(tile_mask + j_sram*(nbatch_fa + 8 ) + 2 *i, mask_h + j_vram*stride_mask + 2 *i);
349343 }
350344 } else {
351345#pragma unroll
352- for (int j0 = 0 ; j0 < ncols1; j0 += nwarps) {
353- const int j = j0 + threadIdx .y ;
346+ for (int j1 = 0 ; j1 < ncols1; j1 += nwarps) {
347+ const int j_sram = j1 + threadIdx .y ;
348+ const int j_vram = fastmodulo (j0 + j_sram, ne01);
354349
355- if (j0 + nwarps > ncols1 && j >= ncols1) {
350+ if (j1 + nwarps > ncols1 && j_sram >= ncols1) {
356351 break ;
357352 }
358353
359354#pragma unroll
360355 for (int i0 = 0 ; i0 < nbatch_fa; i0 += 2 *WARP_SIZE) {
361356 const int i = i0 + 2 *threadIdx .x ;
362357
363- const half zero[2 ] = {0 .0f };
364- ggml_cuda_memcpy_1<sizeof (half2)>(tile_mask + j*(nbatch_fa + 8 ) + i,
365- ncols1 <= 2 || j < j_sup ? mask_h + j*stride_mask + i : zero);
358+ ggml_cuda_memcpy_1<sizeof (half2)>(tile_mask + j_sram*(nbatch_fa + 8 ) + i, mask_h + j_vram*stride_mask + i);
366359 }
367360 }
368361 }
@@ -381,7 +374,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
381374 const float scale,
382375 const float slope,
383376 const float logit_softcap,
384- const int ne01,
377+ const uint3 ne01,
385378 const int ne02,
386379 const int stride_K,
387380 const int stride_V,
@@ -394,9 +387,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
394387 T_C_VKQ * const __restrict__ VKQ_C,
395388 float * const __restrict__ KQ_max,
396389 float * const __restrict__ KQ_rowsum,
390+ const int jt,
397391 const int kb0,
398- const int k_VKQ_sup,
399- const int j_VKQ_sup) {
392+ const int k_VKQ_sup) {
400393#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
401394 constexpr int ncols = ncols1 * ncols2;
402395 constexpr int cols_per_warp = T_B_KQ::I;
@@ -434,7 +427,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
434427 constexpr bool use_cp_async = nstages == 1 ;
435428 if (ncols2 > 1 || mask_h) {
436429 flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
437- (mask_h + k_VKQ_0, tile_mask, stride_mask, k_VKQ_sup, j_VKQ_sup );
430+ (mask_h + k_VKQ_0, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01 );
438431 }
439432 }
440433
@@ -699,7 +692,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
699692 if (!last_iter) {
700693 if (ncols2 > 1 || mask_h) {
701694 flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
702- (mask_h + k_VKQ_0 + nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, j_VKQ_sup );
695+ (mask_h + k_VKQ_0 + nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01 );
703696 }
704697 flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
705698 (K_h2 + int64_t (k_VKQ_0 + nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
@@ -821,7 +814,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
821814 const float scale,
822815 const float slope,
823816 const float logit_softcap,
824- const int ne01,
817+ const uint3 ne01,
825818 const int ne02,
826819 const int ne11,
827820 const int stride_Q1,
@@ -911,7 +904,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
911904 const int j = jc / ncols2;
912905 const int c = jc % ncols2;
913906
914- if (jt*ncols1 + j < ne01) {
907+ if (jt*ncols1 + j < int ( ne01. z ) ) {
915908#pragma unroll
916909 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
917910 const int k = k0 + (stride_k == WARP_SIZE ? threadIdx .x : threadIdx .x % stride_k);
@@ -943,8 +936,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
943936
944937 __syncthreads ();
945938
946- int kb0 = kb0_start;
947- const int j_VKQ_sup = ne01 - jt*ncols1;
939+ int kb0 = kb0_start;
948940
949941 // Preload mask and K data for first iteration when using cp_async with multiple stages:
950942 if constexpr (nstages > 1 ) {
@@ -954,7 +946,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
954946 constexpr int k_VKQ_sup = nbatch_fa;
955947 if (ncols2 > 1 || mask_h) {
956948 flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
957- (mask_h + kb0*nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, j_VKQ_sup );
949+ (mask_h + kb0*nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01 );
958950 }
959951 flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
960952 (K_h2 + int64_t (kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
@@ -969,7 +961,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
969961 T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
970962 (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
971963 ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
972- KQ_max, KQ_rowsum, kb0, k_VKQ_sup, j_VKQ_sup );
964+ KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup );
973965 }
974966 // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
975967 if constexpr (ncols2 == 1 ) {
@@ -982,7 +974,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
982974 T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
983975 (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
984976 ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
985- KQ_max, KQ_rowsum, kb0, k_VKQ_sup, j_VKQ_sup );
977+ KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup );
986978 } else {
987979 constexpr bool last_iter = true ;
988980 constexpr bool oob_check = true ;
@@ -992,7 +984,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
992984 T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
993985 (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
994986 ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
995- KQ_max, KQ_rowsum, kb0, k_VKQ_sup, j_VKQ_sup );
987+ KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup );
996988 }
997989 } else {
998990 constexpr bool last_iter = true ;
@@ -1003,7 +995,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1003995 T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
1004996 (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
1005997 ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
1006- KQ_max, KQ_rowsum, kb0, k_VKQ_sup, j_VKQ_sup );
998+ KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup );
1007999 }
10081000
10091001 // With multi-stage loading there is no __syncthreads at the end of the iter,
@@ -1284,7 +1276,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
12841276 const int j_dst = jc_dst / ncols2;
12851277 const int c_dst = jc_dst % ncols2;
12861278
1287- if (!is_fixup && jt*ncols1 + j_dst >= ne01) {
1279+ if (!is_fixup && jt*ncols1 + j_dst >= int ( ne01. z ) ) {
12881280 continue ;
12891281 }
12901282
@@ -1347,7 +1339,7 @@ static __global__ void flash_attn_ext_f16(
13471339 const float m1,
13481340 const uint32_t n_head_log2,
13491341 const float logit_softcap,
1350- const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
1342+ const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
13511343 const int32_t nb01, const int32_t nb02, const int32_t nb03,
13521344 const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
13531345 const int32_t nb11, const int32_t nb12, const int64_t nb13,
@@ -1384,8 +1376,8 @@ static __global__ void flash_attn_ext_f16(
13841376
13851377 const int stride_V = mla ? stride_K : nb21 / sizeof (half2);
13861378
1387- const int iter_k = (ne11 + (nbatch_fa - 1 )) / nbatch_fa;
1388- const int iter_j = (ne01 + (ncols1 - 1 )) / ncols1;
1379+ const int iter_k = (ne11 + (nbatch_fa - 1 )) / nbatch_fa;
1380+ const int iter_j = (ne01. z + (ncols1 - 1 )) / ncols1;
13891381
13901382 // kbc == k block continuous, current index in continuous ijk space.
13911383 int kbc = (blockIdx .x + 0 )*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim .x ;
@@ -1409,8 +1401,8 @@ static __global__ void flash_attn_ext_f16(
14091401 const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
14101402 const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
14111403 const half * mask_h = ncols2 == 1 && !mask ? nullptr :
1412- (const half *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1 );
1413- float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2 );
1404+ (const half *) (mask + nb33*(sequence % ne33));
1405+ float2 * dstk = ((float2 *) dst) + (sequence*ne01. z *ne02 + head0) * (DV/2 );
14141406
14151407 const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2 ) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
14161408 const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr ;
@@ -1453,8 +1445,8 @@ static __global__ void flash_attn_ext_f16(
14531445 const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
14541446 const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
14551447 const half * mask_h = ncols2 == 1 && !mask ? nullptr :
1456- (const half *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1 );
1457- float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2 );
1448+ (const half *) (mask + nb33*(sequence % ne33));
1449+ float2 * dstk = ((float2 *) dst) + (sequence*ne01. z *ne02 + head0) * (DV/2 );
14581450
14591451 const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2 ) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
14601452 const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr ;
0 commit comments