Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions csrc/flash_dmattn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ void set_params_dgrad(
const float softcap,
bool has_mask,
bool has_bias,
bool accum_dbias,
bool deterministic,
const bool unpadded_lse
) {
Expand Down Expand Up @@ -246,8 +245,6 @@ void set_params_dgrad(
// Softmax sum
params.dsoftmax_sum = dsoftmax_sum_d;

params.accum_dbias = accum_dbias;

params.deterministic = deterministic;
}

Expand Down Expand Up @@ -982,11 +979,10 @@ mha_bwd(
dbias_expanded = has_bias
? (num_heads_bias != num_heads || batch_size_bias == 1 || seqlen_q_bias == 1) // MQA / GQA or dbias has different batch size or seqlen_q
? (seqlen_q_bias == 1)
? torch::zeros({batch_size, num_heads, 1, seqlen_k_rounded}, opts.dtype(at::kFloat))
? torch::zeros({batch_size, num_heads, 1, seqlen_k_rounded}, opts)
: torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts)
Comment on lines +982 to 983
Copy link

Copilot AI Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dtype for dbias_expanded allocation has changed from at::kFloat to the default opts dtype. This may cause precision loss if opts is not Float32, since the accumulation in the kernel uses ElementAccum (fp32). Consider explicitly preserving opts.dtype(at::kFloat) to match the kernel's accumulation precision.

Suggested change
? torch::zeros({batch_size, num_heads, 1, seqlen_k_rounded}, opts)
: torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts)
? torch::zeros({batch_size, num_heads, 1, seqlen_k_rounded}, opts.dtype(at::kFloat))
: torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts.dtype(at::kFloat))

Copilot uses AI. Check for mistakes.
: dbias
: torch::empty({0}, opts);
bool accum_dbias = has_bias && (seqlen_q_bias == 1 && seqlen_q != 1);

Flash_bwd_params params;

Expand All @@ -1013,7 +1009,6 @@ mha_bwd(
softcap,
has_mask,
has_bias,
accum_dbias,
deterministic,
/*unpadded_lse*/false
);
Expand Down Expand Up @@ -1041,7 +1036,7 @@ mha_bwd(
if (num_heads_bias != num_heads && batch_size_bias == batch_size && seqlen_q_bias == seqlen_q) {
at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2});
} else {
if (accum_dbias) {
if (seqlen_q_bias == 1) {
dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, 1, seqlen_k_rounded}), {2});
} else {
dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2});
Expand Down Expand Up @@ -1244,7 +1239,6 @@ mha_varlen_bwd(
softcap,
has_mask,
has_bias,
/*accum_dbias*/false,
deterministic,
/*unpadded_lse*/true
);
Expand Down
2 changes: 0 additions & 2 deletions csrc/flash_dmattn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,6 @@ struct Flash_bwd_params : public Flash_fwd_params {

bool deterministic;
index_t dq_accum_split_stride;

bool accum_dbias;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
92 changes: 57 additions & 35 deletions csrc/flash_dmattn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;

int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);
bool accum_dbias = Has_bias && (params.dbias_row_stride == 0) && (binfo.actual_seqlen_q > 1);

const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
+ (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
Expand Down Expand Up @@ -159,10 +160,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.dbias_row_stride, _1{})
);
[[maybe_unused]] ElementAccum *gdBias_accum_ptr = nullptr;
if constexpr (Has_bias) {
gdBias_accum_ptr = reinterpret_cast<ElementAccum *>(params.dbias_ptr) + row_offset_dbias;
}
Tensor gdO = make_tensor(
make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
Expand Down Expand Up @@ -287,8 +284,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
GmemTiledCopydO gmem_tiled_copy_dO;
auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
typename Kernel_traits::GmemTiledCopydBias gmem_tiled_copy_dBias;
auto gmem_thr_copy_dBias = gmem_tiled_copy_dBias.get_thread_slice(tidx);
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
using GmemLayoutAtomdQaccum = std::conditional_t<
!Seq_parallel,
Expand All @@ -297,6 +292,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
>;
GmemLayoutAtomdQaccum gmem_tiled_copy_dQaccum;
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopydBias gmem_tiled_copy_dBias;
auto gmem_thr_copy_dBias = gmem_tiled_copy_dBias.get_thread_slice(tidx);

Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
Expand Down Expand Up @@ -346,6 +343,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in

Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (MMA, MMA_N, MMA_K)
Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (MMA, MMA_N, MMA_K)
[[maybe_unused]] auto acc_dbias = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{});
[[maybe_unused]] auto acc_dbias_rowcol = make_tensor(acc_dbias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_dbias.layout()));

// Copy Atom retiling
auto smem_tiled_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
Expand Down Expand Up @@ -641,8 +640,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view);
}

clear(acc_dv);
clear(acc_dk);
clear(acc_dv);
if constexpr (Has_bias) { if (accum_dbias) { clear(acc_dbias); } }

