Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 43 additions & 58 deletions csrc/flash_dmattn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, b
inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {

using Element = typename Kernel_traits::Element;
using ElementMask = typename Kernel_traits::ElementMask;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;

Expand Down Expand Up @@ -245,25 +246,29 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
sBias.data(),
typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}
);
Tensor sMask = make_tensor(
Tensor sMaskPlace = make_tensor(
sBias.data() + size(sBias),
typename Kernel_traits::SmemLayoutMaskBiasPdS{}
); // For pointers alignment only
Tensor sMask = make_tensor(
make_smem_ptr(reinterpret_cast<ElementMask *>(sMaskPlace.data().get())),
Copy link

Copilot AI Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The pattern of creating sMaskPlace for pointer alignment and then reinterpreting it to ElementMask* is repeated multiple times. Consider extracting this into a helper function to reduce code duplication.

Copilot uses AI. Check for mistakes.
typename Kernel_traits::SmemLayoutMaskBiasPdS{}
);
Tensor sP = make_tensor(
sMask.data(),
sMaskPlace.data(),
typename Kernel_traits::SmemLayoutMaskBiasPdS{}
);
Tensor sPt = make_tensor(
sMask.data(),
sMaskPlace.data(),
typename Kernel_traits::SmemLayoutPdStransposed{}
);
Tensor sPtNoSwizzle = make_tensor(
sMask.data(),
sMaskPlace.data(),
typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}
);
// sMask, sP and sdQ share the same memory so be careful
Tensor sdQ = make_tensor(
sMask.data(),
sMaskPlace.data(),
typename Kernel_traits::SmemLayoutdQ{}
);

Expand All @@ -278,6 +283,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
GmemTiledCopydO gmem_tiled_copy_dO;
auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
typename Kernel_traits::GmemTiledCopydBias gmem_tiled_copy_dBias;
auto gmem_thr_copy_dBias = gmem_tiled_copy_dBias.get_thread_slice(tidx);
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
using GmemLayoutAtomdQaccum = std::conditional_t<
!Seq_parallel,
Expand All @@ -300,7 +307,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Tensor tMasksMask = gmem_thr_copy_Mask.partition_D(sMask);
Tensor tBiasgBias = gmem_thr_copy_Bias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N)
Tensor tBiassBias = gmem_thr_copy_Bias.partition_D(sBias);
Tensor tdBiasgdBias = gmem_thr_copy_Bias.partition_D(gdBias);
Tensor tdBiasgdBias = gmem_thr_copy_dBias.partition_D(gdBias);

Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom, AtomNum), ATOM_M, ATOM_N)
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
Expand Down Expand Up @@ -350,20 +357,17 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// if (cute::thread(0, 0) && n_block == 0) { print(tSsK.layout()); printf("\n"); }
Tensor tdPsV = smem_thr_copy_KV.partition_S(sV);

// auto smem_tiled_copy_Mask = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma_sdp);
// auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx);
// Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask);
// auto smem_tiled_copy_Bias = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma_sdp);
// auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx);
// Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias);

// Partition sP and sdS to match the accumulator partitioning
// This has to be tiled_mma_sdp, not tiled_mma_dkv
// auto smem_tiled_copy_PdS = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp);
auto smem_tiled_copy_PdS = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp);
auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(tidx);
Tensor tSsMask = smem_thr_copy_PdS.partition_S(sMask);
Tensor tSsBias = smem_thr_copy_PdS.partition_S(sBias);
auto smem_tiled_copy_Mask = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma_sdp);
auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx);
auto smem_tiled_copy_Bias = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma_sdp);
auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx);
Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask);
Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias);
Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom, AtomNum), PIPE_M, PIPE_N)
Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom, AtomNum), PIPE_M, PIPE_N)

Expand Down Expand Up @@ -573,24 +577,19 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// // if (cute::thread(1, 0)) { print(tKrK); }

