Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/benchmark_forward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def prepare_dynamic_mask(
return attn_mask, active_mask


def calculate_zero_hold_states(value_states, dt_proj, A, causal_mask=None):
def calculate_zero_hold_states(value_states, dt_proj, A):
"""
Calculate zero hold states for dynamic mask attention.

Expand Down Expand Up @@ -152,7 +152,7 @@ def dynamic_mask_attention_python(

num_queries_per_kv = num_heads // num_kv_heads

zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A, causal_mask)
zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A)

# Use prepare_dynamic_mask function to process dynamic mask
attn_mask, _ = prepare_dynamic_mask(
Expand Down Expand Up @@ -208,7 +208,7 @@ def dynamic_mask_attention_cuda(
"""

# Calculate zero_hold_states
zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A, causal_mask)
zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A)

# Use prepare_dynamic_mask to get the processed attention mask
_, active_mask = prepare_dynamic_mask(
Expand Down
99 changes: 61 additions & 38 deletions benchmarks/benchmark_forward_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,15 @@
import argparse
import time
import gc
from flash_dma_cpp import apply_dynamic_mask_attention # type: ignore
from typing import cast

# Import the compiled CUDA extension
try:
import flash_dma_cuda
print("✅ Successfully imported flash_dma_cuda")
except ImportError as e:
print(f"❌ Failed to import flash_dma_cuda: {e}")
print("Please make sure the package is properly installed with: pip install .")
exit(1)


def prepare_dynamic_mask(
Expand All @@ -32,6 +39,8 @@ def prepare_dynamic_mask(
"""
Calculate dynamic attention mask to mask tokens for sparse attention.

Combine `dt_states` with `attention_mask` to generate the final `attn_mask`.

Args:
hidden_states: Input hidden states to determine dtype minimum value
dt_states: dt_states of shape (batch_size, num_kv_heads, key_sequence_length)
Expand All @@ -46,7 +55,6 @@ def prepare_dynamic_mask(
attn_mask = dt_states[:, :, None, :].expand(
-1, -1, hidden_states.shape[2], -1
) # [batch_size, num_kv_heads, query_len, key_len]
active_mask = torch.ones_like(attn_mask, dtype=dtype, device=attn_mask.device)

if attention_mask is not None:
if attention_mask.dtype == torch.bool:
Expand All @@ -66,11 +74,12 @@ def prepare_dynamic_mask(
active_mask = torch.zeros_like(attn_mask, dtype=dtype, device=attn_mask.device)
active_mask = active_mask.scatter(-1, topk_indices, 1.0)
attn_mask = attn_mask.masked_fill(active_mask == 0.0, min_dtype)

else:
active_mask = torch.ones_like(attn_mask, dtype=dtype, device=attn_mask.device)
return attn_mask, active_mask


def calculate_zero_hold_states(value_states, dt_proj, A, causal_mask=None):
def calculate_zero_hold_states(value_states, dt_proj, A):
"""
Calculate zero hold states for dynamic mask attention.

Expand Down Expand Up @@ -175,7 +184,7 @@ def dynamic_mask_attention_cuda(
attn_outputs: [batch_size, query_len, num_heads, head_dim]
"""
# Calculate zero_hold_states
zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A, causal_mask)
zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A)

_, active_mask = prepare_dynamic_mask(
query_states,
Expand All @@ -184,7 +193,8 @@ def dynamic_mask_attention_cuda(
causal_mask if is_causal else None
) # [batch_size, num_kv_heads, query_len, key_len]

# Ensure correct data types and memory layout
# Ensure correct data types and memory layout for CUDA function
# CUDA function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format
query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim]
key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]
value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]
Expand All @@ -194,20 +204,24 @@ def dynamic_mask_attention_cuda(
active_mask = active_mask.contiguous() # [batch, num_kv_heads, query_len, key_len]

try:
result = apply_dynamic_mask_attention(
query_states=query_states,
key_states=key_states,
value_states=value_states,
zoh_states=zero_hold_states,
active_mask=active_mask,
scale=scaling,
keep_window_size=keep_window_size,
is_causal=is_causal,
return_softmax=return_softmax
# Call the CUDA implementation using the mha_fwd function signature
out_tensor = None # Let the function allocate the output tensor
result = flash_dma_cuda.fwd( # type: ignore
query_states, # q: [batch, seqlen_q, num_heads, head_dim]
key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim]
value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim]
zero_hold_states, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask
active_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k]
out_tensor, # out: None to auto-allocate
0.0, # p_dropout
scaling, # softmax_scale
is_causal, # is_causal
keep_window_size, # keep_window_size
0.0, # softcap
return_softmax, # return_softmax
None # gen (generator)
)

# Convert result back to original data type
attn_outputs = result[0]
attn_outputs = result[0] # [batch, query_len, num_heads, head_dim]
return attn_outputs
except torch.cuda.OutOfMemoryError:
return "OOM"
Expand Down Expand Up @@ -245,10 +259,10 @@ def dynamic_mask_attention_cuda_no_topk(
attn_outputs: [batch_size, query_len, num_heads, head_dim]
"""
# Calculate zero_hold_states
zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A, causal_mask)
zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A)

# Create a simplified mask without topk computation
batch_size, num_heads, query_len, head_dim = query_states.shape
batch_size, _, query_len, _ = query_states.shape
_, num_kv_heads, key_len, _ = key_states.shape
dtype = query_states.dtype
device = query_states.device
Expand All @@ -267,22 +281,31 @@ def dynamic_mask_attention_cuda_no_topk(
zero_hold_states = zero_hold_states[:, :, None, :].expand(
-1, -1, query_states.shape[1], -1
).contiguous() # [batch, num_kv_heads, query_len, key_len]
# Create full active mask (no topk selection)
active_mask = torch.zeros_like(
zero_hold_states,
dtype=dtype,
device=device
)
active_mask = active_mask.contiguous() # [batch, num_kv_heads, query_len, key_len]

try:
result = apply_dynamic_mask_attention(
query_states=query_states,
key_states=key_states,
value_states=value_states,
zoh_states=zero_hold_states,
active_mask=active_mask,
scale=scaling,
keep_window_size=0,
is_causal=is_causal,
return_softmax=return_softmax
out_tensor = None # Let the function allocate the output tensor
result = flash_dma_cuda.fwd( # type: ignore
query_states, # q: [batch, seqlen_q, num_heads, head_dim]
key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim]
value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim]
zero_hold_states, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask
active_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k]
out_tensor, # out: None to auto-allocate
0.0, # p_dropout
scaling, # softmax_scale
is_causal, # is_causal
keep_window_size, # keep_window_size
0.0, # softcap
return_softmax, # return_softmax
None # gen (generator)
)

# Convert result back to original data type
attn_outputs = result[0]
return attn_outputs
except torch.cuda.OutOfMemoryError:
Expand Down Expand Up @@ -529,10 +552,10 @@ def run_performance_benchmark():
(1, 4, 1, 1024, 1024, 32),
(1, 8, 2, 1024, 1024, 32),

# # Vary head dimension
# (1, 2, 1, 1024, 1024, 32),
# (1, 2, 1, 1024, 1024, 64),
# (1, 2, 1, 1024, 1024, 128),
# Vary head dimension
(1, 2, 1, 1024, 1024, 32),
(1, 2, 1, 1024, 1024, 64),
(1, 2, 1, 1024, 1024, 128),
]

num_runs = 3 # Run 3 times and take average
Expand Down