Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions benchmarks/benchmark_forward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,28 @@
except ImportError as e:
print(f"❌ Failed to import flash_dmattn_cuda: {e}")
print("Please make sure the package is properly installed with: pip install .")
exit(1)
# Don't exit here, just warn
flash_dmattn_cuda = None

# Import the Triton implementation
try:
from flash_dmattn.flash_dmattn_triton import flash_dmattn_func
from flash_dmattn.flash_dmattn_triton import triton_dmattn_func
print("✅ Successfully imported flash_dmattn_triton")
except ImportError as e:
print(f"❌ Failed to import flash_dmattn_triton: {e}")
print("Please make sure the Triton implementation is available.")
# Don't exit here, just warn
flash_dmattn_func = None
triton_dmattn_func = None

# Import the Flex Attention implementation
try:
from flash_dmattn.flash_dmattn_flex import flex_attention_forward
from flash_dmattn.flash_dmattn_flex import flex_dmattn_func
print("✅ Successfully imported flash_dmattn_flex")
except ImportError as e:
print(f"❌ Failed to import flash_dmattn_flex: {e}")
print("Please make sure the Flex Attention implementation is available.")
# Don't exit here, just warn
flex_attention_forward = None
flex_dmattn_func = None


def prepare_dynamic_mask(
Expand Down Expand Up @@ -301,7 +302,7 @@ def dynamic_mask_attention_triton(
Returns:
attn_outputs: [batch_size, query_len, num_heads, head_dim]
"""
if flash_dmattn_func is None:
if triton_dmattn_func is None:
raise RuntimeError("Triton implementation not available")

_, num_heads, _, _ = query_states.shape
Expand Down Expand Up @@ -333,14 +334,14 @@ def dynamic_mask_attention_triton(
attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k]

# Call the Triton implementation
attn_outputs = flash_dmattn_func(
attn_outputs = triton_dmattn_func(
query_states, # q: [batch, seqlen_q, num_heads, head_dim]
key_states, # k: [batch, seqlen_k, num_heads, head_dim]
value_states, # v: [batch, seqlen_k, num_heads, head_dim]
mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k]
bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k]
causal=is_causal, # causal masking
softmax_scale=scaling # scaling factor
attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k]
attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k]
is_causal, # causal masking
scaling # scaling factor
Comment on lines 338 to +344
Copy link

Copilot AI Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Switching from named to positional arguments reduces readability and risks mismatches; prefer using keyword arguments when calling triton_dmattn_func to make the mapping explicit.

Copilot uses AI. Check for mistakes.
)

return attn_outputs # [batch, query_len, num_heads, head_dim]
Expand Down Expand Up @@ -374,7 +375,7 @@ def dynamic_mask_attention_flex(
Returns:
attn_outputs: [batch_size, query_len, num_heads, head_dim]
"""
if flex_attention_forward is None:
if flex_dmattn_func is None:
raise RuntimeError("Flex Attention implementation not available")

_, num_heads, _, _ = query_states.shape
Expand Down Expand Up @@ -402,12 +403,13 @@ def dynamic_mask_attention_flex(
# But attention_mask and attention_bias in [batch, num_heads, query_len, key_len] format

# Call the Flex Attention implementation
attn_outputs, _ = flex_attention_forward(
attn_outputs, _ = flex_dmattn_func(
query_states, # q: [batch, num_heads, query_len, head_dim]
key_states, # k: [batch, num_heads, key_len, head_dim]
value_states, # v: [batch, num_heads, key_len, head_dim]
attention_mask=attn_mask, # attention_mask: [batch, num_heads, query_len, key_len]
attention_bias=attn_bias, # attention_bias: [batch, num_heads, query_len, key_len]
is_causal=is_causal, # is_causal: whether to apply causal masking
scaling=scaling # scaling factor
)

Expand Down Expand Up @@ -662,14 +664,14 @@ def test_triton_forward_equivalence(accuracy_threshold=0.95):
print("🔬 Testing Forward Pass Equivalence: Python vs Triton 🔬")
print("🔥" + "=" * 76 + "🔥")

if flash_dmattn_func is None:
if triton_dmattn_func is None:
print("❌ Triton implementation not available, skipping Triton tests")
return False

# Set random seed for reproducibility
torch.manual_seed(0)
# Smaller test configurations for Triton (to avoid memory issues)

# If you encounter NAN issues when running multiple configurations, try running a single configuration
test_configs = [
# (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal)
(1, 1, 1, 64, 64, 32, True),
Expand Down Expand Up @@ -833,7 +835,7 @@ def test_flex_forward_equivalence(accuracy_threshold=0.95):
print("🔬 Testing Forward Pass Equivalence: Python vs Flex Attention 🔬")
print("🌟" + "=" * 76 + "🌟")

if flex_attention_forward is None:
if flex_dmattn_func is None:
print("❌ Flex Attention implementation not available, skipping Flex Attention tests")
return False

Expand Down
Loading