Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 28 additions & 30 deletions benchmarks/benchmark_forward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@

# Import the compiled CUDA extension
try:
import flash_dmattn_cuda # type: ignore[import]
print("✅ Successfully imported flash_dmattn_cuda")
from flash_dmattn.flash_dmattn_interface import flash_dmattn_func
print("✅ Successfully imported flash_dmattn interface")
except ImportError as e:
print(f"❌ Failed to import flash_dmattn_cuda: {e}")
print(f"❌ Failed to import flash_dmattn interface: {e}")
print("Please make sure the package is properly installed with: pip install .")
# Don't exit here, just warn
flash_dmattn_cuda = None
flash_dmattn_func = None

# Import the Triton implementation
try:
Expand Down Expand Up @@ -229,6 +229,8 @@ def dynamic_mask_attention_cuda(
Returns:
attn_outputs: [batch_size, query_len, num_heads, head_dim]
"""
if flash_dmattn_func is None:
raise RuntimeError("flash_dmattn_func not available")

# Calculate zoh_states
zoh_states = calculate_zoh_states(value_states, dt_proj, A)
Expand All @@ -243,35 +245,26 @@ def dynamic_mask_attention_cuda(

# Ensure correct data types and memory layout for CUDA function
# CUDA function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format
query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim]
key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]
value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]
zoh_states = zoh_states[:, :, None, :].expand(
-1, -1, query_states.shape[1], -1
).contiguous() # [batch, num_kv_heads, query_len, key_len]
attn_bias = attn_bias.contiguous() # [batch, num_kv_heads, query_len, key_len]
attn_mask = attn_mask.contiguous() # [batch, num_kv_heads, query_len, key_len]
query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim]
key_states = key_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim]
value_states = value_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim]
Comment on lines +248 to +250
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing .contiguous() calls may cause performance issues if the tensors are not contiguous in memory. CUDA kernels typically require contiguous tensors for optimal performance. Consider adding .contiguous() back or verify that the new interface handles non-contiguous tensors efficiently.

Suggested change
query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim]
key_states = key_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim]
value_states = value_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim]
query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim]
key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]
value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The zoh_states calculation and expansion logic has been removed, but the calculate_zoh_states function is still called earlier in the function. This may cause the attention mechanism to not work correctly as zoh_states appear to be required for the dynamic mask attention computation.

Copilot uses AI. Check for mistakes.

# Call the CUDA implementation using the mha_fwd function signature
out_tensor = None # Let the function allocate the output tensor
result = flash_dmattn_cuda.fwd( # type: ignore
query_states, # q: [batch, seqlen_q, num_heads, head_dim]
key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim]
value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim]
attn_mask, # attn_mask: [batch, num_kv_heads, seqlen_q, seqlen_k]
attn_bias, # attn_bias: [batch, num_kv_heads, seqlen_q, seqlen_k]
out_tensor, # out: None to auto-allocate
0.0, # p_dropout
scaling, # softmax_scale
is_causal, # is_causal
keep_window_size, # keep_window_size
0.0, # softcap
return_softmax, # return_softmax
None # gen (generator)
# Call the new flash_dmattn_func interface
attn_outputs = flash_dmattn_func(
query_states, # [batch, query_len, num_heads, head_dim]
key_states, # [batch, key_len, num_kv_heads, head_dim]
value_states, # [batch, key_len, num_kv_heads, head_dim]
attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len]
attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len]
dropout_p=0.0,
softmax_scale=scaling,
is_causal=is_causal,
softcap=0.0,
deterministic=True,
return_attn_probs=return_softmax
)

attn_outputs = result[0] # [batch, query_len, num_heads, head_dim]
return attn_outputs
return attn_outputs # [batch, query_len, num_heads, head_dim]


def dynamic_mask_attention_triton(
Expand Down Expand Up @@ -514,6 +507,11 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95):
print("🔬 Testing Forward Pass Equivalence: Python Prototype vs CUDA Implementation 🔬")
print("🚀" + "=" * 76 + "🚀")

# Check if CUDA implementation is available
if flash_dmattn_func is None:
print("❌ CUDA implementation not available, skipping test.")
return False

# Set random seed for reproducibility
torch.manual_seed(0)

Expand Down