From 59b775a284f5cf58c8c55c654315f74aa933ef62 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 12 Oct 2025 18:48:36 +0800 Subject: [PATCH 01/12] Supports broadcasting for mask/bias strides Sets zero stride for size-1 dims to broadcast across batch/head/row, instead of using the underlying stride. Prevents incorrect indexing when mask/bias are shared across dimensions and aligns with standard broadcasting semantics. Improves forward-pass robustness with partially broadcast inputs. --- csrc/flash_dmattn/flash_api.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index 0cc7cdf..94ad7c9 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -76,10 +76,10 @@ void set_params_fprop( params.k_head_stride = k.stride(-2); params.v_row_stride = v.stride(-3); params.v_head_stride = v.stride(-2); - params.mask_head_stride = has_mask ? mask.stride(-3) : 0; - params.mask_row_stride = has_mask ? mask.stride(-2) : 0; - params.bias_head_stride = has_bias ? bias.stride(-3) : 0; - params.bias_row_stride = has_bias ? bias.stride(-2) : 0; + params.mask_head_stride = has_mask ? (mask.size(-3) == 1 ? 0 : mask.stride(-3)) : 0; + params.mask_row_stride = has_mask ? (mask.size(-2) == 1 ? 0 : mask.stride(-2)) : 0; + params.bias_head_stride = has_bias ? (bias.size(-3) == 1 ? 0 : bias.stride(-3)) : 0; + params.bias_row_stride = has_bias ? (bias.size(-2) == 1 ? 0 : bias.stride(-2)) : 0; params.o_row_stride = out.stride(-3); params.o_head_stride = out.stride(-2); @@ -87,8 +87,8 @@ void set_params_fprop( params.q_batch_stride = q.stride(0); params.k_batch_stride = k.stride(0); params.v_batch_stride = v.stride(0); - params.mask_batch_stride = has_mask ? mask.stride(0) : 0; - params.bias_batch_stride = has_bias ? bias.stride(0) : 0; + params.mask_batch_stride = has_mask ? (mask.size(0) == 1 ? 0 : mask.stride(0)) : 0; + params.bias_batch_stride = has_bias ? (bias.size(0) == 1 ? 0 : bias.stride(0)) : 0; params.o_batch_stride = out.stride(0); if (seqlenq_ngroups_swapped) { params.q_batch_stride *= seqlen_q; From 54fde5876da0590d4d74eb26e553850be42693a4 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 12 Oct 2025 18:49:04 +0800 Subject: [PATCH 02/12] Supports broadcasted bias strides in backward Sets zero stride for singleton batch/head/row dims to align with broadcasting semantics, preventing misaddressing when bias is shared across dimensions. Improves correctness and flexibility of the backward path with broadcasted bias. --- csrc/flash_dmattn/flash_api.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index 94ad7c9..2326596 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -227,15 +227,15 @@ void set_params_dgrad( params.dk_head_stride = dk.stride(-2); params.dv_row_stride = dv.stride(-3); params.dv_head_stride = dv.stride(-2); - params.dbias_head_stride = has_bias ? dbias.stride(-3) : 0; - params.dbias_row_stride = has_bias ? dbias.stride(-2) : 0; + params.dbias_head_stride = has_bias ? (dbias.size(-3) == 1 ? 0 : dbias.stride(-3)) : 0; + params.dbias_row_stride = has_bias ? (dbias.size(-2) == 1 ? 0 : dbias.stride(-2)) : 0; if (cu_seqlens_q_d == nullptr) { params.do_batch_stride = dout.stride(0); params.dq_batch_stride = dq.stride(0); params.dk_batch_stride = dk.stride(0); params.dv_batch_stride = dv.stride(0); - params.dbias_batch_stride = has_bias ? dbias.stride(0) : 0; + params.dbias_batch_stride = has_bias ? (dbias.size(0) == 1 ? 0 : dbias.stride(0)) : 0; } params.dq_accum_ptr = dq_accum_d; From 59c7f3975784c0bcae1cc14943b78f601d9a6253 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 12 Oct 2025 18:51:09 +0800 Subject: [PATCH 03/12] Enforces 4D mask/bias and robust broadcasting Clarifies and validates mask/bias to a strict 4D contract with broadcastable batch, head, and query dims, and a key length rounded to 128. Removes implicit 3D unsqueeze/expand to prevent silent shape mismatches and stride issues. Reworks shape handling in the swapped-heads path to use reshape-based broadcasting and preserves original batch/head counts for correctness. Computes rounding earlier and applies consistent checks across inputs. Improves correctness and stability by aligning inputs with kernel expectations and reducing accidental expansions. --- csrc/flash_dmattn/flash_api.cpp | 73 ++++++++++++++++----------------- 1 file changed, 36 insertions(+), 37 deletions(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index 2326596..576b8bb 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -353,8 +353,8 @@ mha_fwd( at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - std::optional &mask_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k - std::optional &bias_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k + std::optional &mask_, // {batch_size|1} x {num_heads|num_heads_k|1} x {seqlen_q|1} x round_multiple(seqlen_k, 128) + std::optional &bias_, // {batch_size|1} x {num_heads|num_heads_k|1} x {seqlen_q|1} x round_multiple(seqlen_k, 128) std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const float softmax_scale, bool is_causal, @@ -387,11 +387,8 @@ mha_fwd( mask = mask_.value(); TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); CHECK_DEVICE(mask); + TORCH_CHECK(mask.dim() == 4, "mask must have 4 dimensions with shape (batch_size, nheads, seqlen_q, seqlen_k_rounded)"); TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - if (mask.dim() == 3) { - // Add a dummy dimension for seqlen_q - mask = mask.unsqueeze(2).expand({-1, -1, q.size(1), -1}); - } } else { mask = torch::empty({0}, opts); } @@ -401,11 +398,8 @@ mha_fwd( bias = bias_.value(); TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); CHECK_DEVICE(bias); + TORCH_CHECK(bias.dim() == 4, "bias must have 4 dimensions with shape (batch_size, nheads, seqlen_q, seqlen_k_rounded)"); TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - if (bias.dim() == 3) { - // Add a dummy dimension for seqlen_q - bias = bias.unsqueeze(2).expand({-1, -1, q.size(1), -1}); - } } else { bias = torch::empty({0}, opts); } @@ -420,16 +414,27 @@ mha_fwd( const int num_heads_k = k.size(2); int num_heads_mask = has_mask ? mask.size(1) : 1; int num_heads_bias = has_bias ? bias.size(1) : 1; + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (has_mask) { + TORCH_CHECK(mask.size(0) == batch_size || mask.size(0) == 1, "Batch dimension in mask must be 1 or equal to batch size"); TORCH_CHECK(num_heads_mask == 1 || num_heads_mask == num_heads_k || num_heads_mask == num_heads, "Number of heads in mask must be 1, h_k or h"); + TORCH_CHECK(mask.size(2) == 1 || mask.size(2) == seqlen_q, "Query length dimension in mask must be 1 or equal to seqlen_q"); + TORCH_CHECK(mask.size(3) == seqlen_k_rounded, "Key length dimension in mask must be seqlen_k_rounded"); } if (has_bias) { + TORCH_CHECK(bias.size(0) == batch_size || bias.size(0) == 1, "Batch dimension in bias must be 1 or equal to batch size"); TORCH_CHECK(num_heads_bias == 1 || num_heads_bias == num_heads_k || num_heads_bias == num_heads, "Number of heads in bias must be 1, h_k or h"); + TORCH_CHECK(bias.size(2) == 1 || bias.size(2) == seqlen_q, "Query length dimension in bias must be 1 or equal to seqlen_q"); + TORCH_CHECK(bias.size(3) == seqlen_k_rounded, "Key length dimension in bias must be seqlen_k_rounded"); } // causal=true is the same as causal=false in this case @@ -439,27 +444,29 @@ mha_fwd( // H/t Daniel Haziza const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0; const int ngroups = num_heads / num_heads_k; - const int orig_num_heads_mask = num_heads_mask; - const int orig_num_heads_bias = num_heads_bias; + const int batch_size_mask_og = has_mask ? mask.size(0) : batch_size; + const int batch_size_bias_og = has_bias ? bias.size(0) : batch_size; + const int num_heads_mask_og = num_heads_mask; + const int num_heads_bias_og = num_heads_bias; if (seqlenq_ngroups_swapped) { q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); if (has_mask) { - mask = num_heads_mask == 1 - ? mask.expand({batch_size, 1, ngroups, seqlen_k}) - : ( - num_heads_mask == num_heads_k - ? mask.expand({batch_size, num_heads_k, ngroups, seqlen_k}) - : mask.reshape({batch_size, num_heads_k, ngroups, seqlen_k}) - ); + if (num_heads_mask == 1) { + mask = mask.reshape({batch_size_mask_og, 1, 1, seqlen_k_rounded}); + } else if (num_heads_mask == num_heads_k) { + mask = mask.reshape({batch_size_mask_og, num_heads_k, 1, seqlen_k_rounded}); + } else if (num_heads_mask == num_heads) { + mask = mask.reshape({batch_size_mask_og, num_heads_k, ngroups, seqlen_k_rounded}); + } } if (has_bias) { - bias = num_heads_bias == 1 - ? bias.expand({batch_size, 1, ngroups, seqlen_k}) - : ( - num_heads_bias == num_heads_k - ? bias.expand({batch_size, num_heads_k, ngroups, seqlen_k}) - : bias.reshape({batch_size, num_heads_k, ngroups, seqlen_k}) - ); + if (num_heads_bias == 1) { + bias = bias.reshape({batch_size_bias_og, 1, 1, seqlen_k_rounded}); + } else if (num_heads_bias == num_heads_k) { + bias = bias.reshape({batch_size_bias_og, num_heads_k, 1, seqlen_k_rounded}); + } else if (num_heads_bias == num_heads) { + bias = bias.reshape({batch_size_bias_og, num_heads_k, ngroups, seqlen_k_rounded}); + } } num_heads_mask = has_mask ? ((num_heads_mask == num_heads) ? num_heads_k : num_heads_mask) : 1; num_heads_bias = has_bias ? ((num_heads_bias == num_heads) ? num_heads_k : num_heads_bias) : 1; @@ -485,11 +492,6 @@ mha_fwd( out = torch::empty_like(q); } - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); - const int seqlen_q_rounded = round_multiple(seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); at::Tensor p; @@ -541,16 +543,13 @@ mha_fwd( q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); if (has_mask) { - mask = (orig_num_heads_mask == 1 || orig_num_heads_mask == num_heads_k) - ? mask.narrow(2, 0, 1) - : mask.reshape({batch_size, num_heads_k * seqlen_q, 1, seqlen_k}); + mask = mask.reshape({batch_size_mask_og, num_heads_mask_og, 1, seqlen_k_rounded}); } if (has_bias) { - bias = (orig_num_heads_bias == 1 || orig_num_heads_bias == num_heads_k) - ? bias.narrow(2, 0, 1) - : bias.reshape({batch_size, num_heads_k * seqlen_q, 1, seqlen_k}); + bias = bias.reshape({batch_size_bias_og, num_heads_bias_og, 1, seqlen_k_rounded}); } } + return {out, softmax_lse, p}; } From 2a8f9ea1a6e25923251cceeb939121cb08e92019 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 12 Oct 2025 18:53:42 +0800 Subject: [PATCH 04/12] Enforces 4D mask/bias; fixes dbias broadcasting Standardizes attention aux inputs to strict 4D shapes with contiguous last dim and explicit broadcasting over batch, heads, and seqlen_q. Removes 3D mask/bias handling and validates dimensions against rounded key length. Allocates/validates dbias with broadcast-aware shapes and updates reductions to correctly sum over group, batch, and seqlen_q when broadcast, improving correctness for MQA/GQA and padded key lengths. Improves shape checks and internal consistency to prevent silent misalignment and shape-induced bugs. --- csrc/flash_dmattn/flash_api.cpp | 97 ++++++++++++++------------------- 1 file changed, 40 insertions(+), 57 deletions(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index 576b8bb..e6a8b7a 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -795,14 +795,14 @@ mha_bwd( const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size - const std::optional &mask_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k - const std::optional &bias_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k + const std::optional &mask_, // {batch_size|1} x {num_heads|num_heads_k|1} x {seqlen_q|1} x {seqlen_k|1} + const std::optional &bias_, // {batch_size|1} x {num_heads|num_heads_k|1} x {seqlen_q|1} x {seqlen_k|1} const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &softmax_lse, // b x h x seqlen_q std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size - std::optional &dbias_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k + std::optional &dbias_, // {batch_size|1} x {num_heads|num_heads_k|1} x {seqlen_q|1} x {seqlen_k|1} const float softmax_scale, const bool is_causal, const float softcap, @@ -845,11 +845,8 @@ mha_bwd( mask = mask_.value(); TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); CHECK_DEVICE(mask); + TORCH_CHECK(mask.dim() == 4, "mask must have 4 dimensions with shape (batch_size, nheads, seqlen_q, seqlen_k)"); TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - if (mask.dim() == 3) { - // Add a dummy dimension for seqlen_q - mask = mask.unsqueeze(2).expand({-1, -1, q.size(1), -1}); - } } else { mask = torch::empty({0}, opts); } @@ -859,11 +856,8 @@ mha_bwd( bias = bias_.value(); TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); CHECK_DEVICE(bias); + TORCH_CHECK(bias.dim() == 4, "bias must have 4 dimensions with shape (batch_size, nheads, seqlen_q, seqlen_k)"); TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - if (bias.dim() == 3) { - // Add a dummy dimension for seqlen_q - bias = bias.unsqueeze(2).expand({-1, -1, q.size(1), -1}); - } } else { bias = torch::empty({0}, opts); } @@ -878,29 +872,39 @@ mha_bwd( const int num_heads_k = k.size(2); int num_heads_mask = has_mask ? mask.size(1) : 1; int num_heads_bias = has_bias ? bias.size(1) : 1; + int batch_size_mask = has_mask ? mask.size(0) : batch_size; + int batch_size_bias = has_bias ? bias.size(0) : batch_size; + int seqlen_q_mask = has_mask ? mask.size(2) : seqlen_q; + int seqlen_q_bias = has_bias ? bias.size(2) : seqlen_q; + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention backward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (has_mask) { + TORCH_CHECK(mask.size(0) == batch_size || mask.size(0) == 1, "Batch dimension in mask must be 1 or equal to batch size"); TORCH_CHECK(num_heads_mask == 1 || num_heads_mask == num_heads_k || num_heads_mask == num_heads, "Number of heads in mask must be 1, h_k or h"); + TORCH_CHECK(mask.size(2) == 1 || mask.size(2) == seqlen_q, "Query length dimension in mask must be 1 or equal to seqlen_q"); + TORCH_CHECK(mask.size(3) == seqlen_k_rounded, "Key length dimension in mask must be seqlen_k_rounded"); } if (has_bias) { + TORCH_CHECK(bias.size(0) == batch_size || bias.size(0) == 1, "Batch dimension in bias must be 1 or equal to batch size"); TORCH_CHECK(num_heads_bias == 1 || num_heads_bias == num_heads_k || num_heads_bias == num_heads, "Number of heads in bias must be 1, h_k or h"); + TORCH_CHECK(bias.size(2) == 1 || bias.size(2) == seqlen_q, "Query length dimension in bias must be 1 or equal to seqlen_q"); + TORCH_CHECK(bias.size(3) == seqlen_k_rounded, "Key length dimension in bias must be seqlen_k_rounded"); } - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); - const int seqlen_q_rounded = round_multiple(seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); - + at::Tensor dq, dk, dv, dbias; if (dq_.has_value()) { dq = dq_.value(); @@ -934,30 +938,14 @@ mha_bwd( dbias = dbias_.value(); TORCH_CHECK(dbias.dtype() == q_dtype, "dbias must have the same dtype as q"); CHECK_DEVICE(dbias); + TORCH_CHECK(dbias.dim() == 4, "dbias must have 4 dimensions with shape (batch_size, nheads, seqlen_q, seqlen_k)"); TORCH_CHECK(dbias.stride(-1) == 1, "dbias must have contiguous last dimension"); - if (dbias.dim() == 4) { - CHECK_SHAPE(dbias, batch_size, num_heads_bias, seqlen_q, seqlen_k); - } else { - CHECK_SHAPE(dbias, batch_size, num_heads_bias, seqlen_k); - } + TORCH_CHECK(dbias.size(0) == batch_size || dbias.size(0) == 1, "Batch dimension in dbias must be 1 or equal to batch size"); + TORCH_CHECK(dbias.size(1) == num_heads || dbias.size(1) == num_heads_k || dbias.size(1) == 1, "Number of heads in dbias must be 1, h_k or h"); + TORCH_CHECK(dbias.size(2) == seqlen_q || dbias.size(2) == 1, "Query length dimension in dbias must be 1 or equal to seqlen_q"); + TORCH_CHECK(dbias.size(3) == seqlen_k_rounded, "Key length dimension in dbias must be seqlen_k_rounded"); } else { - if (bias.dim() == 4) { - if (num_heads_bias == 1) { - dbias = torch::empty({batch_size, 1, seqlen_q, seqlen_k}, opts); - } else if (num_heads_bias == num_heads_k) { - dbias = torch::empty({batch_size, num_heads_k, seqlen_q, seqlen_k}, opts); - } else { - dbias = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts); - } - } else { - if (num_heads_bias == 1) { - dbias = torch::empty({batch_size, 1, seqlen_k}, opts); - } else if (num_heads_bias == num_heads_k) { - dbias = torch::empty({batch_size, num_heads_k, seqlen_k}, opts); - } else { - dbias = torch::empty({batch_size, num_heads, seqlen_k}, opts); - } - } + dbias = torch::empty({batch_size_bias, num_heads_bias, seqlen_q_bias, seqlen_k_rounded}, opts); } } else { dbias = torch::empty({0}, opts); @@ -990,8 +978,8 @@ mha_bwd( : dv; dbias_expanded = has_bias ? ( - (num_heads_bias != num_heads) || (bias_.has_value() && bias_.value().dim() == 3) // MQA / GQA or bias has no seqlen_q dimension - ? torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts) + (num_heads_bias != num_heads || batch_size_bias != batch_size || seqlen_q_bias != seqlen_q) // MQA / GQA or dbias has different batch size or seqlen_q + ? torch::empty({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts) : dbias ) : torch::empty({0}, opts); @@ -1046,24 +1034,19 @@ mha_bwd( at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); } - // For MQA/GQA or num_heads_bias != num_heads, we also need to sum dbias across the heads + // For MQA/GQA or dbias has different batch size or seqlen_q, we need to sum dbias across the groups, batch and seqlen_q if (has_bias) { - bool sum_seqlen_q = bias_.has_value() && bias_.value().dim() == 3; - if (num_heads_bias != num_heads) { - if (sum_seqlen_q) { - dbias_expanded = at::sum( - at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k}), {2} - ); - } else { - at::sum_out( - dbias, - at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k}), {2} - ); + if (num_heads_bias != num_heads && batch_size_bias == batch_size && seqlen_q_bias == seqlen_q) { + at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2}); + } else { + dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2}); + if (seqlen_q_bias == 1) { + dbias_expanded = at::sum(dbias_expanded, {2}, true); } - } - if (sum_seqlen_q) { - // We need to sum across the seqlen_q dimension - at::sum_out(dbias, dbias_expanded, {2}); + if (batch_size_bias == 1) { + dbias_expanded = at::sum(dbias_expanded, {0}, true); + } + dbias.copy_(dbias_expanded); } } From 755e0a473cbc8e56f9547e2be171789872612dcb Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 12 Oct 2025 18:55:56 +0800 Subject: [PATCH 05/12] Aligns mask/bias to 128 and supports broadcast Updates attention mask/bias handling to round the key length to a multiple of 128 and expand length-1 tensors or pad as needed, preventing shape mismatches and reducing unnecessary padding of K/V. Simplifies backward by slicing only the bias gradient to the original key length and removing tracking of the original sequence length. Clarifies docs to allow broadcastable dimensions for mask/bias across batch, heads, and sequence. --- flash_dmattn/flash_dmattn_interface.py | 31 +++++++++++++------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 893c638..10ab385 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -416,14 +416,17 @@ def forward( q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - seqlen_k_og = k.shape[1] - if seqlen_k_og % 8 != 0: - k = torch.nn.functional.pad(k, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8]) - v = torch.nn.functional.pad(v, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8]) - if mask is not None: - mask = torch.nn.functional.pad(mask, [0, 8 - seqlen_k_og % 8], value=False) - if bias is not None: - bias = torch.nn.functional.pad(bias, [0, 8 - seqlen_k_og % 8], value=0.0) + seqlen_k_rounded = round_multiple(k.shape[1], 128) + if mask is not None and mask.shape[-1] != seqlen_k_rounded: + if mask.shape[-1] == 1: + mask = mask.expand(*mask.shape[:-1], seqlen_k_rounded) + else: + mask = torch.nn.functional.pad(mask, [0, seqlen_k_rounded - mask.shape[-1]]) + if bias is not None and bias.shape[-1] != seqlen_k_rounded: + if bias.shape[-1] == 1: + bias = bias.expand(*bias.shape[:-1], seqlen_k_rounded) + else: + bias = torch.nn.functional.pad(bias, [0, seqlen_k_rounded - bias.shape[-1]]) out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward( q, @@ -443,7 +446,6 @@ def forward( ctx.is_causal = is_causal ctx.softcap = softcap ctx.deterministic = deterministic - ctx.seqlen_k_og = seqlen_k_og out = out_padded[..., :head_size_og] @@ -488,11 +490,8 @@ def backward( dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - if ctx.seqlen_k_og % 8 != 0: - dk = dk[:, : ctx.seqlen_k_og, :, :] - dv = dv[:, : ctx.seqlen_k_og, :, :] - if dbias is not None: - dbias = dbias[..., : ctx.seqlen_k_og] + if dbias is not None: + dbias = dbias[..., : k.shape[1]] return dq, dk, dv, None, dbias, None, None, None, None, None, None @@ -646,10 +645,10 @@ def flash_dmattn_func( key: torch.Tensor. The key tensor of shape (batch_size, seqlen, nheads_k, headdim) value: torch.Tensor. The value tensor of shape (batch_size, seqlen, nheads_k, headdim) attn_mask: torch.Tensor, optional. The attention mask boolean tensor of - shape (batch_size, {nheads|nheads_k|1}, {seqlen_q|0}, seqlen_k) to apply to the attention scores. + shape ({batch_size|1}, {nheads|nheads_k|1}, {seqlen_q|1}, {seqlen_k|1}) to apply to the attention scores. If None, no mask is applied. attn_bias: torch.Tensor, optional. The attention bias float tensor of - shape (batch_size, {nheads|nheads_k|1}, {seqlen_q|0}, seqlen_k) to add to the attention scores. + shape (batch_size, {nheads|nheads_k|1}, {seqlen_q|1}, {seqlen_k|1}) to add to the attention scores. If None, no bias is applied. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). From 71f86314aa1f15b93539282e1537d0f3f4d9bef8 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 12 Oct 2025 18:58:22 +0800 Subject: [PATCH 06/12] Clarifies mask/bias shapes; updates gradient docs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Corrects tensor shape examples to use 1-sized broadcastable dims for batch, heads, query, and key, improving clarity and avoiding invalid 0-length notation. Removes “dbias” jargon from the gradient description for clearer wording. Syncs English and Chinese documentation on shape semantics. --- README.md | 6 +++--- README_zh.md | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 89c6082..c163452 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ Flash-DMA is a high-performance attention implementation that integrates Flash A ## Key Features ### 🎯 Core Kernel Advantages -- **Mask & Bias Support**: Native support for `(batch_size, {1|num_kv_heads|num_heads}, {0|query_len}, key_len)` shaped attention mask and attention bias tensors +- **Mask & Bias Support**: Native support for `({1|batch_size}, {1|num_kv_heads|num_heads}, {1|query_len}, {1|key_len})` shaped attention mask and attention bias tensors - **Intelligent Computation Skipping**: Block-level automatic skipping mechanism based on masks, completely bypassing computation and memory access for zero-mask blocks - **Complete Gradient Support**: Built-in full gradient computation path for attention bias, supporting end-to-end training @@ -236,9 +236,9 @@ Flash-DMA integrates the efficient memory access patterns of Flash Attention wit ### Core Technology Integration -- **🎯 Native Mask & Bias Support**: Kernels directly process `(batch_size, {1|num_kv_heads|num_heads}, {0|query_len}, key_len)` shaped tensors +- **🎯 Native Mask & Bias Support**: Kernels directly process `({1|batch_size}, {1|num_kv_heads|num_heads}, {1|query_len}, {1|key_len})` shaped tensors - **⚡ Block-level Intelligent Skipping**: Unified OR-reduction skipping logic based on masks, completely avoiding computation and memory access for zero blocks -- **🔄 Complete Gradient Chain**: Built-in attention bias gradient computation (dbias) supporting end-to-end differentiable training +- **🔄 Complete Gradient Chain**: Built-in attention bias gradient computation supporting end-to-end differentiable training ### Key Optimization Strategies diff --git a/README_zh.md b/README_zh.md index 8e16c41..2550652 100644 --- a/README_zh.md +++ b/README_zh.md @@ -18,7 +18,7 @@ Flash-DMA 是一个高性能的注意力实现,将 Flash Attention 的内存 ## 主要特性 ### 🎯 核心内核优势 -- **Mask & Bias 支持**: 原生支持 `(batch_size, {1|num_kv_heads|num_heads}, {0|query_len}, key_len)` 形状的 attention_mask 和 attention_bias 张量 +- **Mask & Bias 支持**: 原生支持 `({1|batch_size}, {1|num_kv_heads|num_heads}, {1|query_len}, {1|key_len})` 形状的 attention_mask 和 attention_bias 张量 - **智能计算跳过**: 基于 attention_mask 的 block-level 自动跳过机制,完全跳过全零 mask 区块的计算和内存访问 - **完整梯度支持**: 内置 attention_bias 的完整梯度计算路径,支持端到端训练 @@ -236,7 +236,7 @@ Flash-DMA 通过将 Flash Attention 的高效内存访问模式与动态掩码 ### 核心技术融合 -- **🎯 Mask & Bias 原生支持**: 内核直接处理 `(batch_size, {1|num_kv_heads|num_heads}, {0|query_len}, key_len)` 形状的张量 +- **🎯 Mask & Bias 原生支持**: 内核直接处理 `({1|batch_size}, {1|num_kv_heads|num_heads}, {1|query_len}, {1|key_len})` 形状的张量 - **⚡ Block-level 智能跳过**: 基于 mask 的统一 OR-reduction 跳过逻辑,完全避免全零区块的计算和内存访问 - **🔄 完整梯度链路**: 内置 attention bias 梯度计算,支持端到端可微分训练 From bcdf7adb81f94a3957a31451cccdba3f7630b167 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 12 Oct 2025 18:59:55 +0800 Subject: [PATCH 07/12] Refactors attention to explicit bias and mask Simplifies dynamic masking by accepting precomputed attention bias and an optional causal mask, removing dependence on internal ZOH/dt projection parameters and unifying the API across Python, CUDA, Triton, and Flex backends. Applies masking explicitly via a boolean mask with -inf before softmax and selects a top-k window per query (optionally respecting the causal mask), improving correctness and consistency across implementations. Aligns function signatures, renames keep_window_size to window_size, removes unused return flags, and fixes tensor layouts/contiguity where needed. Updates tests to generate attention bias and derive causal masks, improving forward-equivalence coverage and determinism while reducing coupling to value-state-derived features. --- benchmarks/forward_equivalence.py | 386 +++++++++++++----------------- 1 file changed, 163 insertions(+), 223 deletions(-) diff --git a/benchmarks/forward_equivalence.py b/benchmarks/forward_equivalence.py index 1da6f2a..fe624ab 100644 --- a/benchmarks/forward_equivalence.py +++ b/benchmarks/forward_equivalence.py @@ -50,104 +50,66 @@ flex_dmattn_func = None -def prepare_dynamic_mask( - hidden_states: torch.Tensor, - zoh_states: torch.Tensor, - keep_window_size: int = 2048, - cache_position: torch.Tensor = None, -): +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + Transform from (batch, num_key_value_heads, seqlen, head_dim) + to (batch, num_attention_heads, seqlen, head_dim) """ - Calculate dynamic attention mask to mask tokens for sparse attention. + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - Combine `zoh_states` with `attention_mask` to generate the final `attn_mask`. +def prepare_mask( + hidden_states: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor = None, + window_size: int = None, +): + """ Args: hidden_states: Input hidden states to determine dtype minimum value - zoh_states: zoh_states of shape (batch_size, num_kv_heads, key_sequence_length) - keep_window_size: Window size of tokens not dynamically masked - cache_position: Optional cache position for causal masking + attn_bias: Attention bias of shape (batch_size, num_heads, query_length, key_length) + causal_mask: Optional causal mask to apply + window_size: Window size of tokens not masked Returns: tuple: (attn_bias, attn_mask) """ dtype = hidden_states.dtype min_dtype = torch.finfo(dtype).min - attn_bias = zoh_states[:, :, None, :].expand( - -1, -1, hidden_states.shape[2], -1 - ).to(dtype) # [batch_size, num_kv_heads, query_len, key_len] - - if cache_position is not None: - attn_bias = attn_bias.masked_fill( - torch.arange(attn_bias.shape[-1], device=attn_bias.device) > cache_position.reshape(-1, 1), - min_dtype - ) - if attn_bias.shape[-1] > keep_window_size: - topk_values, topk_indices = torch.topk( - attn_bias, keep_window_size, dim=-1, largest=True, sorted=False - ) - valid_topk = topk_values != min_dtype - attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk) - attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype) + if attn_bias.shape[-1] > window_size: + if causal_mask is not None: + topk_values, topk_indices = torch.topk( + attn_bias.masked_fill(~causal_mask, min_dtype).detach(), + window_size, dim=-1, largest=True, sorted=False + ) + else: + topk_values, topk_indices = torch.topk( + attn_bias, + window_size, dim=-1, largest=True, sorted=False + ) + attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device).scatter_(-1, topk_indices, topk_values != min_dtype) else: - attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) + attn_mask = causal_mask.expand_as(attn_bias) if causal_mask is not None else torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) return attn_bias, attn_mask -def calculate_zoh_states(value_states, dt_proj, A): - """ - Calculate zoh states for dynamic mask attention. - - Args: - value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] - causal_mask: Optional causal mask - - Returns: - zoh_states: [batch_size, num_kv_heads, key_len] - """ - batch_size, _, key_len, _ = value_states.shape - - # Transpose and reshape value_states, then matrix multiply with dt_proj.T - dt_result = torch.matmul( - value_states.transpose(-2, -3).reshape(batch_size, key_len, -1), - dt_proj.T - ) - - # Apply softplus activation and coefficient A - dt_states = torch.exp(F.softplus(dt_result) * A) - zoh_states = dt_states.transpose(-1, -2) # [batch_size, num_kv_heads, key_len] - - return zoh_states - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). - Transform from (batch, num_key_value_heads, seqlen, head_dim) - to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - def dynamic_mask_attention_python( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, - is_causal=True, + window_size: int, + is_causal: bool, ): """ Python reference implementation of dynamic mask attention. @@ -156,11 +118,10 @@ def dynamic_mask_attention_python( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [batch_size, num_kv_heads, query_len, key_len] + causal_mask: [batch_size, 1, query_len, key_len] or None scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window + window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: @@ -171,26 +132,25 @@ def dynamic_mask_attention_python( num_queries_per_kv = num_heads // num_kv_heads - zoh_states = calculate_zoh_states(value_states, dt_proj, A) - - # Use prepare_dynamic_mask function to process dynamic mask - attn_bias, attn_mask = prepare_dynamic_mask( + attn_bias, attn_mask = prepare_mask( query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None + attn_bias, + causal_mask if is_causal else None, + window_size, ) - - # Sparse attention weight calculation + key_states = repeat_kv(key_states, num_queries_per_kv) value_states = repeat_kv(value_states, num_queries_per_kv) attn_bias = repeat_kv(attn_bias, num_queries_per_kv) - - attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) - attn_weights = attn_weights * scaling + attn_bias # Apply scaling and zoh - attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization - attn_outputs = torch.matmul(attn_weights, value_states) - attn_outputs = attn_outputs.transpose(1, 2).contiguous() # Transpose to [batch, query_len, num_heads, head_dim] + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + + # Sparse attention weight calculation + attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) # Dot product weights + attn_weights = attn_weights * scaling + attn_bias # Apply scaling and bias + attn_weights = attn_weights.masked_fill(~attn_mask, float('-inf')) # Apply mask + attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization + attn_outputs = torch.matmul(attn_weights, value_states) # Weighted sum of values + attn_outputs = attn_outputs.transpose(1, 2).contiguous() # Transpose to [batch, query_len, num_heads, head_dim] return attn_outputs @@ -199,13 +159,11 @@ def dynamic_mask_attention_cuda( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, - is_causal=True, - return_softmax=False + window_size: int, + is_causal: bool, ): """ CUDA implementation of dynamic mask attention. @@ -214,13 +172,11 @@ def dynamic_mask_attention_cuda( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [batch_size, num_kv_heads, query_len, key_len] + causal_mask: [batch_size, 1, query_len, key_len] or None scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window + window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking - return_softmax: Whether to return softmax weights Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] @@ -228,35 +184,30 @@ def dynamic_mask_attention_cuda( if flash_dmattn_func is None: raise RuntimeError("flash_dmattn_func not available") - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) - - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( + attn_bias, attn_mask = prepare_mask( query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] + attn_bias, + causal_mask if is_causal else None, + window_size, + ) # Ensure correct data types and memory layout for CUDA function - # CUDA function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] value_states = value_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] # Call the flash_dmattn_func interface attn_outputs = flash_dmattn_func( - query_states, # [batch, query_len, num_heads, head_dim] - key_states, # [batch, key_len, num_kv_heads, head_dim] - value_states, # [batch, key_len, num_kv_heads, head_dim] - attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len] - attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len] + query_states, + key_states, + value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, is_causal=is_causal, softmax_scale=scaling, softcap=0.0, deterministic=True, - return_attn_probs=return_softmax + return_attn_probs=False, ) return attn_outputs # [batch, query_len, num_heads, head_dim] @@ -266,12 +217,11 @@ def dynamic_mask_attention_triton( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, - is_causal=True, + window_size: int, + is_causal: bool, ): """ Triton implementation of dynamic mask attention. @@ -280,11 +230,10 @@ def dynamic_mask_attention_triton( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [batch_size, num_kv_heads, query_len, key_len] + causal_mask: [batch_size, 1, query_len, key_len] or None scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window + window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: @@ -297,16 +246,12 @@ def dynamic_mask_attention_triton( _, num_kv_heads, _, _ = key_states.shape num_queries_per_kv = num_heads // num_kv_heads - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) - - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( + attn_bias, attn_mask = prepare_mask( query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] + attn_bias, + causal_mask if is_causal else None, + window_size, + ) # Repeat KV for multi-head attention (GQA support) key_states = repeat_kv(key_states, num_queries_per_kv) @@ -323,13 +268,13 @@ def dynamic_mask_attention_triton( # Call the Triton implementation attn_outputs = triton_dmattn_func( - query_states, # q: [batch, seqlen_q, num_heads, head_dim] - key_states, # k: [batch, seqlen_k, num_heads, head_dim] - value_states, # v: [batch, seqlen_k, num_heads, head_dim] - attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] - attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] - is_causal=is_causal, # causal masking - softmax_scale=scaling # scaling factor + query_states, + key_states, + value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, + is_causal=is_causal, + softmax_scale=scaling, ) return attn_outputs # [batch, query_len, num_heads, head_dim] @@ -339,12 +284,11 @@ def dynamic_mask_attention_flex( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, - is_causal=True, + window_size: int, + is_causal: bool, ): """ Flex Attention implementation of dynamic mask attention. @@ -353,11 +297,10 @@ def dynamic_mask_attention_flex( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [batch_size, num_kv_heads, query_len, key_len] + causal_mask: [batch_size, 1, query_len, key_len] or None scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window + window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: @@ -370,16 +313,12 @@ def dynamic_mask_attention_flex( _, num_kv_heads, _, _ = key_states.shape num_queries_per_kv = num_heads // num_kv_heads - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) - - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( + attn_bias, attn_mask = prepare_mask( query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] + attn_bias, + causal_mask if is_causal else None, + window_size, + ) # Repeat KV for multi-head attention (GQA support) key_states = repeat_kv(key_states, num_queries_per_kv) @@ -387,18 +326,22 @@ def dynamic_mask_attention_flex( attn_mask = repeat_kv(attn_mask, num_queries_per_kv) attn_bias = repeat_kv(attn_bias, num_queries_per_kv) - # Flex attention expects: q, k, v in [batch, num_heads, seqlen, head_dim] format - # But attention_mask and attention_bias in [batch, num_heads, query_len, key_len] format - + # Ensure correct data types and memory layout for Flex function + query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] + key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + # Call the Flex Attention implementation attn_outputs = flex_dmattn_func( - query_states.transpose(1, 2), # q: [batch, query_len, num_heads, head_dim] - key_states.transpose(1, 2), # k: [batch, key_len, num_heads, head_dim] - value_states.transpose(1, 2), # v: [batch, key_len, num_heads, head_dim] - attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len] - attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len] - is_causal=is_causal, # is_causal: whether to apply causal masking - softmax_scale=scaling # scaling factor + query_states, + key_states, + value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, + is_causal=is_causal, + softmax_scale=scaling, ) return attn_outputs # [batch, query_len, num_heads, head_dim] @@ -611,18 +554,18 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): torch.cuda.synchronize() batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal = config - + # Progress indicator progress_filled = "█" * (i + 1) progress_empty = "░" * (len(test_configs) - i - 1) progress_bar = f"[{progress_filled}{progress_empty}]" - + print(f"\n🧪 Test configuration {i+1}/{len(test_configs)} {progress_bar}") print(f" 📊 batch_size={batch_size}, num_heads={num_heads}, num_kv_heads={num_kv_heads}") print(f" 📏 query_len={query_len}, key_len={key_len}, head_dim={head_dim}") print(f" 🔒 is_causal={is_causal}") print(f" 🎯 Accuracy threshold: {accuracy_threshold*100:.1f}%") - + # Create random input data query_states = torch.randn( batch_size, num_heads, query_len, head_dim, @@ -636,40 +579,39 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): batch_size, num_kv_heads, key_len, head_dim, device=device, dtype=torch.bfloat16 ) - dt_proj = torch.randn( - num_kv_heads, num_kv_heads * head_dim, + attn_bias = torch.randn( + batch_size, num_kv_heads, query_len, key_len, device=device, dtype=torch.bfloat16 ) - A = torch.randn(num_kv_heads, device=device, dtype=torch.bfloat16) - - # Create cache position cache_position = torch.arange(key_len - query_len, key_len, device=device) - + causal_mask = torch.arange(key_len, device=device) <= cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + # Set scaling factor and keep window size scaling = head_dim ** -0.5 - keep_window_size = 1024 + window_size = 1024 # Run Python implementation start_time = time.time() py_output = dynamic_mask_attention_python( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, scaling, + window_size, is_causal ) torch.cuda.synchronize() py_time = time.time() - start_time - + # Run CUDA implementation start_time = time.time() cuda_output = dynamic_mask_attention_cuda( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, scaling, + window_size, is_causal ) torch.cuda.synchronize() cuda_time = time.time() - start_time - - + + # Analyze differences py_output_copy = py_output.clone() cuda_output_copy = cuda_output.clone() @@ -692,7 +634,7 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): if not is_close and max_diff > 1e-2: print(" ⚠️ Difference too large, stopping subsequent tests.") break - del query_states, key_states, value_states, dt_proj, A, cache_position, py_output, cuda_output, py_output_copy, cuda_output_copy + del query_states, key_states, value_states, attn_bias, causal_mask, py_output, cuda_output, py_output_copy, cuda_output_copy torch.cuda.empty_cache() gc.collect() torch.cuda.synchronize() @@ -774,8 +716,8 @@ def test_triton_forward_equivalence(accuracy_threshold=0.95): (1, 2, 1, 4096, 4096, 128, False), # Not support head_dim > 128 in triton yet - # (1, 2, 1, 128, 128, 128, True), - # (1, 2, 1, 128, 128, 128, False), + # (1, 2, 1, 128, 128, 256, True), + # (1, 2, 1, 128, 128, 256, False), # (1, 2, 1, 256, 256, 256, True), # (1, 2, 1, 256, 256, 256, False), # (1, 2, 1, 512, 512, 256, True), @@ -825,25 +767,24 @@ def test_triton_forward_equivalence(accuracy_threshold=0.95): batch_size, num_kv_heads, key_len, head_dim, device=device, dtype=torch.bfloat16 ) - dt_proj = torch.randn( - num_kv_heads, num_kv_heads * head_dim, + attn_bias = torch.randn( + batch_size, num_kv_heads, query_len, key_len, device=device, dtype=torch.bfloat16 ) - A = torch.randn(num_kv_heads, device=device, dtype=torch.bfloat16) - - # Create cache position - cache_position = torch.arange(0, query_len + 0, device=device) + cache_position = torch.arange(key_len - query_len, key_len, device=device) + causal_mask = torch.arange(key_len, device=device) <= cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) # Set scaling factor and keep window size scaling = head_dim ** -0.5 - keep_window_size = 64 + window_size = 1024 # Run Python implementation start_time = time.time() py_output = dynamic_mask_attention_python( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, scaling, + window_size, is_causal ) torch.cuda.synchronize() py_time = time.time() - start_time @@ -853,8 +794,8 @@ def test_triton_forward_equivalence(accuracy_threshold=0.95): try: triton_output = dynamic_mask_attention_triton( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, scaling, + window_size, is_causal ) torch.cuda.synchronize() triton_time = time.time() - start_time @@ -896,8 +837,8 @@ def test_triton_forward_equivalence(accuracy_threshold=0.95): if triton_max_diff > 1e-2: print(" ⚠️ Difference too large, stopping subsequent tests.") break - - del query_states, key_states, value_states, dt_proj, A, cache_position, py_output, py_output_copy + + del query_states, key_states, value_states, attn_bias, causal_mask, py_output, py_output_copy if triton_output is not None: del triton_output, triton_output_copy torch.cuda.empty_cache() @@ -1031,25 +972,24 @@ def test_flex_forward_equivalence(accuracy_threshold=0.95): batch_size, num_kv_heads, key_len, head_dim, device=device, dtype=torch.bfloat16 ) - dt_proj = torch.randn( - num_kv_heads, num_kv_heads * head_dim, + attn_bias = torch.randn( + batch_size, num_kv_heads, query_len, key_len, device=device, dtype=torch.bfloat16 ) - A = torch.randn(num_kv_heads, device=device, dtype=torch.bfloat16) - - # Create cache position - cache_position = torch.arange(0, query_len + 0, device=device) + cache_position = torch.arange(key_len - query_len, key_len, device=device) + causal_mask = torch.arange(key_len, device=device) <= cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) # Set scaling factor and keep window size scaling = head_dim ** -0.5 - keep_window_size = 64 + window_size = 1024 # Run Python implementation start_time = time.time() py_output = dynamic_mask_attention_python( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, scaling, cache_position, + window_size, is_causal ) torch.cuda.synchronize() py_time = time.time() - start_time @@ -1059,8 +999,8 @@ def test_flex_forward_equivalence(accuracy_threshold=0.95): try: flex_output = dynamic_mask_attention_flex( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, scaling, + window_size, is_causal ) torch.cuda.synchronize() flex_time = time.time() - start_time @@ -1102,8 +1042,8 @@ def test_flex_forward_equivalence(accuracy_threshold=0.95): if flex_max_diff > 1e-2: print(" ⚠️ Difference too large, stopping subsequent tests.") break - - del query_states, key_states, value_states, dt_proj, A, cache_position, py_output, py_output_copy + + del query_states, key_states, value_states, attn_bias, causal_mask, py_output, py_output_copy if flex_output is not None: del flex_output, flex_output_copy torch.cuda.empty_cache() @@ -1203,4 +1143,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() From 1cbd2f990760bec505793b854617bda3ea31b9ea Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 12 Oct 2025 19:00:21 +0800 Subject: [PATCH 08/12] Unifies attention kernels with bias+mask windowing Refactors attention paths to accept external attention bias and boolean causal mask, replacing zoh/dt-based masking and cache-position logic. Introduces a generic mask preparer that applies top-k windowing (optionally causal-aware), and standardizes interfaces across SDPA, Flash, Triton, and Flex implementations. Removes zoh/dt projection and related params, repeats KV artifacts for GQA, and consistently applies additive masks. Updates benchmarks to generate bias/mask inputs, rename keep_window_size to window_size, adjust head dims, and harmonize result handling and output labeling. Improves API consistency, simplifies experimentation with custom biases, and aligns masking semantics across kernels for more reliable benchmarking. --- benchmarks/forward_performance.py | 450 +++++++++++++----------------- 1 file changed, 199 insertions(+), 251 deletions(-) diff --git a/benchmarks/forward_performance.py b/benchmarks/forward_performance.py index 0d48c1a..5730e0e 100644 --- a/benchmarks/forward_performance.py +++ b/benchmarks/forward_performance.py @@ -72,86 +72,51 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def prepare_dynamic_mask( +def prepare_mask( hidden_states: torch.Tensor, - zoh_states: torch.Tensor, - keep_window_size: int = 2048, - cache_position: torch.Tensor = None, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor = None, + window_size: int = None, ): """ - Calculate dynamic attention mask to mask tokens for sparse attention. - - Combine `zoh_states` with `attention_mask` to generate the final `attn_mask`. - Args: hidden_states: Input hidden states to determine dtype minimum value - zoh_states: zoh_states of shape (batch_size, num_kv_heads, key_sequence_length) - keep_window_size: Window size of tokens not dynamically masked - cache_position: Optional cache position for causal masking + attn_bias: Attention bias of shape (batch_size, num_heads, query_length, key_length) + causal_mask: Optional causal mask to apply + window_size: Window size of tokens not masked Returns: tuple: (attn_bias, attn_mask) """ dtype = hidden_states.dtype min_dtype = torch.finfo(dtype).min - attn_bias = zoh_states[:, :, None, :].expand( - -1, -1, hidden_states.shape[2], -1 - ).to(dtype) # [batch_size, num_kv_heads, query_len, key_len] - - if cache_position is not None: - attn_bias = attn_bias.masked_fill( - torch.arange(attn_bias.shape[-1], device=attn_bias.device) > cache_position.reshape(-1, 1), - min_dtype - ) - if attn_bias.shape[-1] > keep_window_size: - topk_values, topk_indices = torch.topk( - attn_bias, keep_window_size, dim=-1, largest=True, sorted=False - ) - valid_topk = topk_values != min_dtype - attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk) - attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype) + if attn_bias.shape[-1] > window_size: + if causal_mask is not None: + topk_values, topk_indices = torch.topk( + attn_bias.masked_fill(~causal_mask, min_dtype).detach(), + window_size, dim=-1, largest=True, sorted=False + ) + else: + topk_values, topk_indices = torch.topk( + attn_bias, + window_size, dim=-1, largest=True, sorted=False + ) + attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device).scatter_(-1, topk_indices, topk_values != min_dtype) else: - attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) + attn_mask = causal_mask.expand_as(attn_bias) if causal_mask is not None else torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) return attn_bias, attn_mask -def calculate_zoh_states(value_states, dt_proj, A): - """ - Calculate zoh states for dynamic mask attention. - - Args: - value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] - causal_mask: Optional causal mask - - Returns: - zoh_states: [batch_size, num_kv_heads, key_len] - """ - batch_size, _, key_len, _ = value_states.shape - - # Transpose and reshape value_states, then matrix multiply with dt_proj.T - dt_result = torch.matmul( - value_states.transpose(-2, -3).reshape(batch_size, key_len, -1), - dt_proj.T - ) - - # Apply softplus activation and coefficient A - dt_states = torch.exp(F.softplus(dt_result) * A) - zoh_states = dt_states.transpose(-1, -2) # [batch_size, num_kv_heads, key_len] - - return zoh_states - - def scaled_dot_product_attention_cuda( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - scaling: float, + attn_bias: torch.Tensor, causal_mask: torch.Tensor, - is_causal=True, + scaling: float, + window_size: int, + is_causal: bool, ): """ CUDA implementation of SDPA baseline. @@ -160,24 +125,36 @@ def scaled_dot_product_attention_cuda( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] + attn_bias: [batch_size, num_heads, query_length, key_length] + causal_mask: [batch_size, 1, query_length, key_length] or None + window_size: Number of tokens to keep in attention window scaling: Attention scaling factor - causal_mask: Causal attention mask is_causal: Whether to apply causal masking Returns: - attn_outputs or "OOM" if out of memory + tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ - _, _, query_len, _ = query_states.shape - _, _, key_len, _ = key_states.shape - if query_len > 32768 and key_len > 32768: - return "OOM" + _, num_heads, _, _ = query_states.shape + _, num_kv_heads, _, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads + + attn_bias, attn_mask = prepare_mask( + query_states, + attn_bias, + causal_mask if is_causal else None, + window_size, + ) + + # Repeat KV for multi-head attention (GQA support) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_bias = repeat_kv(attn_bias, num_queries_per_kv) query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() + attn_bias = attn_bias.masked_fill(~attn_mask, torch.finfo(query_states.dtype).min).contiguous() try: - # Only measure the core attention computation torch.cuda.synchronize() start_time = time.time() @@ -185,17 +162,17 @@ def scaled_dot_product_attention_cuda( query_states, key_states, value_states, - attn_mask=causal_mask, - softmax_scale=scaling, - # is_causal=is_causal if query_len == key_len else False, - enable_gqa=True + attn_mask=attn_bias, + scale=scaling, + # is_causal=is_causal, + enable_gqa=True, ) torch.cuda.synchronize() end_time = time.time() - attn_outputs = attn_outputs.transpose(1, 2).contiguous() # Transpose to [batch, query_len, num_heads, head_dim] - return attn_outputs, (end_time - start_time) * 1000 # Return output and time in ms + attn_outputs = attn_outputs.transpose(1, 2).contiguous() + return attn_outputs, (end_time - start_time) * 1000 except torch.cuda.OutOfMemoryError: return "OOM", 0 @@ -204,13 +181,11 @@ def dynamic_mask_attention_cuda( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, + window_size=2048, is_causal=True, - return_softmax=False ): """ CUDA implementation of dynamic mask attention. @@ -219,33 +194,26 @@ def dynamic_mask_attention_cuda( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [batch_size, num_heads, query_length, key_length] + causal_mask: [batch_size, 1, query_length, key_length] or None + window_size: Number of tokens to keep in attention window scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking - return_softmax: Whether to return softmax weights Returns: - attn_outputs: [batch_size, query_len, num_heads, head_dim] + tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ if flash_dmattn_func is None: return "Not Available", 0 - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) - - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( + attn_bias, attn_mask = prepare_mask( query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] + attn_bias, + causal_mask if is_causal else None, + window_size, + ) # Ensure correct data types and memory layout for CUDA function - # CUDA function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] value_states = value_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] @@ -254,24 +222,23 @@ def dynamic_mask_attention_cuda( torch.cuda.synchronize() start_time = time.time() - # Call the new flash_dmattn_func interface attn_outputs = flash_dmattn_func( - query_states, # [batch, query_len, num_heads, head_dim] - key_states, # [batch, key_len, num_kv_heads, head_dim] - value_states, # [batch, key_len, num_kv_heads, head_dim] - attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len] - attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len] + query_states, + key_states, + value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, is_causal=is_causal, softmax_scale=scaling, softcap=0.0, deterministic=False, - return_attn_probs=return_softmax + return_attn_probs=False ) torch.cuda.synchronize() end_time = time.time() - return attn_outputs, (end_time - start_time) * 1000 # Return output and time in ms + return attn_outputs, (end_time - start_time) * 1000 except torch.cuda.OutOfMemoryError: return "OOM", 0 @@ -280,11 +247,10 @@ def dynamic_mask_attention_triton( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, + window_size=2048, is_causal=True, ): """ @@ -294,15 +260,14 @@ def dynamic_mask_attention_triton( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [batch_size, num_heads, query_length, key_length] + causal_mask: [batch_size, 1, query_length, key_length] or None + window_size: Number of tokens to keep in attention window scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: - attn_outputs: [batch_size, query_len, num_heads, head_dim] + tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ if triton_dmattn_func is None: return "Not Available", 0 @@ -311,44 +276,38 @@ def dynamic_mask_attention_triton( _, num_kv_heads, _, _ = key_states.shape num_queries_per_kv = num_heads // num_kv_heads - try: - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) + attn_bias, attn_mask = prepare_mask( + query_states, + attn_bias, + causal_mask if is_causal else None, + window_size, + ) - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( - query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] - - # Repeat KV for multi-head attention (GQA support) - key_states = repeat_kv(key_states, num_queries_per_kv) - value_states = repeat_kv(value_states, num_queries_per_kv) - attn_mask = repeat_kv(attn_mask, num_queries_per_kv) - attn_bias = repeat_kv(attn_bias, num_queries_per_kv) - - # Triton function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format - query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] - key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] - value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] - attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] - attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] - - # Only measure the core Triton kernel computation + # Repeat KV for multi-head attention (GQA support) + key_states = repeat_kv(key_states, num_queries_per_kv) + value_states = repeat_kv(value_states, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_bias = repeat_kv(attn_bias, num_queries_per_kv) + + # Ensure correct data types and memory layout for Triton function + query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] + key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + + try: torch.cuda.synchronize() start_time = time.time() - # Call the Triton implementation attn_outputs = triton_dmattn_func( - query_states, # q: [batch, seqlen_q, num_heads, head_dim] - key_states, # k: [batch, seqlen_k, num_heads, head_dim] - value_states, # v: [batch, seqlen_k, num_heads, head_dim] - attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] - attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] - is_causal=is_causal, # causal masking - softmax_scale=scaling # scaling factor + query_states, + key_states, + value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, + is_causal=is_causal, + softmax_scale=scaling, ) torch.cuda.synchronize() @@ -363,11 +322,10 @@ def dynamic_mask_attention_flex( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, + window_size=2048, is_causal=True, ): """ @@ -377,15 +335,14 @@ def dynamic_mask_attention_flex( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [batch_size, num_heads, query_length, key_length] + causal_mask: [batch_size, 1, query_length, key_length] or None + window_size: Number of tokens to keep in attention window scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: - attn_outputs: [batch_size, query_len, num_heads, head_dim] + tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ if flex_dmattn_func is None: return "Not Available", 0 @@ -394,40 +351,39 @@ def dynamic_mask_attention_flex( _, num_kv_heads, _, _ = key_states.shape num_queries_per_kv = num_heads // num_kv_heads - try: - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) + attn_bias, attn_mask = prepare_mask( + query_states, + attn_bias, + causal_mask if is_causal else None, + window_size, + ) - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( - query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] - - # Repeat KV for multi-head attention (GQA support) - key_states = repeat_kv(key_states, num_queries_per_kv) - value_states = repeat_kv(value_states, num_queries_per_kv) - attn_mask = repeat_kv(attn_mask, num_queries_per_kv) - attn_bias = repeat_kv(attn_bias, num_queries_per_kv) - - # Flex attention expects: q, k, v in [batch, num_heads, seqlen, head_dim] format - # But attention_mask and attention_bias in [batch, num_heads, query_len, key_len] format - - # Only measure the core Flex Attention computation + # Repeat KV for multi-head attention (GQA support) + key_states = repeat_kv(key_states, num_queries_per_kv) + value_states = repeat_kv(value_states, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_bias = repeat_kv(attn_bias, num_queries_per_kv) + + # Ensure correct data types and memory layout for Flex function + query_states = query_states.transpose(1, 2).contiguous() + key_states = key_states.transpose(1, 2).contiguous() + value_states = value_states.transpose(1, 2).contiguous() + attn_mask = attn_mask.contiguous() + attn_bias = attn_bias.contiguous() + + try: torch.cuda.synchronize() start_time = time.time() # Call the Flex Attention implementation attn_outputs = flex_dmattn_func( - query_states.transpose(1, 2), # q: [batch, query_len, num_heads, head_dim] - key_states.transpose(1, 2), # k: [batch, key_len, num_heads, head_dim] - value_states.transpose(1, 2), # v: [batch, key_len, num_heads, head_dim] - attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len] - attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len] - is_causal=is_causal, # is_causal: whether to apply causal masking - softmax_scale=scaling # scaling factor + query_states, + key_states, + value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, + is_causal=is_causal, + softmax_scale=scaling, ) torch.cuda.synchronize() @@ -457,14 +413,14 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ Benchmark attention performance for a given configuration. Args: - config: Tuple of (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal) + config: Tuple of (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, window_size, is_causal) num_runs: Number of benchmark runs warmup_runs: Number of warmup runs Returns: dict: Performance metrics """ - batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal = config + batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, window_size, is_causal = config device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Create random input data @@ -480,21 +436,12 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ batch_size, num_kv_heads, key_len, head_dim, device=device, dtype=torch.bfloat16 ) - dt_proj = torch.randn( - num_kv_heads, num_kv_heads * head_dim, + attn_bias = torch.randn( + batch_size, num_kv_heads, query_len, key_len, device=device, dtype=torch.bfloat16 ) - A = torch.randn(num_kv_heads, device=device, dtype=torch.bfloat16) - - # Create custom causal mask with cache position cache_position = torch.arange(key_len - query_len, key_len, device=device) - min_type = torch.finfo(value_states.dtype).min - causal_mask = torch.full( - (query_len, key_len), fill_value=min_type, - device=device, dtype=value_states.dtype - ) - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(key_len, device=device) > cache_position.reshape(-1, 1) + causal_mask = torch.arange(key_len, device=device) <= cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) # Set scaling factor from config @@ -531,7 +478,8 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ for _ in range(warmup_runs): result = scaled_dot_product_attention_cuda( query_states, key_states, value_states, - scaling, causal_mask, is_causal + attn_bias, causal_mask, + scaling, window_size, is_causal ) if result[0] == "OOM": results['sdpa_forward_status'] = 'OOM' @@ -546,7 +494,8 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ for _ in range(num_runs): result = scaled_dot_product_attention_cuda( query_states, key_states, value_states, - scaling, causal_mask, is_causal + attn_bias, causal_mask, + scaling, window_size, is_causal ) if result[0] == "OOM": @@ -571,8 +520,8 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ for _ in range(warmup_runs): result = dynamic_mask_attention_cuda( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, + scaling, window_size, is_causal ) if result[0] == "OOM": results['fdma_cuda_forward_status'] = 'OOM' @@ -587,8 +536,8 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ for _ in range(num_runs): result = dynamic_mask_attention_cuda( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, + scaling, window_size, is_causal ) if result[0] == "OOM": @@ -613,8 +562,8 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ for _ in range(warmup_runs): result = dynamic_mask_attention_triton( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, + scaling, window_size, is_causal ) if result[0] in ["OOM", "Not Available"]: results['fdma_triton_forward_status'] = result[0] @@ -629,8 +578,8 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ for _ in range(num_runs): result = dynamic_mask_attention_triton( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, + scaling, window_size, is_causal ) if result[0] in ["OOM", "Not Available"]: @@ -655,8 +604,8 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ for _ in range(warmup_runs): result = dynamic_mask_attention_flex( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, + scaling, window_size, is_causal ) if result[0] in ["OOM", "Not Available"]: results['fdma_flex_forward_status'] = result[0] @@ -671,8 +620,7 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ for _ in range(num_runs): result = dynamic_mask_attention_flex( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, scaling, window_size, is_causal ) if result[0] in ["OOM", "Not Available"]: @@ -718,43 +666,43 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): print(title) print("🏆" + "=" * 76 + "🏆") - # Test configurations: (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal) + # Test configurations: (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, window_size, is_causal) configs = [ # Vary sequence length - (1, 2, 1, 256, 256, 128, 1024, True), - (1, 2, 1, 512, 512, 128, 1024, True), - (1, 2, 1, 1024, 1024, 128, 1024, True), - (1, 2, 1, 2048, 2048, 128, 1024, True), - (1, 2, 1, 4096, 4096, 128, 1024, True), - (1, 2, 1, 8192, 8192, 128, 1024, True), - (1, 2, 1, 16384, 16384, 128, 1024, True), - (1, 2, 1, 32768, 32768, 128, 1024, True), + (1, 2, 1, 256, 256, 64, 1024, True), + (1, 2, 1, 512, 512, 64, 1024, True), + (1, 2, 1, 1024, 1024, 64, 1024, True), + (1, 2, 1, 2048, 2048, 64, 1024, True), + (1, 2, 1, 4096, 4096, 64, 1024, True), + (1, 2, 1, 8192, 8192, 64, 1024, True), + (1, 2, 1, 16384, 16384, 64, 1024, True), + (1, 2, 1, 32768, 32768, 64, 1024, True), # Inference - (1, 2, 1, 1, 256, 128, 1024, True), - (1, 2, 1, 1, 512, 128, 1024, True), - (1, 2, 1, 1, 1024, 128, 1024, True), - (1, 2, 1, 1, 2048, 128, 1024, True), - (1, 2, 1, 1, 4096, 128, 1024, True), - (1, 2, 1, 1, 8192, 128, 1024, True), - (1, 2, 1, 1, 16384, 128, 1024, True), - (1, 2, 1, 1, 32768, 128, 1024, True), - (1, 2, 1, 1, 65536, 128, 1024, True), - (1, 2, 1, 1, 131072, 128, 1024, True), - (1, 2, 1, 1, 262144, 128, 1024, True), - (1, 2, 1, 1, 524288, 128, 1024, True), + (1, 2, 1, 1, 256, 64, 1024, True), + (1, 2, 1, 1, 512, 64, 1024, True), + (1, 2, 1, 1, 1024, 64, 1024, True), + (1, 2, 1, 1, 2048, 64, 1024, True), + (1, 2, 1, 1, 4096, 64, 1024, True), + (1, 2, 1, 1, 8192, 64, 1024, True), + (1, 2, 1, 1, 16384, 64, 1024, True), + (1, 2, 1, 1, 32768, 64, 1024, True), + (1, 2, 1, 1, 65536, 64, 1024, True), + (1, 2, 1, 1, 131072, 64, 1024, True), + (1, 2, 1, 1, 262144, 64, 1024, True), + (1, 2, 1, 1, 524288, 64, 1024, True), # Vary batch size - (1, 2, 1, 4096, 4096, 32, 1024, True), - (2, 2, 1, 4096, 4096, 32, 1024, True), - (4, 2, 1, 4096, 4096, 32, 1024, True), - (8, 2, 1, 4096, 4096, 32, 1024, True), + (1, 2, 1, 4096, 4096, 64, 1024, True), + (2, 2, 1, 4096, 4096, 64, 1024, True), + (4, 2, 1, 4096, 4096, 64, 1024, True), + (8, 2, 1, 4096, 4096, 64, 1024, True), # Vary head count - (1, 1, 1, 4096, 4096, 32, 1024, True), - (1, 2, 1, 4096, 4096, 32, 1024, True), - (1, 4, 1, 4096, 4096, 32, 1024, True), - (1, 8, 2, 4096, 4096, 32, 1024, True), + (1, 1, 1, 4096, 4096, 64, 1024, True), + (1, 2, 1, 4096, 4096, 64, 1024, True), + (1, 4, 1, 4096, 4096, 64, 1024, True), + (1, 8, 2, 4096, 4096, 64, 1024, True), # Vary head dimension (1, 2, 1, 4096, 4096, 32, 1024, True), @@ -764,18 +712,18 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): (1, 2, 1, 4096, 4096, 192, 1024, True), (1, 2, 1, 4096, 4096, 256, 1024, True), - # Vary keep_window_size - (1, 2, 1, 32768, 32768, 128, 32, True), - (1, 2, 1, 32768, 32768, 128, 64, True), - (1, 2, 1, 32768, 32768, 128, 128, True), - (1, 2, 1, 32768, 32768, 128, 256, True), - (1, 2, 1, 32768, 32768, 128, 512, True), - (1, 2, 1, 32768, 32768, 128, 1024, True), - (1, 2, 1, 32768, 32768, 128, 2048, True), - (1, 2, 1, 32768, 32768, 128, 4096, True), - (1, 2, 1, 32768, 32768, 128, 8192, True), - (1, 2, 1, 32768, 32768, 128, 16384, True), - (1, 2, 1, 32768, 32768, 128, 32768, True), + # Vary window_size + (1, 2, 1, 32768, 32768, 64, 32, True), + (1, 2, 1, 32768, 32768, 64, 64, True), + (1, 2, 1, 32768, 32768, 64, 128, True), + (1, 2, 1, 32768, 32768, 64, 256, True), + (1, 2, 1, 32768, 32768, 64, 512, True), + (1, 2, 1, 32768, 32768, 64, 1024, True), + (1, 2, 1, 32768, 32768, 64, 2048, True), + (1, 2, 1, 32768, 32768, 64, 4096, True), + (1, 2, 1, 32768, 32768, 64, 8192, True), + (1, 2, 1, 32768, 32768, 64, 16384, True), + (1, 2, 1, 32768, 32768, 64, 32768, True), ] print(f"\n📊 Benchmark Results (averaged over {num_runs} runs):") @@ -785,7 +733,7 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): all_results = [] for config in configs: - batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal = config + batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, window_size, is_causal = config results = benchmark_attention_performance(config, test_type, num_runs, warmup_runs) all_results.append(results) @@ -824,7 +772,7 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): speedup_strs[impl_key] = "N/A" # Format output with shorter config string - config_short = f" B{batch_size} Hq{num_heads} Hkv{num_kv_heads} Q{query_len} K{key_len} D{head_dim} W{keep_window_size} " + config_short = f" B{batch_size} Hq{num_heads} Hkv{num_kv_heads} Q{query_len} K{key_len} D{head_dim} W{window_size} " if not is_causal: config_short += "N" else: From 53e1aa4e670d6317c08e0bc7dcc4357401a9daad Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 12 Oct 2025 19:01:00 +0800 Subject: [PATCH 09/12] Unifies dmattn to bias+mask API; expands tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces internal zoh/value-proj masking with an external attention bias plus optional causal mask and top‑k windowing, simplifying the interface and masking semantics across backends. Aligns Python, CUDA, Triton, and Flex to a shared signature, applies masking consistently, ensures contiguous layouts, and uses deterministic execution for stable gradients. Expands backward‑equivalence coverage to head dims 192/256 and updates tests to use bf16 bias and causal masks, improving reproducibility and backend parity. --- benchmarks/backward_equivalence.py | 388 ++++++++++++----------------- 1 file changed, 163 insertions(+), 225 deletions(-) diff --git a/benchmarks/backward_equivalence.py b/benchmarks/backward_equivalence.py index fd8ebef..2149a8e 100644 --- a/benchmarks/backward_equivalence.py +++ b/benchmarks/backward_equivalence.py @@ -50,104 +50,66 @@ flex_dmattn_func = None -def prepare_dynamic_mask( - hidden_states: torch.Tensor, - zoh_states: torch.Tensor, - keep_window_size: int = 2048, - cache_position: torch.Tensor = None, -): +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ - Calculate dynamic attention mask to mask tokens for sparse attention. + Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + Transform from (batch, num_key_value_heads, seqlen, head_dim) + to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - Combine `zoh_states` with `attention_mask` to generate the final `attn_mask`. +def prepare_mask( + hidden_states: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor = None, + window_size: int = None, +): + """ Args: hidden_states: Input hidden states to determine dtype minimum value - zoh_states: zoh_states of shape (batch_size, num_kv_heads, key_sequence_length) - keep_window_size: Window size of tokens not dynamically masked - cache_position: Optional cache position for causal masking + attn_bias: Attention bias of shape (batch_size, num_heads, query_length, key_length) + causal_mask: Optional causal mask to apply + window_size: Window size of tokens not masked Returns: tuple: (attn_bias, attn_mask) """ dtype = hidden_states.dtype min_dtype = torch.finfo(dtype).min - attn_bias = zoh_states[:, :, None, :].expand( - -1, -1, hidden_states.shape[2], -1 - ).to(dtype) # [batch_size, num_kv_heads, query_len, key_len] - - if cache_position is not None: - attn_bias = attn_bias.masked_fill( - torch.arange(attn_bias.shape[-1], device=attn_bias.device) > cache_position.reshape(-1, 1), - min_dtype - ) - if attn_bias.shape[-1] > keep_window_size: - topk_values, topk_indices = torch.topk( - attn_bias, keep_window_size, dim=-1, largest=True, sorted=False - ) - valid_topk = topk_values != min_dtype - attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk) - attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype) + if attn_bias.shape[-1] > window_size: + if causal_mask is not None: + topk_values, topk_indices = torch.topk( + attn_bias.masked_fill(~causal_mask, min_dtype).detach(), + window_size, dim=-1, largest=True, sorted=False + ) + else: + topk_values, topk_indices = torch.topk( + attn_bias, + window_size, dim=-1, largest=True, sorted=False + ) + attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device).scatter_(-1, topk_indices, topk_values != min_dtype) else: - attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) + attn_mask = causal_mask.expand_as(attn_bias) if causal_mask is not None else torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) return attn_bias, attn_mask -def calculate_zoh_states(value_states, dt_proj, A): - """ - Calculate zoh states for dynamic mask attention. - - Args: - value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] - causal_mask: Optional causal mask - - Returns: - zoh_states: [batch_size, num_kv_heads, key_len] - """ - batch_size, _, key_len, _ = value_states.shape - - # Transpose and reshape value_states, then matrix multiply with dt_proj.T - dt_result = torch.matmul( - value_states.transpose(-2, -3).reshape(batch_size, key_len, -1), - dt_proj.T - ) - - # Apply softplus activation and coefficient A - dt_states = torch.exp(F.softplus(dt_result) * A) - zoh_states = dt_states.transpose(-1, -2) # [batch_size, num_kv_heads, key_len] - - return zoh_states - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). - Transform from (batch, num_key_value_heads, seqlen, head_dim) - to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - def dynamic_mask_attention_python( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, - is_causal=True, + window_size: int, + is_causal: bool, ): """ Python reference implementation of dynamic mask attention backward pass. @@ -156,11 +118,10 @@ def dynamic_mask_attention_python( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [batch_size, num_kv_heads, query_len, key_len] + causal_mask: [batch_size, 1, query_len, key_len] or None scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window + window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: @@ -174,29 +135,27 @@ def dynamic_mask_attention_python( key_states_leaf = key_states value_states_leaf = value_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) - - # Use prepare_dynamic_mask function to process dynamic mask - attn_bias, attn_mask = prepare_dynamic_mask( + attn_bias, attn_mask = prepare_mask( query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None + attn_bias, + causal_mask if is_causal else None, + window_size, ) attn_bias_leaf = attn_bias attn_bias_leaf.retain_grad() - - # Sparse attention weight calculation + key_states = repeat_kv(key_states, num_queries_per_kv) value_states = repeat_kv(value_states, num_queries_per_kv) attn_mask = repeat_kv(attn_mask, num_queries_per_kv) attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv) - attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) - attn_weights = attn_weights * scaling + attn_bias # Apply scaling and zoh - attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization - attn_outputs = torch.matmul(attn_weights, value_states) - attn_outputs = attn_outputs.transpose(1, 2).contiguous() # Transpose to [batch, query_len, num_heads, head_dim] + # Sparse attention weight calculation + attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) # Dot product weights + attn_weights = attn_weights * scaling + attn_bias # Apply scaling and bias + attn_weights = attn_weights.masked_fill(~attn_mask, float('-inf')) # Apply mask + attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization + attn_outputs = torch.matmul(attn_weights, value_states) # Weighted sum of values + attn_outputs = attn_outputs.transpose(1, 2).contiguous() # Transpose to [batch, query_len, num_heads, head_dim] # Backward pass attn_outputs.sum().backward() @@ -208,12 +167,11 @@ def dynamic_mask_attention_cuda( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, - is_causal=True, + window_size: int, + is_causal: bool, ): """ CUDA implementation of dynamic mask attention backward pass. @@ -222,11 +180,10 @@ def dynamic_mask_attention_cuda( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [batch_size, num_kv_heads, query_len, key_len] + causal_mask: [batch_size, 1, query_len, key_len] or None scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window + window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: @@ -239,36 +196,31 @@ def dynamic_mask_attention_cuda( key_states_leaf = key_states value_states_leaf = value_states - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) - - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( + attn_bias, attn_mask = prepare_mask( query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] + attn_bias, + causal_mask if is_causal else None, + window_size, + ) attn_bias_leaf = attn_bias attn_bias_leaf.retain_grad() # Ensure correct data types and memory layout for CUDA function - # CUDA function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] # Call the flash_dmattn_func interface attn_outputs = flash_dmattn_func( - query=query_states, # q: [batch, query_len, num_heads, head_dim] - key=key_states, # k: [batch, key_len, num_kv_heads, head_dim] - value=value_states, # v: [batch, key_len, num_kv_heads, head_dim] - attn_mask=attn_mask, # mask: [batch, num_kv_heads, query_len, key_len] - attn_bias=attn_bias, # bias: [batch, num_kv_heads, query_len, key_len] - is_causal=is_causal, # causal masking - softmax_scale=scaling, # scaling factor + query=query_states, + key=key_states, + value=value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, + is_causal=is_causal, + softmax_scale=scaling, softcap=0.0, - deterministic=False, + deterministic=True, return_attn_probs=False ) @@ -282,12 +234,11 @@ def dynamic_mask_attention_triton( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, - is_causal=True, + window_size: int, + is_causal: bool, ): """ Triton implementation of dynamic mask attention backward pass. @@ -296,11 +247,10 @@ def dynamic_mask_attention_triton( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [batch_size, num_kv_heads, query_len, key_len] + causal_mask: [batch_size, 1, query_len, key_len] or None scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window + window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: @@ -317,16 +267,12 @@ def dynamic_mask_attention_triton( key_states_leaf = key_states value_states_leaf = value_states - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) - - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( + attn_bias, attn_mask = prepare_mask( query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] + attn_bias, + causal_mask if is_causal else None, + window_size, + ) attn_bias_leaf = attn_bias attn_bias_leaf.retain_grad() @@ -336,7 +282,7 @@ def dynamic_mask_attention_triton( attn_mask = repeat_kv(attn_mask, num_queries_per_kv) attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv) - # Triton function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format + # Ensure correct data types and memory layout for Triton function query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] @@ -345,13 +291,13 @@ def dynamic_mask_attention_triton( # Call the Triton implementation attn_outputs = triton_dmattn_func( - query=query_states, # q: [batch, seqlen_q, num_heads, head_dim] - key=key_states, # k: [batch, seqlen_k, num_heads, head_dim] - value=value_states, # v: [batch, seqlen_k, num_heads, head_dim] - attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] - attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] - is_causal=is_causal, # causal masking - softmax_scale=scaling # scaling factor + query=query_states, + key=key_states, + value=value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, + is_causal=is_causal, + softmax_scale=scaling, ) # Backward pass @@ -364,12 +310,11 @@ def dynamic_mask_attention_flex( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, - is_causal=True, + window_size: int, + is_causal: bool, ): """ Flex Attention implementation of dynamic mask attention backward pass. @@ -378,11 +323,10 @@ def dynamic_mask_attention_flex( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [batch_size, num_kv_heads, query_len, key_len] + causal_mask: [batch_size, 1, query_len, key_len] or None scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window + window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: @@ -395,16 +339,12 @@ def dynamic_mask_attention_flex( _, num_kv_heads, _, _ = key_states.shape num_queries_per_kv = num_heads // num_kv_heads - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) - - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( + attn_bias, attn_mask = prepare_mask( query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] + attn_bias, + causal_mask if is_causal else None, + window_size, + ) attn_bias.retain_grad() # Repeat KV for multi-head attention (GQA support) @@ -413,18 +353,22 @@ def dynamic_mask_attention_flex( attn_mask = repeat_kv(attn_mask, num_queries_per_kv) attn_bias = repeat_kv(attn_bias, num_queries_per_kv) - # Flex attention expects: q, k, v in [batch, num_heads, seqlen, head_dim] format - # But attention_mask and attention_bias in [batch, num_heads, query_len, key_len] format + # Ensure correct data types and memory layout for Flex function + query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] + key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] # Call the Flex Attention implementation attn_outputs = flex_dmattn_func( - query_states.transpose(1, 2), # q: [batch, query_len, num_heads, head_dim] - key_states.transpose(1, 2), # k: [batch, key_len, num_heads, head_dim] - value_states.transpose(1, 2), # v: [batch, key_len, num_heads, head_dim] - attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len] - attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len] - is_causal=is_causal, # is_causal: whether to apply causal masking - softmax_scale=scaling # scaling factor + query_states, + key_states, + value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, + is_causal=is_causal, + softmax_scale=scaling, ) # Backward pass @@ -599,35 +543,33 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): (1, 2, 1, 4096, 4096, 128, False), (1, 2, 1, 4096, 4096, 128, True), - # # Head dim 192 - # Not enough shared memory for head_dim=192 in bwd yet - # (1, 2, 1, 128, 128, 192, False), - # (1, 2, 1, 128, 128, 192, True), - # (1, 2, 1, 256, 256, 192, False), - # (1, 2, 1, 256, 256, 192, True), - # (1, 2, 1, 512, 512, 192, False), - # (1, 2, 1, 512, 512, 192, True), - # (1, 2, 1, 1024, 1024, 192, False), - # (1, 2, 1, 1024, 1024, 192, True), - # (1, 2, 1, 2048, 2048, 192, False), - # (1, 2, 1, 2048, 2048, 192, True), - # (1, 2, 1, 4096, 4096, 192, False), - # (1, 2, 1, 4096, 4096, 192, True), + # Head dim 192 + (1, 2, 1, 128, 128, 192, False), + (1, 2, 1, 128, 128, 192, True), + (1, 2, 1, 256, 256, 192, False), + (1, 2, 1, 256, 256, 192, True), + (1, 2, 1, 512, 512, 192, False), + (1, 2, 1, 512, 512, 192, True), + (1, 2, 1, 1024, 1024, 192, False), + (1, 2, 1, 1024, 1024, 192, True), + (1, 2, 1, 2048, 2048, 192, False), + (1, 2, 1, 2048, 2048, 192, True), + (1, 2, 1, 4096, 4096, 192, False), + (1, 2, 1, 4096, 4096, 192, True), # Head dim 256 - # Not enough shared memory for head_dim=256 in bwd yet - # (1, 2, 1, 128, 128, 256, False), - # (1, 2, 1, 128, 128, 256, True), - # (1, 2, 1, 256, 256, 256, False), - # (1, 2, 1, 256, 256, 256, True), - # (1, 2, 1, 512, 512, 256, False), - # (1, 2, 1, 512, 512, 256, True), - # (1, 2, 1, 1024, 1024, 256, False), - # (1, 2, 1, 1024, 1024, 256, True), - # (1, 2, 1, 2048, 2048, 256, False), - # (1, 2, 1, 2048, 2048, 256, True), - # (1, 2, 1, 4096, 4096, 256, False), - # (1, 2, 1, 4096, 4096, 256, True), + (1, 2, 1, 128, 128, 256, False), + (1, 2, 1, 128, 128, 256, True), + (1, 2, 1, 256, 256, 256, False), + (1, 2, 1, 256, 256, 256, True), + (1, 2, 1, 512, 512, 256, False), + (1, 2, 1, 512, 512, 256, True), + (1, 2, 1, 1024, 1024, 256, False), + (1, 2, 1, 1024, 1024, 256, True), + (1, 2, 1, 2048, 2048, 256, False), + (1, 2, 1, 2048, 2048, 256, True), + (1, 2, 1, 4096, 4096, 256, False), + (1, 2, 1, 4096, 4096, 256, True), ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -668,48 +610,48 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): batch_size, num_kv_heads, key_len, head_dim, device=device, dtype=dtype, requires_grad=True ) - dt_proj = torch.randn( - num_kv_heads, num_kv_heads * head_dim, - device=device, dtype=dtype, requires_grad=True + attn_bias = torch.randn( + batch_size, num_kv_heads, query_len, key_len, + device=device, dtype=torch.bfloat16 ) - A = torch.randn(num_kv_heads, device=device, dtype=dtype, requires_grad=True) - - # Create cache position cache_position = torch.arange(key_len - query_len, key_len, device=device) + causal_mask = torch.arange(key_len, device=device) <= cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) # Set scaling factor and keep window size scaling = head_dim ** -0.5 - keep_window_size = 1024 + window_size = 1024 # Clone inputs for Python implementation query_python = query_states.clone().detach().requires_grad_(True) key_python = key_states.clone().detach().requires_grad_(True) value_python = value_states.clone().detach().requires_grad_(True) - dt_proj_python = dt_proj.clone().detach().requires_grad_(True) - A_python = A.clone().detach().requires_grad_(True) + attn_bias_python = attn_bias.clone().detach().requires_grad_(True) + causal_mask_python = causal_mask.clone().detach() # Run Python implementation start_time = time.time() attn_outputs_python, dq_python, dk_python, dv_python, dbias_python = dynamic_mask_attention_python( - query_python, key_python, value_python, dt_proj_python, A_python, - scaling, cache_position, keep_window_size, is_causal + query_python, key_python, value_python, + attn_bias_python, causal_mask_python, + scaling, window_size, is_causal ) torch.cuda.synchronize() py_time = time.time() - start_time - - + # Clone inputs for CUDA implementation query_cuda = query_states.clone().detach().requires_grad_(True) key_cuda = key_states.clone().detach().requires_grad_(True) value_cuda = value_states.clone().detach().requires_grad_(True) - dt_proj_cuda = dt_proj.clone().detach().requires_grad_(True) - A_cuda = A.clone().detach().requires_grad_(True) - + attn_bias_cuda = attn_bias.clone().detach().requires_grad_(True) + causal_mask_cuda = causal_mask.clone().detach() + # Run CUDA implementation start_time = time.time() attn_outputs_cuda, dq_cuda, dk_cuda, dv_cuda, dbias_cuda = dynamic_mask_attention_cuda( - query_cuda, key_cuda, value_cuda, dt_proj_cuda, A_cuda, - scaling, cache_position, keep_window_size, is_causal + query_cuda, key_cuda, value_cuda, + attn_bias_cuda, causal_mask_cuda, + scaling, window_size, is_causal ) torch.cuda.synchronize() cuda_time = time.time() - start_time @@ -774,7 +716,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): if not is_close and max_dbias_diff > 1e-2: print(" ⚠️ Difference too large, stopping subsequent tests.") break - del query_states, key_states, value_states, dt_proj, A, cache_position, dq_python, dk_python, dv_python, dbias_python, dq_cuda, dk_cuda, dv_cuda, dbias_cuda + del query_states, key_states, value_states, attn_bias, causal_mask, cache_position, dq_python, dk_python, dv_python, dbias_python, dq_cuda, dk_cuda, dv_cuda, dbias_cuda torch.cuda.empty_cache() gc.collect() torch.cuda.synchronize() @@ -872,7 +814,3 @@ def main(): if __name__ == "__main__": main() - - - - From 6f0b7c1f64a16fd99cf9c63d1d31457be4397586 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 12 Oct 2025 19:01:28 +0800 Subject: [PATCH 10/12] Standardizes mask API; removes ZOH-based path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reworks dynamic masking to consume precomputed attention bias plus optional boolean causal mask and a window size, using top‑k selection within the window and honoring causality. Removes ZOH/dt_proj/A dependency to simplify masking and reduce coupling. Aligns CUDA, Triton, Flex, and SDPA wrapper to a unified interface, adds GQA support via KV repetition, and ensures consistent tensor layout. Detaches top‑k selection to avoid unintended gradients. Updates benchmarks to generate attention bias and boolean causal masks, renames keep_window_size to window_size, and adjusts configs/loops accordingly for consistent evaluation across backends. Improves clarity, consistency, and extensibility of the attention backward benchmarks. --- benchmarks/backward_performance.py | 484 +++++++++++++---------------- 1 file changed, 217 insertions(+), 267 deletions(-) diff --git a/benchmarks/backward_performance.py b/benchmarks/backward_performance.py index 03d5018..82deb8c 100644 --- a/benchmarks/backward_performance.py +++ b/benchmarks/backward_performance.py @@ -72,124 +72,100 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def prepare_dynamic_mask( +def prepare_mask( hidden_states: torch.Tensor, - zoh_states: torch.Tensor, - keep_window_size: int = 2048, - cache_position: torch.Tensor = None, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor = None, + window_size: int = None, ): """ - Calculate dynamic attention mask to mask tokens for sparse attention. - - Combine `zoh_states` with `attention_mask` to generate the final `attn_mask`. - Args: hidden_states: Input hidden states to determine dtype minimum value - zoh_states: zoh_states of shape (batch_size, num_kv_heads, key_sequence_length) - keep_window_size: Window size of tokens not dynamically masked - cache_position: Optional cache position for causal masking + attn_bias: Attention bias of shape (batch_size, num_heads, query_length, key_length) + causal_mask: Optional causal mask to apply + window_size: Window size of tokens not masked Returns: tuple: (attn_bias, attn_mask) """ dtype = hidden_states.dtype min_dtype = torch.finfo(dtype).min - attn_bias = zoh_states[:, :, None, :].expand( - -1, -1, hidden_states.shape[2], -1 - ).to(dtype) # [batch_size, num_kv_heads, query_len, key_len] - - if cache_position is not None: - attn_bias = attn_bias.masked_fill( - torch.arange(attn_bias.shape[-1], device=attn_bias.device) > cache_position.reshape(-1, 1), - min_dtype - ) - if attn_bias.shape[-1] > keep_window_size: - topk_values, topk_indices = torch.topk( - attn_bias, keep_window_size, dim=-1, largest=True, sorted=False - ) - valid_topk = topk_values != min_dtype - attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk) - attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype) + if attn_bias.shape[-1] > window_size: + if causal_mask is not None: + topk_values, topk_indices = torch.topk( + attn_bias.masked_fill(~causal_mask, min_dtype).detach(), + window_size, dim=-1, largest=True, sorted=False + ) + else: + topk_values, topk_indices = torch.topk( + attn_bias, + window_size, dim=-1, largest=True, sorted=False + ) + attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device).scatter_(-1, topk_indices, topk_values != min_dtype) else: - attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) + attn_mask = causal_mask.expand_as(attn_bias) if causal_mask is not None else torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) return attn_bias, attn_mask -def calculate_zoh_states(value_states, dt_proj, A): - """ - Calculate zoh states for dynamic mask attention. - - Args: - value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] - causal_mask: Optional causal mask - - Returns: - zoh_states: [batch_size, num_kv_heads, key_len] - """ - batch_size, _, key_len, _ = value_states.shape - - # Transpose and reshape value_states, then matrix multiply with dt_proj.T - dt_result = torch.matmul( - value_states.transpose(-2, -3).reshape(batch_size, key_len, -1), - dt_proj.T - ) - - # Apply softplus activation and coefficient A - dt_states = torch.exp(F.softplus(dt_result) * A) - zoh_states = dt_states.transpose(-1, -2) # [batch_size, num_kv_heads, key_len] - - return zoh_states - - -def scaled_dot_product_attention_backward( +def scaled_dot_product_attention_backward_cuda( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - scaling: float, + attn_bias: torch.Tensor, causal_mask: torch.Tensor, - is_causal=True, + scaling: float, + window_size: int, + is_causal: bool, ): """ - SDPA baseline backward pass implementation. + CUDA implementation of SDPA baseline. Args: query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] + attn_bias: [batch_size, num_heads, query_length, key_length] + causal_mask: [batch_size, 1, query_length, key_length] or None scaling: Attention scaling factor - causal_mask: Causal attention mask is_causal: Whether to apply causal masking Returns: - tuple: (output_tensor, timing_ms) or ("OOM", 0) if out of memory + tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ - _, _, query_len, _ = query_states.shape - _, _, key_len, _ = key_states.shape - if query_len > 32768 and key_len > 32768: - return "OOM", 0 + _, num_heads, _, _ = query_states.shape + _, num_kv_heads, _, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads + + attn_bias, attn_mask = prepare_mask( + query_states, + attn_bias, + causal_mask if is_causal else None, + window_size, + ) + + # Repeat KV for multi-head attention (GQA support) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_bias = repeat_kv(attn_bias, num_queries_per_kv) query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() + attn_bias = attn_bias.masked_fill(~attn_mask, torch.finfo(query_states.dtype).min).contiguous() try: - # Forward pass - SDPA expects q, k, v in [batch, num_heads, seq_len, head_dim] format attn_outputs = F.scaled_dot_product_attention( - query_states, # [batch, num_heads, query_len, head_dim] - key_states, # [batch, num_kv_heads, key_len, head_dim] - value_states, # [batch, num_kv_heads, key_len, head_dim] - attn_mask=causal_mask, - softmax_scale=scaling, - # is_causal=is_causal if query_len == key_len else False, + query_states, + key_states, + value_states, + attn_mask=attn_bias, + scale=scaling, + # is_causal=is_causal, enable_gqa=True ) - # Transpose to match expected output format + attn_outputs = attn_outputs.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] - + torch.cuda.synchronize() start_time = time.time() @@ -209,12 +185,11 @@ def dynamic_mask_attention_backward_cuda( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, - is_causal=True, + window_size: int, + is_causal: bool, ): """ CUDA implementation of dynamic mask attention backward pass. @@ -223,11 +198,10 @@ def dynamic_mask_attention_backward_cuda( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [num_kv_heads, query_len, key_len] + causal_mask: [batch_size, 1, query_length, key_length] or None scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window + window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: @@ -236,33 +210,27 @@ def dynamic_mask_attention_backward_cuda( if flash_dmattn_func is None: return "Not Available", 0 - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) - - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( + attn_bias, attn_mask = prepare_mask( query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] + attn_bias, + causal_mask if is_causal else None, + window_size, + ) # Ensure correct data types and memory layout for CUDA function - # CUDA function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] try: - # Call the flash_dmattn_func interface attn_outputs = flash_dmattn_func( - query=query_states, # q: [batch, query_len, num_heads, head_dim] - key=key_states, # k: [batch, key_len, num_kv_heads, head_dim] - value=value_states, # v: [batch, key_len, num_kv_heads, head_dim] - attn_mask=attn_mask, # mask: [batch, num_kv_heads, query_len, key_len] - attn_bias=attn_bias, # bias: [batch, num_kv_heads, query_len, key_len] - is_causal=is_causal, # causal masking - softmax_scale=scaling, # scaling factor + query=query_states, + key=key_states, + value=value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, + is_causal=is_causal, + softmax_scale=scaling, softcap=0.0, deterministic=False, return_attn_probs=False @@ -287,12 +255,11 @@ def dynamic_mask_attention_backward_triton( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, - is_causal=True, + window_size: int, + is_causal: bool, ): """ Triton implementation of dynamic mask attention backward pass. @@ -301,11 +268,10 @@ def dynamic_mask_attention_backward_triton( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [num_kv_heads, query_len, key_len] + causal_mask: [batch_size, 1, query_length, key_length] or None scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window + window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: @@ -318,40 +284,35 @@ def dynamic_mask_attention_backward_triton( _, num_kv_heads, _, _ = key_states.shape num_queries_per_kv = num_heads // num_kv_heads - try: - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) + attn_bias, attn_mask = prepare_mask( + query_states, + attn_bias, + causal_mask if is_causal else None, + window_size, + ) - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( - query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] - - # Repeat KV for multi-head attention (GQA support) - key_states = repeat_kv(key_states, num_queries_per_kv) - value_states = repeat_kv(value_states, num_queries_per_kv) - attn_mask = repeat_kv(attn_mask, num_queries_per_kv) - attn_bias = repeat_kv(attn_bias, num_queries_per_kv) - - # Triton function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format - query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] - key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] - value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] - attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] - attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] - - # Call the Triton implementation + # Repeat KV for multi-head attention (GQA support) + key_states = repeat_kv(key_states, num_queries_per_kv) + value_states = repeat_kv(value_states, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_bias = repeat_kv(attn_bias, num_queries_per_kv) + + # Ensure correct data types and memory layout for Triton function + query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] + key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + + try: attn_outputs = triton_dmattn_func( - query=query_states, # q: [batch, seqlen_q, num_heads, head_dim] - key=key_states, # k: [batch, seqlen_k, num_heads, head_dim] - value=value_states, # v: [batch, seqlen_k, num_heads, head_dim] - attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] - attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] - is_causal=is_causal, # causal masking - softmax_scale=scaling # scaling factor + query=query_states, + key=key_states, + value=value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, + is_causal=is_causal, + softmax_scale=scaling, ) torch.cuda.synchronize() @@ -373,12 +334,11 @@ def dynamic_mask_attention_backward_flex( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, - is_causal=True, + window_size: int, + is_causal: bool, ): """ Flex Attention implementation of dynamic mask attention backward pass. @@ -387,11 +347,10 @@ def dynamic_mask_attention_backward_flex( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [num_kv_heads, query_len, key_len] + causal_mask: [batch_size, 1, query_length, key_length] or None scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window + window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: @@ -404,36 +363,35 @@ def dynamic_mask_attention_backward_flex( _, num_kv_heads, _, _ = key_states.shape num_queries_per_kv = num_heads // num_kv_heads - try: - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) + attn_bias, attn_mask = prepare_mask( + query_states, + attn_bias, + causal_mask if is_causal else None, + window_size, + ) - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( - query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] - - # Repeat KV for multi-head attention (GQA support) - key_states = repeat_kv(key_states, num_queries_per_kv) - value_states = repeat_kv(value_states, num_queries_per_kv) - attn_mask = repeat_kv(attn_mask, num_queries_per_kv) - attn_bias = repeat_kv(attn_bias, num_queries_per_kv) - - # Flex attention expects: q, k, v in [batch, num_heads, seqlen, head_dim] format - # But attention_mask and attention_bias in [batch, num_heads, query_len, key_len] format - - # Call the Flex Attention implementation + # Repeat KV for multi-head attention (GQA support) + key_states = repeat_kv(key_states, num_queries_per_kv) + value_states = repeat_kv(value_states, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_bias = repeat_kv(attn_bias, num_queries_per_kv) + + # Ensure correct data types and memory layout for Flex function + query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] + key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + + try: attn_outputs = flex_dmattn_func( - query_states.transpose(1, 2), # q: [batch, query_len, num_heads, head_dim] - key_states.transpose(1, 2), # k: [batch, key_len, num_heads, head_dim] - value_states.transpose(1, 2), # v: [batch, key_len, num_heads, head_dim] - attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len] - attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len] - is_causal=is_causal, # is_causal: whether to apply causal masking - softmax_scale=scaling # scaling factor + query_states, + key_states, + value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, + is_causal=is_causal, + softmax_scale=scaling, ) torch.cuda.synchronize() @@ -470,7 +428,7 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 Benchmark backward attention performance for a given configuration. Args: - config: Tuple of (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal) + config: Tuple of (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, window_size, is_causal) test_type: Type of test to run ('all', 'sdpa', 'cuda', 'triton', 'flex', etc.) num_runs: Number of benchmark runs warmup_runs: Number of warmup runs @@ -478,7 +436,7 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 Returns: dict: Performance metrics """ - batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal = config + batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, window_size, is_causal = config device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Create random input data (requires_grad=True for backward pass) @@ -494,21 +452,12 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 batch_size, num_kv_heads, key_len, head_dim, device=device, dtype=torch.bfloat16, requires_grad=True ) - dt_proj = torch.randn( - num_kv_heads, num_kv_heads * head_dim, - device=device, dtype=torch.bfloat16, requires_grad=True + attn_bias = torch.randn( + batch_size, num_kv_heads, query_len, key_len, + device=device, dtype=torch.bfloat16 ) - A = torch.randn(num_kv_heads, device=device, dtype=torch.bfloat16, requires_grad=True) - - # Create custom causal mask with cache position cache_position = torch.arange(key_len - query_len, key_len, device=device) - min_type = torch.finfo(value_states.dtype).min - causal_mask = torch.full( - (query_len, key_len), fill_value=min_type, - device=device, dtype=value_states.dtype - ) - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(key_len, device=device) > cache_position.reshape(-1, 1) + causal_mask = torch.arange(key_len, device=device) <= cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) # Set scaling factor from config @@ -543,13 +492,16 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 # Warmup runs for _ in range(warmup_runs): - # Clone inputs for each run - q_clone = query_states.clone().detach().requires_grad_(True) - k_clone = key_states.clone().detach().requires_grad_(True) - v_clone = value_states.clone().detach().requires_grad_(True) - - result = scaled_dot_product_attention_backward( - q_clone, k_clone, v_clone, scaling, causal_mask, is_causal + query_sdpa = query_states.clone().detach().requires_grad_(True) + key_sdpa = key_states.clone().detach().requires_grad_(True) + value_sdpa = value_states.clone().detach().requires_grad_(True) + attn_bias_sdpa = attn_bias.clone().detach().requires_grad_(True) + causal_mask_sdpa = causal_mask.clone().detach() + + result = scaled_dot_product_attention_backward_cuda( + query_sdpa, key_sdpa, value_sdpa, + attn_bias_sdpa, causal_mask_sdpa, + scaling, window_size, is_causal ) if result[0] == "OOM": results['sdpa_backward_status'] = 'OOM' @@ -562,13 +514,16 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 # Actual benchmark runs for _ in range(num_runs): - # Clone inputs for each run - q_clone = query_states.clone().detach().requires_grad_(True) - k_clone = key_states.clone().detach().requires_grad_(True) - v_clone = value_states.clone().detach().requires_grad_(True) - - result = scaled_dot_product_attention_backward( - q_clone, k_clone, v_clone, scaling, causal_mask, is_causal + query_sdpa = query_states.clone().detach().requires_grad_(True) + key_sdpa = key_states.clone().detach().requires_grad_(True) + value_sdpa = value_states.clone().detach().requires_grad_(True) + attn_bias_sdpa = attn_bias.clone().detach().requires_grad_(True) + causal_mask_sdpa = causal_mask.clone().detach() + + result = scaled_dot_product_attention_backward_cuda( + query_sdpa, key_sdpa, value_sdpa, + attn_bias_sdpa, causal_mask_sdpa, + scaling, window_size, is_causal ) if result[0] == "OOM": @@ -591,16 +546,15 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 # Warmup runs for _ in range(warmup_runs): - # Clone inputs for each run - q_clone = query_states.clone().detach().requires_grad_(True) - k_clone = key_states.clone().detach().requires_grad_(True) - v_clone = value_states.clone().detach().requires_grad_(True) - dt_clone = dt_proj.clone().detach().requires_grad_(True) - a_clone = A.clone().detach().requires_grad_(True) - + query_cuda = query_states.clone().detach().requires_grad_(True) + key_cuda = key_states.clone().detach().requires_grad_(True) + value_cuda = value_states.clone().detach().requires_grad_(True) + attn_bias_cuda = attn_bias.clone().detach().requires_grad_(True) + causal_mask_cuda = causal_mask.clone().detach() + result = dynamic_mask_attention_backward_cuda( - q_clone, k_clone, v_clone, dt_clone, a_clone, - scaling, cache_position, keep_window_size, is_causal + query_cuda, key_cuda, value_cuda, attn_bias_cuda, causal_mask_cuda, + scaling, window_size, is_causal ) if result[0] in ["OOM", "Not Available"]: results['fdma_cuda_backward_status'] = result[0] @@ -613,16 +567,15 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 # Actual benchmark runs for _ in range(num_runs): - # Clone inputs for each run - q_clone = query_states.clone().detach().requires_grad_(True) - k_clone = key_states.clone().detach().requires_grad_(True) - v_clone = value_states.clone().detach().requires_grad_(True) - dt_clone = dt_proj.clone().detach().requires_grad_(True) - a_clone = A.clone().detach().requires_grad_(True) - + query_cuda = query_states.clone().detach().requires_grad_(True) + key_cuda = key_states.clone().detach().requires_grad_(True) + value_cuda = value_states.clone().detach().requires_grad_(True) + attn_bias_cuda = attn_bias.clone().detach().requires_grad_(True) + causal_mask_cuda = causal_mask.clone().detach() + result = dynamic_mask_attention_backward_cuda( - q_clone, k_clone, v_clone, dt_clone, a_clone, - scaling, cache_position, keep_window_size, is_causal + query_cuda, key_cuda, value_cuda, attn_bias_cuda, causal_mask_cuda, + scaling, window_size, is_causal ) if result[0] in ["OOM", "Not Available"]: @@ -645,16 +598,15 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 # Warmup runs for _ in range(warmup_runs): - # Clone inputs for each run - q_clone = query_states.clone().detach().requires_grad_(True) - k_clone = key_states.clone().detach().requires_grad_(True) - v_clone = value_states.clone().detach().requires_grad_(True) - dt_clone = dt_proj.clone().detach().requires_grad_(True) - a_clone = A.clone().detach().requires_grad_(True) - + query_triton = query_states.clone().detach().requires_grad_(True) + key_triton = key_states.clone().detach().requires_grad_(True) + value_triton = value_states.clone().detach().requires_grad_(True) + attn_bias_triton = attn_bias.clone().detach().requires_grad_(True) + causal_mask_triton = causal_mask.clone().detach() + result = dynamic_mask_attention_backward_triton( - q_clone, k_clone, v_clone, dt_clone, a_clone, - scaling, cache_position, keep_window_size, is_causal + query_triton, key_triton, value_triton, attn_bias_triton, causal_mask_triton, + scaling, window_size, is_causal ) if result[0] in ["OOM", "Not Available"]: results['fdma_triton_backward_status'] = result[0] @@ -667,16 +619,15 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 # Actual benchmark runs for _ in range(num_runs): - # Clone inputs for each run - q_clone = query_states.clone().detach().requires_grad_(True) - k_clone = key_states.clone().detach().requires_grad_(True) - v_clone = value_states.clone().detach().requires_grad_(True) - dt_clone = dt_proj.clone().detach().requires_grad_(True) - a_clone = A.clone().detach().requires_grad_(True) - + query_triton = query_states.clone().detach().requires_grad_(True) + key_triton = key_states.clone().detach().requires_grad_(True) + value_triton = value_states.clone().detach().requires_grad_(True) + attn_bias_triton = attn_bias.clone().detach().requires_grad_(True) + causal_mask_triton = causal_mask.clone().detach() + result = dynamic_mask_attention_backward_triton( - q_clone, k_clone, v_clone, dt_clone, a_clone, - scaling, cache_position, keep_window_size, is_causal + query_triton, key_triton, value_triton, attn_bias_triton, causal_mask_triton, + scaling, window_size, is_causal ) if result[0] in ["OOM", "Not Available"]: @@ -699,16 +650,15 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 # Warmup runs for _ in range(warmup_runs): - # Clone inputs for each run - q_clone = query_states.clone().detach().requires_grad_(True) - k_clone = key_states.clone().detach().requires_grad_(True) - v_clone = value_states.clone().detach().requires_grad_(True) - dt_clone = dt_proj.clone().detach().requires_grad_(True) - a_clone = A.clone().detach().requires_grad_(True) - + query_flex = query_states.clone().detach().requires_grad_(True) + key_flex = key_states.clone().detach().requires_grad_(True) + value_flex = value_states.clone().detach().requires_grad_(True) + attn_bias_flex = attn_bias.clone().detach().requires_grad_(True) + causal_mask_flex = causal_mask.clone().detach() + result = dynamic_mask_attention_backward_flex( - q_clone, k_clone, v_clone, dt_clone, a_clone, - scaling, cache_position, keep_window_size, is_causal + query_flex, key_flex, value_flex, attn_bias_flex, causal_mask_flex, + scaling, window_size, is_causal ) if result[0] in ["OOM", "Not Available"]: results['fdma_flex_backward_status'] = result[0] @@ -722,15 +672,15 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 # Actual benchmark runs for _ in range(num_runs): # Clone inputs for each run - q_clone = query_states.clone().detach().requires_grad_(True) - k_clone = key_states.clone().detach().requires_grad_(True) - v_clone = value_states.clone().detach().requires_grad_(True) - dt_clone = dt_proj.clone().detach().requires_grad_(True) - a_clone = A.clone().detach().requires_grad_(True) - + query_flex = query_states.clone().detach().requires_grad_(True) + key_flex = key_states.clone().detach().requires_grad_(True) + value_flex = value_states.clone().detach().requires_grad_(True) + attn_bias_flex = attn_bias.clone().detach().requires_grad_(True) + causal_mask_flex = causal_mask.clone().detach() + result = dynamic_mask_attention_backward_flex( - q_clone, k_clone, v_clone, dt_clone, a_clone, - scaling, cache_position, keep_window_size, is_causal + query_flex, key_flex, value_flex, attn_bias_flex, causal_mask_flex, + scaling, window_size, is_causal ) if result[0] in ["OOM", "Not Available"]: @@ -776,7 +726,7 @@ def run_backward_performance_benchmark(test_type='all', num_runs=3, warmup_runs= print(title) print("🏆" + "=" * 76 + "🏆") - # Test configurations: (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal) + # Test configurations: (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, window_size, is_causal) configs = [ # Vary sequence length (1, 2, 1, 256, 256, 64, 1024, True), @@ -805,7 +755,7 @@ def run_backward_performance_benchmark(test_type='all', num_runs=3, warmup_runs= (1, 2, 1, 16384, 16384, 96, 1024, True), (1, 2, 1, 16384, 16384, 128, 1024, True), - # Vary keep_window_size + # Vary window_size (1, 2, 1, 16384, 16384, 64, 32, True), (1, 2, 1, 16384, 16384, 64, 64, True), (1, 2, 1, 16384, 16384, 64, 128, True), @@ -830,8 +780,8 @@ def run_backward_performance_benchmark(test_type='all', num_runs=3, warmup_runs= all_results.append(results) # Format configuration string - batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal = config - config_str = f"B{batch_size} Hq{num_heads} Hkv{num_kv_heads} Q{query_len} K{key_len} D{head_dim} W{keep_window_size} {'C' if is_causal else 'N'}" + batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, window_size, is_causal = config + config_str = f"B{batch_size} Hq{num_heads} Hkv{num_kv_heads} Q{query_len} K{key_len} D{head_dim} W{window_size} {'C' if is_causal else 'N'}" # Calculate averages and format results sdpa_avg = f"{sum(results['sdpa_backward_times'])/len(results['sdpa_backward_times']):.2f}ms" if results['sdpa_backward_times'] else results['sdpa_backward_status'] From 3b1b4029c3b980847f12f68cbadf235bab02a292 Mon Sep 17 00:00:00 2001 From: Jingze Shi Date: Sun, 12 Oct 2025 19:18:53 +0800 Subject: [PATCH 11/12] Update benchmarks/forward_equivalence.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- benchmarks/forward_equivalence.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/forward_equivalence.py b/benchmarks/forward_equivalence.py index fe624ab..8baff70 100644 --- a/benchmarks/forward_equivalence.py +++ b/benchmarks/forward_equivalence.py @@ -988,8 +988,8 @@ def test_flex_forward_equivalence(accuracy_threshold=0.95): start_time = time.time() py_output = dynamic_mask_attention_python( query_states, key_states, value_states, - attn_bias, causal_mask, scaling, cache_position, - window_size, is_causal + window_size, attn_bias, causal_mask, scaling, + is_causal ) torch.cuda.synchronize() py_time = time.time() - start_time From 08392c8132defaf4926b7f7351560b7139e731c8 Mon Sep 17 00:00:00 2001 From: Jingze Shi Date: Sun, 12 Oct 2025 19:19:15 +0800 Subject: [PATCH 12/12] Update flash_dmattn/flash_dmattn_interface.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- flash_dmattn/flash_dmattn_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 10ab385..663dc2e 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -648,7 +648,7 @@ def flash_dmattn_func( shape ({batch_size|1}, {nheads|nheads_k|1}, {seqlen_q|1}, {seqlen_k|1}) to apply to the attention scores. If None, no mask is applied. attn_bias: torch.Tensor, optional. The attention bias float tensor of - shape (batch_size, {nheads|nheads_k|1}, {seqlen_q|1}, {seqlen_k|1}) to add to the attention scores. + shape ({batch_size|1}, {nheads|nheads_k|1}, {seqlen_q|1}, {seqlen_k|1}) to add to the attention scores. If None, no bias is applied. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim).