diff --git a/README.md b/README.md index e418a47..114c5a0 100644 --- a/README.md +++ b/README.md @@ -2,25 +2,25 @@ ![Flash-DMA Banner](assets/flash_dmattn_banner.jpg) -Flash-DMA is a high-performance attention implementation that integrates Flash Attention's memory efficiency with Dynamic Mask Attention's computational efficiency for processing extremely long sequences in transformer models. +Flash-DMA is a high-performance attention implementation that integrates Flash Attention's memory efficiency with Dynamic Mask Attention's sparse computation capabilities for processing extremely long sequences in transformer models. ## Key Features -- **Sparse Attention Computation**: Dynamically selects the most important keys for each query, reducing computation from $O(N^2)$ to $O(N \cdot k)$ where $k \ll N$. +- **Sparse Attention Computation**: Dynamically selects the most important keys for each query, reducing computation from $O(N^2)$ to $O(N \cdot w)$ where $w \ll N$. - **Memory Efficiency**: Maintains Flash Attention's $O(N)$ memory complexity without materializing the full attention matrix. -- **CUDA-Accelerated**: Deep integration at the CUDA kernel level for maximum performance. -- **Long Sequence Support**: Efficiently handles sequences of 128K+ tokens that would be impractical with standard attention. -- **Backward Compatible**: API compatible with existing Flash Attention implementations. +- **CUDA-Accelerated**: Deep integration at the CUDA kernel level with custom sparse GEMM operations for maximum performance. +- **Long Sequence Support**: Efficiently handles sequences of 128K+ tokens through dynamic masking when sequence length exceeds `keep_window_size`. +- **Advanced Integration**: Complete integration from Python frontend to CUDA backend with optimized memory layouts and sparse computation strategies. ## Installation ### Prerequisites -- **Python**: 3.7 or later -- **PyTorch**: 1.10.0 or later -- **CUDA**: 11.0 or later (for GPU acceleration) -- **NVIDIA GPU**: Compute Capability 6.0 or higher -- **C++ Compiler**: GCC 7+ or compatible +- **Python**: 3.8 or later +- **PyTorch**: 2.0.0 or later +- **CUDA**: 11.8 or later +- **NVIDIA GPU**: Compute Capability 8.0 or higher +- **C++ Compiler**: GCC 7+ ### CUDA Environment Setup @@ -43,44 +43,43 @@ git submodule update --init --recursive pip install . ``` - - ## How It Works @@ -89,7 +88,6 @@ Flash-DMA combines two complementary techniques: - **Dynamic Mask Attention**: Computes relevance scores for keys and selects only the most important ones for attention computation - **Flash Attention**: Processes attention in blocks to reduce memory usage and HBM access - ### The Integration Approach The integration happens at the CUDA kernel level with several key components: @@ -99,18 +97,15 @@ The integration happens at the CUDA kernel level with several key components: - **Sparse Matrix Multiplication**: Custom CUDA kernels for efficient sparse attention computation - **Block-Based Processing**: Maintains Flash Attention's block-based approach for memory efficiency -This creates a hybrid attention mechanism that achieves both memory and computational efficiency. +This creates a hybrid attention mechanism that achieves both memory and computational efficiency for long sequences. -## Documentation -For detailed technical documentation, see: -- [Integration Guide](docs/integration.md) - Comprehensive technical details -- [API Reference](#api-reference) - Function signatures and parameters +## Documentation -### API Reference +📚 **Complete documentation is available in the [docs](docs/) directory:** -> [!IMPORTANT] -> TODO +- **[API Reference](docs/api_reference.md)** - Complete function documentation and usage examples +- **[Integration Guide](docs/integration.md)** - Detailed technical documentation of the Flash Attention integration ## Building from Source @@ -124,36 +119,54 @@ cd flash-dmattn # Build in development mode pip install -e . + +# Run tests to verify installation +python -c "import flash_dma_cpp; print('✅ Flash DMA CUDA extension imported successfully')" ``` ### Build Requirements -- CUDA Toolkit 11.0+ -- CUTLASS library (included as submodule) -- CUB library (included as submodule) +- CUDA Toolkit 11.8+ +- CUTLASS library +- PyTorch with CUDA support ### Supported Architectures -- SM 6.0+ (Pascal, Volta, Turing, Ampere, Ada Lovelace) -- Optimized for SM 8.0+ (Ampere and newer) +- **SM 8.0** +- **SM 9.0** +- **SM 10.0** +- **SM 12.0** + +**Note**: Flash Dynamic Mask Attention requires CUDA compute capability 8.0+ for optimal performance. Earlier architectures are not supported. + +## Benchmarking + +Flash-DMA provides comprehensive benchmarking tools to evaluate performance across different configurations: -## Testing +### Forward Pass Equivalence +```bash +python benchmarks/benchmark_forward_equivalence.py +``` +Validates numerical consistency between Python reference and CUDA implementation. -### Run Tests +### Performance Benchmarking +```bash +python benchmarks/benchmark_forward_performance.py +``` +Compares Flash-DMA against standard Flash Attention across various sequence lengths and batch sizes. +### Gradient Computation ```bash -# Gradient equivalent benchmarks python benchmarks/benchmark_grad.py ``` +Tests backward pass implementation and gradient equivalence. -### Compatibility +### Multi-Query Associative Recall +```bash +python benchmarks/benchmark_mqar.py +``` +Evaluates performance on long-range reasoning tasks. -| Component | Supported Versions | -|-----------|-------------------| -| PyTorch | 1.10.0+ | -| CUDA | 11.0+ | -| Python | 3.7+ | -| GPU Arch | SM 6.0+ | ## Troubleshooting @@ -161,16 +174,48 @@ python benchmarks/benchmark_grad.py **Compilation Errors** ```bash -# Ensure CUDA_HOME is set -export CUDA_HOME=/usr/local/cuda -# Update NVCC if needed -which nvcc +# Ensure CUDA_HOME is set correctly +echo $CUDA_HOME # Linux/Mac +echo $env:CUDA_HOME # Windows PowerShell + +# Check CUDA toolkit version +nvcc --version + +# Verify PyTorch CUDA support +python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')" +``` + +**Import Errors** +```python +# Test basic import +try: + import flash_dma_cpp + print("✅ Flash DMA CUDA extension imported successfully") +except ImportError as e: + print(f"❌ Import failed: {e}") + print("Please ensure the package is properly installed with: pip install -e .") ``` **Performance Issues** -- Ensure GPU has sufficient compute capability (6.0+) -- Use appropriate data types (float16 recommended) -- Verify CUDA kernels are being used (not CPU fallback) +- Ensure GPU has compute capability 8.0+ for optimal performance +- Use `torch.bfloat16` for better numerical stability +- Adjust `keep_window_size` based on available GPU memory +- Verify CUDA kernels are being used + +**Memory Issues** +```python +# Monitor GPU memory usage +torch.cuda.memory_summary() +torch.cuda.max_memory_allocated() + +# Clear cache if needed +torch.cuda.empty_cache() +``` + +**Numerical Issues** +- Use `torch.bfloat16` instead of `torch.float16` for better stability +- Check input tensor ranges for NaN or infinite values +- Validate ZOH states and active mask values are in expected ranges ## License @@ -191,6 +236,9 @@ If you use Flash-DMA in your research, please cite: ## Acknowledgments -This project builds upon the excellent work of: -- [Flash-Attention](https://github.com/Dao-AILab/flash-attention) by Tri Dao et al. -- [NVIDIA CUTLASS](https://github.com/NVIDIA/cutlass) library for efficient matrix operations \ No newline at end of file +This project builds upon and integrates several excellent works: + +- **[Flash-Attention](https://github.com/Dao-AILab/flash-attention)** - Memory-efficient attention computation +- **[NVIDIA CUTLASS](https://github.com/NVIDIA/cutlass)** - High-performance matrix operations library + +We thank the open-source community for their contributions to efficient transformer implementations.