diff --git a/csrc/flash_dmattn/src/flash_bwd_kernel.h b/csrc/flash_dmattn/src/flash_bwd_kernel.h index 9a8b534..94f4470 100644 --- a/csrc/flash_dmattn/src/flash_bwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_bwd_kernel.h @@ -644,17 +644,37 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in clear(acc_dv); if constexpr (Has_bias) { if (accum_dbias) { clear(acc_dbias); } } + cute::cp_async_wait<0>(); + __syncthreads(); + + // Scale K once before streaming loop Q + #pragma unroll + for (int k = 0; k < size(tKsK); ++k) { + tKsK(k) = static_cast(tKsK(k) * params.scale_softmax); + } + for (; m_block >= m_block_min; --m_block) { Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - cute::cp_async_wait<0>(); - __syncthreads(); - Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // (MMA=4, MMA_M, MMA_K) + cute::cp_async_wait<0>(); + __syncthreads(); if (any_active) { - clear(acc_s); + if constexpr (Has_bias) { + // Copy bias from smem to registers + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias); + cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view); + #pragma unroll + for (int i = 0; i < size(acc_s); ++i) { acc_s(i) = tSrBias(i); } + } else { + clear(acc_s); + } + } + + if (any_active) { Tensor dP_sum = make_fragment_like(lse); #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); } @@ -686,35 +706,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in FLASH_NAMESPACE::calculate_dtanh(scores, dtanh, params.softcap); } - if constexpr (Has_mask && Has_bias) { - // Copy mask and bias from smem to registers - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_copy_view = smem_thr_copy_PdS.retile_D(tSrMask); - cute::copy(smem_tiled_copy_PdS, tSsMask, tSrMask_copy_view); - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias); - cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view); - - // Reshape mask, bias from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) - Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout())); - Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout())); - - // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond - // actual_seqlen_k, because acc_s would be some finite value for those indices. - // In the end when we multiply with K to get dQ, the corresponding values of K would be 0, - // so the result would still be correct. - // However, it's possible that the values in acc_s are so large that they overflow - // when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ. - // So we need to mask out the elements beyond actual_seqlen_k. - FLASH_NAMESPACE::apply_mask( - scores, mask, bias, params.scale_softmax, - n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - binfo.actual_seqlen_k, - m_block * kBlockM + get<0>(taccScS_row(0)), - binfo.actual_seqlen_q, - AtomLayoutMS * 16 - ); - } else if constexpr (Has_mask && !Has_bias) { + if constexpr (Has_mask) { // Copy mask from smem to registers Tensor tSrMask = make_tensor(shape(acc_s)); Tensor tSrMask_copy_view = smem_thr_copy_PdS.retile_D(tSrMask); @@ -722,26 +714,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Reshape mask from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout())); - - FLASH_NAMESPACE::apply_mask( - scores, mask, /*bias=*/nullptr, params.scale_softmax, - n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - binfo.actual_seqlen_k, - m_block * kBlockM + get<0>(taccScS_row(0)), - binfo.actual_seqlen_q, - AtomLayoutMS * 16 - ); - } else if constexpr (!Has_mask && Has_bias) { - // Copy bias from smem to registers - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias); - cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view); - // Reshape bias from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) - Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout())); - - FLASH_NAMESPACE::apply_mask( - scores, /*mask=*/nullptr, bias, params.scale_softmax, + FLASH_NAMESPACE::apply_mask( + scores, mask, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), @@ -749,8 +724,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in AtomLayoutMS * 16 ); } else { - FLASH_NAMESPACE::apply_mask( - scores, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax, + FLASH_NAMESPACE::apply_mask( + scores, /*mask=*/nullptr, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), @@ -965,8 +940,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in for (int i = 0; i < size(acc_dq); ++i) { atomicAdd(&tdQgdQaccum(i), acc_dq(i)); } } } else { - #pragma unroll - for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; } + // #pragma unroll + // for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; } // Convert acc_dq from fp32 to fp16 Tensor rdQ = FLASH_NAMESPACE::convert_type(acc_dq); Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom, AtomNum), MMA_M, MMA_K) diff --git a/csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h b/csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h index d6710b2..4e550ff 100644 --- a/csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h +++ b/csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h @@ -279,8 +279,8 @@ inline __device__ void convert_dQ( for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); } tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride; } - #pragma unroll - for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; } + // #pragma unroll + // for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; } // Convert acc_dq from fp32 to fp16 Tensor rdQ = FLASH_NAMESPACE::convert_type(acc_dq); Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom, AtomNum), MMA_N, MMA_N) diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_dmattn/src/flash_fwd_kernel.h index 576de90..9f7ca6b 100644 --- a/csrc/flash_dmattn/src/flash_fwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_fwd_kernel.h @@ -422,6 +422,22 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi binfo.actual_seqlen_k, binfo.actual_seqlen_q ); + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + + // Scale Q once before streaming loop KV + if constexpr (Kernel_traits::Is_Q_in_regs) { + #pragma unroll + for (int i = 0; i < size(tSrQ); ++i) { + tSrQ(i) = static_cast(tSrQ(i) * params.scale_softmax); + } + } else { + #pragma unroll + for (int i = 0; i < size(tSsQ); ++i) { + tSsQ(i) = static_cast(tSsQ(i) * params.scale_softmax); + } + } + // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. // We need masking on S for the very last block when K and V has length not multiple of kBlockN. @@ -437,10 +453,22 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); + if (any_active) { + if constexpr (Has_bias) { + // Copy bias from smem to acc_s registers + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + #pragma unroll + for (int i = 0; i < size(acc_s); ++i) { acc_s(i) = tSrBias(i); } + } else { + clear(acc_s); + } + } + // Advance gV if (masking_step > 0 && any_active) { FLASH_NAMESPACE::copy( @@ -473,46 +501,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - if constexpr (Has_mask && Has_bias) { - // Copy mask and bias from smem to registers - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); - cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and apply mask and add bias - mask.template apply_mask( - acc_s, tSrMask, tSrBias, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); - } else if constexpr (Has_mask && !Has_bias) { + if constexpr (Has_mask) { // Copy mask from smem to registers Tensor tSrMask = make_tensor(shape(acc_s)); Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - // Scale attention scores and apply mask - mask.template apply_mask( - acc_s, tSrMask, /*bias=*/nullptr, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); - } else if constexpr (!Has_mask && Has_bias) { - // Copy bias from smem to registers - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and add bias - mask.template apply_mask( - acc_s, /*mask=*/nullptr, tSrBias, params.scale_softmax, + mask.template apply_mask( + acc_s, tSrMask, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); } else { - // Scale attention scores only - mask.template apply_mask( - acc_s, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax, + mask.template apply_mask( + acc_s, /*mask=*/nullptr, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); } @@ -602,10 +603,22 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // These are the iterations where we don't need masking on S for (; n_block >= n_block_min; --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); + if (any_active) { + if constexpr (Has_bias) { + // Copy bias from smem to acc_s registers + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + #pragma unroll + for (int i = 0; i < size(acc_s); ++i) { acc_s(i) = tSrBias(i); } + } else { + clear(acc_s); + } + } + // Advance gV if (any_active) { FLASH_NAMESPACE::copy( @@ -628,46 +641,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - if constexpr (Has_mask && Has_bias) { - // Copy mask and bias from smem to registers - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); - cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and apply mask and add bias - mask.template apply_mask( - acc_s, tSrMask, tSrBias, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); - } else if constexpr (Has_mask && !Has_bias) { + if constexpr (Has_mask) { // Copy mask from smem to registers Tensor tSrMask = make_tensor(shape(acc_s)); Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - // Scale attention scores and apply mask - mask.template apply_mask( - acc_s, tSrMask, /*bias=*/nullptr, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); - } else if constexpr (!Has_mask && Has_bias) { - // Copy bias from smem to registers - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and apply bias - mask.template apply_mask( - acc_s, /*mask=*/nullptr, tSrBias, params.scale_softmax, + mask.template apply_mask( + acc_s, tSrMask, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); } else { - // Scale attention scores only - mask.template apply_mask( - acc_s, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax, + mask.template apply_mask( + acc_s, /*mask=*/nullptr, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); } @@ -1157,6 +1143,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons binfo.actual_seqlen_k, binfo.actual_seqlen_q ); + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + + // Scale Q once before streaming loop KV + #pragma unroll + for (int i = 0; i < size(tSrQ); ++i) { + tSsQ(i) = static_cast(tSsQ(i) * params.scale_softmax); + } + // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. // We need masking on S for the very last block when K and V has length not multiple of kBlockN. @@ -1172,10 +1167,22 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); + if (any_active) { + if constexpr (Has_bias) { + // Copy bias from smem to acc_s registers + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + #pragma unroll + for (int i = 0; i < size(acc_s); ++i) { acc_s(i) = tSrBias(i); } + } else { + clear(acc_s); + } + } + // Advance gV if (masking_step > 0) { if (block_table == nullptr) { @@ -1219,46 +1226,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - if constexpr (Has_mask && Has_bias) { - // Copy mask and bias from smem to registers - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); - cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and apply mask and bias - mask.template apply_mask( - acc_s, tSrMask, tSrBias, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); - } else if constexpr (Has_mask && !Has_bias) { + if constexpr (Has_mask) { // Copy mask from smem to registers Tensor tSrMask = make_tensor(shape(acc_s)); Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - // Scale attention scores and apply mask - mask.template apply_mask( - acc_s, tSrMask, /*bias=*/nullptr, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); - } else if constexpr (!Has_mask && Has_bias) { - // Copy bias from smem to registers - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and add bias - mask.template apply_mask( - acc_s, /*mask=*/nullptr, tSrBias, params.scale_softmax, + mask.template apply_mask( + acc_s, tSrMask, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); } else { - // Scale attention scores only - mask.template apply_mask( - acc_s, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax, + mask.template apply_mask( + acc_s, /*mask=*/nullptr, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); } @@ -1368,10 +1348,22 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // These are the iterations where we don't need masking on S for (; n_block >= n_block_min; --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); + if (any_active) { + if constexpr (Has_bias) { + // Copy bias from smem to acc_s registers + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + #pragma unroll + for (int i = 0; i < size(acc_s); ++i) { acc_s(i) = tSrBias(i); } + } else { + clear(acc_s); + } + } + // Advance gV if (block_table == nullptr) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); @@ -1403,46 +1395,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - if constexpr (Has_mask && Has_bias) { - // Copy mask and bias from smem to registers - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); - cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and apply mask and bias - mask.template apply_mask( - acc_s, tSrMask, tSrBias, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); - } else if constexpr (Has_mask && !Has_bias) { + if constexpr (Has_mask) { // Copy mask from smem to registers Tensor tSrMask = make_tensor(shape(acc_s)); Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - // Scale attention scores and apply mask - mask.template apply_mask( - acc_s, tSrMask, /*bias=*/nullptr, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); - } else if constexpr (!Has_mask && Has_bias) { - // Copy bias from smem to registers - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and add bias - mask.template apply_mask( - acc_s, /*mask=*/nullptr, tSrBias, params.scale_softmax, + mask.template apply_mask( + acc_s, tSrMask, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); } else { - // Scale attention scores only - mask.template apply_mask( - acc_s, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax, + mask.template apply_mask( + acc_s, /*mask=*/nullptr, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); } diff --git a/csrc/flash_dmattn/src/mask.h b/csrc/flash_dmattn/src/mask.h index 9f81cae..9b0d81c 100644 --- a/csrc/flash_dmattn/src/mask.h +++ b/csrc/flash_dmattn/src/mask.h @@ -11,12 +11,10 @@ namespace FLASH_NAMESPACE { using namespace cute; -template +template __forceinline__ __device__ void apply_mask( TensorType &tensor, const MaskType &mask, - const BiasType &bias, - const float scale_softmax, const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset, @@ -27,13 +25,11 @@ __forceinline__ __device__ void apply_mask( static_assert(TensorType::rank == 2, "Only support 2D Tensor"); if constexpr (Has_mask) static_assert(MaskType::rank == 2, "Only support 2D Mask"); - if constexpr (Has_bias) - static_assert(BiasType::rank == 2, "Only support 2D Bias"); const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - if constexpr (Has_mask && Has_bias) { + if constexpr (Has_mask) { #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { const int row_idx_base = row_idx_offset + mi * warp_row_stride; @@ -49,63 +45,15 @@ __forceinline__ __device__ void apply_mask( const int col_idx = col_idx_base + j; // Without the "make_coord" we get wrong results auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling and bias or masking - tensor(coord) = (col_idx >= col_idx_limit) || (!mask(coord)) - ? -INFINITY - : tensor(coord) * scale_softmax + bias(coord); - } - } - } - } - } else if constexpr (Has_mask && !Has_bias) { - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - // Without the "make_coord" we get wrong results - auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling or masking - tensor(coord) = (col_idx >= col_idx_limit) || (!mask(coord)) - ? -INFINITY - : tensor(coord) * scale_softmax; - } - } - } - } - } else if constexpr (!Has_mask && Has_bias) { - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - // Without the "make_coord" we get wrong results - auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling and bias - tensor(coord) = (col_idx >= col_idx_limit) - ? -INFINITY - : tensor(coord) * scale_softmax + bias(coord); + // Apply masking + if (col_idx >= col_idx_limit || !mask(coord)) { + tensor(coord) = -INFINITY; + } } } } } - } else { // !Has_mask && !Has_bias + } else { #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { const int row_idx_base = row_idx_offset + mi * warp_row_stride; @@ -121,10 +69,10 @@ __forceinline__ __device__ void apply_mask( const int col_idx = col_idx_base + j; // Without the "make_coord" we get wrong results auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling - tensor(coord) = (col_idx >= col_idx_limit) - ? -INFINITY - : tensor(coord) * scale_softmax; + // Apply masking + if (col_idx >= col_idx_limit) { + tensor(coord) = -INFINITY; + } } } } @@ -143,12 +91,10 @@ struct Mask { , max_seqlen_q(max_seqlen_q) { }; - template + template __forceinline__ __device__ void apply_mask( TensorType &tensor_, // acc_s (attention scores, MMA=4, MMA_M, MMA_N) const MaskType &mask_, // Attention Mask (MMA=4, MMA_M, MMA_N) - const BiasType &bias_, // Attention Bias (MMA=4, MMA_M, MMA_N) - const float scale_softmax, // Scale for softmax const int col_idx_offset_, // Column index offset const int row_idx_offset, // Row index offset const int warp_row_stride // Warp row stride @@ -156,8 +102,6 @@ struct Mask { static_assert(TensorType::rank == 3, "tensor_ must be 3D Tensor"); if constexpr (Has_mask) static_assert(MaskType::rank == 3, "mask_ must be 3D Tensor"); - if constexpr (Has_bias) - static_assert(BiasType::rank == 3, "Bias must be 3D Tensor"); static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); // Reshape tensors from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) @@ -166,32 +110,7 @@ struct Mask { const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - if constexpr (Has_mask && Has_bias) { - Tensor mask = make_tensor(mask_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(mask_.layout())); - Tensor bias = make_tensor(bias_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(bias_.layout())); - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling and bias or masking - tensor(coord) = (col_idx >= col_idx_limit) || (!mask(coord)) - ? -INFINITY - : tensor(coord) * scale_softmax + bias(coord); - } - } - } - } - } else if constexpr (Has_mask && !Has_bias) { + if constexpr (Has_mask) { Tensor mask = make_tensor(mask_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(mask_.layout())); #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { @@ -207,39 +126,15 @@ struct Mask { for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling or masking - tensor(coord) = (col_idx >= col_idx_limit) || (!mask(coord)) - ? -INFINITY - : tensor(coord) * scale_softmax; - } - } - } - } - } else if constexpr (!Has_mask && Has_bias) { - Tensor bias = make_tensor(bias_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(bias_.layout())); - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling and bias - tensor(coord) = (col_idx >= col_idx_limit) - ? -INFINITY - : tensor(coord) * scale_softmax + bias(coord); + // Apply masking + if (col_idx >= col_idx_limit || !mask(coord)) { + tensor(coord) = -INFINITY; + } } } } } - } else { // !Has_mask && !Has_bias + } else { #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { const int row_idx_base = row_idx_offset + mi * warp_row_stride; @@ -254,10 +149,10 @@ struct Mask { for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling - tensor(coord) = (col_idx >= col_idx_limit) - ? -INFINITY - : tensor(coord) * scale_softmax; + // Apply masking + if (col_idx >= col_idx_limit) { + tensor(coord) = -INFINITY; + } } } }