-
Notifications
You must be signed in to change notification settings - Fork 39
Refactors API to use unified flash attention interface #84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Replaces low-level CUDA extension calls with simplified function interface that handles dynamic masking internally. Removes manual ZOH state and active mask management in favor of attention bias and mask parameters. Adds dynamic top-k selection for long sequences to improve memory efficiency. Simplifies troubleshooting documentation by removing CUDA-specific debugging steps and focusing on memory monitoring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors the Flash DMA attention API to use a unified interface that abstracts away low-level CUDA operations. The changes simplify the API by replacing manual state and mask management with attention bias and mask parameters, while adding dynamic top-k selection for improved memory efficiency on long sequences.
- Replaces
flash_dma_cudawithflash_dmattn_funcunified interface - Removes manual ZOH state and active mask management in favor of attention bias/mask parameters
- Adds dynamic top-k selection for long sequences to improve memory efficiency
| keep_window_size = 2048 | ||
| if seq_len > keep_window_size: | ||
| # Select top-k most important keys for each query | ||
| topk_indices = torch.topk(attention_bias, keep_window_size, dim=-1, |
Copilot
AI
Jul 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The keep_window_size variable is used but not defined in this context. It was removed from the function call but still referenced in the top-k selection logic.
| topk_indices = torch.topk(attention_bias, keep_window_size, dim=-1, | ||
| largest=True, sorted=False).indices | ||
| attention_mask.zero_() | ||
| attention_mask.scatter(-1, topk_indices, 1.0) |
Copilot
AI
Jul 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The scatter operation may fail because topk_indices has shape [batch_size, num_heads, seq_len, keep_window_size] but attention_mask expects indices for the last dimension of size seq_len. The indices need to be properly shaped or the scatter operation needs to specify the correct dimensions.
| attention_mask.scatter(-1, topk_indices, 1.0) | |
| topk_mask = torch.zeros_like(attention_mask, dtype=torch.bool) | |
| topk_mask.scatter_(-1, topk_indices, True) | |
| attention_mask.masked_fill_(topk_mask, 1.0) |
| from flash_dmattn import flash_dmattn_func, get_available_backends | ||
| print("✅ Flash Dynamic Mask Attention imported successfully") |
Copilot
AI
Jul 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The get_available_backends function is imported but never used in the example. Consider either using it in the example or removing it from the import to avoid confusion.
| from flash_dmattn import flash_dmattn_func, get_available_backends | |
| print("✅ Flash Dynamic Mask Attention imported successfully") | |
| from flash_dmattn import flash_dmattn_func | |
| print("✅ Flash Dynamic Mask Attention imported successfully") | |
| from flash_dmattn import get_available_backends |
Replaces low-level CUDA extension calls with simplified function interface
that handles dynamic masking internally. Removes manual ZOH state and
active mask management in favor of attention bias and mask parameters.
Adds dynamic top-k selection for long sequences to improve memory
efficiency. Simplifies troubleshooting documentation by removing
CUDA-specific debugging steps and focusing on memory monitoring.