Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Enhance flexibility by making attention parameters optional with sensible defaults, including default tensors for masks and biases. Improve API documentation for better readability and clarity, ensuring consistent use of the simplified interface.

Improves flexibility by making attn_mask, attn_bias, is_causal, and scale parameters optional with sensible defaults.

Creates default attention mask and bias tensors when not provided, sets causal attention to true by default, and calculates scale from head dimension when not specified.

Adds proper null checking before tensor slicing operations to prevent errors when optional parameters are None.
Streamlines parameter descriptions and removes verbose explanations to improve readability.

Updates code examples to use the simplified high-level interface consistently across all sections.

Clarifies that the auto function returns a callable rather than direct output, reducing potential confusion for new users.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR enhances the Flash Dynamic Mask Attention API by making attention parameters optional with sensible defaults and improving documentation clarity. The main purpose is to simplify the API while maintaining backward compatibility and improving developer experience.

  • Made core attention parameters (attn_mask, attn_bias, is_causal) optional with sensible defaults
  • Added automatic scale calculation based on head dimension when not provided
  • Significantly simplified and reorganized API documentation for better readability

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
flash_dmattn/flash_dmattn_flex.py Added optional parameters with default tensor creation and automatic scale calculation
docs/api_reference.md Comprehensive documentation rewrite with simplified examples and clearer formatting

if attn_mask is not None:
attn_mask = attn_mask[:, :, :, : key.shape[-2]]
else:
attn_mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype)
Copy link

Copilot AI Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating large default tensors with torch.ones can be memory-intensive for long sequences. Consider using a more memory-efficient approach or lazy evaluation for default masks.

Suggested change
attn_mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype)
attn_mask = None # Avoid allocating a large dense tensor; treat None as "no mask"

Copilot uses AI. Check for mistakes.
Comment on lines +29 to 37
else:
attn_bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype)
if is_causal is None:
is_causal = True
if scale is None:
scale = 1.0 / math.sqrt(dhead)

def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
score = score + attn_bias[batch_idx][head_idx][q_idx][kv_idx]
Copy link

Copilot AI Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating large default tensors with torch.zeros can be memory-intensive for long sequences. Consider using a more memory-efficient approach or lazy evaluation for default biases.

Suggested change
else:
attn_bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype)
if is_causal is None:
is_causal = True
if scale is None:
scale = 1.0 / math.sqrt(dhead)
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
score = score + attn_bias[batch_idx][head_idx][q_idx][kv_idx]
# else: leave attn_bias as None to avoid allocating a large zero tensor
if is_causal is None:
is_causal = True
if scale is None:
scale = 1.0 / math.sqrt(dhead)
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
if attn_bias is not None:
score = score + attn_bias[batch_idx][head_idx][q_idx][kv_idx]

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit a0982bc into main Aug 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants