diff --git a/csrc/src/flash_bwd_kernel.h b/csrc/src/flash_bwd_kernel.h index db4f0e7..643cfcd 100644 --- a/csrc/src/flash_bwd_kernel.h +++ b/csrc/src/flash_bwd_kernel.h @@ -552,24 +552,40 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK); // // if (cute::thread(1, 0)) { print(tKrK); } - FLASH_NAMESPACE::copy( - gmem_tiled_copy_QKV, - tKgK, tKsK, - tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN - ); FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_MaskBias, tMaskgMask, tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tBiasgBias, tBiassBias, - tBiascBias, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN + cute::cp_async_fence(); + FLASH_NAMESPACE::cp_async_wait<0>(); + + // Do OR-reduce on the mask to see if any active threads + Tensor tSsMask_copy_view = smem_thr_copy_PdS.retile_S(tSsMask); + bool any_active_local = false; + bool any_active_local_next = false; // to be updated later for next iteration + #pragma unroll + for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); } + bool any_active = __syncthreads_or(any_active_local); + bool any_active_next = false; // to be updated later for next iteration + + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, + tKgK, tKsK, + tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN ); + + if (any_active) { + FLASH_NAMESPACE::copy_MN( + gmem_tiled_copy_MaskBias, + tBiasgBias, tBiassBias, + tBiascBias, + binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + if (!Kernel_traits::Is_V_in_regs) { FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, @@ -602,52 +618,148 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in for (; m_block >= m_block_min; --m_block) { Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); cute::cp_async_wait<0>(); __syncthreads(); - // Copy mask from smem to registers and do OR-reduce to see if any active threads - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_view = smem_thr_copy_PdS.retile_D(tSrMask); - Tensor tSsMask_view = smem_thr_copy_PdS.retile_S(tSsMask); - bool any_active_local = false; - #pragma unroll - for (int i = 0; i < size(tSrMask_view); ++i) { - Element m = tSsMask_view(i); - any_active_local |= (m != Element(0)); - tSrMask_view(i) = m; - } - bool any_active = __syncthreads_or(any_active_local); + Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // (MMA=4, MMA_M, MMA_K) - // Early skip for fully masked blocks - if (!any_active) { + if (any_active) { + clear(acc_s); - Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - CUTE_STATIC_ASSERT_V(size<0>(acc_dp) == size<0>(acc_s)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(acc_dp) == size<1>(acc_s)); // MMA - CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s)); // MMA + Tensor dP_sum = make_fragment_like(lse); + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); } + + // if (cute::thread0()) { print(sK); } + // Tensor tSrK_copy_view = smem_thr_copy_KV.retile_D(tSrK); + // #pragma unroll + // for (int k = 0; k < size<2>(tSrK_copy_view); ++k) { + // cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k)); + // } + // if (cute::thread0()) { print(tSrK); } + FLASH_NAMESPACE::gemm( + acc_s, + tSrQ, tSrK, tSsQ, tSsK, + tiled_mma_sdp, + smem_tiled_copy_QdO, smem_tiled_copy_KV, + smem_thr_copy_QdO, smem_thr_copy_KV + ); + if constexpr (Is_softcap) { + FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); + } + + // Copy mask and bias from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_PdS.retile_D(tSrMask); + cute::copy(smem_tiled_copy_PdS, tSsMask, tSrMask_copy_view); + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias); + cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view); + + // Reshape acc_s, mask, bias from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); + Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout())); + Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout())); + // if (cute::thread(32, 0)) { print(scores); } + + // Softcapping - calculating dTanh and scaling dS later with it + [[maybe_unused]] Tensor dtanh = make_tensor_like(scores); + if constexpr (Is_softcap) { + FLASH_NAMESPACE::calculate_dtanh(scores, dtanh, params.softcap); + } + + // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond + // actual_seqlen_k, because acc_s would be some finite value for those indices. + // In the end when we multiply with K to get dQ, the corresponding values of K would be 0, + // so the result would still be correct. + // However, it's possible that the values in acc_s are so large that they overflow + // when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ. + // So we need to mask out the elements beyond actual_seqlen_k. + FLASH_NAMESPACE::apply_mask( + scores, mask, bias, params.scale_softmax, + n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + binfo.actual_seqlen_k, + m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_q, + AtomLayoutMS * 16 + ); + + // if (cute::thread(32, 0)) { print(scores); } + // Compute the exponential value. + // FLASH_NAMESPACE::scale_apply_exp2(scores, lse, params.scale_softmax_log2); + FLASH_NAMESPACE::scale_apply_exp2(scores, lse, float(M_LOG2E)); + // Convert scores from fp32 to fp16/bf16 + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tPrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom, AtomNum), MMA_M, MMA_N) + cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); + // if (cute::thread0()) { print(tPaP); } + // __syncthreads(); + // if (cute::thread0()) { print(sP); } clear(acc_dp); + // Tensor acc_dp_reshaped = make_tensor(acc_dp.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_dp.layout())); + // #pragma unroll + // for (int mi = 0; mi < size<0>(acc_dp_reshaped); ++mi) { + // #pragma unroll + // for (int ni = 0; ni < size<1>(acc_dp_reshaped); ++ni) { + // acc_dp_reshaped(mi, ni) = -dP_sum(mi); + // } + // } + + // if (cute::thread0()) { print(dP_sum); } + + FLASH_NAMESPACE::gemm( + acc_dp, + tdPrdO, tdPrV, tdPsdO, tdPsV, + tiled_mma_sdp, + smem_tiled_copy_QdO, smem_tiled_copy_KV, + smem_thr_copy_QdO, smem_thr_copy_KV + ); + + // Reshape acc_dp from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) + Tensor dS = make_tensor(acc_dp.data(), scores.layout()); + auto pointwise_mult = [](float p, float dp, float d) { + return p * (p >= 0 ? dp - d : d); + }; + #pragma unroll + for (int mi = 0; mi < size<0>(dS); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(dS); ++ni) { + float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); + if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); } + dS(mi, ni) = scaled_ds; + } + } + // if (cute::thread0()) { print(dS); } + } + + tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.h * params.d_rounded)); - Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_M, MMA_K - tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.h * params.d_rounded)); + if (any_active) { if (Is_first || Seq_parallel) { clear(acc_dq); } else { - Tensor acc_dq_reshaped_load = make_tensor( + // Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum + Tensor acc_dq_reshaped = make_tensor( acc_dq.data(), make_layout(get<0>(acc_dq.layout()), get<2>(acc_dq.layout()), get<1>(acc_dq.layout())) ); - cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, acc_dq_reshaped_load); + cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, acc_dq_reshaped); } + } - if (Double_buffer && m_block > m_block_min) { - // Double buffer for sQ - const int sQ_offset = m_block % 2 == 0 ? size(sQ) : -size(sQ); - tQsQ.data() = tQsQ.data() + sQ_offset; - tSsQ.data() = tSsQ.data() + sQ_offset; - // Advance gQ - tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride)); + if (Double_buffer && m_block > m_block_min) { + // Double buffer for sQ + const int sQ_offset = m_block % 2 == 0 ? size(sQ) : -size(sQ); + tQsQ.data() = tQsQ.data() + sQ_offset; + tSsQ.data() = tSsQ.data() + sQ_offset; + // Advance gQ + tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride)); + if (any_active) { FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, tQgQ, tQsQ, @@ -655,9 +767,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in ); FLASH_NAMESPACE::cp_async_fence(); } + } - Tensor tdSrdS = make_tensor(shape(acc_dp)); - clear(tdSrdS); + if (any_active) { + // Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout()); + // Convert dS from fp32 to fp16 + Tensor tdSrdS = FLASH_NAMESPACE::convert_type(acc_dp); Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom, AtomNum), MMA_M, MMA_N) cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); __syncthreads(); @@ -670,9 +785,46 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in binfo.actual_seqlen_k - n_block * kBlockN ); - if (m_block > m_block_min) { - // Advance gdO - tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride)); + // if (cute::thread0()) { print(tPrP); } + // Layout p_l = tPrP.layout(); + // Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l))); + // FLASH_NAMESPACE::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt); + // Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout()); + // FLASH_NAMESPACE::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt); + FLASH_NAMESPACE::gemm( + acc_dv, + tdVrPt, tdVrdO, tdVsPt, tdVsdOt, + tiled_mma_dkv, + smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, + smem_thr_copy_PdSt, smem_thr_copy_QdOt + ); + // if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); } + // if (cute::thread0()) { print(acc_dv); } + + __syncthreads(); // Need syncthreads since we're writing to the same sdO location + } + + if (m_block > m_block_min) { + // Advance gMask + tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride)); + FLASH_NAMESPACE::copy_MN( + gmem_tiled_copy_MaskBias, + tMaskgMask, tMasksMask, + tMaskcMask, + binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN + ); + FLASH_NAMESPACE::cp_async_fence(); + FLASH_NAMESPACE::cp_async_wait<0>(); + + // Do OR-reduce on the mask to see if any active threads for next iteration + any_active_local_next = false; + #pragma unroll + for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local_next |= (tSsMask_copy_view(i) != Element(0)); } + any_active_next = __syncthreads_or(any_active_local_next); + + // Advance gdO + tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride)); + if (any_active_next) { if (Is_first) { tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride)); FLASH_NAMESPACE::copy( @@ -692,293 +844,21 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tQcQ, tQpQ ); } - // Advance gMask - tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride)); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tMaskgMask, tMasksMask, - tMaskcMask, - binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN - ); FLASH_NAMESPACE::cp_async_fence(); } - - if (m_block > m_block_min) { - gLSE.data() = gLSE.data() + (-int(kBlockM)); - #pragma unroll - for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); } - gdPsum.data() = gdPsum.data() + (-int(kBlockM)); - } - - if (Double_buffer) { // Double buffer for sQ - tdKsQt.data() = tdKsQt.data() + (m_block % 2 == 0 ? size(sQ) : -size(sQ)); - } - if (!Double_buffer && m_block > m_block_min) { - __syncthreads(); - // Advance gQ - tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride)); - FLASH_NAMESPACE::copy( - gmem_tiled_copy_QKV, - tQgQ, tQsQ, - tQcQ, tQpQ - ); - FLASH_NAMESPACE::cp_async_fence(); - } - - if (m_block > m_block_min) { - // Advance gBias and gdBias - tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride)); - tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride)); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tBiasgBias, tBiassBias, - tBiascBias, - binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN - ); - FLASH_NAMESPACE::cp_async_fence(); - } - - if (Is_first && m_block > m_block_min) { - cute::copy(tdOrdO, tdOsdO); - dot_do_o( - tdOrdO, tdOrO, gdPsum, - Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow) - ); - } - - if (Is_last) { - __syncthreads(); - Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); - clear(tdQrdQ); - cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); - tdQgdQ.data() = tdQgdQ.data() + (-int(kBlockM * params.dq_row_stride)); - Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M, BLK_K) -> (blk_m, blk_k) - Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); - #pragma unroll - for (int m = 0; m < size<1>(tdQgdQ); ++m) { - if (Is_even_MN || get<0>(tdQcdQ(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) { - cute::copy(gmem_tiled_copy_dQ, tdQrdQ(_, m, _), tdQgdQ(_, m, _)); - } - } - } - - continue; - } - - Tensor dP_sum = make_fragment_like(lse); - #pragma unroll - for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); } - - // if (cute::thread0()) { print(sK); } - // Tensor tSrK_copy_view = smem_thr_copy_KV.retile_D(tSrK); - // #pragma unroll - // for (int k = 0; k < size<2>(tSrK_copy_view); ++k) { - // cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k)); - // } - // if (cute::thread0()) { print(tSrK); } - FLASH_NAMESPACE::gemm( - acc_s, - tSrQ, tSrK, tSsQ, tSsK, - tiled_mma_sdp, - smem_tiled_copy_QdO, smem_tiled_copy_KV, - smem_thr_copy_QdO, smem_thr_copy_KV - ); - if constexpr (Is_softcap) { - FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); - } - - // Copy bias from smem to registers - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias); - cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view); - - // Reshape acc_s, mask, bias from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); - Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout())); - Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout())); - // if (cute::thread(32, 0)) { print(scores); } - - // Softcapping - calculating dTanh and scaling dS later with it - [[maybe_unused]] Tensor dtanh = make_tensor_like(scores); - if constexpr (Is_softcap) { - FLASH_NAMESPACE::calculate_dtanh(scores, dtanh, params.softcap); - } - - // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond - // actual_seqlen_k, because acc_s would be some finite value for those indices. - // In the end when we multiply with K to get dQ, the corresponding values of K would be 0, - // so the result would still be correct. - // However, it's possible that the values in acc_s are so large that they overflow - // when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ. - // So we need to mask out the elements beyond actual_seqlen_k. - FLASH_NAMESPACE::apply_mask( - scores, mask, bias, params.scale_softmax, - n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - binfo.actual_seqlen_k, - m_block * kBlockM + get<0>(taccScS_row(0)), - binfo.actual_seqlen_q, - AtomLayoutMS * 16 - ); - - // if (cute::thread(32, 0)) { print(scores); } - // Compute the exponential value. - // FLASH_NAMESPACE::scale_apply_exp2(scores, lse, params.scale_softmax_log2); - FLASH_NAMESPACE::scale_apply_exp2(scores, lse, float(M_LOG2E)); - // Convert scores from fp32 to fp16/bf16 - Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); - // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tPrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); - Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom, AtomNum), MMA_M, MMA_N) - cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); - // if (cute::thread0()) { print(tPaP); } - // __syncthreads(); - // if (cute::thread0()) { print(sP); } - - Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - CUTE_STATIC_ASSERT_V(size<0>(acc_dp) == size<0>(acc_s)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(acc_dp) == size<1>(acc_s)); // MMA - CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s)); // MMA - - clear(acc_dp); - // Tensor acc_dp_reshaped = make_tensor(acc_dp.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_dp.layout())); - // #pragma unroll - // for (int mi = 0; mi < size<0>(acc_dp_reshaped); ++mi) { - // #pragma unroll - // for (int ni = 0; ni < size<1>(acc_dp_reshaped); ++ni) { - // acc_dp_reshaped(mi, ni) = -dP_sum(mi); - // } - // } - - // if (cute::thread0()) { print(dP_sum); } - - FLASH_NAMESPACE::gemm( - acc_dp, - tdPrdO, tdPrV, tdPsdO, tdPsV, - tiled_mma_sdp, - smem_tiled_copy_QdO, smem_tiled_copy_KV, - smem_thr_copy_QdO, smem_thr_copy_KV - ); - - // Reshape acc_dp from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) - Tensor dS = make_tensor(acc_dp.data(), scores.layout()); - auto pointwise_mult = [](float p, float dp, float d) { - return p * (p >= 0 ? dp - d : d); - }; - #pragma unroll - for (int mi = 0; mi < size<0>(dS); ++mi) { - #pragma unroll - for (int ni = 0; ni < size<1>(dS); ++ni) { - float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); - if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); } - dS(mi, ni) = scaled_ds; - } } - // if (cute::thread0()) { print(dS); } - Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_M, MMA_K - tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.h * params.d_rounded)); - if (Is_first || Seq_parallel) { - clear(acc_dq); - } else { - // Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum - Tensor acc_dq_reshaped = make_tensor( - acc_dq.data(), - make_layout(get<0>(acc_dq.layout()), get<2>(acc_dq.layout()), get<1>(acc_dq.layout())) - ); - cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, acc_dq_reshaped); - } - - if (Double_buffer && m_block > m_block_min) { - // Double buffer for sQ - const int sQ_offset = m_block % 2 == 0 ? size(sQ) : -size(sQ); - tQsQ.data() = tQsQ.data() + sQ_offset; - tSsQ.data() = tSsQ.data() + sQ_offset; - // Advance gQ - tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride)); - FLASH_NAMESPACE::copy( - gmem_tiled_copy_QKV, - tQgQ, tQsQ, - tQcQ, tQpQ + if (any_active) { + FLASH_NAMESPACE::gemm( + acc_dq, + tdQrdS, tdQrKt, tdQsdS, tdQsKt, + tiled_mma_dq, + smem_tiled_copy_dS, smem_tiled_copy_Kt, + smem_thr_copy_dS, smem_thr_copy_Kt ); - FLASH_NAMESPACE::cp_async_fence(); + // if (cute::thread0()) { print(acc_dq); } } - Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout()); - // Convert dS from fp32 to fp16 - Tensor tdSrdS = FLASH_NAMESPACE::convert_type(dS_reshaped); - Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom, AtomNum), MMA_M, MMA_N) - cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); - __syncthreads(); - // Write dS to dBias - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tBiassBias, tdBiasgdBias, - tBiascBias, - binfo.actual_seqlen_q - m_block * kBlockM, - binfo.actual_seqlen_k - n_block * kBlockN - ); - - // if (cute::thread0()) { print(tPrP); } - // Layout p_l = tPrP.layout(); - // Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l))); - // FLASH_NAMESPACE::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt); - // Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout()); - // FLASH_NAMESPACE::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt); - FLASH_NAMESPACE::gemm( - acc_dv, - tdVrPt, tdVrdO, tdVsPt, tdVsdOt, - tiled_mma_dkv, - smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, - smem_thr_copy_PdSt, smem_thr_copy_QdOt - ); - // if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); } - // if (cute::thread0()) { print(acc_dv); } - - __syncthreads(); // Need syncthreads since we're writing to the same sdO location - - if (m_block > m_block_min) { - // Advance gdO - tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride)); - if (Is_first) { - tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride)); - FLASH_NAMESPACE::copy( - gmem_tiled_copy_dO, - tdOgdO, tdOrdO, - tQcQ, tQpQ - ); - FLASH_NAMESPACE::copy( - gmem_tiled_copy_dO, - tdOgO, tdOrO, - tQcQ, tQpQ - ); - } else { - FLASH_NAMESPACE::copy( - gmem_tiled_copy_dO, - tdOgdO, tdOsdO, - tQcQ, tQpQ - ); - } - // Advance gMask - tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride)); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tMaskgMask, tMasksMask, - tMaskcMask, - binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN - ); - FLASH_NAMESPACE::cp_async_fence(); - } - - FLASH_NAMESPACE::gemm( - acc_dq, - tdQrdS, tdQrKt, tdQsdS, tdQsKt, - tiled_mma_dq, - smem_tiled_copy_dS, smem_tiled_copy_Kt, - smem_thr_copy_dS, smem_thr_copy_Kt - ); - // if (cute::thread0()) { print(acc_dq); } - if (m_block > m_block_min) { gLSE.data() = gLSE.data() + (-int(kBlockM)); #pragma unroll @@ -986,67 +866,73 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in gdPsum.data() = gdPsum.data() + (-int(kBlockM)); } - if (!Is_last) { - // Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum - Tensor acc_dq_reshaped = make_tensor( - acc_dq.data(), - make_layout(get<0>(acc_dq.layout()), get<2>(acc_dq.layout()), get<1>(acc_dq.layout())) - ); - if (!Seq_parallel) { - cute::copy(gmem_tiled_copy_dQaccum, acc_dq_reshaped, tdQgdQaccum); + if (any_active) { + if (!Is_last) { + // Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum + Tensor acc_dq_reshaped = make_tensor( + acc_dq.data(), + make_layout(get<0>(acc_dq.layout()), get<2>(acc_dq.layout()), get<1>(acc_dq.layout())) + ); + if (!Seq_parallel) { + cute::copy(gmem_tiled_copy_dQaccum, acc_dq_reshaped, tdQgdQaccum); + } else { + // if (cute::thread0()) { print(acc_dq.layout()); printf("\n"); print(acc_dq_reshaped.layout()); printf("\n"); print(tdQgdQaccum.layout()); printf("\n"); } + CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); + #pragma unroll + for (int i = 0; i < size(acc_dq); ++i) { atomicAdd(&tdQgdQaccum(i), acc_dq(i)); } + } } else { - // if (cute::thread0()) { print(acc_dq.layout()); printf("\n"); print(acc_dq_reshaped.layout()); printf("\n"); print(tdQgdQaccum.layout()); printf("\n"); } - CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); #pragma unroll - for (int i = 0; i < size(acc_dq); ++i) { atomicAdd(&tdQgdQaccum(i), acc_dq(i)); } + for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; } + // Convert acc_dq from fp32 to fp16 + Tensor rdQ = FLASH_NAMESPACE::convert_type(acc_dq); + Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom, AtomNum), MMA_M, MMA_K) + cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); } - } else { - #pragma unroll - for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; } - // Convert acc_dq from fp32 to fp16 - Tensor rdQ = FLASH_NAMESPACE::convert_type(acc_dq); - Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom, AtomNum), MMA_M, MMA_K) - cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); + + FLASH_NAMESPACE::gemm( + acc_dk, + tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, + tiled_mma_dkv, + smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, + smem_thr_copy_PdSt, smem_thr_copy_QdOt + ); + // if (cute::thread0()) { print(acc_dk); } + + __syncthreads(); // Need syncthreads since we're using the sBias smem for accumulating acc_dk } - FLASH_NAMESPACE::gemm( - acc_dk, - tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, - tiled_mma_dkv, - smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, - smem_thr_copy_PdSt, smem_thr_copy_QdOt - ); - // if (cute::thread0()) { print(acc_dk); } if (Double_buffer) { // Double buffer for sQ tdKsQt.data() = tdKsQt.data() + (m_block % 2 == 0 ? size(sQ) : -size(sQ)); } if (!Double_buffer && m_block > m_block_min) { - __syncthreads(); // Advance gQ tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride)); - FLASH_NAMESPACE::copy( - gmem_tiled_copy_QKV, - tQgQ, tQsQ, - tQcQ, tQpQ - ); - FLASH_NAMESPACE::cp_async_fence(); + if (any_active_next) { + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, + tQgQ, tQsQ, + tQcQ, tQpQ + ); + } } - __syncthreads(); // Need syncthreads since we're using the sBias smem for accumulating acc_dk - if (m_block > m_block_min) { // Advance gBias and gdBias tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride)); tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride)); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tBiasgBias, tBiassBias, - tBiascBias, - binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN - ); - FLASH_NAMESPACE::cp_async_fence(); + if (any_active_next) { + FLASH_NAMESPACE::copy_MN( + gmem_tiled_copy_MaskBias, + tBiasgBias, tBiassBias, + tBiascBias, + binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN + ); + } } + FLASH_NAMESPACE::cp_async_fence(); + if (Is_first && m_block > m_block_min) { cute::copy(tdOrdO, tdOsdO); dot_do_o( @@ -1070,6 +956,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in } } + any_active = any_active_next; + } diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index 6cb5209..77a5a19 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -362,7 +362,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi #pragma unroll for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); } bool any_active = __syncthreads_or(any_active_local); - bool any_active_next = any_active; // to be updated later for next iteration + bool any_active_next = false; // to be updated later for next iteration // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. if (any_active) { @@ -1016,7 +1016,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons #pragma unroll for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); } bool any_active = __syncthreads_or(any_active_local); - bool any_active_next = any_active; // to be updated later for next iteration + bool any_active_next = false; // to be updated later for next iteration // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. if (any_active) { diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index d4f968f..1acdfdd 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -420,7 +420,9 @@ def forward( ) if mask is None: mask = torch.ones((batch_size, num_heads_k, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device) + return_dbias = True if bias is None: + return_dbias = False bias = torch.zeros((batch_size, num_heads_k, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device) if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -465,6 +467,7 @@ def forward( ctx.is_causal = is_causal ctx.softcap = softcap ctx.deterministic = deterministic + ctx.return_dbias = return_dbias out = out_padded[..., :head_size_og] return out if not return_softmax else (out, softmax_lse, S_dmask) @@ -510,7 +513,9 @@ def backward( dk = dk[:, : ctx.seqlen_k, :, :] dv = dv[:, : ctx.seqlen_k, :, :] dbias = dbias[..., : ctx.seqlen_k] - return dq, dk, dv, None, dbias, None, None, None, None, None, None + if ctx.return_dbias: + return dq, dk, dv, None, dbias, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None class FlashDMAttnVarlenFunc(torch.autograd.Function): @@ -544,8 +549,10 @@ def forward( ) if mask is None: mask = torch.ones((total_q, num_heads_k, max_seqlen_k), dtype=q.dtype, device=q.device) + return_dbias = True if bias is None: bias = torch.zeros((total_q, num_heads_k, max_seqlen_k), dtype=q.dtype, device=q.device) + return_dbias = False if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) if is_causal is None: @@ -606,6 +613,7 @@ def forward( ctx.is_causal = is_causal ctx.softcap = softcap ctx.deterministic = deterministic + ctx.return_dbias = return_dbias out = out_padded[..., :head_size_og] if return_softmax: @@ -658,7 +666,9 @@ def backward( if ctx.seqlen_k_og != ctx.max_seqlen_k: dbias = dbias[:, :, :ctx.seqlen_k_og] - return dq, dk, dv, None, dbias, None, None, None, None, None, None, None, None, None, None, None + if ctx.return_dbias: + return dq, dk, dv, None, dbias, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None def flash_dmattn_func(