From 99730be16ad851fb92990efcb59a5de47a322f32 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 30 Jul 2025 21:27:42 +0800 Subject: [PATCH] Refactors CUDA implementation to use new interface Switches from direct CUDA extension import to standardized function interface for better maintainability and consistency. Simplifies function call signature by removing manual tensor operations and utilizing cleaner parameter passing through the new interface. Adds proper null check to handle cases where the function is unavailable. --- benchmarks/benchmark_forward_performance.py | 56 +++++++++------------ 1 file changed, 23 insertions(+), 33 deletions(-) diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index d64ae9a..e0f7ed3 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -20,7 +20,6 @@ """ import torch -import torch.nn.backends import torch.nn.functional as F import numpy as np import argparse @@ -29,13 +28,13 @@ # Import the compiled CUDA extension try: - import flash_dmattn_cuda # type: ignore[import] - print("✅ Successfully imported flash_dmattn_cuda") + from flash_dmattn.flash_dmattn_interface import flash_dmattn_func + print("✅ Successfully imported flash_dmattn interface") except ImportError as e: - print(f"❌ Failed to import flash_dmattn_cuda: {e}") + print(f"❌ Failed to import flash_dmattn interface: {e}") print("Please make sure the package is properly installed with: pip install .") # Don't exit here, just warn - flash_dmattn_cuda = None + flash_dmattn_func = None # Import the Triton implementation try: @@ -235,6 +234,8 @@ def dynamic_mask_attention_cuda( Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] """ + if flash_dmattn_func is None: + return "Not Available", 0 # Calculate zoh_states zoh_states = calculate_zoh_states(value_states, dt_proj, A) @@ -249,43 +250,32 @@ def dynamic_mask_attention_cuda( # Ensure correct data types and memory layout for CUDA function # CUDA function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format - query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] - key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] - value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] - zoh_states = zoh_states[:, :, None, :].expand( - -1, -1, query_states.shape[1], -1 - ).contiguous() # [batch, num_kv_heads, query_len, key_len] - attn_bias = attn_bias.contiguous() # [batch, num_kv_heads, query_len, key_len] - attn_mask = attn_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] + query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim] + key_states = key_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] + value_states = value_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] try: - # Call the CUDA implementation using the mha_fwd function signature - out_tensor = None # Let the function allocate the output tensor - - # Only measure the core CUDA kernel computation torch.cuda.synchronize() start_time = time.time() - - result = flash_dmattn_cuda.fwd( # type: ignore - query_states, # q: [batch, seqlen_q, num_heads, head_dim] - key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] - value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] - attn_mask, # attn_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] - attn_bias, # attn_bias: [batch, num_kv_heads, seqlen_q, seqlen_k] - out_tensor, # out: None to auto-allocate - 0.0, # p_dropout - scaling, # softmax_scale - is_causal, # is_causal - keep_window_size, # keep_window_size - 0.0, # softcap - return_softmax, # return_softmax - None # gen (generator) + + # Call the new flash_dmattn_func interface + attn_outputs = flash_dmattn_func( + query_states, # [batch, query_len, num_heads, head_dim] + key_states, # [batch, key_len, num_kv_heads, head_dim] + value_states, # [batch, key_len, num_kv_heads, head_dim] + attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len] + attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len] + dropout_p=0.0, + softmax_scale=scaling, + is_causal=is_causal, + softcap=0.0, + deterministic=True, + return_attn_probs=return_softmax ) torch.cuda.synchronize() end_time = time.time() - attn_outputs = result[0] # [batch, query_len, num_heads, head_dim] return attn_outputs, (end_time - start_time) * 1000 # Return output and time in ms except torch.cuda.OutOfMemoryError: return "OOM", 0