Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 77 additions & 43 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@ Flash-DMA is a high-performance attention implementation that integrates Flash A

## Key Features

- **Dynamic Sparse Attention**: Dynamically selects the most relevant keys for each query, reducing computational complexity from $O(N^2)$ to $O(N \cdot w)$ where $w \ll N$, supporting trainable sparse patterns.
- **Memory Efficiency**: Maintains Flash Attention's $O(N)$ memory complexity without instantiating the full attention matrix.
- **CUDA Deep Optimization**: Utilizes custom CUDA kernels with shared memory aliasing, pipelined prefetching, and block skipping for high throughput and low memory access overhead.
- **Extremely Long Context Support**: Handles 128K+ token sequences efficiently through dynamic mask windowing while preserving accuracy.
- **Learnable Bias**: Built-in learnable attention bias and its gradient path dbias, eliminating the need for additional external operators.
- **Fusion-Friendly Training**: Both forward and backward passes support block-level zero-mask skipping, further reducing computation in sparse scenarios.
### 🎯 Core Kernel Advantages
- **4D Mask & Bias Support**: Native support for `(batch_size, num_kv_heads, query_len, key_len)` shaped attention mask and attention bias tensors
- **Intelligent Computation Skipping**: Block-level automatic skipping mechanism based on masks, completely bypassing computation and memory access for zero-mask blocks
- **Complete Gradient Support**: Built-in full gradient computation path for attention bias, supporting end-to-end training

### 🚀 Performance & Efficiency
- **Dynamic Sparse Attention**: Dynamically selects the most relevant keys for each query, reducing computational complexity from $O(N^2)$ to $O(N \cdot w)$ where $w \ll N$, supporting trainable sparse structures
- **Memory Efficiency**: Maintains Flash Attention's $O(N)$ memory complexity without instantiating the full attention matrix
- **CUDA Deep Optimization**: Custom CUDA kernels with shared memory aliasing, pipelined prefetching, and block skipping for high throughput and low memory access overhead
- **Extremely Long Context Support**: Handles 128K+ token sequences efficiently through dynamic mask windowing while preserving accuracy


## Performance
Expand Down Expand Up @@ -145,74 +149,104 @@ MAX_JOBS=4 pip install . --no-build-isolation

## Quick Start

### Basic Usage

```python
import torch
from flash_dmattn import flash_dmattn_func_auto
import math

# Setup
batch_size, seq_len, num_heads, head_dim = 2, 4096, 16, 128
batch_size, seq_len, num_heads, num_kv_heads, head_dim = 1, 256, 2, 1, 64
keep_window_size = 128
device = torch.device('cuda')
dtype = torch.bfloat16
min_dtype = torch.finfo(dtype).min # dtype minimum value

# Input tensors
query = torch.randn(batch_size, seq_len, num_heads, head_dim,
device=device, dtype=dtype)
key = torch.randn(batch_size, seq_len, num_heads, head_dim,
device=device, dtype=dtype)
value = torch.randn(batch_size, seq_len, num_heads, head_dim,
device=device, dtype=dtype)
query = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
key = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype)
value = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype)

# Create mask and bias for sparse attention
attention_bias = torch.randn(batch_size, num_heads, seq_len, seq_len,
device=device, dtype=dtype)
attention_mask = torch.ones(batch_size, num_heads, seq_len, seq_len,
device=device, dtype=dtype)
attention_mask = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype)
attention_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype)

# Apply dynamic masking (keep top-k for long sequences)
keep_window_size = 2048
# Generate sparse mask based on bias
if seq_len > keep_window_size:
# Select top-k most important keys for each query
topk_indices = torch.topk(attention_bias, keep_window_size, dim=-1,
largest=True, sorted=False).indices
attention_mask.zero_()
attention_mask.scatter(-1, topk_indices, 1.0)

# Select backend
topk_values, topk_indices = torch.topk(
attention_bias, keep_window_size, dim=-1,
largest=True, sorted=False
)
# Generate valid top-k mask
valid_topk = (topk_values != min_dtype).to(dtype)
attention_mask = torch.zeros_like(attention_bias, dtype=dtype, device=attention_bias.device)
attention_mask = attention_mask.scatter(-1, topk_indices, valid_topk)
attention_bias = attention_bias.masked_fill(attention_mask == 0.0, min_dtype)

# Select FDMA kernel
flash_dmattn_func = flash_dmattn_func_auto(backend="cuda")

# Run Flash Dynamic Mask Attention
output = flash_dmattn_func(
q=query,
k=key,
v=value,
query=query,
key=key,
value=value,
attn_mask=attention_mask,
attn_bias=attention_bias,
is_causal=True,
scale=1.0/math.sqrt(head_dim),
)

print(f"Output shape: {output.shape}") # [2, 4096, 16, 128]
print(f"Output shape: {output.shape}") # [1, 256, 2, 64]
```

