Skip to content

[FEATURE REQUEST] Centralize windowed mask generation in utils.mask #195

@LoserCheems

Description

@LoserCheems

Problem statement

The top‑k/windowed attention mask generation logic currently lives in an integration path (e.g., the “no padding” branch). This causes duplication and drift risks across call sites and makes it harder to reuse the same behavior in future integrations/backends. It’s also easy to get dtype min-value handling, broadcasting, and shape normalization slightly different in each place.

Proposed solution

Move the windowed mask generation into mask.py and have integrations call a single helper. Concretely:

  • Provide a high-level helper that only applies a top‑k window when requested and otherwise passes through the existing mask:
    • Function (proposed): maybe_create_mask(attention_bias, attention_mask, batch_size, query_len, key_len, window_size, min_dtype=None) -> Optional[Tensor]
    • Behavior:
      • If window_size is None or key_len <= window_size, return the original attention_mask unchanged.
      • Otherwise, compute the top‑k indices from attention_bias (respecting any existing boolean attention_mask if provided), and return a boolean mask of the same shape as attention_bias via a scatter.
      • Default the fill value using attention_bias.dtype (validate it’s a floating dtype).
  • Keep a low-level helper to normalize 2D (B, S) masks to (B, H_or_1, Lq_or_1, Lk_or_1) and construct the final boolean window mask:
    • create_mask(attention_bias, attention_mask, batch_size, query_len, key_len, window_size, min_dtype)
    • Supports masking either on key_len or, if needed, left-padded query_len cases.

Acceptance criteria:

  • Functional parity with current integration path (bitwise-equal boolean masks and identical attention outputs for the same inputs).
  • Handles shapes {B|1} x {H|KVH|1} x {Lq|1} x {Lk|1} correctly, including the case where an input (B, S) mask needs expansion.
  • Correct dtype min fill selection and no graph breaks in common torch.compile paths.

Alternatives considered

  • Keep the logic inside the integration forward path: increases duplication and long-term maintenance cost.
  • Move to utils/padding.py: semantically this is about mask generation rather than padding/unpadding, so utils/mask.py is a clearer home.
  • Push into backends directly: would fragment behavior across kernels and reintroduce duplication.

Implementation details

  • CUDA kernel changes: Not required.
  • Python API: Add maybe_create_mask to utils/mask.py. Keep create_mask as the building block.
  • Integrations: Replace inlined top‑k creation in the no‑padding path with a call to maybe_create_mask.
  • Performance: Neutral to positive. Centralization removes redundant code and ensures consistent early‑exit when windowing is unnecessary (key_len <= window_size).
  • Compatibility: Preserves current shapes and broadcasting rules. Validates attention_bias is floating when window_size is set.
  • Testing/verification:
    • Run forward_equivalence.py, backward_equivalence.py, and grad_equivalence.py to verify identical results.
    • Sanity-check variable sequence lengths and both with/without existing (B, S) masks.

Use case

  • Reusing the same windowed mask logic across:
    • modeling_flash_dynamic_mask_attention_utils.py (no‑padding branch),
    • Future integrations that need top‑k windowing,
    • Potential benchmarking or ablation scripts that wish to synthesize masks uniformly.
  • Typical sequence lengths range from short prompts to long contexts where window_size << key_len, and correctness and consistency across backends are critical.

Related work

  • Aligns with best practices in attention libraries to centralize mask/materialization helpers for consistency (e.g., centralized padding/unpadding in FlashAttention ecosystems).
  • Reduces divergence between “varlen” and “no‑padding” code paths by sharing mask construction semantics.

Additional context

  • Proposed home and helpers:
    • mask.py: maybe_create_mask, create_mask, dynamic_mask
  • Current call site to refactor:
    • modeling_flash_dynamic_mask_attention_utils.py (no‑padding branch for mask creation)

Metadata

Metadata

Labels

featureNew feature request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions