Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 122 additions & 74 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@

![Flash-DMA Banner](assets/flash_dmattn_banner.jpg)

Flash-DMA is a high-performance attention implementation that integrates Flash Attention's memory efficiency with Dynamic Mask Attention's computational efficiency for processing extremely long sequences in transformer models.
Flash-DMA is a high-performance attention implementation that integrates Flash Attention's memory efficiency with Dynamic Mask Attention's sparse computation capabilities for processing extremely long sequences in transformer models.

## Key Features

- **Sparse Attention Computation**: Dynamically selects the most important keys for each query, reducing computation from $O(N^2)$ to $O(N \cdot k)$ where $k \ll N$.
- **Sparse Attention Computation**: Dynamically selects the most important keys for each query, reducing computation from $O(N^2)$ to $O(N \cdot w)$ where $w \ll N$.
- **Memory Efficiency**: Maintains Flash Attention's $O(N)$ memory complexity without materializing the full attention matrix.
- **CUDA-Accelerated**: Deep integration at the CUDA kernel level for maximum performance.
- **Long Sequence Support**: Efficiently handles sequences of 128K+ tokens that would be impractical with standard attention.
- **Backward Compatible**: API compatible with existing Flash Attention implementations.
- **CUDA-Accelerated**: Deep integration at the CUDA kernel level with custom sparse GEMM operations for maximum performance.
- **Long Sequence Support**: Efficiently handles sequences of 128K+ tokens through dynamic masking when sequence length exceeds `keep_window_size`.
- **Advanced Integration**: Complete integration from Python frontend to CUDA backend with optimized memory layouts and sparse computation strategies.

## Installation

### Prerequisites

- **Python**: 3.7 or later
- **PyTorch**: 1.10.0 or later
- **CUDA**: 11.0 or later (for GPU acceleration)
- **NVIDIA GPU**: Compute Capability 6.0 or higher
- **C++ Compiler**: GCC 7+ or compatible
- **Python**: 3.8 or later
- **PyTorch**: 2.0.0 or later
- **CUDA**: 11.8 or later
- **NVIDIA GPU**: Compute Capability 8.0 or higher
- **C++ Compiler**: GCC 7+

### CUDA Environment Setup

Expand All @@ -43,44 +43,43 @@ git submodule update --init --recursive
pip install .
```


<!-- ## Quick Start

### Basic Usage
## Quick Start

```python
import torch
from flash_dma_cpp import apply_dynamic_mask_attention
import flash_dma_cuda
import torch.nn.functional as F
import math

# Setup
batch_size, seq_len, num_heads, head_dim = 2, 4096, 12, 128
device = torch.device('cuda')
dtype = torch.bfloat16

# Input tensors
batch_size, num_heads, seq_len, head_dim = 1, 8, 4096, 64
query = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda', dtype=torch.float16)
key = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda', dtype=torch.float16)
value = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda', dtype=torch.float16)

# Dynamic mask parameters
dt_proj = torch.randn(num_heads, num_heads * head_dim, device='cuda', dtype=torch.float16)
A = torch.randn(num_heads, device='cuda', dtype=torch.float16)

