Skip to content
87 changes: 31 additions & 56 deletions csrc/flash_dmattn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -644,17 +644,37 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
clear(acc_dv);
if constexpr (Has_bias) { if (accum_dbias) { clear(acc_dbias); } }

cute::cp_async_wait<0>();
__syncthreads();

// Scale K once before streaming loop Q
#pragma unroll
for (int k = 0; k < size(tKsK); ++k) {
tKsK(k) = static_cast<Element>(tKsK(k) * params.scale_softmax);
}

for (; m_block >= m_block_min; --m_block) {
Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
cute::cp_async_wait<0>();
__syncthreads();

Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (MMA=4, MMA_M, MMA_K)
cute::cp_async_wait<0>();
__syncthreads();

if (any_active) {
clear(acc_s);
if constexpr (Has_bias) {
// Copy bias from smem to registers
Tensor tSrBias = make_tensor<Element>(shape(acc_s));
Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias);
cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view);
#pragma unroll
for (int i = 0; i < size(acc_s); ++i) { acc_s(i) = tSrBias(i); }
} else {
clear(acc_s);
}
}


Comment on lines +675 to 676
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty line with trailing whitespace. Remove the trailing whitespace for consistency.

Suggested change

Copilot uses AI. Check for mistakes.
if (any_active) {
Tensor dP_sum = make_fragment_like(lse);
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); }
Expand Down Expand Up @@ -686,71 +706,26 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
FLASH_NAMESPACE::calculate_dtanh(scores, dtanh, params.softcap);
}

if constexpr (Has_mask && Has_bias) {
// Copy mask and bias from smem to registers
Tensor tSrMask = make_tensor<Element>(shape(acc_s));
Tensor tSrMask_copy_view = smem_thr_copy_PdS.retile_D(tSrMask);
cute::copy(smem_tiled_copy_PdS, tSsMask, tSrMask_copy_view);
Tensor tSrBias = make_tensor<Element>(shape(acc_s));
Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias);
cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view);

// Reshape mask, bias from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N))
Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout()));
Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout()));

// TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
// actual_seqlen_k, because acc_s would be some finite value for those indices.
// In the end when we multiply with K to get dQ, the corresponding values of K would be 0,
// so the result would still be correct.
// However, it's possible that the values in acc_s are so large that they overflow
// when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ.
// So we need to mask out the elements beyond actual_seqlen_k.
FLASH_NAMESPACE::apply_mask</*Causal_mask=*/Is_causal, /*Has_mask=*/Has_mask, /*Has_bias=*/Has_bias>(
scores, mask, bias, params.scale_softmax,
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k,
m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q,
AtomLayoutMS * 16
);
} else if constexpr (Has_mask && !Has_bias) {
if constexpr (Has_mask) {
// Copy mask from smem to registers
Tensor tSrMask = make_tensor<Element>(shape(acc_s));
Tensor tSrMask_copy_view = smem_thr_copy_PdS.retile_D(tSrMask);
cute::copy(smem_tiled_copy_PdS, tSsMask, tSrMask_copy_view);

// Reshape mask from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N))
Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout()));

FLASH_NAMESPACE::apply_mask</*Causal_mask=*/Is_causal, /*Has_mask=*/Has_mask, /*Has_bias=*/Has_bias>(
scores, mask, /*bias=*/nullptr, params.scale_softmax,
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k,
m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q,
AtomLayoutMS * 16
);
} else if constexpr (!Has_mask && Has_bias) {
// Copy bias from smem to registers
Tensor tSrBias = make_tensor<Element>(shape(acc_s));
Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias);
cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view);

// Reshape bias from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N))
Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout()));

FLASH_NAMESPACE::apply_mask</*Causal_mask=*/Is_causal, /*Has_mask=*/Has_mask, /*Has_bias=*/Has_bias>(
scores, /*mask=*/nullptr, bias, params.scale_softmax,
FLASH_NAMESPACE::apply_mask<Is_causal, Has_mask>(
scores, mask,
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k,
m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q,
AtomLayoutMS * 16
);
} else {
FLASH_NAMESPACE::apply_mask</*Causal_mask=*/Is_causal, /*Has_mask=*/Has_mask, /*Has_bias=*/Has_bias>(
scores, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax,
FLASH_NAMESPACE::apply_mask<Is_causal, Has_mask>(
scores, /*mask=*/nullptr,
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k,
m_block * kBlockM + get<0>(taccScS_row(0)),
Expand Down Expand Up @@ -965,8 +940,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
for (int i = 0; i < size(acc_dq); ++i) { atomicAdd(&tdQgdQaccum(i), acc_dq(i)); }
}
} else {
#pragma unroll
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; }
// #pragma unroll
// for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; }
// Convert acc_dq from fp32 to fp16
Tensor rdQ = FLASH_NAMESPACE::convert_type<Element>(acc_dq);
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom, AtomNum), MMA_M, MMA_K)
Expand Down
4 changes: 2 additions & 2 deletions csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,8 @@ inline __device__ void convert_dQ(
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); }
tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride;
}
#pragma unroll
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; }
// #pragma unroll
// for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; }
// Convert acc_dq from fp32 to fp16
Tensor rdQ = FLASH_NAMESPACE::convert_type<Element>(acc_dq);
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom, AtomNum), MMA_N, MMA_N)
Expand Down
Loading