Skip to content
113 changes: 70 additions & 43 deletions csrc/flash_dmattn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// Global to Shared Memory operation
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyMaskBias gmem_tiled_copy_MaskBias;
auto gmem_thr_copy_Mask = gmem_tiled_copy_MaskBias.get_thread_slice(tidx);
auto gmem_thr_copy_Bias = gmem_tiled_copy_MaskBias.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyMask gmem_tiled_copy_Mask;
typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_Bias;
auto gmem_thr_copy_Mask = gmem_tiled_copy_Mask.get_thread_slice(tidx);
auto gmem_thr_copy_Bias = gmem_tiled_copy_Bias.get_thread_slice(tidx);
using GmemTiledCopydO = std::conditional_t<Is_first, typename Kernel_traits::GmemTiledCopydO, typename Kernel_traits::GmemTiledCopyQKV>;
GmemTiledCopydO gmem_tiled_copy_dO;
auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
Expand Down Expand Up @@ -417,9 +418,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
}

// Allocate predicate tensors for N
Tensor tMaskpMask = make_tensor<bool>(make_shape(size<2>(tMasksMask)));
Tensor tBiaspBias = make_tensor<bool>(make_shape(size<2>(tBiassBias)));

// Set predicates for n bounds
if (!Is_even_MN) {
#pragma unroll
for (int n = 0; n < size(tMaskpMask); ++n) { tMaskpMask(n) = get<1>(tMaskcMask(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); }
#pragma unroll
for (int n = 0; n < size(tBiaspBias); ++n) { tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); }
}


// Prologue

bool any_active = true; // to be updated later for current iteration
bool any_active_next = true; // to be updated later for next iteration

// We'll advance gdQ, gdQaccum and gdBias before the 1st read/write.
tdQgdQ.data() = tdQgdQ.data() + kBlockM * params.dq_row_stride;
tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded;
Expand Down Expand Up @@ -554,24 +570,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK);
// // if (cute::thread(1, 0)) { print(tKrK); }

FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
gmem_tiled_copy_MaskBias,
tMaskgMask, tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
);
// FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
// gmem_tiled_copy_Mask,
// tMaskgMask, tMasksMask,
// tMaskcMask, tMaskpMask,
// binfo.actual_seqlen_q - m_block * kBlockM
// );
// cute::cp_async_fence();
// FLASH_NAMESPACE::cp_async_wait<0>();
__syncthreads();
// // Do OR-reduce on the mask to see if any active threads

// Do OR-reduce on the mask to see if any active threads
Tensor tSsMask_copy_view = smem_thr_copy_PdS.retile_S(tSsMask);
bool any_active_local = false;
bool any_active_local_next = false; // to be updated later for next iteration
#pragma unroll
for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); }
bool any_active = __syncthreads_or(any_active_local);
bool any_active_next = false; // to be updated later for next iteration
FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /*Clear_OOB_MN=*/true, /*To_type=*/Element>(
gmem_tiled_copy_Mask,
tMaskgMask, tMasksMask,
any_active,
tMaskcMask, tMaskpMask,
binfo.actual_seqlen_q - m_block * kBlockM
);
// We don't need to syncthreads here because copy_mask is already or_syncthreads.

FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_QKV,
Expand All @@ -581,12 +597,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
);

if (any_active) {
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_MaskBias,
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_Bias,
tBiasgBias, tBiassBias,
tBiascBias,
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
tBiascBias, tBiaspBias,
binfo.actual_seqlen_q - m_block * kBlockM
);
// Because copy_bias currently uses scalar loads, we need to sync here.
// TODO: Remove sync after fixing to vectorized loads.
__syncthreads();
}

if (!Kernel_traits::Is_V_in_regs) {
Expand Down Expand Up @@ -780,13 +799,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
__syncthreads();
// Write dS to dBias
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/false>(
gmem_tiled_copy_MaskBias,
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/false>(
gmem_tiled_copy_Bias,
tBiassBias, tdBiasgdBias,
tBiascBias,
binfo.actual_seqlen_q - m_block * kBlockM,
binfo.actual_seqlen_k - n_block * kBlockN
tBiascBias, tBiaspBias,
binfo.actual_seqlen_q - m_block * kBlockM
);
// Because copy_bias currently uses scalar loads, we need to sync here.
// TODO: Remove sync after fixing to vectorized loads.
__syncthreads();

// if (cute::thread0()) { print(tPrP); }
// Layout p_l = tPrP.layout();
Expand All @@ -810,21 +831,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
if (m_block > m_block_min) {
// Advance gMask
tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride));
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
gmem_tiled_copy_MaskBias,
tMaskgMask, tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
);
// FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
// gmem_tiled_copy_Mask,
// tMaskgMask, tMasksMask,
// tMaskcMask, tMaskpMask,
// binfo.actual_seqlen_q - (m_block - 1) * kBlockM
// );
// FLASH_NAMESPACE::cp_async_fence();
// FLASH_NAMESPACE::cp_async_wait<0>();
__syncthreads();
// // Do OR-reduce on the mask to see if any active threads for next iteration

// Do OR-reduce on the mask to see if any active threads for next iteration
any_active_local_next = false;
#pragma unroll
for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local_next |= (tSsMask_copy_view(i) != Element(0)); }
any_active_next = __syncthreads_or(any_active_local_next);
FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /*Clear_OOB_MN=*/true, /*To_type=*/Element>(
gmem_tiled_copy_Mask,
tMaskgMask, tMasksMask,
any_active_next,
tMaskcMask, tMaskpMask,
binfo.actual_seqlen_q - (m_block - 1) * kBlockM
);
// We don't need to syncthreads here because copy_mask is already or_syncthreads.

// Advance gdO
tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride));
Expand Down Expand Up @@ -926,12 +950,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride));
tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride));
if (any_active_next) {
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_MaskBias,
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_Bias,
tBiasgBias, tBiassBias,
tBiascBias,
binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
tBiascBias, tBiaspBias,
binfo.actual_seqlen_q - (m_block - 1) * kBlockM
);
// Because copy_bias currently uses scalar loads, we need to sync here.
// TODO: Remove sync after fixing to vectorized loads.
__syncthreads();
}
}

Expand Down
Loading