-
Notifications
You must be signed in to change notification settings - Fork 39
Adds variable length forward pass support #75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Implements variable length sequence processing for multi-head attention with dynamic masking support. Enables handling of batched sequences with different lengths using cumulative sequence length tensors, improving memory efficiency for variable-length inputs. Includes support for optional features like paged key-value caching, left padding, attention masking and biasing, and dropout with proper CUDA device management.
Clarifies tensor dimensionality documentation for attention mask and bias parameters to reflect the actual expected shape of total_q x num_heads_k x max_seqlen_k, replacing the previous ambiguous description that mentioned block table variations.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements variable length sequence processing for multi-head attention with dynamic masking support. The implementation enables handling of batched sequences with different lengths using cumulative sequence length tensors, improving memory efficiency for variable-length inputs.
Key changes include:
- Addition of
mha_varlen_fwdfunction that processes variable-length sequences with cumulative sequence length tracking - Support for optional features including paged key-value caching, left padding, attention masking and biasing
- Proper CUDA device management and data type validation for the new functionality
Comments suppressed due to low confidence (1)
csrc/flash_api.cpp:515
- [nitpick] The variable name
seqlenq_ngroups_swappedis not descriptive. Consider renaming to something clearer likeshould_transpose_for_decodeoruse_decode_optimizationto better convey its purpose.
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && p_dropout == 0.f && head_size % 8 == 0;
| // const bool paged_KV = block_table_.has_value(); | ||
| const bool paged_KV = false; // TODO: Temporarily disable Paged KV, because some bugs are still being fixed. | ||
| if (paged_KV) { | ||
| block_table = block_table_.value(); | ||
| CHECK_DEVICE(block_table); | ||
| TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); | ||
| TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); |
Copilot
AI
Jul 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hard-coding paged_KV = false ignores the block_table_ parameter and makes the related validation code unreachable. Consider removing the paged KV logic entirely or adding a runtime check to reject paged KV requests with a clear error message.
| // const bool paged_KV = block_table_.has_value(); | |
| const bool paged_KV = false; // TODO: Temporarily disable Paged KV, because some bugs are still being fixed. | |
| if (paged_KV) { | |
| block_table = block_table_.value(); | |
| CHECK_DEVICE(block_table); | |
| TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); | |
| TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); | |
| if (block_table_.has_value()) { | |
| TORCH_CHECK(false, "Paged KV is currently not supported. Please disable the block_table_ parameter."); |
| CHECK_DEVICE(cu_seqlens_k); | ||
|
|
||
| at::Tensor block_table; | ||
| // const bool paged_KV = block_table_.has_value(); |
Copilot
AI
Jul 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This commented-out line should be removed as it's replaced by the hard-coded false value below. Keeping dead code reduces readability.
| // const bool paged_KV = block_table_.has_value(); |
| int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size}; | ||
| int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size}; | ||
| out = out.reshape(size_before).transpose(1, 2).reshape(size_after); | ||
| q = q.reshape(size_before).transpose(1, 2).reshape(size_after); |
Copilot
AI
Jul 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Magic array initialization with raw arrays makes the code less maintainable. Consider using std::array or initializer lists with clear variable names to improve readability.
| int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size}; | |
| int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size}; | |
| out = out.reshape(size_before).transpose(1, 2).reshape(size_after); | |
| q = q.reshape(size_before).transpose(1, 2).reshape(size_after); | |
| std::array<int64_t, 4> size_before = {batch_size, max_seqlen_q, num_heads_k, head_size}; | |
| std::array<int64_t, 3> size_after = {batch_size, num_heads_k * max_seqlen_q, head_size}; | |
| out = out.reshape(size_before.data()).transpose(1, 2).reshape(size_after.data()); | |
| q = q.reshape(size_before.data()).transpose(1, 2).reshape(size_after.data()); |
Implements variable length sequence processing for multi-head attention with dynamic masking support.
Enables handling of batched sequences with different lengths using cumulative sequence length tensors, improving memory efficiency for variable-length inputs.
Includes support for optional features like paged key-value caching, left padding, attention masking and biasing, and dropout with proper CUDA device management.