### Gradient Computation Example

## How It Works
```python
# Enable gradient computation
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
attention_bias.requires_grad_(True)

Flash-DMA combines two complementary techniques:
# Forward pass
output = flash_dmattn_func(
query=query, key=key, value=value,
attn_mask=attention_mask,
attn_bias=attention_bias,
is_causal=True,
scale=1.0/math.sqrt(head_dim)
)

# Backward pass
loss = output.sum()
loss.backward()

print(f"Query gradient shape: {query.grad.shape}")
print(f"Key gradient shape: {key.grad.shape}")
print(f"Value gradient shape: {value.grad.shape}")
print(f"Bias gradient shape: {attention_bias.grad.shape}")
```

- **Dynamic Mask Attention**: Computes relevance scores for keys and selects only the most important ones for attention computation
- **Flash Attention**: Processes attention in blocks to reduce memory usage and HBM access

### The Integration Approach
## How It Works

Flash-DMA integrates the efficient memory access patterns of Flash Attention with the sparse computation capabilities of dynamic mask attention to achieve an efficient attention mechanism.

The integration happens at the CUDA kernel level with several key components:
### Core Technology Integration

- **ZOH States**: Pre-computed importance scores for key selection
- **Active Masks**: Binary masks indicating which keys should be considered for each query
- **Sparse Skipping**: Custom CUDA kernels for efficient sparse attention computation
- **Block-Based Processing**: Maintains Flash Attention's block-based approach for memory efficiency
- **🎯 Native 4D Mask & Bias Support**: Kernels directly process `(batch_size, num_kv_heads, query_len, key_len)` shaped tensors
- **⚡ Block-level Intelligent Skipping**: Unified OR-reduction skipping logic based on masks, completely avoiding computation and memory access for zero blocks
- **🔄 Complete Gradient Chain**: Built-in attention bias gradient computation (dbias) supporting end-to-end differentiable training

This creates a hybrid attention mechanism that achieves both memory and computational efficiency for long sequences.
### Key Optimization Strategies

1. **Unified Skip Logic**: Forward and backward passes use the same block-level skip decisions
2. **Memory Access Optimization**: K/V data loaded only when `OR(mask_block) == true`
3. **Gradient Path Completeness**: dbias gradient computation fully fused in backward kernels
4. **Shared Memory Reuse**: sMask ↔ sP, sBias ↔ sdS intelligent aliasing


## Documentation
Expand All @@ -229,7 +263,7 @@ This creates a hybrid attention mechanism that achieves both memory and computat

```bash
# Clone with submodules
git clone --recursive https://github.com/SmallDoges/flash-dmattn.git
git clone https://github.com/SmallDoges/flash-dmattn.git
cd flash-dmattn

# Build in development mode
Expand Down Expand Up @@ -297,8 +331,8 @@ Tests backward pass implementation and gradient equivalence.
**Compilation Errors**
```bash
# Ensure CUDA_HOME is set correctly
echo $CUDA_HOME # Linux/Mac
echo $env:CUDA_HOME # Windows PowerShell
echo $CUDA_HOME # Linux/Mac
echo $env:CUDA_HOME # Windows PowerShell

# Check CUDA toolkit version
nvcc --version
Expand Down
122 changes: 78 additions & 44 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@ Flash-DMA 是一个高性能的注意力实现,将 Flash Attention 的内存

## 主要特性

- **动态稀疏注意力**: 为每个查询动态选择最重要的键,将计算复杂度从 $O(N^2)$ 降低到 $O(N \cdot w)$,其中 $w \ll N$,支持可训练的稀疏结构。
- **内存效率**: 保持 Flash Attention 的 $O(N)$ 内存复杂度,无需实例化完整的注意力矩阵。
- **CUDA 深度优化**:使用自定义 CUDA Kernel, 含共享内存别名、流水线预取、按块跳过, 实现高吞吐与低访存开销。
- **超长上下文支持**:通过动态掩码窗口裁剪,在保持精度的前提下支撑 128K+ 令牌级别的上下文处理。
- **可学习偏置**:内置可学习 attention bias 及其梯度反向路径 dbias,无需额外外部算子。
- **融合式训练友好**:正向与反向过程均支持 block 级全零掩码跳过,在稀疏场景进一步降低计算开销。
### 🎯 核心内核优势
- **4D Mask & Bias 支持**: 原生支持 `(batch_size, num_kv_heads, query_len, key_len)` 形状的 attention_mask 和 attention_bias 张量
- **智能计算跳过**: 基于 attention_mask 的 block-level 自动跳过机制,完全跳过全零 mask 区块的计算和内存访问
- **完整梯度支持**: 内置 attention_bias 的完整梯度计算路径,支持端到端训练

### 🚀 性能与效率
- **动态稀疏注意力**: 为每个查询动态选择最重要的键,将计算复杂度从 $O(N^2)$ 降低到 $O(N \cdot w)$,其中 $w \ll N$, 支持可训练的稀疏结构
- **内存效率**: 保持 Flash Attention 的 $O(N)$ 内存复杂度,无需实例化完整的注意力矩阵
- **CUDA 深度优化**: 自定义 CUDA 内核,含共享内存别名、流水线预取、按块跳过,实现高吞吐与低访存开销
- **超长上下文支持**: 通过动态掩码窗口裁剪,在保持精度的前提下支撑 128K+ 令牌级别的上下文处理


## 性能
Expand Down Expand Up @@ -145,43 +149,46 @@ MAX_JOBS=4 pip install . --no-build-isolation

## 快速开始

### 基本用法

```python
import torch
from flash_dmattn import flash_dmattn_func_auto
import math

# 设置
batch_size, seq_len, num_heads, head_dim = 2, 4096, 16, 128
batch_size, seq_len, num_heads, num_kv_heads, head_dim = 1, 256, 2, 1, 64
keep_window_size = 128
device = torch.device('cuda')
dtype = torch.bfloat16
min_dtype = torch.finfo(dtype).min # dtype 的最小值

# 输入张量
query = torch.randn(batch_size, seq_len, num_heads, head_dim,
device=device, dtype=dtype)
key = torch.randn(batch_size, seq_len, num_heads, head_dim,
device=device, dtype=dtype)
value = torch.randn(batch_size, seq_len, num_heads, head_dim,
device=device, dtype=dtype)

# 为稀疏注意力创建掩码和偏置
attention_bias = torch.randn(batch_size, num_heads, seq_len, seq_len,
device=device, dtype=dtype)
attention_mask = torch.ones(batch_size, num_heads, seq_len, seq_len,
device=device, dtype=dtype)

# 应用动态掩码(为长序列保留 top-k)
keep_window_size = 2048
query = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
key = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype)
value = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype)

# 为稀疏注意力创建 mask 和 bias
attention_mask = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype)
attention_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype)

# 基于 bias 生成稀疏 mask
if seq_len > keep_window_size:
# 为每个查询选择 top-k 最重要的键
topk_indices = torch.topk(attention_bias, keep_window_size, dim=-1,
largest=True, sorted=False).indices
attention_mask.zero_()
attention_mask.scatter(-1, topk_indices, 1.0)

# 选择后端
topk_values, topk_indices = torch.topk(
attention_bias, keep_window_size, dim=-1,
largest=True, sorted=False
)
# 生成有效的 top-k mask
valid_topk = (topk_values != min_dtype).to(dtype)
attention_mask = torch.zeros_like(attention_bias, dtype=dtype, device=attention_bias.device)
attention_mask = attention_mask.scatter(-1, topk_indices, valid_topk)
attention_bias = attention_bias.masked_fill(attention_mask == 0.0, min_dtype)

# 选择 FDMA 内核
flash_dmattn_func = flash_dmattn_func_auto(backend="cuda")

# 运行 Flash 动态掩码注意力
# 运行 FDMA
output = flash_dmattn_func(
query=query,
key=key,
Expand All @@ -192,27 +199,54 @@ output = flash_dmattn_func(
scale=1.0/math.sqrt(head_dim),
)

print(f"输出形状: {output.shape}") # [2, 4096, 16, 128]
print(f"输出形状: {output.shape}") # [1, 256, 2, 64]
```

### 梯度计算示例

## 工作原理
```python
# 开启梯度计算
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
attention_bias.requires_grad_(True)

Flash-DMA 结合了两种互补的技术:
# 前向传播
output = flash_dmattn_func(
query=query, key=key, value=value,
attn_mask=attention_mask,
attn_bias=attention_bias,
is_causal=True,
scale=1.0/math.sqrt(head_dim)
)

- **动态掩码注意力**: 计算键的相关性分数,并仅选择最重要的键进行注意力计算
- **Flash Attention**: 分块处理注意力以减少内存使用和 HBM 访问
# 反向传播
loss = output.sum()
loss.backward()

### 集成方法
print(f"Query 梯度形状: {query.grad.shape}")
print(f"Key 梯度形状: {key.grad.shape}")
print(f"Value 梯度形状: {value.grad.shape}")
print(f"Bias 梯度形状: {attention_bias.grad.shape}")
```

集成发生在 CUDA 内核层面,具有几个关键组件:

- **ZOH 状态**: 预计算的键选择重要性分数
- **活跃掩码**: 指示每个查询应考虑哪些键的二进制掩码
- **稀疏跳过**: 高效稀疏注意力计算的自定义 CUDA 内核
- **分块处理**: 保持 Flash Attention 的分块方法以提高内存效率
## 工作原理

Flash-DMA 通过将 Flash Attention 的高效内存访问模式与动态掩码注意力的稀疏计算能力相结合,实现了高效的注意力机制。

这创建了一种混合注意力机制,为长序列实现了内存和计算效率。
### 核心技术融合

- **🎯 4D Mask & Bias 原生支持**: 内核直接处理 `(batch_size, num_kv_heads, query_len, key_len)` 形状的张量
- **⚡ Block-level 智能跳过**: 基于 mask 的统一 OR-reduction 跳过逻辑,完全避免全零区块的计算和内存访问
- **🔄 完整梯度链路**: 内置 attention bias 梯度计算,支持端到端可微分训练

### 关键优化策略

1. **统一跳过逻辑**: 前向和反向过程使用相同的 block-level 跳过决策
2. **内存访问优化**: 只有当 `OR(mask_block) == true` 时才加载 K/V 数据
3. **梯度路径完整性**: dbias 梯度计算完全融合在反向内核中
4. **共享内存复用**: sMask ↔ sP, sBias ↔ sdS 智能别名化


## 文档
Expand All @@ -229,7 +263,7 @@ Flash-DMA 结合了两种互补的技术:

```bash
# 克隆包含子模块
git clone --recursive https://github.com/SmallDoges/flash-dmattn.git
git clone https://github.com/SmallDoges/flash-dmattn.git
cd flash-dmattn

# 在开发模式下构建
Expand Down Expand Up @@ -296,8 +330,8 @@ python benchmarks/grad_equivalence.py
**编译错误**
```bash
# 确保 CUDA_HOME 设置正确
echo $CUDA_HOME # Linux/Mac
echo $env:CUDA_HOME # Windows PowerShell
echo $CUDA_HOME # Linux/Mac
echo $env:CUDA_HOME # Windows PowerShell

# 检查 CUDA 工具包版本
nvcc --version
Expand Down
Loading