diff --git a/.github/ISSUE_TEMPLATE/performance_issue.yml b/.github/ISSUE_TEMPLATE/performance_issue.yml index c3ecca4..2586dbe 100644 --- a/.github/ISSUE_TEMPLATE/performance_issue.yml +++ b/.github/ISSUE_TEMPLATE/performance_issue.yml @@ -1,5 +1,5 @@ name: Performance issue -description: Report performance problems or optimisation opportunities +description: Report performance problems or optimization opportunities title: "[PERFORMANCE] " labels: - performance diff --git a/.github/PULL_REQUEST_TEMPLATE/performance_optimization.yml b/.github/PULL_REQUEST_TEMPLATE/performance_optimization.yml index 5a00c15..8d6f0b0 100644 --- a/.github/PULL_REQUEST_TEMPLATE/performance_optimization.yml +++ b/.github/PULL_REQUEST_TEMPLATE/performance_optimization.yml @@ -7,7 +7,7 @@ body: - type: markdown attributes: value: | - Document the optimisation, methodology, and results so reviewers can validate gains and correctness. + Document the optimization, methodology, and results so reviewers can validate gains and correctness. - type: textarea id: summary attributes: diff --git a/README.md b/README.md index 13c80f7..89c6082 100644 --- a/README.md +++ b/README.md @@ -195,7 +195,7 @@ output = flash_dmattn_func( attn_mask=attention_mask, attn_bias=attention_bias, is_causal=True, - scale=1.0/math.sqrt(head_dim), + softmax_scale=1.0/math.sqrt(head_dim), ) print(f"Output shape: {output.shape}") # [1, 256, 2, 64] @@ -216,7 +216,7 @@ output = flash_dmattn_func( attn_mask=attention_mask, attn_bias=attention_bias, is_causal=True, - scale=1.0/math.sqrt(head_dim) + softmax_scale=1.0/math.sqrt(head_dim) ) # Backward pass diff --git a/README_zh.md b/README_zh.md index 484a1a8..8e16c41 100644 --- a/README_zh.md +++ b/README_zh.md @@ -195,7 +195,7 @@ output = flash_dmattn_func( attn_mask=attention_mask, attn_bias=attention_bias, is_causal=True, - scale=1.0/math.sqrt(head_dim), + softmax_scale=1.0/math.sqrt(head_dim), ) print(f"输出形状: {output.shape}") # [1, 256, 2, 64] @@ -216,7 +216,7 @@ output = flash_dmattn_func( attn_mask=attention_mask, attn_bias=attention_bias, is_causal=True, - scale=1.0/math.sqrt(head_dim) + softmax_scale=1.0/math.sqrt(head_dim) ) # 反向传播 diff --git a/benchmarks/backward_equivalence.py b/benchmarks/backward_equivalence.py index 804433f..fd8ebef 100644 --- a/benchmarks/backward_equivalence.py +++ b/benchmarks/backward_equivalence.py @@ -266,7 +266,7 @@ def dynamic_mask_attention_cuda( attn_mask=attn_mask, # mask: [batch, num_kv_heads, query_len, key_len] attn_bias=attn_bias, # bias: [batch, num_kv_heads, query_len, key_len] is_causal=is_causal, # causal masking - scale=scaling, # scaling factor + softmax_scale=scaling, # scaling factor softcap=0.0, deterministic=False, return_attn_probs=False @@ -351,7 +351,7 @@ def dynamic_mask_attention_triton( attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] is_causal=is_causal, # causal masking - scale=scaling # scaling factor + softmax_scale=scaling # scaling factor ) # Backward pass @@ -424,7 +424,7 @@ def dynamic_mask_attention_flex( attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len] attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len] is_causal=is_causal, # is_causal: whether to apply causal masking - scale=scaling # scaling factor + softmax_scale=scaling # scaling factor ) # Backward pass diff --git a/benchmarks/backward_performance.py b/benchmarks/backward_performance.py index afccef9..03d5018 100644 --- a/benchmarks/backward_performance.py +++ b/benchmarks/backward_performance.py @@ -183,7 +183,7 @@ def scaled_dot_product_attention_backward( key_states, # [batch, num_kv_heads, key_len, head_dim] value_states, # [batch, num_kv_heads, key_len, head_dim] attn_mask=causal_mask, - scale=scaling, + softmax_scale=scaling, # is_causal=is_causal if query_len == key_len else False, enable_gqa=True ) @@ -262,7 +262,7 @@ def dynamic_mask_attention_backward_cuda( attn_mask=attn_mask, # mask: [batch, num_kv_heads, query_len, key_len] attn_bias=attn_bias, # bias: [batch, num_kv_heads, query_len, key_len] is_causal=is_causal, # causal masking - scale=scaling, # scaling factor + softmax_scale=scaling, # scaling factor softcap=0.0, deterministic=False, return_attn_probs=False @@ -351,7 +351,7 @@ def dynamic_mask_attention_backward_triton( attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] is_causal=is_causal, # causal masking - scale=scaling # scaling factor + softmax_scale=scaling # scaling factor ) torch.cuda.synchronize() @@ -433,7 +433,7 @@ def dynamic_mask_attention_backward_flex( attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len] attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len] is_causal=is_causal, # is_causal: whether to apply causal masking - scale=scaling # scaling factor + softmax_scale=scaling # scaling factor ) torch.cuda.synchronize() diff --git a/benchmarks/forward_equivalence.py b/benchmarks/forward_equivalence.py index 94fc0cf..1da6f2a 100644 --- a/benchmarks/forward_equivalence.py +++ b/benchmarks/forward_equivalence.py @@ -253,7 +253,7 @@ def dynamic_mask_attention_cuda( attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len] attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len] is_causal=is_causal, - scale=scaling, + softmax_scale=scaling, softcap=0.0, deterministic=True, return_attn_probs=return_softmax @@ -329,7 +329,7 @@ def dynamic_mask_attention_triton( attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] is_causal=is_causal, # causal masking - scale=scaling # scaling factor + softmax_scale=scaling # scaling factor ) return attn_outputs # [batch, query_len, num_heads, head_dim] @@ -398,7 +398,7 @@ def dynamic_mask_attention_flex( attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len] attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len] is_causal=is_causal, # is_causal: whether to apply causal masking - scale=scaling # scaling factor + softmax_scale=scaling # scaling factor ) return attn_outputs # [batch, query_len, num_heads, head_dim] diff --git a/benchmarks/forward_performance.py b/benchmarks/forward_performance.py index b540bf6..0d48c1a 100644 --- a/benchmarks/forward_performance.py +++ b/benchmarks/forward_performance.py @@ -186,7 +186,7 @@ def scaled_dot_product_attention_cuda( key_states, value_states, attn_mask=causal_mask, - scale=scaling, + softmax_scale=scaling, # is_causal=is_causal if query_len == key_len else False, enable_gqa=True ) @@ -262,7 +262,7 @@ def dynamic_mask_attention_cuda( attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len] attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len] is_causal=is_causal, - scale=scaling, + softmax_scale=scaling, softcap=0.0, deterministic=False, return_attn_probs=return_softmax @@ -348,7 +348,7 @@ def dynamic_mask_attention_triton( attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] is_causal=is_causal, # causal masking - scale=scaling # scaling factor + softmax_scale=scaling # scaling factor ) torch.cuda.synchronize() @@ -427,7 +427,7 @@ def dynamic_mask_attention_flex( attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len] attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len] is_causal=is_causal, # is_causal: whether to apply causal masking - scale=scaling # scaling factor + softmax_scale=scaling # scaling factor ) torch.cuda.synchronize() diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index d26ecde..0cc7cdf 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -350,17 +350,18 @@ std::tuple set_params_splitkv( std::vector mha_fwd( - at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - std::optional &mask_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k - std::optional &bias_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k - std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + std::optional &mask_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k + std::optional &bias_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k + std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const float softmax_scale, bool is_causal, const float softcap, const bool return_softmax ) { + // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); @@ -443,7 +444,7 @@ mha_fwd( if (seqlenq_ngroups_swapped) { q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); if (has_mask) { - mask = num_heads_mask == 1 + mask = num_heads_mask == 1 ? mask.expand({batch_size, 1, ngroups, seqlen_k}) : ( num_heads_mask == num_heads_k @@ -452,7 +453,7 @@ mha_fwd( ); } if (has_bias) { - bias = num_heads_bias == 1 + bias = num_heads_bias == 1 ? bias.expand({batch_size, 1, ngroups, seqlen_k}) : ( num_heads_bias == num_heads_k @@ -493,9 +494,9 @@ mha_fwd( at::Tensor p; if (return_softmax) { - p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); + p = torch::empty({batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded}, opts); } else { - p = torch::empty({ 0 }, opts); + p = torch::empty({0}, opts); } Flash_fwd_params params; @@ -553,227 +554,227 @@ mha_fwd( return {out, softmax_lse, p}; } +std::vector +mha_varlen_fwd( + at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + std::optional &out_, // total_q x num_heads x head_size + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + std::optional &leftpad_k_, // batch_size + std::optional &block_table_, // batch_size x max_num_blocks_per_seq + int max_seqlen_q, + const int max_seqlen_k, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + const float softcap, + const bool return_softmax +) { -// std::vector -// mha_varlen_fwd( -// at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i -// const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. -// const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. -// const at::Tensor &mask, // total_q x {1|num_heads_k|num_heads} x max_seqlen_k or total_k x {1|num_heads_k|num_heads} -// const at::Tensor &bias, // total_q x {1|num_heads_k|num_heads} x max_seqlen_k or total_k x {1|num_heads_k|num_heads} -// std::optional &out_, // total_q x num_heads x head_size -// const at::Tensor &cu_seqlens_q, // b+1 -// const at::Tensor &cu_seqlens_k, // b+1 -// std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. -// std::optional &leftpad_k_, // batch_size -// std::optional &block_table_, // batch_size x max_num_blocks_per_seq -// int max_seqlen_q, -// const int max_seqlen_k, -// const float softmax_scale, -// const bool zero_tensors, -// bool is_causal, -// const float softcap, -// const bool return_softmax -// ) { -// // Otherwise the kernel will be launched from cuda:0 device -// at::cuda::CUDAGuard device_guard{q.device()}; -// auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); -// bool is_sm8x_min = cc_major >= 8; -// TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); - -// auto q_dtype = q.dtype(); -// TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); -// TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); -// TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); -// TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); -// TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); -// TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); -// TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); - -// CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias); -// CHECK_DEVICE(cu_seqlens_q); -// CHECK_DEVICE(cu_seqlens_k); - -// at::Tensor block_table; -// // const bool paged_KV = block_table_.has_value(); -// const bool paged_KV = false; // TODO: Temporarily disable Paged KV, because some bugs are still being fixed. -// if (paged_KV) { -// block_table = block_table_.value(); -// CHECK_DEVICE(block_table); -// TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); -// TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); -// } - -// TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); -// TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); -// TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); -// TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); -// TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); -// CHECK_CONTIGUOUS(cu_seqlens_q); -// CHECK_CONTIGUOUS(cu_seqlens_k); - -// const auto sizes = q.sizes(); - -// const int batch_size = cu_seqlens_q.numel() - 1; -// int num_heads = sizes[1]; -// const int head_size = sizes[2]; -// const int num_heads_k = paged_KV ? k.size(2) : k.size(1); - -// const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); -// const int num_blocks = !paged_KV ? 0 : k.size(0); -// const int page_block_size = !paged_KV ? 1 : k.size(1); -// TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); - -// if (max_seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case - -// void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); - -// // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case -// // H/t Daniel Haziza -// const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0; -// const int ngroups = num_heads / num_heads_k; -// if (seqlenq_ngroups_swapped) { -// q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); -// max_seqlen_q = ngroups; -// num_heads = num_heads_k; -// cu_seqlens_q_d = nullptr; -// } - -// const int total_q = q.sizes()[0]; - -// TORCH_CHECK(batch_size > 0, "batch size must be positive"); -// TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256"); -// TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); -// TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - -// CHECK_SHAPE(q, total_q, num_heads, head_size); -// if (!paged_KV) { -// const int total_k = k.size(0); -// CHECK_SHAPE(k, total_k, num_heads_k, head_size); -// CHECK_SHAPE(v, total_k, num_heads_k, head_size); -// CHECK_SHAPE(mask, total_q, num_heads_k, max_seqlen_k); -// CHECK_SHAPE(bias, total_q, num_heads_k, max_seqlen_k); -// } else { -// CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size); -// CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size); -// CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); -// } - -// CHECK_SHAPE(cu_seqlens_q, batch_size + 1); -// CHECK_SHAPE(cu_seqlens_k, batch_size + 1); -// if (seqused_k.has_value()){ -// auto seqused_k_ = seqused_k.value(); -// TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); -// TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); -// TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); -// CHECK_SHAPE(seqused_k_, batch_size); -// } - -// at::Tensor out; -// if (out_.has_value()) { -// out = out_.value(); -// TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); -// CHECK_DEVICE(out); -// TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); -// CHECK_SHAPE(out, sizes[0], sizes[1], head_size); -// if (seqlenq_ngroups_swapped) { -// out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); -// } -// } else { -// out = torch::empty_like(q); -// } - -// auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; -// const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); -// const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); -// const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); - -// auto opts = q.options(); -// auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); -// at::Tensor p; - -// if (return_softmax) { -// p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); -// } else { -// p = torch::empty({ 0 }, opts); -// } - -// if (zero_tensors) { -// out.zero_(); -// softmax_lse.fill_(-std::numeric_limits::infinity()); -// if (return_softmax) { p.zero_(); } -// } - -// Flash_fwd_params params; -// set_params_fprop( -// params, -// batch_size, -// max_seqlen_q, max_seqlen_k, -// seqlen_q_rounded, seqlen_k_rounded, -// num_heads, num_heads_k, -// head_size, head_size_rounded, -// q, k, v, mask, bias, out, -// cu_seqlens_q_d, -// cu_seqlens_k.data_ptr(), -// seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, -// return_softmax ? p.data_ptr() : nullptr, -// softmax_lse.data_ptr(), -// softmax_scale, -// is_causal, -// softcap, -// seqlenq_ngroups_swapped, -// /*unpadded_lse*/true -// ); -// params.total_q = total_q; - -// if (paged_KV) { -// params.block_table = block_table.data_ptr(); -// params.block_table_batch_stride = block_table.stride(0); -// params.k_batch_stride = k.stride(0); -// params.v_batch_stride = v.stride(0); -// } -// params.page_block_size = page_block_size; -// // Keep references to these tensors to extend their lifetime -// at::Tensor softmax_lse_accum, out_accum; -// if (seqlenq_ngroups_swapped) { -// // Only apply split-k for decoding -// std::tie(softmax_lse_accum, out_accum) = -// set_params_splitkv( -// params, batch_size, num_heads, head_size, -// max_seqlen_k, max_seqlen_q, head_size_rounded, -// /*num_splits*/ 0, get_num_sm(get_current_device()), opts -// ); -// } - -// if (leftpad_k_.has_value()) { -// auto leftpad_k = leftpad_k_.value(); -// TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); -// TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); -// CHECK_DEVICE(leftpad_k); -// CHECK_CONTIGUOUS(leftpad_k); -// CHECK_SHAPE(leftpad_k, batch_size); -// params.leftpad_k = static_cast(leftpad_k.data_ptr()); -// } - -// if (max_seqlen_k > 0) { -// auto stream = at::cuda::getCurrentCUDAStream().stream(); -// run_mha_fwd(params, stream, paged_KV); -// } else { -// // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. -// out.zero_(); -// softmax_lse.fill_(std::numeric_limits::infinity()); -// } - -// if (seqlenq_ngroups_swapped) { -// int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size}; -// int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size}; -// out = out.reshape(size_before).transpose(1, 2).reshape(size_after); -// q = q.reshape(size_before).transpose(1, 2).reshape(size_after); -// softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size}); -// } - -// return {out, softmax_lse, p}; -// } + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.device()}; + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + bool is_sm8x_min = cc_major >= 8; + TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); + + at::Tensor block_table; + const bool paged_KV = block_table_.has_value(); + if (paged_KV) { + block_table = block_table_.value(); + CHECK_DEVICE(block_table); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + } + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_k); + + auto opts = q.options(); + + bool has_mask = false; + at::Tensor mask; + mask = torch::empty({0}, opts); + bool has_bias = false; + at::Tensor bias; + bias = torch::empty({0}, opts); + + const auto sizes = q.sizes(); + + const int batch_size = cu_seqlens_q.numel() - 1; + int num_heads = sizes[1]; + const int head_size = sizes[2]; + const int num_heads_k = paged_KV ? k.size(2) : k.size(1); + int num_heads_mask = has_mask ? mask.size(1) : 1; + int num_heads_bias = has_bias ? bias.size(1) : 1; + + const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); + const int num_blocks = !paged_KV ? 0 : k.size(0); + const int page_block_size = !paged_KV ? 1 : k.size(1); + TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); + + if (max_seqlen_q == 1) { is_causal = false; } // is_causal=true is the same as is_causal=false in this case + + void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); + + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0; + const int ngroups = num_heads / num_heads_k; + if (seqlenq_ngroups_swapped) { + q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); + max_seqlen_q = ngroups; + num_heads = num_heads_k; + cu_seqlens_q_d = nullptr; + } + + const int total_q = q.sizes()[0]; + + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256"); + TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + CHECK_SHAPE(q, total_q, num_heads, head_size); + if (!paged_KV) { + const int total_k = k.size(0); + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size); + } else { + CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size); + CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size); + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + } + + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + if (seqused_k.has_value()){ + auto seqused_k_ = seqused_k.value(); + TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); + TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); + TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); + CHECK_SHAPE(seqused_k_, batch_size); + } + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, sizes[0], sizes[1], head_size); + if (seqlenq_ngroups_swapped) { + out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); + } + } else { + out = torch::empty_like(q); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); + const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + + auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); + at::Tensor p; + if (return_softmax) { + p = torch::empty({batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded}, opts); + } else { + p = torch::empty({0}, opts); + } + + if (zero_tensors) { + out.zero_(); + softmax_lse.fill_(-std::numeric_limits::infinity()); + if (return_softmax) { p.zero_(); } + } + + Flash_fwd_params params; + set_params_fprop( + params, + batch_size, + max_seqlen_q, max_seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, num_heads_mask, num_heads_bias, + head_size, head_size_rounded, + q, k, v, mask, bias, out, + cu_seqlens_q_d, + cu_seqlens_k.data_ptr(), + seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, + return_softmax ? p.data_ptr() : nullptr, + softmax_lse.data_ptr(), + softmax_scale, + is_causal, + softcap, + has_mask, + has_bias, + seqlenq_ngroups_swapped, + /*unpadded_lse*/true + ); + params.total_q = total_q; + + if (paged_KV) { + params.block_table = block_table.data_ptr(); + params.block_table_batch_stride = block_table.stride(0); + params.k_batch_stride = k.stride(0); + params.v_batch_stride = v.stride(0); + } + params.page_block_size = page_block_size; + // Keep references to these tensors to extend their lifetime + at::Tensor softmax_lse_accum, out_accum; + if (seqlenq_ngroups_swapped) { + // Only apply split-k for decoding + std::tie(softmax_lse_accum, out_accum) = + set_params_splitkv( + params, batch_size, num_heads, head_size, + max_seqlen_k, max_seqlen_q, head_size_rounded, + /*num_splits*/ 0, get_num_sm(get_current_device()), opts + ); + } + + if (leftpad_k_.has_value()) { + auto leftpad_k = leftpad_k_.value(); + TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + CHECK_DEVICE(leftpad_k); + CHECK_CONTIGUOUS(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + params.leftpad_k = static_cast(leftpad_k.data_ptr()); + } + + if (max_seqlen_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream, paged_KV); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + if (seqlenq_ngroups_swapped) { + int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size}; + int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size}; + out = out.reshape(size_before).transpose(1, 2).reshape(size_after); + q = q.reshape(size_before).transpose(1, 2).reshape(size_after); + softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size}); + } + + return { out, softmax_lse, p }; +} void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { FP16_SWITCH(!params.is_bf16, [&] { @@ -791,18 +792,18 @@ void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { std::vector mha_bwd( - const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8) - const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size - const std::optional &mask_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k - const std::optional &bias_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k - const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &softmax_lse, // b x h x seqlen_q - std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size - std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size - std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size - std::optional &dbias_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k + const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8) + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const std::optional &mask_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k + const std::optional &bias_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dbias_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k const float softmax_scale, const bool is_causal, const float softcap, @@ -987,7 +988,7 @@ mha_bwd( : dk; dv_expanded = num_heads_k != num_heads // MQA / GQA ? torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts) - : dv; + : dv; dbias_expanded = has_bias ? ( (num_heads_bias != num_heads) || (bias_.has_value() && bias_.value().dim() == 3) // MQA / GQA or bias has no seqlen_q dimension @@ -1070,235 +1071,228 @@ mha_bwd( return { dq, dk, dv, dbias, softmax_d }; } -// TODO: At present, we don't have a good strategy to handle the mask and bias of the varlen variant. -// std::vector -// mha_varlen_bwd( -// const at::Tensor &dout, // total_q x num_heads, x head_size -// const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i -// const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i -// const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i -// const at::Tensor &mask, // total_q x num_heads_k x max_seqlen_k -// const at::Tensor &bias, // total_q x num_heads_k x max_seqlen_k -// const at::Tensor &out, // total_q x num_heads x head_size -// const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp -// std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i -// std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i -// std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i -// std::optional &dbias_, // total_q x num_heads_k x max_seqlen_k -// const at::Tensor &cu_seqlens_q, // b+1 -// const at::Tensor &cu_seqlens_k, // b+1 -// const int max_seqlen_q, -// const int max_seqlen_k, // max sequence length to choose the kernel -// const float softmax_scale, -// const bool zero_tensors, -// const bool is_causal, -// const float softcap, -// const bool deterministic -// ) { - -// #ifdef FLASHATTENTION_DISABLE_BACKWARD -// TORCH_CHECK(false, "This flash dynamic mask attention build does not support backward."); -// #endif - -// // Otherwise the kernel will be launched from cuda:0 device -// at::cuda::CUDAGuard device_guard{q.device()}; - -// auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); -// bool is_sm8x_min = cc_major >= 8; -// TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); - -// auto stream = at::cuda::getCurrentCUDAStream().stream(); - -// auto q_dtype = q.dtype(); -// TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); -// TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); -// TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); -// TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); -// TORCH_CHECK(bias.dtype() == q_dtype, "query and bias must have the same dtype"); -// TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); -// TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); -// TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); -// TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); - -// CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias); -// CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); -// CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); - -// TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); -// TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); -// TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); -// TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); -// TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); -// TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); -// TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); -// CHECK_CONTIGUOUS(cu_seqlens_q); -// CHECK_CONTIGUOUS(cu_seqlens_k); - -// const auto sizes = q.sizes(); -// auto opts = q.options(); - -// const int total_q = sizes[0]; -// const int batch_size = cu_seqlens_q.numel() - 1; -// const int num_heads = sizes[1]; -// const int head_size = sizes[2]; -// const int total_k = k.size(0); -// const int num_heads_k = k.size(1); -// TORCH_CHECK(batch_size > 0, "batch size must be positive"); -// TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); -// TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention backward only supports head dimension at most 256"); -// TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - -// auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; -// const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); -// const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); -// const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); - -// CHECK_SHAPE(q, total_q, num_heads, head_size); -// CHECK_SHAPE(k, total_k, num_heads_k, head_size); -// CHECK_SHAPE(v, total_k, num_heads_k, head_size); -// CHECK_SHAPE(mask, total_q, num_heads_k, max_seqlen_k); -// CHECK_SHAPE(bias, total_q, num_heads_k, max_seqlen_k); -// CHECK_SHAPE(out, total_q, num_heads, head_size); -// CHECK_SHAPE(dout, total_q, num_heads, head_size); -// CHECK_SHAPE(cu_seqlens_q, batch_size + 1); -// CHECK_SHAPE(cu_seqlens_k, batch_size + 1); - -// at::Tensor dq, dk, dv, dbias; -// if (dq_.has_value()) { -// dq = dq_.value(); -// TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); -// CHECK_DEVICE(dq); -// TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); -// CHECK_SHAPE(dq, total_q, num_heads, head_size); -// } else { -// dq = torch::empty_like(q); -// } -// if (dk_.has_value()) { -// dk = dk_.value(); -// TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); -// CHECK_DEVICE(dk); -// TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); -// CHECK_SHAPE(dk, total_k, num_heads_k, head_size); -// } else { -// dk = torch::empty_like(k); -// } -// if (dv_.has_value()) { -// dv = dv_.value(); -// TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); -// CHECK_DEVICE(dv); -// TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); -// CHECK_SHAPE(dv, total_k, num_heads_k, head_size); -// } else { -// dv = torch::empty_like(v); -// } -// if (dbias_.has_value()) { -// dbias = dbias_.value(); -// TORCH_CHECK(dbias.dtype() == q_dtype, "dbias must have the same dtype as q"); -// CHECK_DEVICE(dbias); -// TORCH_CHECK(dbias.stride(-1) == 1, "dbias must have contiguous last dimension"); -// CHECK_SHAPE(dbias, total_q, num_heads_k, max_seqlen_k); -// } else { -// dbias = torch::empty({total_q, num_heads_k, max_seqlen_k}, opts); -// } - -// // bool loop = max_seqlen_k > blocksize_c; -// // TODO: change later, for now set to true for simplicity -// bool loop = true; - -// auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat)); -// at::Tensor dq_accum; -// if (loop) { -// // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded) -// // because that would be too large if there is a very long sequence and the rest of the sequences are short. -// // Instead, we allocate dq_accum of size (total_q + 128 * batch, num_heads, head_size_rounded). -// // Note that 128 is the max block size on the seqlen_q dimension. -// // For dQ, the i-th sequence is stored in indices from cu_seqlens[i] + 128 * i to -// // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will -// // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally -// // allowed to do. So we won't have to do any bound checking, and performance should stay the same. -// // Same holds for softmax_d, since LSE is stored in unpadded format. -// if (!deterministic) { -// dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); -// } else { -// const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads); -// dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); -// } -// } - -// at::Tensor dk_expanded, dv_expanded, dbias_expanded; -// if (num_heads_k != num_heads) { // MQA / GQA -// dk_expanded = torch::empty({total_k, num_heads, head_size}, opts); -// dv_expanded = torch::empty({total_k, num_heads, head_size}, opts); -// dbias_expanded = torch::empty({total_q, num_heads, max_seqlen_k}, opts); -// } else { -// dk_expanded = dk; -// dv_expanded = dv; -// dbias_expanded = dbias; -// } - -// if( zero_tensors ) { -// dq.zero_(); -// dk_expanded.zero_(); -// dv_expanded.zero_(); -// dbias_expanded.zero_(); -// softmax_d.zero_(); -// } - -// Flash_bwd_params params; - -// set_params_dgrad( -// params, -// batch_size, -// max_seqlen_q, max_seqlen_k, -// seqlen_q_rounded, seqlen_k_rounded, -// num_heads, num_heads_k, -// head_size, head_size_rounded, -// q, k, v, mask, bias, out, -// dout, dq, dk_expanded, dv_expanded, dbias_expanded, -// cu_seqlens_q.data_ptr(), -// cu_seqlens_k.data_ptr(), -// loop ? dq_accum.data_ptr() : nullptr, -// nullptr, -// nullptr, -// softmax_lse.data_ptr(), -// softmax_d.data_ptr(), -// softmax_scale, -// is_causal, -// softcap, -// deterministic, -// /*unpadded_lse*/true -// ); -// params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); -// params.total_q = total_q; - -// auto launch = &run_mha_bwd; - -// if (max_seqlen_q > 0) { -// launch(params, stream); -// } else { -// // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. -// dk_expanded.zero_(); -// dv_expanded.zero_(); -// dbias_expanded.zero_(); -// softmax_d.zero_(); -// } - -// // For MQA/GQA we need to sum dK and dV across the groups -// if (num_heads_k != num_heads) { -// at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); -// at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); -// at::sum_out(dbias, at::reshape(dbias_expanded, {total_q, num_heads_k, num_heads / num_heads_k, max_seqlen_k}), {2}); -// } - -// return { dq, dk, dv, dbias, softmax_d }; -// } +std::vector +mha_varlen_bwd( + const at::Tensor &dout, // total_q x num_heads, x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp + std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + const float softcap, + const bool deterministic +) { + + #ifdef FLASHATTENTION_DISABLE_BACKWARD + TORCH_CHECK(false, "This flash dynamic mask attention build does not support backward."); + #endif + + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.device()}; + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + bool is_sm8x_min = cc_major >= 8; + TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_k); + + auto opts = q.options(); + + bool has_mask = false; + at::Tensor mask; + mask = torch::empty({0}, opts); + bool has_bias = false; + at::Tensor bias; + bias = torch::empty({0}, opts); + + const auto sizes = q.sizes(); + + const int total_q = sizes[0]; + const int batch_size = cu_seqlens_q.numel() - 1; + const int num_heads = sizes[1]; + const int head_size = sizes[2]; + const int total_k = k.size(0); + const int num_heads_k = k.size(1); + const int num_heads_mask = has_mask ? mask.size(1) : 1; + const int num_heads_bias = has_bias ? bias.size(1) : 1; + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention backward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); + const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + + CHECK_SHAPE(q, total_q, num_heads, head_size); + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size); + CHECK_SHAPE(out, total_q, num_heads, head_size); + CHECK_SHAPE(dout, total_q, num_heads, head_size); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + at::Tensor dq, dk, dv, dbias; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, total_q, num_heads, head_size); + } else { + dq = torch::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, total_k, num_heads_k, head_size); + } else { + dk = torch::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, total_k, num_heads_k, head_size); + } else { + dv = torch::empty_like(v); + } + dbias = torch::empty({0}, opts); + + // bool loop = max_seqlen_k > blocksize_c; + // TODO: change later, for now set to true for simplicity + bool loop = true; + + auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat)); + at::Tensor dq_accum; + if (loop) { + // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded) + // because that would be too large if there is a very long sequence and the rest of the sequences are short. + // Instead, we allocate dq_accum of size (total_q + 128 * batch, num_heads, head_size_rounded). + // Note that 128 is the max block size on the seqlen_q dimension. + // For dQ, the i-th sequence is stored in indices from cu_seqlens[i] + 128 * i to + // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will + // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally + // allowed to do. So we won't have to do any bound checking, and performance should stay the same. + // Same holds for softmax_d, since LSE is stored in unpadded format. + if (!deterministic) { + dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); + } else { + const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads); + dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); + } + } + + at::Tensor dk_expanded, dv_expanded, dbias_expanded; + if (num_heads_k != num_heads) { // MQA / GQA + dk_expanded = torch::empty({total_k, num_heads, head_size}, opts); + dv_expanded = torch::empty({total_k, num_heads, head_size}, opts); + // dbias_expanded = torch::empty({total_q, num_heads, max_seqlen_k}, opts); + dbias_expanded = torch::empty({0}, opts); + } else { + dk_expanded = dk; + dv_expanded = dv; + dbias_expanded = dbias; + } + + if( zero_tensors ) { + dq.zero_(); + dk_expanded.zero_(); + dv_expanded.zero_(); + // dbias_expanded.zero_(); + softmax_d.zero_(); + } + + Flash_bwd_params params; + + set_params_dgrad( + params, + batch_size, + max_seqlen_q, max_seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, num_heads_mask, num_heads_bias, + head_size, head_size_rounded, + q, k, v, mask, bias, out, + dout, dq, dk_expanded, dv_expanded, dbias_expanded, + cu_seqlens_q.data_ptr(), + cu_seqlens_k.data_ptr(), + loop ? dq_accum.data_ptr() : nullptr, + nullptr, + nullptr, + softmax_lse.data_ptr(), + softmax_d.data_ptr(), + softmax_scale, + is_causal, + softcap, + has_mask, + has_bias, + deterministic, + /*unpadded_lse*/true + ); + params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); + params.total_q = total_q; + + auto launch = &run_mha_bwd; + + if (max_seqlen_q > 0) { + launch(params, stream); + } else { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dk_expanded.zero_(); + dv_expanded.zero_(); + // dbias_expanded.zero_(); + softmax_d.zero_(); + } + + // For MQA/GQA we need to sum dK and dV across the groups + if (num_heads_k != num_heads) { + at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); + at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); + // at::sum_out(dbias, at::reshape(dbias_expanded, {total_q, num_heads_k, num_heads / num_heads_k, max_seqlen_k}), {2}); + } + + return { dq, dk, dv, softmax_d }; +} } // namespace FLASH_NAMESPACE PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "FlashDynamicMaskAttention"; m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass"); - // m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass with variable length"); + m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass with variable length"); m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass"); - // m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass with variable length"); + m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass with variable length"); } diff --git a/docs/api_reference.md b/docs/api_reference.md index e73eb1a..65bbebf 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -159,7 +159,7 @@ def flash_dmattn_func( value: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim) attn_mask: Optional[torch.Tensor] = None, # (batch, {num_heads, num_kv_heads, 1}, {seqlen_q, 0}, seqlen_k) attn_bias: Optional[torch.Tensor] = None, # (batch, {num_heads, num_kv_heads, 1}, {seqlen_q, 0}, seqlen_k) - scale: Optional[float] = None, # score scaling, defaults to 1/sqrt(head_dim) + softmax_scale: Optional[float] = None, # score scaling, defaults to 1/sqrt(head_dim) is_causal: Optional[bool] = None, # causal mask softcap: Optional[float] = None, # CUDA-only deterministic: Optional[bool] = None, # CUDA-only @@ -174,7 +174,7 @@ def flash_dmattn_func( - value: (B, K, H_kv, D). Same dtype/device as query; GQA when H_kv <= H - attn_mask: (B, {H, H_kv, 1}, {Q, 0}, K). 1.0 = visible, 0.0 = masked. None to disable - attn_bias: (B, {H, H_kv, 1}, {Q, 0}, K). Added to scores before softmax. None to disable -- scale: score scaling; default 1/sqrt(D) +- softmax_scale: score scaling; default 1/sqrt(D) - is_causal: apply lower-triangular mask - softcap, deterministic, return_attn_probs: only effective on the CUDA backend; ignored on others @@ -194,7 +194,7 @@ def triton_dmattn_func( attn_mask: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) attn_bias: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) is_causal: bool = False, # causal mask - scale: Optional[float] = None, # score scaling, defaults to 1/sqrt(head_dim) + softmax_scale: Optional[float] = None, # score scaling, defaults to 1/sqrt(head_dim) ) -> torch.Tensor ``` @@ -210,7 +210,7 @@ def flex_dmattn_func( attn_mask: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) attn_bias: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) is_causal: Optional[bool] = None, # causal mask - scale: Optional[float] = None, # score scaling, defaults to 1/sqrt(head_dim) + softmax_scale: Optional[float] = None, # score scaling, defaults to 1/sqrt(head_dim) ) -> torch.Tensor ``` @@ -341,7 +341,7 @@ class DynamicMaskAttention(nn.Module): value_states, attention_mask=attention_mask, attention_bias=attn_bias, - scale=self.scaling, + softmax_scale=self.scaling, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() diff --git a/docs/api_reference_zh.md b/docs/api_reference_zh.md index efd308a..83d1617 100644 --- a/docs/api_reference_zh.md +++ b/docs/api_reference_zh.md @@ -159,7 +159,7 @@ def flash_dmattn_func( value: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim) attn_mask: Optional[torch.Tensor] = None, # (batch, {num_heads, num_kv_heads, 1}, {seqlen_q, 0}, seqlen_k) attn_bias: Optional[torch.Tensor] = None, # (batch, {num_heads, num_kv_heads, 1}, {seqlen_q, 0}, seqlen_k) - scale: Optional[float] = None, # 分数缩放,默认为 1/sqrt(head_dim) + softmax_scale: Optional[float] = None, # 分数缩放,默认为 1/sqrt(head_dim) is_causal: Optional[bool] = None, # 因果掩码 softcap: Optional[float] = None, # 仅 CUDA 支持 deterministic: Optional[bool] = None, # 仅 CUDA 支持 @@ -174,7 +174,7 @@ def flash_dmattn_func( - value: (B, K, H_kv, D). 与 query 相同的数据类型/设备;当 H_kv <= H 时为 GQA - attn_mask: (B, {H, H_kv, 1}, {Q, 0}, K). 1.0 = 可见,0.0 = 被掩码。None 表示禁用 - attn_bias: (B, {H, H_kv, 1}, {Q, 0}, K). 在 softmax 前加到分数上。None 表示禁用 -- scale: 分数缩放;默认为 1/sqrt(D) +- softmax_scale: 分数缩放;默认为 1/sqrt(D) - is_causal: 应用因果掩码 - softcap, deterministic, return_attn_probs: 仅在 CUDA 后端有效;在其他后端被忽略 @@ -194,7 +194,7 @@ def triton_dmattn_func( attn_mask: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) attn_bias: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) is_causal: bool = False, # 因果掩码 - scale: Optional[float] = None, # 分数缩放,默认为 1/sqrt(head_dim) + softmax_scale: Optional[float] = None, # 分数缩放,默认为 1/sqrt(head_dim) ) -> torch.Tensor ``` @@ -210,7 +210,7 @@ def flex_dmattn_func( attn_mask: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) attn_bias: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) is_causal: Optional[bool] = None, # 因果掩码 - scale: Optional[float] = None, # 分数缩放,默认为 1/sqrt(head_dim) + softmax_scale: Optional[float] = None, # 分数缩放,默认为 1/sqrt(head_dim) ) -> torch.Tensor ``` @@ -340,7 +340,7 @@ class DynamicMaskAttention(nn.Module): value_states, attention_mask=attention_mask, attention_bias=attn_bias, - scale=self.scaling, + softmax_scale=self.scaling, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() diff --git a/docs/integration.md b/docs/integration.md index feb5ffa..31a57d1 100644 --- a/docs/integration.md +++ b/docs/integration.md @@ -1108,7 +1108,7 @@ def _flash_dynamic_mask_attention_forward( out = flash_dmattn_func( query_states, key_states, value_states, attn_mask=attention_mask, attn_bias=attention_bias, - scale=softmax_scale, is_causal=is_causal + softmax_scale=softmax_scale, is_causal=is_causal ) return out[0] if isinstance(out, tuple) else out @@ -1551,7 +1551,7 @@ def _flash_dynamic_mask_attention_forward( out = flash_dmattn_func( query_states, key_states, value_states, attn_mask=attention_mask, attn_bias=attention_bias, - scale=softmax_scale, is_causal=is_causal + softmax_scale=softmax_scale, is_causal=is_causal ) return out[0] if isinstance(out, tuple) else out @@ -1994,7 +1994,7 @@ def _flash_dynamic_mask_attention_forward( out = flash_dmattn_func( query_states, key_states, value_states, attn_mask=attention_mask, attn_bias=attention_bias, - scale=softmax_scale, is_causal=is_causal + softmax_scale=softmax_scale, is_causal=is_causal ) return out[0] if isinstance(out, tuple) else out @@ -2421,7 +2421,7 @@ def _flash_dynamic_mask_attention_forward( out = flash_dmattn_func( query_states, key_states, value_states, attn_mask=attention_mask, attn_bias=attention_bias, - scale=softmax_scale, is_causal=is_causal + softmax_scale=softmax_scale, is_causal=is_causal ) return out[0] if isinstance(out, tuple) else out diff --git a/docs/integration_zh.md b/docs/integration_zh.md new file mode 100644 index 0000000..56ed0c1 --- /dev/null +++ b/docs/integration_zh.md @@ -0,0 +1,522 @@ +# Flash 动态掩码注意力集成指南 + +## 概述 + +本文档阐述了如何在 Flash Attention 框架中集成 Dynamic Mask Attention(动态掩码注意力)。通过将 Flash Attention 的高效显存利用方式与动态稀疏掩码结合,这一集成能够在极长序列场景下实现稀疏注意力的高效计算。 + +该集成方案采用统一的稀疏计算路径:Python 端负责预计算注意力掩码与偏置张量,CUDA 后端在前向与反向两个阶段执行基于块的跳过逻辑与稀疏算子调度。 + +## 目录 + +1. [集成架构](#集成架构) +2. [核心改动](#核心改动) +3. [实现细节](#实现细节) +4. [稀疏计算策略](#稀疏计算策略) +5. [内存布局](#内存布局) +6. [性能考量](#性能考量) +7. [API 变化](#api-变化) + +## 集成架构 + +### 高层设计 + +动态掩码注意力的集成在前向与反向过程中统一采用块级稀疏执行路径: + +1. **动态掩码计算**:Python 端预先生成注意力掩码(mask)与注意力偏置(bias)张量。 +2. **统一稀疏执行**:CUDA 后端在块粒度上决定是否跳过计算,并执行稀疏化的注意力与梯度算子。 +3. **内存优化**:通过共享内存别名与显式同步实现更高的共享内存复用率。 + +### 关键组件 + +- **注意力掩码**:形状为 `(batch, num_kv_heads, query_len, key_len)` 的二值张量(1.0 表示保留,0.0 表示跳过)。 +- **注意力偏置**:与掩码形状一致的张量,在 Softmax 前加性注入。 +- **块级跳过逻辑**:对 `(BlockM × BlockN)` tile 做 OR 归约判断是否执行计算。 +- **LSE 缓存**:前向阶段缓存 log-sum-exp 结果,反向阶段复用以保持数值稳定。 +- **共享内存别名**:动态复用共享内存缓冲区,配合 `__syncthreads()` 控制生命周期。 +- **完备梯度链路**:在保留稀疏跳过能力的同时,确保梯度流动正确。 + +## 核心改动 + +### 1. 参数结构扩展(`flash.h`) + +**目的**:扩展参数结构体以支持动态掩码与偏置信息,同时保留对 QKV 的统一访问接口。 + +```cpp +struct QKV_params { + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; + index_t q_batch_stride, k_batch_stride, v_batch_stride; + index_t q_row_stride, k_row_stride, v_row_stride; + index_t q_head_stride, k_head_stride, v_head_stride; + int h, h_k; + int h_h_k_ratio; +}; + +struct Mask_params { + void *__restrict__ mask_ptr; + index_t mask_batch_stride; + index_t mask_head_stride; + index_t mask_row_stride; +}; + +struct Bias_params { + void *__restrict__ bias_ptr; + index_t bias_batch_stride; + index_t bias_head_stride; + index_t bias_row_stride; +}; + +struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_params { + // ...existing code... + bool seqlenq_ngroups_swapped; +}; +``` + +**设计要点**: +- 多重继承将 QKV、掩码、偏置的参数维度拆分,保持接口清晰。 +- 为掩码与偏置提供完整的 stride 信息,以便在 CUDA 中高效寻址。 +- 与原有 Flash Attention 的内存布局保持兼容,避免性能回退。 + +### 2. 内核特性与内存布局(`kernel_traits.h`) + +**目的**:根据架构(SM75 / SM80+)选择合适的 MMA 原子与内存拷贝路径,为动态掩码操作提供最佳性能。 + +```cpp +template +struct Flash_kernel_traits { + using Element = elem_type; + using ElementAccum = float; + using index_t = int64_t; + static constexpr int kHeadDim = kHeadDim_; + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kNWarps = kNWarps_; + // ...existing code... + using SmemCopyAtomMask = SmemCopyAtom; + using SmemCopyAtomBias = SmemCopyAtom; +}; +``` + +**设计要点**: +- 根据编译目标自动选择 `cp.async` 与 LDSM 指令路径。 +- 统一掩码与偏置的共享内存加载策略,避免额外的 bank conflict。 +- 模板化的类型安全保证不同精度(FP16/BF16)路径一致。 + +### 3. 块级信息扩展(`block_info.h`) + +**目的**:在可变长度场景下计算掩码与偏置的块级偏移量,保证全局内存访问有序。 + +```cpp +template +struct BlockInfo { + template + __device__ BlockInfo(const Params ¶ms, const int bidb) { + // ...existing code... + } + + template + __forceinline__ __device__ index_t mask_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + sum_s_k == -1 ? offset += leftpad_k : offset += uint32_t(sum_s_k + leftpad_k); + return offset; + } + + // ...existing code... +}; +``` + +**设计要点**: +- 提供统一的偏移量计算方法,简化内核中的地址计算。 +- 同时支持固定长度与可变长度两种输入形式。 +- 将左侧填充(left pad)纳入偏移量,保证稀疏掩码与 KV 缓存对齐。 + +### 4. 内存拷贝与算子工具(`utils.h`) + +**目的**:提供布局转换、类型转换、warp 归约与通用 GEMM 包装,适配 Flash Attention 的内存层次结构。 + +```cpp +namespace FLASH_NAMESPACE { + +template +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + // ...existing code... + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); +}; + +// ...existing code... + +template +__forceinline__ __device__ void gemm(/* ... */) { + // ...existing code... +} + +} // namespace FLASH_NAMESPACE +``` + +**设计要点**: +- 通过布局转换统一 MMA 累加器的访问方式,方便掩码逻辑在寄存器中操作。 +- 提供针对 BF16 的专用类型转换,避免额外的精度损耗。 +- Warp 归约与 GEMM 包装均支持将数据留在寄存器中,降低共享内存压力。 + +### 5. 动态掩码核心逻辑(`mask.h`) + +**目的**:在寄存器层面将掩码与偏置应用到注意力得分上,同时处理因果掩码与边界情况。 + +```cpp +template +__forceinline__ __device__ void apply_mask( + TensorType &tensor, + MaskType &mask, + BiasType &bias, + const float scale_softmax, + const int col_idx_offset_, + const int max_seqlen_k, + const int row_idx_offset, + const int max_seqlen_q, + const int warp_row_stride) { + // ...existing code... +} +``` + +**设计要点**: +- 在 `tensor` 保持 MMA 布局的情况下,逐元素应用掩码、偏置与缩放因子。 +- 因果掩码通过列索引上限裁剪实现,与动态掩码兼容。 +- 被掩盖的位置直接写入 `-INFINITY`,防止 Softmax 后出现数值污染。 + +### 6. 反向链路扩展(`flash_bwd_kernel.h`) + +**目的**:在反向传播中复用动态掩码逻辑,确保梯度仅在活跃 tile 上计算。 + +```cpp +struct Flash_bwd_params : public Flash_fwd_params { + // ...existing code... +}; + +template +inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, + const int bidh, const int n_block) { + // ...existing code... +} +``` + +**设计要点**: +- 反向路径沿用前向阶段的 tile 活跃性判断,跳过完全被掩码的块。 +- 结合 LSE 缓存,重算前向 Softmax 时保持数值稳定。 +- 保证五个梯度 GEMM 在活跃 tile 上依旧串联执行,避免梯度缺失。 + +### 7. 前向内核改造(`flash_fwd_kernel.h`) + +**目的**:在主注意力内核中插入动态掩码流程,同时保持 Flash Attention 的高并发与共享内存利用率。 + +```cpp +template +inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, + const int bidh, const int m_block) { + using Element = typename Kernel_traits::Element; + // ...existing code... +} +``` + +**设计要点**: +- 按 tile 裁剪逻辑提前判断是否加载 K/V,降低无效内存访问。 +- 仅在提供掩码/偏置时启用相应的分支,保持向后兼容。 +- 通过模板参数在编译期裁剪分支,减少运行期开销。 + +### 8. 启动模板更新(`flash_fwd_launch_template.h`) + +**目的**:在 kernel launch 阶段配置共享内存需求、模板实例化与错误处理,适配动态掩码的新资源需求。 + +```cpp +#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ +template \ +__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, + bool Is_causal, bool Is_even_MN, bool Is_even_K, + bool Is_softcap, bool Return_softmax) { + // ...existing code... +} + +// ...existing code... + +template +void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr size_t smem_size = Kernel_traits::kSmemSize; + // ...existing code... +} +``` + +**设计要点**: +- 统一宏定义减少重复代码,便于扩展到新的 kernel 变体。 +- 针对不支持的架构给出明确的构建期/运行期错误提示。 +- 在 launch 前计算共享内存需求,必要时启用 `cudaFuncSetAttribute` 进行配置。 + +### 9. Python 接口扩展(`flash_api.cpp`) + +**目的**:扩展 C++/PyBind11 接口以接受掩码与偏置张量,并提供全面的数据校验。 + +```cpp +void set_params_fprop( + Flash_fwd_params ¶ms, + // ...existing code... +) { + // ...existing code... +} + +std::vector mha_fwd( + at::Tensor &q, + // ...existing code... + const bool return_softmax) { + // ...existing code... + return {out, softmax_lse}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "FlashDynamicMaskAttention"; + // ...existing code... +} +``` + +**设计要点**: +- 对输入张量的形状、dtype、device 进行全面校验。 +- 保持原有参数顺序,新增参数保持向后兼容的默认行为。 +- 当掩码或偏置未提供时,自动填充零值张量以保证接口易用性。 + +## 实现细节 + +### C++ API 接口 + +C++ 端对外暴露如下核心函数,用于前向、可变长度前向与反向计算: + +```cpp +namespace FLASH_NAMESPACE { + +std::vector mha_fwd( + at::Tensor &q, + at::Tensor &k, + at::Tensor &v, + // ...existing code... + const bool return_softmax); + +std::vector mha_varlen_fwd(/* ... */); + +std::vector mha_bwd(/* ... */); + +} // namespace FLASH_NAMESPACE +``` + +- `mha_fwd`:标准批量前向,支持稀疏掩码与偏置。 +- `mha_varlen_fwd`:支持变长序列并使用累计长度数组。 +- `mha_bwd`:完成梯度计算,返回 dQ / dK / dV / dBias / dMask 等张量。 + +### 参数设置与校验 + +`set_params_fprop` 会在调用前: + +- 重置 `Flash_fwd_params` 并写入基本维度信息。 +- 将掩码与偏置的设备指针、stride、批次数等全部注册。 +- 基于输入 `dtype` 设置缩放因子与 `softcap`,同时准备缓存指针。 + +### Python 绑定与接口 + +PyBind11 模块对外暴露 `mha_fwd`、`mha_bwd`、`varlen_fwd` 等接口,文档字符串说明了参数要求与返回值。用户可通过 Python 直接调用 C++/CUDA 实现。 + +### Python 前端集成示例 + +```python +import torch +import torch.nn as nn +import flash_dmattn_cuda as flash_dmattn + +class DynamicMaskAttention(nn.Module): + def __init__(self, config): + super().__init__() + # ...existing code... + + def forward(self, query_states, key_states, value_states, attn_mask, attn_bias): + out, softmax_lse = flash_dmattn.fwd( + query_states, key_states, value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, + return_softmax=True, + ) + return out, softmax_lse +``` + +- 前端模块负责生成 `attn_mask`(布尔)与 `attn_bias`(与 Q/K/V dtype 相同)。 +- 内部 `_flash_dynamic_mask_attention_forward` 会根据需要补零偏置并调用后端。 +- 输入张量默认为 `(batch, seq_len, num_heads, head_dim)` 排列,内部会自动转置到后端期望格式。 + +## 稀疏计算策略 + +### 块级跳过逻辑 + +- 在加载 Q tile 后,先将掩码 tile 拷贝到共享内存并执行 OR 归约。 +- 若整块被掩盖,则跳过 K/V 加载与后续计算,只推进指针。 +- 对活跃块执行常规注意力流程,并复用共享内存保存 Softmax 结果。 + +### 前向算法 + +```pseudo +for m_block in M_tiles: + load Q_tile + load mask_tile -> shared + any_active = or_reduce(mask_tile) + if not any_active: + continue + load K_tile, V_tile + compute scaled dot product + apply mask & bias in registers + softmax -> write O_tile +``` + +- 掩码裁剪保证 Tile 内所有无效位置直接输出 `-INF`。 +- Softmax 前的缩放与偏置添加与密集版本保持一致。 +- 通过共享内存别名(sMask ↔ sP)减少显存占用。 + +### 反向算法 + +```pseudo +for m_block in reversed(M_tiles): + load Q_tile, dO_tile + load mask_tile -> shared + if tile inactive: + continue + recompute scores with cached LSE + propagate gradients for dS, dV, dK, dQ +``` + +- 仅对活跃块执行五个 GEMM 组合,减少稀疏场景下的冗余计算。 +- 使用前向缓存的 LSE 确保 Softmax 反向的数值稳定性。 +- 对被跳过的块梯度自然为零,避免写入污染。 + +### 跳过逻辑正确性 + +- 若 tile 全部被掩码,输出必为零,跳过计算不会影响结果。 +- 反向阶段活跃性与前向保持一致,保证梯度对应关系不被破坏。 +- 由于被掩盖位置在 Softmax 前已写入 `-INF`,LSE 亦不受影响。 + +## 内存布局 + +### 全局内存组织 + +``` +Q: [batch, seqlen_q, num_heads, head_dim] +K: [batch, seqlen_k, num_kv_heads, head_dim] +V: [batch, seqlen_k, num_kv_heads, head_dim] +Mask: [batch, num_kv_heads, seqlen_q, seqlen_k] +Bias: [batch, num_kv_heads, seqlen_q, seqlen_k] +Output: [batch, seqlen_q, num_heads, head_dim] +``` + +### 共享内存布局(每个线程块) + +``` +Q Tile : [kBlockM, head_dim] +K Tile : [kBlockN, head_dim] +V Tile : [kBlockN, head_dim] +S Tile : [kBlockM, kBlockN] +Mask Tile: [kBlockM, kBlockN] +Bias Tile: [kBlockM, kBlockN] +``` + +### 寄存器布局(每个线程) + +``` +Q Frag : [MMA_M, head_dim / N] +K Frag : [MMA_N, head_dim / N] +V Frag : [MMA_N, head_dim / N] +S Frag : [MMA_M, MMA_N] +Mask Frag: [MMA_M, MMA_N] +Bias Frag: [MMA_M, MMA_N] +Acc Frag : [MMA_M, head_dim / N] +``` + +### 内存访问模式 + +- 掩码与偏置与 K/V 共享相同的 `Copy_Atom` 配置,确保 128-bit 对齐、最大化带宽。 +- 共享内存拷贝后通过 `local_partition` 分配给线程,避免 bank conflict。 +- `convert_layout_acc_rowcol` 将 MMA 布局转换为行/列布局,方便寄存器操作。 + +### 共享内存优化 + +- **别名复用**:`sMask` 在使用后可重用为 `sP`(Softmax 输出),`sBias` 可重用为 `sdS`。 +- **同步屏障**:在重用前使用 `__syncthreads()` 确保所有线程完成对旧数据的使用。 +- **块尺寸选择**:根据稀疏度与共享内存限制调整 tile 尺寸,提高 SM 占用率。 + +## 性能考量 + +- **共享内存复用**:别名策略可将共享内存占用削减约 30%。 +- **块级跳过**:当稀疏度为 75% 时,可获得约 3× 的前向提速;稀疏度 90% 时可达到 ~6×。 +- **带宽优化**:跳过无效 tile 可以线性降低全局内存带宽需求。 +- **同步开销**:跳过路径的额外 OR 归约占总时间 <5%,可忽略不计。 +- **硬件自适应**:针对 SM75/SM80+ 的不同指令集做了专门优化,确保跨架构稳定收益。 + +## API 变化 + +### 新增必要参数 + +- `attn_mask` (`torch.Tensor`): 形状 `(batch, num_kv_heads, seqlen_q, seqlen_k)` 的布尔张量,决定稀疏模式。 +- `attn_bias` (`torch.Tensor`): 形状与掩码一致的加性偏置张量,dtype 与 Q/K/V 保持一致。 + +### 更新的函数签名 + +```python +def fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor, + attn_bias: torch.Tensor, + is_causal: bool = False, + return_softmax: bool = False, + **kwargs +) -> List[torch.Tensor]: + ... +``` + +### 向后兼容说明 + +- 这是一个破坏性更新,旧的 Flash Attention 调用需显式提供掩码与偏置。 +- 若业务场景不需要稀疏掩码,可传入全 1 掩码与全 0 偏置实现与旧版一致的行为。 +- 缺省值在 Python 前端会自动补齐,降低迁移的代码改动。 + +### 完整用法示例 + +```python +import torch +from flash_dmattn.integrations.flash_dynamic_mask_attention import ( + flash_dynamic_mask_attention_forward, +) + +batch, seq_q, seq_k, n_heads, head_dim = 2, 4096, 4096, 16, 128 +q = torch.randn(batch, seq_q, n_heads, head_dim, device="cuda", dtype=torch.float16) +k = torch.randn_like(q) +v = torch.randn_like(q) +mask = torch.ones(batch, n_heads, seq_q, seq_k, device=q.device, dtype=torch.bool) +bias = torch.zeros(batch, n_heads, seq_q, seq_k, device=q.device, dtype=q.dtype) + +out = flash_dynamic_mask_attention_forward( + query_states=q, + key_states=k, + value_states=v, + attention_mask=mask, + attention_bias=bias, + return_attn_probs=False, +) +``` + +- `flash_dynamic_mask_attention_forward` 会自动完成张量转置、补零偏置等准备工作。 +- 若指定 `return_attn_probs=True`,将返回经过 Softmax 的注意力概率,用于调试或可视化。 +- 稀疏模式的 mask 可通过 `flash_dmattn.utils.mask.MaskMod` 组合生成。 + +## 附加建议 + +- 修改 CUDA 核心代码后,至少运行 `benchmarks/forward_equivalence.py` 与 `benchmarks/grad_equivalence.py` 进行回归验证。 +- 构建扩展时可使用 `pip install -e . --no-build-isolation`,必要时设置 `FLASH_DMATTN_CUDA_ARCHS` 指定目标架构。 +- 若仅依赖 Triton/Flex 后端,可通过环境变量 `FLASH_DMATTN_SKIP_CUDA_BUILD=1` 跳过 CUDA 构建。 diff --git a/flash_dmattn/flash_dmattn_flex.py b/flash_dmattn/flash_dmattn_flex.py index 7be4d01..379f984 100644 --- a/flash_dmattn/flash_dmattn_flex.py +++ b/flash_dmattn/flash_dmattn_flex.py @@ -12,7 +12,7 @@ def flex_attention_forward( attn_mask: Optional[torch.Tensor] = None, attn_bias: Optional[torch.Tensor] = None, is_causal: Optional[bool] = None, - scale: Optional[float] = None, + softmax_scale: Optional[float] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: batch, seqlen_q, nheads, dhead = query.shape @@ -30,8 +30,8 @@ def flex_attention_forward( attn_bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype) if is_causal is None: is_causal = True - if scale is None: - scale = 1.0 / math.sqrt(dhead) + if softmax_scale is None: + softmax_scale = 1.0 / math.sqrt(dhead) def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): score = score + attn_bias[batch_idx][head_idx][q_idx][kv_idx] @@ -66,7 +66,7 @@ def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): value, score_mod=score_mod, block_mask=block_mask if is_causal else None, - scale=scale, + scale=softmax_scale, kernel_options=kernel_options, # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. # For simplification, we thus always return it as no additional computations are introduced. diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 861184b..893c638 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -131,8 +131,6 @@ def _flash_dmattn_varlen_forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - mask: Optional[torch.Tensor], - bias: Optional[torch.Tensor], cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, @@ -146,13 +144,11 @@ def _flash_dmattn_varlen_forward( seqused_k: Optional[torch.Tensor] = None, zero_tensors: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v, mask, bias = [maybe_contiguous(x) for x in (q, k, v, mask, bias)] + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, softmax_lse, S_dmask = flash_dmattn_gpu.varlen_fwd( q, k, v, - mask, - bias, None, cu_seqlens_q, cu_seqlens_k, @@ -176,8 +172,6 @@ def _flash_dmattn_varlen_forward_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - mask: Optional[torch.Tensor], - bias: Optional[torch.Tensor], cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, @@ -191,7 +185,7 @@ def _flash_dmattn_varlen_forward_fake( seqused_k: Optional[torch.Tensor] = None, zero_tensors: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v, mask, bias = [maybe_contiguous(x) for x in (q, k, v, mask, bias)] + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] paged_kv = block_table is not None batch_size = cu_seqlens_q.numel() - 1 total_q, num_heads, _ = q.shape @@ -294,20 +288,17 @@ def _flash_dmattn_backward_fake( _wrapped_flash_dmattn_backward = _flash_dmattn_backward -@_torch_custom_op_wrapper("flash_dmattn::_flash_dmattn_varlen_backward", mutates_args=("dq", "dk", "dv", "dbias"), device_types="cuda") +@_torch_custom_op_wrapper("flash_dmattn::_flash_dmattn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") def _flash_dmattn_varlen_backward( dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - mask: Optional[torch.Tensor], - bias: Optional[torch.Tensor], out: torch.Tensor, softmax_lse: torch.Tensor, dq: Optional[torch.Tensor], dk: Optional[torch.Tensor], dv: Optional[torch.Tensor], - dbias: Optional[torch.Tensor], cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, @@ -318,20 +309,17 @@ def _flash_dmattn_varlen_backward( deterministic: bool, zero_tensors: bool = False, ) -> torch.Tensor: - dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] ( dq, dk, dv, - dbias, softmax_d, ) = flash_dmattn_gpu.varlen_bwd( dout, q, k, v, - mask, - bias, out, softmax_lse, dq, @@ -347,7 +335,7 @@ def _flash_dmattn_varlen_backward( softcap, deterministic, ) - _sanitize_tensors(dq, dk, dv, dbias, nan=0.0, posinf=0.0, neginf=0.0) + _sanitize_tensors(dq, dk, dv, nan=0.0, posinf=0.0, neginf=0.0) return softmax_d @@ -357,8 +345,6 @@ def _flash_dmattn_varlen_backward_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - mask: Optional[torch.Tensor], - bias: Optional[torch.Tensor], out: torch.Tensor, softmax_lse: torch.Tensor, dq: Optional[torch.Tensor], @@ -375,7 +361,7 @@ def _flash_dmattn_varlen_backward_fake( deterministic: bool, zero_tensors: bool = False, ) -> torch.Tensor: - dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] batch_size = cu_seqlens_q.numel() - 1 total_q, num_heads, _ = q.shape @@ -385,8 +371,6 @@ def _flash_dmattn_varlen_backward_fake( dk = torch.empty_like(k) if dv is None: dv = torch.empty_like(v) - if dbias is None and bias is not None: - dbias = torch.empty_like(bias) softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32) return softmax_d @@ -422,7 +406,7 @@ def forward( if softcap is None: softcap = 0.0 if deterministic is None: - deterministic = True + deterministic = False if return_softmax is None: return_softmax = False @@ -521,8 +505,6 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - mask: Optional[torch.Tensor], - bias: Optional[torch.Tensor], cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, @@ -545,7 +527,7 @@ def forward( if softcap is None: softcap = 0.0 if deterministic is None: - deterministic = True + deterministic = False if return_softmax is None: return_softmax = False @@ -555,21 +537,11 @@ def forward( q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - seqlen_k_og = k.shape[1] - if seqlen_k_og % 8 != 0: - k = torch.nn.functional.pad(k, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8]) - v = torch.nn.functional.pad(v, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8]) - if mask is not None: - mask = torch.nn.functional.pad(mask, [0, 8 - seqlen_k_og % 8], value=False) - if bias is not None: - bias = torch.nn.functional.pad(bias, [0, 8 - seqlen_k_og % 8], value=0.0) out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_varlen_forward( q, k, v, - mask, - bias, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -583,7 +555,7 @@ def forward( if is_grad: ctx.save_for_backward( - q, k, v, mask, bias, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k ) ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k @@ -598,9 +570,8 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, mask, bias, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) - dbias = torch.zeros_like(bias).contiguous() if bias is not None else None head_size_og = dout.size(2) dout_padded = dout @@ -612,14 +583,11 @@ def backward(ctx, dout, *args): q, k, v, - mask, - bias, out, softmax_lse, dq, dk, dv, - dbias, cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, @@ -635,13 +603,7 @@ def backward(ctx, dout, *args): dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - if ctx.seqlen_k_og % 8 != 0: - dk = dk[:, : ctx.seqlen_k_og, :, :] - dv = dv[:, : ctx.seqlen_k_og, :, :] - if dbias is not None: - dbias = dbias[..., : ctx.seqlen_k_og] - - return dq, dk, dv, None, dbias, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None def flash_dmattn_func( @@ -725,8 +687,6 @@ def flash_dmattn_varlen_func( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_mask: Optional[torch.Tensor], - attn_bias: Optional[torch.Tensor], cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, @@ -744,11 +704,6 @@ def flash_dmattn_varlen_func( For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - Similarity, also supports attn_mask and attn_bias with head dimension of 1, nheads_k or nheads for MQA/GQA. - For example, if Q has 6 heads, K, V have 2 heads, then attn_mask and attn_bias can have head dimension - of 1, 2 or 6. If it is 1, all heads use the same mask/bias; if it is 2, head 0, 1, 2 of Q use head 0 - of mask/bias, head 3, 4, 5 of Q use head 1 of mask/bias. If it is 6, each head uses its own mask/bias. - If is_causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 1 1 1 1 0 @@ -765,12 +720,6 @@ def flash_dmattn_varlen_func( query: torch.Tensor. The query tensor of shape (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. key: torch.Tensor. The key tensor of shape (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. value: torch.Tensor. The value tensor of shape (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - attn_mask: torch.Tensor, optional. The attention mask boolean tensor of - shape (total_q, {nheads|nheads_k|1}, max_seqlen_k) or (total_k, {nheads|nheads_k|1}) to apply to the attention scores. - If None, no mask is applied. - attn_bias: torch.Tensor, optional. The attention bias float tensor of - shape (total_q, {nheads|nheads_k|1}, max_seqlen_k) or (total_k, {nheads|nheads_k|1}) to add to the attention scores. - If None, no bias is applied. cu_seqlens_q: torch.Tensor. The cumulative sequence lengths of the sequences in the batch, used to index into q. cu_seqlens_k: torch.Tensor. The cumulative sequence lengths of the sequences in the batch, used to index into kv. max_seqlen_q: int. Maximum query sequence length in the batch. @@ -796,8 +745,6 @@ def flash_dmattn_varlen_func( query, key, value, - attn_mask, - attn_bias, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index e61a0cf..c94500b 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -1108,5 +1108,5 @@ def backward(ctx, do): return dq, dk, dv, None, dbias, None, None -def triton_dmattn_func(query, key, value, attn_mask=None, attn_bias=None, is_causal=False, scale=None): - return FlashDMAttnFunc.apply(query, key, value, attn_mask, attn_bias, is_causal, scale) +def triton_dmattn_func(query, key, value, attn_mask=None, attn_bias=None, is_causal=False, softmax_scale=None): + return FlashDMAttnFunc.apply(query, key, value, attn_mask, attn_bias, is_causal, softmax_scale) diff --git a/flash_dmattn/integrations/flash_dynamic_mask_attention.py b/flash_dmattn/integrations/flash_dynamic_mask_attention.py index c37d367..b583a29 100644 --- a/flash_dmattn/integrations/flash_dynamic_mask_attention.py +++ b/flash_dmattn/integrations/flash_dynamic_mask_attention.py @@ -17,6 +17,7 @@ def flash_dynamic_mask_attention_forward( attention_mask: Optional[torch.Tensor], attention_bias: Optional[torch.Tensor], scaling: Optional[float] = None, + window_size: Optional[int] = None, softcap: Optional[float] = None, **kwargs, ) -> tuple[torch.Tensor, None]: @@ -29,14 +30,16 @@ def flash_dynamic_mask_attention_forward( query (torch.Tensor): The query tensor of shape (batch_size, num_heads, query_len, head_dim). key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim). value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim). - attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, {num_heads|num_kv_heads|1}, query_len, key_len). - attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, {num_heads|num_kv_heads|1}, query_len, key_len), if attention_mask is None, also supports (batch_size, {num_heads|num_kv_heads|1}, key_len). + attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape + (batch_size, seq_len) or (batch_size, {num_heads|num_kv_heads|1}, {query_len|0}, key_len). + attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape + (batch_size, {num_heads|num_kv_heads|1}, {query_len|0}, key_len). scaling (Optional[float]): The scaling factor for the attention scores. + window_size (Optional[int]): The size of the window to keep. softcap (Optional[float]): The softcap value for the attention scores. **kwargs: Additional keyword arguments. Includes: - is_causal (bool): Whether to apply a causal mask. - - window_size (int): The size of the window to keep. - layer_idx (int): The index of the layer (for logging purposes). - implementation (str): The implementation to use ("flash_dmattn" or None). @@ -82,9 +85,10 @@ def flash_dynamic_mask_attention_forward( else: target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype - # FDMA always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice - kwargs.pop("is_causal", None) - kwargs.pop("window_size", None) + # Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented + is_causal = kwargs.pop("is_causal", None) + if is_causal is None: + is_causal = module.is_causal attn_output = _flash_dynamic_mask_attention_forward( query, @@ -94,10 +98,10 @@ def flash_dynamic_mask_attention_forward( attention_bias, query_length=query_len, key_length=key_len, - is_causal=module.is_causal, + is_causal=is_causal, softmax_scale=scaling, softcap=softcap, - window_size=module.window_size, + window_size=window_size, target_dtype=target_dtype, implementation="flash_dmattn", layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None, diff --git a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py index de96347..538acd0 100644 --- a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py +++ b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py @@ -11,17 +11,365 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import inspect +import os +from functools import partial from typing import Optional, TypedDict + import torch -from .import_utils import is_flash_dmattn_available +import torch.nn.functional as F +from .import_utils import is_flash_dmattn_available from transformers.utils import logging logger = logging.get_logger(__name__) +# `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves +_flash_fn = None +_flash_varlen_fn = None +_pad_fn = None +_unpad_fn = None + +# function that processes kwargs, generalized to handle any supported kwarg within the function +_process_flash_kwargs_fn = None +# exceptions where hf API doesn't match the original FDMA API +_hf_api_to_flash_mapping = { + "dropout": None, + "sliding_window": None, +} + + +def _lazy_imports(implementation: Optional[str]): + """ + Lazy loads the respective flash dynamic mask attention implementations. + + Return: + flash_attn_func: The base flash dynamic mask attention function. + flash_attn_varlen_func: The flash dynamic mask attention function supporting variable sequence lengths, e.g. for padding-free training. + pad_input: The function to pad inputs into one sequence and returning the respective kwargs. + unpad_input: The function to unpad outputs based on the kwargs (from pad_input). + """ + is_fdma = is_flash_dmattn_available() + + pad_input, unpad_input = _pad_input, _unpad_input + + if (implementation == "flash_dmattn" and is_fdma) or (implementation is None and is_fdma): + from flash_dmattn import flash_dmattn_func, flash_dmattn_varlen_func + from flash_dmattn.utils.padding import pad_input, unpad_input + + return flash_dmattn_func, flash_dmattn_varlen_func, pad_input, unpad_input + + +def _lazy_define_process_function(flash_function): + """ + Depending on the version and kernel some features are not supported. Due to limitations in + `torch.compile`, we opt to statically type which (optional) kwarg parameters are supported + within `_process_flash_dynamic_mask_attention_kwargs`. + + NOTE: While all supported kwargs are marked as `True`, everything else is marked as `False`. + This might be confusing for kwargs that we use in any case, e.g. `is_causal`. + """ + + flash_parameters = inspect.signature(flash_function).parameters + process_parameters = inspect.signature(_process_flash_dynamic_mask_attention_kwargs).parameters + + supports_mapping = {} + for param in process_parameters: + fdma_param = _hf_api_to_flash_mapping.get(param, param) + supports_mapping[fdma_param] = fdma_param in flash_parameters + + return partial(_process_flash_dynamic_mask_attention_kwargs, supports_mapping=supports_mapping) + + +def lazy_import_flash_dynamic_mask_attention(implementation: Optional[str], force_import: Optional[bool] = False): + """ + Lazily import flash dmattn and return the respective functions + flags. + + NOTE: For fullgraph, this needs to be called before compile, while no fullgraph can + work without preloading. See `load_and_register_kernel` in `integrations.hub_kernels`. + """ + global _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn + if force_import or any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]): + _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn = _lazy_imports(implementation) + + global _process_flash_kwargs_fn + if force_import or _process_flash_kwargs_fn is None: + _process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn) + + return (_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn), _process_flash_kwargs_fn + + +def _index_first_axis(tensor, indices): + """ + A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis, + after flattening the first two dimensions of the tensor. This is functionally equivalent to + FA2's `index_first_axis` and replaces the need to import it. + """ + # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first + # two dimensions to get (total_tokens, ...) before indexing. + reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:]) + return reshaped_tensor[indices] + + +def _unpad_input(hidden_states, attention_mask, unused_mask=None): + """ + unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3. + + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + return ( + _index_first_axis(hidden_states, indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +def _pad_input(hidden_states, indices, batch, seqlen): + """ + pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3. + + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[1:] + output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) + output[indices] = hidden_states + return output.view(batch, seqlen, *dim) + + +def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + + Arguments: + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + indices (`torch.Tensor`): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`torch.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + # NOTE: Similar to the `.item()` in prepare_fdma_kwargs_from_position_ids, with torch compile, + # this might cause a graph break + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _upad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + unpad_input_func, +): + """ + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. + This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary + tensors for query, key, value tensors. + + Arguments: + query_layer (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + query_length (`int`): + Target length. + unpad_input_func: + The function to use for unpadding the input tensors. + + Return: + query_layer (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + # With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage + # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores + if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]): + key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :] + + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = _index_first_axis(key_layer, indices_k) + value_layer = _index_first_axis(value_layer, indices_k) + if query_length == kv_seq_len: + query_layer = _index_first_axis(query_layer, indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def prepare_fdma_kwargs_from_position_ids(position_ids): + """ + This function returns all the necessary kwargs to call `flash_attn_varlen_func` extracted from position_ids. + + Arguments: + position_ids (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into + ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, + `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device} + + position_ids = position_ids.view(-1) + indices_q = (position_ids == 0).nonzero().view(-1) + + cu_seq_lens_q = torch.cat( + ( + indices_q.to(**tensor_kwargs), + torch.tensor(position_ids.size(), **tensor_kwargs), + ) + ) + cu_seq_lens_k = cu_seq_lens_q + + # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424 + # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing + # for some models (e.g. qwen2-vl). + max_length_q = cu_seq_lens_q.diff().max() + # NOTE: With torch compile, this will cause a graph break if you don't set + # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call + # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass. + # This is a limitation of flash attention API, as the function `flash_attn_varlen_func` + # requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`. + max_length_q = max_length_q.item() + max_length_k = max_length_q + + return (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) + + +def _prepare_from_posids(query, key, value, position_ids): + """ + This function returns necessary arguments to call `flash_attn_varlen_func`. + All three query, key, value states will be flattened. + Cumulative lengths of each examples in the batch will be extracted from position_ids. + NOTE: ideally cumulative lengths should be prepared at the data collator stage + + Arguments: + query (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + position_ids (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + query (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + query = query.contiguous().view(-1, query.size(-2), query.size(-1)) + key = key.contiguous().view(-1, key.size(-2), key.size(-1)) + value = value.contiguous().view(-1, value.size(-2), value.size(-1)) + + (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fdma_kwargs_from_position_ids(position_ids) + + return (query, key, value, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k)) + + +def _is_packed_sequence(position_ids, batch_size): + """ + Check the position ids whether packed sequences are indicated or not + 1. Position ids exist + 2. Flattened sequences only are supported + 3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences + """ + if position_ids is None: + return False + + increasing_position_sequences = ( + torch.arange(position_ids.shape[1], device=position_ids.device) + position_ids.min() + ) + return batch_size == 1 and (increasing_position_sequences - position_ids).abs().sum().bool() + + def fdma_peft_integration_check( q: torch.Tensor, k: torch.Tensor, @@ -43,18 +391,6 @@ def fdma_peft_integration_check( return q, k, v, bias -def _lazy_imports(impl: Optional[str]): - # returns funcs based on impl - is_fdma = is_flash_dmattn_available() - - if impl == "flash_dmattn" or (impl is None and is_fdma): - from flash_dmattn import flash_dmattn_func - return flash_dmattn_func - - else: - return getattr(impl, "flash_dmattn_func", None) - - class FlashDynamicMaskAttentionKwargs(TypedDict, total=False): """ Keyword arguments for Flash Dynamic Mask Attention with Compile. @@ -74,7 +410,69 @@ class FlashDynamicMaskAttentionKwargs(TypedDict, total=False): cu_seq_lens_k: Optional[torch.LongTensor] max_length_q: Optional[int] max_length_k: Optional[int] - + + +def _process_flash_dynamic_mask_attention_kwargs( + query_length: int, + key_length: int, + is_causal: bool, + softmax_scale: Optional[float] = None, + window_size: Optional[int] = None, + softcap: Optional[float] = None, + deterministic: Optional[bool] = None, + s_aux: Optional[torch.Tensor] = None, + supports_mapping: Optional[dict[str, bool]] = None, + **kwargs, +): + """ + Returns a set of kwargs that are passed down to the according flash attention function based on + requested features and whether it is supported - depends on the version and kernel implementation + which is dynamically configured at `lazy_import_flash_dynamic_mask_attention`. The (un)supported features can be + inspected in `supports_mapping`, see `_lazy_define_process_function` for more details. + + Args: + query_length (`int`): + Length of the query states + key_length (`int`): + Length of the key states + is_causal (`bool`): + Whether we perform causal (decoder) attention or full attention. + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to `1 / sqrt(head_dim)`. + window_size (`int`, *optional*): + If set, only the `window_size` largest key/value pairs per query are kept. + softcap (`float`, *optional*): + Softcap for the attention logits, used e.g. in gemma2. + deterministic (`bool`, *optional*): + Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. + s_aux (`torch.Tensor`, *optional*): + Attention sink auxiliary that adds a `bias` to the attention calculation via an additional head. + Return: + flash_kwargs (`dict`): + A dict of kwargs that are requested and supported. + """ + flash_kwargs = { + "is_causal": is_causal and not query_length == 1, + "softmax_scale": softmax_scale, + } + + if supports_mapping["window_size"] and window_size is not None and key_length > window_size: + flash_kwargs["window_size"] = window_size + + if supports_mapping["deterministic"]: + flash_kwargs["deterministic"] = ( + deterministic if deterministic is not None else os.getenv("FLASH_DMATTN_DETERMINISTIC", "0") == "1" + ) + + if supports_mapping["softcap"] and softcap is not None: + flash_kwargs["softcap"] = softcap + + # Only within kernel implementation atm + if supports_mapping["s_aux"] and s_aux is not None: + flash_kwargs["s_aux"] = s_aux + + return flash_kwargs + def _flash_dynamic_mask_attention_forward( query_states: torch.Tensor, @@ -85,51 +483,178 @@ def _flash_dynamic_mask_attention_forward( query_length: int, key_length: int, is_causal: bool, + position_ids: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, - softcap: Optional[float] = None, window_size: Optional[int] = None, + softcap: Optional[float] = None, deterministic: Optional[bool] = None, + cu_seq_lens_q: Optional[torch.LongTensor] = None, + cu_seq_lens_k: Optional[torch.LongTensor] = None, + max_length_q: Optional[int] = None, + max_length_k: Optional[int] = None, target_dtype: Optional[torch.dtype] = None, implementation: Optional[str] = None, **kwargs, ): - dtype = query_states.dtype - min_dtype = torch.finfo(dtype).min + """ + Calls the forward method of Flash Dynamic Mask Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. - if not all(k in globals() for k in ("_flash_fn")): - flash_fn = _lazy_imports(implementation) - globals()["_flash_fn"] = flash_fn - else: - flash_fn = globals()["_flash_fn"] + (Optional) kwargs are described further in `_process_flash_dynamic_mask_attention_kwargs` and `FlashDynamicMaskAttentionKwargs`. - is_causal = is_causal and not query_length == 1 - flash_kwargs = {} - if deterministic is not None: - flash_kwargs["deterministic"] = deterministic - if softcap is not None: - flash_kwargs["softcap"] = softcap + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash DMATTN API + key_states (`torch.Tensor`): + Input key states to be passed to Flash DMATTN API + value_states (`torch.Tensor`): + Input value states to be passed to Flash DMATTN API + attention_mask (`torch.Tensor`, *optional*): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + attention_bias (`torch.Tensor`, *optional*): + The attention bias tensor to add to attention scores. + implementation (`str`, *optional*): + The attention implementation to use. If None, will default to the one based on the environment. + """ + if ( + attention_mask is not None + and attention_mask.dim() == 2 + and attention_bias is not None + ): + raise ValueError( + "If shape of attention_mask is (batch_size, seq_len), attention_bias has to be None." + ) + + (flash_fn, flash_varlen_fn, pad_fn, unpad_fn), process_flash_kwargs_fn = lazy_import_flash_dynamic_mask_attention(implementation) + + # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op query_states, key_states, value_states, attention_bias = fdma_peft_integration_check( query_states, key_states, value_states, attention_bias, target_dtype ) - if attention_mask is not None and attention_mask.dim() == 4: - if attention_bias.dim() == 3: - attention_bias = attention_bias.unsqueeze(-2) - attention_bias = attention_bias.masked_fill( - ~attention_mask, - min_dtype + # Extract the flash attention kwargs that have been requested (and are supported by the implementation) + flash_kwargs = process_flash_kwargs_fn( + query_length=query_length, + key_length=key_length, + is_causal=is_causal, + softmax_scale=softmax_scale, + window_size=window_size, + softcap=softcap, + deterministic=deterministic, + **kwargs, + ) + + # We will use `flash_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases: + # Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`. + # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to + # use `flash_varlen_fn` knowing we already have all necessary the kwargs. + # + # NOTE: it is user's responsibility to take care of flattening `position_ids` if that's needed by the model. + # See #39121 for more information. + is_fdma_with_position_ids = _is_packed_sequence(position_ids, batch_size=query_states.size(0)) + is_fdma_with_varlen_kwargs = all( + kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k) + ) + + # Contains at least one padding token in the sequence + if attention_mask is not None and attention_mask.dim() == 2: + q, k, v, indices_q, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _upad_input( + query_states, key_states, value_states, attention_mask, query_length, unpad_fn ) - if window_size is not None and key_length > window_size: - topk_values, topk_indices = torch.topk( - attention_bias, window_size, dim=-1, largest=True, sorted=False + # TODO for now this is required to work with + # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py + if "mps" in str(q.device): + cu_seq_lens_k = cu_seq_lens_k.clone() + + out_unpad = flash_varlen_fn( + q, + k, + v, + cu_seqlens_q=cu_seq_lens_q, + cu_seqlens_k=cu_seq_lens_k, + max_seqlen_q=max_length_q, + max_seqlen_k=max_length_k, + **flash_kwargs, ) - attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool, device=attention_bias.device) - attention_mask = attention_mask.scatter(-1, topk_indices, topk_values != min_dtype) + if isinstance(out_unpad, tuple): + out_unpad = out_unpad[0] - out = flash_fn( - query_states, key_states, value_states, attn_mask=attention_mask, attn_bias=attention_bias, softmax_scale=softmax_scale, is_causal=is_causal - ) + out = pad_fn(out_unpad, indices_q, query_states.size(0), query_length) + + # Padding free, i.e. sequences flattened into one total sequence + elif is_fdma_with_varlen_kwargs or is_fdma_with_position_ids: + if cu_seq_lens_q is None or cu_seq_lens_k is None: + q, k, v, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _prepare_from_posids( + query_states, key_states, value_states, position_ids + ) + else: + q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) + k = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) + v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) + + # TODO for now this is required to work with + # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py + if "mps" in str(q.device): + cu_seq_lens_k = cu_seq_lens_k.clone() + + out = flash_varlen_fn( + q, + k, + v, + cu_seqlens_q=cu_seq_lens_q, + cu_seqlens_k=cu_seq_lens_k, + max_seqlen_q=max_length_q, + max_seqlen_k=max_length_k, + **flash_kwargs, + ) + if isinstance(out, tuple): + out = out[0] + + out = out.view(query_states.size(0), -1, out.size(-2), out.size(-1)) + + # No padding + else: + + # Generate a combined attention mask if `attention_bias` are provided + if ( + attention_bias is not None + and window_size is not None + and key_length > window_size + ): + min_dtype = torch.finfo(query_states.dtype).min + if attention_mask is not None: + if attention_mask.dim() == 4 and attention_bias.dim() == 3: + attention_bias = attention_bias.unsqueeze(-2).expand(-1, -1, query_length, -1) + if attention_mask.dim() == 3 and attention_bias.dim() == 4: + attention_mask = attention_mask.unsqueeze(-2).expand(-1, -1, query_length, -1) + + topk_values, topk_indices = torch.topk( + attention_bias.masked_fill(~attention_mask, min_dtype).detach(), + window_size, dim=-1, largest=True, sorted=False + ) + attention_mask = torch.zeros_like( + attention_bias, dtype=torch.bool, device=attention_bias.device + ).scatter_(-1, topk_indices, topk_values != min_dtype) + else: + topk_values, topk_indices = torch.topk( + attention_bias.detach(), window_size, dim=-1, largest=True, sorted=False + ) + attention_mask = torch.zeros_like( + attention_bias, dtype=torch.bool, device=attention_bias.device + ).scatter_(-1, topk_indices, topk_values != min_dtype) + + out = flash_fn( + query_states, + key_states, + value_states, + attention_mask, + attention_bias, + **flash_kwargs, + ) + if isinstance(out, tuple): + out = out[0] - return out[0] if isinstance(out, tuple) else out + return out diff --git a/flash_dmattn/utils/padding.py b/flash_dmattn/utils/padding.py new file mode 100644 index 0000000..27350f4 --- /dev/null +++ b/flash_dmattn/utils/padding.py @@ -0,0 +1,170 @@ +# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py + +import torch +import torch.nn.functional as F + + +def index_first_axis(tensor, indices): + """ + A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis, + after flattening the first two dimensions of the tensor. + """ + # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first + # two dimensions to get (total_tokens, ...) before indexing. + reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:]) + return reshaped_tensor[indices] + + +def unpad_input(hidden_states, attention_mask, unused_mask=None): + """ + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + return ( + index_first_axis(hidden_states, indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +def pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[1:] + output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) + output[indices] = hidden_states + return output.view(batch, seqlen, *dim) + + +def get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + + Arguments: + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + indices (`torch.Tensor`): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`torch.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + # NOTE: Similar to the `.item()` in prepare_fdma_from_position_ids, with torch compile, + # this might cause a graph break + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def upad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + unpad_input_func, +): + """ + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. + This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary + tensors for query, key, value tensors. + + Arguments: + query_layer (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + query_length (`int`): + Target length. + unpad_input_func: + The function to use for unpadding the input tensors. + + Return: + query_layer (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = get_unpad_data(attention_mask) + + # With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage + # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores + if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]): + key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :] + + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis(key_layer, indices_k) + value_layer = index_first_axis(value_layer, indices_k) + if query_length == kv_seq_len: + query_layer = index_first_axis(query_layer, indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) \ No newline at end of file