# Apply Flash-DMA attention
output = apply_dynamic_mask_attention(
query, key, value,
dt_proj=dt_proj,
A=A,
keep_window_size=2048,
query = torch.randn(batch_size, seq_len, num_heads, head_dim,
device=device, dtype=dtype)
key = torch.randn(batch_size, seq_len, num_heads, head_dim,
device=device, dtype=dtype)
value = torch.randn(batch_size, seq_len, num_heads, head_dim,
device=device, dtype=dtype)
zoh_states = torch.randn(batch_size, num_heads, seq_len, seq_len,
device=device, dtype=dtype)
active_mask = torch.ones(batch_size, num_heads, seq_len, seq_len,
device=device, dtype=dtype)

# Run Flash-DMA
output = flash_dma_cuda.fwd(
q=query, k=key, v=value,
zoh=zoh_states, active_mask=active_mask,
softmax_scale=1.0/math.sqrt(head_dim),
keep_window_size=keep_window_size,
is_causal=True
)
```

### Performance Comparison
)[0]

Flash-DMA achieves significant speedups for long sequences:
print(f"Output shape: {output.shape}") # [2, 4096, 12, 128]
```

| Sequence Length | Selection Ratio | Theoretical Speedup | Practical Speedup |
|-----------------|----------------|---------------------|-------------------|
| 4,096 | 0.25 | 4.0× | 2.5-3.0× |
| 16,384 | 0.125 | 8.0× | 4.0-5.0× |
| 65,536 | 0.0625 | 16.0× | 6.0-8.0× | -->

## How It Works

Expand All @@ -89,7 +88,6 @@ Flash-DMA combines two complementary techniques:
- **Dynamic Mask Attention**: Computes relevance scores for keys and selects only the most important ones for attention computation
- **Flash Attention**: Processes attention in blocks to reduce memory usage and HBM access


### The Integration Approach

The integration happens at the CUDA kernel level with several key components:
Expand All @@ -99,18 +97,15 @@ The integration happens at the CUDA kernel level with several key components:
- **Sparse Matrix Multiplication**: Custom CUDA kernels for efficient sparse attention computation
- **Block-Based Processing**: Maintains Flash Attention's block-based approach for memory efficiency

This creates a hybrid attention mechanism that achieves both memory and computational efficiency.
This creates a hybrid attention mechanism that achieves both memory and computational efficiency for long sequences.

## Documentation

For detailed technical documentation, see:
- [Integration Guide](docs/integration.md) - Comprehensive technical details
- [API Reference](#api-reference) - Function signatures and parameters
## Documentation

### API Reference
📚 **Complete documentation is available in the [docs](docs/) directory:**

> [!IMPORTANT]
> TODO
- **[API Reference](docs/api_reference.md)** - Complete function documentation and usage examples
- **[Integration Guide](docs/integration.md)** - Detailed technical documentation of the Flash Attention integration


## Building from Source
Expand All @@ -124,53 +119,103 @@ cd flash-dmattn

# Build in development mode
pip install -e .

# Run tests to verify installation
python -c "import flash_dma_cpp; print('✅ Flash DMA CUDA extension imported successfully')"
```

### Build Requirements

- CUDA Toolkit 11.0+
- CUTLASS library (included as submodule)
- CUB library (included as submodule)
- CUDA Toolkit 11.8+
- CUTLASS library
- PyTorch with CUDA support

### Supported Architectures

- SM 6.0+ (Pascal, Volta, Turing, Ampere, Ada Lovelace)
- Optimized for SM 8.0+ (Ampere and newer)
- **SM 8.0**
- **SM 9.0**
- **SM 10.0**
- **SM 12.0**

**Note**: Flash Dynamic Mask Attention requires CUDA compute capability 8.0+ for optimal performance. Earlier architectures are not supported.

## Benchmarking

Flash-DMA provides comprehensive benchmarking tools to evaluate performance across different configurations:

## Testing
### Forward Pass Equivalence
```bash
python benchmarks/benchmark_forward_equivalence.py
```
Validates numerical consistency between Python reference and CUDA implementation.

### Run Tests
### Performance Benchmarking
```bash
python benchmarks/benchmark_forward_performance.py
```
Compares Flash-DMA against standard Flash Attention across various sequence lengths and batch sizes.

### Gradient Computation
```bash
# Gradient equivalent benchmarks
python benchmarks/benchmark_grad.py
```
Tests backward pass implementation and gradient equivalence.

### Compatibility
### Multi-Query Associative Recall
```bash
python benchmarks/benchmark_mqar.py
```
Evaluates performance on long-range reasoning tasks.

| Component | Supported Versions |
|-----------|-------------------|
| PyTorch | 1.10.0+ |
| CUDA | 11.0+ |
| Python | 3.7+ |
| GPU Arch | SM 6.0+ |

## Troubleshooting

### Common Issues

**Compilation Errors**
```bash
# Ensure CUDA_HOME is set
export CUDA_HOME=/usr/local/cuda
# Update NVCC if needed
which nvcc
# Ensure CUDA_HOME is set correctly
echo $CUDA_HOME # Linux/Mac
echo $env:CUDA_HOME # Windows PowerShell

# Check CUDA toolkit version
nvcc --version

# Verify PyTorch CUDA support
python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')"
```

**Import Errors**
```python
# Test basic import
try:
import flash_dma_cpp
print("✅ Flash DMA CUDA extension imported successfully")
except ImportError as e:
print(f"❌ Import failed: {e}")
print("Please ensure the package is properly installed with: pip install -e .")
```

**Performance Issues**
- Ensure GPU has sufficient compute capability (6.0+)
- Use appropriate data types (float16 recommended)
- Verify CUDA kernels are being used (not CPU fallback)
- Ensure GPU has compute capability 8.0+ for optimal performance
- Use `torch.bfloat16` for better numerical stability
- Adjust `keep_window_size` based on available GPU memory
- Verify CUDA kernels are being used

**Memory Issues**
```python
# Monitor GPU memory usage
torch.cuda.memory_summary()
torch.cuda.max_memory_allocated()

# Clear cache if needed
torch.cuda.empty_cache()
```

**Numerical Issues**
- Use `torch.bfloat16` instead of `torch.float16` for better stability
- Check input tensor ranges for NaN or infinite values
- Validate ZOH states and active mask values are in expected ranges

## License

Expand All @@ -191,6 +236,9 @@ If you use Flash-DMA in your research, please cite:

## Acknowledgments

This project builds upon the excellent work of:
- [Flash-Attention](https://github.com/Dao-AILab/flash-attention) by Tri Dao et al.
- [NVIDIA CUTLASS](https://github.com/NVIDIA/cutlass) library for efficient matrix operations
This project builds upon and integrates several excellent works:

- **[Flash-Attention](https://github.com/Dao-AILab/flash-attention)** - Memory-efficient attention computation
- **[NVIDIA CUTLASS](https://github.com/NVIDIA/cutlass)** - High-performance matrix operations library

We thank the open-source community for their contributions to efficient transformer implementations.