From 1e6f46deb61ac2949726b754dd30e8d557eae20f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 10 Oct 2025 09:40:01 +0000 Subject: [PATCH 1/5] Initial plan From dcd700b7c5cbdeb53d247af92d224508fdffff96 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 10 Oct 2025 09:46:00 +0000 Subject: [PATCH 2/5] Add layout flags and kernel support for total_k-based mask/bias Co-authored-by: LoserCheems <124847097+LoserCheems@users.noreply.github.com> --- csrc/flash_dmattn/flash_api.cpp | 4 ++ csrc/flash_dmattn/src/block_info.h | 28 ++++++++++---- csrc/flash_dmattn/src/flash.h | 2 + csrc/flash_dmattn/src/flash_fwd_kernel.h | 48 +++++++++++++++++------- 4 files changed, 60 insertions(+), 22 deletions(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index d26ecde..1249d21 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -52,6 +52,8 @@ void set_params_fprop( const float softcap, bool has_mask, bool has_bias, + bool mask_layout_is_k_based=false, + bool bias_layout_is_k_based=false, bool seqlenq_ngroups_swapped=false, const bool unpadded_lse=false ) { @@ -142,6 +144,8 @@ void set_params_fprop( params.is_causal = is_causal; params.has_mask = has_mask; params.has_bias = has_bias; + params.mask_layout_is_k_based = mask_layout_is_k_based; + params.bias_layout_is_k_based = bias_layout_is_k_based; params.is_seqlens_k_cumulative = true; #ifdef FLASHATTENTION_DISABLE_UNEVEN_K diff --git a/csrc/flash_dmattn/src/block_info.h b/csrc/flash_dmattn/src/block_info.h index 38fefca..bcf595a 100644 --- a/csrc/flash_dmattn/src/block_info.h +++ b/csrc/flash_dmattn/src/block_info.h @@ -36,17 +36,29 @@ struct BlockInfo { } template - __forceinline__ __device__ index_t mask_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { - index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; - sum_s_k == -1 ? offset += leftpad_k : offset += uint32_t(sum_s_k + leftpad_k); - return offset; + __forceinline__ __device__ index_t mask_offset(const index_t batch_stride, const index_t row_stride, const int bidb, const bool is_k_based = false) const { + if (is_k_based) { + // For total_k-based layouts, only use k offset (broadcast across query positions) + return sum_s_k == -1 ? bidb * batch_stride + leftpad_k : uint32_t(sum_s_k + leftpad_k); + } else { + // For total_q-based layouts, use both q and k offsets + index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + sum_s_k == -1 ? offset += leftpad_k : offset += uint32_t(sum_s_k + leftpad_k); + return offset; + } } template - __forceinline__ __device__ index_t bias_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { - index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; - sum_s_k == -1 ? offset += leftpad_k : offset += uint32_t(sum_s_k + leftpad_k); - return offset; + __forceinline__ __device__ index_t bias_offset(const index_t batch_stride, const index_t row_stride, const int bidb, const bool is_k_based = false) const { + if (is_k_based) { + // For total_k-based layouts, only use k offset (broadcast across query positions) + return sum_s_k == -1 ? bidb * batch_stride + leftpad_k : uint32_t(sum_s_k + leftpad_k); + } else { + // For total_q-based layouts, use both q and k offsets + index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + sum_s_k == -1 ? offset += leftpad_k : offset += uint32_t(sum_s_k + leftpad_k); + return offset; + } } const int sum_s_q; diff --git a/csrc/flash_dmattn/src/flash.h b/csrc/flash_dmattn/src/flash.h index a1c9bf1..43c40a5 100644 --- a/csrc/flash_dmattn/src/flash.h +++ b/csrc/flash_dmattn/src/flash.h @@ -56,6 +56,7 @@ struct Mask_params { int h_h_mask_ratio; // precompute h / h_mask bool has_mask; + bool mask_layout_is_k_based; // If true, mask is shaped (total_k, num_heads_variant) for broadcasting }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -73,6 +74,7 @@ struct Bias_params { int h_h_bias_ratio; // precompute h / h_bias bool has_bias; + bool bias_layout_is_k_based; // If true, bias is shaped (total_k, num_heads_variant) for broadcasting }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_dmattn/src/flash_fwd_kernel.h index 576de90..c03f17a 100644 --- a/csrc/flash_dmattn/src/flash_fwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_fwd_kernel.h @@ -170,9 +170,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_coord(_, 0) ); // (kBlockN, kHeadDim, nblocksN) Tensor mMask = make_tensor( - make_gmem_ptr(reinterpret_cast(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)), - make_shape(params.h_mask, binfo.actual_seqlen_q, binfo.actual_seqlen_k), - make_stride(params.mask_head_stride, params.mask_row_stride, _1{}) + make_gmem_ptr(reinterpret_cast(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb, params.mask_layout_is_k_based)), + params.mask_layout_is_k_based + ? make_shape(params.h_mask, _1{}, binfo.actual_seqlen_k) // (h, 1, k) for k-based layout + : make_shape(params.h_mask, binfo.actual_seqlen_q, binfo.actual_seqlen_k), // (h, q, k) for q-based layout + params.mask_layout_is_k_based + ? make_stride(params.mask_head_stride, _0{}, _1{}) // Broadcast across q dimension + : make_stride(params.mask_head_stride, params.mask_row_stride, _1{}) ); Tensor gMask = local_tile( mMask(bidh / params.h_h_mask_ratio, _, _), @@ -180,9 +184,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_coord(m_block, _) ); // (kBlockM, kBlockN, nblocksN) Tensor mBias = make_tensor( - make_gmem_ptr(reinterpret_cast(params.bias_ptr) + binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb)), - make_shape(params.h_bias, binfo.actual_seqlen_q, binfo.actual_seqlen_k), - make_stride(params.bias_head_stride, params.bias_row_stride, _1{}) + make_gmem_ptr(reinterpret_cast(params.bias_ptr) + binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb, params.bias_layout_is_k_based)), + params.bias_layout_is_k_based + ? make_shape(params.h_bias, _1{}, binfo.actual_seqlen_k) // (h, 1, k) for k-based layout + : make_shape(params.h_bias, binfo.actual_seqlen_q, binfo.actual_seqlen_k), // (h, q, k) for q-based layout + params.bias_layout_is_k_based + ? make_stride(params.bias_head_stride, _0{}, _1{}) // Broadcast across q dimension + : make_stride(params.bias_head_stride, params.bias_row_stride, _1{}) ); Tensor gBias = local_tile( mBias(bidh / params.h_h_bias_ratio, _, _), @@ -928,15 +936,23 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; const index_t col_offset_mask = (block_table == nullptr) - ? binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb_cache) - + (bidh / params.h_h_mask_ratio) * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + (n_block_max - 1) * kBlockN + ? binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb_cache, params.mask_layout_is_k_based) + + (bidh / params.h_h_mask_ratio) * params.mask_head_stride + + (params.mask_layout_is_k_based ? 0 : m_block * kBlockM * params.mask_row_stride) // No row offset for k-based + + (n_block_max - 1) * kBlockN : binfo.q_offset(/*batch_stride=*/index_t(0), params.mask_row_stride, bidb_cache) - + (bidh / params.h_h_mask_ratio) * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + block_table[block_table_idx] * params.mask_batch_stride + block_table_offset; + + (bidh / params.h_h_mask_ratio) * params.mask_head_stride + + (params.mask_layout_is_k_based ? 0 : m_block * kBlockM * params.mask_row_stride) // No row offset for k-based + + block_table[block_table_idx] * params.mask_batch_stride + block_table_offset; const index_t col_offset_bias = (block_table == nullptr) - ? binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb_cache) - + (bidh / params.h_h_bias_ratio) * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + (n_block_max - 1) * kBlockN + ? binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb_cache, params.bias_layout_is_k_based) + + (bidh / params.h_h_bias_ratio) * params.bias_head_stride + + (params.bias_layout_is_k_based ? 0 : m_block * kBlockM * params.bias_row_stride) // No row offset for k-based + + (n_block_max - 1) * kBlockN : binfo.q_offset(/*batch_stride=*/index_t(0), params.bias_row_stride, bidb_cache) - + (bidh / params.h_h_bias_ratio) * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + block_table[block_table_idx] * params.bias_batch_stride + block_table_offset; + + (bidh / params.h_h_bias_ratio) * params.bias_head_stride + + (params.bias_layout_is_k_based ? 0 : m_block * kBlockM * params.bias_row_stride) // No row offset for k-based + + block_table[block_table_idx] * params.bias_batch_stride + block_table_offset; // Global memory tensor configuration Tensor mQ = make_tensor( @@ -962,12 +978,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor gMask = make_tensor( make_gmem_ptr(reinterpret_cast(params.mask_ptr) + col_offset_mask), Shape, Int>{}, - make_stride(params.mask_row_stride, _1{}) + params.mask_layout_is_k_based + ? make_stride(_0{}, _1{}) // Broadcast across M (query) dimension for k-based layout + : make_stride(params.mask_row_stride, _1{}) ); Tensor gBias = make_tensor( make_gmem_ptr(reinterpret_cast(params.bias_ptr) + col_offset_bias), Shape, Int>{}, - make_stride(params.bias_row_stride, _1{}) + params.bias_layout_is_k_based + ? make_stride(_0{}, _1{}) // Broadcast across M (query) dimension for k-based layout + : make_stride(params.bias_row_stride, _1{}) ); // Shared memory layout configuration From 3953a4f3ef34aca569197f86f11d82d8f2d22389 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 10 Oct 2025 09:49:51 +0000 Subject: [PATCH 3/5] Uncomment and update mha_varlen_fwd with broadcastable mask/bias support Co-authored-by: LoserCheems <124847097+LoserCheems@users.noreply.github.com> --- csrc/flash_dmattn/flash_api.cpp | 468 ++++++++++++++++++-------------- 1 file changed, 267 insertions(+), 201 deletions(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index 1249d21..dd2f007 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -557,227 +557,293 @@ mha_fwd( return {out, softmax_lse, p}; } +std::vector +mha_varlen_fwd( + at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + std::optional &mask_, // total_q x {1|num_heads_k|num_heads} x max_seqlen_k or total_k x {1|num_heads_k|num_heads} + std::optional &bias_, // total_q x {1|num_heads_k|num_heads} x max_seqlen_k or total_k x {1|num_heads_k|num_heads} + std::optional &out_, // total_q x num_heads x head_size + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + std::optional &leftpad_k_, // batch_size + std::optional &block_table_, // batch_size x max_num_blocks_per_seq + int max_seqlen_q, + const int max_seqlen_k, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + const float softcap, + const bool return_softmax +) { + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.device()}; + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + bool is_sm8x_min = cc_major >= 8; + TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); -// std::vector -// mha_varlen_fwd( -// at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i -// const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. -// const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. -// const at::Tensor &mask, // total_q x {1|num_heads_k|num_heads} x max_seqlen_k or total_k x {1|num_heads_k|num_heads} -// const at::Tensor &bias, // total_q x {1|num_heads_k|num_heads} x max_seqlen_k or total_k x {1|num_heads_k|num_heads} -// std::optional &out_, // total_q x num_heads x head_size -// const at::Tensor &cu_seqlens_q, // b+1 -// const at::Tensor &cu_seqlens_k, // b+1 -// std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. -// std::optional &leftpad_k_, // batch_size -// std::optional &block_table_, // batch_size x max_num_blocks_per_seq -// int max_seqlen_q, -// const int max_seqlen_k, -// const float softmax_scale, -// const bool zero_tensors, -// bool is_causal, -// const float softcap, -// const bool return_softmax -// ) { -// // Otherwise the kernel will be launched from cuda:0 device -// at::cuda::CUDAGuard device_guard{q.device()}; -// auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); -// bool is_sm8x_min = cc_major >= 8; -// TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); -// auto q_dtype = q.dtype(); -// TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); -// TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); -// TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); -// TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); -// TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); -// TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); -// TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(cu_seqlens_q); + CHECK_DEVICE(cu_seqlens_k); -// CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias); -// CHECK_DEVICE(cu_seqlens_q); -// CHECK_DEVICE(cu_seqlens_k); - -// at::Tensor block_table; -// // const bool paged_KV = block_table_.has_value(); -// const bool paged_KV = false; // TODO: Temporarily disable Paged KV, because some bugs are still being fixed. -// if (paged_KV) { -// block_table = block_table_.value(); -// CHECK_DEVICE(block_table); -// TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); -// TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); -// } + auto opts = q.options(); + + bool has_mask = mask_.has_value(); + at::Tensor mask; + bool mask_layout_is_k_based = false; + if (has_mask) { + mask = mask_.value(); + TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); + CHECK_DEVICE(mask); + TORCH_CHECK(mask.stride(-1) == 1, "mask must have contiguous last dimension"); + } else { + mask = torch::empty({0}, opts); + } + + bool has_bias = bias_.has_value(); + at::Tensor bias; + bool bias_layout_is_k_based = false; + if (has_bias) { + bias = bias_.value(); + TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); + CHECK_DEVICE(bias); + TORCH_CHECK(bias.stride(-1) == 1, "bias must have contiguous last dimension"); + } else { + bias = torch::empty({0}, opts); + } -// TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); -// TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); -// TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); -// TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); -// TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); -// CHECK_CONTIGUOUS(cu_seqlens_q); -// CHECK_CONTIGUOUS(cu_seqlens_k); + at::Tensor block_table; + // const bool paged_KV = block_table_.has_value(); + const bool paged_KV = false; // TODO: Temporarily disable Paged KV, because some bugs are still being fixed. + if (paged_KV) { + block_table = block_table_.value(); + CHECK_DEVICE(block_table); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + } -// const auto sizes = q.sizes(); + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); -// const int batch_size = cu_seqlens_q.numel() - 1; -// int num_heads = sizes[1]; -// const int head_size = sizes[2]; -// const int num_heads_k = paged_KV ? k.size(2) : k.size(1); - -// const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); -// const int num_blocks = !paged_KV ? 0 : k.size(0); -// const int page_block_size = !paged_KV ? 1 : k.size(1); -// TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); - -// if (max_seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case - -// void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); - -// // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case -// // H/t Daniel Haziza -// const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0; -// const int ngroups = num_heads / num_heads_k; -// if (seqlenq_ngroups_swapped) { -// q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); -// max_seqlen_q = ngroups; -// num_heads = num_heads_k; -// cu_seqlens_q_d = nullptr; -// } + const auto sizes = q.sizes(); -// const int total_q = q.sizes()[0]; + const int batch_size = cu_seqlens_q.numel() - 1; + int num_heads = sizes[1]; + const int head_size = sizes[2]; + const int num_heads_k = paged_KV ? k.size(2) : k.size(1); -// TORCH_CHECK(batch_size > 0, "batch size must be positive"); -// TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256"); -// TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); -// TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); + const int num_blocks = !paged_KV ? 0 : k.size(0); + const int page_block_size = !paged_KV ? 1 : k.size(1); + TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); -// CHECK_SHAPE(q, total_q, num_heads, head_size); -// if (!paged_KV) { -// const int total_k = k.size(0); -// CHECK_SHAPE(k, total_k, num_heads_k, head_size); -// CHECK_SHAPE(v, total_k, num_heads_k, head_size); -// CHECK_SHAPE(mask, total_q, num_heads_k, max_seqlen_k); -// CHECK_SHAPE(bias, total_q, num_heads_k, max_seqlen_k); -// } else { -// CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size); -// CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size); -// CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); -// } + if (max_seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case -// CHECK_SHAPE(cu_seqlens_q, batch_size + 1); -// CHECK_SHAPE(cu_seqlens_k, batch_size + 1); -// if (seqused_k.has_value()){ -// auto seqused_k_ = seqused_k.value(); -// TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); -// TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); -// TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); -// CHECK_SHAPE(seqused_k_, batch_size); -// } + void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); -// at::Tensor out; -// if (out_.has_value()) { -// out = out_.value(); -// TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); -// CHECK_DEVICE(out); -// TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); -// CHECK_SHAPE(out, sizes[0], sizes[1], head_size); -// if (seqlenq_ngroups_swapped) { -// out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); -// } -// } else { -// out = torch::empty_like(q); -// } + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0; + const int ngroups = num_heads / num_heads_k; + if (seqlenq_ngroups_swapped) { + q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); + max_seqlen_q = ngroups; + num_heads = num_heads_k; + cu_seqlens_q_d = nullptr; + } -// auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; -// const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); -// const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); -// const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + const int total_q = q.sizes()[0]; -// auto opts = q.options(); -// auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); -// at::Tensor p; + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256"); + TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); -// if (return_softmax) { -// p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); -// } else { -// p = torch::empty({ 0 }, opts); -// } + CHECK_SHAPE(q, total_q, num_heads, head_size); + if (!paged_KV) { + const int total_k = k.size(0); + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size); + + // Detect mask/bias layout based on first dimension + if (has_mask) { + if (mask.dim() == 2 && mask.size(0) == total_k) { + // total_k-based layout: (total_k, num_heads_variant) + mask_layout_is_k_based = true; + int num_heads_mask = mask.size(1); + TORCH_CHECK(num_heads_mask == 1 || num_heads_mask == num_heads_k || num_heads_mask == num_heads, + "Number of heads in k-based mask must be 1, h_k or h"); + } else if (mask.dim() == 3 && mask.size(0) == total_q) { + // total_q-based layout: (total_q, num_heads_variant, max_seqlen_k) + mask_layout_is_k_based = false; + int num_heads_mask = mask.size(1); + TORCH_CHECK(num_heads_mask == 1 || num_heads_mask == num_heads_k || num_heads_mask == num_heads, + "Number of heads in q-based mask must be 1, h_k or h"); + CHECK_SHAPE(mask, total_q, num_heads_mask, max_seqlen_k); + } else { + TORCH_CHECK(false, "mask must be (total_k, num_heads_variant) or (total_q, num_heads_variant, max_seqlen_k)"); + } + } + + if (has_bias) { + if (bias.dim() == 2 && bias.size(0) == total_k) { + // total_k-based layout: (total_k, num_heads_variant) + bias_layout_is_k_based = true; + int num_heads_bias = bias.size(1); + TORCH_CHECK(num_heads_bias == 1 || num_heads_bias == num_heads_k || num_heads_bias == num_heads, + "Number of heads in k-based bias must be 1, h_k or h"); + } else if (bias.dim() == 3 && bias.size(0) == total_q) { + // total_q-based layout: (total_q, num_heads_variant, max_seqlen_k) + bias_layout_is_k_based = false; + int num_heads_bias = bias.size(1); + TORCH_CHECK(num_heads_bias == 1 || num_heads_bias == num_heads_k || num_heads_bias == num_heads, + "Number of heads in q-based bias must be 1, h_k or h"); + CHECK_SHAPE(bias, total_q, num_heads_bias, max_seqlen_k); + } else { + TORCH_CHECK(false, "bias must be (total_k, num_heads_variant) or (total_q, num_heads_variant, max_seqlen_k)"); + } + } + } else { + CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size); + CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size); + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + } -// if (zero_tensors) { -// out.zero_(); -// softmax_lse.fill_(-std::numeric_limits::infinity()); -// if (return_softmax) { p.zero_(); } -// } + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + if (seqused_k.has_value()){ + auto seqused_k_ = seqused_k.value(); + TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); + TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); + TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); + CHECK_SHAPE(seqused_k_, batch_size); + } -// Flash_fwd_params params; -// set_params_fprop( -// params, -// batch_size, -// max_seqlen_q, max_seqlen_k, -// seqlen_q_rounded, seqlen_k_rounded, -// num_heads, num_heads_k, -// head_size, head_size_rounded, -// q, k, v, mask, bias, out, -// cu_seqlens_q_d, -// cu_seqlens_k.data_ptr(), -// seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, -// return_softmax ? p.data_ptr() : nullptr, -// softmax_lse.data_ptr(), -// softmax_scale, -// is_causal, -// softcap, -// seqlenq_ngroups_swapped, -// /*unpadded_lse*/true -// ); -// params.total_q = total_q; + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, sizes[0], sizes[1], head_size); + if (seqlenq_ngroups_swapped) { + out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); + } + } else { + out = torch::empty_like(q); + } -// if (paged_KV) { -// params.block_table = block_table.data_ptr(); -// params.block_table_batch_stride = block_table.stride(0); -// params.k_batch_stride = k.stride(0); -// params.v_batch_stride = v.stride(0); -// } -// params.page_block_size = page_block_size; -// // Keep references to these tensors to extend their lifetime -// at::Tensor softmax_lse_accum, out_accum; -// if (seqlenq_ngroups_swapped) { -// // Only apply split-k for decoding -// std::tie(softmax_lse_accum, out_accum) = -// set_params_splitkv( -// params, batch_size, num_heads, head_size, -// max_seqlen_k, max_seqlen_q, head_size_rounded, -// /*num_splits*/ 0, get_num_sm(get_current_device()), opts -// ); -// } + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); + const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); -// if (leftpad_k_.has_value()) { -// auto leftpad_k = leftpad_k_.value(); -// TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); -// TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); -// CHECK_DEVICE(leftpad_k); -// CHECK_CONTIGUOUS(leftpad_k); -// CHECK_SHAPE(leftpad_k, batch_size); -// params.leftpad_k = static_cast(leftpad_k.data_ptr()); -// } + auto opts = q.options(); + auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); + at::Tensor p; -// if (max_seqlen_k > 0) { -// auto stream = at::cuda::getCurrentCUDAStream().stream(); -// run_mha_fwd(params, stream, paged_KV); -// } else { -// // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. -// out.zero_(); -// softmax_lse.fill_(std::numeric_limits::infinity()); -// } + if (return_softmax) { + p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); + } else { + p = torch::empty({ 0 }, opts); + } -// if (seqlenq_ngroups_swapped) { -// int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size}; -// int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size}; -// out = out.reshape(size_before).transpose(1, 2).reshape(size_after); -// q = q.reshape(size_before).transpose(1, 2).reshape(size_after); -// softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size}); -// } + if (zero_tensors) { + out.zero_(); + softmax_lse.fill_(-std::numeric_limits::infinity()); + if (return_softmax) { p.zero_(); } + } -// return {out, softmax_lse, p}; -// } + int num_heads_mask = has_mask ? mask.size(mask_layout_is_k_based ? 1 : 1) : 1; + int num_heads_bias = has_bias ? bias.size(bias_layout_is_k_based ? 1 : 1) : 1; + + Flash_fwd_params params; + set_params_fprop( + params, + batch_size, + max_seqlen_q, max_seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + num_heads_mask, num_heads_bias, + head_size, head_size_rounded, + q, k, v, mask, bias, out, + cu_seqlens_q_d, + cu_seqlens_k.data_ptr(), + seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, + return_softmax ? p.data_ptr() : nullptr, + softmax_lse.data_ptr(), + softmax_scale, + is_causal, + softcap, + has_mask, + has_bias, + mask_layout_is_k_based, + bias_layout_is_k_based, + seqlenq_ngroups_swapped, + /*unpadded_lse*/true + ); + params.total_q = total_q; + + if (paged_KV) { + params.block_table = block_table.data_ptr(); + params.block_table_batch_stride = block_table.stride(0); + params.k_batch_stride = k.stride(0); + params.v_batch_stride = v.stride(0); + } + params.page_block_size = page_block_size; + // Keep references to these tensors to extend their lifetime + at::Tensor softmax_lse_accum, out_accum; + if (seqlenq_ngroups_swapped) { + // Only apply split-k for decoding + std::tie(softmax_lse_accum, out_accum) = + set_params_splitkv( + params, batch_size, num_heads, head_size, + max_seqlen_k, max_seqlen_q, head_size_rounded, + /*num_splits*/ 0, get_num_sm(get_current_device()), opts + ); + } + + if (leftpad_k_.has_value()) { + auto leftpad_k = leftpad_k_.value(); + TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + CHECK_DEVICE(leftpad_k); + CHECK_CONTIGUOUS(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + params.leftpad_k = static_cast(leftpad_k.data_ptr()); + } + + if (max_seqlen_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream, paged_KV); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + if (seqlenq_ngroups_swapped) { + int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size}; + int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size}; + out = out.reshape(size_before).transpose(1, 2).reshape(size_after); + q = q.reshape(size_before).transpose(1, 2).reshape(size_after); + softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size}); + } + + return {out, softmax_lse, p}; +} void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { FP16_SWITCH(!params.is_bf16, [&] { @@ -1302,7 +1368,7 @@ mha_bwd( PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "FlashDynamicMaskAttention"; m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass"); - // m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass with variable length"); + m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass with variable length"); m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass"); // m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass with variable length"); } From f69933d555b98d68462595979246bb75e474e74e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 10 Oct 2025 09:53:28 +0000 Subject: [PATCH 4/5] Add documentation and examples for broadcastable mask/bias feature Co-authored-by: LoserCheems <124847097+LoserCheems@users.noreply.github.com> --- docs/varlen_broadcastable_mask.md | 211 +++++++++++++++++++ examples/varlen_broadcastable_example.py | 256 +++++++++++++++++++++++ 2 files changed, 467 insertions(+) create mode 100644 docs/varlen_broadcastable_mask.md create mode 100644 examples/varlen_broadcastable_example.py diff --git a/docs/varlen_broadcastable_mask.md b/docs/varlen_broadcastable_mask.md new file mode 100644 index 0000000..8eb477a --- /dev/null +++ b/docs/varlen_broadcastable_mask.md @@ -0,0 +1,211 @@ +# Variable Length Batch Inference with Broadcastable Key-based Masks/Bias + +## Overview + +This feature enables efficient batch inference with variable-length sequences using key-side broadcastable masks and bias tensors. Instead of materializing per-query masks/bias of shape `(total_q, num_heads, max_seqlen_k)`, you can now provide compact key-side tensors of shape `(total_k, num_heads_variant)` that broadcast across query positions. + +## Motivation + +In autoregressive decoding with dynamic sparsity: +- Queries are typically short (1-8 tokens per batch element) +- Keys/values can be thousands of tokens from the KV cache +- Precomputed key-side gating scores are naturally shaped `(total_k, num_heads)` +- Reshaping to per-query layout wastes O(total_q * num_heads) memory +- Streaming workloads cannot backfill materialized copies + +## Supported Layouts + +### Traditional Query-based Layout (existing) +```python +# Mask: (total_q, {1|num_heads_k|num_heads}, max_seqlen_k) +# Bias: (total_q, {1|num_heads_k|num_heads}, max_seqlen_k) +``` +Each query position has its own mask/bias slice. This is the default when the first dimension equals `total_q`. + +### New Key-based Broadcastable Layout +```python +# Mask: (total_k, {1|num_heads_k|num_heads}) +# Bias: (total_k, {1|num_heads_k|num_heads}) +``` +A single mask/bias value per key position, broadcast across all query positions. Automatically detected when the first dimension equals `total_k`. + +## Usage Example + +```python +import torch +from flash_dmattn import flash_dmattn_varlen_func + +batch_size = 4 +max_seqlen_q = 2 # Typical for decoding +max_seqlen_k = 1024 # Large KV cache +num_heads = 32 +num_heads_k = 8 # GQA +head_dim = 128 + +# Create variable length sequences +cu_seqlens_q = torch.tensor([0, 1, 3, 4, 6], dtype=torch.int32, device='cuda') # total_q = 6 +cu_seqlens_k = torch.tensor([0, 256, 512, 768, 1024], dtype=torch.int32, device='cuda') # total_k = 1024 + +# Query, key, value tensors +q = torch.randn(6, num_heads, head_dim, dtype=torch.float16, device='cuda') +k = torch.randn(1024, num_heads_k, head_dim, dtype=torch.float16, device='cuda') +v = torch.randn(1024, num_heads_k, head_dim, dtype=torch.float16, device='cuda') + +# Key-based broadcastable mask and bias (NEW!) +# Shape: (total_k, num_heads_variant) - broadcasts across query positions +attn_mask = torch.randint(0, 2, (1024, num_heads_k), dtype=torch.bool, device='cuda') +attn_bias = torch.randn(1024, num_heads_k, dtype=torch.float16, device='cuda') + +# Call varlen function - layout detection is automatic +output = flash_dmattn_varlen_func( + query=q, + key=k, + value=v, + attn_mask=attn_mask, # Automatically detected as k-based + attn_bias=attn_bias, # Automatically detected as k-based + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, +) +``` + +## Layout Detection + +The implementation automatically detects which layout is being used: + +```python +if mask.dim() == 2 and mask.size(0) == total_k: + # Key-based layout: (total_k, num_heads_variant) + # Broadcast across query positions +elif mask.dim() == 3 and mask.size(0) == total_q: + # Query-based layout: (total_q, num_heads_variant, max_seqlen_k) + # Per-query mask slices +``` + +The same logic applies independently to both mask and bias tensors. + +## Head Dimension Broadcasting + +Both layouts support flexible head dimensions: +- `1`: Single mask/bias for all heads (broadcast across all heads) +- `num_heads_k`: One per KV head (broadcast across query head groups in GQA) +- `num_heads`: One per query head (no broadcasting) + +## Performance Benefits + +### Memory Savings +For typical decoding scenarios: +- Query-based: `total_q × num_heads × max_seqlen_k` elements +- Key-based: `total_k × num_heads` elements +- **Savings**: ~`(total_q × max_seqlen_k) / num_heads` reduction + +Example: With `total_q=8`, `max_seqlen_k=2048`, `num_heads=32`: +- Query-based: 524,288 elements +- Key-based: 65,536 elements +- **87.5% memory reduction** + +### Computational Efficiency +- No host-side tensor reshaping or copying +- Direct key-side indexing in CUDA kernels +- Maintains streaming-friendly data layout +- Zero materialization overhead + +## Implementation Details + +### Kernel Changes +The CUDA kernels handle broadcasting by: +1. Using stride `_0{}` for the query dimension in key-based tensors +2. Adjusting offset calculations to skip query-position indexing +3. Reading from the same key-position value for all queries + +```cpp +// Key-based layout +Tensor gMask = make_tensor( + ptr, + Shape, Int>{}, + make_stride(_0{}, _1{}) // Zero stride = broadcast across M +); + +// Query-based layout +Tensor gMask = make_tensor( + ptr, + Shape, Int>{}, + make_stride(mask_row_stride, _1{}) // Normal 2D indexing +); +``` + +### API Parameters + +The C++ API signature is: +```cpp +std::vector mha_varlen_fwd( + at::Tensor &q, // total_q x num_heads x head_size + const at::Tensor &k, // total_k x num_heads_k x head_size + const at::Tensor &v, // total_k x num_heads_k x head_size + std::optional &mask_, // (total_q, h, k) or (total_k, h) + std::optional &bias_, // (total_q, h, k) or (total_k, h) + std::optional &out_, // total_q x num_heads x head_size + const at::Tensor &cu_seqlens_q, // batch_size + 1 + const at::Tensor &cu_seqlens_k, // batch_size + 1 + ... +); +``` + +Layout detection happens automatically based on tensor shapes. + +## Use Cases + +### Autoregressive Decoding with Dynamic Sparsity +```python +# Precompute key-side attention scores from dependency graph +key_scores = compute_dependency_scores(kv_cache) # (total_k, num_heads) +key_mask = key_scores > threshold + +# Use directly without reshaping +output = flash_dmattn_varlen_func(..., attn_mask=key_mask) +``` + +### Batch Decode with Shared Key Filtering +```python +# Apply same key filtering to all queries in batch +key_importance = model.compute_key_importance(keys) # (total_k, 1) +key_mask = key_importance > threshold + +# Broadcast to all heads +output = flash_dmattn_varlen_func(..., attn_mask=key_mask) +``` + +### MaskMod Pipelines +```python +# Dependency-aware masking from MaskMod +from torch.nn.attention.flex_attention import create_mask + +# Generate key-side mask efficiently +key_mask = create_mask_mod_k_based(...) # (total_k, num_heads) + +# Direct usage without conversion +output = flash_dmattn_varlen_func(..., attn_mask=key_mask) +``` + +## Limitations + +- Only supported in `mha_varlen_fwd` (variable length forward pass) +- Backward pass (gradient computation) uses query-based layout +- Paged KV cache support is experimental +- Both mask and bias can independently use either layout + +## Compatibility + +- GPU: Requires Ampere (SM80) or newer +- PyTorch: Compatible with existing Flash Attention interfaces +- Mixed Layouts: Mask and bias can use different layouts in the same call +- GQA/MQA: Full support for grouped-query and multi-query attention + +## Related Work + +This feature aligns with: +- Sparse attention patterns in modern LLMs +- Efficient KV cache management +- Streaming inference workloads +- MaskMod and FlexAttention paradigms diff --git a/examples/varlen_broadcastable_example.py b/examples/varlen_broadcastable_example.py new file mode 100644 index 0000000..6b81636 --- /dev/null +++ b/examples/varlen_broadcastable_example.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 +""" +Example demonstrating variable-length batch inference with broadcastable key-based masks and bias. + +This example shows how to use the new total_k-based mask/bias layout for efficient decoding +with variable-length sequences and large KV caches. +""" + +import torch +from flash_dmattn import flash_dmattn_varlen_func + + +def create_varlen_sequences(batch_size, max_seqlen_q, max_seqlen_k): + """Create variable-length sequences for demonstration.""" + import random + + # Generate random sequence lengths + seqlens_q = [random.randint(1, max_seqlen_q) for _ in range(batch_size)] + seqlens_k = [random.randint(max_seqlen_k // 2, max_seqlen_k) for _ in range(batch_size)] + + # Create cumulative sequence length tensors + cu_seqlens_q = torch.tensor([0] + list(torch.tensor(seqlens_q).cumsum(0).tolist()), + dtype=torch.int32, device='cuda') + cu_seqlens_k = torch.tensor([0] + list(torch.tensor(seqlens_k).cumsum(0).tolist()), + dtype=torch.int32, device='cuda') + + total_q = cu_seqlens_q[-1].item() + total_k = cu_seqlens_k[-1].item() + + return cu_seqlens_q, cu_seqlens_k, total_q, total_k, max(seqlens_q), max(seqlens_k) + + +def example_k_based_mask_bias(): + """Example using key-based broadcastable mask and bias (NEW FEATURE).""" + print("=" * 80) + print("Example: Key-based Broadcastable Mask and Bias") + print("=" * 80) + + # Configuration + batch_size = 4 + max_seqlen_q = 8 # Short queries (typical for decoding) + max_seqlen_k = 1024 # Large KV cache + num_heads = 32 + num_heads_k = 8 # Grouped-query attention + head_dim = 128 + + print(f"Batch size: {batch_size}") + print(f"Max query length: {max_seqlen_q}") + print(f"Max key length: {max_seqlen_k}") + print(f"Num query heads: {num_heads}") + print(f"Num KV heads: {num_heads_k}") + print(f"Head dimension: {head_dim}") + print() + + # Create variable-length sequences + cu_seqlens_q, cu_seqlens_k, total_q, total_k, actual_max_q, actual_max_k = \ + create_varlen_sequences(batch_size, max_seqlen_q, max_seqlen_k) + + print(f"Total query tokens: {total_q}") + print(f"Total key tokens: {total_k}") + print(f"Actual max query length: {actual_max_q}") + print(f"Actual max key length: {actual_max_k}") + print() + + # Create query, key, value tensors + q = torch.randn(total_q, num_heads, head_dim, dtype=torch.float16, device='cuda') + k = torch.randn(total_k, num_heads_k, head_dim, dtype=torch.float16, device='cuda') + v = torch.randn(total_k, num_heads_k, head_dim, dtype=torch.float16, device='cuda') + + # KEY FEATURE: Key-based broadcastable mask and bias + # Shape: (total_k, num_heads_k) - broadcasts across ALL query positions + # This saves memory compared to (total_q, num_heads_k, max_seqlen_k) + attn_mask = torch.randint(0, 2, (total_k, num_heads_k), dtype=torch.bool, device='cuda') + attn_bias = torch.randn(total_k, num_heads_k, dtype=torch.float16, device='cuda') + + print(f"Query shape: {q.shape}") + print(f"Key shape: {k.shape}") + print(f"Value shape: {v.shape}") + print(f"Mask shape (key-based): {attn_mask.shape}") + print(f"Bias shape (key-based): {attn_bias.shape}") + print() + + # Memory comparison + q_based_elements = total_q * num_heads_k * actual_max_k + k_based_elements = total_k * num_heads_k + memory_saving = (1 - k_based_elements / q_based_elements) * 100 + + print("Memory comparison:") + print(f" Query-based layout would need: {q_based_elements:,} elements") + print(f" Key-based layout needs: {k_based_elements:,} elements") + print(f" Memory savings: {memory_saving:.1f}%") + print() + + # Call flash attention with automatic layout detection + print("Running flash attention with key-based mask/bias...") + output = flash_dmattn_varlen_func( + query=q, + key=k, + value=v, + attn_mask=attn_mask, # Automatically detected as key-based + attn_bias=attn_bias, # Automatically detected as key-based + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=actual_max_q, + max_seqlen_k=actual_max_k, + softmax_scale=1.0 / (head_dim ** 0.5), + ) + + print(f"Output shape: {output.shape}") + print("Success! ✓") + print() + + +def example_q_based_mask_bias(): + """Example using traditional query-based mask and bias (for comparison).""" + print("=" * 80) + print("Example: Traditional Query-based Mask and Bias") + print("=" * 80) + + # Configuration + batch_size = 4 + max_seqlen_q = 8 + max_seqlen_k = 1024 + num_heads = 32 + num_heads_k = 8 + head_dim = 128 + + # Create variable-length sequences + cu_seqlens_q, cu_seqlens_k, total_q, total_k, actual_max_q, actual_max_k = \ + create_varlen_sequences(batch_size, max_seqlen_q, max_seqlen_k) + + print(f"Total query tokens: {total_q}") + print(f"Total key tokens: {total_k}") + print() + + # Create query, key, value tensors + q = torch.randn(total_q, num_heads, head_dim, dtype=torch.float16, device='cuda') + k = torch.randn(total_k, num_heads_k, head_dim, dtype=torch.float16, device='cuda') + v = torch.randn(total_k, num_heads_k, head_dim, dtype=torch.float16, device='cuda') + + # Traditional query-based mask and bias + # Shape: (total_q, num_heads_k, max_seqlen_k) + attn_mask = torch.randint(0, 2, (total_q, num_heads_k, actual_max_k), + dtype=torch.bool, device='cuda') + attn_bias = torch.randn(total_q, num_heads_k, actual_max_k, + dtype=torch.float16, device='cuda') + + print(f"Query shape: {q.shape}") + print(f"Key shape: {k.shape}") + print(f"Value shape: {v.shape}") + print(f"Mask shape (query-based): {attn_mask.shape}") + print(f"Bias shape (query-based): {attn_bias.shape}") + print() + + # Call flash attention + print("Running flash attention with query-based mask/bias...") + output = flash_dmattn_varlen_func( + query=q, + key=k, + value=v, + attn_mask=attn_mask, + attn_bias=attn_bias, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=actual_max_q, + max_seqlen_k=actual_max_k, + softmax_scale=1.0 / (head_dim ** 0.5), + ) + + print(f"Output shape: {output.shape}") + print("Success! ✓") + print() + + +def example_mixed_layouts(): + """Example using key-based mask with query-based bias (mixed layouts).""" + print("=" * 80) + print("Example: Mixed Layouts - Key-based Mask + Query-based Bias") + print("=" * 80) + + # Configuration + batch_size = 2 + max_seqlen_q = 4 + max_seqlen_k = 512 + num_heads = 16 + num_heads_k = 4 + head_dim = 64 + + # Create variable-length sequences + cu_seqlens_q, cu_seqlens_k, total_q, total_k, actual_max_q, actual_max_k = \ + create_varlen_sequences(batch_size, max_seqlen_q, max_seqlen_k) + + print(f"Total query tokens: {total_q}") + print(f"Total key tokens: {total_k}") + print() + + # Create query, key, value tensors + q = torch.randn(total_q, num_heads, head_dim, dtype=torch.float16, device='cuda') + k = torch.randn(total_k, num_heads_k, head_dim, dtype=torch.float16, device='cuda') + v = torch.randn(total_k, num_heads_k, head_dim, dtype=torch.float16, device='cuda') + + # Key-based mask (broadcast across queries) + attn_mask = torch.randint(0, 2, (total_k, num_heads_k), dtype=torch.bool, device='cuda') + + # Query-based bias (per-query values) + attn_bias = torch.randn(total_q, num_heads_k, actual_max_k, dtype=torch.float16, device='cuda') + + print(f"Mask shape (key-based): {attn_mask.shape}") + print(f"Bias shape (query-based): {attn_bias.shape}") + print() + + # Call flash attention - each tensor uses its own layout + print("Running flash attention with mixed layouts...") + output = flash_dmattn_varlen_func( + query=q, + key=k, + value=v, + attn_mask=attn_mask, # Key-based + attn_bias=attn_bias, # Query-based + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=actual_max_q, + max_seqlen_k=actual_max_k, + softmax_scale=1.0 / (head_dim ** 0.5), + ) + + print(f"Output shape: {output.shape}") + print("Success! ✓") + print() + + +def main(): + """Run all examples.""" + if not torch.cuda.is_available(): + print("CUDA is not available. These examples require a CUDA-capable GPU.") + return + + print("Flash Dynamic Mask Attention - Variable Length Broadcastable Examples") + print() + + # Example 1: Key-based mask and bias (NEW FEATURE) + example_k_based_mask_bias() + + # Example 2: Traditional query-based (for comparison) + example_q_based_mask_bias() + + # Example 3: Mixed layouts + example_mixed_layouts() + + print("=" * 80) + print("All examples completed successfully!") + print("=" * 80) + + +if __name__ == "__main__": + main() From 54d05fc1bc43b59c59860ac3ffb327d9b665db2f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 10 Oct 2025 09:54:48 +0000 Subject: [PATCH 5/5] Update README with new broadcastable mask/bias feature Co-authored-by: LoserCheems <124847097+LoserCheems@users.noreply.github.com> --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 13c80f7..d3960fa 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ Flash-DMA is a high-performance attention implementation that integrates Flash A ### 🎯 Core Kernel Advantages - **Mask & Bias Support**: Native support for `(batch_size, {1|num_kv_heads|num_heads}, {0|query_len}, key_len)` shaped attention mask and attention bias tensors +- **Broadcastable Key-based Layouts** 🆕: Variable-length batch inference now supports efficient `(total_k, num_heads_variant)` mask/bias tensors that broadcast across query positions, eliminating redundant materialization for autoregressive decoding - **Intelligent Computation Skipping**: Block-level automatic skipping mechanism based on masks, completely bypassing computation and memory access for zero-mask blocks - **Complete Gradient Support**: Built-in full gradient computation path for attention bias, supporting end-to-end training @@ -254,6 +255,7 @@ Flash-DMA integrates the efficient memory access patterns of Flash Attention wit - **[API Reference](docs/api_reference.md)** - Complete function documentation and usage examples - **[Integration Guide](docs/integration.md)** - Detailed technical documentation of the Flash Attention integration +- **[Variable Length Broadcastable Mask/Bias](docs/varlen_broadcastable_mask.md)** 🆕 - Guide to using key-based broadcastable masks and bias for efficient batch inference ## Building from Source