diff --git a/csrc/flash_dmattn/src/flash_bwd_kernel.h b/csrc/flash_dmattn/src/flash_bwd_kernel.h index 780e616..e8578e2 100644 --- a/csrc/flash_dmattn/src/flash_bwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_bwd_kernel.h @@ -80,6 +80,7 @@ template(sMaskPlace.data().get())), + typename Kernel_traits::SmemLayoutMaskBiasPdS{} ); Tensor sP = make_tensor( - sMask.data(), + sMaskPlace.data(), typename Kernel_traits::SmemLayoutMaskBiasPdS{} ); Tensor sPt = make_tensor( - sMask.data(), + sMaskPlace.data(), typename Kernel_traits::SmemLayoutPdStransposed{} ); Tensor sPtNoSwizzle = make_tensor( - sMask.data(), + sMaskPlace.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{} ); // sMask, sP and sdQ share the same memory so be careful Tensor sdQ = make_tensor( - sMask.data(), + sMaskPlace.data(), typename Kernel_traits::SmemLayoutdQ{} ); @@ -278,6 +283,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in GmemTiledCopydO gmem_tiled_copy_dO; auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ; + typename Kernel_traits::GmemTiledCopydBias gmem_tiled_copy_dBias; + auto gmem_thr_copy_dBias = gmem_tiled_copy_dBias.get_thread_slice(tidx); auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx); using GmemLayoutAtomdQaccum = std::conditional_t< !Seq_parallel, @@ -300,7 +307,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in Tensor tMasksMask = gmem_thr_copy_Mask.partition_D(sMask); Tensor tBiasgBias = gmem_thr_copy_Bias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) Tensor tBiassBias = gmem_thr_copy_Bias.partition_D(sBias); - Tensor tdBiasgdBias = gmem_thr_copy_Bias.partition_D(gdBias); + Tensor tdBiasgdBias = gmem_thr_copy_dBias.partition_D(gdBias); Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom, AtomNum), ATOM_M, ATOM_N) Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); @@ -350,20 +357,17 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // if (cute::thread(0, 0) && n_block == 0) { print(tSsK.layout()); printf("\n"); } Tensor tdPsV = smem_thr_copy_KV.partition_S(sV); - // auto smem_tiled_copy_Mask = make_tiled_copy_C_warpcontiguousN(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma_sdp); - // auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx); - // Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask); - // auto smem_tiled_copy_Bias = make_tiled_copy_C_warpcontiguousN(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma_sdp); - // auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx); - // Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias); - // Partition sP and sdS to match the accumulator partitioning // This has to be tiled_mma_sdp, not tiled_mma_dkv // auto smem_tiled_copy_PdS = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp); auto smem_tiled_copy_PdS = make_tiled_copy_C_warpcontiguousN(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp); auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(tidx); - Tensor tSsMask = smem_thr_copy_PdS.partition_S(sMask); - Tensor tSsBias = smem_thr_copy_PdS.partition_S(sBias); + auto smem_tiled_copy_Mask = make_tiled_copy_C_warpcontiguousN(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma_sdp); + auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx); + auto smem_tiled_copy_Bias = make_tiled_copy_C_warpcontiguousN(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma_sdp); + auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx); + Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask); + Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias); Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom, AtomNum), PIPE_M, PIPE_N) Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom, AtomNum), PIPE_M, PIPE_N) @@ -573,24 +577,19 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // // if (cute::thread(1, 0)) { print(tKrK); } if constexpr (Has_mask) { - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask, tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads - - FLASH_NAMESPACE::copy_mask_with_or_reduce( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask, tMasksMask, - any_active, tMaskcMask, tMaskpMask, binfo.actual_seqlen_q - m_block * kBlockM ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + __syncthreads(); + // Do OR-reduce on the mask to see if any active threads for current interation. + FLASH_NAMESPACE::mask_or_reduce( + tMasksMask, + any_active, + smem_thr_copy_Mask + ); } FLASH_NAMESPACE::copy( @@ -602,15 +601,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in if (any_active) { if constexpr (Has_bias) { - FLASH_NAMESPACE::copy_bias( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias, tBiassBias, tBiascBias, tBiaspBias, binfo.actual_seqlen_q - m_block * kBlockM ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); } } @@ -627,9 +623,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); } if (Is_first) { cute::copy(tdOrdO, tdOsdO); - dot_do_o( + dot_do_o( tdOrdO, tdOrO, gdPsum, - Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow) + Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRowQKVO) ); } @@ -852,15 +848,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in __syncthreads(); if constexpr (Has_bias) { // Write dS to dBias - FLASH_NAMESPACE::copy_bias( - gmem_tiled_copy_Bias, + FLASH_NAMESPACE::copy_MN( + gmem_tiled_copy_dBias, tBiassBias, tdBiasgdBias, tBiascBias, tBiaspBias, binfo.actual_seqlen_q - m_block * kBlockM ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); } // if (cute::thread0()) { print(tPrP); } @@ -886,24 +879,19 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in if constexpr (Has_mask) { // Advance gMask tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride)); - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask, tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - (m_block - 1) * kBlockM - // ); - // FLASH_NAMESPACE::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads for next iteration - - FLASH_NAMESPACE::copy_mask_with_or_reduce( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask, tMasksMask, - any_active_next, tMaskcMask, tMaskpMask, binfo.actual_seqlen_q - (m_block - 1) * kBlockM ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + __syncthreads(); + // Do OR-reduce on the mask to see if any active threads for next iteration + FLASH_NAMESPACE::mask_or_reduce( + tMasksMask, + any_active_next, + smem_thr_copy_Mask + ); } // Advance gdO @@ -1007,15 +995,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride)); tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride)); if (any_active_next) { - FLASH_NAMESPACE::copy_bias( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias, tBiassBias, tBiascBias, tBiaspBias, binfo.actual_seqlen_q - (m_block - 1) * kBlockM ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); } } } @@ -1024,9 +1009,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in if (Is_first && m_block > m_block_min) { cute::copy(tdOrdO, tdOsdO); - dot_do_o( + dot_do_o( tdOrdO, tdOrO, gdPsum, - Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow) + Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRowQKVO) ); } diff --git a/csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h b/csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h index c8a35c0..d6710b2 100644 --- a/csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h +++ b/csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h @@ -148,9 +148,9 @@ inline __device__ void compute_dot_do_o( tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM ); - dot_do_o( + dot_do_o( tdOrdO, tdOrO, dP_sum, - Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow) + Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRowQKVO) ); if (Clear_dQaccum) { // We're actually not zero'ing out all of dQaccum, but only the part that we're going to diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_dmattn/src/flash_fwd_kernel.h index 0e4d69b..576de90 100644 --- a/csrc/flash_dmattn/src/flash_fwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_fwd_kernel.h @@ -54,6 +54,7 @@ template(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)), + make_gmem_ptr(reinterpret_cast(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)), make_shape(params.h_mask, binfo.actual_seqlen_q, binfo.actual_seqlen_k), make_stride(params.mask_head_stride, params.mask_row_stride, _1{}) ); @@ -216,13 +217,17 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{} ); - Tensor sMask = make_tensor( + Tensor sMaskPlace = make_tensor( Has_mask ? sV.data() + size(sV) : sV.data(), - typename Kernel_traits::SmemLayoutAtomPS{} + typename Kernel_traits::SmemLayoutPS{} + ); // For pointers alignment only + Tensor sMask = make_tensor( + make_smem_ptr(reinterpret_cast(sMaskPlace.data().get())), + typename Kernel_traits::SmemLayoutPS{} ); Tensor sBias = make_tensor( - Has_bias ? (Has_mask ? sMask.data() + size(sMask) : sV.data() + size(sV)) : sV.data(), - typename Kernel_traits::SmemLayoutAtomPS{} + Has_bias ? (Has_mask ? sMaskPlace.data() + size(sMaskPlace) : sV.data() + size(sV)) : sV.data(), + typename Kernel_traits::SmemLayoutPS{} ); // Global to Shared Memory operation @@ -267,10 +272,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPS{}, tiled_mma); + auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma); auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx); Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask); - auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPS{}, tiled_mma); + auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma); auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx); Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias); @@ -364,25 +369,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } if constexpr (Has_mask) { - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask(_, _, _, n_block), tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads - - - FLASH_NAMESPACE::copy_mask_with_or_reduce( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask(_, _, _, n_block), tMasksMask, - any_active, tMaskcMask, tMaskpMask, binfo.actual_seqlen_q - m_block * kBlockM ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + __syncthreads(); + // Do OR-reduce on the mask to see if any active threads for current iteration. + FLASH_NAMESPACE::mask_or_reduce( + tMasksMask, + any_active, + smem_thr_copy_Mask + ); } // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. @@ -394,15 +393,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi binfo.actual_seqlen_k - n_block * kBlockN ); if constexpr (Has_bias) { - FLASH_NAMESPACE::copy_bias( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias(_, _, _, n_block), tBiassBias, tBiascBias, tBiaspBias, binfo.actual_seqlen_q - m_block * kBlockM ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); } cute::cp_async_fence(); } @@ -527,24 +523,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi if (n_block > n_block_min) { if constexpr (Has_mask) { - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask(_, _, _, n_block - 1), tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads for next iteration. - - FLASH_NAMESPACE::copy_mask_with_or_reduce( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask(_, _, _, n_block - 1), tMasksMask, - any_active_next, tMaskcMask, tMaskpMask, binfo.actual_seqlen_q - m_block * kBlockM ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + __syncthreads(); + // Do OR-reduce on the mask to see if any active threads for next iteration. + FLASH_NAMESPACE::mask_or_reduce( + tMasksMask, + any_active_next, + smem_thr_copy_Mask + ); } if (any_active_next) { @@ -554,15 +545,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tKVcKV, tKVpKV ); if constexpr (Has_bias) { - FLASH_NAMESPACE::copy_bias( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias(_, _, _, n_block - 1), tBiassBias, tBiascBias, tBiaspBias, binfo.actual_seqlen_q - m_block * kBlockM ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); } // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. @@ -690,24 +678,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi if (n_block > n_block_min) { if constexpr (Has_mask) { - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask(_, _, _, n_block - 1), tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads for next iteration. - - FLASH_NAMESPACE::copy_mask_with_or_reduce( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask(_, _, _, n_block - 1), tMasksMask, - any_active_next, tMaskcMask, tMaskpMask, binfo.actual_seqlen_q - m_block * kBlockM ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads + __syncthreads(); + // Do OR-reduce on the mask to see if any active threads for next iteration. + FLASH_NAMESPACE::mask_or_reduce( + tMasksMask, + any_active_next, + smem_thr_copy_Mask + ); } if (any_active_next) { @@ -717,15 +700,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tKVcKV, tKVpKV ); if constexpr (Has_bias) { - FLASH_NAMESPACE::copy_bias( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias(_, _, _, n_block - 1), tBiassBias, tBiascBias, tBiaspBias, binfo.actual_seqlen_q - m_block * kBlockM ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); } // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. @@ -843,6 +823,7 @@ template(params.mask_ptr) + col_offset_mask), + make_gmem_ptr(reinterpret_cast(params.mask_ptr) + col_offset_mask), Shape, Int>{}, make_stride(params.mask_row_stride, _1{}) ); @@ -1010,13 +991,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{} ); - Tensor sMask = make_tensor( + Tensor sMaskPlace = make_tensor( Has_mask ? sV.data() + size(sV) : sV.data(), - typename Kernel_traits::SmemLayoutAtomPS{} + typename Kernel_traits::SmemLayoutPS{} + ); // For pointers alignment only + Tensor sMask = make_tensor( + make_smem_ptr(reinterpret_cast(sMaskPlace.data().get())), + typename Kernel_traits::SmemLayoutPS{} ); Tensor sBias = make_tensor( - Has_bias ? (Has_mask ? sMask.data() + size(sMask) : sV.data() + size(sV)) : sV.data(), - typename Kernel_traits::SmemLayoutAtomPS{} + Has_bias ? (Has_mask ? sMaskPlace.data() + size(sMaskPlace) : sV.data() + size(sV)) : sV.data(), + typename Kernel_traits::SmemLayoutPS{} ); // Global to Shared Memory operation @@ -1026,7 +1011,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons auto gmem_thr_copy_Mask = gmem_tiled_copy_Mask.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_Bias; auto gmem_thr_copy_Bias = gmem_tiled_copy_Bias.get_thread_slice(tidx); - Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); @@ -1059,10 +1043,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPS{}, tiled_mma); + auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma); auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx); Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask); - auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPS{}, tiled_mma); + auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma); auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx); Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias); @@ -1125,24 +1109,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ); if constexpr (Has_mask) { - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask, tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads - - FLASH_NAMESPACE::copy_mask_with_or_reduce( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask, tMasksMask, - any_active, tMaskcMask, tMaskpMask, binfo.actual_seqlen_q - m_block * kBlockM ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + __syncthreads(); + // Do OR-reduce on the mask to see if any active threads for current iteration. + FLASH_NAMESPACE::mask_or_reduce( + tMasksMask, + any_active, + smem_thr_copy_Mask + ); } // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. @@ -1154,15 +1133,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons binfo.actual_seqlen_k - n_block * kBlockN ); if constexpr (Has_bias) { - FLASH_NAMESPACE::copy_bias( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias, tBiassBias, tBiascBias, tBiaspBias, binfo.actual_seqlen_q - m_block * kBlockM ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); } cute::cp_async_fence(); } @@ -1318,24 +1294,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons } if constexpr (Has_mask) { - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask, tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads for next iteration. - - FLASH_NAMESPACE::copy_mask_with_or_reduce( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask, tMasksMask, - any_active_next, tMaskcMask, tMaskpMask, binfo.actual_seqlen_q - m_block * kBlockM ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + __syncthreads(); + // Do OR-reduce on the mask to see if any active threads for next iteration. + FLASH_NAMESPACE::mask_or_reduce( + tMasksMask, + any_active_next, + smem_thr_copy_Mask + ); } if (any_active_next) { @@ -1345,15 +1316,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKVcKV, tKVpKV ); if constexpr (Has_bias) { - FLASH_NAMESPACE::copy_bias( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, - tBiasgBias, tBiassBias, + tBiasgBias, tBiassBias, tBiascBias, tBiaspBias, binfo.actual_seqlen_q - m_block * kBlockM ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); } // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. @@ -1508,24 +1476,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons } if constexpr (Has_mask) { - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask, tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads for next iteration. - - FLASH_NAMESPACE::copy_mask_with_or_reduce( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask, tMasksMask, - any_active_next, tMaskcMask, tMaskpMask, binfo.actual_seqlen_q - m_block * kBlockM ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + __syncthreads(); + // Do OR-reduce on the mask to see if any active threads for next iteration. + FLASH_NAMESPACE::mask_or_reduce( + tMasksMask, + any_active_next, + smem_thr_copy_Mask + ); } if (any_active_next) { @@ -1535,15 +1498,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKVcKV, tKVpKV ); if constexpr (Has_bias) { - FLASH_NAMESPACE::copy_bias( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, - tBiasgBias, tBiassBias, + tBiasgBias, tBiassBias, tBiascBias, tBiaspBias, binfo.actual_seqlen_q - m_block * kBlockM ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); } // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. diff --git a/csrc/flash_dmattn/src/kernel_traits.h b/csrc/flash_dmattn/src/kernel_traits.h index f6a4e42..4412984 100644 --- a/csrc/flash_dmattn/src/kernel_traits.h +++ b/csrc/flash_dmattn/src/kernel_traits.h @@ -23,6 +23,7 @@ struct Flash_kernel_traits { static constexpr bool Has_cp_async = false; #endif + using ElementMask = uint8_t; using ElementAccum = float; using index_t = int64_t; @@ -55,6 +56,7 @@ struct Flash_fwd_kernel_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; + using ElementMask = typename Base::ElementMask; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; @@ -83,7 +85,7 @@ struct Flash_fwd_kernel_traits : public Base { Tile, _16, _16> >; - // Shared memory layout for Q matrix and Mask matrix + // Shared memory layout for Q matrix using SmemLayoutAtomQ = decltype( composition( Swizzle{}, @@ -92,14 +94,6 @@ struct Flash_fwd_kernel_traits : public Base { Stride, _1>>{} ) ); - using SmemLayoutAtomPS = decltype( - composition( - Swizzle{}, - Layout, Int>, - Stride, _1>>{} - ) - ); - using SmemLayoutQ = decltype( tile_to_shape( SmemLayoutAtomQ{}, @@ -123,13 +117,22 @@ struct Flash_fwd_kernel_traits : public Base { ); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); + // Shared memory layout for Mask and Bias matrices + using SmemLayoutAtomPS = decltype( + composition( + Swizzle{}, + Layout, Int>, + Stride, _1>>{} + ) + ); using SmemLayoutPS = decltype( tile_to_shape( SmemLayoutAtomPS{}, Shape, Int>{} ) ); - using SmemCopyAtomPS = Copy_Atom, Element>; + using SmemCopyAtomMask = Copy_Atom, ElementMask>; + using SmemCopyAtomBias = Copy_Atom, Element>; // Shared memory layout for output using SmemLayoutAtomO = decltype( @@ -157,18 +160,34 @@ struct Flash_fwd_kernel_traits : public Base { // Shared memory size with QKV matrices and mask/bias matrices static constexpr int kSmemSize = (Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize) + (Has_mask ? kSmemMaskSize : 0) + (Has_bias ? kSmemBiasSize : 0); - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + static constexpr int kGmemElemsPerLoadQKVO = sizeof(cute::uint128_t) / sizeof(Element); + static constexpr int kGmemElemsPerLoadMask = sizeof(cute::uint128_t) / sizeof(ElementMask); + static constexpr int kGmemElemsPerLoadBias = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoadQKVO == 0, "kHeadDim must be a multiple of kGmemElemsPerLoadQKVO"); + static_assert(kBlockN % kGmemElemsPerLoadMask == 0, "kBlockN must be a multiple of kGmemElemsPerLoadMask"); + static_assert(kBlockN % kGmemElemsPerLoadBias == 0, "kBlockN must be a multiple of kGmemElemsPerLoadBias"); // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. // For example, for d=128, smem is split into 2 "pages", each page takes care of columns // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, // to the same banks. - static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; - static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); - using GmemLayoutAtom = Layout< - Shape , Int>, - Stride, _1> + static constexpr int kGmemThreadsPerRowQKVO = kBlockKSmem / kGmemElemsPerLoadQKVO; + static constexpr int kGmemThreadsPerRowMask = kBlockN / kGmemElemsPerLoadMask; + static constexpr int kGmemThreadsPerRowBias = kBlockN / kGmemElemsPerLoadBias; + static_assert(kNThreads % kGmemThreadsPerRowQKVO == 0, "kNThreads must be a multiple of kGmemThreadsPerRowQKVO"); + static_assert(kNThreads % kGmemThreadsPerRowMask == 0, "kNThreads must be a multiple of kGmemThreadsPerRowMask"); + static_assert(kNThreads % kGmemThreadsPerRowBias == 0, "kNThreads must be a multiple of kGmemThreadsPerRowBias"); + using GmemLayoutAtomQKVO = Layout< + Shape , Int>, + Stride, _1> + >; + using GmemLayoutAtomMask = Layout< + Shape , Int>, + Stride, _1> + >; + using GmemLayoutAtomBias = Layout< + Shape , Int>, + Stride, _1> >; // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading @@ -181,28 +200,28 @@ struct Flash_fwd_kernel_traits : public Base { using GmemTiledCopyQKV = decltype( make_tiled_copy( Copy_Atom{}, - GmemLayoutAtom{}, + GmemLayoutAtomQKVO{}, Layout>{} ) ); // Val layout, 8 vals per read using GmemTiledCopyMask = decltype( make_tiled_copy( - Copy_Atom, Element>{}, - GmemLayoutAtom{}, - Layout>{} + Copy_Atom, ElementMask>{}, + GmemLayoutAtomMask{}, + Layout>{} ) - ); // Val layout, 8 vals per read + ); // Val layout, 16 vals per read using GmemTiledCopyBias = decltype( make_tiled_copy( - Copy_Atom, Element>{}, - GmemLayoutAtom{}, + Copy_Atom{}, + GmemLayoutAtomBias{}, Layout>{} ) ); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy( Copy_Atom, Element>{}, - GmemLayoutAtom{}, + GmemLayoutAtomQKVO{}, Layout>{} ) ); // Val layout, 8 vals per store @@ -210,15 +229,15 @@ struct Flash_fwd_kernel_traits : public Base { // Accumulator layout for output using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, - Layout, Stride< _8, _1>>, // Thread layout, 8 threads per row - Layout, Stride< _16, _1>> // Thread layout, 16 threads per row + Layout, Stride<_8, _1>>, // Thread layout, 8 threads per row + Layout, Stride<_16, _1>> // Thread layout, 16 threads per row >; using GmemTiledCopyOaccum = decltype( make_tiled_copy( Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, - Layout>{} + Layout>{} ) ); // Val layout, 4 vals per store }; @@ -233,6 +252,7 @@ template< > struct Flash_bwd_kernel_traits : public Base { using Element = typename Base::Element; + using ElementMask = typename Base::ElementMask; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; @@ -347,6 +367,8 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); using SmemCopyAtomPdS = Copy_Atom, elem_type>; + using SmemCopyAtomMask = Copy_Atom, ElementMask>; + using SmemCopyAtomBias = Copy_Atom, elem_type>; using SmemLayoutQdOtransposed = decltype( composition( @@ -405,15 +427,31 @@ struct Flash_bwd_kernel_traits : public Base { : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemBiasdSSize + kSmemPMaskSize) ); - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + static constexpr int kGmemElemsPerLoadQKVO = sizeof(cute::uint128_t) / sizeof(Element); + static constexpr int kGmemElemsPerLoadMask = sizeof(cute::uint128_t) / sizeof(ElementMask); + static constexpr int kGmemElemsPerLoadBias = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoadQKVO == 0, "kHeadDim must be a multiple of kGmemElemsPerLoadQKVO"); + static_assert(kBlockN % kGmemElemsPerLoadMask == 0, "kBlockN must be a multiple of kGmemElemsPerLoadMask"); + static_assert(kBlockN % kGmemElemsPerLoadBias == 0, "kBlockN must be a multiple of kGmemElemsPerLoadBias"); // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem // to affect speed in practice. - static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; - static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); - using GmemLayoutAtom = Layout< - Shape , Int>, - Stride, _1> + static constexpr int kGmemThreadsPerRowQKVO = kBlockKSmem / kGmemElemsPerLoadQKVO; + static constexpr int kGmemThreadsPerRowMask = kBlockN / kGmemElemsPerLoadMask; + static constexpr int kGmemThreadsPerRowBias = kBlockN / kGmemElemsPerLoadBias; + static_assert(kNThreads % kGmemThreadsPerRowQKVO == 0, "kNThreads must be a multiple of kGmemThreadsPerRowQKVO"); + static_assert(kNThreads % kGmemThreadsPerRowMask == 0, "kNThreads must be a multiple of kGmemThreadsPerRowMask"); + static_assert(kNThreads % kGmemThreadsPerRowBias == 0, "kNThreads must be a multiple of kGmemThreadsPerRowBias"); + using GmemLayoutAtomQKVO = Layout< + Shape, Int>, + Stride, _1> + >; + using GmemLayoutAtomMask = Layout< + Shape, Int>, + Stride, _1> + >; + using GmemLayoutAtomBias = Layout< + Shape, Int>, + Stride, _1> >; // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading @@ -426,65 +464,71 @@ struct Flash_bwd_kernel_traits : public Base { using GmemTiledCopyQKV = decltype( make_tiled_copy( Copy_Atom{}, - GmemLayoutAtom{}, + GmemLayoutAtomQKVO{}, Layout>{} ) ); // Val layout, 8 vals per read - using GmemTiledCopydO = decltype( + using GmemTiledCopyMask = decltype( make_tiled_copy( - Copy_Atom, elem_type>{}, - GmemLayoutAtom{}, - Layout>{} + Copy_Atom, ElementMask>{}, + GmemLayoutAtomMask{}, + Layout>{} ) - ); // Val layout, 8 vals per store - using GmemTiledCopydKV = decltype( + ); // Val layout, 16 vals per read + using GmemTiledCopyBias = decltype( make_tiled_copy( - Copy_Atom, elem_type>{}, - GmemLayoutAtom{}, - Layout>{} + Copy_Atom{}, + GmemLayoutAtomBias{}, + Layout>{} ) - ); // Val layout, 8 vals per store - using GmemTiledCopydQ = decltype( + ); // Val layout, 8 vals per read + using GmemTiledCopydO = decltype( make_tiled_copy( Copy_Atom, elem_type>{}, - GmemLayoutAtom{}, - Layout>{} + GmemLayoutAtomQKVO{}, + Layout>{} ) ); // Val layout, 8 vals per store - using GmemTiledCopyMask = decltype( + using GmemTiledCopydBias = decltype( make_tiled_copy( Copy_Atom, elem_type>{}, - GmemLayoutAtom{}, + GmemLayoutAtomBias{}, Layout>{} ) ); // Val layout, 8 vals per read - using GmemTiledCopyBias = decltype( + using GmemTiledCopydKV = decltype( make_tiled_copy( Copy_Atom, elem_type>{}, - GmemLayoutAtom{}, + GmemLayoutAtomQKVO{}, Layout>{} ) - ); // Val layout, 8 vals per read + ); // Val layout, 8 vals per store + using GmemTiledCopydQ = decltype( + make_tiled_copy( + Copy_Atom, elem_type>{}, + GmemLayoutAtomQKVO{}, + Layout>{} + ) + ); // Val layout, 8 vals per store using GmemLayoutAtomdQaccum = std::conditional_t< kBlockKSmem == 32, - Layout, Stride< _8, _1>>, // Thread layout, 8 threads per row - Layout, Stride< _16, _1>> // Thread layout, 16 threads per row - + Layout, Stride<_8, _1>>, // Thread layout, 8 threads per row + Layout, Stride<_16, _1>> // Thread layout, 16 threads per row >; using GmemTiledCopydQaccum = decltype( make_tiled_copy( Copy_Atom, ElementAccum>{}, GmemLayoutAtomdQaccum{}, - Layout>{} + Layout>{} ) ); // Val layout, 4 vals per store using GmemTiledCopydQaccumAtomicAdd = decltype( make_tiled_copy( Copy_Atom, ElementAccum>{}, - Layout, // Thread layout, 8 threads per row + Layout, // Thread layout, 8 threads per row Stride<_32, _1>>{}, - Layout>{} + Layout>{} ) ); // Val layout, 1 val per store }; diff --git a/csrc/flash_dmattn/src/utils.h b/csrc/flash_dmattn/src/utils.h index 52fb7a0..13d204b 100644 --- a/csrc/flash_dmattn/src/utils.h +++ b/csrc/flash_dmattn/src/utils.h @@ -614,7 +614,7 @@ __forceinline__ __device__ void copy_mask( //////////////////////////////////////////////////////////////////////////////////////////////////// template < - bool Is_even_MN=true, bool Clear_OOB_MN=false, typename To_type=void, + bool Is_even_MN=true, bool Clear_OOB_MN=false, typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2, typename Engine3, typename Layout3 @@ -622,7 +622,7 @@ template < __forceinline__ __device__ void copy_mask_with_or_reduce( TiledCopy tiled_copy, Tensor const &S, Tensor &D, - bool &block_active, + bool &active, Tensor const &identity_MN, Tensor const &predicate_N, const int max_M=0 ) { @@ -632,18 +632,13 @@ __forceinline__ __device__ void copy_mask_with_or_reduce( CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_N - bool any_active = false; #pragma unroll for (int m = 0; m < size<1>(S); ++m) { if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_M) { #pragma unroll for (int n = 0; n < size<2>(S); ++n) { if (Is_even_MN || predicate_N(n)) { - #pragma unroll - for (int i = 0; i < size<0>(S); ++i) { - any_active |= S(i, m, n); - D(i, m, n) = static_cast(S(i, m, n)); - } + cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); } else if (Clear_OOB_MN) { cute::clear(D(_, m, n)); } @@ -653,7 +648,14 @@ __forceinline__ __device__ void copy_mask_with_or_reduce( } } - block_active = __syncthreads_or(any_active); + __syncthreads(); + + bool active_local = false; + #pragma unroll + for (int i = 0; i < size(D); ++i) { + active_local |= D(i); + } + active = __syncthreads_or(active_local); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -675,18 +677,14 @@ __forceinline__ __device__ void copy_bias( CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_N - + #pragma unroll for (int m = 0; m < size<1>(S); ++m) { if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_M) { #pragma unroll for (int n = 0; n < size<2>(S); ++n) { if (Is_even_MN || predicate_N(n)) { - // cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); - #pragma unroll - for (int i = 0; i < size<0>(S); ++i) { - D(i, m, n) = S(i, m, n); - } + cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); } else if (Clear_OOB_MN) { cute::clear(D(_, m, n)); } diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 5b5a889..2c267d4 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -95,7 +95,7 @@ def _flash_dmattn_forward( softcap, return_softmax, ) - _sanitize_tensors(out, nan=0.0, posinf=torch.finfo(out.dtype).max, neginf=torch.finfo(out.dtype).min) + # _sanitize_tensors(out, nan=0.0, posinf=torch.finfo(out.dtype).max, neginf=torch.finfo(out.dtype).min) return out, softmax_lse, S_dmask @@ -170,7 +170,7 @@ def _flash_dmattn_backward( softcap, deterministic, ) - _sanitize_tensors(dq, dk, dv, dbias, nan=0.0, posinf=torch.finfo(dq.dtype).max, neginf=torch.finfo(dq.dtype).min) + # _sanitize_tensors(dq, dk, dv, dbias, nan=0.0, posinf=torch.finfo(dq.dtype).max, neginf=torch.finfo(dq.dtype).min) return softmax_d @@ -247,14 +247,14 @@ def forward( q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - # seqlen_k_og = k.shape[1] - # if seqlen_k_og % 8 != 0: - # k = torch.nn.functional.pad(k, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8]) - # v = torch.nn.functional.pad(v, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8]) - # if mask is not None: - # mask = torch.nn.functional.pad(mask, [0, 8 - seqlen_k_og % 8], value=False) - # if bias is not None: - # bias = torch.nn.functional.pad(bias, [0, 8 - seqlen_k_og % 8], value=0.0) + seqlen_k_og = k.shape[1] + if seqlen_k_og % 8 != 0: + k = torch.nn.functional.pad(k, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8]) + v = torch.nn.functional.pad(v, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8]) + if mask is not None: + mask = torch.nn.functional.pad(mask, [0, 8 - seqlen_k_og % 8], value=False) + if bias is not None: + bias = torch.nn.functional.pad(bias, [0, 8 - seqlen_k_og % 8], value=0.0) out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward( q, @@ -274,7 +274,7 @@ def forward( ctx.is_causal = is_causal ctx.softcap = softcap ctx.deterministic = deterministic - # ctx.seqlen_k_og = seqlen_k_og + ctx.seqlen_k_og = seqlen_k_og out = out_padded[..., :head_size_og] @@ -318,10 +318,11 @@ def backward( dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - # if ctx.seqlen_k_og % 8 != 0: - # dk = dk[:, : ctx.seqlen_k_og, :, :] - # dv = dv[:, : ctx.seqlen_k_og, :, :] - # dbias = dbias[..., : ctx.seqlen_k_og] + if ctx.seqlen_k_og % 8 != 0: + dk = dk[:, : ctx.seqlen_k_og, :, :] + dv = dv[:, : ctx.seqlen_k_og, :, :] + if dbias is not None: + dbias = dbias[..., : ctx.seqlen_k_og] return dq, dk, dv, None, dbias, None, None, None, None, None, None