-
Notifications
You must be signed in to change notification settings - Fork 39
Make attention parameters optional with defaults and simplify API documentation #95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Improves flexibility by making attn_mask, attn_bias, is_causal, and scale parameters optional with sensible defaults. Creates default attention mask and bias tensors when not provided, sets causal attention to true by default, and calculates scale from head dimension when not specified. Adds proper null checking before tensor slicing operations to prevent errors when optional parameters are None.
Streamlines parameter descriptions and removes verbose explanations to improve readability. Updates code examples to use the simplified high-level interface consistently across all sections. Clarifies that the auto function returns a callable rather than direct output, reducing potential confusion for new users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR enhances the Flash Dynamic Mask Attention API by making attention parameters optional with sensible defaults and improving documentation clarity. The main purpose is to simplify the API while maintaining backward compatibility and improving developer experience.
- Made core attention parameters (attn_mask, attn_bias, is_causal) optional with sensible defaults
- Added automatic scale calculation based on head dimension when not provided
- Significantly simplified and reorganized API documentation for better readability
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| flash_dmattn/flash_dmattn_flex.py | Added optional parameters with default tensor creation and automatic scale calculation |
| docs/api_reference.md | Comprehensive documentation rewrite with simplified examples and clearer formatting |
| if attn_mask is not None: | ||
| attn_mask = attn_mask[:, :, :, : key.shape[-2]] | ||
| else: | ||
| attn_mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype) |
Copilot
AI
Aug 8, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Creating large default tensors with torch.ones can be memory-intensive for long sequences. Consider using a more memory-efficient approach or lazy evaluation for default masks.
| attn_mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype) | |
| attn_mask = None # Avoid allocating a large dense tensor; treat None as "no mask" |
| else: | ||
| attn_bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype) | ||
| if is_causal is None: | ||
| is_causal = True | ||
| if scale is None: | ||
| scale = 1.0 / math.sqrt(dhead) | ||
|
|
||
| def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): | ||
| score = score + attn_bias[batch_idx][head_idx][q_idx][kv_idx] |
Copilot
AI
Aug 8, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Creating large default tensors with torch.zeros can be memory-intensive for long sequences. Consider using a more memory-efficient approach or lazy evaluation for default biases.
| else: | |
| attn_bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype) | |
| if is_causal is None: | |
| is_causal = True | |
| if scale is None: | |
| scale = 1.0 / math.sqrt(dhead) | |
| def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): | |
| score = score + attn_bias[batch_idx][head_idx][q_idx][kv_idx] | |
| # else: leave attn_bias as None to avoid allocating a large zero tensor | |
| if is_causal is None: | |
| is_causal = True | |
| if scale is None: | |
| scale = 1.0 / math.sqrt(dhead) | |
| def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): | |
| if attn_bias is not None: | |
| score = score + attn_bias[batch_idx][head_idx][q_idx][kv_idx] |
Enhance flexibility by making attention parameters optional with sensible defaults, including default tensors for masks and biases. Improve API documentation for better readability and clarity, ensuring consistent use of the simplified interface.