Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 22 additions & 24 deletions docs/api_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,22 +68,23 @@ Main attention function. Supports multi-head and grouped-query attention (when t

```python
def flash_dmattn_func(
q: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim)
k: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim)
v: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim)
attn_mask: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k)
attn_bias: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k)
scale: Optional[float] = None, # score scaling, defaults to 1/sqrt(head_dim)
is_causal: Optional[bool] = None, # causal mask
softcap: Optional[float] = None, # CUDA-only
deterministic: Optional[bool] = None, # CUDA-only
query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim)
key: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim)
value: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim)
attn_mask: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k)
attn_bias: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k)
scale: Optional[float] = None, # score scaling, defaults to 1/sqrt(head_dim)
is_causal: Optional[bool] = None, # causal mask
softcap: Optional[float] = None, # CUDA-only
deterministic: Optional[bool] = None, # CUDA-only
) -> torch.Tensor
```

#### Parameters

- q: (B, Q, H, D). CUDA tensor, fp16/bf16, last dim contiguous
- k, v: (B, K, H_kv, D). Same dtype/device as q; GQA when H_kv < H
- query: (B, Q, H, D). CUDA tensor, fp16/bf16, last dim contiguous
- key: (B, K, H_kv, D). Same dtype/device as query; GQA when H_kv <= H
- value: (B, K, H_kv, D). Same dtype/device as query; GQA when H_kv <= H
Comment on lines +86 to +87
Copy link

Copilot AI Aug 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GQA condition description should clarify what 'H' refers to. It should specify 'H_kv <= num_heads' or reference the query tensor's head dimension to avoid ambiguity.

Suggested change
- key: (B, K, H_kv, D). Same dtype/device as query; GQA when H_kv <= H
- value: (B, K, H_kv, D). Same dtype/device as query; GQA when H_kv <= H
- key: (B, K, H_kv, D). Same dtype/device as query; GQA when H_kv <= H (number of query heads in the query tensor)
- value: (B, K, H_kv, D). Same dtype/device as query; GQA when H_kv <= H (number of query heads in the query tensor)

Copilot uses AI. Check for mistakes.
Comment on lines +86 to +87
Copy link

Copilot AI Aug 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GQA condition description should clarify what 'H' refers to. It should specify 'H_kv <= num_heads' or reference the query tensor's head dimension to avoid ambiguity.

Suggested change
- key: (B, K, H_kv, D). Same dtype/device as query; GQA when H_kv <= H
- value: (B, K, H_kv, D). Same dtype/device as query; GQA when H_kv <= H
- key: (B, K, H_kv, D). Same dtype/device as query; GQA when H_kv <= H (number of query heads)
- value: (B, K, H_kv, D). Same dtype/device as query; GQA when H_kv <= H (number of query heads)

Copilot uses AI. Check for mistakes.
- attn_mask: (B, H, Q, K). 1.0 = visible, 0.0 = masked. None to disable
- attn_bias: (B, H, Q, K). Added to scores before softmax. None to disable
- scale: score scaling; default 1/sqrt(D)
Expand Down Expand Up @@ -137,20 +138,20 @@ Variable length attention for batches with mixed sequence lengths.

```python
def flash_dmattn_varlen_func(
q: torch.Tensor, # (total_q, H, D) or (B, Q, H, D)
k: torch.Tensor, # same layout as q
v: torch.Tensor, # same layout as q
attn_mask: Optional[torch.Tensor] = None, # (B, H, Q, K)
attn_bias: Optional[torch.Tensor] = None, # (B, H, Q, K)
cu_seqlens_q: torch.Tensor = None, # (B+1,)
cu_seqlens_k: torch.Tensor = None, # (B+1,)
query: torch.Tensor, # (total_q, H, D) or (B, Q, H, D)
key: torch.Tensor, # same layout as query
value: torch.Tensor, # same layout as query
attn_mask: Optional[torch.Tensor] = None, # (B, H, Q, K)
attn_bias: Optional[torch.Tensor] = None, # (B, H, Q, K)
cu_seqlens_q: torch.Tensor = None, # (B+1,)
cu_seqlens_k: torch.Tensor = None, # (B+1,)
max_seqlen_q: int = None,
max_seqlen_k: int = None,
scale: Optional[float] = None,
is_causal: Optional[bool] = None,
softcap: Optional[float] = None, # CUDA-only
deterministic: Optional[bool] = None, # CUDA-only
block_table: Optional[torch.Tensor] = None, # experimental: paged attention
softcap: Optional[float] = None, # CUDA-only
deterministic: Optional[bool] = None, # CUDA-only
block_table: Optional[torch.Tensor] = None, # experimental: paged attention
) -> torch.Tensor
```

Expand Down Expand Up @@ -386,6 +387,3 @@ print_memory_stats()
torch.cuda.empty_cache()
```

---

See also: `docs/integration.md` and `benchmarks/`.