Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Summary

  • Adds dynamic top‑k attention mask utilities and centralizes mask creation for Flash Dynamic Mask Attention (FDMA).
  • Replaces ad‑hoc top‑k logic in the integration with a single, dtype‑aware create_mask helper, improving numerical robustness and readability.
  • Resolves issue [FEATURE REQUEST] Centralize windowed mask generation in utils.mask #195 by normalizing mask construction across callsites and aligning shapes/broadcast semantics with FDMA expectations.

Linked issue: #195

Design

  • Introduce mask.py with:
    • dynamic_mask: builds boolean masks from top‑k over an attention bias, honoring an optional pre‑mask and excluding invalid positions using a configurable minimum value.
    • create_mask: normalizes user‑provided masks (2D or 4D) to match bias shape, pads as needed to align query/key lengths, then delegates to dynamic_mask.
  • Integration wiring in modeling_flash_dynamic_mask_attention_utils.py:
    • Extend the lazy loader to return create_mask alongside core FDMA functions and pad/unpad.
    • Replace inlined top‑k scatter logic with a single call to the shared create_mask, using the attention bias’ dtype for the minimum sentinel.
    • Rename internal variables to fdma_* to clarify the backend and remove local pad/unpad fallbacks in favor of package implementations.

Why this approach:

  • Centralizing mask logic avoids duplicated, slightly different implementations across callsites and ensures consistent dtype/shape handling.
  • Using the bias’ actual dtype for the “min” sentinel avoids subtle mixed‑precision corner cases.
  • Keeping shape normalization in one place makes future mask patterns (e.g., composite sparsity) easier to maintain.

Changes

  • New:
    • mask.py
      • dynamic_mask(attention_bias, attention_mask, window_size, min_dtype)
      • create_mask(attention_bias, attention_mask, batch_size, query_len, key_len, window_size, min_dtype)
  • Updated:
    • modeling_flash_dynamic_mask_attention_utils.py
      • Lazy import now returns (fdma_fn, fdma_varlen_fn, pad_fn, unpad_fn, create_mask_fn).
      • _flash_dynamic_mask_attention_forward uses create_mask_fn(...) instead of local top‑k scatter logic.
      • Internal naming standardized to fdma_*.
      • Local pad/unpad fallbacks removed; use flash_dmattn.utils.padding.
      • The minimum sentinel is now torch.finfo(attention_bias.dtype).min (previously derived from query dtype).
  • No public API changes; behavior is functionally equivalent except more robust dtype handling and consistent shape normalization.

Implementation Notes

  • create_mask supports:
    • 2D masks shaped (batch_size, seq_len); reshaped to (B, 1, 1, K) if seq_len == key_len, or padded/sliced when seq_len == query_len to align with key_len.
    • 4D masks already broadcastable to the bias shape ({B|1}, {H|KVH|1}, {Q|1}, K).
  • When an attention_mask is provided, dynamic_mask first masks invalid positions (set to min_dtype) before top‑k, ensuring they’re never selected.
  • Top‑k indices are produced without sorting for efficiency; boolean mask is built via single scatter_ on the last axis.
  • Internal variable renames (_flash_*_fdma_*) are purely organizational.

Tests

  • Sanity and regression checks performed using the existing benchmark harness:
    • forward_equivalence.py — PASS (no numerical drift observed vs. previous implementation).
    • backward_equivalence.py — PASS.
    • grad_equivalence.py — PASS.
  • Smoke test:
    • python -c "from flash_dmattn import get_available_backends; print(get_available_backends())" — confirms backend availability.
  • Manual checks:
    • Verified masks respect provided 2D attention masks and correct broadcast/alignment to bias shapes.
    • Verified dtype‑aware min sentinel avoids spurious top‑k selection in mixed precision.

Note: The project currently relies on benchmark scripts rather than pytest. Results above are from running the included scripts locally.

Docs

  • Inline docstrings added for dynamic_mask and create_mask with shape contracts.
  • No external docs changed in this PR; follow‑up can add a short section in integration.md describing centralized mask construction and usage in integrations.

