From 76a9d9900a14ae9da01ad72b8cdd37f5cd997acc Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Fri, 27 Jun 2025 16:00:38 +0800 Subject: [PATCH] Updates README with improved technical accuracy and examples MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Improves technical terminology by replacing "computational efficiency" with "sparse computation capabilities" for better precision. Updates complexity notation from O(N·k) to O(N·w) to align with implementation variable naming. Raises minimum requirements to Python 3.8+, PyTorch 2.0+, CUDA 11.8+, and compute capability 8.0+ to reflect actual supported configurations. Adds comprehensive Quick Start section with complete working code example including proper tensor shapes and device setup. Expands documentation with detailed benchmarking section covering forward pass equivalence, performance testing, gradient computation, and multi-query associative recall. Enhances troubleshooting guide with specific commands for verifying CUDA setup, handling import errors, monitoring memory usage, and addressing numerical stability issues. --- README.md | 196 +++++++++++++++++++++++++++++++++--------------------- 1 file changed, 122 insertions(+), 74 deletions(-) diff --git a/README.md b/README.md index e418a47..114c5a0 100644 --- a/README.md +++ b/README.md @@ -2,25 +2,25 @@ ![Flash-DMA Banner](assets/flash_dmattn_banner.jpg) -Flash-DMA is a high-performance attention implementation that integrates Flash Attention's memory efficiency with Dynamic Mask Attention's computational efficiency for processing extremely long sequences in transformer models. +Flash-DMA is a high-performance attention implementation that integrates Flash Attention's memory efficiency with Dynamic Mask Attention's sparse computation capabilities for processing extremely long sequences in transformer models. ## Key Features -- **Sparse Attention Computation**: Dynamically selects the most important keys for each query, reducing computation from $O(N^2)$ to $O(N \cdot k)$ where $k \ll N$. +- **Sparse Attention Computation**: Dynamically selects the most important keys for each query, reducing computation from $O(N^2)$ to $O(N \cdot w)$ where $w \ll N$. - **Memory Efficiency**: Maintains Flash Attention's $O(N)$ memory complexity without materializing the full attention matrix. -- **CUDA-Accelerated**: Deep integration at the CUDA kernel level for maximum performance. -- **Long Sequence Support**: Efficiently handles sequences of 128K+ tokens that would be impractical with standard attention. -- **Backward Compatible**: API compatible with existing Flash Attention implementations. +- **CUDA-Accelerated**: Deep integration at the CUDA kernel level with custom sparse GEMM operations for maximum performance. +- **Long Sequence Support**: Efficiently handles sequences of 128K+ tokens through dynamic masking when sequence length exceeds `keep_window_size`. +- **Advanced Integration**: Complete integration from Python frontend to CUDA backend with optimized memory layouts and sparse computation strategies. ## Installation ### Prerequisites -- **Python**: 3.7 or later -- **PyTorch**: 1.10.0 or later -- **CUDA**: 11.0 or later (for GPU acceleration) -- **NVIDIA GPU**: Compute Capability 6.0 or higher -- **C++ Compiler**: GCC 7+ or compatible +- **Python**: 3.8 or later +- **PyTorch**: 2.0.0 or later +- **CUDA**: 11.8 or later +- **NVIDIA GPU**: Compute Capability 8.0 or higher +- **C++ Compiler**: GCC 7+ ### CUDA Environment Setup @@ -43,44 +43,43 @@ git submodule update --init --recursive pip install . ``` - - ## How It Works @@ -89,7 +88,6 @@ Flash-DMA combines two complementary techniques: - **Dynamic Mask Attention**: Computes relevance scores for keys and selects only the most important ones for attention computation - **Flash Attention**: Processes attention in blocks to reduce memory usage and HBM access - ### The Integration Approach The integration happens at the CUDA kernel level with several key components: @@ -99,18 +97,15 @@ The integration happens at the CUDA kernel level with several key components: - **Sparse Matrix Multiplication**: Custom CUDA kernels for efficient sparse attention computation - **Block-Based Processing**: Maintains Flash Attention's block-based approach for memory efficiency -This creates a hybrid attention mechanism that achieves both memory and computational efficiency. +This creates a hybrid attention mechanism that achieves both memory and computational efficiency for long sequences. -## Documentation -For detailed technical documentation, see: -- [Integration Guide](docs/integration.md) - Comprehensive technical details -- [API Reference](#api-reference) - Function signatures and parameters +## Documentation -### API Reference +📚 **Complete documentation is available in the [docs](docs/) directory:** -> [!IMPORTANT] -> TODO +- **[API Reference](docs/api_reference.md)** - Complete function documentation and usage examples +- **[Integration Guide](docs/integration.md)** - Detailed technical documentation of the Flash Attention integration ## Building from Source @@ -124,36 +119,54 @@ cd flash-dmattn # Build in development mode pip install -e . + +# Run tests to verify installation +python -c "import flash_dma_cpp; print('✅ Flash DMA CUDA extension imported successfully')" ``` ### Build Requirements -- CUDA Toolkit 11.0+ -- CUTLASS library (included as submodule) -- CUB library (included as submodule) +- CUDA Toolkit 11.8+ +- CUTLASS library +- PyTorch with CUDA support ### Supported Architectures -- SM 6.0+ (Pascal, Volta, Turing, Ampere, Ada Lovelace) -- Optimized for SM 8.0+ (Ampere and newer) +- **SM 8.0** +- **SM 9.0** +- **SM 10.0** +- **SM 12.0** + +**Note**: Flash Dynamic Mask Attention requires CUDA compute capability 8.0+ for optimal performance. Earlier architectures are not supported. + +## Benchmarking + +Flash-DMA provides comprehensive benchmarking tools to evaluate performance across different configurations: -## Testing +### Forward Pass Equivalence +```bash +python benchmarks/benchmark_forward_equivalence.py +``` +Validates numerical consistency between Python reference and CUDA implementation. -### Run Tests +### Performance Benchmarking +```bash +python benchmarks/benchmark_forward_performance.py +``` +Compares Flash-DMA against standard Flash Attention across various sequence lengths and batch sizes. +### Gradient Computation ```bash -# Gradient equivalent benchmarks python benchmarks/benchmark_grad.py ``` +Tests backward pass implementation and gradient equivalence. -### Compatibility +### Multi-Query Associative Recall +```bash +python benchmarks/benchmark_mqar.py +``` +Evaluates performance on long-range reasoning tasks. -| Component | Supported Versions | -|-----------|-------------------| -| PyTorch | 1.10.0+ | -| CUDA | 11.0+ | -| Python | 3.7+ | -| GPU Arch | SM 6.0+ | ## Troubleshooting @@ -161,16 +174,48 @@ python benchmarks/benchmark_grad.py **Compilation Errors** ```bash -# Ensure CUDA_HOME is set -export CUDA_HOME=/usr/local/cuda -# Update NVCC if needed -which nvcc +# Ensure CUDA_HOME is set correctly +echo $CUDA_HOME # Linux/Mac +echo $env:CUDA_HOME # Windows PowerShell + +# Check CUDA toolkit version +nvcc --version + +# Verify PyTorch CUDA support +python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')" +``` + +**Import Errors** +```python +# Test basic import +try: + import flash_dma_cpp + print("✅ Flash DMA CUDA extension imported successfully") +except ImportError as e: + print(f"❌ Import failed: {e}") + print("Please ensure the package is properly installed with: pip install -e .") ``` **Performance Issues** -- Ensure GPU has sufficient compute capability (6.0+) -- Use appropriate data types (float16 recommended) -- Verify CUDA kernels are being used (not CPU fallback) +- Ensure GPU has compute capability 8.0+ for optimal performance +- Use `torch.bfloat16` for better numerical stability +- Adjust `keep_window_size` based on available GPU memory +- Verify CUDA kernels are being used + +**Memory Issues** +```python +# Monitor GPU memory usage +torch.cuda.memory_summary() +torch.cuda.max_memory_allocated() + +# Clear cache if needed +torch.cuda.empty_cache() +``` + +**Numerical Issues** +- Use `torch.bfloat16` instead of `torch.float16` for better stability +- Check input tensor ranges for NaN or infinite values +- Validate ZOH states and active mask values are in expected ranges ## License @@ -191,6 +236,9 @@ If you use Flash-DMA in your research, please cite: ## Acknowledgments -This project builds upon the excellent work of: -- [Flash-Attention](https://github.com/Dao-AILab/flash-attention) by Tri Dao et al. -- [NVIDIA CUTLASS](https://github.com/NVIDIA/cutlass) library for efficient matrix operations \ No newline at end of file +This project builds upon and integrates several excellent works: + +- **[Flash-Attention](https://github.com/Dao-AILab/flash-attention)** - Memory-efficient attention computation +- **[NVIDIA CUTLASS](https://github.com/NVIDIA/cutlass)** - High-performance matrix operations library + +We thank the open-source community for their contributions to efficient transformer implementations.