-
Notifications
You must be signed in to change notification settings - Fork 39
Closed
Labels
featureNew feature requestNew feature request
Description
Problem statement
The top‑k/windowed attention mask generation logic currently lives in an integration path (e.g., the “no padding” branch). This causes duplication and drift risks across call sites and makes it harder to reuse the same behavior in future integrations/backends. It’s also easy to get dtype min-value handling, broadcasting, and shape normalization slightly different in each place.
Proposed solution
Move the windowed mask generation into mask.py and have integrations call a single helper. Concretely:
- Provide a high-level helper that only applies a top‑k window when requested and otherwise passes through the existing mask:
- Function (proposed):
maybe_create_mask(attention_bias, attention_mask, batch_size, query_len, key_len, window_size, min_dtype=None) -> Optional[Tensor] - Behavior:
- If
window_size is Noneorkey_len <= window_size, return the originalattention_maskunchanged. - Otherwise, compute the top‑k indices from
attention_bias(respecting any existing booleanattention_maskif provided), and return a boolean mask of the same shape asattention_biasvia a scatter. - Default the fill value using
attention_bias.dtype(validate it’s a floating dtype).
- If
- Function (proposed):
- Keep a low-level helper to normalize 2D
(B, S)masks to(B, H_or_1, Lq_or_1, Lk_or_1)and construct the final boolean window mask:create_mask(attention_bias, attention_mask, batch_size, query_len, key_len, window_size, min_dtype)- Supports masking either on
key_lenor, if needed, left-paddedquery_lencases.
Acceptance criteria:
- Functional parity with current integration path (bitwise-equal boolean masks and identical attention outputs for the same inputs).
- Handles shapes
{B|1} x {H|KVH|1} x {Lq|1} x {Lk|1}correctly, including the case where an input(B, S)mask needs expansion. - Correct dtype min fill selection and no graph breaks in common
torch.compilepaths.
Alternatives considered
- Keep the logic inside the integration forward path: increases duplication and long-term maintenance cost.
- Move to
utils/padding.py: semantically this is about mask generation rather than padding/unpadding, soutils/mask.pyis a clearer home. - Push into backends directly: would fragment behavior across kernels and reintroduce duplication.
Implementation details
- CUDA kernel changes: Not required.
- Python API: Add
maybe_create_masktoutils/mask.py. Keepcreate_maskas the building block. - Integrations: Replace inlined top‑k creation in the no‑padding path with a call to
maybe_create_mask. - Performance: Neutral to positive. Centralization removes redundant code and ensures consistent early‑exit when windowing is unnecessary (
key_len <= window_size). - Compatibility: Preserves current shapes and broadcasting rules. Validates
attention_biasis floating whenwindow_sizeis set. - Testing/verification:
- Run forward_equivalence.py,
backward_equivalence.py, andgrad_equivalence.pyto verify identical results. - Sanity-check variable sequence lengths and both with/without existing
(B, S)masks.
- Run forward_equivalence.py,
Use case
- Reusing the same windowed mask logic across:
- modeling_flash_dynamic_mask_attention_utils.py (no‑padding branch),
- Future integrations that need top‑k windowing,
- Potential benchmarking or ablation scripts that wish to synthesize masks uniformly.
- Typical sequence lengths range from short prompts to long contexts where
window_size << key_len, and correctness and consistency across backends are critical.
Related work
- Aligns with best practices in attention libraries to centralize mask/materialization helpers for consistency (e.g., centralized padding/unpadding in FlashAttention ecosystems).
- Reduces divergence between “varlen” and “no‑padding” code paths by sharing mask construction semantics.
Additional context
- Proposed home and helpers:
- mask.py:
maybe_create_mask,create_mask,dynamic_mask
- mask.py:
- Current call site to refactor:
- modeling_flash_dynamic_mask_attention_utils.py (no‑padding branch for mask creation)
Metadata
Metadata
Labels
featureNew feature requestNew feature request