if constexpr (Has_mask) {
// FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
// gmem_tiled_copy_Mask,
// tMaskgMask, tMasksMask,
// tMaskcMask, tMaskpMask,
// binfo.actual_seqlen_q - m_block * kBlockM
// );
// cute::cp_async_fence();
// FLASH_NAMESPACE::cp_async_wait<0>();
// // Do OR-reduce on the mask to see if any active threads

FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /*Clear_OOB_MN=*/true, /*To_type=*/Element>(
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_Mask,
tMaskgMask, tMasksMask,
any_active,
tMaskcMask, tMaskpMask,
binfo.actual_seqlen_q - m_block * kBlockM
);
// We don't need to syncthreads here because copy_mask is already or_syncthreads.
__syncthreads();
// Do OR-reduce on the mask to see if any active threads for current interation.
FLASH_NAMESPACE::mask_or_reduce(
tMasksMask,
any_active,
smem_thr_copy_Mask
);
}

FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
Expand All @@ -602,15 +601,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in

if (any_active) {
if constexpr (Has_bias) {
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/true>(
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_Bias,
tBiasgBias, tBiassBias,
tBiascBias, tBiaspBias,
binfo.actual_seqlen_q - m_block * kBlockM
);
// Because copy_bias currently uses scalar loads, we need to sync here.
// TODO: Remove sync after fixing to vectorized loads.
__syncthreads();
}
}

Expand All @@ -627,9 +623,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); }
if (Is_first) {
cute::copy(tdOrdO, tdOsdO);
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(
dot_do_o<Kernel_traits::kGmemThreadsPerRowQKVO>(
tdOrdO, tdOrO, gdPsum,
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow)
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRowQKVO)
);
}

Expand Down Expand Up @@ -852,15 +848,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
__syncthreads();
if constexpr (Has_bias) {
// Write dS to dBias
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/false>(
gmem_tiled_copy_Bias,
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/false>(
gmem_tiled_copy_dBias,
tBiassBias, tdBiasgdBias,
tBiascBias, tBiaspBias,
binfo.actual_seqlen_q - m_block * kBlockM
);
// Because copy_bias currently uses scalar loads, we need to sync here.
// TODO: Remove sync after fixing to vectorized loads.
__syncthreads();
}

// if (cute::thread0()) { print(tPrP); }
Expand All @@ -886,24 +879,19 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
if constexpr (Has_mask) {
// Advance gMask
tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride));
// FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
// gmem_tiled_copy_Mask,
// tMaskgMask, tMasksMask,
// tMaskcMask, tMaskpMask,
// binfo.actual_seqlen_q - (m_block - 1) * kBlockM
// );
// FLASH_NAMESPACE::cp_async_fence();
// FLASH_NAMESPACE::cp_async_wait<0>();
// // Do OR-reduce on the mask to see if any active threads for next iteration

FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /*Clear_OOB_MN=*/true, /*To_type=*/Element>(
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_Mask,
tMaskgMask, tMasksMask,
any_active_next,
tMaskcMask, tMaskpMask,
binfo.actual_seqlen_q - (m_block - 1) * kBlockM
);
// We don't need to syncthreads here because copy_mask is already or_syncthreads.
__syncthreads();
// Do OR-reduce on the mask to see if any active threads for next iteration
FLASH_NAMESPACE::mask_or_reduce(
tMasksMask,
any_active_next,
smem_thr_copy_Mask
);
}

// Advance gdO
Expand Down Expand Up @@ -1007,15 +995,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride));
tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride));
if (any_active_next) {
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/true>(
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_Bias,
tBiasgBias, tBiassBias,
tBiascBias, tBiaspBias,
binfo.actual_seqlen_q - (m_block - 1) * kBlockM
);
// Because copy_bias currently uses scalar loads, we need to sync here.
// TODO: Remove sync after fixing to vectorized loads.
__syncthreads();
}
}
}
Expand All @@ -1024,9 +1009,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in

if (Is_first && m_block > m_block_min) {
cute::copy(tdOrdO, tdOsdO);
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(
dot_do_o<Kernel_traits::kGmemThreadsPerRowQKVO>(
tdOrdO, tdOrO, gdPsum,
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow)
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRowQKVO)
);
}

Expand Down
4 changes: 2 additions & 2 deletions csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ inline __device__ void compute_dot_do_o(
tdOcdO, tdOpdO,
binfo.actual_seqlen_q - m_block * kBlockM
);
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(
dot_do_o<Kernel_traits::kGmemThreadsPerRowQKVO>(
tdOrdO, tdOrO, dP_sum,
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow)
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRowQKVO)
);
if (Clear_dQaccum) {
// We're actually not zero'ing out all of dQaccum, but only the part that we're going to
Expand Down
Loading