From d17e3b1f67b024987254b1c0ba733cddf52b5641 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 30 Jul 2025 21:23:18 +0800 Subject: [PATCH] Refactors CUDA interface to use simplified function Replaces direct CUDA module import with higher-level interface function that provides cleaner API with named parameters. Simplifies function call by removing manual tensor preparation and using more intuitive parameter names like dropout_p and softmax_scale. Adds runtime check to ensure CUDA implementation availability before executing tests. --- benchmarks/benchmark_forward_equivalence.py | 58 ++++++++++----------- 1 file changed, 28 insertions(+), 30 deletions(-) diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index 60a868a..6d5177b 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -21,13 +21,13 @@ # Import the compiled CUDA extension try: - import flash_dmattn_cuda # type: ignore[import] - print("✅ Successfully imported flash_dmattn_cuda") + from flash_dmattn.flash_dmattn_interface import flash_dmattn_func + print("✅ Successfully imported flash_dmattn interface") except ImportError as e: - print(f"❌ Failed to import flash_dmattn_cuda: {e}") + print(f"❌ Failed to import flash_dmattn interface: {e}") print("Please make sure the package is properly installed with: pip install .") # Don't exit here, just warn - flash_dmattn_cuda = None + flash_dmattn_func = None # Import the Triton implementation try: @@ -229,6 +229,8 @@ def dynamic_mask_attention_cuda( Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] """ + if flash_dmattn_func is None: + raise RuntimeError("flash_dmattn_func not available") # Calculate zoh_states zoh_states = calculate_zoh_states(value_states, dt_proj, A) @@ -243,35 +245,26 @@ def dynamic_mask_attention_cuda( # Ensure correct data types and memory layout for CUDA function # CUDA function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format - query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] - key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] - value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] - zoh_states = zoh_states[:, :, None, :].expand( - -1, -1, query_states.shape[1], -1 - ).contiguous() # [batch, num_kv_heads, query_len, key_len] - attn_bias = attn_bias.contiguous() # [batch, num_kv_heads, query_len, key_len] - attn_mask = attn_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] + query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim] + key_states = key_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] + value_states = value_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] - # Call the CUDA implementation using the mha_fwd function signature - out_tensor = None # Let the function allocate the output tensor - result = flash_dmattn_cuda.fwd( # type: ignore - query_states, # q: [batch, seqlen_q, num_heads, head_dim] - key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] - value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] - attn_mask, # attn_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] - attn_bias, # attn_bias: [batch, num_kv_heads, seqlen_q, seqlen_k] - out_tensor, # out: None to auto-allocate - 0.0, # p_dropout - scaling, # softmax_scale - is_causal, # is_causal - keep_window_size, # keep_window_size - 0.0, # softcap - return_softmax, # return_softmax - None # gen (generator) + # Call the new flash_dmattn_func interface + attn_outputs = flash_dmattn_func( + query_states, # [batch, query_len, num_heads, head_dim] + key_states, # [batch, key_len, num_kv_heads, head_dim] + value_states, # [batch, key_len, num_kv_heads, head_dim] + attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len] + attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len] + dropout_p=0.0, + softmax_scale=scaling, + is_causal=is_causal, + softcap=0.0, + deterministic=True, + return_attn_probs=return_softmax ) - attn_outputs = result[0] # [batch, query_len, num_heads, head_dim] - return attn_outputs + return attn_outputs # [batch, query_len, num_heads, head_dim] def dynamic_mask_attention_triton( @@ -514,6 +507,11 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): print("🔬 Testing Forward Pass Equivalence: Python Prototype vs CUDA Implementation 🔬") print("🚀" + "=" * 76 + "🚀") + # Check if CUDA implementation is available + if flash_dmattn_func is None: + print("❌ CUDA implementation not available, skipping test.") + return False + # Set random seed for reproducibility torch.manual_seed(0)