Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ pip install .

```python
import torch
from flash_dmattn import flash_dmattn_func
from flash_dmattn import flash_dmattn_func_auto
import math

# Setup
batch_size, seq_len, num_heads, head_dim = 2, 4096, 12, 128
batch_size, seq_len, num_heads, head_dim = 2, 4096, 16, 128
device = torch.device('cuda')
dtype = torch.bfloat16

Expand All @@ -103,18 +103,21 @@ if seq_len > keep_window_size:
attention_mask.zero_()
attention_mask.scatter(-1, topk_indices, 1.0)

# Select backend
flash_dmattn_func = flash_dmattn_func_auto(backend="cuda")

# Run Flash Dynamic Mask Attention
output = flash_dmattn_func(
q=query,
k=key,
v=value,
attn_mask=attention_mask,
attn_bias=attention_bias,
softmax_scale=1.0/math.sqrt(head_dim),
is_causal=True
is_causal=True,
scale=1.0/math.sqrt(head_dim),
)

print(f"Output shape: {output.shape}") # [2, 4096, 12, 128]
print(f"Output shape: {output.shape}") # [2, 4096, 16, 128]
```


Expand Down
19 changes: 11 additions & 8 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ pip install .

```python
import torch
from flash_dmattn import flash_dmattn_func
from flash_dmattn import flash_dmattn_func_auto
import math

# 设置
batch_size, seq_len, num_heads, head_dim = 2, 4096, 12, 128
batch_size, seq_len, num_heads, head_dim = 2, 4096, 16, 128
device = torch.device('cuda')
dtype = torch.bfloat16

Expand All @@ -103,18 +103,21 @@ if seq_len > keep_window_size:
attention_mask.zero_()
attention_mask.scatter(-1, topk_indices, 1.0)

# 选择后端
flash_dmattn_func = flash_dmattn_func_auto(backend="cuda")

# 运行 Flash 动态掩码注意力
output = flash_dmattn_func(
q=query,
k=key,
v=value,
query=query,
key=key,
value=value,
attn_mask=attention_mask,
attn_bias=attention_bias,
softmax_scale=1.0/math.sqrt(head_dim),
is_causal=True
is_causal=True,
scale=1.0/math.sqrt(head_dim),
)

print(f"输出形状: {output.shape}") # [2, 4096, 12, 128]
print(f"输出形状: {output.shape}") # [2, 4096, 16, 128]
```


Expand Down
26 changes: 13 additions & 13 deletions benchmarks/benchmark_forward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ def dynamic_mask_attention_cuda(
attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len]
attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len]
dropout_p=0.0,
softmax_scale=scaling,
is_causal=is_causal,
scale=scaling,
softcap=0.0,
deterministic=True,
return_attn_probs=return_softmax
Expand Down Expand Up @@ -331,10 +331,10 @@ def dynamic_mask_attention_triton(
query_states, # q: [batch, seqlen_q, num_heads, head_dim]
key_states, # k: [batch, seqlen_k, num_heads, head_dim]
value_states, # v: [batch, seqlen_k, num_heads, head_dim]
attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k]
attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k]
is_causal, # causal masking
scaling # scaling factor
attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k]
attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k]
is_causal=is_causal, # causal masking
scale=scaling # scaling factor
)

return attn_outputs # [batch, query_len, num_heads, head_dim]
Expand Down Expand Up @@ -396,14 +396,14 @@ def dynamic_mask_attention_flex(
# But attention_mask and attention_bias in [batch, num_heads, query_len, key_len] format

# Call the Flex Attention implementation
attn_outputs, _ = flex_dmattn_func(
query_states, # q: [batch, num_heads, query_len, head_dim]
key_states, # k: [batch, num_heads, key_len, head_dim]
value_states, # v: [batch, num_heads, key_len, head_dim]
attention_mask=attn_mask, # attention_mask: [batch, num_heads, query_len, key_len]
attention_bias=attn_bias, # attention_bias: [batch, num_heads, query_len, key_len]
is_causal=is_causal, # is_causal: whether to apply causal masking
scaling=scaling # scaling factor
attn_outputs = flex_dmattn_func(
query_states.transpose(1, 2), # q: [batch, query_len, num_heads, head_dim]
key_states.transpose(1, 2), # k: [batch, key_len, num_heads, head_dim]
value_states.transpose(1, 2), # v: [batch, key_len, num_heads, head_dim]
attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len]
attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len]
is_causal=is_causal, # is_causal: whether to apply causal masking
scale=scaling # scaling factor
)

return attn_outputs # [batch, query_len, num_heads, head_dim]
Expand Down
26 changes: 13 additions & 13 deletions benchmarks/benchmark_forward_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ def dynamic_mask_attention_cuda(
attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len]
attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len]
dropout_p=0.0,
softmax_scale=scaling,
is_causal=is_causal,
scale=scaling,
softcap=0.0,
deterministic=True,
return_attn_probs=return_softmax
Expand Down Expand Up @@ -350,10 +350,10 @@ def dynamic_mask_attention_triton(
query_states, # q: [batch, seqlen_q, num_heads, head_dim]
key_states, # k: [batch, seqlen_k, num_heads, head_dim]
value_states, # v: [batch, seqlen_k, num_heads, head_dim]
attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k]
attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k]
is_causal, # causal masking
scaling # scaling factor
attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k]
attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k]
is_causal=is_causal, # causal masking
scale=scaling # scaling factor
)

torch.cuda.synchronize()
Expand Down Expand Up @@ -425,14 +425,14 @@ def dynamic_mask_attention_flex(
start_time = time.time()

# Call the Flex Attention implementation
attn_outputs, _ = flex_dmattn_func(
query_states, # q: [batch, num_heads, query_len, head_dim]
key_states, # k: [batch, num_heads, key_len, head_dim]
value_states, # v: [batch, num_heads, key_len, head_dim]
attention_mask=attn_mask, # attention_mask: [batch, num_heads, query_len, key_len]
attention_bias=attn_bias, # attention_bias: [batch, num_heads, query_len, key_len]
is_causal=is_causal, # is_causal: Whether to apply causal masking
scaling=scaling # scaling factor
attn_outputs = flex_dmattn_func(
query_states.transpose(1, 2), # q: [batch, query_len, num_heads, head_dim]
key_states.transpose(1, 2), # k: [batch, key_len, num_heads, head_dim]
value_states.transpose(1, 2), # v: [batch, key_len, num_heads, head_dim]
attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len]
attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len]
is_causal=is_causal, # is_causal: whether to apply causal masking
scale=scaling # scaling factor
)

torch.cuda.synchronize()
Expand Down
2 changes: 1 addition & 1 deletion flash_dmattn/flash_dmattn_flex.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def flex_attention_forward(
value: torch.Tensor,
attn_mask: torch.Tensor,
attn_bias: torch.Tensor,
scale: Optional[float] = None,
is_causal: bool = True,
scale: Optional[float] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
query = query.transpose(1, 2).contiguous() # [B, H, Q_LEN, D]
Expand Down
Loading