From 4edb7a81e255135fea68baafe99d7c731b67787a Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Sat, 9 Aug 2025 11:34:06 +0800 Subject: [PATCH] Improves API parameter naming consistency Renames q/k/v parameters to query/key/value in flash attention functions for better readability and standardization. Updates parameter documentation to reflect the new naming convention and fixes GQA condition description to use <= instead of <. Removes outdated footer reference to integration docs. --- docs/api_reference.md | 46 +++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/docs/api_reference.md b/docs/api_reference.md index 435231d..b1a1307 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -68,22 +68,23 @@ Main attention function. Supports multi-head and grouped-query attention (when t ```python def flash_dmattn_func( - q: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) - k: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim) - v: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim) - attn_mask: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) - attn_bias: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) - scale: Optional[float] = None, # score scaling, defaults to 1/sqrt(head_dim) - is_causal: Optional[bool] = None, # causal mask - softcap: Optional[float] = None, # CUDA-only - deterministic: Optional[bool] = None, # CUDA-only + query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) + key: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim) + value: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim) + attn_mask: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) + attn_bias: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) + scale: Optional[float] = None, # score scaling, defaults to 1/sqrt(head_dim) + is_causal: Optional[bool] = None, # causal mask + softcap: Optional[float] = None, # CUDA-only + deterministic: Optional[bool] = None, # CUDA-only ) -> torch.Tensor ``` #### Parameters -- q: (B, Q, H, D). CUDA tensor, fp16/bf16, last dim contiguous -- k, v: (B, K, H_kv, D). Same dtype/device as q; GQA when H_kv < H +- query: (B, Q, H, D). CUDA tensor, fp16/bf16, last dim contiguous +- key: (B, K, H_kv, D). Same dtype/device as query; GQA when H_kv <= H +- value: (B, K, H_kv, D). Same dtype/device as query; GQA when H_kv <= H - attn_mask: (B, H, Q, K). 1.0 = visible, 0.0 = masked. None to disable - attn_bias: (B, H, Q, K). Added to scores before softmax. None to disable - scale: score scaling; default 1/sqrt(D) @@ -137,20 +138,20 @@ Variable length attention for batches with mixed sequence lengths. ```python def flash_dmattn_varlen_func( - q: torch.Tensor, # (total_q, H, D) or (B, Q, H, D) - k: torch.Tensor, # same layout as q - v: torch.Tensor, # same layout as q - attn_mask: Optional[torch.Tensor] = None, # (B, H, Q, K) - attn_bias: Optional[torch.Tensor] = None, # (B, H, Q, K) - cu_seqlens_q: torch.Tensor = None, # (B+1,) - cu_seqlens_k: torch.Tensor = None, # (B+1,) + query: torch.Tensor, # (total_q, H, D) or (B, Q, H, D) + key: torch.Tensor, # same layout as query + value: torch.Tensor, # same layout as query + attn_mask: Optional[torch.Tensor] = None, # (B, H, Q, K) + attn_bias: Optional[torch.Tensor] = None, # (B, H, Q, K) + cu_seqlens_q: torch.Tensor = None, # (B+1,) + cu_seqlens_k: torch.Tensor = None, # (B+1,) max_seqlen_q: int = None, max_seqlen_k: int = None, scale: Optional[float] = None, is_causal: Optional[bool] = None, - softcap: Optional[float] = None, # CUDA-only - deterministic: Optional[bool] = None, # CUDA-only - block_table: Optional[torch.Tensor] = None, # experimental: paged attention + softcap: Optional[float] = None, # CUDA-only + deterministic: Optional[bool] = None, # CUDA-only + block_table: Optional[torch.Tensor] = None, # experimental: paged attention ) -> torch.Tensor ``` @@ -386,6 +387,3 @@ print_memory_stats() torch.cuda.empty_cache() ``` ---- - -See also: `docs/integration.md` and `benchmarks/`.