Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 23 additions & 33 deletions benchmarks/benchmark_forward_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"""

import torch
import torch.nn.backends
import torch.nn.functional as F
import numpy as np
import argparse
Expand All @@ -29,13 +28,13 @@

# Import the compiled CUDA extension
try:
import flash_dmattn_cuda # type: ignore[import]
print("✅ Successfully imported flash_dmattn_cuda")
from flash_dmattn.flash_dmattn_interface import flash_dmattn_func
print("✅ Successfully imported flash_dmattn interface")
except ImportError as e:
print(f"❌ Failed to import flash_dmattn_cuda: {e}")
print(f"❌ Failed to import flash_dmattn interface: {e}")
print("Please make sure the package is properly installed with: pip install .")
# Don't exit here, just warn
flash_dmattn_cuda = None
flash_dmattn_func = None

# Import the Triton implementation
try:
Expand Down Expand Up @@ -235,6 +234,8 @@ def dynamic_mask_attention_cuda(
Returns:
attn_outputs: [batch_size, query_len, num_heads, head_dim]
"""
if flash_dmattn_func is None:
return "Not Available", 0

# Calculate zoh_states
zoh_states = calculate_zoh_states(value_states, dt_proj, A)
Expand All @@ -249,43 +250,32 @@ def dynamic_mask_attention_cuda(

# Ensure correct data types and memory layout for CUDA function
# CUDA function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format
query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim]
key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]
value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]
zoh_states = zoh_states[:, :, None, :].expand(
-1, -1, query_states.shape[1], -1
).contiguous() # [batch, num_kv_heads, query_len, key_len]
attn_bias = attn_bias.contiguous() # [batch, num_kv_heads, query_len, key_len]
attn_mask = attn_mask.contiguous() # [batch, num_kv_heads, query_len, key_len]
query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim]
key_states = key_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim]
value_states = value_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim]
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The zoh_states calculation and usage has been removed but the calculate_zoh_states function is still being called on line 241. This will cause the function to compute zoh_states that are no longer used, and the new flash_dmattn_func interface may not handle the dynamic masking behavior that zoh_states were intended to provide.

Copilot uses AI. Check for mistakes.
Comment on lines +253 to +255
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The .contiguous() calls have been removed from the tensor operations. If the new flash_dmattn_func interface expects contiguous tensors, this could cause performance issues or errors. Consider adding .contiguous() calls back if the interface requires them.

Suggested change
query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim]
key_states = key_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim]
value_states = value_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim]
query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim]
key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]
value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]

Copilot uses AI. Check for mistakes.

try:
# Call the CUDA implementation using the mha_fwd function signature
out_tensor = None # Let the function allocate the output tensor

# Only measure the core CUDA kernel computation
torch.cuda.synchronize()
start_time = time.time()

result = flash_dmattn_cuda.fwd( # type: ignore
query_states, # q: [batch, seqlen_q, num_heads, head_dim]
key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim]
value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim]
attn_mask, # attn_mask: [batch, num_kv_heads, seqlen_q, seqlen_k]
attn_bias, # attn_bias: [batch, num_kv_heads, seqlen_q, seqlen_k]
out_tensor, # out: None to auto-allocate
0.0, # p_dropout
scaling, # softmax_scale
is_causal, # is_causal
keep_window_size, # keep_window_size
0.0, # softcap
return_softmax, # return_softmax
None # gen (generator)

# Call the new flash_dmattn_func interface
attn_outputs = flash_dmattn_func(
query_states, # [batch, query_len, num_heads, head_dim]
key_states, # [batch, key_len, num_kv_heads, head_dim]
value_states, # [batch, key_len, num_kv_heads, head_dim]
attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len]
attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len]
dropout_p=0.0,
softmax_scale=scaling,
is_causal=is_causal,
softcap=0.0,
deterministic=True,
return_attn_probs=return_softmax
)

torch.cuda.synchronize()
end_time = time.time()

attn_outputs = result[0] # [batch, query_len, num_heads, head_dim]
return attn_outputs, (end_time - start_time) * 1000 # Return output and time in ms
except torch.cuda.OutOfMemoryError:
return "OOM", 0
Expand Down