From 187cbd030f22af0542623f53445b3e28e9a61c1c Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 30 Jul 2025 21:41:00 +0800 Subject: [PATCH] Refactors API to use unified flash attention interface Replaces low-level CUDA extension calls with simplified function interface that handles dynamic masking internally. Removes manual ZOH state and active mask management in favor of attention bias and mask parameters. Adds dynamic top-k selection for long sequences to improve memory efficiency. Simplifies troubleshooting documentation by removing CUDA-specific debugging steps and focusing on memory monitoring. --- README.md | 65 +++++++++++++++++++++++++++++++------------------------ 1 file changed, 37 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index 139f8ba..64ae1a6 100644 --- a/README.md +++ b/README.md @@ -47,8 +47,7 @@ pip install . ```python import torch -import flash_dma_cuda -import torch.nn.functional as F +from flash_dmattn import flash_dmattn_func import math # Setup @@ -63,19 +62,32 @@ key = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) value = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) -zoh_states = torch.randn(batch_size, num_heads, seq_len, seq_len, - device=device, dtype=dtype) -active_mask = torch.ones(batch_size, num_heads, seq_len, seq_len, - device=device, dtype=dtype) - -# Run Flash-DMA -output = flash_dma_cuda.fwd( - q=query, k=key, v=value, - zoh=zoh_states, active_mask=active_mask, + +# Create mask and bias for sparse attention +attention_bias = torch.randn(batch_size, num_heads, seq_len, seq_len, + device=device, dtype=dtype) +attention_mask = torch.ones(batch_size, num_heads, seq_len, seq_len, + device=device, dtype=dtype) + +# Apply dynamic masking (keep top-k for long sequences) +keep_window_size = 2048 +if seq_len > keep_window_size: + # Select top-k most important keys for each query + topk_indices = torch.topk(attention_bias, keep_window_size, dim=-1, + largest=True, sorted=False).indices + attention_mask.zero_() + attention_mask.scatter(-1, topk_indices, 1.0) + +# Run Flash Dynamic Mask Attention +output = flash_dmattn_func( + q=query, + k=key, + v=value, + attn_mask=attention_mask, + attn_bias=attention_bias, softmax_scale=1.0/math.sqrt(head_dim), - keep_window_size=keep_window_size, is_causal=True -)[0] +) print(f"Output shape: {output.shape}") # [2, 4096, 12, 128] ``` @@ -189,34 +201,31 @@ python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')" ```python # Test basic import try: - import flash_dma_cuda - print("✅ Flash DMA CUDA extension imported successfully") + from flash_dmattn import flash_dmattn_func, get_available_backends + print("✅ Flash Dynamic Mask Attention imported successfully") + print(f"Available backends: {get_available_backends()}") except ImportError as e: print(f"❌ Import failed: {e}") print("Please ensure the package is properly installed with: pip install -e .") ``` **Performance Issues** -- Ensure GPU has compute capability 8.0+ for optimal performance -- Use `torch.bfloat16` for better numerical stability -- Adjust `keep_window_size` based on available GPU memory -- Verify CUDA kernels are being used - -**Memory Issues** ```python # Monitor GPU memory usage -torch.cuda.memory_summary() -torch.cuda.max_memory_allocated() +from flash_dmattn import flash_dmattn_func + +def print_memory_stats(): + if torch.cuda.is_available(): + print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB") + +print_memory_stats() +output = flash_dmattn_func(q=query, k=key, v=value, is_causal=True) +print_memory_stats() # Clear cache if needed torch.cuda.empty_cache() ``` -**Numerical Issues** -- Use `torch.bfloat16` instead of `torch.float16` for better stability -- Check input tensor ranges for NaN or infinite values -- Validate ZOH states and active mask values are in expected ranges - ## License This project is licensed under the BSD 3-Clause License. See [LICENSE](LICENSE) for details.