Checklist

  • Linked issue provided (fixes [FEATURE REQUEST] Centralize windowed mask generation in utils.mask #195)
  • API stable (no breaking public API changes)
  • Tests added or updated (covered via benchmark equivalence scripts)
  • Docs added or updated (docstrings; follow‑up suggested for integration docs)
  • No known performance regressions (hot path simplified; no additional allocations beyond mask creation already required)

Additional Notes

  • This refactor sets up a single entry point for future mask patterns (e.g., union of local window + learned bias), making it straightforward to extend without touching multiple callsites.
  • Windows users building the CUDA extension should follow the existing guidance in README/setup; no extra steps are introduced by this change.

Introduces utilities to build boolean masks for Flash Dynamic Mask Attention by selecting top‑k positions from an attention bias, improving sparsity and compute efficiency.

Handles 2D mask reshaping and padding to align query/key lengths, respects existing masks, and excludes invalid positions via a configurable minimum value.
Renames internal attention callsites to FDMA-prefixed names for clarity and consistency.

Adds lazy import and wiring for a mask creation utility and uses it to build sliding‑window masks instead of ad‑hoc top‑k logic, improving readability and numerical correctness by using attention bias dtype for min.

Removes local pad/unpad fallbacks in favor of package implementations.

Updates lazy loader return signature and processing hook accordingly.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR centralizes dynamic mask creation for Flash Dynamic Mask Attention (FDMA) by introducing utility functions in a new mask.py module and refactoring the integration layer to use them. The changes improve code maintainability by eliminating duplicated top-k logic and enhance numerical robustness by using dtype-aware sentinel values.

Key Changes:

  • Introduced dynamic_mask and create_mask utilities to centralize mask construction logic with consistent shape normalization and dtype handling
  • Replaced inline top-k scatter logic in the integration with calls to the centralized create_mask function
  • Updated variable naming from _flash_* to _fdma_* for clarity and removed local pad/unpad fallbacks in favor of package implementations

Reviewed Changes

Copilot reviewed 2 out of 3 changed files in this pull request and generated 2 comments.

File Description
flash_dmattn/utils/mask.py New utility module providing dynamic_mask and create_mask functions for centralized mask generation
flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py Refactored to use centralized mask utilities, renamed internal variables to fdma_*, removed local pad/unpad implementations

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment on lines +90 to +94
pad_mask = torch.ones(
(batch_size, 1, 1, pad_len),
dtype=torch.bool,
device=attention_mask.device,
)
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Padding mask is initialized with torch.ones (all True), which marks padded positions as valid. This contradicts the typical attention mask convention where False/0 indicates invalid positions. Consider using torch.zeros instead to mark padding as invalid.

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems requested a review from Copilot October 23, 2025 13:38
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 2 out of 3 changed files in this pull request and generated 2 comments.


Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment on lines +41 to +47
attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias
topk_values, topk_indices = torch.topk(
attention_bias.detach(),
window_size, dim=-1, largest=True, sorted=False
)
attention_mask = torch.zeros_like(
attention_bias, dtype=torch.bool, device=attention_bias.device
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This line creates a potentially unnecessary copy of attention_bias when attention_mask is None. Consider using an early return pattern or storing the result in a new variable to make the intent clearer and avoid reassignment.

Suggested change
attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias
topk_values, topk_indices = torch.topk(
attention_bias.detach(),
window_size, dim=-1, largest=True, sorted=False
)
attention_mask = torch.zeros_like(
attention_bias, dtype=torch.bool, device=attention_bias.device
masked_attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias
topk_values, topk_indices = torch.topk(
masked_attention_bias.detach(),
window_size, dim=-1, largest=True, sorted=False
)
attention_mask = torch.zeros_like(
masked_attention_bias, dtype=torch.bool, device=masked_attention_bias.device

Copilot uses AI. Check for mistakes.
# Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`.
# Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
# use `flash_varlen_fn` knowing we already have all necessary the kwargs.
# use `fdma_varlen_fn` knowing we already have all necessary the kwargs.
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the extra article 'the' before 'kwargs' in line 501. Should read 'all necessary kwargs' instead of 'all necessary the kwargs'.

Suggested change
# use `fdma_varlen_fn` knowing we already have all necessary the kwargs.
# use `fdma_varlen_fn` knowing we already have all necessary kwargs.

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit 2ce1efe into main Oct 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEATURE REQUEST] Centralize windowed mask generation in utils.mask

10 participants