Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Implements variable length sequence processing for multi-head attention with dynamic masking support.

Enables handling of batched sequences with different lengths using cumulative sequence length tensors, improving memory efficiency for variable-length inputs.

Includes support for optional features like paged key-value caching, left padding, attention masking and biasing, and dropout with proper CUDA device management.

Implements variable length sequence processing for multi-head attention with dynamic masking support.

Enables handling of batched sequences with different lengths using cumulative sequence length tensors, improving memory efficiency for variable-length inputs.

Includes support for optional features like paged key-value caching, left padding, attention masking and biasing, and dropout with proper CUDA device management.
Copilot AI review requested due to automatic review settings July 25, 2025 15:47

This comment was marked as outdated.

Clarifies tensor dimensionality documentation for attention mask and bias parameters to reflect the actual expected shape of total_q x num_heads_k x max_seqlen_k, replacing the previous ambiguous description that mentioned block table variations.
@LoserCheems LoserCheems requested a review from Copilot July 25, 2025 15:53

This comment was marked as outdated.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@LoserCheems LoserCheems requested a review from Copilot July 25, 2025 16:01
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements variable length sequence processing for multi-head attention with dynamic masking support. The implementation enables handling of batched sequences with different lengths using cumulative sequence length tensors, improving memory efficiency for variable-length inputs.

Key changes include:

  • Addition of mha_varlen_fwd function that processes variable-length sequences with cumulative sequence length tracking
  • Support for optional features including paged key-value caching, left padding, attention masking and biasing
  • Proper CUDA device management and data type validation for the new functionality
Comments suppressed due to low confidence (1)

csrc/flash_api.cpp:515

  • [nitpick] The variable name seqlenq_ngroups_swapped is not descriptive. Consider renaming to something clearer like should_transpose_for_decode or use_decode_optimization to better convey its purpose.
    const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && p_dropout == 0.f && head_size % 8 == 0;

Comment on lines +478 to +484
// const bool paged_KV = block_table_.has_value();
const bool paged_KV = false; // TODO: Temporarily disable Paged KV, because some bugs are still being fixed.
if (paged_KV) {
block_table = block_table_.value();
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
Copy link

Copilot AI Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard-coding paged_KV = false ignores the block_table_ parameter and makes the related validation code unreachable. Consider removing the paged KV logic entirely or adding a runtime check to reject paged KV requests with a clear error message.

Suggested change
// const bool paged_KV = block_table_.has_value();
const bool paged_KV = false; // TODO: Temporarily disable Paged KV, because some bugs are still being fixed.
if (paged_KV) {
block_table = block_table_.value();
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
if (block_table_.has_value()) {
TORCH_CHECK(false, "Paged KV is currently not supported. Please disable the block_table_ parameter.");

Copilot uses AI. Check for mistakes.
CHECK_DEVICE(cu_seqlens_k);

at::Tensor block_table;
// const bool paged_KV = block_table_.has_value();
Copy link

Copilot AI Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commented-out line should be removed as it's replaced by the hard-coded false value below. Keeping dead code reduces readability.

Suggested change
// const bool paged_KV = block_table_.has_value();

Copilot uses AI. Check for mistakes.
Comment on lines +671 to +674
int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size};
int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size};
out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
q = q.reshape(size_before).transpose(1, 2).reshape(size_after);
Copy link

Copilot AI Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Magic array initialization with raw arrays makes the code less maintainable. Consider using std::array or initializer lists with clear variable names to improve readability.

Suggested change
int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size};
int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size};
out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
q = q.reshape(size_before).transpose(1, 2).reshape(size_after);
std::array<int64_t, 4> size_before = {batch_size, max_seqlen_q, num_heads_k, head_size};
std::array<int64_t, 3> size_after = {batch_size, num_heads_k * max_seqlen_q, head_size};
out = out.reshape(size_before.data()).transpose(1, 2).reshape(size_after.data());
q = q.reshape(size_before.data()).transpose(1, 2).reshape(size_after.data());

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit 02b769d into main Jul 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants