diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 2efc9cc880cf8..2b60b3bb13563 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -540,10 +540,12 @@ static __device__ __forceinline__ void flash_attn_tile_iter( KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); } - KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) && (!oob_check || i_KQ < k_VKQ_sup) ? - slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f; + if (!oob_check || i_KQ < k_VKQ_sup) { + KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ? + slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f; - KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); + KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); + } } KQ_max_new[jc0] = warp_reduce_max(KQ_max_new[jc0]); @@ -581,10 +583,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter( float KQ_sum_add = 0.0f; #pragma unroll for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { - const float val = expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]); - if (!oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup) { - KQ_sum_add += val; - } + const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ? + expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f; + KQ_sum_add += val; tmp[i0/(np*warp_size)][jc1] = val; } KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add; @@ -975,26 +976,6 @@ static __global__ void flash_attn_tile( } } - if (gridDim.y == 1) { -#pragma unroll - for (int jc0 = 0; jc0 < cpw; ++jc0) { -#ifdef FAST_FP16_AVAILABLE - const half2 KQ_sum_jc_inv = make_half2(1.0f/KQ_sum[jc0], 1.0f/KQ_sum[jc0]); -#pragma unroll - for (int i = 0; i < (DVp/2)/warp_size; ++i) { - VKQ[jc0*((DVp/2)/warp_size) + i] *= KQ_sum_jc_inv; - } -#else - const float KQ_sum_jc_inv = 1.0f/KQ_sum[jc0]; -#pragma unroll - for (int i = 0; i < (DVp/2)/warp_size; ++i) { - VKQ[jc0*((DVp/2)/warp_size) + i].x *= KQ_sum_jc_inv; - VKQ[jc0*((DVp/2)/warp_size) + i].y *= KQ_sum_jc_inv; - } -#endif // FAST_FP16_AVAILABLE - } - } - // Write back results: #pragma unroll for (int jc0 = 0; jc0 < cpw; ++jc0) { @@ -1007,6 +988,8 @@ static __global__ void flash_attn_tile( return; } + const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f; + const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y; #ifdef FAST_FP16_AVAILABLE @@ -1017,6 +1000,8 @@ static __global__ void flash_attn_tile( #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; ++i1) { tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]); + tmp[i1].x *= scale; + tmp[i1].y *= scale; } if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) { ggml_cuda_memcpy_1(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp); @@ -1027,6 +1012,11 @@ static __global__ void flash_attn_tile( #pragma unroll for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) { +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) { + VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x *= scale; + VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y *= scale; + } ggml_cuda_memcpy_1( &dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D], &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]);