diff --git a/docs/api_reference.md b/docs/api_reference.md index d91ca4f..435231d 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -2,13 +2,13 @@ ## Overview -Flash Dynamic Mask Attention is a high-performance implementation that combines Flash Attention's memory efficiency with Dynamic Mask Attention's sparse computation capabilities. This API provides CUDA-accelerated attention computation with dynamic masking for handling extremely long sequences efficiently. +Flash Dynamic Mask Attention is a high-performance attention implementation that combines the memory efficiency of Flash Attention with the sparse compute benefits of Dynamic Mask Attention. It supports CUDA, Triton, and Flex Attention backends and dynamic masking for very long sequences. -The library provides multiple interfaces: -- **High-level Functions**: Easy-to-use functions with automatic backend selection -- **Specific Implementations**: Direct access to CUDA, Triton, and Flex Attention backends -- **Packed Variants**: Optimized functions for QKV-packed and KV-packed tensors -- **Variable Length**: Support for variable sequence lengths within batches +Interfaces provided: +- High-level: simple entry point with automatic backend selection +- Backend-specific: direct access to CUDA, Triton, and Flex implementations +- Packed variants: optimized paths for QKV-packed and KV-packed inputs +- Variable length: support for batches with different sequence lengths ## Table of Contents @@ -23,11 +23,11 @@ The library provides multiple interfaces: ### Prerequisites -- **Python**: 3.8 or later -- **PyTorch**: 2.0.0 or later with CUDA support -- **CUDA**: 11.8 or later -- **NVIDIA GPU**: Compute Capability 8.0 or higher -- **Dependencies**: `packaging`, `torch` +- Python: 3.8+ +- PyTorch: 2.0.0+ with CUDA +- CUDA: 11.8+ +- NVIDIA GPU: Compute Capability 8.0+ +- Dependencies: `packaging`, `torch` ### Install from Source @@ -42,6 +42,8 @@ pip install -e . ### Automatic Backend Selection +Note: `flash_dmattn_func_auto` returns a callable attention function, not the attention output. + ```python from flash_dmattn import flash_dmattn_func_auto, get_available_backends @@ -49,171 +51,112 @@ from flash_dmattn import flash_dmattn_func_auto, get_available_backends backends = get_available_backends() print(f"Available backends: {backends}") -# Use with automatic backend selection -output = flash_dmattn_func_auto( - q=query, k=key, v=value, - attn_mask=attention_mask, - attn_bias=attention_bias -) +# Auto-select (priority: cuda > triton > flex) +attn = flash_dmattn_func_auto() +output = attn(q, k, v, attn_mask=attention_mask, attn_bias=attention_bias, is_causal=True, scale=None) -# Force specific backend -output = flash_dmattn_func_auto( - backend="cuda", # or "triton", "flex" - q=query, k=key, v=value, - attn_mask=attention_mask, - attn_bias=attention_bias -) +# Force a specific backend +attn = flash_dmattn_func_auto(backend="cuda") # or "triton", "flex" +output = attn(q, k, v, attn_mask=attention_mask, attn_bias=attention_bias, is_causal=True, scale=None) ``` ## Core Functions -### flash_dmattn_func +### flash_dmattn_func (CUDA backend) -The main attention function supporting multi-head and grouped-query attention. +Main attention function. Supports multi-head and grouped-query attention (when the number of KV heads is smaller than the number of Q heads). Requires the CUDA extension to be built and available. ```python def flash_dmattn_func( - q: torch.Tensor, # Query tensor - k: torch.Tensor, # Key tensor - v: torch.Tensor, # Value tensor - attn_mask: Optional[torch.Tensor] = None, # Attention mask - attn_bias: Optional[torch.Tensor] = None, # Attention bias - dropout_p: Optional[float] = None, # Dropout probability - softmax_scale: Optional[float] = None, # Scaling factor - is_causal: Optional[bool] = None, # Causal masking - softcap: Optional[float] = None, # Soft capping - deterministic: Optional[bool] = None, # Deterministic mode - return_attn_probs: Optional[bool] = None, # Return attention weights + q: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) + k: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim) + v: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim) + attn_mask: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) + attn_bias: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) + scale: Optional[float] = None, # score scaling, defaults to 1/sqrt(head_dim) + is_causal: Optional[bool] = None, # causal mask + softcap: Optional[float] = None, # CUDA-only + deterministic: Optional[bool] = None, # CUDA-only ) -> torch.Tensor ``` #### Parameters -- **q** (`torch.Tensor`): Query tensor of shape `(batch_size, seqlen_q, num_heads, head_dim)` - - Must be contiguous and on CUDA device - - Supported dtypes: `torch.float16`, `torch.bfloat16` - -- **k** (`torch.Tensor`): Key tensor of shape `(batch_size, seqlen_k, num_heads_k, head_dim)` - - Same dtype and device as `q` - - Supports grouped-query attention when `num_heads_k < num_heads` - -- **v** (`torch.Tensor`): Value tensor of shape `(batch_size, seqlen_k, num_heads_k, head_dim)` - - Same dtype and device as `q` - - Supports grouped-query attention when `num_heads_k < num_heads` - -- **attn_mask** (`Optional[torch.Tensor]`): Attention mask of shape `(batch_size, num_heads, seqlen_q, seqlen_k)` - - Binary mask: 1.0 for positions to attend, 0.0 for masked positions - - If `None`, no masking is applied - -- **attn_bias** (`Optional[torch.Tensor]`): Attention bias of shape `(batch_size, num_heads, seqlen_q, seqlen_k)` - - Added to attention scores before softmax - - If `None`, no bias is applied - -- **dropout_p** (`Optional[float]`): Dropout probability (default: 0.0) - - Range: [0.0, 1.0] - - Applied to attention weights - -- **softmax_scale** (`Optional[float]`): Scaling factor for attention scores - - If `None`, defaults to `1.0 / sqrt(head_dim)` - -- **is_causal** (`Optional[bool]`): Whether to apply causal masking (default: False) - - When True, applies lower triangular mask - -- **softcap** (`Optional[float]`): Soft capping value (default: 0.0) - - If > 0, applies `softcap * tanh(score / softcap)` - -- **deterministic** (`Optional[bool]`): Use deterministic backward pass (default: True) - - Slightly slower but more memory efficient - -- **return_attn_probs** (`Optional[bool]`): Return attention probabilities (default: False) - - For debugging only +- q: (B, Q, H, D). CUDA tensor, fp16/bf16, last dim contiguous +- k, v: (B, K, H_kv, D). Same dtype/device as q; GQA when H_kv < H +- attn_mask: (B, H, Q, K). 1.0 = visible, 0.0 = masked. None to disable +- attn_bias: (B, H, Q, K). Added to scores before softmax. None to disable +- scale: score scaling; default 1/sqrt(D) +- is_causal: apply lower-triangular mask +- softcap, deterministic: only effective on the CUDA backend; ignored on others #### Returns -- **output** (`torch.Tensor`): Attention output of shape `(batch_size, seqlen_q, num_heads, head_dim)` -- **softmax_lse** (optional): Log-sum-exp of attention weights -- **attn_probs** (optional): Attention probabilities (if `return_attn_probs=True`) +- output: (B, Q, H, D) -## Packed Variants +## Packed Variants (CUDA backend) ### flash_dmattn_qkvpacked_func -Optimized function for QKV-packed tensors. +Optimized function for QKV-packed input. ```python def flash_dmattn_qkvpacked_func( - qkv: torch.Tensor, # Packed QKV tensor - attn_mask: Optional[torch.Tensor] = None, # Attention mask - attn_bias: Optional[torch.Tensor] = None, # Attention bias - dropout_p: Optional[float] = None, # Dropout probability - softmax_scale: Optional[float] = None, # Scaling factor - is_causal: Optional[bool] = None, # Causal masking - softcap: Optional[float] = None, # Soft capping - deterministic: Optional[bool] = None, # Deterministic mode - return_attn_probs: Optional[bool] = None, # Return attention weights + qkv: torch.Tensor, # (batch, seqlen, 3, num_heads, head_dim) + attn_mask: Optional[torch.Tensor] = None, + attn_bias: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + is_causal: Optional[bool] = None, + softcap: Optional[float] = None, # CUDA-only + deterministic: Optional[bool] = None, # CUDA-only ) -> torch.Tensor ``` -**Parameters:** -- **qkv** (`torch.Tensor`): Packed tensor of shape `(batch_size, seqlen, 3, num_heads, head_dim)` - - Contains query, key, and value tensors stacked along dimension 2 - ### flash_dmattn_kvpacked_func -Optimized function for KV-packed tensors. +Optimized function for KV-packed input. ```python def flash_dmattn_kvpacked_func( - q: torch.Tensor, # Query tensor - kv: torch.Tensor, # Packed KV tensor - attn_mask: Optional[torch.Tensor] = None, # Attention mask - attn_bias: Optional[torch.Tensor] = None, # Attention bias - dropout_p: Optional[float] = None, # Dropout probability - softmax_scale: Optional[float] = None, # Scaling factor - is_causal: Optional[bool] = None, # Causal masking - softcap: Optional[float] = None, # Soft capping - deterministic: Optional[bool] = None, # Deterministic mode - return_attn_probs: Optional[bool] = None, # Return attention weights + q: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) + kv: torch.Tensor, # (batch, seqlen_k, 2, num_kv_heads, head_dim) + attn_mask: Optional[torch.Tensor] = None, + attn_bias: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + is_causal: Optional[bool] = None, + softcap: Optional[float] = None, # CUDA-only + deterministic: Optional[bool] = None, # CUDA-only ) -> torch.Tensor ``` -**Parameters:** -- **q** (`torch.Tensor`): Query tensor of shape `(batch_size, seqlen_q, num_heads, head_dim)` -- **kv** (`torch.Tensor`): Packed tensor of shape `(batch_size, seqlen_k, 2, num_heads_k, head_dim)` - -## Variable Length Functions +## Variable Length Functions (CUDA backend) ### flash_dmattn_varlen_func -Attention function supporting variable sequence lengths within a batch. +Variable length attention for batches with mixed sequence lengths. ```python def flash_dmattn_varlen_func( - q: torch.Tensor, # Query tensor - k: torch.Tensor, # Key tensor - v: torch.Tensor, # Value tensor - attn_mask: Optional[torch.Tensor] = None, # Attention mask - attn_bias: Optional[torch.Tensor] = None, # Attention bias - cu_seqlens_q: torch.Tensor = None, # Cumulative sequence lengths (query) - cu_seqlens_k: torch.Tensor = None, # Cumulative sequence lengths (key) - max_seqlen_q: int = None, # Maximum query sequence length - max_seqlen_k: int = None, # Maximum key sequence length - dropout_p: Optional[float] = None, # Dropout probability - softmax_scale: Optional[float] = None, # Scaling factor - is_causal: Optional[bool] = None, # Causal masking - softcap: Optional[float] = None, # Soft capping - deterministic: Optional[bool] = None, # Deterministic mode - return_attn_probs: Optional[bool] = None, # Return attention weights - block_table: Optional[torch.Tensor] = None, # Block table for paged attention + q: torch.Tensor, # (total_q, H, D) or (B, Q, H, D) + k: torch.Tensor, # same layout as q + v: torch.Tensor, # same layout as q + attn_mask: Optional[torch.Tensor] = None, # (B, H, Q, K) + attn_bias: Optional[torch.Tensor] = None, # (B, H, Q, K) + cu_seqlens_q: torch.Tensor = None, # (B+1,) + cu_seqlens_k: torch.Tensor = None, # (B+1,) + max_seqlen_q: int = None, + max_seqlen_k: int = None, + scale: Optional[float] = None, + is_causal: Optional[bool] = None, + softcap: Optional[float] = None, # CUDA-only + deterministic: Optional[bool] = None, # CUDA-only + block_table: Optional[torch.Tensor] = None, # experimental: paged attention ) -> torch.Tensor ``` -**Additional Parameters:** -- **cu_seqlens_q** (`torch.Tensor`): Cumulative sequence lengths for queries, shape `(batch_size + 1,)` -- **cu_seqlens_k** (`torch.Tensor`): Cumulative sequence lengths for keys, shape `(batch_size + 1,)` -- **max_seqlen_q** (`int`): Maximum sequence length in the batch for queries -- **max_seqlen_k** (`int`): Maximum sequence length in the batch for keys -- **block_table** (`Optional[torch.Tensor]`): Block table for paged attention (experimental) +- cu_seqlens_q/k: cumulative sequence lengths for query/key +- max_seqlen_q/k: max sequence lengths per batch +- block_table: experimental support for paged attention ## Backend Selection @@ -222,272 +165,227 @@ def flash_dmattn_varlen_func( ```python from flash_dmattn import get_available_backends, CUDA_AVAILABLE, TRITON_AVAILABLE, FLEX_AVAILABLE -# Check which backends are available -backends = get_available_backends() -print(f"Available: {backends}") - -# Check individual backend availability -print(f"CUDA: {CUDA_AVAILABLE}") -print(f"Triton: {TRITON_AVAILABLE}") -print(f"Flex: {FLEX_AVAILABLE}") +print(get_available_backends()) # e.g., ["cuda", "triton", "flex"] +print(CUDA_AVAILABLE, TRITON_AVAILABLE, FLEX_AVAILABLE) ``` ### Backend-Specific Functions ```python -# Direct access to specific implementations -from flash_dmattn import flash_dmattn_func # CUDA backend -from flash_dmattn import triton_dmattn_func # Triton backend +# Direct access to specific backends +from flash_dmattn import flash_dmattn_func # CUDA backend (requires compiled extension) +from flash_dmattn import triton_dmattn_func # Triton backend from flash_dmattn import flex_dmattn_func # Flex Attention backend + +# Unified call signature (public layer) +# query/key/value: (B, L{q/k}, H, D) +# attn_mask/attn_bias: (B, H, Lq, Lk) +# is_causal: bool, scale: Optional[float] +output = triton_dmattn_func(q, k, v, attn_mask=mask, attn_bias=bias, is_causal=True, scale=None) +output = flex_dmattn_func(q, k, v, attn_mask=mask, attn_bias=bias, is_causal=True, scale=None) ``` +Notes: +- Triton returns only the attention output tensor. +- Flex currently uses causal masking and score_mod with bias; provided attn_mask is not applied in the kernel at the moment (subject to change in future versions). + ### Data Types and Memory Layout -- **Supported dtypes**: `torch.float16`, `torch.bfloat16` -- **Recommended**: `torch.bfloat16` for better numerical stability -- **Device**: CUDA tensors only -- **Memory**: All tensors must be contiguous in the last dimension +- dtypes: `torch.float16`, `torch.bfloat16` (bf16 recommended for stability) +- device: CUDA tensors only +- memory: last dimension must be contiguous (`stride(-1) == 1`); call `.contiguous()` if needed -### Basic Usage Examples +## Basic Usage Examples -#### Standard Attention +Prefer the high-level automatic interface for cross-backend portability. + +### Standard Attention ```python import torch -from flash_dmattn import flash_dmattn_func +from flash_dmattn import flash_dmattn_func_auto -# Setup -batch_size, seq_len, num_heads, head_dim = 2, 4096, 12, 128 +B, L, H, D = 2, 4096, 12, 128 device = torch.device('cuda') dtype = torch.bfloat16 -# Create input tensors -q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) -k = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) -v = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) +q = torch.randn(B, L, H, D, device=device, dtype=dtype) +k = torch.randn(B, L, H, D, device=device, dtype=dtype) +v = torch.randn(B, L, H, D, device=device, dtype=dtype) -# Basic attention -output = flash_dmattn_func(q=q, k=k, v=v, is_causal=True) -print(f"Output shape: {output.shape}") # [2, 4096, 12, 128] +attn = flash_dmattn_func_auto() +output = attn(q, k, v, is_causal=True) +print(output.shape) # [2, 4096, 12, 128] ``` -#### Dynamic Mask Attention +### Dynamic Mask Attention ```python -import torch -from flash_dmattn import flash_dmattn_func -import math +import torch, math +from flash_dmattn import flash_dmattn_func_auto -# Create attention mask and bias for dynamic masking -batch_size, num_heads, seq_len = 2, 12, 4096 +B, H, L = 2, 12, 4096 keep_window_size = 1024 device = torch.device('cuda') dtype = torch.bfloat16 -# Create sparse attention mask (attend to top-k positions) -attention_bias = torch.randn(batch_size, num_heads, seq_len, seq_len, device=device, dtype=dtype) +q = torch.randn(B, L, H, 128, device=device, dtype=dtype) +k = torch.randn(B, L, H, 128, device=device, dtype=dtype) +v = torch.randn(B, L, H, 128, device=device, dtype=dtype) + +attention_bias = torch.randn(B, H, L, L, device=device, dtype=dtype) attention_mask = torch.zeros_like(attention_bias) -# Keep top-k positions per query -if seq_len > keep_window_size: +if L > keep_window_size: topk_indices = torch.topk(attention_bias, keep_window_size, dim=-1, largest=True).indices - attention_mask.scatter(-1, topk_indices, 1.0) + attention_mask.scatter_(-1, topk_indices, 1.0) else: attention_mask.fill_(1.0) -# Run attention with dynamic masking -output = flash_dmattn_func( - q=q, k=k, v=v, - attn_mask=attention_mask, - attn_bias=attention_bias, - is_causal=True, - softmax_scale=1.0/math.sqrt(head_dim) -) +attn = flash_dmattn_func_auto() +output = attn(q, k, v, attn_mask=attention_mask, attn_bias=attention_bias, is_causal=True, scale=1.0/math.sqrt(128)) ``` -#### Grouped-Query Attention (GQA) +### Grouped-Query Attention (GQA) ```python import torch -from flash_dmattn import flash_dmattn_func +from flash_dmattn import flash_dmattn_func_auto -# GQA setup: fewer key/value heads than query heads -batch_size, seq_len, num_heads, num_kv_heads, head_dim = 2, 2048, 32, 8, 128 +B, L, H, H_kv, D = 2, 2048, 32, 8, 128 device = torch.device('cuda') dtype = torch.bfloat16 -q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) -k = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype) -v = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype) +q = torch.randn(B, L, H, D, device=device, dtype=dtype) +k = torch.randn(B, L, H_kv, D, device=device, dtype=dtype) +v = torch.randn(B, L, H_kv, D, device=device, dtype=dtype) -# Attention mask for GQA -attn_mask = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype) +attn_mask = torch.ones(B, H, L, L, device=device, dtype=dtype) -output = flash_dmattn_func(q=q, k=k, v=v, attn_mask=attn_mask, is_causal=True) +attn = flash_dmattn_func_auto() +output = attn(q, k, v, attn_mask=attn_mask, is_causal=True) ``` -#### Variable Length Sequences +### Variable Length Sequences (CUDA backend) ```python import torch from flash_dmattn import flash_dmattn_varlen_func -# Variable length setup -batch_size = 3 -seq_lens = [512, 1024, 768] # Different lengths per batch -total_tokens = sum(seq_lens) -num_heads, head_dim = 16, 64 +B = 3 +seq_lens = [512, 1024, 768] +T = sum(seq_lens) +H, D = 16, 64 device = torch.device('cuda') dtype = torch.bfloat16 -# Concatenated tensors -q = torch.randn(total_tokens, num_heads, head_dim, device=device, dtype=dtype) -k = torch.randn(total_tokens, num_heads, head_dim, device=device, dtype=dtype) -v = torch.randn(total_tokens, num_heads, head_dim, device=device, dtype=dtype) +q = torch.randn(T, H, D, device=device, dtype=dtype) +k = torch.randn(T, H, D, device=device, dtype=dtype) +v = torch.randn(T, H, D, device=device, dtype=dtype) -# Cumulative sequence lengths -cu_seqlens = torch.tensor([0] + seq_lens, device=device, dtype=torch.int32).cumsum(0) +cu = torch.tensor([0] + seq_lens, device=device, dtype=torch.int32).cumsum(0) -# Variable length attention output = flash_dmattn_varlen_func( q=q, k=k, v=v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max(seq_lens), - max_seqlen_k=max(seq_lens), + cu_seqlens_q=cu, cu_seqlens_k=cu, + max_seqlen_q=max(seq_lens), max_seqlen_k=max(seq_lens), is_causal=True ) ``` -### Performance Optimization +## Performance Optimization -#### Memory Efficiency +### Memory Efficiency ```python -# Use gradient checkpointing for long sequences +# Gradient checkpointing for long sequences import torch.utils.checkpoint as checkpoint +from flash_dmattn import flash_dmattn_func_auto + +attn = flash_dmattn_func_auto() def attention_checkpoint(q, k, v, *args, **kwargs): - return checkpoint.checkpoint(flash_dmattn_func, q, k, v, *args, **kwargs) + return checkpoint.checkpoint(lambda *a, **kw: attn(*a, **kw), q, k, v, *args, **kwargs) # Process very long sequences in chunks -def chunked_attention(q, k, v, chunk_size=8192): - seq_len = q.shape[1] - outputs = [] - - for i in range(0, seq_len, chunk_size): - q_chunk = q[:, i:i+chunk_size] - output_chunk = flash_dmattn_func(q=q_chunk, k=k, v=v, is_causal=True) - outputs.append(output_chunk) - - return torch.cat(outputs, dim=1) +def chunked_attention(q, k, v, chunk_size=8192, **kwargs): + L = q.shape[1] + outs = [] + for i in range(0, L, chunk_size): + outs.append(attn(q[:, i:i+chunk_size], k, v, **kwargs)) + return torch.cat(outs, dim=1) ``` -#### Backend Selection for Performance +### Backend Selection for Performance ```python +import torch from flash_dmattn import flash_dmattn_func_auto -# Automatic selection (CUDA > Triton > Flex) -output = flash_dmattn_func_auto(q=q, k=k, v=v) - -# Force specific backend for performance testing backends = ["cuda", "triton", "flex"] for backend in backends: try: - start_time = torch.cuda.Event(enable_timing=True) - end_time = torch.cuda.Event(enable_timing=True) - - start_time.record() - output = flash_dmattn_func_auto(backend=backend, q=q, k=k, v=v) - end_time.record() - + attn = flash_dmattn_func_auto(backend=backend) + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + _ = attn(q, k, v, is_causal=True) + end.record() torch.cuda.synchronize() - elapsed = start_time.elapsed_time(end_time) - print(f"{backend}: {elapsed:.2f} ms") + print(f"{backend}: {start.elapsed_time(end):.2f} ms") except RuntimeError as e: print(f"{backend}: not available - {e}") ``` -### Common Issues and Solutions +## Common Issues and Solutions -#### Import Errors +### Import Errors ```python -# Test basic import try: - from flash_dmattn import flash_dmattn_func, get_available_backends - print("✅ Flash Dynamic Mask Attention imported successfully") - print(f"Available backends: {get_available_backends()}") + from flash_dmattn import flash_dmattn_func_auto, get_available_backends + print("✅ Imported successfully", get_available_backends()) except ImportError as e: print(f"❌ Import failed: {e}") - print("Please ensure the package is properly installed with: pip install -e .") + print("Please install with: pip install -e .") ``` -#### Performance Issues - -1. **Slow Execution** - - Ensure tensors are contiguous and on the same GPU - - Use optimal head dimensions (multiples of 8) - - Check that CUDA backend is being used - -2. **High Memory Usage** - - Use gradient checkpointing for training - - Process sequences in chunks for very long sequences - - Consider using variable length functions for batches with mixed lengths +### Performance Issues -3. **Numerical Instability** - - Use `torch.bfloat16` instead of `torch.float16` - - Check attention mask and bias values for NaN/Inf - - Monitor gradient norms during training +1. Slow execution: ensure all tensors are on the same GPU and last dim is contiguous; use head dims multiple of 8; prefer CUDA backend when available +2. High memory: use gradient checkpointing; chunk long sequences; use varlen for mixed-length batches +3. Numerical stability: prefer bfloat16; check mask/bias for NaN/Inf; monitor gradient norms -#### Debugging +### Debugging ```python -# Enable anomaly detection -torch.autograd.set_detect_anomaly(True) - -# Check intermediate values -output = flash_dmattn_func( - q=q, k=k, v=v, - attn_mask=attn_mask, - attn_bias=attn_bias, - return_attn_probs=True # Get attention weights for debugging -) - -if isinstance(output, tuple): - attn_output, softmax_lse, attn_weights = output - print(f"Attention weights range: [{attn_weights.min():.6f}, {attn_weights.max():.6f}]") - print(f"LSE stats: mean={softmax_lse.mean():.6f}, std={softmax_lse.std():.6f}") -else: - attn_output = output +import torch +from flash_dmattn import flash_dmattn_func_auto -# Check for NaN values -if torch.isnan(attn_output).any(): +torch.autograd.set_detect_anomaly(True) +attn = flash_dmattn_func_auto() +output = attn(q, k, v, attn_mask=attn_mask, attn_bias=attn_bias, is_causal=True) +if torch.isnan(output).any(): print("⚠️ NaN detected in attention output") ``` -#### Memory Monitoring +### Memory Monitoring ```python def print_memory_stats(): if torch.cuda.is_available(): - print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB allocated") - print(f"GPU Memory: {torch.cuda.memory_reserved() / 1e9:.2f} GB reserved") - print(f"GPU Memory: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB max allocated") + print(f"allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB") + print(f"reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB") + print(f"max alloc: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB") -# Monitor memory usage print_memory_stats() -output = flash_dmattn_func(q=q, k=k, v=v) +attn = flash_dmattn_func_auto() +output = attn(q, k, v) print_memory_stats() -# Clear cache if needed torch.cuda.empty_cache() ``` - - - --- -For more information, see the [integration documentation](integration.md) and [benchmarking results](../benchmarks/). +See also: `docs/integration.md` and `benchmarks/`. diff --git a/flash_dmattn/flash_dmattn_flex.py b/flash_dmattn/flash_dmattn_flex.py index 1c160ad..7be4d01 100644 --- a/flash_dmattn/flash_dmattn_flex.py +++ b/flash_dmattn/flash_dmattn_flex.py @@ -1,4 +1,5 @@ from typing import Optional, Tuple +import math import torch from torch.nn.attention.flex_attention import create_block_mask from transformers.integrations.flex_attention import compile_friendly_flex_attention @@ -8,17 +9,29 @@ def flex_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_mask: torch.Tensor, - attn_bias: torch.Tensor, - is_causal: bool = True, + attn_mask: Optional[torch.Tensor] = None, + attn_bias: Optional[torch.Tensor] = None, + is_causal: Optional[bool] = None, scale: Optional[float] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: + batch, seqlen_q, nheads, dhead = query.shape + _, seqlen_k, _, _ = key.shape query = query.transpose(1, 2).contiguous() # [B, H, Q_LEN, D] key = key.transpose(1, 2).contiguous() # [B, H, KV_LEN, D] value = value.transpose(1, 2).contiguous() # [B, H, KV_LEN, D] - attn_mask = attn_mask[:, :, :, : key.shape[-2]] - attn_bias = attn_bias[:, :, :, : key.shape[-2]] + if attn_mask is not None: + attn_mask = attn_mask[:, :, :, : key.shape[-2]] + else: + attn_mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype) + if attn_bias is not None: + attn_bias = attn_bias[:, :, :, : key.shape[-2]] + else: + attn_bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype) + if is_causal is None: + is_causal = True + if scale is None: + scale = 1.0 / math.sqrt(dhead) def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): score = score + attn_bias[batch_idx][head_idx][q_idx][kv_idx]