-
Notifications
You must be signed in to change notification settings - Fork 39
Refactors API to use unified flash attention interface #84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -47,8 +47,7 @@ pip install . | |||||||||||
|
|
||||||||||||
| ```python | ||||||||||||
| import torch | ||||||||||||
| import flash_dma_cuda | ||||||||||||
| import torch.nn.functional as F | ||||||||||||
| from flash_dmattn import flash_dmattn_func | ||||||||||||
| import math | ||||||||||||
|
|
||||||||||||
| # Setup | ||||||||||||
|
|
@@ -63,19 +62,32 @@ key = torch.randn(batch_size, seq_len, num_heads, head_dim, | |||||||||||
| device=device, dtype=dtype) | ||||||||||||
| value = torch.randn(batch_size, seq_len, num_heads, head_dim, | ||||||||||||
| device=device, dtype=dtype) | ||||||||||||
| zoh_states = torch.randn(batch_size, num_heads, seq_len, seq_len, | ||||||||||||
| device=device, dtype=dtype) | ||||||||||||
| active_mask = torch.ones(batch_size, num_heads, seq_len, seq_len, | ||||||||||||
| device=device, dtype=dtype) | ||||||||||||
|
|
||||||||||||
| # Run Flash-DMA | ||||||||||||
| output = flash_dma_cuda.fwd( | ||||||||||||
| q=query, k=key, v=value, | ||||||||||||
| zoh=zoh_states, active_mask=active_mask, | ||||||||||||
|
|
||||||||||||
| # Create mask and bias for sparse attention | ||||||||||||
| attention_bias = torch.randn(batch_size, num_heads, seq_len, seq_len, | ||||||||||||
| device=device, dtype=dtype) | ||||||||||||
| attention_mask = torch.ones(batch_size, num_heads, seq_len, seq_len, | ||||||||||||
| device=device, dtype=dtype) | ||||||||||||
|
|
||||||||||||
| # Apply dynamic masking (keep top-k for long sequences) | ||||||||||||
| keep_window_size = 2048 | ||||||||||||
| if seq_len > keep_window_size: | ||||||||||||
| # Select top-k most important keys for each query | ||||||||||||
| topk_indices = torch.topk(attention_bias, keep_window_size, dim=-1, | ||||||||||||
| largest=True, sorted=False).indices | ||||||||||||
| attention_mask.zero_() | ||||||||||||
| attention_mask.scatter(-1, topk_indices, 1.0) | ||||||||||||
|
||||||||||||
| attention_mask.scatter(-1, topk_indices, 1.0) | |
| topk_mask = torch.zeros_like(attention_mask, dtype=torch.bool) | |
| topk_mask.scatter_(-1, topk_indices, True) | |
| attention_mask.masked_fill_(topk_mask, 1.0) |
Copilot
AI
Jul 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The get_available_backends function is imported but never used in the example. Consider either using it in the example or removing it from the import to avoid confusion.
| from flash_dmattn import flash_dmattn_func, get_available_backends | |
| print("✅ Flash Dynamic Mask Attention imported successfully") | |
| from flash_dmattn import flash_dmattn_func | |
| print("✅ Flash Dynamic Mask Attention imported successfully") | |
| from flash_dmattn import get_available_backends |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
keep_window_sizevariable is used but not defined in this context. It was removed from the function call but still referenced in the top-k selection logic.