-
Notifications
You must be signed in to change notification settings - Fork 39
[FEATURE SUPPORT] Centralize dynamic mask creation for FDMA #197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Introduces utilities to build boolean masks for Flash Dynamic Mask Attention by selecting top‑k positions from an attention bias, improving sparsity and compute efficiency. Handles 2D mask reshaping and padding to align query/key lengths, respects existing masks, and excludes invalid positions via a configurable minimum value.
Renames internal attention callsites to FDMA-prefixed names for clarity and consistency. Adds lazy import and wiring for a mask creation utility and uses it to build sliding‑window masks instead of ad‑hoc top‑k logic, improving readability and numerical correctness by using attention bias dtype for min. Removes local pad/unpad fallbacks in favor of package implementations. Updates lazy loader return signature and processing hook accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR centralizes dynamic mask creation for Flash Dynamic Mask Attention (FDMA) by introducing utility functions in a new mask.py module and refactoring the integration layer to use them. The changes improve code maintainability by eliminating duplicated top-k logic and enhance numerical robustness by using dtype-aware sentinel values.
Key Changes:
- Introduced
dynamic_maskandcreate_maskutilities to centralize mask construction logic with consistent shape normalization and dtype handling - Replaced inline top-k scatter logic in the integration with calls to the centralized
create_maskfunction - Updated variable naming from
_flash_*to_fdma_*for clarity and removed local pad/unpad fallbacks in favor of package implementations
Reviewed Changes
Copilot reviewed 2 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| flash_dmattn/utils/mask.py | New utility module providing dynamic_mask and create_mask functions for centralized mask generation |
| flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py | Refactored to use centralized mask utilities, renamed internal variables to fdma_*, removed local pad/unpad implementations |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| pad_mask = torch.ones( | ||
| (batch_size, 1, 1, pad_len), | ||
| dtype=torch.bool, | ||
| device=attention_mask.device, | ||
| ) |
Copilot
AI
Oct 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Padding mask is initialized with torch.ones (all True), which marks padded positions as valid. This contradicts the typical attention mask convention where False/0 indicates invalid positions. Consider using torch.zeros instead to mark padding as invalid.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…ad of {key_len|1})
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 2 out of 3 changed files in this pull request and generated 2 comments.
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias | ||
| topk_values, topk_indices = torch.topk( | ||
| attention_bias.detach(), | ||
| window_size, dim=-1, largest=True, sorted=False | ||
| ) | ||
| attention_mask = torch.zeros_like( | ||
| attention_bias, dtype=torch.bool, device=attention_bias.device |
Copilot
AI
Oct 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] This line creates a potentially unnecessary copy of attention_bias when attention_mask is None. Consider using an early return pattern or storing the result in a new variable to make the intent clearer and avoid reassignment.
| attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias | |
| topk_values, topk_indices = torch.topk( | |
| attention_bias.detach(), | |
| window_size, dim=-1, largest=True, sorted=False | |
| ) | |
| attention_mask = torch.zeros_like( | |
| attention_bias, dtype=torch.bool, device=attention_bias.device | |
| masked_attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias | |
| topk_values, topk_indices = torch.topk( | |
| masked_attention_bias.detach(), | |
| window_size, dim=-1, largest=True, sorted=False | |
| ) | |
| attention_mask = torch.zeros_like( | |
| masked_attention_bias, dtype=torch.bool, device=masked_attention_bias.device |
| # Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`. | ||
| # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to | ||
| # use `flash_varlen_fn` knowing we already have all necessary the kwargs. | ||
| # use `fdma_varlen_fn` knowing we already have all necessary the kwargs. |
Copilot
AI
Oct 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the extra article 'the' before 'kwargs' in line 501. Should read 'all necessary kwargs' instead of 'all necessary the kwargs'.
| # use `fdma_varlen_fn` knowing we already have all necessary the kwargs. | |
| # use `fdma_varlen_fn` knowing we already have all necessary kwargs. |
Summary
create_maskhelper, improving numerical robustness and readability.Linked issue: #195
Design
dynamic_mask: builds boolean masks from top‑k over an attention bias, honoring an optional pre‑mask and excluding invalid positions using a configurable minimum value.create_mask: normalizes user‑provided masks (2D or 4D) to match bias shape, pads as needed to align query/key lengths, then delegates todynamic_mask.modeling_flash_dynamic_mask_attention_utils.py:create_maskalongside core FDMA functions andpad/unpad.create_mask, using the attention bias’ dtype for the minimum sentinel.fdma_*to clarify the backend and remove local pad/unpad fallbacks in favor of package implementations.Why this approach:
Changes
dynamic_mask(attention_bias, attention_mask, window_size, min_dtype)create_mask(attention_bias, attention_mask, batch_size, query_len, key_len, window_size, min_dtype)(fdma_fn, fdma_varlen_fn, pad_fn, unpad_fn, create_mask_fn)._flash_dynamic_mask_attention_forwardusescreate_mask_fn(...)instead of local top‑k scatter logic.fdma_*.flash_dmattn.utils.padding.torch.finfo(attention_bias.dtype).min(previously derived from query dtype).Implementation Notes
create_masksupports:(batch_size, seq_len); reshaped to(B, 1, 1, K)ifseq_len == key_len, or padded/sliced whenseq_len == query_lento align withkey_len.({B|1}, {H|KVH|1}, {Q|1}, K).attention_maskis provided,dynamic_maskfirst masks invalid positions (set tomin_dtype) before top‑k, ensuring they’re never selected.scatter_on the last axis._flash_*→_fdma_*) are purely organizational.Tests
python -c "from flash_dmattn import get_available_backends; print(get_available_backends())"— confirms backend availability.Note: The project currently relies on benchmark scripts rather than pytest. Results above are from running the included scripts locally.
Docs
dynamic_maskandcreate_maskwith shape contracts.Checklist
Additional Notes