Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Replaces low-level CUDA extension calls with simplified function interface
that handles dynamic masking internally. Removes manual ZOH state and
active mask management in favor of attention bias and mask parameters.

Adds dynamic top-k selection for long sequences to improve memory
efficiency. Simplifies troubleshooting documentation by removing
CUDA-specific debugging steps and focusing on memory monitoring.

Replaces low-level CUDA extension calls with simplified function interface
that handles dynamic masking internally. Removes manual ZOH state and
active mask management in favor of attention bias and mask parameters.

Adds dynamic top-k selection for long sequences to improve memory
efficiency. Simplifies troubleshooting documentation by removing
CUDA-specific debugging steps and focusing on memory monitoring.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors the Flash DMA attention API to use a unified interface that abstracts away low-level CUDA operations. The changes simplify the API by replacing manual state and mask management with attention bias and mask parameters, while adding dynamic top-k selection for improved memory efficiency on long sequences.

  • Replaces flash_dma_cuda with flash_dmattn_func unified interface
  • Removes manual ZOH state and active mask management in favor of attention bias/mask parameters
  • Adds dynamic top-k selection for long sequences to improve memory efficiency

keep_window_size = 2048
if seq_len > keep_window_size:
# Select top-k most important keys for each query
topk_indices = torch.topk(attention_bias, keep_window_size, dim=-1,
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The keep_window_size variable is used but not defined in this context. It was removed from the function call but still referenced in the top-k selection logic.

Copilot uses AI. Check for mistakes.
topk_indices = torch.topk(attention_bias, keep_window_size, dim=-1,
largest=True, sorted=False).indices
attention_mask.zero_()
attention_mask.scatter(-1, topk_indices, 1.0)
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scatter operation may fail because topk_indices has shape [batch_size, num_heads, seq_len, keep_window_size] but attention_mask expects indices for the last dimension of size seq_len. The indices need to be properly shaped or the scatter operation needs to specify the correct dimensions.

Suggested change
attention_mask.scatter(-1, topk_indices, 1.0)
topk_mask = torch.zeros_like(attention_mask, dtype=torch.bool)
topk_mask.scatter_(-1, topk_indices, True)
attention_mask.masked_fill_(topk_mask, 1.0)

Copilot uses AI. Check for mistakes.
Comment on lines +204 to +205
from flash_dmattn import flash_dmattn_func, get_available_backends
print("✅ Flash Dynamic Mask Attention imported successfully")
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The get_available_backends function is imported but never used in the example. Consider either using it in the example or removing it from the import to avoid confusion.

Suggested change
from flash_dmattn import flash_dmattn_func, get_available_backends
print("✅ Flash Dynamic Mask Attention imported successfully")
from flash_dmattn import flash_dmattn_func
print("✅ Flash Dynamic Mask Attention imported successfully")
from flash_dmattn import get_available_backends

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit 08fdcb4 into main Jul 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants