From 5b214608f0b2f0c37bccd9af212d38e73139d9fe Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Fri, 27 Jun 2025 22:08:22 +0800 Subject: [PATCH 1/2] Removes unused causal_mask parameter from function Simplifies function signature by eliminating the causal_mask parameter that was not being utilized in the implementation, reducing unnecessary complexity in both the function definition and all call sites. --- benchmarks/benchmark_forward_equivalence.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index 4fb8cc3..b7af351 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -76,7 +76,7 @@ def prepare_dynamic_mask( return attn_mask, active_mask -def calculate_zero_hold_states(value_states, dt_proj, A, causal_mask=None): +def calculate_zero_hold_states(value_states, dt_proj, A): """ Calculate zero hold states for dynamic mask attention. @@ -152,7 +152,7 @@ def dynamic_mask_attention_python( num_queries_per_kv = num_heads // num_kv_heads - zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A, causal_mask) + zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A) # Use prepare_dynamic_mask function to process dynamic mask attn_mask, _ = prepare_dynamic_mask( @@ -208,7 +208,7 @@ def dynamic_mask_attention_cuda( """ # Calculate zero_hold_states - zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A, causal_mask) + zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A) # Use prepare_dynamic_mask to get the processed attention mask _, active_mask = prepare_dynamic_mask( From 3281c01ce17c14bc8525fe9203166660eb7b9e69 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Fri, 27 Jun 2025 22:09:37 +0800 Subject: [PATCH 2/2] Updates CUDA extension integration and function calls Replaces old flash_dma_cpp import with flash_dma_cuda module and adds proper error handling for import failures. Updates function calls to use the new flash_dma_cuda.fwd API with expanded parameter list including dropout, softcap, and generator parameters. Removes unused causal_mask parameter from calculate_zero_hold_states function and fixes active_mask initialization logic in prepare_dynamic_mask. Re-enables commented head dimension test cases in performance benchmark configuration. --- benchmarks/benchmark_forward_performance.py | 99 +++++++++++++-------- 1 file changed, 61 insertions(+), 38 deletions(-) diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index f1b3f88..f93de7f 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -19,8 +19,15 @@ import argparse import time import gc -from flash_dma_cpp import apply_dynamic_mask_attention # type: ignore -from typing import cast + +# Import the compiled CUDA extension +try: + import flash_dma_cuda + print("✅ Successfully imported flash_dma_cuda") +except ImportError as e: + print(f"❌ Failed to import flash_dma_cuda: {e}") + print("Please make sure the package is properly installed with: pip install .") + exit(1) def prepare_dynamic_mask( @@ -32,6 +39,8 @@ def prepare_dynamic_mask( """ Calculate dynamic attention mask to mask tokens for sparse attention. + Combine `dt_states` with `attention_mask` to generate the final `attn_mask`. + Args: hidden_states: Input hidden states to determine dtype minimum value dt_states: dt_states of shape (batch_size, num_kv_heads, key_sequence_length) @@ -46,7 +55,6 @@ def prepare_dynamic_mask( attn_mask = dt_states[:, :, None, :].expand( -1, -1, hidden_states.shape[2], -1 ) # [batch_size, num_kv_heads, query_len, key_len] - active_mask = torch.ones_like(attn_mask, dtype=dtype, device=attn_mask.device) if attention_mask is not None: if attention_mask.dtype == torch.bool: @@ -66,11 +74,12 @@ def prepare_dynamic_mask( active_mask = torch.zeros_like(attn_mask, dtype=dtype, device=attn_mask.device) active_mask = active_mask.scatter(-1, topk_indices, 1.0) attn_mask = attn_mask.masked_fill(active_mask == 0.0, min_dtype) - + else: + active_mask = torch.ones_like(attn_mask, dtype=dtype, device=attn_mask.device) return attn_mask, active_mask -def calculate_zero_hold_states(value_states, dt_proj, A, causal_mask=None): +def calculate_zero_hold_states(value_states, dt_proj, A): """ Calculate zero hold states for dynamic mask attention. @@ -175,7 +184,7 @@ def dynamic_mask_attention_cuda( attn_outputs: [batch_size, query_len, num_heads, head_dim] """ # Calculate zero_hold_states - zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A, causal_mask) + zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A) _, active_mask = prepare_dynamic_mask( query_states, @@ -184,7 +193,8 @@ def dynamic_mask_attention_cuda( causal_mask if is_causal else None ) # [batch_size, num_kv_heads, query_len, key_len] - # Ensure correct data types and memory layout + # Ensure correct data types and memory layout for CUDA function + # CUDA function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] @@ -194,20 +204,24 @@ def dynamic_mask_attention_cuda( active_mask = active_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] try: - result = apply_dynamic_mask_attention( - query_states=query_states, - key_states=key_states, - value_states=value_states, - zoh_states=zero_hold_states, - active_mask=active_mask, - scale=scaling, - keep_window_size=keep_window_size, - is_causal=is_causal, - return_softmax=return_softmax + # Call the CUDA implementation using the mha_fwd function signature + out_tensor = None # Let the function allocate the output tensor + result = flash_dma_cuda.fwd( # type: ignore + query_states, # q: [batch, seqlen_q, num_heads, head_dim] + key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] + value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] + zero_hold_states, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask + active_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] + out_tensor, # out: None to auto-allocate + 0.0, # p_dropout + scaling, # softmax_scale + is_causal, # is_causal + keep_window_size, # keep_window_size + 0.0, # softcap + return_softmax, # return_softmax + None # gen (generator) ) - - # Convert result back to original data type - attn_outputs = result[0] + attn_outputs = result[0] # [batch, query_len, num_heads, head_dim] return attn_outputs except torch.cuda.OutOfMemoryError: return "OOM" @@ -245,10 +259,10 @@ def dynamic_mask_attention_cuda_no_topk( attn_outputs: [batch_size, query_len, num_heads, head_dim] """ # Calculate zero_hold_states - zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A, causal_mask) + zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A) # Create a simplified mask without topk computation - batch_size, num_heads, query_len, head_dim = query_states.shape + batch_size, _, query_len, _ = query_states.shape _, num_kv_heads, key_len, _ = key_states.shape dtype = query_states.dtype device = query_states.device @@ -267,22 +281,31 @@ def dynamic_mask_attention_cuda_no_topk( zero_hold_states = zero_hold_states[:, :, None, :].expand( -1, -1, query_states.shape[1], -1 ).contiguous() # [batch, num_kv_heads, query_len, key_len] + # Create full active mask (no topk selection) + active_mask = torch.zeros_like( + zero_hold_states, + dtype=dtype, + device=device + ) active_mask = active_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] try: - result = apply_dynamic_mask_attention( - query_states=query_states, - key_states=key_states, - value_states=value_states, - zoh_states=zero_hold_states, - active_mask=active_mask, - scale=scaling, - keep_window_size=0, - is_causal=is_causal, - return_softmax=return_softmax + out_tensor = None # Let the function allocate the output tensor + result = flash_dma_cuda.fwd( # type: ignore + query_states, # q: [batch, seqlen_q, num_heads, head_dim] + key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] + value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] + zero_hold_states, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask + active_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] + out_tensor, # out: None to auto-allocate + 0.0, # p_dropout + scaling, # softmax_scale + is_causal, # is_causal + keep_window_size, # keep_window_size + 0.0, # softcap + return_softmax, # return_softmax + None # gen (generator) ) - - # Convert result back to original data type attn_outputs = result[0] return attn_outputs except torch.cuda.OutOfMemoryError: @@ -529,10 +552,10 @@ def run_performance_benchmark(): (1, 4, 1, 1024, 1024, 32), (1, 8, 2, 1024, 1024, 32), - # # Vary head dimension - # (1, 2, 1, 1024, 1024, 32), - # (1, 2, 1, 1024, 1024, 64), - # (1, 2, 1, 1024, 1024, 128), + # Vary head dimension + (1, 2, 1, 1024, 1024, 32), + (1, 2, 1, 1024, 1024, 64), + (1, 2, 1, 1024, 1024, 128), ] num_runs = 3 # Run 3 times and take average