From c8d658bc08ad3a48543395c10c107da6a3d21b7d Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 30 Jul 2025 21:32:02 +0800 Subject: [PATCH 1/2] Expands API documentation with comprehensive interface guide Restructures documentation to provide complete coverage of all available functions and interfaces including high-level auto-selection, packed variants, and variable length support. Adds detailed usage examples for standard attention, dynamic masking, grouped-query attention, and variable length sequences with performance optimization tips. Includes troubleshooting section covering common import errors, performance issues, and debugging techniques with memory monitoring utilities. --- docs/api_reference.md | 510 ++++++++++++++++++++++++++++++++---------- 1 file changed, 394 insertions(+), 116 deletions(-) diff --git a/docs/api_reference.md b/docs/api_reference.md index 832b131..1bc92a1 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -4,127 +4,257 @@ Flash Dynamic Mask Attention is a high-performance implementation that combines Flash Attention's memory efficiency with Dynamic Mask Attention's sparse computation capabilities. This API provides CUDA-accelerated attention computation with dynamic masking for handling extremely long sequences efficiently. +The library provides multiple interfaces: +- **High-level Functions**: Easy-to-use functions with automatic backend selection +- **Specific Implementations**: Direct access to CUDA, Triton, and Flex Attention backends +- **Packed Variants**: Optimized functions for QKV-packed and KV-packed tensors +- **Variable Length**: Support for variable sequence lengths within batches + ## Table of Contents 1. [Installation](#installation) -2. [Forward](#forward) +2. [High-Level Interface](#high-level-interface) +3. [Core Functions](#core-functions) +4. [Packed Variants](#packed-variants) +5. [Variable Length Functions](#variable-length-functions) +6. [Backend Selection](#backend-selection) ## Installation ### Prerequisites -- CUDA >= 11.8 -- PyTorch >= 2.0 -- CUTLASS library -- GPU with compute capability >= 8.0 (Ampere architecture or newer) +- **Python**: 3.8 or later +- **PyTorch**: 2.0.0 or later with CUDA support +- **CUDA**: 11.8 or later +- **NVIDIA GPU**: Compute Capability 8.0 or higher +- **Dependencies**: `packaging`, `torch` -### Build from Source +### Install from Source ```bash -cd flash_dma +git clone https://github.com/SmallDoges/flash-dmattn.git +cd flash-dmattn +git submodule update --init --recursive pip install -e . ``` -## Forward +## High-Level Interface + +### Automatic Backend Selection + +```python +from flash_dmattn import flash_dmattn_func_auto, get_available_backends + +# Check available backends +backends = get_available_backends() +print(f"Available backends: {backends}") + +# Use with automatic backend selection +output = flash_dmattn_func_auto( + q=query, k=key, v=value, + attn_mask=attention_mask, + attn_bias=attention_bias +) + +# Force specific backend +output = flash_dmattn_func_auto( + backend="cuda", # or "triton", "flex" + q=query, k=key, v=value, + attn_mask=attention_mask, + attn_bias=attention_bias +) +``` + +## Core Functions + +### flash_dmattn_func + +The main attention function supporting multi-head and grouped-query attention. ```python -def fwd( - q: torch.Tensor, # Query tensor - k: torch.Tensor, # Key tensor - v: torch.Tensor, # Value tensor - zoh: torch.Tensor, # ZOH states tensor - active_mask: torch.Tensor, # Active mask tensor - out: Optional[torch.Tensor] = None, # Output tensor (optional) - p_dropout: float = 0.0, # Dropout probability - softmax_scale: float = None, # Scaling factor for attention - is_causal: bool = False, # Whether to apply causal mask - keep_window_size: int = 2048, # Window size for dynamic masking - softcap: float = 0.0, # Soft capping for attention scores - return_softmax: bool = False, # Whether to return softmax weights - gen: Optional[torch.Generator] = None # Random generator for dropout -) -> List[torch.Tensor] +def flash_dmattn_func( + q: torch.Tensor, # Query tensor + k: torch.Tensor, # Key tensor + v: torch.Tensor, # Value tensor + attn_mask: Optional[torch.Tensor] = None, # Attention mask + attn_bias: Optional[torch.Tensor] = None, # Attention bias + dropout_p: Optional[float] = None, # Dropout probability + softmax_scale: Optional[float] = None, # Scaling factor + is_causal: Optional[bool] = None, # Causal masking + softcap: Optional[float] = None, # Soft capping + deterministic: Optional[bool] = None, # Deterministic mode + return_attn_probs: Optional[bool] = None, # Return attention weights +) -> torch.Tensor ``` #### Parameters - **q** (`torch.Tensor`): Query tensor of shape `(batch_size, seqlen_q, num_heads, head_dim)` - - Must be contiguous in the last dimension + - Must be contiguous and on CUDA device - Supported dtypes: `torch.float16`, `torch.bfloat16` - - Must be on CUDA device - **k** (`torch.Tensor`): Key tensor of shape `(batch_size, seqlen_k, num_heads_k, head_dim)` - - Must be contiguous in the last dimension - Same dtype and device as `q` - - `num_heads_k` can be different from `num_heads` for grouped-query attention + - Supports grouped-query attention when `num_heads_k < num_heads` - **v** (`torch.Tensor`): Value tensor of shape `(batch_size, seqlen_k, num_heads_k, head_dim)` - - Must be contiguous in the last dimension - - Same dtype and device as `q` - -- **zoh** (`torch.Tensor`): Zero-Order Hold states tensor of shape `(batch_size, num_heads_k, seqlen_q, seqlen_k)` - - Contains the dynamic attention bias values - Same dtype and device as `q` - - Used for dynamic masking computation + - Supports grouped-query attention when `num_heads_k < num_heads` -- **active_mask** (`torch.Tensor`): Active mask tensor of shape `(batch_size, num_heads_k, seqlen_q, seqlen_k)` - - Binary mask indicating which positions should be computed - - Same dtype and device as `q` - - 1.0 for active positions, 0.0 for masked positions +- **attn_mask** (`Optional[torch.Tensor]`): Attention mask of shape `(batch_size, num_heads, seqlen_q, seqlen_k)` + - Binary mask: 1.0 for positions to attend, 0.0 for masked positions + - If `None`, no masking is applied -- **out** (`Optional[torch.Tensor]`): Pre-allocated output tensor - - If provided, must have shape `(batch_size, seqlen_q, num_heads, head_dim)` - - If `None`, will be allocated automatically +- **attn_bias** (`Optional[torch.Tensor]`): Attention bias of shape `(batch_size, num_heads, seqlen_q, seqlen_k)` + - Added to attention scores before softmax + - If `None`, no bias is applied -- **p_dropout** (`float`): Dropout probability (default: 0.0) +- **dropout_p** (`Optional[float]`): Dropout probability (default: 0.0) - Range: [0.0, 1.0] - Applied to attention weights -- **softmax_scale** (`float`): Scaling factor for attention scores +- **softmax_scale** (`Optional[float]`): Scaling factor for attention scores - If `None`, defaults to `1.0 / sqrt(head_dim)` - - Applied before softmax - -- **is_causal** (`bool`): Whether to apply causal (lower triangular) mask (default: False) - - Combined with dynamic masking -- **keep_window_size** (`int`): Maximum number of tokens to keep per query (default: 2048) - - Controls sparsity level of attention - - Dynamic masking only applied when `seqlen_k > keep_window_size` +- **is_causal** (`Optional[bool]`): Whether to apply causal masking (default: False) + - When True, applies lower triangular mask -- **softcap** (`float`): Soft capping value for attention scores (default: 0.0) +- **softcap** (`Optional[float]`): Soft capping value (default: 0.0) - If > 0, applies `softcap * tanh(score / softcap)` -- **return_softmax** (`bool`): Whether to return attention weights (default: False) - - Only for debugging purposes +- **deterministic** (`Optional[bool]`): Use deterministic backward pass (default: True) + - Slightly slower but more memory efficient -- **gen** (`Optional[torch.Generator]`): Random number generator for dropout - - Used for reproducible dropout +- **return_attn_probs** (`Optional[bool]`): Return attention probabilities (default: False) + - For debugging only #### Returns -Returns a list of tensors: -- `output`: Attention output of shape `(batch_size, seqlen_q, num_heads, head_dim)` -- `softmax_lse`: Log-sum-exp of attention weights for numerical stability -- `p`: Attention weights (if `return_softmax=True`) -- `rng_state`: Random number generator state +- **output** (`torch.Tensor`): Attention output of shape `(batch_size, seqlen_q, num_heads, head_dim)` +- **softmax_lse** (optional): Log-sum-exp of attention weights +- **attn_probs** (optional): Attention probabilities (if `return_attn_probs=True`) -### Data Types +## Packed Variants -- **Supported**: `torch.float16`, `torch.bfloat16` -- **Recommended**: `torch.bfloat16` for better numerical stability -- All input tensors must have the same dtype +### flash_dmattn_qkvpacked_func -### Memory Layout +Optimized function for QKV-packed tensors. -- All tensors must be contiguous in the last dimension -- CUDA tensors only -- Optimal performance with tensors already on the same GPU +```python +def flash_dmattn_qkvpacked_func( + qkv: torch.Tensor, # Packed QKV tensor + attn_mask: Optional[torch.Tensor] = None, # Attention mask + attn_bias: Optional[torch.Tensor] = None, # Attention bias + dropout_p: Optional[float] = None, # Dropout probability + softmax_scale: Optional[float] = None, # Scaling factor + is_causal: Optional[bool] = None, # Causal masking + softcap: Optional[float] = None, # Soft capping + deterministic: Optional[bool] = None, # Deterministic mode + return_attn_probs: Optional[bool] = None, # Return attention weights +) -> torch.Tensor +``` + +**Parameters:** +- **qkv** (`torch.Tensor`): Packed tensor of shape `(batch_size, seqlen, 3, num_heads, head_dim)` + - Contains query, key, and value tensors stacked along dimension 2 -### Basic Usage +### flash_dmattn_kvpacked_func + +Optimized function for KV-packed tensors. + +```python +def flash_dmattn_kvpacked_func( + q: torch.Tensor, # Query tensor + kv: torch.Tensor, # Packed KV tensor + attn_mask: Optional[torch.Tensor] = None, # Attention mask + attn_bias: Optional[torch.Tensor] = None, # Attention bias + dropout_p: Optional[float] = None, # Dropout probability + softmax_scale: Optional[float] = None, # Scaling factor + is_causal: Optional[bool] = None, # Causal masking + softcap: Optional[float] = None, # Soft capping + deterministic: Optional[bool] = None, # Deterministic mode + return_attn_probs: Optional[bool] = None, # Return attention weights +) -> torch.Tensor +``` + +**Parameters:** +- **q** (`torch.Tensor`): Query tensor of shape `(batch_size, seqlen_q, num_heads, head_dim)` +- **kv** (`torch.Tensor`): Packed tensor of shape `(batch_size, seqlen_k, 2, num_heads_k, head_dim)` + +## Variable Length Functions + +### flash_dmattn_varlen_func + +Attention function supporting variable sequence lengths within a batch. + +```python +def flash_dmattn_varlen_func( + q: torch.Tensor, # Query tensor + k: torch.Tensor, # Key tensor + v: torch.Tensor, # Value tensor + attn_mask: Optional[torch.Tensor] = None, # Attention mask + attn_bias: Optional[torch.Tensor] = None, # Attention bias + cu_seqlens_q: torch.Tensor = None, # Cumulative sequence lengths (query) + cu_seqlens_k: torch.Tensor = None, # Cumulative sequence lengths (key) + max_seqlen_q: int = None, # Maximum query sequence length + max_seqlen_k: int = None, # Maximum key sequence length + dropout_p: Optional[float] = None, # Dropout probability + softmax_scale: Optional[float] = None, # Scaling factor + is_causal: Optional[bool] = None, # Causal masking + softcap: Optional[float] = None, # Soft capping + deterministic: Optional[bool] = None, # Deterministic mode + return_attn_probs: Optional[bool] = None, # Return attention weights + block_table: Optional[torch.Tensor] = None, # Block table for paged attention +) -> torch.Tensor +``` + +**Additional Parameters:** +- **cu_seqlens_q** (`torch.Tensor`): Cumulative sequence lengths for queries, shape `(batch_size + 1,)` +- **cu_seqlens_k** (`torch.Tensor`): Cumulative sequence lengths for keys, shape `(batch_size + 1,)` +- **max_seqlen_q** (`int`): Maximum sequence length in the batch for queries +- **max_seqlen_k** (`int`): Maximum sequence length in the batch for keys +- **block_table** (`Optional[torch.Tensor]`): Block table for paged attention (experimental) + +## Backend Selection + +### Available Backends + +```python +from flash_dmattn import get_available_backends, CUDA_AVAILABLE, TRITON_AVAILABLE, FLEX_AVAILABLE + +# Check which backends are available +backends = get_available_backends() +print(f"Available: {backends}") + +# Check individual backend availability +print(f"CUDA: {CUDA_AVAILABLE}") +print(f"Triton: {TRITON_AVAILABLE}") +print(f"Flex: {FLEX_AVAILABLE}") +``` + +### Backend-Specific Functions + +```python +# Direct access to specific implementations +from flash_dmattn import flash_dmattn_func # CUDA backend +from flash_dmattn import triton_dmattn_func # Triton backend +from flash_dmattn import flex_dmattn_func # Flex Attention backend +``` + +### Data Types and Memory Layout + +- **Supported dtypes**: `torch.float16`, `torch.bfloat16` +- **Recommended**: `torch.bfloat16` for better numerical stability +- **Device**: CUDA tensors only +- **Memory**: All tensors must be contiguous in the last dimension + +### Basic Usage Examples + +#### Standard Attention ```python import torch -import flash_dma -import math +from flash_dmattn import flash_dmattn_func # Setup batch_size, seq_len, num_heads, head_dim = 2, 4096, 12, 128 @@ -132,72 +262,220 @@ device = torch.device('cuda') dtype = torch.bfloat16 # Create input tensors -query = torch.randn(batch_size, seq_len, num_heads, head_dim, - device=device, dtype=dtype) -key = torch.randn(batch_size, seq_len, num_heads, head_dim, - device=device, dtype=dtype) -value = torch.randn(batch_size, seq_len, num_heads, head_dim, - device=device, dtype=dtype) - -# Prepare ZOH states and active mask -zoh_states = torch.randn(batch_size, num_heads, seq_len, seq_len, - device=device, dtype=dtype) -active_mask = torch.zeros(batch_size, num_heads, seq_len, seq_len, - device=device, dtype=dtype) - -# Apply sparsity (keep top-k per row) +q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) +k = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) +v = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + +# Basic attention +output = flash_dmattn_func(q=q, k=k, v=v, is_causal=True) +print(f"Output shape: {output.shape}") # [2, 4096, 12, 128] +``` + +#### Dynamic Mask Attention + +```python +import torch +from flash_dmattn import flash_dmattn_func + +# Create attention mask and bias for dynamic masking +batch_size, num_heads, seq_len = 2, 12, 4096 keep_window_size = 1024 + +# Create sparse attention mask (attend to top-k positions) +attention_bias = torch.randn(batch_size, num_heads, seq_len, seq_len, device=device, dtype=dtype) +attention_mask = torch.zeros_like(attention_bias) + +# Keep top-k positions per query if seq_len > keep_window_size: - # Select top-k most important keys for each query - topk_indices = torch.topk(zoh_states, keep_window_size, dim=-1, - largest=True, sorted=False).indices - active_mask.scatter(-1, topk_indices, 1.0) + topk_indices = torch.topk(attention_bias, keep_window_size, dim=-1, largest=True).indices + attention_mask.scatter(-1, topk_indices, 1.0) +else: + attention_mask.fill_(1.0) + +# Run attention with dynamic masking +output = flash_dmattn_func( + q=q, k=k, v=v, + attn_mask=attention_mask, + attn_bias=attention_bias, + is_causal=True, + softmax_scale=1.0/math.sqrt(head_dim) +) +``` + +#### Grouped-Query Attention (GQA) + +```python +import torch +from flash_dmattn import flash_dmattn_func + +# GQA setup: fewer key/value heads than query heads +batch_size, seq_len, num_heads, num_kv_heads, head_dim = 2, 2048, 32, 8, 128 + +q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) +k = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype) +v = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype) + +# Attention mask for GQA +attn_mask = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype) + +output = flash_dmattn_func(q=q, k=k, v=v, attn_mask=attn_mask, is_causal=True) +``` -# Run attention -output = flash_dma.fwd( - q=query, - k=key, - v=value, - zoh=zoh_states, - active_mask=active_mask, - softmax_scale=1.0/math.sqrt(head_dim), - keep_window_size=keep_window_size -)[0] +#### Variable Length Sequences -print(f"Output shape: {output.shape}") +```python +import torch +from flash_dmattn import flash_dmattn_varlen_func + +# Variable length setup +batch_size = 3 +seq_lens = [512, 1024, 768] # Different lengths per batch +total_tokens = sum(seq_lens) +num_heads, head_dim = 16, 64 + +# Concatenated tensors +q = torch.randn(total_tokens, num_heads, head_dim, device=device, dtype=dtype) +k = torch.randn(total_tokens, num_heads, head_dim, device=device, dtype=dtype) +v = torch.randn(total_tokens, num_heads, head_dim, device=device, dtype=dtype) + +# Cumulative sequence lengths +cu_seqlens = torch.tensor([0] + seq_lens, device=device, dtype=torch.int32).cumsum(0) + +# Variable length attention +output = flash_dmattn_varlen_func( + q=q, k=k, v=v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max(seq_lens), + max_seqlen_k=max(seq_lens), + is_causal=True +) +``` + +### Performance Optimization + +#### Memory Efficiency + +```python +# Use gradient checkpointing for long sequences +import torch.utils.checkpoint as checkpoint + +def attention_checkpoint(q, k, v, *args, **kwargs): + return checkpoint.checkpoint(flash_dmattn_func, q, k, v, *args, **kwargs) + +# Process very long sequences in chunks +def chunked_attention(q, k, v, chunk_size=8192): + seq_len = q.shape[1] + outputs = [] + + for i in range(0, seq_len, chunk_size): + q_chunk = q[:, i:i+chunk_size] + output_chunk = flash_dmattn_func(q=q_chunk, k=k, v=v, is_causal=True) + outputs.append(output_chunk) + + return torch.cat(outputs, dim=1) +``` + +#### Backend Selection for Performance + +```python +from flash_dmattn import flash_dmattn_func_auto + +# Automatic selection (CUDA > Triton > Flex) +output = flash_dmattn_func_auto(q=q, k=k, v=v) + +# Force specific backend for performance testing +backends = ["cuda", "triton", "flex"] +for backend in backends: + try: + start_time = torch.cuda.Event(enable_timing=True) + end_time = torch.cuda.Event(enable_timing=True) + + start_time.record() + output = flash_dmattn_func_auto(backend=backend, q=q, k=k, v=v) + end_time.record() + + torch.cuda.synchronize() + elapsed = start_time.elapsed_time(end_time) + print(f"{backend}: {elapsed:.2f} ms") + except RuntimeError as e: + print(f"{backend}: not available - {e}") ``` -### Performance Issues +### Common Issues and Solutions + +#### Import Errors + +```python +# Test basic import +try: + from flash_dmattn import flash_dmattn_func, get_available_backends + print("✅ Flash Dynamic Mask Attention imported successfully") + print(f"Available backends: {get_available_backends()}") +except ImportError as e: + print(f"❌ Import failed: {e}") + print("Please ensure the package is properly installed with: pip install -e .") +``` + +#### Performance Issues 1. **Slow Execution** - - Ensure tensors are contiguous + - Ensure tensors are contiguous and on the same GPU - Use optimal head dimensions (multiples of 8) - - Check GPU utilization with `nvidia-smi` - + - Check that CUDA backend is being used + 2. **High Memory Usage** - - Reduce `keep_window_size` - - Use gradient checkpointing - - Process sequences in chunks + - Use gradient checkpointing for training + - Process sequences in chunks for very long sequences + - Consider using variable length functions for batches with mixed lengths 3. **Numerical Instability** - Use `torch.bfloat16` instead of `torch.float16` - - Check attention mask values - - Monitor gradient norms + - Check attention mask and bias values for NaN/Inf + - Monitor gradient norms during training -### Debug Mode +#### Debugging ```python -# Enable debug output +# Enable anomaly detection torch.autograd.set_detect_anomaly(True) -# Check intermediate values -output, softmax_lse, attn_weights, _ = flash_dma.fwd( - query, key, value, zoh_states, active_mask, - return_softmax=True +# Check intermediate values +output = flash_dmattn_func( + q=q, k=k, v=v, + attn_mask=attn_mask, + attn_bias=attn_bias, + return_attn_probs=True # Get attention weights for debugging ) -print(f"Attention weights range: [{attn_weights.min()}, {attn_weights.max()}]") -print(f"LSE stats: mean={softmax_lse.mean()}, std={softmax_lse.std()}") +if isinstance(output, tuple): + attn_output, softmax_lse, attn_weights = output + print(f"Attention weights range: [{attn_weights.min():.6f}, {attn_weights.max():.6f}]") + print(f"LSE stats: mean={softmax_lse.mean():.6f}, std={softmax_lse.std():.6f}") +else: + attn_output = output + +# Check for NaN values +if torch.isnan(attn_output).any(): + print("⚠️ NaN detected in attention output") +``` + +#### Memory Monitoring + +```python +def print_memory_stats(): + if torch.cuda.is_available(): + print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB allocated") + print(f"GPU Memory: {torch.cuda.memory_reserved() / 1e9:.2f} GB reserved") + print(f"GPU Memory: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB max allocated") + +# Monitor memory usage +print_memory_stats() +output = flash_dmattn_func(q=q, k=k, v=v) +print_memory_stats() + +# Clear cache if needed +torch.cuda.empty_cache() ``` From 68c17fca1989faa23dae8726938f95f90ee906e4 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 30 Jul 2025 21:35:54 +0800 Subject: [PATCH 2/2] Adds missing device and dtype declarations to examples Ensures all code examples include proper device and dtype variable definitions to prevent NameError when users run the sample code. Also adds missing math import for completeness. --- docs/api_reference.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/api_reference.md b/docs/api_reference.md index 1bc92a1..d91ca4f 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -276,10 +276,13 @@ print(f"Output shape: {output.shape}") # [2, 4096, 12, 128] ```python import torch from flash_dmattn import flash_dmattn_func +import math # Create attention mask and bias for dynamic masking batch_size, num_heads, seq_len = 2, 12, 4096 keep_window_size = 1024 +device = torch.device('cuda') +dtype = torch.bfloat16 # Create sparse attention mask (attend to top-k positions) attention_bias = torch.randn(batch_size, num_heads, seq_len, seq_len, device=device, dtype=dtype) @@ -310,6 +313,8 @@ from flash_dmattn import flash_dmattn_func # GQA setup: fewer key/value heads than query heads batch_size, seq_len, num_heads, num_kv_heads, head_dim = 2, 2048, 32, 8, 128 +device = torch.device('cuda') +dtype = torch.bfloat16 q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) k = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype) @@ -332,6 +337,8 @@ batch_size = 3 seq_lens = [512, 1024, 768] # Different lengths per batch total_tokens = sum(seq_lens) num_heads, head_dim = 16, 64 +device = torch.device('cuda') +dtype = torch.bfloat16 # Concatenated tensors q = torch.randn(total_tokens, num_heads, head_dim, device=device, dtype=dtype)