for (; m_block >= m_block_min; --m_block) {
Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
Expand Down Expand Up @@ -806,6 +806,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); }
dS(mi, ni) = scaled_ds;
if constexpr (Has_bias) {
if (accum_dbias) {
acc_dbias_rowcol(mi, ni) += scaled_ds;
}
}
}
}
// if (cute::thread0()) { print(dS); }
Expand Down Expand Up @@ -852,36 +857,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
__syncthreads();
if constexpr (Has_bias) {
// Write dS to dBias
if (!params.accum_dbias) {
if (!accum_dbias) {
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/false>(
gmem_tiled_copy_dBias,
tBiassBias, tdBiasgdBias,
tBiascBias, tBiaspBias,
binfo.actual_seqlen_q - m_block * kBlockM
);
} else {
#pragma unroll
for (int m = 0; m < size<1>(tBiassBias); ++m) {
if (Is_even_MN || get<0>(tBiascBias(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) {
#pragma unroll
for (int n = 0; n < size<2>(tBiassBias); ++n) {
if (Is_even_MN || tBiaspBias(n)) {
#pragma unroll
for (int i = 0; i < size<0>(tBiassBias); ++i) {
const auto coord = tBiascBias(i, m, n);
const int row = get<0>(coord);
const int col = get<1>(coord);
if (Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM) {
atomicAdd(
gdBias_accum_ptr + row * params.dbias_row_stride + col,
static_cast<ElementAccum>(tBiassBias(i, m, n))
);
}
}
}
}
}
}
}
}

Expand Down Expand Up @@ -1023,9 +1005,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// Advance gBias and gdBias
tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride));
tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride));
if (params.accum_dbias) {
gdBias_accum_ptr -= int(kBlockM * params.dbias_row_stride);
}
if (any_active_next) {
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_Bias,
Expand Down Expand Up @@ -1069,10 +1048,53 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in

// Epilogue

if constexpr (Has_bias) {
if (accum_dbias) {
const int actual_block_n = Is_even_MN ? kBlockN : std::max(0, std::min(kBlockN, binfo.actual_seqlen_k - n_block * kBlockN));

// Convert acc_dbias from fp32 to fp16
Tensor tdBiasrdBias = FLASH_NAMESPACE::convert_type<Element>(acc_dbias);

// Partition sBias to match the accumulator partitioning
Tensor tdBiasadBias = smem_thr_copy_Bias.retile_S(tdBiasrdBias); // ((Atom, AtomNum), MMA_M, MMA_N)

// We need syncthreads here since we're writing to the same location as sBias.
// Without syncthreads, some thread might modify the location of sBias while another thread
// is reading it for dQ gemm, leading to a race condition.
// If Is_last, there's already a __syncthreads() at the end of the loop.
if (!Is_last) { __syncthreads(); }

cute::copy(smem_tiled_copy_PdS, tdBiasadBias, tdSsdS);

__syncthreads();
Copy link

Copilot AI Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reduction loop should document that it's performing a column-wise sum across all M rows into row 0 of shared memory, since this is the critical step for correct dBias accumulation when dbias_row_stride == 0.

Suggested change
__syncthreads();
__syncthreads();
// Perform a column-wise sum across all M rows into row 0 of shared memory.
// This is the critical step for correct dBias accumulation when dbias_row_stride == 0.

Copilot uses AI. Check for mistakes.
for (int col = threadIdx.x; col < kBlockN; col += blockDim.x) {
if (col < actual_block_n) {
ElementAccum rowsum = 0.f;
#pragma unroll
for (int row = 0; row < kBlockM; ++row) {
rowsum += static_cast<ElementAccum>(sdS(row, col));
}
sdS(0, col) = static_cast<Element>(rowsum);
}
}
__syncthreads();

#pragma unroll
for (int ni = 0; ni < size(tBiaspBias); ++ni) { tBiaspBias(ni) = ni < actual_block_n; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/false, /*Clear_OOB_MN=*/false>(
gmem_tiled_copy_dBias,
tBiassBias, tdBiasgdBias,
tBiascBias, tBiaspBias,
/*max_M=*/1
Copy link

Copilot AI Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The max_M=1 parameter is critical to the correctness of this implementation as it ensures only the reduced row is written. This should have a comment explaining that we're writing only row 0 which contains the sum across all M rows.

Copilot uses AI. Check for mistakes.
);
}
}

#pragma unroll
for (int i = 0; i < size(acc_dk); ++i) { acc_dk(i) *= params.scale_softmax; }

// Convert acc_dv from fp32 to fp16
// Convert acc_dk, acc_dv from fp32 to fp16
Tensor rdK = FLASH_NAMESPACE::convert_type<Element>(acc_dk);
Tensor rdV = FLASH_NAMESPACE::convert_type<Element>(acc_dv);

Expand Down
Loading