-
Notifications
You must be signed in to change notification settings - Fork 39
Adds variable length forward pass support #75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -430,9 +430,258 @@ mha_fwd( | |||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| return {out, softmax_lse, p, rng_state}; | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| std::vector<at::Tensor> | ||||||||||||||||||||||||||||||||||
| mha_varlen_fwd( | ||||||||||||||||||||||||||||||||||
| at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i | ||||||||||||||||||||||||||||||||||
| const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. | ||||||||||||||||||||||||||||||||||
| const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. | ||||||||||||||||||||||||||||||||||
| const at::Tensor &attn_mask, // total_q x num_heads_k x max_seqlen_k | ||||||||||||||||||||||||||||||||||
| const at::Tensor &attn_bias, // total_q x num_heads_k x max_seqlen_k | ||||||||||||||||||||||||||||||||||
| std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i | ||||||||||||||||||||||||||||||||||
| const at::Tensor &cu_seqlens_q, // b+1 | ||||||||||||||||||||||||||||||||||
| const at::Tensor &cu_seqlens_k, // b+1 | ||||||||||||||||||||||||||||||||||
| std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used. | ||||||||||||||||||||||||||||||||||
| std::optional<const at::Tensor> &leftpad_k_, // batch_size | ||||||||||||||||||||||||||||||||||
| std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq | ||||||||||||||||||||||||||||||||||
| int max_seqlen_q, | ||||||||||||||||||||||||||||||||||
| const int max_seqlen_k, | ||||||||||||||||||||||||||||||||||
| const float p_dropout, | ||||||||||||||||||||||||||||||||||
| const float softmax_scale, | ||||||||||||||||||||||||||||||||||
| const bool zero_tensors, | ||||||||||||||||||||||||||||||||||
| bool is_causal, | ||||||||||||||||||||||||||||||||||
| const int keep_window_size, | ||||||||||||||||||||||||||||||||||
| const float softcap, | ||||||||||||||||||||||||||||||||||
| const bool return_softmax, | ||||||||||||||||||||||||||||||||||
| std::optional<at::Generator> gen_ | ||||||||||||||||||||||||||||||||||
| ) { | ||||||||||||||||||||||||||||||||||
| // Otherwise the kernel will be launched from cuda:0 device | ||||||||||||||||||||||||||||||||||
| at::cuda::CUDAGuard device_guard{q.device()}; | ||||||||||||||||||||||||||||||||||
| auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); | ||||||||||||||||||||||||||||||||||
| bool is_sm8x_min = cc_major >= 8; | ||||||||||||||||||||||||||||||||||
| TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| auto q_dtype = q.dtype(); | ||||||||||||||||||||||||||||||||||
| TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); | ||||||||||||||||||||||||||||||||||
| TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); | ||||||||||||||||||||||||||||||||||
| TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); | ||||||||||||||||||||||||||||||||||
| TORCH_CHECK(attn_mask.dtype() == q_dtype, "attn_mask must have the same dtype as inputs"); | ||||||||||||||||||||||||||||||||||
| TORCH_CHECK(attn_bias.dtype() == q_dtype, "attn_bias must have the same dtype as inputs"); | ||||||||||||||||||||||||||||||||||
| TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); | ||||||||||||||||||||||||||||||||||
| TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(attn_mask); CHECK_DEVICE(attn_bias); | ||||||||||||||||||||||||||||||||||
| CHECK_DEVICE(cu_seqlens_q); | ||||||||||||||||||||||||||||||||||
| CHECK_DEVICE(cu_seqlens_k); | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| at::Tensor block_table; | ||||||||||||||||||||||||||||||||||
| // const bool paged_KV = block_table_.has_value(); | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
| // const bool paged_KV = block_table_.has_value(); |
Copilot
AI
Jul 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This commented-out line should be removed as it's replaced by the hard-coded false value below. Keeping dead code reduces readability.
| // const bool paged_KV = block_table_.has_value(); |
Copilot
AI
Jul 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hard-coding paged_KV to false makes the subsequent paged_KV conditional logic unreachable. Consider removing the paged_KV related code blocks or adding a feature flag instead of hard-coding false.
| const bool paged_KV = false; // TODO: Temporarily disable Paged KV, because some bugs are still being fixed. | |
| const bool paged_KV = ENABLE_PAGED_KV; // Use feature flag to control Paged KV functionality. |
Copilot
AI
Jul 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The paged_KV feature is hardcoded to false but the function still accepts and validates the block_table parameter. Consider either removing the block_table parameter entirely or adding an early return/error when block_table is provided while paged_KV is disabled to avoid confusion and unnecessary validation overhead.
| if (paged_KV) { | |
| block_table = block_table_.value(); | |
| CHECK_DEVICE(block_table); | |
| TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); | |
| TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); | |
| if (block_table_.has_value()) { | |
| TORCH_CHECK(false, "block_table is not supported because paged_KV is currently disabled."); |
Copilot
AI
Jul 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hard-coding paged_KV = false ignores the block_table_ parameter and makes the related validation code unreachable. Consider removing the paged KV logic entirely or adding a runtime check to reject paged KV requests with a clear error message.
| // const bool paged_KV = block_table_.has_value(); | |
| const bool paged_KV = false; // TODO: Temporarily disable Paged KV, because some bugs are still being fixed. | |
| if (paged_KV) { | |
| block_table = block_table_.value(); | |
| CHECK_DEVICE(block_table); | |
| TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); | |
| TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); | |
| if (block_table_.has_value()) { | |
| TORCH_CHECK(false, "Paged KV is currently not supported. Please disable the block_table_ parameter."); |
Copilot
AI
Jul 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shape check for attn_mask appears incorrect. The comment on line 439 indicates attn_mask should be total_mask x num_heads_k where total_mask := total_q x total_k, but this check expects total_q x num_heads_k x max_seqlen_k. This mismatch could cause runtime errors when the attention mask has the correct shape according to the documentation.
| CHECK_SHAPE(attn_mask, total_q, num_heads_k, max_seqlen_k); | |
| CHECK_SHAPE(attn_mask, total_q * total_k, num_heads_k); |
Copilot
AI
Jul 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shape check for attn_bias appears incorrect. The comment on line 440 indicates attn_bias should be total_bias x num_heads_k where total_bias := total_q x total_k, but this check expects total_q x num_heads_k x max_seqlen_k. This mismatch could cause runtime errors when the attention bias has the correct shape according to the documentation.
| CHECK_SHAPE(attn_bias, total_q, num_heads_k, max_seqlen_k); | |
| CHECK_SHAPE(attn_bias, total_q * total_k, num_heads_k); |
Copilot
AI
Jul 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Magic array initialization with raw arrays makes the code less maintainable. Consider using std::array or initializer lists with clear variable names to improve readability.
| int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size}; | |
| int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size}; | |
| out = out.reshape(size_before).transpose(1, 2).reshape(size_after); | |
| q = q.reshape(size_before).transpose(1, 2).reshape(size_after); | |
| std::array<int64_t, 4> size_before = {batch_size, max_seqlen_q, num_heads_k, head_size}; | |
| std::array<int64_t, 3> size_after = {batch_size, num_heads_k * max_seqlen_q, head_size}; | |
| out = out.reshape(size_before.data()).transpose(1, 2).reshape(size_after.data()); | |
| q = q.reshape(size_before.data()).transpose(1, 2).reshape(size_after.data()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment incorrectly states 'total_k := \sum_{i=0}^{b} s_i' for the out_ parameter, but it should be 'total_q := \sum_{i=0}^{b} s_i' since out_ has the same shape as the query tensor q.