diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 46b7d9f..bcc17cf 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -1,6 +1,6 @@ --- name: Bug report -about: Create a report to help us improve Flash-DMA +about: Create a report to help us improve FSA title: '[BUG REPORT] ' labels: ["bug"] assignees: @@ -39,7 +39,7 @@ python -c "import torch; print(f'PyTorch: {torch.__version__}'); print(f'CUDA: { **Additional context** - OS: [e.g. Ubuntu 20.04, Windows 10, macOS 12] - Python version: [e.g. 3.9.7] -- Flash-DMA version: [e.g. 0.1.0] +- FSA version: [e.g. 0.1.0] - CUDA Compute Capability: [e.g. 8.6] **Error traceback** diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index dfb007e..65c74de 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -1,5 +1,5 @@ name: Bug report -description: Create a report to help us improve Flash-DMA +description: Create a report to help us improve FSA title: "[BUG REPORT] " labels: - bug diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index 2816d8f..1db7b8e 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -1,6 +1,6 @@ --- name: Feature request -about: Suggest an idea for Flash-DMA +about: Suggest an idea for FSA title: '[FEATURE REQUEST] ' labels: ["feature"] assignees: @@ -44,4 +44,4 @@ Add any other context or screenshots about the feature request here. If this feature is inspired by a paper or existing implementation, please provide: - Link to paper/implementation - Brief explanation of the technique -- Why it would be valuable for Flash-DMA users +- Why it would be valuable for FSA users diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index 46a7d39..8ac591e 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -1,5 +1,5 @@ name: Feature request -description: Suggest an idea for FDMA +description: Suggest an idea for FSA title: "[FEATURE REQUEST] " labels: - feature @@ -16,7 +16,7 @@ body: - type: markdown attributes: value: | - Help us understand the feature you are proposing and why it matters for Flash-DMA workflows. + Help us understand the feature you are proposing and why it matters for FSA workflows. - type: textarea id: problem attributes: diff --git a/.github/workflows/_build.yml b/.github/workflows/_build.yml index 8ad6f3f..80e09a9 100644 --- a/.github/workflows/_build.yml +++ b/.github/workflows/_build.yml @@ -172,12 +172,12 @@ jobs: export MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "129" ] && echo 1 || echo 2) export NVCC_THREADS=2 - export FLASH_DMATTN_FORCE_BUILD="TRUE" - export FLASH_DMATTN_FORCE_CXX11_ABI=${{ inputs.cxx11_abi }} + export FLASH_SPARSE_ATTENTION_FORCE_BUILD="TRUE" + export FLASH_SPARSE_ATTENTION_FORCE_CXX11_ABI=${{ inputs.cxx11_abi }} # If specified, limit to a single compute capability to speed up build if [ -n "${MATRIX_ARCH}" ]; then - export FLASH_DMATTN_CUDA_ARCHS="${MATRIX_ARCH}" + export FLASH_SPARSE_ATTENTION_CUDA_ARCHS="${MATRIX_ARCH}" fi # GH allows max 6h diff --git a/.github/workflows/manual_publish.yml b/.github/workflows/manual_publish.yml index c1dae8e..258bfae 100644 --- a/.github/workflows/manual_publish.yml +++ b/.github/workflows/manual_publish.yml @@ -38,7 +38,7 @@ jobs: - name: Build core package env: - FLASH_DMATTN_SKIP_CUDA_BUILD: "TRUE" + FLASH_SPARSE_ATTENTION_SKIP_CUDA_BUILD: "TRUE" run: | python setup.py sdist --dist-dir=dist ls -l dist diff --git a/CITATION.cff b/CITATION.cff index d8f3d0e..4aaeee9 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -1,8 +1,8 @@ cff-version: "1.2.0" date-released: 2025-06 message: "If you use this software, please cite it using these metadata." -title: "Flash Dynamic Mask Attention: Trainable Dynamic Mask Sparse Attention" -url: "https://github.com/SmallDoges/flash-dmattn" +title: "Flash Sparse Attention: Trainable Dynamic Mask Sparse Attention" +url: "https://github.com/SmallDoges/flash-sparse-attention" authors: - family-names: Shi given-names: Jingze diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d0368d1..ba79358 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -4,7 +4,7 @@ Everyone is welcome to contribute, and we value everybody's contribution. Code c It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you. -However you choose to contribute, please be mindful and respect our [code of conduct](https://github.com/SmallDoges/flash-dmattn/blob/main/CODE_OF_CONDUCT.md). +However you choose to contribute, please be mindful and respect our [code of conduct](https://github.com/SmallDoges/flash-sparse-attention/blob/main/CODE_OF_CONDUCT.md). ## Ways to contribute @@ -16,7 +16,7 @@ There are several ways you can contribute to Flash-DMA: * Contribute to the examples, benchmarks, or documentation. * Improve CUDA kernel performance. -If you don't know where to start, there is a special [Good First Issue](https://github.com/SmallDoges/flash-dmattn/contribute) listing. It will give you a list of open issues that are beginner-friendly and help you start contributing to open-source. +If you don't know where to start, there is a special [Good First Issue](https://github.com/SmallDoges/flash-sparse-attention/contribute) listing. It will give you a list of open issues that are beginner-friendly and help you start contributing to open-source. > All contributions are equally valuable to the community. 🥰 @@ -81,14 +81,14 @@ You will need basic `git` proficiency to contribute to Flash-DMA. You'll need ** ### Development Setup -1. Fork the [repository](https://github.com/SmallDoges/flash-dmattn) by clicking on the **Fork** button. +1. Fork the [repository](https://github.com/SmallDoges/flash-sparse-attention) by clicking on the **Fork** button. 2. Clone your fork to your local disk, and add the base repository as a remote: ```bash - git clone https://github.com//flash-dmattn.git - cd flash-dmattn - git remote add upstream https://github.com/SmallDoges/flash-dmattn.git + git clone https://github.com//flash-sparse-attention.git + cd flash-sparse-attention + git remote add upstream https://github.com/SmallDoges/flash-sparse-attention.git ``` 3. Create a new branch to hold your development changes: @@ -157,7 +157,7 @@ You will need basic `git` proficiency to contribute to Flash-DMA. You'll need ** ### Tests -An extensive test suite is included to test the library behavior and performance. Tests can be found in the [tests](https://github.com/SmallDoges/flash-dmattn/tree/main/tests) folder and benchmarks in the [benchmarks](https://github.com/SmallDoges/flash-dmattn/tree/main/benchmarks) folder. +An extensive test suite is included to test the library behavior and performance. Tests can be found in the [tests](https://github.com/SmallDoges/flash-sparse-attention/tree/main/tests) folder and benchmarks in the [benchmarks](https://github.com/SmallDoges/flash-sparse-attention/tree/main/benchmarks) folder. We use `pytest` for testing. From the root of the repository, run: @@ -200,6 +200,6 @@ If you discover a security vulnerability, please send an e-mail to the maintaine ## Questions? -If you have questions about contributing, feel free to ask in the [GitHub Discussions](https://github.com/SmallDoges/flash-dmattn/discussions) or open an issue. +If you have questions about contributing, feel free to ask in the [GitHub Discussions](https://github.com/SmallDoges/flash-sparse-attention/discussions) or open an issue. -Thank you for contributing to Flash Dynamic Mask Attention! 🚀 +Thank you for contributing to Flash Sparse Attention! 🚀 diff --git a/README.md b/README.md index 736c184..03a817e 100644 --- a/README.md +++ b/README.md @@ -45,95 +45,6 @@ Thus, a more effective approach is sparse attention: interacting each query with - Further performance improvements for skipping memory access and computation -## Performance - -We present the expected speedup of FSA over standard PyTorch SDPA under mask and bias conditions. - -![FSA Performance Overview](assets/performance_overview.png) - ---- - -### Forward Pass Performance - -The following table shows the forward pass performance comparison between FSA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs. - -| Mode | Q len | K len | Window W | SDPA (ms) | FSA (ms) | Speedup | -|--------|-------|--------|----------|-----------|-----------|---------| -| Train | 256 | 256 | 1024 | 0.29 | 0.19 | 1.58x | -| Train | 512 | 512 | 1024 | 0.35 | 0.19 | 1.86x | -| Train | 1024 | 1024 | 1024 | 0.51 | 0.18 | 2.81x | -| Train | 2048 | 2048 | 1024 | 1.04 | 0.18 | 5.68x | -| Train | 4096 | 4096 | 1024 | 2.53 | 0.24 | 10.41x | -| Train | 8192 | 8192 | 1024 | 9.38 | 0.36 | 25.93x | -| Train | 16384 | 16384 | 1024 | 28.39 | 0.81 | 35.25x | -| Train | 32768 | 32768 | 1024 | 111.87 | 2.25 | 49.78x | -| Train | 32768 | 32768 | 32 | 113.19 | 2.10 | 53.97x | -| Train | 32768 | 32768 | 64 | 113.17 | 2.12 | 53.32x | -| Train | 32768 | 32768 | 128 | 113.14 | 2.10 | 53.78x | -| Train | 32768 | 32768 | 256 | 113.18 | 2.13 | 53.18x | -| Train | 32768 | 32768 | 512 | 113.19 | 2.17 | 52.17x | -| Train | 32768 | 32768 | 1024 | 113.19 | 2.24 | 50.45x | -| Train | 32768 | 32768 | 2048 | 113.15 | 2.39 | 47.35x | -| Train | 32768 | 32768 | 4096 | 113.16 | 2.67 | 42.39x | -| Train | 32768 | 32768 | 8192 | 113.11 | 3.20 | 35.29x | -| Train | 32768 | 32768 | 16384 | 113.15 | 3.97 | 28.51x | -| Train | 32768 | 32768 | 32768 | 113.11 | 4.90 | 23.10x | -| Infer | 1 | 256 | 1024 | 0.25 | 0.19 | 1.28x | -| Infer | 1 | 512 | 1024 | 0.25 | 0.19 | 1.27x | -| Infer | 1 | 1024 | 1024 | 0.25 | 0.20 | 1.28x | -| Infer | 1 | 2048 | 1024 | 0.25 | 0.20 | 1.24x | -| Infer | 1 | 4096 | 1024 | 0.25 | 0.19 | 1.29x | -| Infer | 1 | 8192 | 1024 | 0.25 | 0.20 | 1.25x | -| Infer | 1 | 16384 | 1024 | 0.25 | 0.19 | 1.29x | -| Infer | 1 | 32768 | 1024 | 0.27 | 0.20 | 1.33x | -| Infer | 1 | 65536 | 1024 | 0.42 | 0.20 | 2.10x | -| Infer | 1 | 131072 | 1024 | 0.72 | 0.20 | 3.65x | -| Infer | 1 | 262144 | 1024 | 1.31 | 0.22 | 6.06x | -| Infer | 1 | 524288 | 1024 | 2.49 | 0.24 | 10.45x | -| Infer | 1 | 524288 | 32 | 2.48 | 0.21 | 11.60x | -| Infer | 1 | 524288 | 64 | 2.44 | 0.21 | 11.66x | -| Infer | 1 | 524288 | 128 | 2.45 | 0.21 | 11.47x | -| Infer | 1 | 524288 | 256 | 2.43 | 0.21 | 11.47x | -| Infer | 1 | 524288 | 512 | 2.44 | 0.22 | 10.89x | -| Infer | 1 | 524288 | 1024 | 2.44 | 0.24 | 10.31x | -| Infer | 1 | 524288 | 2048 | 2.44 | 0.27 | 9.07x | -| Infer | 1 | 524288 | 4096 | 2.45 | 0.33 | 7.41x | -| Infer | 1 | 524288 | 8192 | 2.44 | 0.35 | 6.93x | -| Infer | 1 | 524288 | 16384 | 2.44 | 0.35 | 6.93x | -| Infer | 1 | 524288 | 32768 | 2.45 | 0.35 | 6.96x | -| Infer | 1 | 524288 | 65536 | 2.44 | 0.35 | 6.88x | - ---- - -### Backward Pass Performance - -The following table shows the backward pass performance comparison between FSA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs. - -| Mode | Q len | K len | Window W | SDPA-BWD (ms) | FSA-BWD (ms) | Speedup | -|-------|-------|--------|----------|---------------|---------------|---------| -| Train | 256 | 256 | 1024 | 0.42 | 0.62 | 0.7x | -| Train | 512 | 512 | 1024 | 0.56 | 0.60 | 0.9x | -| Train | 1024 | 1024 | 1024 | 0.94 | 0.61 | 1.5x | -| Train | 2048 | 2048 | 1024 | 1.79 | 0.69 | 2.6x | -| Train | 4096 | 4096 | 1024 | 3.76 | 1.08 | 3.5x | -| Train | 8192 | 8192 | 1024 | 14.39 | 2.06 | 7.0x | -| Train | 16384 | 16384 | 1024 | 39.56 | 4.97 | 8.0x | -| Train | 32768 | 32768 | 1024 | 142.07 | 25.63 | 5.5x | -| Train | 32768 | 32768 | 32 | 142.70 | 21.91 | 6.5x | -| Train | 32768 | 32768 | 64 | 142.65 | 22.29 | 6.4x | -| Train | 32768 | 32768 | 128 | 142.69 | 23.04 | 6.2x | -| Train | 32768 | 32768 | 256 | 142.69 | 24.27 | 5.9x | -| Train | 32768 | 32768 | 512 | 142.67 | 25.12 | 5.7x | -| Train | 32768 | 32768 | 1024 | 142.55 | 25.58 | 5.6x | -| Train | 32768 | 32768 | 2048 | 142.75 | 25.64 | 5.6x | -| Train | 32768 | 32768 | 4096 | 142.61 | 24.84 | 5.7x | -| Train | 32768 | 32768 | 8192 | 142.33 | 25.63 | 5.6x | -| Train | 32768 | 32768 | 16384 | 142.40 | 25.62 | 5.6x | -| Train | 32768 | 32768 | 32768 | 142.43 | 25.63 | 5.6x | - ---- - - ## Installation ### Requirements @@ -150,14 +61,14 @@ The following table shows the backward pass performance comparison between FSA a You can install FSA via pre-compiled wheels: ```bash -pip install flash_sparse_attn --no-build-isolation +pip install flash-sparse-attn --no-build-isolation ``` Alternatively, you can compile and install from source: ```bash -git clone https://github.com/SmallDoges/flash_sparse_attn.git -cd flash_sparse_attn +git clone https://github.com/SmallDoges/flash-sparse-attn.git +cd flash-sparse-attn pip install . --no-build-isolation ``` @@ -245,6 +156,95 @@ print(f"Bias gradient shape: {attn_bias.grad.shape}") ``` +## Performance + +We present the expected speedup of FSA over standard PyTorch SDPA under mask and bias conditions. + +![FSA Performance Overview](assets/performance_overview.png) + +--- + +### Forward Pass Performance + +The following table shows the forward pass performance comparison between FSA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs. + +| Mode | Q len | K len | Window W | SDPA (ms) | FSA (ms) | Speedup | +|--------|-------|--------|----------|-----------|-----------|---------| +| Train | 256 | 256 | 1024 | 0.29 | 0.19 | 1.58x | +| Train | 512 | 512 | 1024 | 0.35 | 0.19 | 1.86x | +| Train | 1024 | 1024 | 1024 | 0.51 | 0.18 | 2.81x | +| Train | 2048 | 2048 | 1024 | 1.04 | 0.18 | 5.68x | +| Train | 4096 | 4096 | 1024 | 2.53 | 0.24 | 10.41x | +| Train | 8192 | 8192 | 1024 | 9.38 | 0.36 | 25.93x | +| Train | 16384 | 16384 | 1024 | 28.39 | 0.81 | 35.25x | +| Train | 32768 | 32768 | 1024 | 111.87 | 2.25 | 49.78x | +| Train | 32768 | 32768 | 32 | 113.19 | 2.10 | 53.97x | +| Train | 32768 | 32768 | 64 | 113.17 | 2.12 | 53.32x | +| Train | 32768 | 32768 | 128 | 113.14 | 2.10 | 53.78x | +| Train | 32768 | 32768 | 256 | 113.18 | 2.13 | 53.18x | +| Train | 32768 | 32768 | 512 | 113.19 | 2.17 | 52.17x | +| Train | 32768 | 32768 | 1024 | 113.19 | 2.24 | 50.45x | +| Train | 32768 | 32768 | 2048 | 113.15 | 2.39 | 47.35x | +| Train | 32768 | 32768 | 4096 | 113.16 | 2.67 | 42.39x | +| Train | 32768 | 32768 | 8192 | 113.11 | 3.20 | 35.29x | +| Train | 32768 | 32768 | 16384 | 113.15 | 3.97 | 28.51x | +| Train | 32768 | 32768 | 32768 | 113.11 | 4.90 | 23.10x | +| Infer | 1 | 256 | 1024 | 0.25 | 0.19 | 1.28x | +| Infer | 1 | 512 | 1024 | 0.25 | 0.19 | 1.27x | +| Infer | 1 | 1024 | 1024 | 0.25 | 0.20 | 1.28x | +| Infer | 1 | 2048 | 1024 | 0.25 | 0.20 | 1.24x | +| Infer | 1 | 4096 | 1024 | 0.25 | 0.19 | 1.29x | +| Infer | 1 | 8192 | 1024 | 0.25 | 0.20 | 1.25x | +| Infer | 1 | 16384 | 1024 | 0.25 | 0.19 | 1.29x | +| Infer | 1 | 32768 | 1024 | 0.27 | 0.20 | 1.33x | +| Infer | 1 | 65536 | 1024 | 0.42 | 0.20 | 2.10x | +| Infer | 1 | 131072 | 1024 | 0.72 | 0.20 | 3.65x | +| Infer | 1 | 262144 | 1024 | 1.31 | 0.22 | 6.06x | +| Infer | 1 | 524288 | 1024 | 2.49 | 0.24 | 10.45x | +| Infer | 1 | 524288 | 32 | 2.48 | 0.21 | 11.60x | +| Infer | 1 | 524288 | 64 | 2.44 | 0.21 | 11.66x | +| Infer | 1 | 524288 | 128 | 2.45 | 0.21 | 11.47x | +| Infer | 1 | 524288 | 256 | 2.43 | 0.21 | 11.47x | +| Infer | 1 | 524288 | 512 | 2.44 | 0.22 | 10.89x | +| Infer | 1 | 524288 | 1024 | 2.44 | 0.24 | 10.31x | +| Infer | 1 | 524288 | 2048 | 2.44 | 0.27 | 9.07x | +| Infer | 1 | 524288 | 4096 | 2.45 | 0.33 | 7.41x | +| Infer | 1 | 524288 | 8192 | 2.44 | 0.35 | 6.93x | +| Infer | 1 | 524288 | 16384 | 2.44 | 0.35 | 6.93x | +| Infer | 1 | 524288 | 32768 | 2.45 | 0.35 | 6.96x | +| Infer | 1 | 524288 | 65536 | 2.44 | 0.35 | 6.88x | + +--- + +### Backward Pass Performance + +The following table shows the backward pass performance comparison between FSA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs. + +| Mode | Q len | K len | Window W | SDPA-BWD (ms) | FSA-BWD (ms) | Speedup | +|-------|-------|--------|----------|---------------|---------------|---------| +| Train | 256 | 256 | 1024 | 0.42 | 0.62 | 0.7x | +| Train | 512 | 512 | 1024 | 0.56 | 0.60 | 0.9x | +| Train | 1024 | 1024 | 1024 | 0.94 | 0.61 | 1.5x | +| Train | 2048 | 2048 | 1024 | 1.79 | 0.69 | 2.6x | +| Train | 4096 | 4096 | 1024 | 3.76 | 1.08 | 3.5x | +| Train | 8192 | 8192 | 1024 | 14.39 | 2.06 | 7.0x | +| Train | 16384 | 16384 | 1024 | 39.56 | 4.97 | 8.0x | +| Train | 32768 | 32768 | 1024 | 142.07 | 25.63 | 5.5x | +| Train | 32768 | 32768 | 32 | 142.70 | 21.91 | 6.5x | +| Train | 32768 | 32768 | 64 | 142.65 | 22.29 | 6.4x | +| Train | 32768 | 32768 | 128 | 142.69 | 23.04 | 6.2x | +| Train | 32768 | 32768 | 256 | 142.69 | 24.27 | 5.9x | +| Train | 32768 | 32768 | 512 | 142.67 | 25.12 | 5.7x | +| Train | 32768 | 32768 | 1024 | 142.55 | 25.58 | 5.6x | +| Train | 32768 | 32768 | 2048 | 142.75 | 25.64 | 5.6x | +| Train | 32768 | 32768 | 4096 | 142.61 | 24.84 | 5.7x | +| Train | 32768 | 32768 | 8192 | 142.33 | 25.63 | 5.6x | +| Train | 32768 | 32768 | 16384 | 142.40 | 25.62 | 5.6x | +| Train | 32768 | 32768 | 32768 | 142.43 | 25.63 | 5.6x | + +--- + + ## Benchmarking FSA provides comprehensive benchmarking tools to evaluate performance across different configurations: diff --git a/README_zh.md b/README_zh.md index 149b108..8bea29c 100644 --- a/README_zh.md +++ b/README_zh.md @@ -45,95 +45,6 @@ Flash-Sparse-Attention 是一个高性能的可训练稀疏注意力实现, 将 - 进一步提升跳过访存与计算的性能 -## 性能 - -我们展示了带有mask与bias条件下 FSA 相对于标准 PyTorch SDPA 的预期加速效果. - -![FSA Performance Overview](assets/performance_overview.png) - ---- - -### 前向传播性能 - -以下表格是我们在NVIDIA A100-SXM4-80GB上对FSA与标准PyTorch SDPA在不同配置下的前向性能对比测试结果. 结果为预热两次, 运行三次的平均值. - -| Mode | Q len | K len | Window W | SDPA (ms) | FSA (ms) | Speedup | -|--------|-------|--------|----------|-----------|-----------|---------| -| Train | 256 | 256 | 1024 | 0.29 | 0.19 | 1.58x | -| Train | 512 | 512 | 1024 | 0.35 | 0.19 | 1.86x | -| Train | 1024 | 1024 | 1024 | 0.51 | 0.18 | 2.81x | -| Train | 2048 | 2048 | 1024 | 1.04 | 0.18 | 5.68x | -| Train | 4096 | 4096 | 1024 | 2.53 | 0.24 | 10.41x | -| Train | 8192 | 8192 | 1024 | 9.38 | 0.36 | 25.93x | -| Train | 16384 | 16384 | 1024 | 28.39 | 0.81 | 35.25x | -| Train | 32768 | 32768 | 1024 | 111.87 | 2.25 | 49.78x | -| Train | 32768 | 32768 | 32 | 113.19 | 2.10 | 53.97x | -| Train | 32768 | 32768 | 64 | 113.17 | 2.12 | 53.32x | -| Train | 32768 | 32768 | 128 | 113.14 | 2.10 | 53.78x | -| Train | 32768 | 32768 | 256 | 113.18 | 2.13 | 53.18x | -| Train | 32768 | 32768 | 512 | 113.19 | 2.17 | 52.17x | -| Train | 32768 | 32768 | 1024 | 113.19 | 2.24 | 50.45x | -| Train | 32768 | 32768 | 2048 | 113.15 | 2.39 | 47.35x | -| Train | 32768 | 32768 | 4096 | 113.16 | 2.67 | 42.39x | -| Train | 32768 | 32768 | 8192 | 113.11 | 3.20 | 35.29x | -| Train | 32768 | 32768 | 16384 | 113.15 | 3.97 | 28.51x | -| Train | 32768 | 32768 | 32768 | 113.11 | 4.90 | 23.10x | -| Infer | 1 | 256 | 1024 | 0.25 | 0.19 | 1.28x | -| Infer | 1 | 512 | 1024 | 0.25 | 0.19 | 1.27x | -| Infer | 1 | 1024 | 1024 | 0.25 | 0.20 | 1.28x | -| Infer | 1 | 2048 | 1024 | 0.25 | 0.20 | 1.24x | -| Infer | 1 | 4096 | 1024 | 0.25 | 0.19 | 1.29x | -| Infer | 1 | 8192 | 1024 | 0.25 | 0.20 | 1.25x | -| Infer | 1 | 16384 | 1024 | 0.25 | 0.19 | 1.29x | -| Infer | 1 | 32768 | 1024 | 0.27 | 0.20 | 1.33x | -| Infer | 1 | 65536 | 1024 | 0.42 | 0.20 | 2.10x | -| Infer | 1 | 131072 | 1024 | 0.72 | 0.20 | 3.65x | -| Infer | 1 | 262144 | 1024 | 1.31 | 0.22 | 6.06x | -| Infer | 1 | 524288 | 1024 | 2.49 | 0.24 | 10.45x | -| Infer | 1 | 524288 | 32 | 2.48 | 0.21 | 11.60x | -| Infer | 1 | 524288 | 64 | 2.44 | 0.21 | 11.66x | -| Infer | 1 | 524288 | 128 | 2.45 | 0.21 | 11.47x | -| Infer | 1 | 524288 | 256 | 2.43 | 0.21 | 11.47x | -| Infer | 1 | 524288 | 512 | 2.44 | 0.22 | 10.89x | -| Infer | 1 | 524288 | 1024 | 2.44 | 0.24 | 10.31x | -| Infer | 1 | 524288 | 2048 | 2.44 | 0.27 | 9.07x | -| Infer | 1 | 524288 | 4096 | 2.45 | 0.33 | 7.41x | -| Infer | 1 | 524288 | 8192 | 2.44 | 0.35 | 6.93x | -| Infer | 1 | 524288 | 16384 | 2.44 | 0.35 | 6.93x | -| Infer | 1 | 524288 | 32768 | 2.45 | 0.35 | 6.96x | -| Infer | 1 | 524288 | 65536 | 2.44 | 0.35 | 6.88x | - ---- - -### 反向传播性能 - -以下表格是我们在NVIDIA A100-SXM4-80GB上对FSA与标准PyTorch SDPA在不同配置下的反向性能对比测试结果. 结果为预热两次, 运行三次的平均值. - -| Mode | Q len | K len | Window W | SDPA-BWD (ms) | FSA-BWD (ms) | Speedup | -|-------|-------|--------|----------|---------------|---------------|---------| -| Train | 256 | 256 | 1024 | 0.42 | 0.62 | 0.7x | -| Train | 512 | 512 | 1024 | 0.56 | 0.60 | 0.9x | -| Train | 1024 | 1024 | 1024 | 0.94 | 0.61 | 1.5x | -| Train | 2048 | 2048 | 1024 | 1.79 | 0.69 | 2.6x | -| Train | 4096 | 4096 | 1024 | 3.76 | 1.08 | 3.5x | -| Train | 8192 | 8192 | 1024 | 14.39 | 2.06 | 7.0x | -| Train | 16384 | 16384 | 1024 | 39.56 | 4.97 | 8.0x | -| Train | 32768 | 32768 | 1024 | 142.07 | 25.63 | 5.5x | -| Train | 32768 | 32768 | 32 | 142.70 | 21.91 | 6.5x | -| Train | 32768 | 32768 | 64 | 142.65 | 22.29 | 6.4x | -| Train | 32768 | 32768 | 128 | 142.69 | 23.04 | 6.2x | -| Train | 32768 | 32768 | 256 | 142.69 | 24.27 | 5.9x | -| Train | 32768 | 32768 | 512 | 142.67 | 25.12 | 5.7x | -| Train | 32768 | 32768 | 1024 | 142.55 | 25.58 | 5.6x | -| Train | 32768 | 32768 | 2048 | 142.75 | 25.64 | 5.6x | -| Train | 32768 | 32768 | 4096 | 142.61 | 24.84 | 5.7x | -| Train | 32768 | 32768 | 8192 | 142.33 | 25.63 | 5.6x | -| Train | 32768 | 32768 | 16384 | 142.40 | 25.62 | 5.6x | -| Train | 32768 | 32768 | 32768 | 142.43 | 25.63 | 5.6x | - ---- - - ## 安装 ### 依赖 @@ -150,14 +61,14 @@ Flash-Sparse-Attention 是一个高性能的可训练稀疏注意力实现, 将 您可以通过预编译的轮子安装 FSA: ```bash -pip install flash_sparse_attn --no-build-isolation +pip install flash-sparse-attn --no-build-isolation ``` 或者, 您可以从源代码编译和安装: ```bash -git clone https://github.com/SmallDoges/flash_sparse_attn.git -cd flash_sparse_attn +git clone https://github.com/SmallDoges/flash-sparse-attn.git +cd flash-sparse-attn pip install . --no-build-isolation ``` @@ -185,7 +96,7 @@ key = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dt value = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype) # 为稀疏注意力创建 bias -attn_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype) +attn_bias = torch.randn(batch_size, num_kv_heads, 1, seq_len, device=device, dtype=dtype) # 基于 bias 生成动态 mask if seq_len > window_size: @@ -245,6 +156,95 @@ print(f"Bias 梯度形状: {attn_bias.grad.shape}") ``` +## 性能 + +我们展示了带有mask与bias条件下 FSA 相对于标准 PyTorch SDPA 的预期加速效果. + +![FSA Performance Overview](assets/performance_overview.png) + +--- + +### 前向传播性能 + +以下表格是我们在NVIDIA A100-SXM4-80GB上对FSA与标准PyTorch SDPA在不同配置下的前向性能对比测试结果. 结果为预热两次, 运行三次的平均值. + +| Mode | Q len | K len | Window W | SDPA (ms) | FSA (ms) | Speedup | +|--------|-------|--------|----------|-----------|-----------|---------| +| Train | 256 | 256 | 1024 | 0.29 | 0.19 | 1.58x | +| Train | 512 | 512 | 1024 | 0.35 | 0.19 | 1.86x | +| Train | 1024 | 1024 | 1024 | 0.51 | 0.18 | 2.81x | +| Train | 2048 | 2048 | 1024 | 1.04 | 0.18 | 5.68x | +| Train | 4096 | 4096 | 1024 | 2.53 | 0.24 | 10.41x | +| Train | 8192 | 8192 | 1024 | 9.38 | 0.36 | 25.93x | +| Train | 16384 | 16384 | 1024 | 28.39 | 0.81 | 35.25x | +| Train | 32768 | 32768 | 1024 | 111.87 | 2.25 | 49.78x | +| Train | 32768 | 32768 | 32 | 113.19 | 2.10 | 53.97x | +| Train | 32768 | 32768 | 64 | 113.17 | 2.12 | 53.32x | +| Train | 32768 | 32768 | 128 | 113.14 | 2.10 | 53.78x | +| Train | 32768 | 32768 | 256 | 113.18 | 2.13 | 53.18x | +| Train | 32768 | 32768 | 512 | 113.19 | 2.17 | 52.17x | +| Train | 32768 | 32768 | 1024 | 113.19 | 2.24 | 50.45x | +| Train | 32768 | 32768 | 2048 | 113.15 | 2.39 | 47.35x | +| Train | 32768 | 32768 | 4096 | 113.16 | 2.67 | 42.39x | +| Train | 32768 | 32768 | 8192 | 113.11 | 3.20 | 35.29x | +| Train | 32768 | 32768 | 16384 | 113.15 | 3.97 | 28.51x | +| Train | 32768 | 32768 | 32768 | 113.11 | 4.90 | 23.10x | +| Infer | 1 | 256 | 1024 | 0.25 | 0.19 | 1.28x | +| Infer | 1 | 512 | 1024 | 0.25 | 0.19 | 1.27x | +| Infer | 1 | 1024 | 1024 | 0.25 | 0.20 | 1.28x | +| Infer | 1 | 2048 | 1024 | 0.25 | 0.20 | 1.24x | +| Infer | 1 | 4096 | 1024 | 0.25 | 0.19 | 1.29x | +| Infer | 1 | 8192 | 1024 | 0.25 | 0.20 | 1.25x | +| Infer | 1 | 16384 | 1024 | 0.25 | 0.19 | 1.29x | +| Infer | 1 | 32768 | 1024 | 0.27 | 0.20 | 1.33x | +| Infer | 1 | 65536 | 1024 | 0.42 | 0.20 | 2.10x | +| Infer | 1 | 131072 | 1024 | 0.72 | 0.20 | 3.65x | +| Infer | 1 | 262144 | 1024 | 1.31 | 0.22 | 6.06x | +| Infer | 1 | 524288 | 1024 | 2.49 | 0.24 | 10.45x | +| Infer | 1 | 524288 | 32 | 2.48 | 0.21 | 11.60x | +| Infer | 1 | 524288 | 64 | 2.44 | 0.21 | 11.66x | +| Infer | 1 | 524288 | 128 | 2.45 | 0.21 | 11.47x | +| Infer | 1 | 524288 | 256 | 2.43 | 0.21 | 11.47x | +| Infer | 1 | 524288 | 512 | 2.44 | 0.22 | 10.89x | +| Infer | 1 | 524288 | 1024 | 2.44 | 0.24 | 10.31x | +| Infer | 1 | 524288 | 2048 | 2.44 | 0.27 | 9.07x | +| Infer | 1 | 524288 | 4096 | 2.45 | 0.33 | 7.41x | +| Infer | 1 | 524288 | 8192 | 2.44 | 0.35 | 6.93x | +| Infer | 1 | 524288 | 16384 | 2.44 | 0.35 | 6.93x | +| Infer | 1 | 524288 | 32768 | 2.45 | 0.35 | 6.96x | +| Infer | 1 | 524288 | 65536 | 2.44 | 0.35 | 6.88x | + +--- + +### 反向传播性能 + +以下表格是我们在NVIDIA A100-SXM4-80GB上对FSA与标准PyTorch SDPA在不同配置下的反向性能对比测试结果. 结果为预热两次, 运行三次的平均值. + +| Mode | Q len | K len | Window W | SDPA-BWD (ms) | FSA-BWD (ms) | Speedup | +|-------|-------|--------|----------|---------------|---------------|---------| +| Train | 256 | 256 | 1024 | 0.42 | 0.62 | 0.7x | +| Train | 512 | 512 | 1024 | 0.56 | 0.60 | 0.9x | +| Train | 1024 | 1024 | 1024 | 0.94 | 0.61 | 1.5x | +| Train | 2048 | 2048 | 1024 | 1.79 | 0.69 | 2.6x | +| Train | 4096 | 4096 | 1024 | 3.76 | 1.08 | 3.5x | +| Train | 8192 | 8192 | 1024 | 14.39 | 2.06 | 7.0x | +| Train | 16384 | 16384 | 1024 | 39.56 | 4.97 | 8.0x | +| Train | 32768 | 32768 | 1024 | 142.07 | 25.63 | 5.5x | +| Train | 32768 | 32768 | 32 | 142.70 | 21.91 | 6.5x | +| Train | 32768 | 32768 | 64 | 142.65 | 22.29 | 6.4x | +| Train | 32768 | 32768 | 128 | 142.69 | 23.04 | 6.2x | +| Train | 32768 | 32768 | 256 | 142.69 | 24.27 | 5.9x | +| Train | 32768 | 32768 | 512 | 142.67 | 25.12 | 5.7x | +| Train | 32768 | 32768 | 1024 | 142.55 | 25.58 | 5.6x | +| Train | 32768 | 32768 | 2048 | 142.75 | 25.64 | 5.6x | +| Train | 32768 | 32768 | 4096 | 142.61 | 24.84 | 5.7x | +| Train | 32768 | 32768 | 8192 | 142.33 | 25.63 | 5.6x | +| Train | 32768 | 32768 | 16384 | 142.40 | 25.62 | 5.6x | +| Train | 32768 | 32768 | 32768 | 142.43 | 25.63 | 5.6x | + +--- + + ## 基准测试 FSA 提供全面的基准测试工具, 用于评估不同配置下的性能: diff --git a/SECURITY.md b/SECURITY.md index 020430e..1abb585 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -13,7 +13,7 @@ We actively maintain and provide security updates for the following versions: ### CUDA Code Execution -Flash Dynamic Mask Attention includes CUDA kernels and C++ extensions that execute on your GPU. When using this library: +Flash Sparse Attention includes CUDA kernels and C++ extensions that execute on your GPU. When using this library: - Only install from trusted sources (official PyPI releases or verified builds) - Be cautious when building from source with modifications @@ -46,11 +46,11 @@ If you discover a security vulnerability, please report it responsibly: **For security issues:** - Email: losercheems@gmail.com -- Subject: [SECURITY] Flash-DMA Vulnerability Report +- Subject: [SECURITY] FSA Vulnerability Report - Include: Detailed description, reproduction steps, and potential impact **For general bugs:** -- Use our [GitHub Issues](https://github.com/SmallDoges/flash-dmattn/issues) +- Use our [GitHub Issues](https://github.com/SmallDoges/flash-sparse-attention/issues) - Follow our [contributing guidelines](CONTRIBUTING.md) ## Response Timeline @@ -63,21 +63,21 @@ Critical security issues will be prioritized and may result in emergency release ## Security Best Practices -When using Flash Dynamic Mask Attention: +When using Flash Sparse Attention: 1. **Environment Isolation** ```bash # Use virtual environments - python -m venv flash_dma_env - source flash_dma_env/bin/activate # Linux/Mac + python -m venv fsa_env + source fsa_env/bin/activate # Linux/Mac # or - flash_dma_env\Scripts\activate # Windows + fsa_env\Scripts\activate # Windows ``` 2. **Dependency Management** ```bash # Keep dependencies updated - pip install --upgrade torch flash-dmattn + pip install --upgrade torch flash_sparse_attn ``` 3. **Input Validation** @@ -108,5 +108,5 @@ For security-related questions or concerns: - Project maintainers: See [AUTHORS](AUTHORS) file For general support: -- GitHub Issues: https://github.com/SmallDoges/flash-dmattn/issues -- Documentation: https://github.com/SmallDoges/flash-dmattn/tree/main/docs/ +- GitHub Issues: https://github.com/SmallDoges/flash-sparse-attention/issues +- Documentation: https://github.com/SmallDoges/flash-sparse-attention/tree/main/docs/ \ No newline at end of file diff --git a/benchmarks/backward_equivalence.py b/benchmarks/backward_equivalence.py index 7c34ba2..a10da1e 100644 --- a/benchmarks/backward_equivalence.py +++ b/benchmarks/backward_equivalence.py @@ -21,33 +21,33 @@ # Import the compiled CUDA extension try: - from flash_dmattn.flash_dmattn_interface import flash_dmattn_func - print("✅ Successfully imported flash_dmattn interface") + from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn interface") except ImportError as e: - print(f"❌ Failed to import flash_dmattn interface: {e}") + print(f"❌ Failed to import flash_sparse_attn interface: {e}") print("Please make sure the package is properly installed with: pip install .") # Don't exit here, just warn - flash_dmattn_func = None + flash_sparse_attn_func = None # Import the Triton implementation try: - from flash_dmattn.flash_dmattn_triton import triton_dmattn_func - print("✅ Successfully imported flash_dmattn_triton") + from flash_sparse_attn.flash_sparse_attn_triton import triton_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn_triton") except ImportError as e: - print(f"❌ Failed to import flash_dmattn_triton: {e}") + print(f"❌ Failed to import flash_sparse_attn_triton: {e}") print("Please make sure the Triton implementation is available.") # Don't exit here, just warn - triton_dmattn_func = None + triton_sparse_attn_func = None # Import the Flex Attention implementation try: - from flash_dmattn.flash_dmattn_flex import flex_dmattn_func - print("✅ Successfully imported flash_dmattn_flex") + from flash_sparse_attn.flash_sparse_attn_flex import flex_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn_flex") except ImportError as e: - print(f"❌ Failed to import flash_dmattn_flex: {e}") + print(f"❌ Failed to import flash_sparse_attn_flex: {e}") print("Please make sure the Flex Attention implementation is available.") # Don't exit here, just warn - flex_dmattn_func = None + flex_sparse_attn_func = None def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -189,7 +189,7 @@ def dynamic_mask_attention_cuda( Returns: tuple: (attn_outputs, dq, dk, dv, dbias) """ - if flash_dmattn_func is None: + if flash_sparse_attn_func is None: raise ImportError("CUDA implementation not available") query_states_leaf = query_states @@ -210,8 +210,8 @@ def dynamic_mask_attention_cuda( key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] - # Call the flash_dmattn_func interface - attn_outputs = flash_dmattn_func( + # Call the flash_sparse_attn_func interface + attn_outputs = flash_sparse_attn_func( query=query_states, key=key_states, value=value_states, @@ -256,7 +256,7 @@ def dynamic_mask_attention_triton( Returns: tuple: (attn_outputs, dq, dk, dv, dbias) """ - if triton_dmattn_func is None: + if triton_sparse_attn_func is None: raise RuntimeError("Triton implementation not available") _, num_heads, _, _ = query_states.shape @@ -288,7 +288,7 @@ def dynamic_mask_attention_triton( value_states = value_states.transpose(1, 2) # [batch, key_len, num_heads, head_dim] # Call the Triton implementation - attn_outputs = triton_dmattn_func( + attn_outputs = triton_sparse_attn_func( query=query_states, key=key_states, value=value_states, @@ -330,7 +330,7 @@ def dynamic_mask_attention_flex( Returns: tuple: (attn_outputs, dq, dk, dv, dbias) """ - if flex_dmattn_func is None: + if flex_sparse_attn_func is None: raise RuntimeError("Flex Attention implementation not available") _, num_heads, _, _ = query_states.shape @@ -359,7 +359,7 @@ def dynamic_mask_attention_flex( attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] # Call the Flex Attention implementation - attn_outputs = flex_dmattn_func( + attn_outputs = flex_sparse_attn_func( query_states, key_states, value_states, @@ -474,7 +474,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): print("🚀" + "=" * 76 + "🚀") # Check if CUDA implementation is available - if flash_dmattn_func is None: + if flash_sparse_attn_func is None: print("❌ CUDA implementation not available, skipping test.") return False @@ -734,7 +734,7 @@ def test_triton_backward_equivalence(accuracy_threshold=0.95): print("🚀" + "=" * 76 + "🚀") # Check if Triton implementation is available - if triton_dmattn_func is None: + if triton_sparse_attn_func is None: print("❌ Triton implementation not available, skipping test.") return False diff --git a/benchmarks/backward_performance.py b/benchmarks/backward_performance.py index 82deb8c..59daf16 100644 --- a/benchmarks/backward_performance.py +++ b/benchmarks/backward_performance.py @@ -28,33 +28,33 @@ # Import the compiled CUDA extension try: - from flash_dmattn.flash_dmattn_interface import flash_dmattn_func - print("✅ Successfully imported flash_dmattn interface") + from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn interface") except ImportError as e: - print(f"❌ Failed to import flash_dmattn interface: {e}") + print(f"❌ Failed to import flash_sparse_attn interface: {e}") print("Please make sure the package is properly installed with: pip install .") # Don't exit here, just warn - flash_dmattn_func = None + flash_sparse_attn_func = None # Import the Triton implementation try: - from flash_dmattn.flash_dmattn_triton import triton_dmattn_func - print("✅ Successfully imported flash_dmattn_triton") + from flash_sparse_attn.flash_sparse_attn_triton import triton_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn_triton") except ImportError as e: - print(f"❌ Failed to import flash_dmattn_triton: {e}") + print(f"❌ Failed to import flash_sparse_attn_triton: {e}") print("Please make sure the Triton implementation is available.") # Don't exit here, just warn - triton_dmattn_func = None + triton_sparse_attn_func = None # Import the Flex Attention implementation try: - from flash_dmattn.flash_dmattn_flex import flex_dmattn_func - print("✅ Successfully imported flash_dmattn_flex") + from flash_sparse_attn.flash_sparse_attn_flex import flex_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn_flex") except ImportError as e: - print(f"❌ Failed to import flash_dmattn_flex: {e}") + print(f"❌ Failed to import flash_sparse_attn_flex: {e}") print("Please make sure the Flex Attention implementation is available.") # Don't exit here, just warn - flex_dmattn_func = None + flex_sparse_attn_func = None def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -207,7 +207,7 @@ def dynamic_mask_attention_backward_cuda( Returns: tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ - if flash_dmattn_func is None: + if flash_sparse_attn_func is None: return "Not Available", 0 attn_bias, attn_mask = prepare_mask( @@ -223,7 +223,7 @@ def dynamic_mask_attention_backward_cuda( value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] try: - attn_outputs = flash_dmattn_func( + attn_outputs = flash_sparse_attn_func( query=query_states, key=key_states, value=value_states, @@ -277,7 +277,7 @@ def dynamic_mask_attention_backward_triton( Returns: tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ - if triton_dmattn_func is None: + if triton_sparse_attn_func is None: return "Not Available", 0 _, num_heads, _, _ = query_states.shape @@ -305,7 +305,7 @@ def dynamic_mask_attention_backward_triton( attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] try: - attn_outputs = triton_dmattn_func( + attn_outputs = triton_sparse_attn_func( query=query_states, key=key_states, value=value_states, @@ -356,7 +356,7 @@ def dynamic_mask_attention_backward_flex( Returns: tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ - if flex_dmattn_func is None: + if flex_sparse_attn_func is None: return "Not Available", 0 _, num_heads, _, _ = query_states.shape @@ -384,7 +384,7 @@ def dynamic_mask_attention_backward_flex( attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] try: - attn_outputs = flex_dmattn_func( + attn_outputs = flex_sparse_attn_func( query_states, key_states, value_states, diff --git a/benchmarks/forward_equivalence.py b/benchmarks/forward_equivalence.py index 8baff70..9b05ba3 100644 --- a/benchmarks/forward_equivalence.py +++ b/benchmarks/forward_equivalence.py @@ -21,33 +21,33 @@ # Import the compiled CUDA extension try: - from flash_dmattn.flash_dmattn_interface import flash_dmattn_func - print("✅ Successfully imported flash_dmattn interface") + from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn interface") except ImportError as e: - print(f"❌ Failed to import flash_dmattn interface: {e}") + print(f"❌ Failed to import flash_sparse_attn interface: {e}") print("Please make sure the package is properly installed with: pip install .") # Don't exit here, just warn - flash_dmattn_func = None + flash_sparse_attn_func = None # Import the Triton implementation try: - from flash_dmattn.flash_dmattn_triton import triton_dmattn_func - print("✅ Successfully imported flash_dmattn_triton") + from flash_sparse_attn.flash_sparse_attn_triton import triton_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn_triton") except ImportError as e: - print(f"❌ Failed to import flash_dmattn_triton: {e}") + print(f"❌ Failed to import flash_sparse_attn_triton: {e}") print("Please make sure the Triton implementation is available.") # Don't exit here, just warn - triton_dmattn_func = None + triton_sparse_attn_func = None # Import the Flex Attention implementation try: - from flash_dmattn.flash_dmattn_flex import flex_dmattn_func - print("✅ Successfully imported flash_dmattn_flex") + from flash_sparse_attn.flash_sparse_attn_flex import flex_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn_flex") except ImportError as e: - print(f"❌ Failed to import flash_dmattn_flex: {e}") + print(f"❌ Failed to import flash_sparse_attn_flex: {e}") print("Please make sure the Flex Attention implementation is available.") # Don't exit here, just warn - flex_dmattn_func = None + flex_sparse_attn_func = None def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -181,8 +181,8 @@ def dynamic_mask_attention_cuda( Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] """ - if flash_dmattn_func is None: - raise RuntimeError("flash_dmattn_func not available") + if flash_sparse_attn_func is None: + raise RuntimeError("flash_sparse_attn_func not available") attn_bias, attn_mask = prepare_mask( query_states, @@ -196,8 +196,8 @@ def dynamic_mask_attention_cuda( key_states = key_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] value_states = value_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] - # Call the flash_dmattn_func interface - attn_outputs = flash_dmattn_func( + # Call the flash_sparse_attn_func interface + attn_outputs = flash_sparse_attn_func( query_states, key_states, value_states, @@ -239,7 +239,7 @@ def dynamic_mask_attention_triton( Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] """ - if triton_dmattn_func is None: + if triton_sparse_attn_func is None: raise RuntimeError("Triton implementation not available") _, num_heads, _, _ = query_states.shape @@ -267,7 +267,7 @@ def dynamic_mask_attention_triton( attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] # Call the Triton implementation - attn_outputs = triton_dmattn_func( + attn_outputs = triton_sparse_attn_func( query_states, key_states, value_states, @@ -306,7 +306,7 @@ def dynamic_mask_attention_flex( Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] """ - if flex_dmattn_func is None: + if flex_sparse_attn_func is None: raise RuntimeError("Flex Attention implementation not available") _, num_heads, _, _ = query_states.shape @@ -334,7 +334,7 @@ def dynamic_mask_attention_flex( attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] # Call the Flex Attention implementation - attn_outputs = flex_dmattn_func( + attn_outputs = flex_sparse_attn_func( query_states, key_states, value_states, @@ -446,7 +446,7 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): print("🚀" + "=" * 76 + "🚀") # Check if CUDA implementation is available - if flash_dmattn_func is None: + if flash_sparse_attn_func is None: print("❌ CUDA implementation not available, skipping test.") return False @@ -653,7 +653,7 @@ def test_triton_forward_equivalence(accuracy_threshold=0.95): print("🔬 Testing Forward Pass Equivalence: Python vs Triton 🔬") print("🔥" + "=" * 76 + "🔥") - if triton_dmattn_func is None: + if triton_sparse_attn_func is None: print("❌ Triton implementation not available, skipping Triton tests") return False @@ -859,7 +859,7 @@ def test_flex_forward_equivalence(accuracy_threshold=0.95): print("🔬 Testing Forward Pass Equivalence: Python vs Flex Attention 🔬") print("🌟" + "=" * 76 + "🌟") - if flex_dmattn_func is None: + if flex_sparse_attn_func is None: print("❌ Flex Attention implementation not available, skipping Flex Attention tests") return False diff --git a/benchmarks/forward_performance.py b/benchmarks/forward_performance.py index 5730e0e..05e75c4 100644 --- a/benchmarks/forward_performance.py +++ b/benchmarks/forward_performance.py @@ -28,33 +28,33 @@ # Import the compiled CUDA extension try: - from flash_dmattn.flash_dmattn_interface import flash_dmattn_func - print("✅ Successfully imported flash_dmattn interface") + from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn interface") except ImportError as e: - print(f"❌ Failed to import flash_dmattn interface: {e}") + print(f"❌ Failed to import flash_sparse_attn interface: {e}") print("Please make sure the package is properly installed with: pip install .") # Don't exit here, just warn - flash_dmattn_func = None + flash_sparse_attn_func = None # Import the Triton implementation try: - from flash_dmattn.flash_dmattn_triton import triton_dmattn_func - print("✅ Successfully imported flash_dmattn_triton") + from flash_sparse_attn.flash_sparse_attn_triton import triton_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn_triton") except ImportError as e: - print(f"❌ Failed to import flash_dmattn_triton: {e}") + print(f"❌ Failed to import flash_sparse_attn_triton: {e}") print("Please make sure the Triton implementation is available.") # Don't exit here, just warn - triton_dmattn_func = None + triton_sparse_attn_func = None # Import the Flex Attention implementation try: - from flash_dmattn.flash_dmattn_flex import flex_dmattn_func - print("✅ Successfully imported flash_dmattn_flex") + from flash_sparse_attn.flash_sparse_attn_flex import flex_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn_flex") except ImportError as e: - print(f"❌ Failed to import flash_dmattn_flex: {e}") + print(f"❌ Failed to import flash_sparse_attn_flex: {e}") print("Please make sure the Flex Attention implementation is available.") # Don't exit here, just warn - flex_dmattn_func = None + flex_sparse_attn_func = None def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -203,7 +203,7 @@ def dynamic_mask_attention_cuda( Returns: tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ - if flash_dmattn_func is None: + if flash_sparse_attn_func is None: return "Not Available", 0 attn_bias, attn_mask = prepare_mask( @@ -222,7 +222,7 @@ def dynamic_mask_attention_cuda( torch.cuda.synchronize() start_time = time.time() - attn_outputs = flash_dmattn_func( + attn_outputs = flash_sparse_attn_func( query_states, key_states, value_states, @@ -269,7 +269,7 @@ def dynamic_mask_attention_triton( Returns: tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ - if triton_dmattn_func is None: + if triton_sparse_attn_func is None: return "Not Available", 0 _, num_heads, _, _ = query_states.shape @@ -300,7 +300,7 @@ def dynamic_mask_attention_triton( torch.cuda.synchronize() start_time = time.time() - attn_outputs = triton_dmattn_func( + attn_outputs = triton_sparse_attn_func( query_states, key_states, value_states, @@ -344,7 +344,7 @@ def dynamic_mask_attention_flex( Returns: tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ - if flex_dmattn_func is None: + if flex_sparse_attn_func is None: return "Not Available", 0 _, num_heads, _, _ = query_states.shape @@ -376,7 +376,7 @@ def dynamic_mask_attention_flex( start_time = time.time() # Call the Flex Attention implementation - attn_outputs = flex_dmattn_func( + attn_outputs = flex_sparse_attn_func( query_states, key_states, value_states, diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_sparse_attn/flash_api.cpp similarity index 97% rename from csrc/flash_dmattn/flash_api.cpp rename to csrc/flash_sparse_attn/flash_api.cpp index 4a67ec1..2bfa20f 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_sparse_attn/flash_api.cpp @@ -126,7 +126,7 @@ void set_params_fprop( // Set the different scale values. #ifdef FLASHATTENTION_DISABLE_SOFTCAP - TORCH_CHECK(softcap <= 0.0, "This flash dynamic mask attention build does not support softcap."); + TORCH_CHECK(softcap <= 0.0, "This flash sparse attention build does not support softcap."); #endif if (softcap > 0.0) { params.softcap = softmax_scale / softcap; @@ -145,7 +145,7 @@ void set_params_fprop( params.is_seqlens_k_cumulative = true; #ifdef FLASHATTENTION_DISABLE_UNEVEN_K - TORCH_CHECK(d == d_rounded, "This flash dynamic mask attention build does not support headdim not being a multiple of 32."); + TORCH_CHECK(d == d_rounded, "This flash sparse attention build does not support headdim not being a multiple of 32."); #endif params.unpadded_lse = unpadded_lse; @@ -366,10 +366,10 @@ mha_fwd( at::cuda::CUDAGuard device_guard{q.device()}; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x_min = cc_major >= 8; - TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); + TORCH_CHECK(is_sm8x_min, "FlashSparseAttention only supports Ampere GPUs or newer."); auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashSparseAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); @@ -420,7 +420,7 @@ mha_fwd( const int seqlen_k_rounded = round_multiple(seqlen_k, 128); TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256"); + TORCH_CHECK(head_size <= 256, "FlashSparseAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); @@ -577,10 +577,10 @@ mha_varlen_fwd( at::cuda::CUDAGuard device_guard{q.device()}; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x_min = cc_major >= 8; - TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); + TORCH_CHECK(is_sm8x_min, "FlashSparseAttention only supports Ampere GPUs or newer."); auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashSparseAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); @@ -644,7 +644,7 @@ mha_varlen_fwd( const int total_q = q.sizes()[0]; TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256"); + TORCH_CHECK(head_size <= 256, "FlashSparseAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); @@ -810,19 +810,19 @@ mha_bwd( ) { #ifdef FLASHATTENTION_DISABLE_BACKWARD - TORCH_CHECK(false, "This flash dynamic mask attention build does not support backward."); + TORCH_CHECK(false, "This flash sparse attention build does not support backward."); #endif // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x_min = cc_major >= 8; - TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); + TORCH_CHECK(is_sm8x_min, "FlashSparseAttention only supports Ampere GPUs or newer."); auto stream = at::cuda::getCurrentCUDAStream().stream(); auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashSparseAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); @@ -881,7 +881,7 @@ mha_bwd( TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); - TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention backward only supports head dimension at most 256"); + TORCH_CHECK(head_size <= 256, "FlashSparseAttention backward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (has_mask) { @@ -1072,19 +1072,19 @@ mha_varlen_bwd( ) { #ifdef FLASHATTENTION_DISABLE_BACKWARD - TORCH_CHECK(false, "This flash dynamic mask attention build does not support backward."); + TORCH_CHECK(false, "This flash sparse attention build does not support backward."); #endif // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x_min = cc_major >= 8; - TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); + TORCH_CHECK(is_sm8x_min, "FlashSparseAttention only supports Ampere GPUs or newer."); auto stream = at::cuda::getCurrentCUDAStream().stream(); auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashSparseAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); @@ -1124,7 +1124,7 @@ mha_varlen_bwd( const int num_heads_bias = has_bias ? bias.size(1) : 1; TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); - TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention backward only supports head dimension at most 256"); + TORCH_CHECK(head_size <= 256, "FlashSparseAttention backward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; @@ -1268,7 +1268,7 @@ mha_varlen_bwd( } // namespace FLASH_NAMESPACE PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "FlashDynamicMaskAttention"; + m.doc() = "FlashSparseAttention"; m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass"); m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass with variable length"); m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass"); diff --git a/csrc/flash_dmattn/src/block_info.h b/csrc/flash_sparse_attn/src/block_info.h similarity index 100% rename from csrc/flash_dmattn/src/block_info.h rename to csrc/flash_sparse_attn/src/block_info.h diff --git a/csrc/flash_dmattn/src/flash.h b/csrc/flash_sparse_attn/src/flash.h similarity index 100% rename from csrc/flash_dmattn/src/flash.h rename to csrc/flash_sparse_attn/src/flash.h diff --git a/csrc/flash_dmattn/src/flash_bwd_kernel.h b/csrc/flash_sparse_attn/src/flash_bwd_kernel.h similarity index 100% rename from csrc/flash_dmattn/src/flash_bwd_kernel.h rename to csrc/flash_sparse_attn/src/flash_bwd_kernel.h diff --git a/csrc/flash_dmattn/src/flash_bwd_launch_template.h b/csrc/flash_sparse_attn/src/flash_bwd_launch_template.h similarity index 98% rename from csrc/flash_dmattn/src/flash_bwd_launch_template.h rename to csrc/flash_sparse_attn/src/flash_bwd_launch_template.h index 00712b8..a6a3717 100644 --- a/csrc/flash_dmattn/src/flash_bwd_launch_template.h +++ b/csrc/flash_sparse_attn/src/flash_bwd_launch_template.h @@ -24,7 +24,7 @@ namespace FLASH_NAMESPACE { #endif // Define a macro for unsupported architecture handling to centralize the error message -#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashDynamicMaskAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); +#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashSparseAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); // Use a macro to clean up kernel definitions #define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \ diff --git a/csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h b/csrc/flash_sparse_attn/src/flash_bwd_preprocess_kernel.h similarity index 100% rename from csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h rename to csrc/flash_sparse_attn/src/flash_bwd_preprocess_kernel.h diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_sparse_attn/src/flash_fwd_kernel.h similarity index 100% rename from csrc/flash_dmattn/src/flash_fwd_kernel.h rename to csrc/flash_sparse_attn/src/flash_fwd_kernel.h diff --git a/csrc/flash_dmattn/src/flash_fwd_launch_template.h b/csrc/flash_sparse_attn/src/flash_fwd_launch_template.h similarity index 99% rename from csrc/flash_dmattn/src/flash_fwd_launch_template.h rename to csrc/flash_sparse_attn/src/flash_fwd_launch_template.h index 9c3d94b..412db39 100644 --- a/csrc/flash_dmattn/src/flash_fwd_launch_template.h +++ b/csrc/flash_sparse_attn/src/flash_fwd_launch_template.h @@ -23,7 +23,7 @@ namespace FLASH_NAMESPACE { #endif // Define a macro for unsupported architecture handling to centralize the error message -#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashDynamicMaskAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); +#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashSparseAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); // Use a macro to clean up kernel definitions #define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ diff --git a/csrc/flash_dmattn/src/generate_kernels.py b/csrc/flash_sparse_attn/src/generate_kernels.py similarity index 97% rename from csrc/flash_dmattn/src/generate_kernels.py rename to csrc/flash_sparse_attn/src/generate_kernels.py index 54d5a72..3626243 100644 --- a/csrc/flash_dmattn/src/generate_kernels.py +++ b/csrc/flash_sparse_attn/src/generate_kernels.py @@ -113,7 +113,7 @@ def main(output_dir: Optional[str]) -> None: if __name__ == "__main__": parser = argparse.ArgumentParser( prog="generate_kernels", - description="Generate the flash_dmattn kernels template instantiations", + description="Generate the flash_sparse_attn kernels template instantiations", ) parser.add_argument( "-o", diff --git a/csrc/flash_dmattn/src/hardware_info.h b/csrc/flash_sparse_attn/src/hardware_info.h similarity index 100% rename from csrc/flash_dmattn/src/hardware_info.h rename to csrc/flash_sparse_attn/src/hardware_info.h diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/kernel_traits.h b/csrc/flash_sparse_attn/src/kernel_traits.h similarity index 100% rename from csrc/flash_dmattn/src/kernel_traits.h rename to csrc/flash_sparse_attn/src/kernel_traits.h diff --git a/csrc/flash_dmattn/src/mask.h b/csrc/flash_sparse_attn/src/mask.h similarity index 100% rename from csrc/flash_dmattn/src/mask.h rename to csrc/flash_sparse_attn/src/mask.h diff --git a/csrc/flash_dmattn/src/namespace_config.h b/csrc/flash_sparse_attn/src/namespace_config.h similarity index 100% rename from csrc/flash_dmattn/src/namespace_config.h rename to csrc/flash_sparse_attn/src/namespace_config.h diff --git a/csrc/flash_dmattn/src/softmax.h b/csrc/flash_sparse_attn/src/softmax.h similarity index 100% rename from csrc/flash_dmattn/src/softmax.h rename to csrc/flash_sparse_attn/src/softmax.h diff --git a/csrc/flash_dmattn/src/static_switch.h b/csrc/flash_sparse_attn/src/static_switch.h similarity index 100% rename from csrc/flash_dmattn/src/static_switch.h rename to csrc/flash_sparse_attn/src/static_switch.h diff --git a/csrc/flash_dmattn/src/utils.h b/csrc/flash_sparse_attn/src/utils.h similarity index 100% rename from csrc/flash_dmattn/src/utils.h rename to csrc/flash_sparse_attn/src/utils.h diff --git a/docs/api_reference.md b/docs/api_reference.md index 65bbebf..f6b4316 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -1,9 +1,9 @@ -# Flash Dynamic Mask Attention API Reference +# Flash Sparse Attention API Reference ## Overview -Flash Dynamic Mask Attention is a high-performance attention implementation that combines the memory efficiency of Flash Attention with the sparse compute benefits of Dynamic Mask Attention. It supports CUDA, Triton, and Flex Attention backends and dynamic masking for very long sequences. +Flash Sparse Attention is a high-performance attention implementation that combines the memory efficiency of Flash Attention with the sparse compute benefits of Dynamic Mask Attention. It supports CUDA, Triton, and Flex Attention backends and dynamic masking for very long sequences. ## Table of Contents @@ -12,9 +12,9 @@ Flash Dynamic Mask Attention is a high-performance attention implementation that 2. [Quick Start](#quick-start) 3. [Backend Selection and Comparison](#backend-selection-and-comparison) 4. [API Reference](#api-reference) - - [CUDA Backend: flash_dmattn_func](#flash_dmattn_func-cuda-backend) - - [Triton Backend: triton_dmattn_func](#triton_dmattn_func-triton-backend) - - [Flex Backend: flex_dmattn_func](#flex_dmattn_func-flex-backend) + - [CUDA Backend: flash_sparse_attn_func](#flash_sparse_attn_func-cuda-backend) + - [Triton Backend: triton_sparse_attn_func](#triton_sparse_attn_func-triton-backend) + - [Flex Backend: flex_sparse_attn_func](#flex_sparse_attn_func-flex-backend) 5. [Integrations](#integrations) - [Transformers Integration](#transformers-integration) 6. [Common Issues and Solutions](#common-issues-and-solutions) @@ -22,27 +22,27 @@ Flash Dynamic Mask Attention is a high-performance attention implementation that ## Installation -Please refer to the [README](https://github.com/SmallDoges/flash-dmattn/blob/main/README.md#install) for detailed installation instructions. +Please refer to the [README](https://github.com/SmallDoges/flash-sparse-attention/blob/main/README.md#install) for detailed installation instructions. ```bash # With CUDA backend -pip install flash-dmattn +pip install flash-sparse-attn # Or install from source pip install -e . # Triton/Flex only -FLASH_DMATTN_SKIP_CUDA_BUILD=1 pip install -e . +FLASH_SPARSE_ATTENTION_SKIP_CUDA_BUILD=1 pip install -e . ``` ## Quick Start -Use `flash_dmattn_func_auto` to automatically select the best available backend without manual checking. +Use `flash_sparse_attn_func_auto` to automatically select the best available backend without manual checking. ```python import torch -from flash_dmattn import flash_dmattn_func_auto +from flash_sparse_attn import flash_sparse_attn_func_auto # Prepare input tensors batch, seqlen, num_heads, head_dim = 2, 1024, 8, 64 @@ -51,19 +51,19 @@ k = torch.randn(batch, seqlen, num_heads, head_dim, dtype=torch.bfloat16, device v = torch.randn(batch, seqlen, num_heads, head_dim, dtype=torch.bfloat16, device='cuda') # Get attention function (auto-select backend, priority: cuda > triton > flex) -attn_func = flash_dmattn_func_auto() +attn_func = flash_sparse_attn_func_auto() # Compute attention output = attn_func(q, k, v, is_causal=True) print(f"Output shape: {output.shape}") # (2, 1024, 8, 64) # Or force a specific backend -attn_func = flash_dmattn_func_auto(backend="cuda") # or "triton", "flex" +attn_func = flash_sparse_attn_func_auto(backend="cuda") # or "triton", "flex" output = attn_func(q, k, v, is_causal=True) ``` > [!NOTE] -> `flash_dmattn_func_auto` returns a callable attention function, not the attention output. +> `flash_sparse_attn_func_auto` returns a callable attention function, not the attention output. ## Backend Selection and Comparison @@ -71,7 +71,7 @@ output = attn_func(q, k, v, is_causal=True) ### Check Available Backends ```python -from flash_dmattn import get_available_backends, CUDA_AVAILABLE, TRITON_AVAILABLE, FLEX_AVAILABLE +from flash_sparse_attn import get_available_backends, CUDA_AVAILABLE, TRITON_AVAILABLE, FLEX_AVAILABLE # List all available backends print(get_available_backends()) # e.g., ["cuda", "triton", "flex"] @@ -101,19 +101,19 @@ print(f"CUDA: {CUDA_AVAILABLE}, Triton: {TRITON_AVAILABLE}, Flex: {FLEX_AVAILABL ### When to Use Each Backend -**CUDA Backend** ([details](#flash_dmattn_func-cuda-backend)) +**CUDA Backend** ([details](#flash_sparse_attn_func-cuda-backend)) - ✅ Training workloads requiring full gradient support - ✅ Production inference requiring maximum performance - ✅ Applications needing deterministic behavior - ❌ Avoid: when custom CUDA extensions cannot be built -**Triton Backend** ([details](#triton_dmattn_func-triton-backend)) +**Triton Backend** ([details](#triton_sparse_attn_func-triton-backend)) - ✅ Training when CUDA extension unavailable - ✅ Development and prototyping - ✅ Cross-platform compatibility needs - ✅ Good balance of performance and ease of installation -**Flex Backend** ([details](#flex_dmattn_func-flex-backend)) +**Flex Backend** ([details](#flex_sparse_attn_func-flex-backend)) - ✅ Inference-only applications - ✅ Research with latest PyTorch features - ✅ Quick experimentation without custom builds @@ -123,15 +123,15 @@ print(f"CUDA: {CUDA_AVAILABLE}, Triton: {TRITON_AVAILABLE}, Flex: {FLEX_AVAILABL ### Import Available Functions ```python -from flash_dmattn import ( +from flash_sparse_attn import ( # Automatic backend selection get_available_backends, - flash_dmattn_func_auto, + flash_sparse_attn_func_auto, # Backend-specific functions - flash_dmattn_func, # CUDA backend - triton_dmattn_func, # Triton backend - flex_dmattn_func, # Flex backend + flash_sparse_attn_func, # CUDA backend + triton_sparse_attn_func, # Triton backend + flex_sparse_attn_func, # Flex backend # Backend availability flags CUDA_AVAILABLE, @@ -140,20 +140,20 @@ from flash_dmattn import ( ) # Transformers integration -from flash_dmattn.integrations.flash_dynamic_mask_attention import ( - flash_dynamic_mask_attention_forward +from flash_sparse_attn.integrations.flash_sparse_attention import ( + flash_sparse_attention_forward ) ``` ## API Reference -### flash_dmattn_func (CUDA backend) +### flash_sparse_attn_func (CUDA backend) Main attention function. Supports multi-head and grouped-query attention (when the number of KV heads is smaller than the number of Q heads). Requires the CUDA extension to be built and available. ```python -def flash_dmattn_func( +def flash_sparse_attn_func( query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) key: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim) value: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim) @@ -182,12 +182,12 @@ def flash_dmattn_func( - output: (B, Q, H, D) -### triton_dmattn_func (Triton backend) +### triton_sparse_attn_func (Triton backend) Triton-based implementation that provides good performance without requiring custom CUDA kernels. ```python -def triton_dmattn_func( +def triton_sparse_attn_func( query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) key: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) value: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) @@ -198,12 +198,12 @@ def triton_dmattn_func( ) -> torch.Tensor ``` -### flex_dmattn_func (Flex Attention backend) +### flex_sparse_attn_func (Flex Attention backend) Flex Attention-based implementation using PyTorch's native flex attention with dynamic masking support. ```python -def flex_dmattn_func( +def flex_sparse_attn_func( query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) key: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) value: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) @@ -221,13 +221,13 @@ def flex_dmattn_func( Integration function for HuggingFace Transformers models that provides seamless flash dynamic mask attention support. -#### flash_dynamic_mask_attention_forward +#### flash_sparse_attention_forward ```python -from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward +from flash_sparse_attn.integrations.flash_sparse_attention import flash_sparse_attention_forward -def flash_dynamic_mask_attention_forward( +def flash_sparse_attention_forward( module: torch.nn.Module, # The attention module query: torch.Tensor, # (batch_size, num_heads, query_len, head_dim) key: torch.Tensor, # (batch_size, num_kv_heads, key_len, head_dim) @@ -254,7 +254,7 @@ def flash_dynamic_mask_attention_forward( - is_causal: Whether to apply causal mask - window_size: Size of window to keep - layer_idx: Layer index for logging - - implementation: Implementation to use ("flash_dmattn" or None) + - implementation: Implementation to use ("flash_sparse_attn" or None) #### Returns @@ -268,7 +268,7 @@ import torch.nn as nn import torch.nn.functional as F from typing import Optional, Callable, tuple from transformers.cache_utils import Cache -from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward +from flash_sparse_attn.integrations.flash_sparse_attention import flash_sparse_attention_forward class DynamicMaskAttention(nn.Module): def __init__(self, config, layer_idx: Optional[int] = None): @@ -332,7 +332,7 @@ class DynamicMaskAttention(nn.Module): attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype) # Choose attention implementation - attention_interface: Callable = flash_dynamic_mask_attention_forward + attention_interface: Callable = flash_sparse_attention_forward attn_output, attn_weights = attention_interface( self, @@ -362,7 +362,7 @@ This example shows: ```python try: - from flash_dmattn import flash_dmattn_func_auto, get_available_backends + from flash_sparse_attn import flash_sparse_attn_func_auto, get_available_backends print("✅ Imported successfully", get_available_backends()) except ImportError as e: print(f"❌ Import failed: {e}") @@ -385,10 +385,10 @@ except ImportError as e: ```python import torch -from flash_dmattn import flash_dmattn_func_auto +from flash_sparse_attn import flash_sparse_attn_func_auto torch.autograd.set_detect_anomaly(True) -attn = flash_dmattn_func_auto() +attn = flash_sparse_attn_func_auto() output = attn(q, k, v, attn_mask=attn_mask, attn_bias=attn_bias, is_causal=True) if torch.isnan(output).any(): print("⚠️ NaN detected in attention output") @@ -404,7 +404,7 @@ def print_memory_stats(): print(f"max alloc: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB") print_memory_stats() -attn = flash_dmattn_func_auto() +attn = flash_sparse_attn_func_auto() output = attn(q, k, v) print_memory_stats() ``` diff --git a/docs/api_reference_zh.md b/docs/api_reference_zh.md index 83d1617..3676d16 100644 --- a/docs/api_reference_zh.md +++ b/docs/api_reference_zh.md @@ -1,9 +1,9 @@ -# Flash Dynamic Mask Attention API 参考文档 +# Flash Sparse Attention API 参考文档 ## 概述 -Flash Dynamic Mask Attention 是一个高性能注意力实现,结合了 Flash Attention 的内存效率和 Dynamic Mask Attention 的稀疏计算优势。它支持 CUDA、Triton 和 Flex Attention 后端,并支持超长序列的动态掩码。 +Flash Sparse Attention 是一个高性能注意力实现,结合了 Flash Attention 的内存效率和 Dynamic Mask Attention 的稀疏计算优势。它支持 CUDA、Triton 和 Flex Attention 后端,并支持超长序列的动态掩码。 ## 目录 @@ -12,9 +12,9 @@ Flash Dynamic Mask Attention 是一个高性能注意力实现,结合了 Flash 2. [快速开始](#快速开始) 3. [后端选择与比较](#后端选择与比较) 4. [接口函数详解](#接口函数详解) - - [CUDA 后端:flash_dmattn_func](#flash_dmattn_func-cuda-后端) - - [Triton 后端:triton_dmattn_func](#triton_dmattn_func-triton-后端) - - [Flex 后端:flex_dmattn_func](#flex_dmattn_func-flex-后端) + - [CUDA 后端:flash_sparse_attn_func](#flash_sparse_attn_func-cuda-后端) + - [Triton 后端:triton_sparse_attn_func](#triton_sparse_attn_func-triton-后端) + - [Flex 后端:flex_sparse_attn_func](#flex_sparse_attn_func-flex-后端) 5. [集成](#集成) - [Transformers 集成](#transformers-集成) 6. [常见问题与解决方案](#常见问题与解决方案) @@ -22,27 +22,26 @@ Flash Dynamic Mask Attention 是一个高性能注意力实现,结合了 Flash ## 安装 -请参考 [README](https://github.com/SmallDoges/flash-dmattn/blob/main/README_zh.md#%E5%AE%89%E8%A3%85-1) 以获取详细的安装说明和依赖项。 +请参考 [README](https://github.com/SmallDoges/flash-sparse-attention/blob/main/README_zh.md#%E5%AE%89%E8%A3%85-1) 以获取详细的安装说明和依赖项。 ```bash # 使用 CUDA 后端 -pip install flash-dmattn - +pip install flash-sparse-attn # 或从源码安装 pip install -e . # 仅使用 Triton/Flex 后端 -FLASH_DMATTN_SKIP_CUDA_BUILD=1 pip install -e . +FLASH_SPARSE_ATTENTION_SKIP_CUDA_BUILD=1 pip install -e . ``` ## 快速开始 -使用 `flash_dmattn_func_auto` 可以自动选择最佳可用后端,无需手动判断。 +使用 `flash_sparse_attn_func_auto` 可以自动选择最佳可用后端,无需手动判断。 ```python import torch -from flash_dmattn import flash_dmattn_func_auto +from flash_sparse_attn import flash_sparse_attn_func_auto # 准备输入张量 batch, seqlen, num_heads, head_dim = 2, 1024, 8, 64 @@ -51,19 +50,19 @@ k = torch.randn(batch, seqlen, num_heads, head_dim, dtype=torch.bfloat16, device v = torch.randn(batch, seqlen, num_heads, head_dim, dtype=torch.bfloat16, device='cuda') # 获取注意力函数(自动选择后端,优先级: cuda > triton > flex) -attn_func = flash_dmattn_func_auto() +attn_func = flash_sparse_attn_func_auto() # 调用注意力计算 output = attn_func(q, k, v, is_causal=True) print(f"输出形状: {output.shape}") # (2, 1024, 8, 64) # 也可以强制使用特定后端 -attn_func = flash_dmattn_func_auto(backend="cuda") # 或 "triton", "flex" +attn_func = flash_sparse_attn_func_auto(backend="cuda") # 或 "triton", "flex" output = attn_func(q, k, v, is_causal=True) ``` > [!NOTE] -> `flash_dmattn_func_auto` 返回一个可调用的注意力函数,而不是注意力输出。 +> `flash_sparse_attn_func_auto` 返回一个可调用的注意力函数,而不是注意力输出。 ## 后端选择与比较 @@ -71,7 +70,7 @@ output = attn_func(q, k, v, is_causal=True) ### 可用后端检查 ```python -from flash_dmattn import get_available_backends, CUDA_AVAILABLE, TRITON_AVAILABLE, FLEX_AVAILABLE +from flash_sparse_attn import get_available_backends, CUDA_AVAILABLE, TRITON_AVAILABLE, FLEX_AVAILABLE # 查看所有可用后端 print(get_available_backends()) # 例如:["cuda", "triton", "flex"] @@ -101,19 +100,19 @@ print(f"CUDA: {CUDA_AVAILABLE}, Triton: {TRITON_AVAILABLE}, Flex: {FLEX_AVAILABL ### 何时使用各个后端 -**CUDA 后端** ([详细说明](#flash_dmattn_func-cuda-后端)) +**CUDA 后端** ([详细说明](#flash_sparse_attn_func-cuda-后端)) - ✅ 完整梯度支持的训练工作负载 - ✅ 最大性能生产推理 - ✅ 需要确定性行为的应用 - ❌ 避免:无法构建自定义 CUDA 扩展时 -**Triton 后端** ([详细说明](#triton_dmattn_func-triton-后端)) +**Triton 后端** ([详细说明](#triton_sparse_attn_func-triton-后端)) - ✅ CUDA 扩展不可用时的训练工作负载 - ✅ 开发和原型设计 - ✅ 跨平台兼容性需求 - ✅ 性能和易安装性的良好平衡 -**Flex 后端** ([详细说明](#flex_dmattn_func-flex-后端)) +**Flex 后端** ([详细说明](#flex_sparse_attn_func-flex-后端)) - ✅ 仅推理应用 - ✅ 使用最新 PyTorch 特性的研究 - ✅ 无需自定义构建的快速实验 @@ -123,15 +122,15 @@ print(f"CUDA: {CUDA_AVAILABLE}, Triton: {TRITON_AVAILABLE}, Flex: {FLEX_AVAILABL ### 导入可用函数 ```python -from flash_dmattn import ( +from flash_sparse_attn import ( # 自动后端选择 get_available_backends, - flash_dmattn_func_auto, + flash_sparse_attn_func_auto, # 后端特定函数 - flash_dmattn_func, # CUDA 后端 - triton_dmattn_func, # Triton 后端 - flex_dmattn_func, # Flex 后端 + flash_sparse_attn_func, # CUDA 后端 + triton_sparse_attn_func, # Triton 后端 + flex_sparse_attn_func, # Flex 后端 # 后端可用性标志 CUDA_AVAILABLE, @@ -140,20 +139,20 @@ from flash_dmattn import ( ) # Transformers 集成 -from flash_dmattn.integrations.flash_dynamic_mask_attention import ( - flash_dynamic_mask_attention_forward +from flash_sparse_attn.integrations.flash_sparse_attention import ( + flash_sparse_attention_forward ) ``` ## 接口函数详解 -### flash_dmattn_func (CUDA 后端) +### flash_sparse_attn_func (CUDA 后端) 主要的注意力函数。支持多头注意力和分组查询注意力(当 KV 头数少于 Q 头数时)。需要 CUDA 扩展已构建并可用。 ```python -def flash_dmattn_func( +def flash_sparse_attn_func( query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) key: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim) value: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim) @@ -182,12 +181,12 @@ def flash_dmattn_func( - output: (B, Q, H, D) -### triton_dmattn_func (Triton 后端) +### triton_sparse_attn_func (Triton 后端) 基于 Triton 的实现,无需自定义 CUDA 内核即可提供良好性能。 ```python -def triton_dmattn_func( +def triton_sparse_attn_func( query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) key: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) value: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) @@ -198,12 +197,12 @@ def triton_dmattn_func( ) -> torch.Tensor ``` -### flex_dmattn_func (Flex Attention 后端) +### flex_sparse_attn_func (Flex Attention 后端) 基于 Flex Attention 的实现,使用 PyTorch 原生 flex attention 并支持动态掩码。 ```python -def flex_dmattn_func( +def flex_sparse_attn_func( query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) key: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) value: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) @@ -219,14 +218,14 @@ def flex_dmattn_func( ### Transformers 集成 -为 HuggingFace Transformers 模型提供的集成函数,提供无缝的 flash dynamic mask attention 支持。 +为 HuggingFace Transformers 模型提供的集成函数,提供无缝的 flash sparse attention 支持。 -#### flash_dynamic_mask_attention_forward +#### flash_sparse_attention_forward ```python -from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward +from flash_sparse_attn.integrations.flash_sparse_attention import flash_sparse_attention_forward -def flash_dynamic_mask_attention_forward( +def flash_sparse_attention_forward( module: torch.nn.Module, # 注意力模块 query: torch.Tensor, # (batch_size, num_heads, query_len, head_dim) key: torch.Tensor, # (batch_size, num_kv_heads, key_len, head_dim) @@ -253,7 +252,7 @@ def flash_dynamic_mask_attention_forward( - is_causal: 是否应用因果掩码 - window_size: 保持的窗口大小 - layer_idx: 用于日志的层索引 - - implementation: 使用的实现("flash_dmattn" 或 None) + - implementation: 使用的实现("flash_sparse_attn" 或 None) #### 返回值 @@ -267,7 +266,7 @@ import torch.nn as nn import torch.nn.functional as F from typing import Optional, Callable, tuple from transformers.cache_utils import Cache -from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward +from flash_sparse_attn.integrations.flash_sparse_attention import flash_sparse_attention_forward class DynamicMaskAttention(nn.Module): def __init__(self, config, layer_idx: Optional[int] = None): @@ -331,7 +330,7 @@ class DynamicMaskAttention(nn.Module): attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype) # 选择注意力实现 - attention_interface: Callable = flash_dynamic_mask_attention_forward + attention_interface: Callable = flash_sparse_attention_forward attn_output, attn_weights = attention_interface( self, @@ -361,7 +360,7 @@ class DynamicMaskAttention(nn.Module): ```python try: - from flash_dmattn import flash_dmattn_func_auto, get_available_backends + from flash_sparse_attn import flash_sparse_attn_func_auto, get_available_backends print("✅ 导入成功", get_available_backends()) except ImportError as e: print(f"❌ 导入失败: {e}") @@ -384,10 +383,10 @@ except ImportError as e: ```python import torch -from flash_dmattn import flash_dmattn_func_auto +from flash_sparse_attn import flash_sparse_attn_func_auto torch.autograd.set_detect_anomaly(True) -attn = flash_dmattn_func_auto() +attn = flash_sparse_attn_func_auto() output = attn(q, k, v, attn_mask=attn_mask, attn_bias=attn_bias, is_causal=True) if torch.isnan(output).any(): print("⚠️ 注意力输出中检测到 NaN") @@ -403,7 +402,7 @@ def print_memory_stats(): print(f"最大分配: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB") print_memory_stats() -attn = flash_dmattn_func_auto() +attn = flash_sparse_attn_func_auto() output = attn(q, k, v) print_memory_stats() ``` diff --git a/docs/integration.md b/docs/integration.md deleted file mode 100644 index 31a57d1..0000000 --- a/docs/integration.md +++ /dev/null @@ -1,2872 +0,0 @@ -# Flash Dynamic Mask Attention Integration Guide - -## Overview - -This document describes the integration of Dynamic Mask Attention into the Flash Attention framework. The integration enables efficient sparse attention computation by combining Flash Attention's memory-efficient approach with dynamic masking capabilities for handling extremely long sequences. - -The integration implements a unified sparse computation approach with block-level skip logic: Python frontend pre-computes Attention Mask and Attention Bias tensors, while the CUDA backend performs block-level skip decisions and sparse attention computation for both forward and backward passes. - -## Table of Contents - -1. [Integration Architecture](#integration-architecture) -2. [Core Modifications](#core-modifications) -3. [Implementation Details](#implementation-details) -4. [Sparse Computation Strategy](#sparse-computation-strategy) -5. [Memory Layout](#memory-layout) -6. [Performance Considerations](#performance-considerations) -7. [API Changes](#api-changes) - -## Integration Architecture - -### High-Level Design - -The Dynamic Mask Attention integration implements a unified sparse computation approach with block-level skip logic for both forward and backward passes: - -1. **Dynamic Mask Computation**: Python frontend pre-computes Attention Mask and Attention Bias tensors -2. **Unified Sparse Execution**: CUDA backend performs block-level skip decisions for both forward and backward passes -3. **Memory Optimization**: Smart shared memory aliasing and barrier synchronization - - -### Key Components - -- **Attention Mask**: Binary mask `(batch, num_kv_heads, query_len, key_len)` indicating which positions should be computed (1.0) or skipped (0.0) -- **Attention Bias**: Dynamic attention bias values `(batch, num_kv_heads, query_len, key_len)` applied to attention scores before softmax -- **Block-level Skip Logic**: Unified OR-reduction over (BlockM × BlockN) tiles to determine if computation should be performed -- **LSE Caching**: Log-sum-exp values cached during forward pass for numerically stable backward recomputation -- **Shared Memory Aliasing**: Smart memory reuse with explicit barrier synchronization -- **Complete Gradient Chain**: Full gradient computation pipeline with sparse skip capability -- **Memory Optimization**: Reduced shared memory footprint enabling larger tile sizes and higher occupancy - -## Core Modifications - -### 1. Parameter Structure Extensions (`flash.h`) - -**Purpose**: Extended parameter structures to support dynamic masking tensors with proper memory layout information. - -**Changes Made**: -```cpp -struct QKV_params { - // The QKV matrices. - void *__restrict__ q_ptr; // Query tensor [batch_size, num_heads, query_len, head_dim] - void *__restrict__ k_ptr; // Key tensor [batch_size, num_kv_heads, key_len, head_dim] - void *__restrict__ v_ptr; // Value tensor [batch_size, num_kv_heads, key_len, head_dim] - - // The stride between rows of the Q, K and V matrices. - index_t q_batch_stride, k_batch_stride, v_batch_stride; - index_t q_row_stride, k_row_stride, v_row_stride; - index_t q_head_stride, k_head_stride, v_head_stride; - - // The number of heads. - int h, h_k; - int h_h_k_ratio; // precompute h / h_k -}; - -struct Mask_params { - void * __restrict__ mask_ptr; // Attention mask tensor [batch_size, num_kv_heads, query_len, key_len] - - // The stride of the attention mask tensors. - index_t mask_batch_stride; // Stride between batches of attention mask - index_t mask_head_stride; // Stride between heads of attention mask - index_t mask_row_stride; // Stride between rows of attention mask -}; - -struct Bias_params { - void *__restrict__ bias_ptr; // Attention bias tensor [batch_size, num_kv_heads, query_len, key_len] - - // The stride of the attention bias tensor. - index_t bias_batch_stride; // Stride between batches of attention bias - index_t bias_head_stride; // Stride between heads of attention bias - index_t bias_row_stride; // Stride between rows of attention bias -}; - -struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_params { - - // The O matrix (output). - void * __restrict__ o_ptr; - void * __restrict__ oaccum_ptr; - - // The stride between rows of O. - index_t o_batch_stride; - index_t o_row_stride; - index_t o_head_stride; - - // The pointer to the P matrix. - void * __restrict__ p_ptr; - - // The pointer to the softmax sum. - void * __restrict__ softmax_lse_ptr; - void * __restrict__ softmax_lseaccum_ptr; - - // The dimensions. - int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q; - - // The scaling factors for the kernel. - float scale_softmax; - float scale_softmax_log2; - float softcap; - - // array of length b+1 holding starting offset of each sequence. - int * __restrict__ cu_seqlens_q; - int * __restrict__ cu_seqlens_k; - int * __restrict__ leftpad_k; - - // If provided, the actual length of each k sequence. - int * __restrict__ seqused_k; - - int *__restrict__ blockmask; - - // The K_new and V_new matrices. - void * __restrict__ knew_ptr; - void * __restrict__ vnew_ptr; - - // The stride between rows of the K_new and V_new matrices. - index_t knew_batch_stride; - index_t vnew_batch_stride; - index_t knew_row_stride; - index_t vnew_row_stride; - index_t knew_head_stride; - index_t vnew_head_stride; - - // The cos and sin matrices for rotary embedding. - void * __restrict__ rotary_cos_ptr; - void * __restrict__ rotary_sin_ptr; - - // The indices to index into the KV cache. - int * __restrict__ cache_batch_idx; - - // Paged KV cache - int * __restrict__ block_table; - index_t block_table_batch_stride; - int page_block_size; - - bool is_bf16; - bool is_causal; - - // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. - // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. - bool is_seqlens_k_cumulative; - - int num_splits; // For split-KV version - - bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. - bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d). -}; -``` - -**Rationale**: -- **Multiple Inheritance Design**: Cleanly separates QKV parameters from Mask/Bias parameters while maintaining unified access -- **Comprehensive Stride Information**: Provides all necessary stride information for efficient tensor indexing in CUDA kernels -- **Memory Layout Optimization**: Enables optimal memory access patterns for both regular and sparse tensors - -### 2. Kernel Traits and Memory Layout (`kernel_traits.h`) - -**Purpose**: Define kernel characteristics and memory layouts optimized for dynamic masking operations, supporting both SM75 and SM80+ architectures. - -**Changes Made**: -```cpp -template -struct Flash_kernel_traits { - using Element = elem_type; - using ElementAccum = float; - using index_t = int64_t; - - static constexpr int kHeadDim = kHeadDim_; - static constexpr int kBlockM = kBlockM_; - static constexpr int kBlockN = kBlockN_; - static constexpr int kNWarps = kNWarps_; - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - static constexpr bool Has_cp_async = true; - using MMA_Atom_Arch = std::conditional_t< - std::is_same_v, - MMA_Atom, - MMA_Atom - >; -#else - static constexpr bool Has_cp_async = false; - using MMA_Atom_Arch = MMA_Atom; -#endif - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 - using SmemCopyAtom = Copy_Atom; - using SmemCopyAtomTransposed = Copy_Atom; -#else - using SmemCopyAtom = Copy_Atom; - using SmemCopyAtomTransposed = Copy_Atom; -#endif - - // Specialized traits for mask and bias operations - using SmemCopyAtomMask = SmemCopyAtom; - using SmemCopyAtomBias = SmemCopyAtom; -}; -``` - -**Rationale**: -- **Architecture Adaptation**: Automatically selects optimal MMA atoms and copy operations based on GPU architecture -- **Type Safety**: Template-based design ensures type consistency across mask, bias, and attention operations -- **Performance Optimization**: Leverages specialized load/store instructions (LDSM) for maximum memory bandwidth - -### 3. Block Information Extension (`block_info.h`) - -**Purpose**: Calculate memory offsets for attention bias and attention masks within thread blocks, enabling efficient global memory access. - -**Changes Made**: -```cpp -template -struct BlockInfo { - template - __device__ BlockInfo(const Params ¶ms, const int bidb) - : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) - , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]) - , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) - , leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) - , seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : - (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k) - , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : - seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) - { - } - - template - __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { - return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; - } - - template - __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { - return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride; - } - - template - __forceinline__ __device__ index_t mask_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { - index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; - sum_s_k == -1 ? offset += leftpad_k : offset += uint32_t(sum_s_k + leftpad_k); - return offset; - } - - template - __forceinline__ __device__ index_t bias_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { - index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; - sum_s_k == -1 ? offset += leftpad_k : offset += uint32_t(sum_s_k + leftpad_k); - return offset; - } - - const int sum_s_q, sum_s_k; - const int actual_seqlen_q; - const int leftpad_k; - const int seqlen_k_cache; - const int actual_seqlen_k; -}; -``` - -**Rationale**: -- **Unified Offset Calculation**: Provides dedicated methods for calculating mask and bias tensor offsets -- **Variable Length Support**: Handles both fixed and variable length sequences through template specialization -- **Memory Access Optimization**: Encapsulates complex address arithmetic for efficient global memory access - -### 4. Memory Copy Operations (`utils.h`) - -**Purpose**: Implement efficient tensor operations and layout conversions optimized for Flash Attention's memory hierarchy. - -**Changes Made**: -```cpp -namespace FLASH_NAMESPACE { - -// Convert accumulator layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) -template -__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { - static_assert(decltype(size<0>(acc_layout))::value == 4); - static_assert(decltype(rank(acc_layout))::value == 3); - auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); -}; - -// Type conversion utilities for different precisions -template -__forceinline__ __device__ T convert_type(float x) { - return T(x); -} - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -template<> -__forceinline__ __device__ cutlass::bfloat16_t convert_type(float x) { - return cutlass::bfloat16_t(x); -} -#endif - -// Warp-level reduction operations -template -__forceinline__ __device__ float warp_reduce_sum(float x) { -#pragma unroll - for (int mask = THREADS / 2; mask > 0; mask >>= 1) { - x += __shfl_xor_sync(0xffffffff, x, mask); - } - return x; -} - -// GEMM operations with register and shared memory variants -template < - bool A_in_regs=false, bool B_in_regs=false, - typename Tensor0, typename Tensor1, typename Tensor2, - typename Tensor3, typename Tensor4, - typename TiledMma, typename TiledCopyA, typename TiledCopyB, - typename ThrCopyA, typename ThrCopyB -> -__forceinline__ __device__ void gemm( - Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, - Tensor3 &tCsA, Tensor4 &tCsB, - TiledMma tiled_mma, - TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, - ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B -) { - if constexpr (!A_in_regs) { - copy(smem_tiled_copy_A, tCsA, tCrA); - } - if constexpr (!B_in_regs) { - copy(smem_tiled_copy_B, tCsB, tCrB); - } - - // Perform matrix multiplication - gemm(tiled_mma, acc, tCrA, tCrB, acc); -} - -} // namespace FLASH_NAMESPACE -``` - -**Rationale**: -- **Layout Conversion**: Efficient transformation between MMA and row-column layouts for easier tensor manipulation -- **Multi-Precision Support**: Proper type conversion utilities for FP16 and BF16 operations -- **Memory Hierarchy Management**: Flexible GEMM operations supporting different data residency patterns -- **Performance Optimization**: Warp-level reductions and vectorized operations for maximum throughput - -### 5. Dynamic Masking Logic (`mask.h`) - -**Purpose**: Implement the core dynamic masking functionality that applies attention bias and attention masks during attention computation. - -**Changes Made**: -```cpp -template -__forceinline__ __device__ void apply_mask( - TensorType &tensor, - MaskType &mask, - BiasType &bias, - const float scale_softmax, - const int col_idx_offset_, - const int max_seqlen_k, - const int row_idx_offset, - const int max_seqlen_q, - const int warp_row_stride -) { - // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) - static_assert(TensorType::rank == 2, "Only support 2D Tensor"); - static_assert(MaskType::rank == 2, "Only support 2D Mask"); - static_assert(BiasType::rank == 2, "Only support 2D Bias"); - - const int lane_id = threadIdx.x % 32; - const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = Causal_mask ? - std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : - max_seqlen_k; - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling and bias or masking - tensor(coord) = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f) - ? -INFINITY - : tensor(coord) * scale_softmax + bias(coord); - } - } - } - } -} - -template -struct Mask { - const int max_seqlen_k, max_seqlen_q; - - __forceinline__ __device__ Mask( - const int max_seqlen_k, - const int max_seqlen_q - ) // Constructor - : max_seqlen_k(max_seqlen_k) - , max_seqlen_q(max_seqlen_q) { - }; - - template - __forceinline__ __device__ void apply_mask( - TensorType &tensor_, // acc_s (attention scores, MMA=4, MMA_M, MMA_N) - MaskType &tSrMask, // Attention Mask (MMA=4, MMA_M, MMA_N) - BiasType &tSrBias, // Attention Bias (MMA=4, MMA_M, MMA_N) - const float scale_softmax, // Scale for softmax - const int col_idx_offset_, // Column index offset - const int row_idx_offset, // Row index offset - const int warp_row_stride // Warp row stride - ) { - // Reshape tensors from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tensor_.layout())); - Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout())); - Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout())); - - const int lane_id = threadIdx.x % 32; - const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = Causal_mask ? - std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : - max_seqlen_k; - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling and bias or masking - tensor(coord) = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f) - ? -INFINITY - : tensor(coord) * scale_softmax + bias(coord); - } - } - } - } - } -}; -``` - -**Rationale**: -- **Register-Level Operations**: All masking operations performed in registers for maximum efficiency -- **Unified Masking Logic**: Combines causal masking, boundary checking, and dynamic masking in a single pass -- **Layout Conversion**: Properly handles MMA tensor layout conversion for efficient indexing -- **Numerical Stability**: Proper handling of infinity values for masked positions ensures stable softmax computation - -### 6. Backward Pass Integration (`flash_bwd_kernel.h`) - -**Purpose**: Extend backward pass computation to support dynamic masking with proper gradient computation for masked positions. - -**Changes Made**: -```cpp -struct Flash_bwd_params : public Flash_fwd_params { - - // The dO and dQKV and dBias matrices. - void *__restrict__ do_ptr; - void *__restrict__ dq_ptr; - void *__restrict__ dk_ptr; - void *__restrict__ dv_ptr; - void *__restrict__ dbias_ptr; - - // To accumulate dQ, dK, dV - void *__restrict__ dq_accum_ptr; - void *__restrict__ dk_accum_ptr; - void *__restrict__ dv_accum_ptr; - - // The stride between rows of the dO, dQ, dK and dV matrices. - index_t do_batch_stride; - index_t do_row_stride; - index_t do_head_stride; - index_t dq_batch_stride; - index_t dk_batch_stride; - index_t dv_batch_stride; - index_t dq_row_stride; - index_t dk_row_stride; - index_t dv_row_stride; - index_t dq_head_stride; - index_t dk_head_stride; - index_t dv_head_stride; - index_t dbias_batch_stride; - index_t dbias_head_stride; - index_t dbias_row_stride; - - // The pointer to the softmax d sum. - void *__restrict__ dsoftmax_sum; - - bool deterministic; - index_t dq_accum_split_stride; -}; - -template -inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { - // Backward pass computation with dynamic masking support - // Includes proper gradient computation through masked attention scores - // Maintains numerical stability for masked positions - - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - // Initialize block information and tensor views - const BlockInfo binfo(params, bidb); - - // Set up gradient computation with masking awareness - // Load bias and mask gradients when computing dBias - // Apply masking logic consistently with forward pass -} -``` - -**Rationale**: -- **Gradient Consistency**: Ensures gradients are computed consistently with forward pass masking logic -- **Memory Layout Preservation**: Maintains the same memory layout and stride patterns as forward pass -- **Numerical Stability**: Proper handling of gradients at masked positions to prevent NaN propagation - -### 7. Attention Kernel Modifications (`flash_fwd_kernel.h`) - -**Purpose**: Integrate dynamic masking into the core attention computation kernels while maintaining Flash Attention's memory efficiency and optimization strategies. - -**Changes Made**: -```cpp -template -inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - // Initialize block information - const BlockInfo binfo(params, bidb); - - // Set up tensor views for Q, K, V matrices - Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), - make_shape(binfo.actual_seqlen_q, Int{}), - make_stride(params.q_row_stride, _1{})); - - Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)), - make_shape(binfo.actual_seqlen_k, Int{}), - make_stride(params.k_row_stride, _1{})); - - // Set up mask and bias tensor views if available - Tensor mMask, mBias; - if (params.mask_ptr != nullptr) { - mMask = make_tensor(make_gmem_ptr(reinterpret_cast(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)), - make_shape(binfo.actual_seqlen_q, binfo.actual_seqlen_k), - make_stride(params.mask_row_stride, _1{})); - } - - if (params.bias_ptr != nullptr) { - mBias = make_tensor(make_gmem_ptr(reinterpret_cast(params.bias_ptr) + binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb)), - make_shape(binfo.actual_seqlen_q, binfo.actual_seqlen_k), - make_stride(params.bias_row_stride, _1{})); - } - - // Main computation loop with dynamic masking integration - for (int n_block = n_block_min; n_block < n_block_max; ++n_block) { - // Standard Flash Attention computation: Q*K^T - gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, - smem_tiled_copy_Q, smem_tiled_copy_K, - smem_thr_copy_Q, smem_thr_copy_K); - - // Apply dynamic masking if mask/bias tensors are provided - if (params.mask_ptr != nullptr || params.bias_ptr != nullptr) { - Mask mask(params.seqlen_k, params.seqlen_q); - mask.apply_mask(acc_s, tSrMask, tSrBias, params.scale_softmax, - n_block * Kernel_traits::kBlockN, m_block * Kernel_traits::kBlockM, - Kernel_traits::kBlockM); - } - - // Continue with softmax computation - softmax.template softmax_rescale_o( - acc_s, acc_o, params.scale_softmax_log2 - ); - - // Attention * V computation - gemm(acc_o, acc_s, tSrV, acc_s, tSsV, tiled_mma, - smem_tiled_copy_S, smem_tiled_copy_V, - smem_thr_copy_S, smem_thr_copy_V); - } -} -``` - -**Rationale**: -- **Seamless Integration**: Dynamic masking logic integrated into existing Flash Attention computation flow without affecting core performance -- **Memory Efficiency Preservation**: Maintains Flash Attention's tiling and shared memory optimization strategies -- **Conditional Execution**: Only applies masking operations when mask/bias tensors are actually provided -- **Template Specialization**: Compile-time optimization eliminates runtime branching for better performance - -### 8. Launch Template Updates (`flash_fwd_launch_template.h`) - -**Purpose**: Update kernel launch templates to support dynamic masking functionality with proper template instantiation and dispatch logic. - -**Changes Made**: -```cpp -// Determine if the architecture supports FLASH and define parameter modifiers -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#define ARCH_SUPPORTS_FLASH -#define KERNEL_PARAM_MODIFIER __grid_constant__ -#else -#define KERNEL_PARAM_MODIFIER -#endif - -// Define unsupported architecture error handling -#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashDynamicMaskAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); - -// Kernel definition macro for cleaner code -#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ -template \ -__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) - -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { - #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_attn(params); - #else - FLASH_UNSUPPORTED_ARCH - #endif -} - -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split) { - #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_attn_splitkv(params); - #else - FLASH_UNSUPPORTED_ARCH - #endif -} - -template -void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr size_t smem_size = Kernel_traits::kSmemSize; - - // Handle different precision types and head dimensions - BOOL_SWITCH(params.is_bf16, Is_Bf16, [&] { - using elem_type = std::conditional_t; - HEADDIM_SWITCH(params.d, [&] { - BOOL_SWITCH(params.seqlen_k % Kernel_traits::kBlockN == 0, Is_even_N, [&] { - BOOL_SWITCH(params.d == kHeadDim, Is_even_K, [&] { - SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { - auto kernel = &flash_fwd_kernel; - // Launch kernel with appropriate grid and block dimensions - kernel<<>>(params); - }); - }); - }); - }); - }); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -// Template instantiations for different configurations -template -void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream); -template -void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream); -template -void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream); -template -void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream); -template -void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream); -template -void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream); -``` - -**Rationale**: -- **Template Dispatch**: Efficient compile-time branching based on runtime parameters for optimal performance -- **Architecture Support**: Proper handling of different GPU architectures with appropriate error messages -- **Memory Management**: Correct shared memory allocation based on kernel requirements -- **Type Safety**: Strong typing through template parameters ensures correctness across different precisions - -**Purpose**: Update kernel launch functions to properly configure and validate dynamic masking parameters, ensuring correct shared memory allocation and kernel selection. - -**Changes Made**: -```cpp -template -void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - // Calculate shared memory requirements - constexpr size_t smem_size = Kernel_traits::kSmemSize; - - // Set up grid dimensions - const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; - dim3 grid(num_m_block, params.b, params.h); - - // Determine kernel variant based on sequence lengths and alignment - const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && - params.seqlen_k % Kernel_traits::kBlockN == 0 && - params.seqlen_q % Kernel_traits::kBlockM == 0; - const bool is_even_K = params.d == Kernel_traits::kHeadDim; - const bool return_softmax = params.p_ptr != nullptr; - - // Launch appropriate kernel variant with dynamic masking support - BOOL_SWITCH(is_even_MN, IsEvenMN, [&] { - BOOL_SWITCH(is_even_K, IsEvenK, [&] { - BOOL_SWITCH(return_softmax, ReturnSoftmax, [&] { - auto kernel = &flash_fwd_kernel; - - // Configure dynamic shared memory if needed - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - - // Launch kernel with extended parameter set - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - }); - }); -} - -template -void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - // Split-K variant launch with dynamic masking support - // Handles cases where sequence length exceeds single kernel capacity - static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); - static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); - - // Configure split parameters based on sequence length and hardware capabilities - const int num_splits = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; - // ... split-K launch logic with dynamic masking support -} -``` - -**Rationale**: -- **Resource Management**: Proper shared memory allocation and validation for extended tensor requirements -- **Kernel Selection**: Intelligent kernel variant selection based on problem size and hardware capabilities -- **Error Handling**: Comprehensive validation of parameters and device limits -- **Performance Optimization**: Compile-time optimizations through template specialization - -### 9. API Interface Extensions (`flash_api.cpp`) - -**Purpose**: Extend the Python-facing API to support dynamic masking tensors with comprehensive validation and backward compatibility. - -**Changes Made**: -```cpp -void set_params_fprop( - Flash_fwd_params ¶ms, - // ... existing parameters ... - const at::Tensor mask, // Attention mask tensor - const at::Tensor bias, // Attention bias tensor - // ... other parameters ... -) { - // Reset parameters and set basic properties - params = {}; - params.is_bf16 = q.dtype() == torch::kBFloat16; - - // Set attention mask pointers and strides - params.mask_ptr = mask.data_ptr(); - params.mask_batch_stride = mask.stride(-4); - params.mask_head_stride = mask.stride(-3); - params.mask_row_stride = mask.stride(-2); - - // Set attention bias pointers and strides - params.bias_ptr = bias.data_ptr(); - params.bias_batch_stride = bias.stride(-4); - params.bias_head_stride = bias.stride(-3); - params.bias_row_stride = bias.stride(-2); - - // ... existing parameter setup ... -} - -std::vector mha_fwd( - at::Tensor &q, // Query tensor - const at::Tensor &k, // Key tensor - const at::Tensor &v, // Value tensor - const at::Tensor &mask, // Attention mask tensor - const at::Tensor &bias, // Attention bias tensor - std::optional &out_, // Optional output tensor - const float softmax_scale, - bool is_causal, - const float softcap, - const bool return_softmax -) { - // Comprehensive input validation - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); - CHECK_DEVICE(mask); CHECK_DEVICE(bias); - CHECK_CONTIGUOUS(q); CHECK_CONTIGUOUS(k); CHECK_CONTIGUOUS(v); - CHECK_CONTIGUOUS(mask); CHECK_CONTIGUOUS(bias); - - // Validate tensor shapes - auto batch_size = q.size(0); - auto seqlen_q = q.size(1); - auto num_heads = q.size(2); - auto head_dim = q.size(3); - auto seqlen_k = k.size(1); - auto num_heads_k = k.size(2); - - CHECK_SHAPE(mask, batch_size, num_heads_k, seqlen_q, seqlen_k); - CHECK_SHAPE(bias, batch_size, num_heads_k, seqlen_q, seqlen_k); - - // Validate data types consistency - TORCH_CHECK(q.dtype() == k.dtype() && k.dtype() == v.dtype(), - "All QKV tensors must have the same dtype"); - TORCH_CHECK(mask.dtype() == q.dtype(), - "Attention mask must have the same dtype as QKV tensors"); - TORCH_CHECK(bias.dtype() == q.dtype(), - "Attention bias must have the same dtype as QKV tensors"); - - // Set up parameters and launch computation - Flash_fwd_params params; - set_params_fprop(params, batch_size, seqlen_q, seqlen_k, /* ... */, - q, k, v, mask, bias, /* ... */); - - // Launch kernel with appropriate configuration - run_mha_fwd(params, at::cuda::getCurrentCUDAStream()); - - // Return results - return {out, softmax_lse}; -} - -// Python binding -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "FlashDynamicMaskAttention"; - m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass with dynamic masking", - py::arg("q"), py::arg("k"), py::arg("v"), - py::arg("mask"), py::arg("bias"), // Updated arguments - py::arg("out") = py::none(), - py::arg("softmax_scale") = 0.0f, - py::arg("is_causal") = false, - py::arg("softcap") = 0.0f, - py::arg("return_softmax") = false); -} -``` - -**Rationale**: -- **Comprehensive Validation**: Thorough validation of all input tensors for shape, type, and device consistency -- **Backward Compatibility**: Maintains existing parameter order while adding new functionality -- **Error Handling**: Clear error messages for common usage mistakes -- **Type Safety**: Strict type checking to prevent runtime errors -- **Documentation**: Clear parameter documentation for Python users - -## Implementation Details - -### C++ API Interface (`flash_api.cpp`) - -The core C++ API provides the following main functions for Dynamic Mask Attention: - -```cpp -namespace FLASH_NAMESPACE { - -std::vector mha_fwd( - at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &mask, // batch_size x num_heads_k x seqlen_q x seqlen_k - const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k - std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) - const float softmax_scale, - bool is_causal, - const float softcap, - const bool return_softmax -); - -std::vector mha_varlen_fwd( - at::Tensor &q, // total_q x num_heads x round_multiple(head_size, 8) - const at::Tensor &k, // total_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &v, // total_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &mask, // total_q x num_heads_k x max_seqlen_k - const at::Tensor &bias, // total_q x num_heads_k x max_seqlen_k - std::optional &out_, // total_q x num_heads x round_multiple(head_size, 8) - const at::Tensor &cu_seqlens_q, // batch_size + 1 - const at::Tensor &cu_seqlens_k, // batch_size + 1 - std::optional &seqused_k, - std::optional &leftpad_k, - const int max_seqlen_q, - const int max_seqlen_k, - const float softmax_scale, - bool is_causal, - const float softcap, - const bool return_softmax -); - -std::vector mha_bwd( - const at::Tensor &dout, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) - const at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &mask, // batch_size x num_heads_k x seqlen_q x seqlen_k - const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k - const at::Tensor &out, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) - const at::Tensor &softmax_lse, // batch_size x num_heads x seqlen_q - std::optional &dq_, - std::optional &dk_, - std::optional &dv_, - std::optional &dbias_, - const float softmax_scale, - bool is_causal, - const float softcap, - bool deterministic, - std::optional gen_ -); - -} // namespace FLASH_NAMESPACE -``` - -### Parameter Setup and Validation - -The implementation includes comprehensive parameter validation and setup: - -```cpp -void set_params_fprop( - Flash_fwd_params ¶ms, - const size_t b, const size_t seqlen_q, const size_t seqlen_k, - const size_t seqlen_q_rounded, const size_t seqlen_k_rounded, - const size_t h, const size_t h_k, const size_t d, const size_t d_rounded, - const at::Tensor q, const at::Tensor k, const at::Tensor v, - const at::Tensor mask, const at::Tensor bias, at::Tensor out, - void *cu_seqlens_q_d, void *cu_seqlens_k_d, void *seqused_k, - void *p_d, void *softmax_lse_d, float softmax_scale, bool is_causal, - const float softcap, bool seqlenq_ngroups_swapped=false, - const bool unpadded_lse=false -) { - // Reset parameters - params = {}; - params.is_bf16 = q.dtype() == torch::kBFloat16; - - // Set tensor pointers - params.q_ptr = q.data_ptr(); - params.k_ptr = k.data_ptr(); - params.v_ptr = v.data_ptr(); - params.mask_ptr = mask.data_ptr(); - params.bias_ptr = bias.data_ptr(); - params.o_ptr = out.data_ptr(); - - // Set stride information (all strides are in elements, not bytes) - params.q_row_stride = q.stride(-3); - params.k_row_stride = k.stride(-3); - params.v_row_stride = v.stride(-3); - params.mask_row_stride = mask.stride(-2); - params.bias_row_stride = bias.stride(-2); - params.o_row_stride = out.stride(-3); - - params.q_head_stride = q.stride(-2); - params.k_head_stride = k.stride(-2); - params.v_head_stride = v.stride(-2); - params.mask_head_stride = mask.stride(-3); - params.bias_head_stride = bias.stride(-3); - params.o_head_stride = out.stride(-2); - - // Set batch stride information - if (cu_seqlens_q_d == nullptr) { - params.q_batch_stride = q.stride(0); - params.k_batch_stride = k.stride(0); - params.v_batch_stride = v.stride(0); - params.mask_batch_stride = mask.stride(0); - params.bias_batch_stride = bias.stride(0); - params.o_batch_stride = out.stride(0); - } - - // Set sequence length and dimension parameters - params.b = b; params.h = h; params.h_k = h_k; - params.h_h_k_ratio = h / h_k; - params.seqlen_q = seqlen_q; params.seqlen_k = seqlen_k; - params.seqlen_q_rounded = seqlen_q_rounded; - params.seqlen_k_rounded = seqlen_k_rounded; - params.d = d; params.d_rounded = d_rounded; - - // Set scaling and control parameters - params.scale_softmax = softmax_scale; - params.scale_softmax_log2 = softmax_scale * M_LOG2E; - params.softcap = softcap; - params.is_causal = is_causal; - params.unpadded_lse = unpadded_lse; - params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped; -} -``` - -### Python Binding and Interface - -The C++ functions are exposed to Python through PyBind11: - -```cpp -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "FlashDynamicMaskAttention"; - m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass"); - m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass with variable length"); - m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass"); - m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass with variable length"); -} -``` - -### Python Frontend Integration Example - -Dynamic Mask Attention can be integrated into transformer models as follows: - -```python -import torch -import torch.nn as nn -import flash_dmattn_cuda as flash_dmattn - -class DynamicMaskAttention(nn.Module): - def __init__(self, config): - super().__init__() - self.num_heads = config.num_attention_heads - self.num_kv_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // config.num_attention_heads - self.scaling = 1.0 / math.sqrt(self.head_dim) - - # Standard attention projections - self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) - - def forward(self, hidden_states, attention_mask=None, attention_bias=None): - batch_size, seq_len, _ = hidden_states.shape - - # Project to Q, K, V - query_states = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - - # Prepare mask and bias tensors with proper shapes - if attention_mask is None: - attention_mask = torch.ones((batch_size, self.num_kv_heads, seq_len, seq_len), - dtype=query_states.dtype, device=query_states.device) - - if attention_bias is None: - attention_bias = torch.zeros((batch_size, self.num_kv_heads, seq_len, seq_len), - dtype=query_states.dtype, device=query_states.device) - - # Call Flash Dynamic Mask Attention - output, _ = flash_dmattn.fwd( - query_states, key_states, value_states, - attention_mask, attention_bias, - None, # out - self.scaling, # softmax_scale - False, # is_causal - 0.0, # softcap - False # return_softmax - ) - - # Output projection - output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) - return self.o_proj(output) -``` - - # Call attention implementation - attn_output, attn_weights = flash_dynamic_mask_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - attention_bias=attn_bias, - scaling=self.scaling, - ) - - return attn_output, attn_weights -``` - -The attention bias generation process: - -1. **Value-based Dynamic States**: - ```python - dt_states = self.dt_proj(value_states_flattened) - dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) - ``` - -2. **Bias Expansion**: - ```python - attn_bias = dt_states[:, :, None, :].expand(-1, -1, query_len, -1) - ``` - -3. **Mask Processing**: Done internally in `_flash_dynamic_mask_attention_forward` - - -### CUDA Backend: Sparse Attention Computation - -The CUDA backend implements the sparse attention computation through `_flash_dynamic_mask_attention_forward`: - -```python -def _flash_dynamic_mask_attention_forward( - query_states, key_states, value_states, - attention_mask, attention_bias, - query_length, key_length, - is_causal, softmax_scale=None, softcap=None, - target_dtype=None, implementation=None, **kwargs -): - dtype = query_states.dtype - min_dtype = torch.finfo(dtype).min - batch_size, _, num_kv_heads, _ = key_states.shape - - # Initialize attention bias if not provided - if attention_bias is None: - attention_bias = torch.zeros( - (batch_size, num_kv_heads, query_length, key_length), - dtype=dtype, device=query_states.device - ) - - # Apply attention mask to bias - if attention_mask is not None: - attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) - attention_mask = attention_mask.to(dtype) - - # Call Flash Attention with dynamic masking - out = flash_dmattn_func( - query_states, key_states, value_states, - attn_mask=attention_mask, attn_bias=attention_bias, - softmax_scale=softmax_scale, is_causal=is_causal - ) - - return out[0] if isinstance(out, tuple) else out -``` - -The backend processing stages: - -1. **Bias Initialization**: Create zero bias tensor if not provided -2. **Mask Application**: Apply boolean attention mask to bias tensor -3. **Flash Attention Call**: Execute optimized CUDA kernels with sparse patterns - -#### Updated Forward Algorithm - -The implementation introduces unified block-level skip logic that optimizes computation by skipping entire tiles when they are fully masked: - -```cpp -// Forward pass with unified skip logic -for m_block in M_tiles: - load Q_tile - for n_block in N_tiles_stream: - load mask_block - any_active = OR(mask_block) // Block-level skip decision - if !any_active: - advance_pointers() // Skip computation, advance to next tile - continue - - // Only execute for active tiles - load K_tile, V_tile // Load data only when needed - S = Q_tile @ K_tile^T + bias_block // Sparse Q*K^T GEMM - S_masked = apply_mask(S, mask_block) // Apply dynamic masking - P = softmax(S_masked, LSE_cache) // Softmax with LSE caching - O_partial += P @ V_tile // Sparse Score*V GEMM -write O -``` - -Key improvements: -- **Block-level Skip Logic**: OR-reduction over entire (BlockM × BlockN) tile determines if computation is needed -- **Early Skip Decision**: Mask evaluation happens before expensive K/V loading and computation -- **Pointer Management**: Safe pointer advancement ensures correct memory layout for subsequent tiles - -#### Updated Backward Algorithm - -The backward pass also benefits from the unified skip logic, maintaining numerical correctness while significantly reducing computation for sparse patterns: - -```cpp -// Backward pass with unified skip logic -for m_block in reversed(M_tiles): - load Q_tile, dO_tile - init accum_dQ - for n_block in N_tiles_stream: - load mask_block - any_active = OR(mask_block) // Same skip decision as forward - if !any_active: - advance_pointers_zero_side_outputs() // Skip computation, zero side outputs - continue - - // Only execute for active tiles - load K_tile, V_tile - - # Recompute (identical to forward for active tiles) - S = Q_tile @ K_tile^T + bias_block - P = softmax(S, LSE_cache) // Use cached LSE for stability - - # Gradient computation chain (5 GEMMs) - dV += P^T @ dO_tile // Accumulate dV - dP = dO_tile @ V_tile^T // Compute dP - dS = g(P, dP) // dS = (dP - (P ⊙ dP).sum(axis)) * P - dQ += dS @ K_tile // Accumulate dQ - dK += dS^T @ Q_tile // Accumulate dK - write dQ, accumulate dK, dV -``` - -Key features: -- **Recomputation Strategy**: Forward computation is recomputed only for active tiles to maintain numerical precision -- **LSE Caching**: Uses cached log-sum-exp values from forward pass for stable softmax recomputation -- **Gradient Chain**: All five gradient GEMMs are skipped for fully masked tiles, maintaining mathematical correctness -- **Zero Handling**: Properly handles zero contributions from skipped tiles in accumulation - -#### Skip Logic Correctness - -The mathematical correctness of the skip logic relies on the following principles: - -1. **Forward Skip**: If a tile is entirely masked (active_mask = 0), its contribution to the output is exactly zero: - ``` - O_contribution = P @ V = 0 @ V = 0 - ``` - -2. **Backward Skip**: For fully masked tiles, all intermediate gradients are zero: - ``` - P = 0 ⟹ dS = 0 ⟹ dQ = dK = dV = 0 (from this tile) - ``` - -3. **LSE Preservation**: Skipped tiles don't contribute to the log-sum-exp, maintaining numerical stability. - -### Sparse Computation Strategy - -### Block-level Skip Logic - -The implementation introduces unified block-level skip logic that operates at the tile granularity rather than individual elements: - -1. **Tile-level Active Detection**: - ```cpp - any_active = OR_reduce(mask_block) // Single bit indicating if any position in tile is active - ``` - -2. **Skip Decision**: Binary branch based on tile activity: - ```cpp - if (!any_active) { - advance_pointers(); // Forward: skip all computation - advance_pointers_zero_outputs(); // Backward: skip computation, zero side outputs - continue; - } - ``` - -3. **Computational Benefits**: - - Skip entire K/V loads for inactive tiles - - Eliminate all 5 GEMMs in backward pass for inactive tiles - - Reduce memory bandwidth and arithmetic operations proportional to sparsity - -### Sparsity Pattern Recognition - -The Dynamic Mask Attention implements structured sparsity based on learned importance scores: - -1. **Attention Bias Computation**: Attention bias values are computed based on dynamic states derived from value tensors - - Learned projection matrices map value features to importance scores - - Coefficient parameters control the dynamic range of importance values - - Activation functions ensure appropriate bias magnitude - -2. **Binary Attention Mask**: - - 1.0 for positions that should be computed - - 0.0 for positions that should be skipped - -### Performance Model (Updated) - -For block-level sparsity with active tile fraction $p$, skip overhead ratio $\varepsilon$, and early-exit efficiency $\eta$: - -$$ -\text{Speedup} \approx \frac{1}{p + (1-p)(\varepsilon + \eta \cdot \text{LoadOverhead})} -$$ - -Where: -- $p$: fraction of active tiles -- $\varepsilon$: skip branching overhead -- $\eta$: efficiency of early memory load exit -- $\text{LoadOverhead}$: relative cost of K/V loading vs computation - -Upper bound as $\varepsilon, \eta \to 0$: $1/p$ - -### Shared Memory Aliasing - -The implementation introduces smart shared memory aliasing to reduce footprint and enable larger tile sizes: - -1. **sMask ↔ sP Aliasing**: Mask shared memory region is reused for storing softmax probabilities P after mask consumption -2. **sBias ↔ sdS Aliasing**: Bias shared memory region is reused for gradient computations dS -3. **Barrier Synchronization**: Explicit `__syncthreads()` calls ensure safe transitions between aliased usage - -```cpp -// Example aliasing pattern -load mask -> sMask -any_active = or_reduce(sMask) -if any_active: - compute S - __syncthreads() // ensure mask fully consumed - softmax -> write P into aliased region (sP) // reuse sMask region as sP - ... -__syncthreads() // ensure dS consumed -// reuse sBias region as sdS in next iteration -``` - -### Memory Efficiency Optimizations - -1. **Shared Memory Aliasing**: Smart reuse of memory regions (sMask ↔ sP, sBias ↔ sdS) with explicit barrier synchronization -2. **Block-level Skip**: Early exit from computation and memory loading for inactive tiles -3. **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability -4. **Register-Optimized Operations**: Critical masking and gradient operations performed in register memory -5. **Coalesced Memory Access**: Optimized access patterns for GPU memory hierarchy -6. **Template Specialization**: Compile-time optimization eliminates runtime branching overhead - -## Memory Layout - -### Tensor Memory Organization - -The Dynamic Mask Attention extends Flash Attention's memory layout to include attention masks and attention bias: - -``` -Global Memory Layout: -┌─────────────────────────────────────────────────────────────────┐ -│ Q: [batch, seqlen_q, num_heads, head_dim] │ -│ K: [batch, seqlen_k, num_heads_k, head_dim] │ -│ V: [batch, seqlen_k, num_heads_k, head_dim] │ -│ Mask: [batch, num_heads_k, seqlen_q, seqlen_k] │ -│ Bias: [batch, num_heads_k, seqlen_q, seqlen_k] │ -│ Output: [batch, seqlen_q, num_heads, head_dim] │ -└─────────────────────────────────────────────────────────────────┘ - -Shared Memory Layout (per thread block): -┌─────────────────────────────────────────────────────────────────────┐ -│ Q Tile: [kBlockM, head_dim] │ K Tile: [kBlockN, head_dim] │ -│ V Tile: [kBlockN, head_dim] │ S Tile: [kBlockM, kBlockN] │ -│ AM Tile: [kBlockM, kBlockN] │ Bias Tile: [kBlockM, kBlockN] │ -└─────────────────────────────────────────────────────────────────────┘ - -Register Memory (per thread): -┌─────────────────────────────────────────────────────────────────────┐ -│ Q Frag: [MMA_M, head_dim/N] │ K Frag: [MMA_N, head_dim/N] │ -│ V Frag: [MMA_N, head_dim/N] │ S Frag: [MMA_M, MMA_N] │ -│ AM Frag: [MMA_M, MMA_N] │ Bias Frag: [MMA_M, MMA_N] │ -│ Acc Frag: [MMA_M, head_dim/N] │ │ -└─────────────────────────────────────────────────────────────────────┘ -``` - -### Memory Access Patterns - -#### Attention Mask and Attention Bias Loading -```cpp -// Global to Shared Memory (coalesced access) -Tensor tSgBias = local_partition(mBias, smem_tiled_copy_Bias, thread_idx); -Tensor tSsBias = local_partition(sBias, smem_tiled_copy_Bias, thread_idx); - -// Each thread loads a contiguous chunk to maximize memory bandwidth -copy(smem_tiled_copy_Bias, tSgBias, tSsBias); - -// Shared to Register Memory (bank-conflict-free) -Tensor tSrBias = local_partition(sBias, smem_thr_copy_Bias, thread_idx); -copy(smem_thr_copy_Bias, tSsBias, tSrBias); -``` - -#### Memory Layout Transformations -```cpp -// Convert MMA accumulator layout to row-column layout for masking -// From: (MMA=4, MMA_M, MMA_N) -> (nrow=(2, MMA_M), ncol=(2, MMA_N)) -auto convert_layout_acc_rowcol = [](auto layout) { - return make_layout( - make_layout(make_shape(Int<2>{}, get<1>(layout.shape())), - make_stride(Int(layout.stride())* 2>{}, get<1>(layout.stride()))), - make_layout(make_shape(Int<2>{}, get<2>(layout.shape())), - make_stride(Int<1>{}, Int<2>{})) - ); -}; -``` - -### Shared Memory Optimization - -#### Bank Conflict Avoidance -- Attention bias and attention masks use the same copy patterns as Q/K/V to avoid bank conflicts -- Padding added when necessary to ensure 128-bit aligned access -- Thread block size chosen to maximize occupancy while maintaining memory efficiency - -#### Memory Coalescing -```cpp -// Example: Loading 128-bit aligned chunks for optimal bandwidth -using SmemCopyAtomBias = Copy_Atom; // 128-bit loads -using SmemCopyAtomAttnMask = Copy_Atom; -``` - -## Performance Considerations - -### Memory Efficiency -- **Shared Memory Aliasing**: Smart memory reuse (sMask ↔ sP, sBias ↔ sdS) reduces footprint by ~30% -- **Block-level Skip**: Early exit eliminates unnecessary memory loads for inactive tiles -- **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability -- **Coalesced Access**: Optimized tensor layouts for GPU memory hierarchy - -### Computational Efficiency -- **Unified Skip Logic**: Both forward and backward passes benefit from block-level computation skipping -- **5-GEMM Chain Skip**: Complete gradient computation chain skipped for inactive tiles -- **Early Branch Decision**: Mask OR-reduction happens before expensive K/V loads -- **Warp-Level Optimization**: Operations optimized for GPU warp execution model - -### Scalability -- **Block-level Granularity**: Tile-level sparsity more efficient than element-level for long sequences -- **Multi-Head Support**: Efficient handling of multiple attention heads with per-head sparsity patterns -- **Barrier Optimization**: Minimal synchronization overhead through smart aliasing strategies - -### Performance Model - -Expected speedup for various sparsity levels: -- **50% sparsity**: ~1.8x speedup -- **75% sparsity**: ~3.2x speedup -- **90% sparsity**: ~6.5x speedup - -Performance factors: -- Skip overhead typically <5% of dense computation time -- Memory bandwidth reduction scales linearly with sparsity -- Shared memory aliasing enables 20-30% larger tile sizes - -## API Changes - -### New Required Parameters - -The Dynamic Mask Attention integration introduces new required parameters to the forward pass: - -- **`attn_mask`** (`torch.Tensor`): Attention mask tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` - - Binary mask (1.0 = compute, 0.0 = skip) indicating which positions should be processed - - Determines the sparsity pattern for computational efficiency - -- **`attn_bias`** (`torch.Tensor`): Attention bias tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` - - Contains dynamic attention bias values applied to attention scores before softmax - - Must have the same dtype and device as Q/K/V tensors - -### Updated Function Signature - -```python -def fwd( - q: torch.Tensor, # Query tensor - k: torch.Tensor, # Key tensor - v: torch.Tensor, # Value tensor - attn_mask: torch.Tensor, # Attention mask (REQUIRED) - attn_bias: torch.Tensor, # Attention bias (REQUIRED) - out: Optional[torch.Tensor] = None, # Pre-allocated output - softmax_scale: float = None, # Attention scaling - is_causal: bool = False, # Causal masking - softcap: float = 0.0, # Soft capping - return_softmax: bool = False, # Return attention weights -) -> List[torch.Tensor] -``` - -### Backward Compatibility - -**Breaking Change Notice**: The integration requires attention bias and attention mask tensors as mandatory parameters. This is a breaking change from the original Flash Attention API. - -**Migration Path**: Users need to: -1. Add attention mask and bias generation logic to attention modules -2. Implement appropriate mask and bias computation within the attention forward pass -3. Ensure proper tensor shapes and dtypes for mask and bias tensors - -### Complete Usage Example - -```python -import torch -import torch.nn as nn -import flash_dmattn_cuda as flash_dmattn - -class DynamicMaskAttention(nn.Module): - def __init__(self, config): - super().__init__() - self.num_heads = config.num_attention_heads - self.num_kv_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // config.num_attention_heads - self.scaling = 1.0 / math.sqrt(self.head_dim) - - # Standard attention projections - self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) - - def forward(self, hidden_states, attention_mask=None, attention_bias=None): - batch_size, seq_len, _ = hidden_states.shape - - # Project to Q, K, V - query_states = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - - # Prepare mask and bias tensors with proper shapes - if attention_mask is None: - attention_mask = torch.ones((batch_size, self.num_kv_heads, seq_len, seq_len), - dtype=query_states.dtype, device=query_states.device) - - if attention_bias is None: - attention_bias = torch.zeros((batch_size, self.num_kv_heads, seq_len, seq_len), - dtype=query_states.dtype, device=query_states.device) - - # Call Flash Dynamic Mask Attention - output, _ = flash_dmattn.fwd( - query_states, key_states, value_states, - attention_mask, attention_bias, - None, # out - self.scaling, # softmax_scale - False, # is_causal - 0.0, # softcap - False # return_softmax - ) - - # Output projection - output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) - return self.o_proj(output) -``` - - # Call attention implementation - attn_output, attn_weights = flash_dynamic_mask_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - attention_bias=attn_bias, - scaling=self.scaling, - ) - - return attn_output, attn_weights -``` - -The attention bias generation process: - -1. **Value-based Dynamic States**: - ```python - dt_states = self.dt_proj(value_states_flattened) - dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) - ``` - -2. **Bias Expansion**: - ```python - attn_bias = dt_states[:, :, None, :].expand(-1, -1, query_len, -1) - ``` - -3. **Mask Processing**: Done internally in `_flash_dynamic_mask_attention_forward` - - -### CUDA Backend: Sparse Attention Computation - -The CUDA backend implements the sparse attention computation through `_flash_dynamic_mask_attention_forward`: - -```python -def _flash_dynamic_mask_attention_forward( - query_states, key_states, value_states, - attention_mask, attention_bias, - query_length, key_length, - is_causal, softmax_scale=None, softcap=None, - target_dtype=None, implementation=None, **kwargs -): - dtype = query_states.dtype - min_dtype = torch.finfo(dtype).min - batch_size, _, num_kv_heads, _ = key_states.shape - - # Initialize attention bias if not provided - if attention_bias is None: - attention_bias = torch.zeros( - (batch_size, num_kv_heads, query_length, key_length), - dtype=dtype, device=query_states.device - ) - - # Apply attention mask to bias - if attention_mask is not None: - attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) - attention_mask = attention_mask.to(dtype) - - # Call Flash Attention with dynamic masking - out = flash_dmattn_func( - query_states, key_states, value_states, - attn_mask=attention_mask, attn_bias=attention_bias, - softmax_scale=softmax_scale, is_causal=is_causal - ) - - return out[0] if isinstance(out, tuple) else out -``` - -The backend processing stages: - -1. **Bias Initialization**: Create zero bias tensor if not provided -2. **Mask Application**: Apply boolean attention mask to bias tensor -3. **Flash Attention Call**: Execute optimized CUDA kernels with sparse patterns - -#### Updated Forward Algorithm - -The implementation introduces unified block-level skip logic that optimizes computation by skipping entire tiles when they are fully masked: - -```cpp -// Forward pass with unified skip logic -for m_block in M_tiles: - load Q_tile - for n_block in N_tiles_stream: - load mask_block - any_active = OR(mask_block) // Block-level skip decision - if !any_active: - advance_pointers() // Skip computation, advance to next tile - continue - - // Only execute for active tiles - load K_tile, V_tile // Load data only when needed - S = Q_tile @ K_tile^T + bias_block // Sparse Q*K^T GEMM - S_masked = apply_mask(S, mask_block) // Apply dynamic masking - P = softmax(S_masked, LSE_cache) // Softmax with LSE caching - O_partial += P @ V_tile // Sparse Score*V GEMM -write O -``` - -Key improvements: -- **Block-level Skip Logic**: OR-reduction over entire (BlockM × BlockN) tile determines if computation is needed -- **Early Skip Decision**: Mask evaluation happens before expensive K/V loading and computation -- **Pointer Management**: Safe pointer advancement ensures correct memory layout for subsequent tiles - -#### Updated Backward Algorithm - -The backward pass also benefits from the unified skip logic, maintaining numerical correctness while significantly reducing computation for sparse patterns: - -```cpp -// Backward pass with unified skip logic -for m_block in reversed(M_tiles): - load Q_tile, dO_tile - init accum_dQ - for n_block in N_tiles_stream: - load mask_block - any_active = OR(mask_block) // Same skip decision as forward - if !any_active: - advance_pointers_zero_side_outputs() // Skip computation, zero side outputs - continue - - // Only execute for active tiles - load K_tile, V_tile - - # Recompute (identical to forward for active tiles) - S = Q_tile @ K_tile^T + bias_block - P = softmax(S, LSE_cache) // Use cached LSE for stability - - # Gradient computation chain (5 GEMMs) - dV += P^T @ dO_tile // Accumulate dV - dP = dO_tile @ V_tile^T // Compute dP - dS = g(P, dP) // dS = (dP - (P ⊙ dP).sum(axis)) * P - dQ += dS @ K_tile // Accumulate dQ - dK += dS^T @ Q_tile // Accumulate dK - write dQ, accumulate dK, dV -``` - -Key features: -- **Recomputation Strategy**: Forward computation is recomputed only for active tiles to maintain numerical precision -- **LSE Caching**: Uses cached log-sum-exp values from forward pass for stable softmax recomputation -- **Gradient Chain**: All five gradient GEMMs are skipped for fully masked tiles, maintaining mathematical correctness -- **Zero Handling**: Properly handles zero contributions from skipped tiles in accumulation - -#### Skip Logic Correctness - -The mathematical correctness of the skip logic relies on the following principles: - -1. **Forward Skip**: If a tile is entirely masked (active_mask = 0), its contribution to the output is exactly zero: - ``` - O_contribution = P @ V = 0 @ V = 0 - ``` - -2. **Backward Skip**: For fully masked tiles, all intermediate gradients are zero: - ``` - P = 0 ⟹ dS = 0 ⟹ dQ = dK = dV = 0 (from this tile) - ``` - -3. **LSE Preservation**: Skipped tiles don't contribute to the log-sum-exp, maintaining numerical stability. - -### Sparse Computation Strategy - -### Block-level Skip Logic - -The implementation introduces unified block-level skip logic that operates at the tile granularity rather than individual elements: - -1. **Tile-level Active Detection**: - ```cpp - any_active = OR_reduce(mask_block) // Single bit indicating if any position in tile is active - ``` - -2. **Skip Decision**: Binary branch based on tile activity: - ```cpp - if (!any_active) { - advance_pointers(); // Forward: skip all computation - advance_pointers_zero_outputs(); // Backward: skip computation, zero side outputs - continue; - } - ``` - -3. **Computational Benefits**: - - Skip entire K/V loads for inactive tiles - - Eliminate all 5 GEMMs in backward pass for inactive tiles - - Reduce memory bandwidth and arithmetic operations proportional to sparsity - -### Sparsity Pattern Recognition - -The Dynamic Mask Attention implements structured sparsity based on learned importance scores: - -1. **Attention Bias Computation**: Attention bias values are computed based on dynamic states derived from value tensors - - Learned projection matrices map value features to importance scores - - Coefficient parameters control the dynamic range of importance values - - Activation functions ensure appropriate bias magnitude - -2. **Binary Attention Mask**: - - 1.0 for positions that should be computed - - 0.0 for positions that should be skipped - -### Performance Model (Updated) - -For block-level sparsity with active tile fraction $p$, skip overhead ratio $\varepsilon$, and early-exit efficiency $\eta$: - -$$ -\text{Speedup} \approx \frac{1}{p + (1-p)(\varepsilon + \eta \cdot \text{LoadOverhead})} -$$ - -Where: -- $p$: fraction of active tiles -- $\varepsilon$: skip branching overhead -- $\eta$: efficiency of early memory load exit -- $\text{LoadOverhead}$: relative cost of K/V loading vs computation - -Upper bound as $\varepsilon, \eta \to 0$: $1/p$ - -### Shared Memory Aliasing - -The implementation introduces smart shared memory aliasing to reduce footprint and enable larger tile sizes: - -1. **sMask ↔ sP Aliasing**: Mask shared memory region is reused for storing softmax probabilities P after mask consumption -2. **sBias ↔ sdS Aliasing**: Bias shared memory region is reused for gradient computations dS -3. **Barrier Synchronization**: Explicit `__syncthreads()` calls ensure safe transitions between aliased usage - -```cpp -// Example aliasing pattern -load mask -> sMask -any_active = or_reduce(sMask) -if any_active: - compute S - __syncthreads() // ensure mask fully consumed - softmax -> write P into aliased region (sP) // reuse sMask region as sP - ... -__syncthreads() // ensure dS consumed -// reuse sBias region as sdS in next iteration -``` - -### Memory Efficiency Optimizations - -1. **Shared Memory Aliasing**: Smart reuse of memory regions (sMask ↔ sP, sBias ↔ sdS) with explicit barrier synchronization -2. **Block-level Skip**: Early exit from computation and memory loading for inactive tiles -3. **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability -4. **Register-Optimized Operations**: Critical masking and gradient operations performed in register memory -5. **Coalesced Memory Access**: Optimized access patterns for GPU memory hierarchy -6. **Template Specialization**: Compile-time optimization eliminates runtime branching overhead - -## Memory Layout - -### Tensor Memory Organization - -The Dynamic Mask Attention extends Flash Attention's memory layout to include attention masks and attention bias: - -``` -Global Memory Layout: -┌─────────────────────────────────────────────────────────────────┐ -│ Q: [batch, seqlen_q, num_heads, head_dim] │ -│ K: [batch, seqlen_k, num_heads_k, head_dim] │ -│ V: [batch, seqlen_k, num_heads_k, head_dim] │ -│ AttnMask: [batch, num_kv_heads, seqlen_q, seqlen_k] │ -│ Bias: [batch, num_kv_heads, seqlen_q, seqlen_k] │ -│ Output: [batch, seqlen_q, num_heads, head_dim] │ -└─────────────────────────────────────────────────────────────────┘ - -Shared Memory Layout (per thread block): -┌─────────────────────────────────────────────────────────────────────┐ -│ Q Tile: [kBlockM, head_dim] │ K Tile: [kBlockN, head_dim] │ -│ V Tile: [kBlockN, head_dim] │ S Tile: [kBlockM, kBlockN] │ -│ AM Tile: [kBlockM, kBlockN] │ Bias Tile: [kBlockM, kBlockN] │ -└─────────────────────────────────────────────────────────────────────┘ - -Register Memory (per thread): -┌─────────────────────────────────────────────────────────────────────┐ -│ Q Frag: [MMA_M, head_dim/N] │ K Frag: [MMA_N, head_dim/N] │ -│ V Frag: [MMA_N, head_dim/N] │ S Frag: [MMA_M, MMA_N] │ -│ AM Frag: [MMA_M, MMA_N] │ Bias Frag: [MMA_M, MMA_N] │ -│ Acc Frag: [MMA_M, head_dim/N] │ │ -└─────────────────────────────────────────────────────────────────────┘ -``` - -### Memory Access Patterns - -#### Attention Mask and Attention Bias Loading -```cpp -// Global to Shared Memory (coalesced access) -Tensor tSgBias = local_partition(mBias, smem_tiled_copy_Bias, thread_idx); -Tensor tSsBias = local_partition(sBias, smem_tiled_copy_Bias, thread_idx); - -// Each thread loads a contiguous chunk to maximize memory bandwidth -copy(smem_tiled_copy_Bias, tSgBias, tSsBias); - -// Shared to Register Memory (bank-conflict-free) -Tensor tSrBias = local_partition(sBias, smem_thr_copy_Bias, thread_idx); -copy(smem_thr_copy_Bias, tSsBias, tSrBias); -``` - -#### Memory Layout Transformations -```cpp -// Convert MMA accumulator layout to row-column layout for masking -// From: (MMA=4, MMA_M, MMA_N) -> (nrow=(2, MMA_M), ncol=(2, MMA_N)) -auto convert_layout_acc_rowcol = [](auto layout) { - return make_layout( - make_layout(make_shape(Int<2>{}, get<1>(layout.shape())), - make_stride(Int(layout.stride())* 2>{}, get<1>(layout.stride()))), - make_layout(make_shape(Int<2>{}, get<2>(layout.shape())), - make_stride(Int<1>{}, Int<2>{})) - ); -}; -``` - -### Shared Memory Optimization - -#### Bank Conflict Avoidance -- Attention bias and attention masks use the same copy patterns as Q/K/V to avoid bank conflicts -- Padding added when necessary to ensure 128-bit aligned access -- Thread block size chosen to maximize occupancy while maintaining memory efficiency - -#### Memory Coalescing -```cpp -// Example: Loading 128-bit aligned chunks for optimal bandwidth -using SmemCopyAtomBias = Copy_Atom; // 128-bit loads -using SmemCopyAtomAttnMask = Copy_Atom; -``` - -## Performance Considerations - -### Memory Efficiency -- **Shared Memory Aliasing**: Smart memory reuse (sMask ↔ sP, sBias ↔ sdS) reduces footprint by ~30% -- **Block-level Skip**: Early exit eliminates unnecessary memory loads for inactive tiles -- **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability -- **Coalesced Access**: Optimized tensor layouts for GPU memory hierarchy - -### Computational Efficiency -- **Unified Skip Logic**: Both forward and backward passes benefit from block-level computation skipping -- **5-GEMM Chain Skip**: Complete gradient computation chain skipped for inactive tiles -- **Early Branch Decision**: Mask OR-reduction happens before expensive K/V loads -- **Warp-Level Optimization**: Operations optimized for GPU warp execution model - -### Scalability -- **Block-level Granularity**: Tile-level sparsity more efficient than element-level for long sequences -- **Multi-Head Support**: Efficient handling of multiple attention heads with per-head sparsity patterns -- **Barrier Optimization**: Minimal synchronization overhead through smart aliasing strategies - -### Performance Model - -Expected speedup for various sparsity levels: -- **50% sparsity**: ~1.8x speedup -- **75% sparsity**: ~3.2x speedup -- **90% sparsity**: ~6.5x speedup - -Performance factors: -- Skip overhead typically <5% of dense computation time -- Memory bandwidth reduction scales linearly with sparsity -- Shared memory aliasing enables 20-30% larger tile sizes - -## API Changes - -### New Required Parameters - -The Dynamic Mask Attention integration introduces new required parameters to the forward pass: - -- **`attn_mask`** (`torch.Tensor`): Attention mask tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` - - Binary mask (1.0 = compute, 0.0 = skip) indicating which positions should be processed - - Determines the sparsity pattern for computational efficiency - -- **`attn_bias`** (`torch.Tensor`): Attention bias tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` - - Contains dynamic attention bias values applied to attention scores before softmax - - Must have the same dtype and device as Q/K/V tensors - -### Updated Function Signature - -```python -def fwd( - q: torch.Tensor, # Query tensor - k: torch.Tensor, # Key tensor - v: torch.Tensor, # Value tensor - attn_mask: torch.Tensor, # Attention mask (REQUIRED) - attn_bias: torch.Tensor, # Attention bias (REQUIRED) - out: Optional[torch.Tensor] = None, # Pre-allocated output - softmax_scale: float = None, # Attention scaling - is_causal: bool = False, # Causal masking - softcap: float = 0.0, # Soft capping - return_softmax: bool = False, # Return attention weights -) -> List[torch.Tensor] -``` - -### Backward Compatibility - -**Breaking Change Notice**: The integration requires attention bias and attention mask tensors as mandatory parameters. This is a breaking change from the original Flash Attention API. - -**Migration Path**: Users need to: -1. Add attention mask and bias generation logic to attention modules -2. Implement appropriate mask and bias computation within the attention forward pass -3. Ensure proper tensor shapes and dtypes for mask and bias tensors - -### Complete Usage Example - -```python -import torch -import torch.nn as nn -import flash_dmattn_cuda as flash_dmattn - -class DynamicMaskAttention(nn.Module): - def __init__(self, config): - super().__init__() - self.num_heads = config.num_attention_heads - self.num_kv_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // config.num_attention_heads - self.scaling = 1.0 / math.sqrt(self.head_dim) - - # Standard attention projections - self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) - - def forward(self, hidden_states, attention_mask=None, attention_bias=None): - batch_size, seq_len, _ = hidden_states.shape - - # Project to Q, K, V - query_states = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - - # Prepare mask and bias tensors with proper shapes - if attention_mask is None: - attention_mask = torch.ones((batch_size, self.num_kv_heads, seq_len, seq_len), - dtype=query_states.dtype, device=query_states.device) - - if attention_bias is None: - attention_bias = torch.zeros((batch_size, self.num_kv_heads, seq_len, seq_len), - dtype=query_states.dtype, device=query_states.device) - - # Call Flash Dynamic Mask Attention - output, _ = flash_dmattn.fwd( - query_states, key_states, value_states, - attention_mask, attention_bias, - None, # out - self.scaling, # softmax_scale - False, # is_causal - 0.0, # softcap - False # return_softmax - ) - - # Output projection - output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) - return self.o_proj(output) -``` - - # Call attention implementation - attn_output, attn_weights = flash_dynamic_mask_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - attention_bias=attn_bias, - scaling=self.scaling, - ) - - return attn_output, attn_weights -``` - -The attention bias generation process: - -1. **Value-based Dynamic States**: - ```python - dt_states = self.dt_proj(value_states_flattened) - dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) - ``` - -2. **Bias Expansion**: - ```python - attn_bias = dt_states[:, :, None, :].expand(-1, -1, query_len, -1) - ``` - -3. **Mask Processing**: Done internally in `_flash_dynamic_mask_attention_forward` - - -### CUDA Backend: Sparse Attention Computation - -The CUDA backend implements the sparse attention computation through `_flash_dynamic_mask_attention_forward`: - -```python -def _flash_dynamic_mask_attention_forward( - query_states, key_states, value_states, - attention_mask, attention_bias, - query_length, key_length, - is_causal, softmax_scale=None, softcap=None, - target_dtype=None, implementation=None, **kwargs -): - dtype = query_states.dtype - min_dtype = torch.finfo(dtype).min - batch_size, _, num_kv_heads, _ = key_states.shape - - # Initialize attention bias if not provided - if attention_bias is None: - attention_bias = torch.zeros( - (batch_size, num_kv_heads, query_length, key_length), - dtype=dtype, device=query_states.device - ) - - # Apply attention mask to bias - if attention_mask is not None: - attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) - attention_mask = attention_mask.to(dtype) - - # Call Flash Attention with dynamic masking - out = flash_dmattn_func( - query_states, key_states, value_states, - attn_mask=attention_mask, attn_bias=attention_bias, - softmax_scale=softmax_scale, is_causal=is_causal - ) - - return out[0] if isinstance(out, tuple) else out -``` - -The backend processing stages: - -1. **Bias Initialization**: Create zero bias tensor if not provided -2. **Mask Application**: Apply boolean attention mask to bias tensor -3. **Flash Attention Call**: Execute optimized CUDA kernels with sparse patterns - -#### Updated Forward Algorithm - -The implementation introduces unified block-level skip logic that optimizes computation by skipping entire tiles when they are fully masked: - -```cpp -// Forward pass with unified skip logic -for m_block in M_tiles: - load Q_tile - for n_block in N_tiles_stream: - load mask_block - any_active = OR(mask_block) // Block-level skip decision - if !any_active: - advance_pointers() // Skip computation, advance to next tile - continue - - // Only execute for active tiles - load K_tile, V_tile // Load data only when needed - S = Q_tile @ K_tile^T + bias_block // Sparse Q*K^T GEMM - S_masked = apply_mask(S, mask_block) // Apply dynamic masking - P = softmax(S_masked, LSE_cache) // Softmax with LSE caching - O_partial += P @ V_tile // Sparse Score*V GEMM -write O -``` - -Key improvements: -- **Block-level Skip Logic**: OR-reduction over entire (BlockM × BlockN) tile determines if computation is needed -- **Early Skip Decision**: Mask evaluation happens before expensive K/V loading and computation -- **Pointer Management**: Safe pointer advancement ensures correct memory layout for subsequent tiles - -#### Updated Backward Algorithm - -The backward pass also benefits from the unified skip logic, maintaining numerical correctness while significantly reducing computation for sparse patterns: - -```cpp -// Backward pass with unified skip logic -for m_block in reversed(M_tiles): - load Q_tile, dO_tile - init accum_dQ - for n_block in N_tiles_stream: - load mask_block - any_active = OR(mask_block) // Same skip decision as forward - if !any_active: - advance_pointers_zero_side_outputs() // Skip computation, zero side outputs - continue - - // Only execute for active tiles - load K_tile, V_tile - - # Recompute (identical to forward for active tiles) - S = Q_tile @ K_tile^T + bias_block - P = softmax(S, LSE_cache) // Use cached LSE for stability - - # Gradient computation chain (5 GEMMs) - dV += P^T @ dO_tile // Accumulate dV - dP = dO_tile @ V_tile^T // Compute dP - dS = g(P, dP) // dS = (dP - (P ⊙ dP).sum(axis)) * P - dQ += dS @ K_tile // Accumulate dQ - dK += dS^T @ Q_tile // Accumulate dK - write dQ, accumulate dK, dV -``` - -Key features: -- **Recomputation Strategy**: Forward computation is recomputed only for active tiles to maintain numerical precision -- **LSE Caching**: Uses cached log-sum-exp values from forward pass for stable softmax recomputation -- **Gradient Chain**: All five gradient GEMMs are skipped for fully masked tiles, maintaining mathematical correctness -- **Zero Handling**: Properly handles zero contributions from skipped tiles in accumulation - -#### Skip Logic Correctness - -The mathematical correctness of the skip logic relies on the following principles: - -1. **Forward Skip**: If a tile is entirely masked (active_mask = 0), its contribution to the output is exactly zero: - ``` - O_contribution = P @ V = 0 @ V = 0 - ``` - -2. **Backward Skip**: For fully masked tiles, all intermediate gradients are zero: - ``` - P = 0 ⟹ dS = 0 ⟹ dQ = dK = dV = 0 (from this tile) - ``` - -3. **LSE Preservation**: Skipped tiles don't contribute to the log-sum-exp, maintaining numerical stability. - -### Sparse Computation Strategy - -### Block-level Skip Logic - -The implementation introduces unified block-level skip logic that operates at the tile granularity rather than individual elements: - -1. **Tile-level Active Detection**: - ```cpp - any_active = OR_reduce(mask_block) // Single bit indicating if any position in tile is active - ``` - -2. **Skip Decision**: Binary branch based on tile activity: - ```cpp - if (!any_active) { - advance_pointers(); // Forward: skip all computation - advance_pointers_zero_outputs(); // Backward: skip computation, zero side outputs - continue; - } - ``` - -3. **Computational Benefits**: - - Skip entire K/V loads for inactive tiles - - Eliminate all 5 GEMMs in backward pass for inactive tiles - - Reduce memory bandwidth and arithmetic operations proportional to sparsity - -### Sparsity Pattern Recognition - -The Dynamic Mask Attention implements structured sparsity based on learned importance scores: - -1. **Attention Bias Computation**: Attention bias values are computed based on dynamic states derived from value tensors - - Learned projection matrices map value features to importance scores - - Coefficient parameters control the dynamic range of importance values - - Activation functions ensure appropriate bias magnitude - -2. **Binary Attention Mask**: - - 1.0 for positions that should be computed - - 0.0 for positions that should be skipped - -### Performance Model (Updated) - -For block-level sparsity with active tile fraction $p$, skip overhead ratio $\varepsilon$, and early-exit efficiency $\eta$: - -$$ -\text{Speedup} \approx \frac{1}{p + (1-p)(\varepsilon + \eta \cdot \text{LoadOverhead})} -$$ - -Where: -- $p$: fraction of active tiles -- $\varepsilon$: skip branching overhead -- $\eta$: efficiency of early memory load exit -- $\text{LoadOverhead}$: relative cost of K/V loading vs computation - -Upper bound as $\varepsilon, \eta \to 0$: $1/p$ - -### Shared Memory Aliasing - -The implementation introduces smart shared memory aliasing to reduce footprint and enable larger tile sizes: - -1. **sMask ↔ sP Aliasing**: Mask shared memory region is reused for storing softmax probabilities P after mask consumption -2. **sBias ↔ sdS Aliasing**: Bias shared memory region is reused for gradient computations dS -3. **Barrier Synchronization**: Explicit `__syncthreads()` calls ensure safe transitions between aliased usage - -```cpp -// Example aliasing pattern -load mask -> sMask -any_active = or_reduce(sMask) -if any_active: - compute S - __syncthreads() // ensure mask fully consumed - softmax -> write P into aliased region (sP) // reuse sMask region as sP - ... -__syncthreads() // ensure dS consumed -// reuse sBias region as sdS in next iteration -``` - -### Memory Efficiency Optimizations - -1. **Shared Memory Aliasing**: Smart reuse of memory regions (sMask ↔ sP, sBias ↔ sdS) with explicit barrier synchronization -2. **Block-level Skip**: Early exit from computation and memory loading for inactive tiles -3. **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability -4. **Register-Optimized Operations**: Critical masking and gradient operations performed in register memory -5. **Coalesced Memory Access**: Optimized access patterns for GPU memory hierarchy -6. **Template Specialization**: Compile-time optimization eliminates runtime branching overhead - -## Memory Layout - -### Tensor Memory Organization - -The Dynamic Mask Attention extends Flash Attention's memory layout to include attention masks and attention bias: - -``` -Global Memory Layout: -┌─────────────────────────────────────────────────────────────────┐ -│ Q: [batch, seqlen_q, num_heads, head_dim] │ -│ K: [batch, seqlen_k, num_heads_k, head_dim] │ -│ V: [batch, seqlen_k, num_heads_k, head_dim] │ -│ AttnMask: [batch, num_kv_heads, seqlen_q, seqlen_k] │ -│ Bias: [batch, num_kv_heads, seqlen_q, seqlen_k] │ -│ Output: [batch, seqlen_q, num_heads, head_dim] │ -└─────────────────────────────────────────────────────────────────┘ - -Shared Memory Layout (per thread block): -┌─────────────────────────────────────────────────────────────────────┐ -│ Q Tile: [kBlockM, head_dim] │ K Tile: [kBlockN, head_dim] │ -│ V Tile: [kBlockN, head_dim] │ S Tile: [kBlockM, kBlockN] │ -│ AM Tile: [kBlockM, kBlockN] │ Bias Tile: [kBlockM, kBlockN] │ -└─────────────────────────────────────────────────────────────────────┘ - -Register Memory (per thread): -┌─────────────────────────────────────────────────────────────────────┐ -│ Q Frag: [MMA_M, head_dim/N] │ K Frag: [MMA_N, head_dim/N] │ -│ V Frag: [MMA_N, head_dim/N] │ S Frag: [MMA_M, MMA_N] │ -│ AM Frag: [MMA_M, MMA_N] │ Bias Frag: [MMA_M, MMA_N] │ -│ Acc Frag: [MMA_M, head_dim/N] │ │ -└─────────────────────────────────────────────────────────────────────┘ -``` - -### Memory Access Patterns - -#### Attention Mask and Attention Bias Loading -```cpp -// Global to Shared Memory (coalesced access) -Tensor tSgBias = local_partition(mBias, smem_tiled_copy_Bias, thread_idx); -Tensor tSsBias = local_partition(sBias, smem_tiled_copy_Bias, thread_idx); - -// Each thread loads a contiguous chunk to maximize memory bandwidth -copy(smem_tiled_copy_Bias, tSgBias, tSsBias); - -// Shared to Register Memory (bank-conflict-free) -Tensor tSrBias = local_partition(sBias, smem_thr_copy_Bias, thread_idx); -copy(smem_thr_copy_Bias, tSsBias, tSrBias); -``` - -#### Memory Layout Transformations -```cpp -// Convert MMA accumulator layout to row-column layout for masking -// From: (MMA=4, MMA_M, MMA_N) -> (nrow=(2, MMA_M), ncol=(2, MMA_N)) -auto convert_layout_acc_rowcol = [](auto layout) { - return make_layout( - make_layout(make_shape(Int<2>{}, get<1>(layout.shape())), - make_stride(Int(layout.stride())* 2>{}, get<1>(layout.stride()))), - make_layout(make_shape(Int<2>{}, get<2>(layout.shape())), - make_stride(Int<1>{}, Int<2>{})) - ); -}; -``` - -### Shared Memory Optimization - -#### Bank Conflict Avoidance -- Attention bias and attention masks use the same copy patterns as Q/K/V to avoid bank conflicts -- Padding added when necessary to ensure 128-bit aligned access -- Thread block size chosen to maximize occupancy while maintaining memory efficiency - -#### Memory Coalescing -```cpp -// Example: Loading 128-bit aligned chunks for optimal bandwidth -using SmemCopyAtomBias = Copy_Atom; // 128-bit loads -using SmemCopyAtomAttnMask = Copy_Atom; -``` - -## Performance Considerations - -### Memory Efficiency -- **Shared Memory Aliasing**: Smart memory reuse (sMask ↔ sP, sBias ↔ sdS) reduces footprint by ~30% -- **Block-level Skip**: Early exit eliminates unnecessary memory loads for inactive tiles -- **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability -- **Coalesced Access**: Optimized tensor layouts for GPU memory hierarchy - -### Computational Efficiency -- **Unified Skip Logic**: Both forward and backward passes benefit from block-level computation skipping -- **5-GEMM Chain Skip**: Complete gradient computation chain skipped for inactive tiles -- **Early Branch Decision**: Mask OR-reduction happens before expensive K/V loads -- **Warp-Level Optimization**: Operations optimized for GPU warp execution model - -### Scalability -- **Block-level Granularity**: Tile-level sparsity more efficient than element-level for long sequences -- **Multi-Head Support**: Efficient handling of multiple attention heads with per-head sparsity patterns -- **Barrier Optimization**: Minimal synchronization overhead through smart aliasing strategies - -### Performance Model - -Expected speedup for various sparsity levels: -- **50% sparsity**: ~1.8x speedup -- **75% sparsity**: ~3.2x speedup -- **90% sparsity**: ~6.5x speedup - -Performance factors: -- Skip overhead typically <5% of dense computation time -- Memory bandwidth reduction scales linearly with sparsity -- Shared memory aliasing enables 20-30% larger tile sizes - -## API Changes - -### New Required Parameters - -The Dynamic Mask Attention integration introduces new required parameters to the forward pass: - -- **`attn_mask`** (`torch.Tensor`): Attention mask tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` - - Binary mask (1.0 = compute, 0.0 = skip) indicating which positions should be processed - - Determines the sparsity pattern for computational efficiency - -- **`attn_bias`** (`torch.Tensor`): Attention bias tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` - - Contains dynamic attention bias values applied to attention scores before softmax - - Must have the same dtype and device as Q/K/V tensors - -### Updated Function Signature - -```python -def fwd( - q: torch.Tensor, # Query tensor - k: torch.Tensor, # Key tensor - v: torch.Tensor, # Value tensor - attn_mask: torch.Tensor, # Attention mask (REQUIRED) - attn_bias: torch.Tensor, # Attention bias (REQUIRED) - out: Optional[torch.Tensor] = None, # Pre-allocated output - softmax_scale: float = None, # Attention scaling - is_causal: bool = False, # Causal masking - softcap: float = 0.0, # Soft capping - return_softmax: bool = False, # Return attention weights -) -> List[torch.Tensor] -``` - -### Backward Compatibility - -**Breaking Change Notice**: The integration requires attention bias and attention mask tensors as mandatory parameters. This is a breaking change from the original Flash Attention API. - -**Migration Path**: Users need to: -1. Add attention mask and bias generation logic to attention modules -2. Implement appropriate mask and bias computation within the attention forward pass -3. Ensure proper tensor shapes and dtypes for mask and bias tensors - -### Complete Usage Example - -```python -import torch -import torch.nn as nn -from flash_dmattn.integration.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward - -class DynamicMaskAttention(nn.Module): - def __init__(self, config): - super().__init__() - self.num_heads = config.num_attention_heads - self.num_kv_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // config.num_attention_heads - self.scaling = 1.0 / math.sqrt(self.head_dim) - - # Standard attention projections - self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) - - def forward(self, hidden_states, attention_mask=None, attention_bias=None): - batch_size, seq_len, _ = hidden_states.shape - - # Project to Q, K, V - query_states = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - - # Prepare mask and bias tensors with proper shapes - if attention_mask is None: - attention_mask = torch.ones((batch_size, self.num_kv_heads, seq_len, seq_len), - dtype=query_states.dtype, device=query_states.device) - - if attention_bias is None: - attention_bias = torch.zeros((batch_size, self.num_kv_heads, seq_len, seq_len), - dtype=query_states.dtype, device=query_states.device) - - # Call attention implementation - attn_output, attn_weights = flash_dynamic_mask_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - attention_bias=attention_bias, - scaling=self.scaling, - ) - - return attn_output, attn_weights -``` - -The attention bias generation process: - -1. **Value-based Dynamic States**: - ```python - dt_states = self.dt_proj(value_states_flattened) - dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) - ``` - -2. **Bias Expansion**: - ```python - attn_bias = dt_states[:, :, None, :].expand(-1, -1, query_len, -1) - ``` - -3. **Mask Processing**: Done internally in `_flash_dynamic_mask_attention_forward` - - -### CUDA Backend: Sparse Attention Computation - -The CUDA backend implements the sparse attention computation through `_flash_dynamic_mask_attention_forward`: - -```python -def _flash_dynamic_mask_attention_forward( - query_states, key_states, value_states, - attention_mask, attention_bias, - query_length, key_length, - is_causal, softmax_scale=None, softcap=None, - target_dtype=None, implementation=None, **kwargs -): - dtype = query_states.dtype - min_dtype = torch.finfo(dtype).min - batch_size, _, num_kv_heads, _ = key_states.shape - - # Initialize attention bias if not provided - if attention_bias is None: - attention_bias = torch.zeros( - (batch_size, num_kv_heads, query_length, key_length), - dtype=dtype, device=query_states.device - ) - - # Apply attention mask to bias - if attention_mask is not None: - attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) - attention_mask = attention_mask.to(dtype) - - # Call Flash Attention with dynamic masking - out = flash_dmattn_func( - query_states, key_states, value_states, - attn_mask=attention_mask, attn_bias=attention_bias, - softmax_scale=softmax_scale, is_causal=is_causal - ) - - return out[0] if isinstance(out, tuple) else out -``` - -The backend processing stages: - -1. **Bias Initialization**: Create zero bias tensor if not provided -2. **Mask Application**: Apply boolean attention mask to bias tensor -3. **Flash Attention Call**: Execute optimized CUDA kernels with sparse patterns - -#### Forward Algorithm - -The implementation introduces unified block-level skip logic that optimizes computation by skipping entire tiles when they are fully masked: - -```cpp -// Forward pass with unified skip logic -for m_block in M_tiles: - load Q_tile - for n_block in N_tiles_stream: - load mask_block - any_active = OR(mask_block) // Block-level skip decision - if !any_active: - advance_pointers() // Skip computation, advance to next tile - continue - - // Only execute for active tiles - load K_tile, V_tile // Load data only when needed - S = Q_tile @ K_tile^T + bias_block // Sparse Q*K^T GEMM - S_masked = apply_mask(S, mask_block) // Apply dynamic masking - P = softmax(S_masked, LSE_cache) // Softmax with LSE caching - O_partial += P @ V_tile // Sparse Score*V GEMM -write O -``` - -Key improvements: -- **Block-level Skip Logic**: OR-reduction over entire (BlockM × BlockN) tile determines if computation is needed -- **Early Skip Decision**: Mask evaluation happens before expensive K/V loading and computation -- **Pointer Management**: Safe pointer advancement ensures correct memory layout for subsequent tiles - -#### Backward Algorithm - -The backward pass also benefits from the unified skip logic, maintaining numerical correctness while significantly reducing computation for sparse patterns: - -```cpp -// Backward pass with unified skip logic -for m_block in reversed(M_tiles): - load Q_tile, dO_tile - init accum_dQ - for n_block in N_tiles_stream: - load mask_block - any_active = OR(mask_block) // Same skip decision as forward - if !any_active: - advance_pointers_zero_side_outputs() // Skip computation, zero side outputs - continue - - // Only execute for active tiles - load K_tile, V_tile - - // Recompute (identical to forward for active tiles) - S = Q_tile @ K_tile^T + bias_block - P = softmax(S, LSE_cache) // Use cached LSE for stability - - // Gradient computation chain (5 GEMMs) - dV += P^T @ dO_tile // Accumulate dV - dP = dO_tile @ V_tile^T // Compute dP - dS = g(P, dP) // dS = (dP - (P ⊙ dP).sum(axis)) * P - dQ += dS @ K_tile // Accumulate dQ - dK += dS^T @ Q_tile // Accumulate dK - write dQ, accumulate dK, dV -``` - -Key features: -- **Recomputation Strategy**: Forward computation is recomputed only for active tiles to maintain numerical precision -- **LSE Caching**: Uses cached log-sum-exp values from forward pass for stable softmax recomputation -- **Gradient Chain**: All five gradient GEMMs are skipped for fully masked tiles, maintaining mathematical correctness -- **Zero Handling**: Properly handles zero contributions from skipped tiles in accumulation - -#### Skip Logic Correctness - -The mathematical correctness of the skip logic relies on the following principles: - -1. **Forward Skip**: If a tile is entirely masked (active_mask = 0), its contribution to the output is exactly zero: - ``` - O_contribution = P @ V = 0 @ V = 0 - ``` - -2. **Backward Skip**: For fully masked tiles, all intermediate gradients are zero: - ``` - P = 0 ⟹ dS = 0 ⟹ dQ = dK = dV = 0 (from this tile) - ``` - -3. **LSE Preservation**: Skipped tiles don't contribute to the log-sum-exp, maintaining numerical stability. - -### Sparse Computation Strategy - -### Block-level Skip Logic - -The implementation introduces unified block-level skip logic that operates at the tile granularity rather than individual elements: - -1. **Tile-level Active Detection**: - ```cpp - any_active = OR_reduce(mask_block) // Single bit indicating if any position in tile is active - ``` - -2. **Skip Decision**: Binary branch based on tile activity: - ```cpp - if (!any_active) { - advance_pointers(); // Forward: skip all computation - advance_pointers_zero_outputs(); // Backward: skip computation, zero side outputs - continue; - } - ``` - -3. **Computational Benefits**: - - Skip entire K/V loads for inactive tiles - - Eliminate all 5 GEMMs in backward pass for inactive tiles - - Reduce memory bandwidth and arithmetic operations proportional to sparsity - -### Sparsity Pattern Recognition - -The Dynamic Mask Attention implements structured sparsity based on learned importance scores: - -1. **Attention Bias Computation**: Attention bias values are computed based on dynamic states derived from value tensors - - Learned projection matrices map value features to importance scores - - Coefficient parameters control the dynamic range of importance values - - Activation functions ensure appropriate bias magnitude - -2. **Binary Attention Mask**: - - 1.0 for positions that should be computed - - 0.0 for positions that should be skipped - -### Performance Model - -For block-level sparsity with active tile fraction $p$, skip overhead ratio $\varepsilon$, and early-exit efficiency $\eta$: - -$$ -\text{Speedup} \approx \frac{1}{p + (1-p)(\varepsilon + \eta \cdot \text{LoadOverhead})} -$$ - -Where: -- $p$: fraction of active tiles -- $\varepsilon$: skip branching overhead -- $\eta$: efficiency of early memory load exit -- $\text{LoadOverhead}$: relative cost of K/V loading vs computation - -Upper bound as $\varepsilon, \eta \to 0$: $1/p$ - -### Shared Memory Aliasing - -The implementation introduces smart shared memory aliasing to reduce footprint and enable larger tile sizes: - -1. **sMask ↔ sP Aliasing**: Mask shared memory region is reused for storing softmax probabilities P after mask consumption -2. **sBias ↔ sdS Aliasing**: Bias shared memory region is reused for gradient computations dS -3. **Barrier Synchronization**: Explicit `__syncthreads()` calls ensure safe transitions between aliased usage - -```cpp -// Example aliasing pattern -load mask -> sMask -any_active = or_reduce(sMask) -if any_active: - compute S - __syncthreads() // ensure mask fully consumed - softmax -> write P into aliased region (sP) // reuse sMask region as sP - ... -__syncthreads() // ensure dS consumed -// reuse sBias region as sdS in next iteration -``` - -### Memory Efficiency Optimizations - -1. **Shared Memory Aliasing**: Smart reuse of memory regions (sMask ↔ sP, sBias ↔ sdS) with explicit barrier synchronization -2. **Block-level Skip**: Early exit from computation and memory loading for inactive tiles -3. **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability -4. **Register-Optimized Operations**: Critical masking and gradient operations performed in register memory -5. **Coalesced Memory Access**: Optimized access patterns for GPU memory hierarchy -6. **Template Specialization**: Compile-time optimization eliminates runtime branching overhead - -## Memory Layout - -### Tensor Memory Organization - -The Dynamic Mask Attention extends Flash Attention's memory layout to include attention masks and attention bias: - -``` -Global Memory Layout: -┌─────────────────────────────────────────────────────────────────┐ -│ Q: [batch, seqlen_q, num_heads, head_dim] │ -│ K: [batch, seqlen_k, num_heads_k, head_dim] │ -│ V: [batch, seqlen_k, num_heads_k, head_dim] │ -│ AttnMask: [batch, num_kv_heads, seqlen_q, seqlen_k] │ -│ Bias: [batch, num_kv_heads, seqlen_q, seqlen_k] │ -│ Output: [batch, seqlen_q, num_heads, head_dim] │ -└─────────────────────────────────────────────────────────────────┘ - -Shared Memory Layout (per thread block): -┌─────────────────────────────────────────────────────────────────────┐ -│ Q Tile: [kBlockM, head_dim] │ K Tile: [kBlockN, head_dim] │ -│ V Tile: [kBlockN, head_dim] │ S Tile: [kBlockM, kBlockN] │ -│ AM Tile: [kBlockM, kBlockN] │ Bias Tile: [kBlockM, kBlockN] │ -└─────────────────────────────────────────────────────────────────────┘ - -Register Memory (per thread): -┌─────────────────────────────────────────────────────────────────────┐ -│ Q Frag: [MMA_M, head_dim/N] │ K Frag: [MMA_N, head_dim/N] │ -│ V Frag: [MMA_N, head_dim/N] │ S Frag: [MMA_M, MMA_N] │ -│ AM Frag: [MMA_M, MMA_N] │ Bias Frag: [MMA_M, MMA_N] │ -│ Acc Frag: [MMA_M, head_dim/N] │ │ -└─────────────────────────────────────────────────────────────────────┘ -``` - -### Memory Access Patterns - -#### Attention Mask and Attention Bias Loading -```cpp -// Global to Shared Memory (coalesced access) -Tensor tSgBias = local_partition(mBias, smem_tiled_copy_Bias, thread_idx); -Tensor tSsBias = local_partition(sBias, smem_tiled_copy_Bias, thread_idx); - -// Each thread loads a contiguous chunk to maximize memory bandwidth -copy(smem_tiled_copy_Bias, tSgBias, tSsBias); - -// Shared to Register Memory (bank-conflict-free) -Tensor tSrBias = local_partition(sBias, smem_thr_copy_Bias, thread_idx); -copy(smem_thr_copy_Bias, tSsBias, tSrBias); -``` - -#### Memory Layout Transformations -```cpp -// Convert MMA accumulator layout to row-column layout for masking -// From: (MMA=4, MMA_M, MMA_N) -> (nrow=(2, MMA_M), ncol=(2, MMA_N)) -auto convert_layout_acc_rowcol = [](auto layout) { - return make_layout( - make_layout(make_shape(Int<2>{}, get<1>(layout.shape())), - make_stride(Int(layout.stride())* 2>{}, get<1>(layout.stride()))), - make_layout(make_shape(Int<2>{}, get<2>(layout.shape())), - make_stride(Int<1>{}, Int<2>{})) - ); -}; -``` - -### Shared Memory Optimization - -#### Bank Conflict Avoidance -- Attention bias and attention masks use the same copy patterns as Q/K/V to avoid bank conflicts -- Padding added when necessary to ensure 128-bit aligned access -- Thread block size chosen to maximize occupancy while maintaining memory efficiency - -#### Memory Coalescing -```cpp -// Example: Loading 128-bit aligned chunks for optimal bandwidth -using SmemCopyAtomBias = Copy_Atom; // 128-bit loads -using SmemCopyAtomAttnMask = Copy_Atom; -``` - -## Performance Considerations - -### Memory Efficiency -- **Shared Memory Aliasing**: Smart memory reuse (sMask ↔ sP, sBias ↔ sdS) reduces footprint by ~30% -- **Block-level Skip**: Early exit eliminates unnecessary memory loads for inactive tiles -- **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability -- **Coalesced Access**: Optimized tensor layouts for GPU memory hierarchy - -### Computational Efficiency -- **Unified Skip Logic**: Both forward and backward passes benefit from block-level computation skipping -- **5-GEMM Chain Skip**: Complete gradient computation chain skipped for inactive tiles -- **Early Branch Decision**: Mask OR-reduction happens before expensive K/V loads -- **Warp-Level Optimization**: Operations optimized for GPU warp execution model - -### Scalability -- **Block-level Granularity**: Tile-level sparsity more efficient than element-level for long sequences -- **Multi-Head Support**: Efficient handling of multiple attention heads with per-head sparsity patterns -- **Barrier Optimization**: Minimal synchronization overhead through smart aliasing strategies - -### Performance Model - -Expected speedup for various sparsity levels: -- **50% sparsity**: ~1.8x speedup -- **75% sparsity**: ~3.2x speedup -- **90% sparsity**: ~6.5x speedup - -Performance factors: -- Skip overhead typically <5% of dense computation time -- Memory bandwidth reduction scales linearly with sparsity -- Shared memory aliasing enables 20-30% larger tile sizes - -## API Changes - -### New Required Parameters - -The Dynamic Mask Attention integration introduces new required parameters to the forward pass: - -- **`attn_mask`** (`torch.Tensor`): Attention mask tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` - - Binary mask (1.0 = compute, 0.0 = skip) indicating which positions should be processed - - Determines the sparsity pattern for computational efficiency - -- **`attn_bias`** (`torch.Tensor`): Attention bias tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` - - Contains dynamic attention bias values applied to attention scores before softmax - - Must have the same dtype and device as Q/K/V tensors - -### Updated Function Signature - -```python -def fwd( - q: torch.Tensor, # Query tensor - k: torch.Tensor, # Key tensor - v: torch.Tensor, # Value tensor - attn_mask: torch.Tensor, # Attention mask (REQUIRED) - attn_bias: torch.Tensor, # Attention bias (REQUIRED) - out: Optional[torch.Tensor] = None, # Pre-allocated output - softmax_scale: float = None, # Attention scaling - is_causal: bool = False, # Causal masking - softcap: float = 0.0, # Soft capping - return_softmax: bool = False, # Return attention weights -) -> List[torch.Tensor] -``` - -### Backward Compatibility - -**Breaking Change Notice**: The integration requires attention bias and attention mask tensors as mandatory parameters. This is a breaking change from the original Flash Attention API. - -**Migration Path**: Users need to: -1. Add attention mask and bias generation logic to attention modules -2. Implement appropriate mask and bias computation within the attention forward pass -3. Ensure proper tensor shapes and dtypes for mask and bias tensors - -### Complete Usage Example - -```python -import torch -import torch.nn as nn -import flash_dmattn_cuda as flash_dmattn - -class DynamicMaskAttention(nn.Module): - def __init__(self, config): - super().__init__() - self.num_heads = config.num_attention_heads - self.num_kv_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // config.num_attention_heads - self.scaling = 1.0 / math.sqrt(self.head_dim) - - # Standard attention projections - self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) - - def forward(self, hidden_states, attention_mask=None, attention_bias=None): - batch_size, seq_len, _ = hidden_states.shape - - # Project to Q, K, V - query_states = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - - # Generate attention bias from value states - dt_states = self.dt_proj( - value_states.transpose(1, 2).reshape(batch_size, seq_len, -1) - ) - dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) - attention_bias = dt_states[:, :, None, :].expand(-1, -1, seq_len, -1).to(hidden_states.dtype) - - # Prepare attention mask for multi-head - if attention_mask is not None: - attention_mask = attention_mask.expand(-1, self.num_kv_heads, -1, -1) - - # Flash Dynamic Mask Attention - attn_output, _ = flash_dynamic_mask_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - attention_bias=attention_bias, - scaling=self.scaling, - ) - - # Output projection - attn_output = attn_output.reshape(batch_size, seq_len, -1) - return self.o_proj(attn_output) - -# Usage example -config = type('Config', (), { - 'hidden_size': 768, - 'num_attention_heads': 12, - 'num_key_value_heads': 12, -})() - -attention = DynamicMaskAttention(config) -hidden_states = torch.randn(2, 4096, 768, device='cuda', dtype=torch.bfloat16) -output = attention(hidden_states) -print(f"Output shape: {output.shape}") # [2, 4096, 768] -``` - -### Integration with Existing Codebases - -For users migrating from Flash Attention, the typical changes required are: - -```python -# Before (Flash Attention) -class StandardAttention(nn.Module): - def __init__(self, config): - super().__init__() - self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim) - self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim) - self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim) - self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size) - - def forward(self, hidden_states): - q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) - output = flash_attn_func(q, k, v, dropout_p=0.1, softmax_scale=self.scaling, causal=True) - return self.o_proj(output) - -# After (Dynamic Mask Attention) -class DynamicMaskAttention(nn.Module): - def __init__(self, config): - super().__init__() - # Same standard projections - self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim) - self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim) - self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim) - self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size) - - # Add dynamic mask parameters - self.A = nn.Parameter(torch.zeros(config.num_key_value_heads)) - self.dt_proj = nn.Linear(config.num_key_value_heads * self.head_dim, config.num_key_value_heads) - self.keep_window_size = config.keep_window_size - - def forward(self, hidden_states): - # Standard Q, K, V projections - query_states = self.q_proj(hidden_states).view(...).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(...).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(...).transpose(1, 2) - - # Generate attention bias from value states - dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(...)) - dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) - attention_bias = dt_states[:, :, None, :].expand(-1, -1, seq_len, -1) - - # Use Flash Dynamic Mask Attention - attn_output, _ = flash_dynamic_mask_attention_forward( - self, query_states, key_states, value_states, - attention_mask=attention_mask, attention_bias=attention_bias, - scaling=self.scaling - ) - - return self.o_proj(attn_output.reshape(...)) -``` \ No newline at end of file diff --git a/docs/integration_zh.md b/docs/integration_zh.md deleted file mode 100644 index 56ed0c1..0000000 --- a/docs/integration_zh.md +++ /dev/null @@ -1,522 +0,0 @@ -# Flash 动态掩码注意力集成指南 - -## 概述 - -本文档阐述了如何在 Flash Attention 框架中集成 Dynamic Mask Attention(动态掩码注意力)。通过将 Flash Attention 的高效显存利用方式与动态稀疏掩码结合,这一集成能够在极长序列场景下实现稀疏注意力的高效计算。 - -该集成方案采用统一的稀疏计算路径:Python 端负责预计算注意力掩码与偏置张量,CUDA 后端在前向与反向两个阶段执行基于块的跳过逻辑与稀疏算子调度。 - -## 目录 - -1. [集成架构](#集成架构) -2. [核心改动](#核心改动) -3. [实现细节](#实现细节) -4. [稀疏计算策略](#稀疏计算策略) -5. [内存布局](#内存布局) -6. [性能考量](#性能考量) -7. [API 变化](#api-变化) - -## 集成架构 - -### 高层设计 - -动态掩码注意力的集成在前向与反向过程中统一采用块级稀疏执行路径: - -1. **动态掩码计算**:Python 端预先生成注意力掩码(mask)与注意力偏置(bias)张量。 -2. **统一稀疏执行**:CUDA 后端在块粒度上决定是否跳过计算,并执行稀疏化的注意力与梯度算子。 -3. **内存优化**:通过共享内存别名与显式同步实现更高的共享内存复用率。 - -### 关键组件 - -- **注意力掩码**:形状为 `(batch, num_kv_heads, query_len, key_len)` 的二值张量(1.0 表示保留,0.0 表示跳过)。 -- **注意力偏置**:与掩码形状一致的张量,在 Softmax 前加性注入。 -- **块级跳过逻辑**:对 `(BlockM × BlockN)` tile 做 OR 归约判断是否执行计算。 -- **LSE 缓存**:前向阶段缓存 log-sum-exp 结果,反向阶段复用以保持数值稳定。 -- **共享内存别名**:动态复用共享内存缓冲区,配合 `__syncthreads()` 控制生命周期。 -- **完备梯度链路**:在保留稀疏跳过能力的同时,确保梯度流动正确。 - -## 核心改动 - -### 1. 参数结构扩展(`flash.h`) - -**目的**:扩展参数结构体以支持动态掩码与偏置信息,同时保留对 QKV 的统一访问接口。 - -```cpp -struct QKV_params { - void *__restrict__ q_ptr; - void *__restrict__ k_ptr; - void *__restrict__ v_ptr; - index_t q_batch_stride, k_batch_stride, v_batch_stride; - index_t q_row_stride, k_row_stride, v_row_stride; - index_t q_head_stride, k_head_stride, v_head_stride; - int h, h_k; - int h_h_k_ratio; -}; - -struct Mask_params { - void *__restrict__ mask_ptr; - index_t mask_batch_stride; - index_t mask_head_stride; - index_t mask_row_stride; -}; - -struct Bias_params { - void *__restrict__ bias_ptr; - index_t bias_batch_stride; - index_t bias_head_stride; - index_t bias_row_stride; -}; - -struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_params { - // ...existing code... - bool seqlenq_ngroups_swapped; -}; -``` - -**设计要点**: -- 多重继承将 QKV、掩码、偏置的参数维度拆分,保持接口清晰。 -- 为掩码与偏置提供完整的 stride 信息,以便在 CUDA 中高效寻址。 -- 与原有 Flash Attention 的内存布局保持兼容,避免性能回退。 - -### 2. 内核特性与内存布局(`kernel_traits.h`) - -**目的**:根据架构(SM75 / SM80+)选择合适的 MMA 原子与内存拷贝路径,为动态掩码操作提供最佳性能。 - -```cpp -template -struct Flash_kernel_traits { - using Element = elem_type; - using ElementAccum = float; - using index_t = int64_t; - static constexpr int kHeadDim = kHeadDim_; - static constexpr int kBlockM = kBlockM_; - static constexpr int kBlockN = kBlockN_; - static constexpr int kNWarps = kNWarps_; - // ...existing code... - using SmemCopyAtomMask = SmemCopyAtom; - using SmemCopyAtomBias = SmemCopyAtom; -}; -``` - -**设计要点**: -- 根据编译目标自动选择 `cp.async` 与 LDSM 指令路径。 -- 统一掩码与偏置的共享内存加载策略,避免额外的 bank conflict。 -- 模板化的类型安全保证不同精度(FP16/BF16)路径一致。 - -### 3. 块级信息扩展(`block_info.h`) - -**目的**:在可变长度场景下计算掩码与偏置的块级偏移量,保证全局内存访问有序。 - -```cpp -template -struct BlockInfo { - template - __device__ BlockInfo(const Params ¶ms, const int bidb) { - // ...existing code... - } - - template - __forceinline__ __device__ index_t mask_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { - index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; - sum_s_k == -1 ? offset += leftpad_k : offset += uint32_t(sum_s_k + leftpad_k); - return offset; - } - - // ...existing code... -}; -``` - -**设计要点**: -- 提供统一的偏移量计算方法,简化内核中的地址计算。 -- 同时支持固定长度与可变长度两种输入形式。 -- 将左侧填充(left pad)纳入偏移量,保证稀疏掩码与 KV 缓存对齐。 - -### 4. 内存拷贝与算子工具(`utils.h`) - -**目的**:提供布局转换、类型转换、warp 归约与通用 GEMM 包装,适配 Flash Attention 的内存层次结构。 - -```cpp -namespace FLASH_NAMESPACE { - -template -__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { - // ...existing code... - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); -}; - -// ...existing code... - -template -__forceinline__ __device__ void gemm(/* ... */) { - // ...existing code... -} - -} // namespace FLASH_NAMESPACE -``` - -**设计要点**: -- 通过布局转换统一 MMA 累加器的访问方式,方便掩码逻辑在寄存器中操作。 -- 提供针对 BF16 的专用类型转换,避免额外的精度损耗。 -- Warp 归约与 GEMM 包装均支持将数据留在寄存器中,降低共享内存压力。 - -### 5. 动态掩码核心逻辑(`mask.h`) - -**目的**:在寄存器层面将掩码与偏置应用到注意力得分上,同时处理因果掩码与边界情况。 - -```cpp -template -__forceinline__ __device__ void apply_mask( - TensorType &tensor, - MaskType &mask, - BiasType &bias, - const float scale_softmax, - const int col_idx_offset_, - const int max_seqlen_k, - const int row_idx_offset, - const int max_seqlen_q, - const int warp_row_stride) { - // ...existing code... -} -``` - -**设计要点**: -- 在 `tensor` 保持 MMA 布局的情况下,逐元素应用掩码、偏置与缩放因子。 -- 因果掩码通过列索引上限裁剪实现,与动态掩码兼容。 -- 被掩盖的位置直接写入 `-INFINITY`,防止 Softmax 后出现数值污染。 - -### 6. 反向链路扩展(`flash_bwd_kernel.h`) - -**目的**:在反向传播中复用动态掩码逻辑,确保梯度仅在活跃 tile 上计算。 - -```cpp -struct Flash_bwd_params : public Flash_fwd_params { - // ...existing code... -}; - -template -inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, - const int bidh, const int n_block) { - // ...existing code... -} -``` - -**设计要点**: -- 反向路径沿用前向阶段的 tile 活跃性判断,跳过完全被掩码的块。 -- 结合 LSE 缓存,重算前向 Softmax 时保持数值稳定。 -- 保证五个梯度 GEMM 在活跃 tile 上依旧串联执行,避免梯度缺失。 - -### 7. 前向内核改造(`flash_fwd_kernel.h`) - -**目的**:在主注意力内核中插入动态掩码流程,同时保持 Flash Attention 的高并发与共享内存利用率。 - -```cpp -template -inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, - const int bidh, const int m_block) { - using Element = typename Kernel_traits::Element; - // ...existing code... -} -``` - -**设计要点**: -- 按 tile 裁剪逻辑提前判断是否加载 K/V,降低无效内存访问。 -- 仅在提供掩码/偏置时启用相应的分支,保持向后兼容。 -- 通过模板参数在编译期裁剪分支,减少运行期开销。 - -### 8. 启动模板更新(`flash_fwd_launch_template.h`) - -**目的**:在 kernel launch 阶段配置共享内存需求、模板实例化与错误处理,适配动态掩码的新资源需求。 - -```cpp -#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ -template \ -__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) - -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, - bool Is_causal, bool Is_even_MN, bool Is_even_K, - bool Is_softcap, bool Return_softmax) { - // ...existing code... -} - -// ...existing code... - -template -void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr size_t smem_size = Kernel_traits::kSmemSize; - // ...existing code... -} -``` - -**设计要点**: -- 统一宏定义减少重复代码,便于扩展到新的 kernel 变体。 -- 针对不支持的架构给出明确的构建期/运行期错误提示。 -- 在 launch 前计算共享内存需求,必要时启用 `cudaFuncSetAttribute` 进行配置。 - -### 9. Python 接口扩展(`flash_api.cpp`) - -**目的**:扩展 C++/PyBind11 接口以接受掩码与偏置张量,并提供全面的数据校验。 - -```cpp -void set_params_fprop( - Flash_fwd_params ¶ms, - // ...existing code... -) { - // ...existing code... -} - -std::vector mha_fwd( - at::Tensor &q, - // ...existing code... - const bool return_softmax) { - // ...existing code... - return {out, softmax_lse}; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "FlashDynamicMaskAttention"; - // ...existing code... -} -``` - -**设计要点**: -- 对输入张量的形状、dtype、device 进行全面校验。 -- 保持原有参数顺序,新增参数保持向后兼容的默认行为。 -- 当掩码或偏置未提供时,自动填充零值张量以保证接口易用性。 - -## 实现细节 - -### C++ API 接口 - -C++ 端对外暴露如下核心函数,用于前向、可变长度前向与反向计算: - -```cpp -namespace FLASH_NAMESPACE { - -std::vector mha_fwd( - at::Tensor &q, - at::Tensor &k, - at::Tensor &v, - // ...existing code... - const bool return_softmax); - -std::vector mha_varlen_fwd(/* ... */); - -std::vector mha_bwd(/* ... */); - -} // namespace FLASH_NAMESPACE -``` - -- `mha_fwd`:标准批量前向,支持稀疏掩码与偏置。 -- `mha_varlen_fwd`:支持变长序列并使用累计长度数组。 -- `mha_bwd`:完成梯度计算,返回 dQ / dK / dV / dBias / dMask 等张量。 - -### 参数设置与校验 - -`set_params_fprop` 会在调用前: - -- 重置 `Flash_fwd_params` 并写入基本维度信息。 -- 将掩码与偏置的设备指针、stride、批次数等全部注册。 -- 基于输入 `dtype` 设置缩放因子与 `softcap`,同时准备缓存指针。 - -### Python 绑定与接口 - -PyBind11 模块对外暴露 `mha_fwd`、`mha_bwd`、`varlen_fwd` 等接口,文档字符串说明了参数要求与返回值。用户可通过 Python 直接调用 C++/CUDA 实现。 - -### Python 前端集成示例 - -```python -import torch -import torch.nn as nn -import flash_dmattn_cuda as flash_dmattn - -class DynamicMaskAttention(nn.Module): - def __init__(self, config): - super().__init__() - # ...existing code... - - def forward(self, query_states, key_states, value_states, attn_mask, attn_bias): - out, softmax_lse = flash_dmattn.fwd( - query_states, key_states, value_states, - attn_mask=attn_mask, - attn_bias=attn_bias, - return_softmax=True, - ) - return out, softmax_lse -``` - -- 前端模块负责生成 `attn_mask`(布尔)与 `attn_bias`(与 Q/K/V dtype 相同)。 -- 内部 `_flash_dynamic_mask_attention_forward` 会根据需要补零偏置并调用后端。 -- 输入张量默认为 `(batch, seq_len, num_heads, head_dim)` 排列,内部会自动转置到后端期望格式。 - -## 稀疏计算策略 - -### 块级跳过逻辑 - -- 在加载 Q tile 后,先将掩码 tile 拷贝到共享内存并执行 OR 归约。 -- 若整块被掩盖,则跳过 K/V 加载与后续计算,只推进指针。 -- 对活跃块执行常规注意力流程,并复用共享内存保存 Softmax 结果。 - -### 前向算法 - -```pseudo -for m_block in M_tiles: - load Q_tile - load mask_tile -> shared - any_active = or_reduce(mask_tile) - if not any_active: - continue - load K_tile, V_tile - compute scaled dot product - apply mask & bias in registers - softmax -> write O_tile -``` - -- 掩码裁剪保证 Tile 内所有无效位置直接输出 `-INF`。 -- Softmax 前的缩放与偏置添加与密集版本保持一致。 -- 通过共享内存别名(sMask ↔ sP)减少显存占用。 - -### 反向算法 - -```pseudo -for m_block in reversed(M_tiles): - load Q_tile, dO_tile - load mask_tile -> shared - if tile inactive: - continue - recompute scores with cached LSE - propagate gradients for dS, dV, dK, dQ -``` - -- 仅对活跃块执行五个 GEMM 组合,减少稀疏场景下的冗余计算。 -- 使用前向缓存的 LSE 确保 Softmax 反向的数值稳定性。 -- 对被跳过的块梯度自然为零,避免写入污染。 - -### 跳过逻辑正确性 - -- 若 tile 全部被掩码,输出必为零,跳过计算不会影响结果。 -- 反向阶段活跃性与前向保持一致,保证梯度对应关系不被破坏。 -- 由于被掩盖位置在 Softmax 前已写入 `-INF`,LSE 亦不受影响。 - -## 内存布局 - -### 全局内存组织 - -``` -Q: [batch, seqlen_q, num_heads, head_dim] -K: [batch, seqlen_k, num_kv_heads, head_dim] -V: [batch, seqlen_k, num_kv_heads, head_dim] -Mask: [batch, num_kv_heads, seqlen_q, seqlen_k] -Bias: [batch, num_kv_heads, seqlen_q, seqlen_k] -Output: [batch, seqlen_q, num_heads, head_dim] -``` - -### 共享内存布局(每个线程块) - -``` -Q Tile : [kBlockM, head_dim] -K Tile : [kBlockN, head_dim] -V Tile : [kBlockN, head_dim] -S Tile : [kBlockM, kBlockN] -Mask Tile: [kBlockM, kBlockN] -Bias Tile: [kBlockM, kBlockN] -``` - -### 寄存器布局(每个线程) - -``` -Q Frag : [MMA_M, head_dim / N] -K Frag : [MMA_N, head_dim / N] -V Frag : [MMA_N, head_dim / N] -S Frag : [MMA_M, MMA_N] -Mask Frag: [MMA_M, MMA_N] -Bias Frag: [MMA_M, MMA_N] -Acc Frag : [MMA_M, head_dim / N] -``` - -### 内存访问模式 - -- 掩码与偏置与 K/V 共享相同的 `Copy_Atom` 配置,确保 128-bit 对齐、最大化带宽。 -- 共享内存拷贝后通过 `local_partition` 分配给线程,避免 bank conflict。 -- `convert_layout_acc_rowcol` 将 MMA 布局转换为行/列布局,方便寄存器操作。 - -### 共享内存优化 - -- **别名复用**:`sMask` 在使用后可重用为 `sP`(Softmax 输出),`sBias` 可重用为 `sdS`。 -- **同步屏障**:在重用前使用 `__syncthreads()` 确保所有线程完成对旧数据的使用。 -- **块尺寸选择**:根据稀疏度与共享内存限制调整 tile 尺寸,提高 SM 占用率。 - -## 性能考量 - -- **共享内存复用**:别名策略可将共享内存占用削减约 30%。 -- **块级跳过**:当稀疏度为 75% 时,可获得约 3× 的前向提速;稀疏度 90% 时可达到 ~6×。 -- **带宽优化**:跳过无效 tile 可以线性降低全局内存带宽需求。 -- **同步开销**:跳过路径的额外 OR 归约占总时间 <5%,可忽略不计。 -- **硬件自适应**:针对 SM75/SM80+ 的不同指令集做了专门优化,确保跨架构稳定收益。 - -## API 变化 - -### 新增必要参数 - -- `attn_mask` (`torch.Tensor`): 形状 `(batch, num_kv_heads, seqlen_q, seqlen_k)` 的布尔张量,决定稀疏模式。 -- `attn_bias` (`torch.Tensor`): 形状与掩码一致的加性偏置张量,dtype 与 Q/K/V 保持一致。 - -### 更新的函数签名 - -```python -def fwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - attn_mask: torch.Tensor, - attn_bias: torch.Tensor, - is_causal: bool = False, - return_softmax: bool = False, - **kwargs -) -> List[torch.Tensor]: - ... -``` - -### 向后兼容说明 - -- 这是一个破坏性更新,旧的 Flash Attention 调用需显式提供掩码与偏置。 -- 若业务场景不需要稀疏掩码,可传入全 1 掩码与全 0 偏置实现与旧版一致的行为。 -- 缺省值在 Python 前端会自动补齐,降低迁移的代码改动。 - -### 完整用法示例 - -```python -import torch -from flash_dmattn.integrations.flash_dynamic_mask_attention import ( - flash_dynamic_mask_attention_forward, -) - -batch, seq_q, seq_k, n_heads, head_dim = 2, 4096, 4096, 16, 128 -q = torch.randn(batch, seq_q, n_heads, head_dim, device="cuda", dtype=torch.float16) -k = torch.randn_like(q) -v = torch.randn_like(q) -mask = torch.ones(batch, n_heads, seq_q, seq_k, device=q.device, dtype=torch.bool) -bias = torch.zeros(batch, n_heads, seq_q, seq_k, device=q.device, dtype=q.dtype) - -out = flash_dynamic_mask_attention_forward( - query_states=q, - key_states=k, - value_states=v, - attention_mask=mask, - attention_bias=bias, - return_attn_probs=False, -) -``` - -- `flash_dynamic_mask_attention_forward` 会自动完成张量转置、补零偏置等准备工作。 -- 若指定 `return_attn_probs=True`,将返回经过 Softmax 的注意力概率,用于调试或可视化。 -- 稀疏模式的 mask 可通过 `flash_dmattn.utils.mask.MaskMod` 组合生成。 - -## 附加建议 - -- 修改 CUDA 核心代码后,至少运行 `benchmarks/forward_equivalence.py` 与 `benchmarks/grad_equivalence.py` 进行回归验证。 -- 构建扩展时可使用 `pip install -e . --no-build-isolation`,必要时设置 `FLASH_DMATTN_CUDA_ARCHS` 指定目标架构。 -- 若仅依赖 Triton/Flex 后端,可通过环境变量 `FLASH_DMATTN_SKIP_CUDA_BUILD=1` 跳过 CUDA 构建。 diff --git a/docs/v1.0.0_technical_report.md b/docs/v1.0.0_technical_report.md deleted file mode 100644 index f6ddf40..0000000 --- a/docs/v1.0.0_technical_report.md +++ /dev/null @@ -1,299 +0,0 @@ -# flash-dmattn v1.0.0 Technical Report - -## 1. Overview -flash-dmattn is a high-performance FlashAttention-style implementation optimized for large sequence lengths and structured sparsity via Dynamic Masks. It provides: -- Unified block-level dynamic mask (block-sparse) skip logic in both forward and backward passes. -- Fused softmax, normalization, and recomputation-friendly backward pipeline. -- Smart shared memory aliasing to reduce footprint and enhance occupancy. -- Support for bias, Log-Sum-Exp (LSE) caching, and optional softcap. -- PyTorch Autograd compatibility and downstream model integration (example: Doge model, HuggingFace-style interface). - -v1.0.0 Highlights: -1. Unified sparse skip logic for both forward and backward (eliminates redundant compute on fully masked tiles). -2. Improved numerical and performance consistency: coherent shared memory layout, aliasing, and barrier sequencing. -3. Documentation, API stabilization, and extensibility groundwork for finer-grained sparsity (bit-packed, fragment-level) later. - -Differences vs v0.3.0: -- v0.3.0 only considered backward skip conceptually; v1.0.0 fully unifies forward + backward skip execution. -- Added strict barrier ordering to prevent NaNs (notably in dK path) when reusing aliased shared memory regions. -- Enhanced documentation, tests, and benchmarking. - -## 2. Architecture -Layers: -1. Python Integration: `flash_dmattn_interface.py` exposing user-friendly APIs (mirroring standard attention calls). -2. Kernel Dispatch Layer: `flash_dmattn_flex.py` / `flash_dmattn_triton.py` selecting CUDA / Triton / hybrid code paths. -3. C++/CUDA Core: flash_api.cpp + `src/*.h` (core kernels: `flash_fwd_kernel.h`, `flash_bwd_kernel.h`). -4. Dynamic Mask Integration: `integrations/flash_dynamic_mask_attention.py` and helpers. -5. Benchmarks & Validation: `benchmarks/*_equivalence.py`, `*_performance.py`. - -Backward dataflow: -Q,K,V,dO (+ mask, bias, LSE) → block streaming → (block-sparse skip decision) → if active: recompute scores & softmax(P) → accumulate dV,dP,dQ,dK → write back. - -## 3. Key Features -- Block-level Dynamic Mask: - - OR-reduction over (BlockM × BlockN) tile; if all zeros → skip. -- Unified Skip (Forward + Backward): - - Forward: skip QK^T, softmax, and P·V for fully masked tiles; safely advances pointers / outputs zeros. - - Backward: skip recompute + the chain of 5 GEMMs (QK^T, dO·V^T, P^T·dO→dV, dP·K→dQ, dP^T·Q→dK). -- LSE Caching: - - Ensures numerical stability: P derived via stored log-sum-exp. -- Optional Softcap: - - Scaling / clamping scores pre-softmax. -- Shared Memory Aliasing: - - sMask ↔ sP; sBias ↔ sdS with explicit barriers. -- Mixed Precision: - - FP16/BF16 inputs, FP32 accumulation. -- Modular KernelTraits: - - Controls block sizes, pipeline depth (double buffering), layouts. -- Extensible Sparsity: - - Design leaves room for bit-packed masks and fragment gating. - -## 4. Algorithms & Kernels - -### 4.1 Forward (Pseudo-code) -``` -for m_block in M_tiles: - load Q_tile - for n_block in N_tiles_stream: - load mask_block - any_active = OR(mask_block) - if !any_active: - advance_pointers() - continue - load K_tile, V_tile - S = Q_tile @ K_tile^T + bias_block - S_masked = apply_mask(S, mask_block) - P = softmax(S_masked, LSE_cache) - O_partial += P @ V_tile -write O -``` - -### 4.2 Backward (Pseudo-code) -``` -for m_block in reversed(M_tiles): - load Q_tile, dO_tile - init accum_dQ - for n_block in N_tiles_stream: - load mask_block - any_active = OR(mask_block) - if !any_active: - advance_pointers_zero_side_outputs() - continue - load K_tile, V_tile - # Recompute - S = Q_tile @ K_tile^T + bias_block - P = softmax(S, LSE_cache) - # Grad chain - dV += P^T @ dO_tile - dP = dO_tile @ V_tile^T - dS = g(P, dP) # (dP - (P ⊙ dP).sum(axis)) * P - dQ += dS @ K_tile - dK += dS^T @ Q_tile - write dQ, accumulate dK, dV -``` - -### 4.3 Softmax & Gradient -Given $S_{ij}$ and $LSE_i = \log \sum_k e^{S_{ik}}$, - -$$ -P_{ij} = \frac{e^{S_{ij}-LSE_i}}{\sum_k e^{S_{ik}-LSE_i}} -$$ - -Backward: - -$$ -\frac{\partial \mathcal{L}}{\partial S_{ij}} = \left( \frac{\partial \mathcal{L}}{\partial P_{ij}} - \sum_{k} \frac{\partial \mathcal{L}}{\partial P_{ik}} P_{ik} \right) P_{ij} -$$ - -Fully masked tile: $P=0 \Rightarrow dS=0$, all dependent GEMMs yield zero → safe to skip. - -### 4.4 Correctness of Skip -If a tile is entirely masked: -- Forward contributions vanish (outputs zero block). -- Backward intermediate tensors (S,P,dS,dP) logically zero; linear GEMMs on zero give zero. -Therefore removing those computations preserves gradients. - -## 5. Sparsity Logic & Performance - -### 5.1 Active Tile Detection -- Load mask tile into shared memory. -- Parallel OR reduction across threads / warps. -- any_active=false triggers skip branch. - -### 5.2 Performance Model -Let active fraction $p$, skip overhead ratio $\varepsilon$: - -$$ -\text{Speedup} \approx \frac{1}{p + (1-p)\varepsilon} -$$ - -Upper bound as $\varepsilon \to 0$: $1/p$. - -### 5.3 Influencing Factors -- Reduction latency vs early placement. -- Pipeline bubbles due to frequent divergent skip branches. -- Memory bandwidth—mask format (bit-packed future) reduces load footprint. - -### 5.4 Future Enhancements -- Earlier gating (before K/V loads). -- Adaptive density threshold. -- Bit-packed + warp ballot fast OR. -- Persistent CTA / work queue for load balancing. - -## 6. API Summary -Primary function: -`flash_dynamic_mask_attention(q, k, v, attn_mask=None, bias=None, softcap=None, causal=False, return_lse=False, ...)` - -Inputs: -- q/k/v: [B, H, L, D] (k/v possibly different length) -- attn_mask: block-aligned or internally sliced dynamic mask -- bias: optional additive bias -- softcap: optional scaling/clamp -Outputs: -- O (and optionally LSE when requested). - -Config: -- Block sizes (e.g., 64×64) via traits -- dtype: fp16 / bf16 (fp32 accum) -- enable_skip (default on) -- softcap scalar - -## 7. Memory & Synchronization -- Double buffering for streaming Q/K/V with `cp.async` fences. -- Aliasing: - - sMask reused as sP after consumption. - - sBias reused as sdS after gradient consumption. -- Critical barriers: - 1. Ensure mask fully read before overwriting region with P. - 2. Ensure dS fully consumed (dK finished) before alias region becomes bias. -Goal: minimize shared memory to enable larger tiles and higher occupancy. - -## 8. Numerical Stability -- LSE caching prevents overflow. -- FP16/BF16 inputs + FP32 accumulation. -- Skip path doesn't touch LSE entries of masked tiles. -- Validation scripts: forward/backward/grad equivalence across lengths, densities. - -## 9. Backward Compatibility & Upgrade -- Same Python API; upgrading from v0.3.0 requires no code changes for standard use. -- Internal layout symbols not part of public contract—custom kernels should revalidate alias expectations. -- Future runtime stats API planned (non-breaking). - -## 10. Known Limitations -- Only block-aligned sparsity (no arbitrary coordinate compression yet). -- Skip decision not yet moved ahead of K/V/dO loads. -- No fragment-level (Tensor Core tile) sparsity gating yet. -- No built-in distributed (multi-GPU) attention aggregation logic. -- Triton path feature parity still evolving. - -## 11. Testing & Validation -- Numerical: compare to dense `scaled_dot_product_attention`. -- Sparsity: random masks of varying density; compare skip vs forced-dense output. -- Regression: multi-block scenarios to guard prior dK NaN issue. -- Benchmarks: measure kernel time vs density p. - -## 12. Roadmap -1. Early mask gating pre-load. -2. Bit-packed mask + warp ballot OR. -3. Adaptive skip threshold (disable when p high). -4. Fragment-level MMA gating. -5. Persistent CTA + work queue. -6. Runtime counters: active/skipped tile counts, effective density. -7. Distributed integration examples. - -## 13. Safety & Robustness -- Input validation: shapes / dtypes / device alignment. -- Mask alignment and slicing. -- LSE + FP32 mitigate overflow. -- Barriers enforce safe alias lifecycle. -- Future fallback path for anomaly detection (planned). - -## 14. Acknowledgements -- Inspired by FlashAttention research and community. -- Contributors: core maintainers & commit authors (see git history). -- Ecosystem: PyTorch / CUTLASS / Triton. - -## 15. Version Delta Summary -Changes vs v0.3.0: -- Added forward skip bringing full forward/backward symmetry. -- Fixed block size condition + enhanced documentation. -- Shared memory alias + barrier ordering refinements (resolved dK NaNs). -- Skip branch pointer advancement semantics aligned with dense path. -- Comprehensive technical documentation and math derivations. - -## 16. Formula Quick Reference -1. Softmax: - -$$ -P_{ij} = \frac{e^{S_{ij}-LSE_i}}{\sum_k e^{S_{ik}-LSE_i}}, \quad LSE_i = \log \sum_k e^{S_{ik}} -$$ - -2. dS: - -$$ -dS_{ij} = \left(dP_{ij} - \sum_k dP_{ik} P_{ik}\right) P_{ij} -$$ - -3. Grad propagation: - -$$ -dQ = dS K,\quad dK = dS^T Q,\quad dV = P^T dO -$$ - -4. Skip predicate: - -$$ -any\_active = \bigvee_{(i,j)\in tile} mask_{ij} -$$ - -## 17. Alias & Barrier Snippet -``` -load mask -> sMask -any_active = or_reduce(sMask) -if any_active: - # reuse sMask region as sP after consumption - compute S - softmax -> write P into aliased region (sP) - ... -__syncthreads() # ensure dS consumed -# reuse sBias region as sdS in next iteration -``` - -## 18. Glossary -- Block / Tile: matrix sub-region processed per step. -- Skip: branch eliminating compute for fully masked tile. -- LSE: log-sum-exp cache for stability. -- Aliasing: reusing shared memory region across disjoint lifetimes. -- Fragment-level: granularity of Tensor Core MMA fragments. - -## 19. Integration -- HuggingFace-style example: modeling_doge.py -- Drop-in custom attention module inside transformer blocks. -- Planned: wrapper matching `scaled_dot_product_attention` signature for rapid substitution. - -## 20. Debug & Diagnostics -Planned env toggles to print: -- Active vs skipped tile counts -- Skip hit rate -- Average tile density -Common issues: -- NaNs: verify barrier / alias ordering not altered. -- Poor speedup: density p too high; disable skip to compare. - -## 21. Release Guidance -- Users gain block-sparse skip automatically after upgrade. -- For custom builds: confirm target GPU arch (sm80+) for Tensor Core efficiency. - -## 22. References -- Dao et al., FlashAttention series -- CUTLASS docs -- PyTorch Autograd internals - ---- - -## What's Changed -* Optimize sparse logic by @LoserCheems in https://github.com/SmallDoges/flash-dmattn/pull/131 -* Fix block size condition and enhance documentation by @LoserCheems in https://github.com/SmallDoges/flash-dmattn/pull/134 - - -**Full Changelog**: https://github.com/SmallDoges/flash-dmattn/compare/v0.3.0...v1.0.0 diff --git a/examples/modeling/modeling_doge.py b/examples/modeling/modeling_doge.py index d350f71..af1f791 100644 --- a/examples/modeling/modeling_doge.py +++ b/examples/modeling/modeling_doge.py @@ -45,9 +45,9 @@ from .configuration_doge import DogeConfig try: - from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward + from flash_sparse_attn.integrations.flash_sparse_attention import flash_dynamic_mask_attention_forward except ImportError: - print("Please install flash_dmattn to use this model: pip install flash-dmattn") + print("Please install flash_sparse_attn to use this model: pip install flash-sparse-attn") if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask diff --git a/flash_dmattn/__init__.py b/flash_sparse_attn/__init__.py similarity index 72% rename from flash_dmattn/__init__.py rename to flash_sparse_attn/__init__.py index 484e8d1..23309da 100644 --- a/flash_dmattn/__init__.py +++ b/flash_sparse_attn/__init__.py @@ -2,32 +2,32 @@ from typing import Optional -__version__ = "1.2.2" +__version__ = "1.2.3" # Import CUDA functions when available try: - from flash_dmattn.flash_dmattn_interface import flash_dmattn_func, flash_dmattn_varlen_func + from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func, flash_sparse_attn_varlen_func CUDA_AVAILABLE = True except ImportError: CUDA_AVAILABLE = False - flash_dmattn_func, flash_dmattn_varlen_func = None, None + flash_sparse_attn_func, flash_sparse_attn_varlen_func = None, None # Import Triton functions when available try: - from flash_dmattn.flash_dmattn_triton import triton_dmattn_func + from flash_sparse_attn.flash_sparse_attn_triton import triton_sparse_attn_func TRITON_AVAILABLE = True except ImportError: TRITON_AVAILABLE = False - triton_dmattn_func = None + triton_sparse_attn_func = None # Import Flex functions when available try: - from flash_dmattn.flash_dmattn_flex import flex_dmattn_func + from flash_sparse_attn.flash_sparse_attn_flex import flex_sparse_attn_func FLEX_AVAILABLE = True except ImportError: FLEX_AVAILABLE = False - flex_dmattn_func = None + flex_sparse_attn_func = None def get_available_backends(): @@ -42,9 +42,9 @@ def get_available_backends(): return backends -def flash_dmattn_func_auto(backend: Optional[str] = None, **kwargs): +def flash_sparse_attn_func_auto(backend: Optional[str] = None, **kwargs): """ - Flash Dynamic Mask Attention function with automatic backend selection. + Flash Sparse Attention function with automatic backend selection. Args: backend (str, optional): Backend to use ('cuda', 'triton', 'flex'). @@ -68,17 +68,17 @@ def flash_dmattn_func_auto(backend: Optional[str] = None, **kwargs): if backend == "cuda": if not CUDA_AVAILABLE: raise RuntimeError("CUDA backend is not available. Please build the CUDA extension.") - return flash_dmattn_func + return flash_sparse_attn_func elif backend == "triton": if not TRITON_AVAILABLE: raise RuntimeError("Triton backend is not available. Please install triton: pip install triton") - return triton_dmattn_func + return triton_sparse_attn_func elif backend == "flex": if not FLEX_AVAILABLE: raise RuntimeError("Flex backend is not available. Please install transformers: pip install transformers") - return flex_dmattn_func + return flex_sparse_attn_func else: raise ValueError(f"Unknown backend: {backend}. Available backends: {get_available_backends()}") @@ -88,10 +88,10 @@ def flash_dmattn_func_auto(backend: Optional[str] = None, **kwargs): "CUDA_AVAILABLE", "TRITON_AVAILABLE", "FLEX_AVAILABLE", - "flash_dmattn_func", - "flash_dmattn_varlen_func", - "triton_dmattn_func", - "flex_dmattn_func", + "flash_sparse_attn_func", + "flash_sparse_attn_varlen_func", + "triton_sparse_attn_func", + "flex_sparse_attn_func", "get_available_backends", - "flash_dmattn_func_auto", + "flash_sparse_attn_func_auto", ] diff --git a/flash_dmattn/flash_dmattn_triton_special.py b/flash_sparse_attn/flash_dmattn_triton.py similarity index 100% rename from flash_dmattn/flash_dmattn_triton_special.py rename to flash_sparse_attn/flash_dmattn_triton.py diff --git a/flash_dmattn/flash_dmattn_flex.py b/flash_sparse_attn/flash_sparse_attn_flex.py similarity index 98% rename from flash_dmattn/flash_dmattn_flex.py rename to flash_sparse_attn/flash_sparse_attn_flex.py index 379f984..055c2c8 100644 --- a/flash_dmattn/flash_dmattn_flex.py +++ b/flash_sparse_attn/flash_sparse_attn_flex.py @@ -77,4 +77,4 @@ def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): return attn_output -flex_dmattn_func = flex_attention_forward \ No newline at end of file +flex_sparse_attn_func = flex_attention_forward \ No newline at end of file diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_sparse_attn/flash_sparse_attn_interface.py similarity index 92% rename from flash_dmattn/flash_dmattn_interface.py rename to flash_sparse_attn/flash_sparse_attn_interface.py index d34adb0..9f8e676 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_sparse_attn/flash_sparse_attn_interface.py @@ -4,7 +4,7 @@ from packaging import version import torch -import flash_dmattn_cuda as flash_dmattn_gpu # type: ignore +import flash_sparse_attn_cuda as flash_sparse_attn_gpu # type: ignore def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: @@ -70,8 +70,8 @@ def wrap(func): _torch_register_fake_wrapper = noop_register_fake_wrapper -@_torch_custom_op_wrapper("flash_dmattn::_flash_dmattn_forward", mutates_args=(), device_types="cuda") -def _flash_dmattn_forward( +@_torch_custom_op_wrapper("flash_sparse_attn::_flash_sparse_attn_forward", mutates_args=(), device_types="cuda") +def _flash_sparse_attn_forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -83,7 +83,7 @@ def _flash_dmattn_forward( return_softmax: bool ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v, mask, bias = [maybe_contiguous(x) for x in (q, k, v, mask, bias)] - out, softmax_lse, S_dmask = flash_dmattn_gpu.fwd( + out, softmax_lse, S_dmask = flash_sparse_attn_gpu.fwd( q, k, v, @@ -99,8 +99,8 @@ def _flash_dmattn_forward( return out, softmax_lse, S_dmask -@_torch_register_fake_wrapper("flash_dmattn::_flash_dmattn_forward") -def _flash_dmattn_forward_fake( +@_torch_register_fake_wrapper("flash_sparse_attn::_flash_sparse_attn_forward") +def _flash_sparse_attn_forward_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -123,11 +123,11 @@ def _flash_dmattn_forward_fake( return out, softmax_lse, p -_wrapped_flash_dmattn_forward = _flash_dmattn_forward +_wrapped_flash_sparse_attn_forward = _flash_sparse_attn_forward -@_torch_custom_op_wrapper("flash_dmattn::_flash_dmattn_varlen_forward", mutates_args=(), device_types="cuda") -def _flash_dmattn_varlen_forward( +@_torch_custom_op_wrapper("flash_sparse_attn::_flash_sparse_attn_varlen_forward", mutates_args=(), device_types="cuda") +def _flash_sparse_attn_varlen_forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -145,7 +145,7 @@ def _flash_dmattn_varlen_forward( zero_tensors: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, softmax_lse, S_dmask = flash_dmattn_gpu.varlen_fwd( + out, softmax_lse, S_dmask = flash_sparse_attn_gpu.varlen_fwd( q, k, v, @@ -167,8 +167,8 @@ def _flash_dmattn_varlen_forward( return out, softmax_lse, S_dmask -@_torch_register_fake_wrapper("flash_dmattn::_flash_dmattn_varlen_forward") -def _flash_dmattn_varlen_forward_fake( +@_torch_register_fake_wrapper("flash_sparse_attn::_flash_sparse_attn_varlen_forward") +def _flash_sparse_attn_varlen_forward_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -200,11 +200,11 @@ def _flash_dmattn_varlen_forward_fake( return out, softmax_lse, p -_wrapped_flash_dmattn_varlen_forward = _flash_dmattn_varlen_forward +_wrapped_flash_sparse_attn_varlen_forward = _flash_sparse_attn_varlen_forward -@_torch_custom_op_wrapper("flash_dmattn::_flash_dmattn_backward", mutates_args=("dq", "dk", "dv", "dbias"), device_types="cuda") -def _flash_dmattn_backward( +@_torch_custom_op_wrapper("flash_sparse_attn::_flash_sparse_attn_backward", mutates_args=("dq", "dk", "dv", "dbias"), device_types="cuda") +def _flash_sparse_attn_backward( dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, @@ -229,7 +229,7 @@ def _flash_dmattn_backward( dv, dbias, softmax_d, - ) = flash_dmattn_gpu.bwd( + ) = flash_sparse_attn_gpu.bwd( dout, q, k, @@ -251,8 +251,8 @@ def _flash_dmattn_backward( return softmax_d -@_torch_register_fake_wrapper("flash_dmattn::_flash_dmattn_backward") -def _flash_dmattn_backward_fake( +@_torch_register_fake_wrapper("flash_sparse_attn::_flash_sparse_attn_backward") +def _flash_sparse_attn_backward_fake( dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, @@ -285,11 +285,11 @@ def _flash_dmattn_backward_fake( return softmax_d -_wrapped_flash_dmattn_backward = _flash_dmattn_backward +_wrapped_flash_sparse_attn_backward = _flash_sparse_attn_backward -@_torch_custom_op_wrapper("flash_dmattn::_flash_dmattn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") -def _flash_dmattn_varlen_backward( +@_torch_custom_op_wrapper("flash_sparse_attn::_flash_sparse_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") +def _flash_sparse_attn_varlen_backward( dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, @@ -315,7 +315,7 @@ def _flash_dmattn_varlen_backward( dk, dv, softmax_d, - ) = flash_dmattn_gpu.varlen_bwd( + ) = flash_sparse_attn_gpu.varlen_bwd( dout, q, k, @@ -339,8 +339,8 @@ def _flash_dmattn_varlen_backward( return softmax_d -@_torch_register_fake_wrapper("flash_dmattn::_flash_dmattn_varlen_backward") -def _flash_dmattn_varlen_backward_fake( +@_torch_register_fake_wrapper("flash_sparse_attn::_flash_sparse_attn_varlen_backward") +def _flash_sparse_attn_varlen_backward_fake( dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, @@ -376,7 +376,7 @@ def _flash_dmattn_varlen_backward_fake( return softmax_d -_wrapped_flash_dmattn_varlen_backward = _flash_dmattn_varlen_backward +_wrapped_flash_sparse_attn_varlen_backward = _flash_sparse_attn_varlen_backward class FlashDMAttnFunc(torch.autograd.Function): @@ -429,7 +429,7 @@ def forward( else: bias = torch.nn.functional.pad(bias, [0, seqlen_k_rounded - bias.shape[-1]]) - out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward( + out_padded, softmax_lse, S_dmask = _wrapped_flash_sparse_attn_forward( q, k, v, @@ -468,7 +468,7 @@ def backward( if head_size_og % 8 != 0: dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - _wrapped_flash_dmattn_backward( + _wrapped_flash_sparse_attn_backward( dout_padded, q, k, @@ -539,7 +539,7 @@ def forward( k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_varlen_forward( + out_padded, softmax_lse, S_dmask = _wrapped_flash_sparse_attn_varlen_forward( q, k, v, @@ -579,7 +579,7 @@ def backward(ctx, dout, *args): if head_size_og % 8 != 0: dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - _wrapped_flash_dmattn_varlen_backward( + _wrapped_flash_sparse_attn_varlen_backward( dout_padded, q, k, @@ -607,7 +607,7 @@ def backward(ctx, dout, *args): return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None -def flash_dmattn_func( +def flash_sparse_attn_func( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -684,7 +684,7 @@ def flash_dmattn_func( ) -def flash_dmattn_varlen_func( +def flash_sparse_attn_varlen_func( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_sparse_attn/flash_sparse_attn_triton.py similarity index 99% rename from flash_dmattn/flash_dmattn_triton.py rename to flash_sparse_attn/flash_sparse_attn_triton.py index 66141cb..eefbf76 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_sparse_attn/flash_sparse_attn_triton.py @@ -888,7 +888,7 @@ def _bwd_kernel( ) -def _flash_dmattn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False): +def _flash_sparse_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False): # shape constraints batch, seqlen_q, nheads, d = q.shape _, seqlen_k, nheads_k, _ = k.shape @@ -980,7 +980,7 @@ def _flash_dmattn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=Fal return o, lse, softmax_scale # softmax_scale could have been updated -def _flash_dmattn_backward( +def _flash_sparse_attn_backward( do, q, k, v, mask, bias, o, lse, softmax_scale=None, is_causal=False ): # Make sure that the last dimension is contiguous @@ -1195,7 +1195,7 @@ def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, is_causal=Fa else: attn_bias = torch.nn.functional.pad(attn_bias, [0, seqlen_k_rounded - attn_bias.shape[-1]]) - o, lse, ctx.softmax_scale = _flash_dmattn_forward( + o, lse, ctx.softmax_scale = _flash_sparse_attn_forward( query, key, value, @@ -1218,7 +1218,7 @@ def backward(ctx, do): if head_size_og % 8 != 0: do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_og % 8]) - dq, dk, dv, dbias = _flash_dmattn_backward( + dq, dk, dv, dbias = _flash_sparse_attn_backward( do_padded, query, key, @@ -1242,5 +1242,5 @@ def backward(ctx, do): return dq, dk, dv, None, dbias, None, None -def triton_dmattn_func(query, key, value, attn_mask=None, attn_bias=None, is_causal=False, softmax_scale=None): +def triton_sparse_attn_func(query, key, value, attn_mask=None, attn_bias=None, is_causal=False, softmax_scale=None): return FlashDMAttnFunc.apply(query, key, value, attn_mask, attn_bias, is_causal, softmax_scale) diff --git a/flash_dmattn/integrations/flash_dynamic_mask_attention.py b/flash_sparse_attn/integrations/flash_sparse_attention.py similarity index 86% rename from flash_dmattn/integrations/flash_dynamic_mask_attention.py rename to flash_sparse_attn/integrations/flash_sparse_attention.py index f842ae6..1dd6550 100644 --- a/flash_dmattn/integrations/flash_dynamic_mask_attention.py +++ b/flash_sparse_attn/integrations/flash_sparse_attention.py @@ -2,14 +2,14 @@ import torch -from .modeling_flash_dynamic_mask_attention_utils import _flash_dynamic_mask_attention_forward +from .modeling_flash_sparse_attention_utils import _flash_sparse_attention_forward from transformers.utils import logging logger = logging.get_logger(__name__) -def flash_dynamic_mask_attention_forward( +def flash_sparse_attention_forward( module: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -22,8 +22,8 @@ def flash_dynamic_mask_attention_forward( **kwargs, ) -> tuple[torch.Tensor, None]: """ - A wrapper around the _flash_dynamic_mask_attention_forward function to be used in - the FlashDynamicMaskAttention class from HuggingFace Transformers. + A wrapper around the _flash_sparse_attention_forward function to be used in + the FlashSparseAttention class from HuggingFace Transformers. Args: module (torch.nn.Module): The attention module. @@ -41,7 +41,7 @@ def flash_dynamic_mask_attention_forward( Includes: - is_causal (bool): Whether to apply a causal mask. - layer_idx (int): The index of the layer (for logging purposes). - - implementation (str): The implementation to use ("flash_dmattn" or None). + - implementation (str): The implementation to use ("flash_sparse_attn" or None). Returns: tuple[torch.Tensor, None]: The output tensor of shape (batch_size, seq_len, num_heads, head_dim) @@ -50,7 +50,7 @@ def flash_dynamic_mask_attention_forward( if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None: logger.warning_once( - "`flash_dynamic_mask_attention` does not support `output_attentions=True` or `head_mask`." + "`flash_sparse_attention` does not support `output_attentions=True` or `head_mask`." " Please set your attention to `eager` if you want any of these features." ) @@ -61,11 +61,11 @@ def flash_dynamic_mask_attention_forward( if any(dim == 0 for dim in query.shape): raise ValueError( "Tensor query has shape with a zero dimension.\n" - "FlashDynamicMaskAttention does not support inputs with dim=0.\n" + "FlashSparseAttention does not support inputs with dim=0.\n" "Please check your input shapes or use SDPA instead." ) - # FDMA uses non-transposed inputs + # FSA uses non-transposed inputs query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) @@ -90,7 +90,7 @@ def flash_dynamic_mask_attention_forward( if is_causal is None: is_causal = module.is_causal - attn_output = _flash_dynamic_mask_attention_forward( + attn_output = _flash_sparse_attention_forward( query, key, value, @@ -103,7 +103,7 @@ def flash_dynamic_mask_attention_forward( softcap=softcap, window_size=window_size, target_dtype=target_dtype, - implementation="flash_dmattn", + implementation="flash_sparse_attn", layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None, **kwargs, ) diff --git a/flash_dmattn/integrations/import_utils.py b/flash_sparse_attn/integrations/import_utils.py similarity index 97% rename from flash_dmattn/integrations/import_utils.py rename to flash_sparse_attn/integrations/import_utils.py index 583248b..70a91fe 100644 --- a/flash_dmattn/integrations/import_utils.py +++ b/flash_sparse_attn/integrations/import_utils.py @@ -80,11 +80,11 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ @lru_cache -def is_flash_dmattn_available(): +def is_flash_sparse_attn_available(): if not is_torch_available(): return False - if not _is_package_available("flash_dmattn"): + if not _is_package_available("flash_sparse_attn"): return False import torch diff --git a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py b/flash_sparse_attn/integrations/modeling_flash_sparse_attention_utils.py similarity index 85% rename from flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py rename to flash_sparse_attn/integrations/modeling_flash_sparse_attention_utils.py index c2638b8..b62b34e 100644 --- a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py +++ b/flash_sparse_attn/integrations/modeling_flash_sparse_attention_utils.py @@ -19,7 +19,7 @@ import torch import torch.nn.functional as F -from .import_utils import is_flash_dmattn_available +from .import_utils import is_flash_sparse_attn_available from transformers.utils import logging @@ -27,15 +27,15 @@ # `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves -_fdma_fn = None -_fdma_varlen_fn = None +_fsa_fn = None +_fsa_varlen_fn = None _pad_fn = None _unpad_fn = None _create_mask_fn = None # function that processes kwargs, generalized to handle any supported kwarg within the function _process_flash_kwargs_fn = None -# exceptions where hf API doesn't match the original FDMA API +# exceptions where hf API doesn't match the original FSA API _hf_api_to_flash_mapping = { "dropout": None, "sliding_window": None, @@ -44,61 +44,60 @@ def _lazy_imports(implementation: Optional[str]): """ - Lazy loads the respective flash dynamic mask attention implementations. + Lazy loads the respective flash sparse attention implementations. Return: - flash_attn_func: The base flash dynamic mask attention function. - flash_attn_varlen_func: The flash dynamic mask attention function supporting variable sequence lengths, e.g. for padding-free training. + flash_sparse_attn_func: The base flash sparse attention function. + flash_sparse_attn_varlen_func: The flash sparse attention function supporting variable sequence lengths, e.g. for padding-free training. pad_input: The function to pad inputs into one sequence and returning the respective kwargs. unpad_input: The function to unpad outputs based on the kwargs (from pad_input). """ - is_fdma = is_flash_dmattn_available() + is_fsa = is_flash_sparse_attn_available() - if (implementation == "flash_dmattn" and is_fdma) or (implementation is None and is_fdma): - from flash_dmattn import flash_dmattn_func, flash_dmattn_varlen_func - from flash_dmattn.utils.padding import pad_input, unpad_input - from flash_dmattn.utils.mask import create_mask + if (implementation == "flash_sparse_attn" and is_fsa) or (implementation is None and is_fsa): + from flash_sparse_attn import flash_sparse_attn_func, flash_sparse_attn_varlen_func + from flash_sparse_attn.utils.padding import pad_input, unpad_input + from flash_sparse_attn.utils.mask import create_mask - return flash_dmattn_func, flash_dmattn_varlen_func, pad_input, unpad_input, create_mask + return flash_sparse_attn_func, flash_sparse_attn_varlen_func, pad_input, unpad_input, create_mask def _lazy_define_process_function(flash_function): """ Depending on the version and kernel some features are not supported. Due to limitations in `torch.compile`, we opt to statically type which (optional) kwarg parameters are supported - within `_process_flash_dynamic_mask_attention_kwargs`. + within `_process_flash_sparse_attention_kwargs`. NOTE: While all supported kwargs are marked as `True`, everything else is marked as `False`. This might be confusing for kwargs that we use in any case, e.g. `is_causal`. """ flash_parameters = inspect.signature(flash_function).parameters - process_parameters = inspect.signature(_process_flash_dynamic_mask_attention_kwargs).parameters + process_parameters = inspect.signature(_process_flash_sparse_attention_kwargs).parameters supports_mapping = {} for param in process_parameters: - fdma_param = _hf_api_to_flash_mapping.get(param, param) - supports_mapping[fdma_param] = fdma_param in flash_parameters + fsa_param = _hf_api_to_flash_mapping.get(param, param) + supports_mapping[fsa_param] = fsa_param in flash_parameters - return partial(_process_flash_dynamic_mask_attention_kwargs, supports_mapping=supports_mapping) + return partial(_process_flash_sparse_attention_kwargs, supports_mapping=supports_mapping) -def lazy_import_flash_dynamic_mask_attention(implementation: Optional[str], force_import: Optional[bool] = False): +def lazy_import_flash_sparse_attention(implementation: Optional[str], force_import: Optional[bool] = False): """ - Lazily import flash dmattn and return the respective functions + flags. + Lazily import flash sparse attention and return the respective functions + flags. NOTE: For fullgraph, this needs to be called before compile, while no fullgraph can work without preloading. See `load_and_register_kernel` in `integrations.hub_kernels`. """ - global _fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn - if force_import or any(k is None for k in [_fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn]): - _fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn = _lazy_imports(implementation) + global _fsa_fn, _fsa_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn + if force_import or any(k is None for k in [_fsa_fn, _fsa_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn]): + _fsa_fn, _fsa_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn = _lazy_imports(implementation) global _process_flash_kwargs_fn if force_import or _process_flash_kwargs_fn is None: - _process_flash_kwargs_fn = _lazy_define_process_function(_fdma_varlen_fn) - - return (_fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn), _process_flash_kwargs_fn + _process_flash_kwargs_fn = _lazy_define_process_function(_fsa_varlen_fn) + return (_fsa_fn, _fsa_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn), _process_flash_kwargs_fn def _index_first_axis(tensor, indices): @@ -131,7 +130,7 @@ def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.T """ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - # NOTE: Similar to the `.item()` in prepare_fdma_kwargs_from_position_ids, with torch compile, + # NOTE: Similar to the `.item()` in prepare_fsa_kwargs_from_position_ids, with torch compile, # this might cause a graph break max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) @@ -221,9 +220,9 @@ def _upad_input( ) -def prepare_fdma_kwargs_from_position_ids(position_ids): +def prepare_fsa_kwargs_from_position_ids(position_ids): """ - This function returns all the necessary kwargs to call `flash_attn_varlen_func` extracted from position_ids. + This function returns all the necessary kwargs to call `flash_sparse_attn_varlen_func` extracted from position_ids. Arguments: position_ids (`torch.Tensor`): @@ -267,7 +266,7 @@ def prepare_fdma_kwargs_from_position_ids(position_ids): def _prepare_from_posids(query, key, value, position_ids): """ - This function returns necessary arguments to call `flash_attn_varlen_func`. + This function returns necessary arguments to call `flash_sparse_attn_varlen_func`. All three query, key, value states will be flattened. Cumulative lengths of each examples in the batch will be extracted from position_ids. NOTE: ideally cumulative lengths should be prepared at the data collator stage @@ -298,7 +297,7 @@ def _prepare_from_posids(query, key, value, position_ids): key = key.contiguous().view(-1, key.size(-2), key.size(-1)) value = value.contiguous().view(-1, value.size(-2), value.size(-1)) - (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fdma_kwargs_from_position_ids(position_ids) + (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fsa_kwargs_from_position_ids(position_ids) return (query, key, value, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k)) @@ -319,7 +318,7 @@ def _is_packed_sequence(position_ids, batch_size): return batch_size == 1 and (increasing_position_sequences - position_ids).abs().sum().bool() -def fdma_peft_integration_check( +def fsa_peft_integration_check( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -333,16 +332,16 @@ def fdma_peft_integration_check( This might slowdown training & inference so it is recommended to not cast the LayerNorms! """ if target_dtype and q.dtype == torch.float32: - logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-dmattn compatibility.") + logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash_sparse_attn compatibility.") q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype) if bias is not None: bias = bias.to(target_dtype) return q, k, v, bias -class FlashDynamicMaskAttentionKwargs(TypedDict, total=False): +class FlashSparseAttentionKwargs(TypedDict, total=False): """ - Keyword arguments for Flash Dynamic Mask Attention with Compile. + Keyword arguments for Flash Sparse Attention with Compile. Attributes: cu_seq_lens_q (`torch.LongTensor`, *optional*) @@ -361,7 +360,7 @@ class FlashDynamicMaskAttentionKwargs(TypedDict, total=False): max_length_k: Optional[int] -def _process_flash_dynamic_mask_attention_kwargs( +def _process_flash_sparse_attention_kwargs( query_length: int, key_length: int, is_causal: bool, @@ -376,7 +375,7 @@ def _process_flash_dynamic_mask_attention_kwargs( """ Returns a set of kwargs that are passed down to the according flash attention function based on requested features and whether it is supported - depends on the version and kernel implementation - which is dynamically configured at `lazy_import_flash_dynamic_mask_attention`. The (un)supported features can be + which is dynamically configured at `lazy_import_flash_sparse_attention`. The (un)supported features can be inspected in `supports_mapping`, see `_lazy_define_process_function` for more details. Args: @@ -410,7 +409,7 @@ def _process_flash_dynamic_mask_attention_kwargs( if supports_mapping["deterministic"]: flash_kwargs["deterministic"] = ( - deterministic if deterministic is not None else os.getenv("FLASH_DMATTN_DETERMINISTIC", "0") == "1" + deterministic if deterministic is not None else os.getenv("FLASH_SPARSE_ATTENTION_DETERMINISTIC", "0") == "1" ) if supports_mapping["softcap"] and softcap is not None: @@ -423,7 +422,7 @@ def _process_flash_dynamic_mask_attention_kwargs( return flash_kwargs -def _flash_dynamic_mask_attention_forward( +def _flash_sparse_attention_forward( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, @@ -446,18 +445,18 @@ def _flash_dynamic_mask_attention_forward( **kwargs, ): """ - Calls the forward method of Flash Dynamic Mask Attention - if the input hidden states contain at least one padding token + Calls the forward method of Flash Sparse Attention - if the input hidden states contain at least one padding token first unpad the input, then computes the attention scores and pad the final attention scores. - (Optional) kwargs are described further in `_process_flash_dynamic_mask_attention_kwargs` and `FlashDynamicMaskAttentionKwargs`. + (Optional) kwargs are described further in `_process_flash_sparse_attention_kwargs` and `FlashSparseAttentionKwargs`. Args: query_states (`torch.Tensor`): - Input query states to be passed to Flash DMATTN API + Input query states to be passed to FSA API key_states (`torch.Tensor`): - Input key states to be passed to Flash DMATTN API + Input key states to be passed to FSA API value_states (`torch.Tensor`): - Input value states to be passed to Flash DMATTN API + Input value states to be passed to FSA API attention_mask (`torch.Tensor`, *optional*): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. @@ -476,10 +475,10 @@ def _flash_dynamic_mask_attention_forward( "If shape of attention_mask is (batch_size, seq_len), attention_bias has to be None." ) - (fdma_fn, fdma_varlen_fn, pad_fn, unpad_fn, create_mask_fn), process_flash_kwargs_fn = lazy_import_flash_dynamic_mask_attention(implementation) + (fsa_fn, fsa_varlen_fn, pad_fn, unpad_fn, create_mask_fn), process_flash_kwargs_fn = lazy_import_flash_sparse_attention(implementation) # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op - query_states, key_states, value_states, attention_bias = fdma_peft_integration_check( + query_states, key_states, value_states, attention_bias = fsa_peft_integration_check( query_states, key_states, value_states, attention_bias, target_dtype ) @@ -495,15 +494,15 @@ def _flash_dynamic_mask_attention_forward( **kwargs, ) - # We will use `fdma_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases: + # We will use `fsa_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases: # Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`. # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to - # use `fdma_varlen_fn` knowing we already have all necessary the kwargs. + # use `fsa_varlen_fn` knowing we already have all necessary the kwargs. # # NOTE: it is user's responsibility to take care of flattening `position_ids` if that's needed by the model. # See #39121 for more information. - is_fdma_with_position_ids = _is_packed_sequence(position_ids, batch_size=query_states.size(0)) - is_fdma_with_varlen_kwargs = all( + is_fsa_with_position_ids = _is_packed_sequence(position_ids, batch_size=query_states.size(0)) + is_fsa_with_varlen_kwargs = all( kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k) ) @@ -518,7 +517,7 @@ def _flash_dynamic_mask_attention_forward( if "mps" in str(q.device): cu_seq_lens_k = cu_seq_lens_k.clone() - out_unpad = fdma_varlen_fn( + out_unpad = fsa_varlen_fn( q, k, v, @@ -534,7 +533,7 @@ def _flash_dynamic_mask_attention_forward( out = pad_fn(out_unpad, indices_q, query_states.size(0), query_length) # Padding free, i.e. sequences flattened into one total sequence - elif is_fdma_with_varlen_kwargs or is_fdma_with_position_ids: + elif is_fsa_with_varlen_kwargs or is_fsa_with_position_ids: if cu_seq_lens_q is None or cu_seq_lens_k is None: q, k, v, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _prepare_from_posids( query_states, key_states, value_states, position_ids @@ -549,7 +548,7 @@ def _flash_dynamic_mask_attention_forward( if "mps" in str(q.device): cu_seq_lens_k = cu_seq_lens_k.clone() - out = fdma_varlen_fn( + out = fsa_varlen_fn( q, k, v, @@ -583,7 +582,7 @@ def _flash_dynamic_mask_attention_forward( min_dtype=torch.finfo(attention_bias.dtype).min, ) - out = fdma_fn( + out = fsa_fn( query_states, key_states, value_states, diff --git a/flash_dmattn/utils/mask.py b/flash_sparse_attn/utils/mask.py similarity index 99% rename from flash_dmattn/utils/mask.py rename to flash_sparse_attn/utils/mask.py index 491e270..4905835 100644 --- a/flash_dmattn/utils/mask.py +++ b/flash_sparse_attn/utils/mask.py @@ -172,7 +172,7 @@ def create_mask( type: str = "topk", ) -> torch.Tensor: r""" - This function creates a mask tensor for Flash Dynamic Mask Attention. + This function creates a mask tensor for Flash Sparse Attention. If attention_mask is not of shape (batch_size, seq_len), it needs to match the shape of attention_bias. diff --git a/flash_dmattn/utils/padding.py b/flash_sparse_attn/utils/padding.py similarity index 100% rename from flash_dmattn/utils/padding.py rename to flash_sparse_attn/utils/padding.py diff --git a/pyproject.toml b/pyproject.toml index 97ffc48..1158eec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,9 +11,9 @@ requires = [ build-backend = "setuptools.build_meta" [project] -name = "flash-dmattn" +name = "flash-sparse-attn" dynamic = ["version"] -description = "Flash Dynamic Mask Attention: Fast and Memory-Efficient Trainable Dynamic Mask Sparse Attention" +description = "Flash Sparse Attention: Fast and Memory-Efficient Trainable Dynamic Mask Sparse Attention" readme = "README.md" license = { file = "LICENSE" } authors = [ @@ -40,9 +40,9 @@ classifiers = [ ] [project.urls] -Homepage = "https://github.com/SmallDoges/flash-dmattn" -Source = "https://github.com/SmallDoges/flash-dmattn" -Issues = "https://github.com/SmallDoges/flash-dmattn/issues" +Homepage = "https://github.com/SmallDoges/flash-sparse-attention" +Source = "https://github.com/SmallDoges/flash-sparse-attention" +Issues = "https://github.com/SmallDoges/flash-sparse-attention/issues" [project.optional-dependencies] triton = [ @@ -69,11 +69,11 @@ dev = [ ] [tool.setuptools.dynamic] -version = { attr = "flash_dmattn.__version__" } +version = { attr = "flash_sparse_attn.__version__" } [tool.setuptools.packages.find] where = ["."] -include = ["flash_dmattn*"] +include = ["flash_sparse_attn*"] exclude = [ "build", "csrc", @@ -82,10 +82,10 @@ exclude = [ "dist", "docs", "benchmarks", - "flash_dmattn.egg-info" + "flash_sparse_attn.egg-info" ] [tool.setuptools.package-data] -flash_dmattn = ["*.py"] +flash_sparse_attn = ["*.py"] [tool.setuptools] diff --git a/setup.py b/setup.py index 6bda267..6c2a8b0 100644 --- a/setup.py +++ b/setup.py @@ -34,19 +34,19 @@ # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) -PACKAGE_NAME = "flash_dmattn" +PACKAGE_NAME = "flash_sparse_attn" BASE_WHEEL_URL = ( - "https://github.com/SmallDoges/flash-dmattn/releases/download/{tag_name}/{wheel_name}" + "https://github.com/SmallDoges/flash-sparse-attention/releases/download/{tag_name}/{wheel_name}" ) # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation # Also useful when user only wants Triton/Flex backends without CUDA compilation -FORCE_BUILD = os.getenv("FLASH_DMATTN_FORCE_BUILD", "FALSE") == "TRUE" -SKIP_CUDA_BUILD = os.getenv("FLASH_DMATTN_SKIP_CUDA_BUILD", "FALSE") == "TRUE" +FORCE_BUILD = os.getenv("FLASH_SPARSE_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" +SKIP_CUDA_BUILD = os.getenv("FLASH_SPARSE_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI -FORCE_CXX11_ABI = os.getenv("FLASH_DMATTN_FORCE_CXX11_ABI", "FALSE") == "TRUE" +FORCE_CXX11_ABI = os.getenv("FLASH_SPARSE_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" # Auto-detect if user wants only Triton/Flex backends based on pip install command # This helps avoid unnecessary CUDA compilation when user only wants Python backends @@ -69,7 +69,7 @@ def should_skip_cuda_build(): if has_triton_or_flex and not has_all_or_dev: print("Detected Triton/Flex-only installation. Skipping CUDA compilation.") - print("Set FLASH_DMATTN_FORCE_BUILD=TRUE to force CUDA compilation.") + print("Set FLASH_SPARSE_ATTENTION_FORCE_BUILD=TRUE to force CUDA compilation.") return True return False @@ -79,7 +79,7 @@ def should_skip_cuda_build(): @functools.lru_cache(maxsize=None) def cuda_archs(): - return os.getenv("FLASH_DMATTN_CUDA_ARCHS", "80;90;100").split(";") + return os.getenv("FLASH_SPARSE_ATTENTION_CUDA_ARCHS", "80;90;100").split(";") def detect_preferred_sm_arch() -> Optional[str]: @@ -154,14 +154,14 @@ def append_nvcc_threads(nvcc_extra_args): TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) - check_if_cuda_home_none("flash_dmattn") + check_if_cuda_home_none("flash_sparse_attn") # Check, if CUDA11 is installed for compute capability 8.0 cc_flag = [] if CUDA_HOME is not None: _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version < Version("11.7"): raise RuntimeError( - "Flash Dynamic Mask Attention is only supported on CUDA 11.7 and above. " + "Flash Sparse Attention is only supported on CUDA 11.7 and above. " "Note: make sure nvcc has a supported version by running nvcc -V." ) @@ -218,20 +218,20 @@ def append_nvcc_threads(nvcc_extra_args): ext_modules.append( CUDAExtension( - name="flash_dmattn_cuda", + name="flash_sparse_attn_cuda", sources=( [ - "csrc/flash_dmattn/flash_api.cpp", + "csrc/flash_sparse_attn/flash_api.cpp", ] - + sorted(glob.glob("csrc/flash_dmattn/src/instantiations/flash_*.cu")) + + sorted(glob.glob("csrc/flash_sparse_attn/src/instantiations/flash_*.cu")) ), extra_compile_args={ "cxx": compiler_c17_flag, "nvcc": append_nvcc_threads(nvcc_flags + cc_flag), }, include_dirs=[ - Path(this_dir) / "csrc" / "flash_dmattn", - Path(this_dir) / "csrc" / "flash_dmattn" / "src", + Path(this_dir) / "csrc" / "flash_sparse_attn", + Path(this_dir) / "csrc" / "flash_sparse_attn" / "src", Path(this_dir) / "csrc" / "cutlass" / "include", ], ) @@ -239,10 +239,10 @@ def append_nvcc_threads(nvcc_extra_args): def get_package_version(): - with open(Path(this_dir) / "flash_dmattn" / "__init__.py", "r") as f: + with open(Path(this_dir) / "flash_sparse_attn" / "__init__.py", "r") as f: version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) public_version = ast.literal_eval(version_match.group(1)) - local_version = os.environ.get("FLASH_DMATTN_LOCAL_VERSION") + local_version = os.environ.get("FLASH_SPARSE_ATTENTION_LOCAL_VERSION") if local_version: return f"{public_version}+{local_version}" else: