Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions csrc/flash_dmattn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ void set_params_dgrad(
const float softcap,
bool has_mask,
bool has_bias,
bool accum_dbias,
bool deterministic,
const bool unpadded_lse
) {
Expand Down Expand Up @@ -245,6 +246,8 @@ void set_params_dgrad(
// Softmax sum
params.dsoftmax_sum = dsoftmax_sum_d;

params.accum_dbias = accum_dbias;

params.deterministic = deterministic;
}

Expand Down Expand Up @@ -977,12 +980,13 @@ mha_bwd(
? torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts)
: dv;
dbias_expanded = has_bias
? (
(num_heads_bias != num_heads || batch_size_bias != batch_size || seqlen_q_bias != seqlen_q) // MQA / GQA or dbias has different batch size or seqlen_q
? torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts)
: dbias
)
? (num_heads_bias != num_heads || batch_size_bias == 1 || seqlen_q_bias == 1) // MQA / GQA or dbias has different batch size or seqlen_q
? (seqlen_q_bias == 1)
? torch::zeros({batch_size, num_heads, 1, seqlen_k_rounded}, opts.dtype(at::kFloat))
: torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts)
: dbias
: torch::empty({0}, opts);
bool accum_dbias = has_bias && (seqlen_q_bias == 1 && seqlen_q != 1);

Flash_bwd_params params;

Expand All @@ -1009,6 +1013,7 @@ mha_bwd(
softcap,
has_mask,
has_bias,
accum_dbias,
deterministic,
/*unpadded_lse*/false
);
Expand Down Expand Up @@ -1036,9 +1041,10 @@ mha_bwd(
if (num_heads_bias != num_heads && batch_size_bias == batch_size && seqlen_q_bias == seqlen_q) {
at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2});
} else {
dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2});
if (seqlen_q_bias == 1) {
dbias_expanded = at::sum(dbias_expanded, {2}, true);
if (accum_dbias) {
dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, 1, seqlen_k_rounded}), {2});
} else {
dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2});
}
if (batch_size_bias == 1) {
dbias_expanded = at::sum(dbias_expanded, {0}, true);
Expand Down Expand Up @@ -1238,6 +1244,7 @@ mha_varlen_bwd(
softcap,
has_mask,
has_bias,
/*accum_dbias*/false,
deterministic,
/*unpadded_lse*/true
);
Expand Down
2 changes: 2 additions & 0 deletions csrc/flash_dmattn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ struct Flash_bwd_params : public Flash_fwd_params {

bool deterministic;
index_t dq_accum_split_stride;

bool accum_dbias;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
44 changes: 38 additions & 6 deletions csrc/flash_dmattn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.dbias_row_stride, _1{})
);
[[maybe_unused]] ElementAccum *gdBias_accum_ptr = nullptr;
if constexpr (Has_bias) {
gdBias_accum_ptr = reinterpret_cast<ElementAccum *>(params.dbias_ptr) + row_offset_dbias;
}
Tensor gdO = make_tensor(
make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
Expand Down Expand Up @@ -848,12 +852,37 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
__syncthreads();
if constexpr (Has_bias) {
// Write dS to dBias
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/false>(
gmem_tiled_copy_dBias,
tBiassBias, tdBiasgdBias,
tBiascBias, tBiaspBias,
binfo.actual_seqlen_q - m_block * kBlockM
);
if (!params.accum_dbias) {
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/false>(
gmem_tiled_copy_dBias,
tBiassBias, tdBiasgdBias,
tBiascBias, tBiaspBias,
binfo.actual_seqlen_q - m_block * kBlockM
);
} else {
#pragma unroll
for (int m = 0; m < size<1>(tBiassBias); ++m) {
if (Is_even_MN || get<0>(tBiascBias(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) {
#pragma unroll
for (int n = 0; n < size<2>(tBiassBias); ++n) {
if (Is_even_MN || tBiaspBias(n)) {
#pragma unroll
for (int i = 0; i < size<0>(tBiassBias); ++i) {
const auto coord = tBiascBias(i, m, n);
const int row = get<0>(coord);
const int col = get<1>(coord);
if (Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM) {
atomicAdd(
gdBias_accum_ptr + row * params.dbias_row_stride + col,
static_cast<ElementAccum>(tBiassBias(i, m, n))
);
}
}
}
Comment on lines +863 to +881
Copy link

Copilot AI Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This path performs an atomicAdd per element, which can cause significant contention when seqlen_q is large (many M-tiles accumulate into the same broadcasted bias row). Consider reducing within the threadblock first (e.g., per-(row,col) partial sums in shared memory or warp-level reductions) and issuing a single atomicAdd per (row,col) per block. This typically cuts the number of atomics by a factor of size<0>(tBiassBias) and improves throughput.

Copilot uses AI. Check for mistakes.
}
}
}
}
}

// if (cute::thread0()) { print(tPrP); }
Expand Down Expand Up @@ -994,6 +1023,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// Advance gBias and gdBias
tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride));
tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride));
if (params.accum_dbias) {
gdBias_accum_ptr -= int(kBlockM * params.dbias_row_stride);
}
if (any_active_next) {
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_Bias,
Expand Down