From 3a05b76f441e66c9223a4a795b6f7a6dce577b64 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Thu, 13 Nov 2025 12:24:54 +0800 Subject: [PATCH 1/6] Unifies mask creation via shared utility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the ad‑hoc mask builder with a shared utility to centralize top‑k/causal masking across CUDA, Triton, and Flex paths. Passes explicit tensor shapes and dtype min to ensure correctness, and adds block‑sparse support (block size 64) for dynamic CUDA. Handles optional masks safely when repeating KV for GQA. Updates benchmark bias to broadcast over queries (B,H,1,K) to reduce memory and match masking expectations. Improves consistency, reduces duplication, and prepares for extensible masking strategies. --- benchmarks/forward_performance.py | 123 ++++++++++++++---------------- 1 file changed, 57 insertions(+), 66 deletions(-) diff --git a/benchmarks/forward_performance.py b/benchmarks/forward_performance.py index 05e75c4..851fc88 100644 --- a/benchmarks/forward_performance.py +++ b/benchmarks/forward_performance.py @@ -26,6 +26,8 @@ import time import gc +from flash_sparse_attn.utils.mask import create_mask, topk_indices + # Import the compiled CUDA extension try: from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func @@ -72,42 +74,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def prepare_mask( - hidden_states: torch.Tensor, - attn_bias: torch.Tensor, - causal_mask: torch.Tensor = None, - window_size: int = None, -): - """ - Args: - hidden_states: Input hidden states to determine dtype minimum value - attn_bias: Attention bias of shape (batch_size, num_heads, query_length, key_length) - causal_mask: Optional causal mask to apply - window_size: Window size of tokens not masked - - Returns: - tuple: (attn_bias, attn_mask) - """ - dtype = hidden_states.dtype - min_dtype = torch.finfo(dtype).min - - if attn_bias.shape[-1] > window_size: - if causal_mask is not None: - topk_values, topk_indices = torch.topk( - attn_bias.masked_fill(~causal_mask, min_dtype).detach(), - window_size, dim=-1, largest=True, sorted=False - ) - else: - topk_values, topk_indices = torch.topk( - attn_bias, - window_size, dim=-1, largest=True, sorted=False - ) - attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device).scatter_(-1, topk_indices, topk_values != min_dtype) - else: - attn_mask = causal_mask.expand_as(attn_bias) if causal_mask is not None else torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) - return attn_bias, attn_mask - - def scaled_dot_product_attention_cuda( query_states: torch.Tensor, key_states: torch.Tensor, @@ -134,19 +100,24 @@ def scaled_dot_product_attention_cuda( Returns: tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads - - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" ) # Repeat KV for multi-head attention (GQA support) - attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) if attn_mask is not None else None attn_bias = repeat_kv(attn_bias, num_queries_per_kv) query_states = query_states.contiguous() @@ -167,7 +138,7 @@ def scaled_dot_product_attention_cuda( # is_causal=is_causal, enable_gqa=True, ) - + torch.cuda.synchronize() end_time = time.time() @@ -206,11 +177,21 @@ def dynamic_mask_attention_cuda( if flash_sparse_attn_func is None: return "Not Available", 0 - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + + num_queries_per_kv = num_heads // num_kv_heads + + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + block_size=64, + type="topk" ) # Ensure correct data types and memory layout for CUDA function @@ -272,15 +253,20 @@ def dynamic_mask_attention_triton( if triton_sparse_attn_func is None: return "Not Available", 0 - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" ) # Repeat KV for multi-head attention (GQA support) @@ -347,15 +333,20 @@ def dynamic_mask_attention_flex( if flex_sparse_attn_func is None: return "Not Available", 0 - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" ) # Repeat KV for multi-head attention (GQA support) @@ -437,7 +428,7 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ device=device, dtype=torch.bfloat16 ) attn_bias = torch.randn( - batch_size, num_kv_heads, query_len, key_len, + batch_size, num_kv_heads, 1, key_len, device=device, dtype=torch.bfloat16 ) cache_position = torch.arange(key_len - query_len, key_len, device=device) From 95c7c16fed4fd1aad2ff7cc8dc798107f2a4c9dd Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Thu, 13 Nov 2025 12:25:49 +0800 Subject: [PATCH 2/6] Removes unused import from benchmark script Removes an unused symbol to clean up imports and silence linter warnings. Reduces clutter and avoids confusion with unreferenced utilities. Introduces no functional changes. --- benchmarks/forward_performance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/forward_performance.py b/benchmarks/forward_performance.py index 851fc88..f93537b 100644 --- a/benchmarks/forward_performance.py +++ b/benchmarks/forward_performance.py @@ -26,7 +26,7 @@ import time import gc -from flash_sparse_attn.utils.mask import create_mask, topk_indices +from flash_sparse_attn.utils.mask import create_mask # Import the compiled CUDA extension try: From 3f5a40f26af10953e697f00c7b10819c1aa0a1ee Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Thu, 13 Nov 2025 12:27:21 +0800 Subject: [PATCH 3/6] Unifies masking via shared utility Replaces the local mask builder with a centralized utility to standardize top-k/causal masking across Python, CUDA, Triton, and Flex paths. Passes explicit batch/query/key sizes and dtype min, repeats masks only when present, and skips masked_fill when unneeded. Reduces duplication, improves consistency and maintainability, and streamlines GQA handling. --- benchmarks/forward_equivalence.py | 118 ++++++++++++++---------------- 1 file changed, 54 insertions(+), 64 deletions(-) diff --git a/benchmarks/forward_equivalence.py b/benchmarks/forward_equivalence.py index 9b05ba3..8dfe5c8 100644 --- a/benchmarks/forward_equivalence.py +++ b/benchmarks/forward_equivalence.py @@ -19,6 +19,8 @@ import gc import sys +from flash_sparse_attn.utils.mask import create_mask + # Import the compiled CUDA extension try: from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func @@ -65,42 +67,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def prepare_mask( - hidden_states: torch.Tensor, - attn_bias: torch.Tensor, - causal_mask: torch.Tensor = None, - window_size: int = None, -): - """ - Args: - hidden_states: Input hidden states to determine dtype minimum value - attn_bias: Attention bias of shape (batch_size, num_heads, query_length, key_length) - causal_mask: Optional causal mask to apply - window_size: Window size of tokens not masked - - Returns: - tuple: (attn_bias, attn_mask) - """ - dtype = hidden_states.dtype - min_dtype = torch.finfo(dtype).min - - if attn_bias.shape[-1] > window_size: - if causal_mask is not None: - topk_values, topk_indices = torch.topk( - attn_bias.masked_fill(~causal_mask, min_dtype).detach(), - window_size, dim=-1, largest=True, sorted=False - ) - else: - topk_values, topk_indices = torch.topk( - attn_bias, - window_size, dim=-1, largest=True, sorted=False - ) - attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device).scatter_(-1, topk_indices, topk_values != min_dtype) - else: - attn_mask = causal_mask.expand_as(attn_bias) if causal_mask is not None else torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) - return attn_bias, attn_mask - - def dynamic_mask_attention_python( query_states: torch.Tensor, key_states: torch.Tensor, @@ -127,27 +93,32 @@ def dynamic_mask_attention_python( Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] """ - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape num_queries_per_kv = num_heads // num_kv_heads - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" ) key_states = repeat_kv(key_states, num_queries_per_kv) value_states = repeat_kv(value_states, num_queries_per_kv) attn_bias = repeat_kv(attn_bias, num_queries_per_kv) - attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) if attn_mask is not None else None # Sparse attention weight calculation attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) # Dot product weights attn_weights = attn_weights * scaling + attn_bias # Apply scaling and bias - attn_weights = attn_weights.masked_fill(~attn_mask, float('-inf')) # Apply mask + if attn_mask is not None: + attn_weights = attn_weights.masked_fill(~attn_mask, float('-inf')) # Apply mask attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization attn_outputs = torch.matmul(attn_weights, value_states) # Weighted sum of values attn_outputs = attn_outputs.transpose(1, 2).contiguous() # Transpose to [batch, query_len, num_heads, head_dim] @@ -184,11 +155,20 @@ def dynamic_mask_attention_cuda( if flash_sparse_attn_func is None: raise RuntimeError("flash_sparse_attn_func not available") - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + + num_queries_per_kv = num_heads // num_kv_heads + + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" ) # Ensure correct data types and memory layout for CUDA function @@ -242,15 +222,20 @@ def dynamic_mask_attention_triton( if triton_sparse_attn_func is None: raise RuntimeError("Triton implementation not available") - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" ) # Repeat KV for multi-head attention (GQA support) @@ -309,15 +294,20 @@ def dynamic_mask_attention_flex( if flex_sparse_attn_func is None: raise RuntimeError("Flex Attention implementation not available") - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" ) # Repeat KV for multi-head attention (GQA support) From 8ce664287ea6dbf5235ae0e83382fe7355814c73 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Thu, 13 Nov 2025 12:29:25 +0800 Subject: [PATCH 4/6] Refactors masking via utility; fixes Flex grads MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces ad-hoc mask construction with a shared mask utility to unify top‑k/causal masking across Python, CUDA, Triton, and Flex paths. Reduces duplication and allows safely skipping masking when not required. Fixes gradient reporting in the Flex path by returning grads w.r.t. the original input tensors. Also clarifies shape handling and guards masked fill, improving robustness. --- benchmarks/backward_equivalence.py | 150 +++++++++++++---------------- 1 file changed, 69 insertions(+), 81 deletions(-) diff --git a/benchmarks/backward_equivalence.py b/benchmarks/backward_equivalence.py index a10da1e..049510b 100644 --- a/benchmarks/backward_equivalence.py +++ b/benchmarks/backward_equivalence.py @@ -19,6 +19,8 @@ import gc import sys +from flash_sparse_attn.utils.mask import create_mask + # Import the compiled CUDA extension try: from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func @@ -65,42 +67,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def prepare_mask( - hidden_states: torch.Tensor, - attn_bias: torch.Tensor, - causal_mask: torch.Tensor = None, - window_size: int = None, -): - """ - Args: - hidden_states: Input hidden states to determine dtype minimum value - attn_bias: Attention bias of shape (batch_size, num_heads, query_length, key_length) - causal_mask: Optional causal mask to apply - window_size: Window size of tokens not masked - - Returns: - tuple: (attn_bias, attn_mask) - """ - dtype = hidden_states.dtype - min_dtype = torch.finfo(dtype).min - - if attn_bias.shape[-1] > window_size: - if causal_mask is not None: - topk_values, topk_indices = torch.topk( - attn_bias.masked_fill(~causal_mask, min_dtype).detach(), - window_size, dim=-1, largest=True, sorted=False - ) - else: - topk_values, topk_indices = torch.topk( - attn_bias, - window_size, dim=-1, largest=True, sorted=False - ) - attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device).scatter_(-1, topk_indices, topk_values != min_dtype) - else: - attn_mask = causal_mask.expand_as(attn_bias) if causal_mask is not None else torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) - return attn_bias, attn_mask - - def dynamic_mask_attention_python( query_states: torch.Tensor, key_states: torch.Tensor, @@ -127,32 +93,38 @@ def dynamic_mask_attention_python( Returns: tuple: (attn_outputs, dq, dk, dv, dbias) """ - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" + ) + query_states_leaf = query_states key_states_leaf = key_states value_states_leaf = value_states - - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, - ) attn_bias_leaf = attn_bias attn_bias_leaf.retain_grad() key_states = repeat_kv(key_states, num_queries_per_kv) value_states = repeat_kv(value_states, num_queries_per_kv) - attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) if attn_mask is not None else None attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv) # Sparse attention weight calculation attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) # Dot product weights attn_weights = attn_weights * scaling + attn_bias # Apply scaling and bias - attn_weights = attn_weights.masked_fill(~attn_mask, float('-inf')) # Apply mask + if attn_mask is not None: + attn_weights = attn_weights.masked_fill(~attn_mask, float('-inf')) # Apply mask attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization attn_outputs = torch.matmul(attn_weights, value_states) # Weighted sum of values attn_outputs = attn_outputs.transpose(1, 2).contiguous() # Transpose to [batch, query_len, num_heads, head_dim] @@ -192,16 +164,25 @@ def dynamic_mask_attention_cuda( if flash_sparse_attn_func is None: raise ImportError("CUDA implementation not available") + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + + num_queries_per_kv = num_heads // num_kv_heads + + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" + ) + query_states_leaf = query_states key_states_leaf = key_states value_states_leaf = value_states - - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, - ) attn_bias_leaf = attn_bias attn_bias_leaf.retain_grad() @@ -259,29 +240,28 @@ def dynamic_mask_attention_triton( if triton_sparse_attn_func is None: raise RuntimeError("Triton implementation not available") - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" + ) + query_states_leaf = query_states key_states_leaf = key_states value_states_leaf = value_states - - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, - ) attn_bias_leaf = attn_bias attn_bias_leaf.retain_grad() - # Repeat KV for multi-head attention (GQA support) - key_states = repeat_kv(key_states, num_queries_per_kv) - value_states = repeat_kv(value_states, num_queries_per_kv) - attn_mask = repeat_kv(attn_mask, num_queries_per_kv) - attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv) - # Ensure correct data types and memory layout for Triton function query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2) # [batch, key_len, num_heads, head_dim] @@ -333,30 +313,38 @@ def dynamic_mask_attention_flex( if flex_sparse_attn_func is None: raise RuntimeError("Flex Attention implementation not available") - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" ) - attn_bias.retain_grad() + + query_states_leaf = query_states + key_states_leaf = key_states + value_states_leaf = value_states + attn_bias_leaf = attn_bias + attn_bias_leaf.retain_grad() # Repeat KV for multi-head attention (GQA support) key_states = repeat_kv(key_states, num_queries_per_kv) value_states = repeat_kv(value_states, num_queries_per_kv) - attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) if attn_mask is not None else None attn_bias = repeat_kv(attn_bias, num_queries_per_kv) # Ensure correct data types and memory layout for Flex function query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] - attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] - attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] # Call the Flex Attention implementation attn_outputs = flex_sparse_attn_func( @@ -372,7 +360,7 @@ def dynamic_mask_attention_flex( # Backward pass attn_outputs.sum().backward() - return attn_outputs, query_states.grad, key_states.grad, value_states.grad, attn_bias.grad + return attn_outputs, query_states_leaf.grad, key_states_leaf.grad, value_states_leaf.grad, attn_bias_leaf.grad def analyze_differences(original_result, cuda_result, accuracy_threshold=0.95): From 2951e2443b0e08a266c5a199106490c1af974465 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Thu, 13 Nov 2025 12:36:30 +0800 Subject: [PATCH 5/6] =?UTF-8?q?Adopts=20shared=20top=E2=80=91k=20masking;?= =?UTF-8?q?=20aligns=20bias=20shape?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces duplicated mask logic with a shared utility to standardize top‑k/causal masking and dtype handling across CUDA, Triton, and Flex backends. Aligns attention bias to a broadcastable per‑query shape to cut memory and simplify kernel expectations. Removes redundant KV/mask/bias repetition in the Triton path, repeats conditionally for Flex, and makes GQA fan‑out explicit for correctness and performance. --- benchmarks/backward_performance.py | 129 ++++++++++++----------------- 1 file changed, 55 insertions(+), 74 deletions(-) diff --git a/benchmarks/backward_performance.py b/benchmarks/backward_performance.py index 59daf16..68aebbc 100644 --- a/benchmarks/backward_performance.py +++ b/benchmarks/backward_performance.py @@ -26,6 +26,8 @@ import gc import sys +from flash_sparse_attn.utils.mask import create_mask + # Import the compiled CUDA extension try: from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func @@ -72,42 +74,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def prepare_mask( - hidden_states: torch.Tensor, - attn_bias: torch.Tensor, - causal_mask: torch.Tensor = None, - window_size: int = None, -): - """ - Args: - hidden_states: Input hidden states to determine dtype minimum value - attn_bias: Attention bias of shape (batch_size, num_heads, query_length, key_length) - causal_mask: Optional causal mask to apply - window_size: Window size of tokens not masked - - Returns: - tuple: (attn_bias, attn_mask) - """ - dtype = hidden_states.dtype - min_dtype = torch.finfo(dtype).min - - if attn_bias.shape[-1] > window_size: - if causal_mask is not None: - topk_values, topk_indices = torch.topk( - attn_bias.masked_fill(~causal_mask, min_dtype).detach(), - window_size, dim=-1, largest=True, sorted=False - ) - else: - topk_values, topk_indices = torch.topk( - attn_bias, - window_size, dim=-1, largest=True, sorted=False - ) - attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device).scatter_(-1, topk_indices, topk_values != min_dtype) - else: - attn_mask = causal_mask.expand_as(attn_bias) if causal_mask is not None else torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) - return attn_bias, attn_mask - - def scaled_dot_product_attention_backward_cuda( query_states: torch.Tensor, key_states: torch.Tensor, @@ -133,15 +99,20 @@ def scaled_dot_product_attention_backward_cuda( Returns: tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" ) # Repeat KV for multi-head attention (GQA support) @@ -210,11 +181,20 @@ def dynamic_mask_attention_backward_cuda( if flash_sparse_attn_func is None: return "Not Available", 0 - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + + num_queries_per_kv = num_heads // num_kv_heads + + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" ) # Ensure correct data types and memory layout for CUDA function @@ -280,29 +260,27 @@ def dynamic_mask_attention_backward_triton( if triton_sparse_attn_func is None: return "Not Available", 0 - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + block_size=64, + type="topk" ) - # Repeat KV for multi-head attention (GQA support) - key_states = repeat_kv(key_states, num_queries_per_kv) - value_states = repeat_kv(value_states, num_queries_per_kv) - attn_mask = repeat_kv(attn_mask, num_queries_per_kv) - attn_bias = repeat_kv(attn_bias, num_queries_per_kv) - # Ensure correct data types and memory layout for Triton function query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] - attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] - attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] try: attn_outputs = triton_sparse_attn_func( @@ -359,29 +337,32 @@ def dynamic_mask_attention_backward_flex( if flex_sparse_attn_func is None: return "Not Available", 0 - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" ) # Repeat KV for multi-head attention (GQA support) key_states = repeat_kv(key_states, num_queries_per_kv) value_states = repeat_kv(value_states, num_queries_per_kv) - attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) if attn_mask is not None else None attn_bias = repeat_kv(attn_bias, num_queries_per_kv) # Ensure correct data types and memory layout for Flex function query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] - attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] - attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] try: attn_outputs = flex_sparse_attn_func( @@ -453,7 +434,7 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 device=device, dtype=torch.bfloat16, requires_grad=True ) attn_bias = torch.randn( - batch_size, num_kv_heads, query_len, key_len, + batch_size, num_kv_heads, 1, key_len, device=device, dtype=torch.bfloat16 ) cache_position = torch.arange(key_len - query_len, key_len, device=device) From ade4ef3bc26e50b11c116cedcce3112cffb698cf Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Thu, 13 Nov 2025 12:37:11 +0800 Subject: [PATCH 6/6] Broadcasts attention bias over query dimension Updates forward/backward equivalence benchmarks to create attention bias with a singleton query dimension so it broadcasts across queries. Aligns shapes with kernel expectations during cached decoding, reduces memory footprint, and prevents shape mismatches across CUDA, Triton, and Flex paths. --- benchmarks/backward_equivalence.py | 4 ++-- benchmarks/forward_equivalence.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/benchmarks/backward_equivalence.py b/benchmarks/backward_equivalence.py index 049510b..e50a35a 100644 --- a/benchmarks/backward_equivalence.py +++ b/benchmarks/backward_equivalence.py @@ -597,7 +597,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): device=device, dtype=dtype, requires_grad=True ) attn_bias = torch.randn( - batch_size, num_kv_heads, query_len, key_len, + batch_size, num_kv_heads, 1, key_len, device=device, dtype=torch.bfloat16 ) cache_position = torch.arange(key_len - query_len, key_len, device=device) @@ -831,7 +831,7 @@ def test_triton_backward_equivalence(accuracy_threshold=0.95): device=device, dtype=dtype, requires_grad=True ) attn_bias = torch.randn( - batch_size, num_kv_heads, query_len, key_len, + batch_size, num_kv_heads, 1, key_len, device=device, dtype=torch.bfloat16 ) cache_position = torch.arange(key_len - query_len, key_len, device=device) diff --git a/benchmarks/forward_equivalence.py b/benchmarks/forward_equivalence.py index 8dfe5c8..aebb2de 100644 --- a/benchmarks/forward_equivalence.py +++ b/benchmarks/forward_equivalence.py @@ -570,7 +570,7 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): device=device, dtype=torch.bfloat16 ) attn_bias = torch.randn( - batch_size, num_kv_heads, query_len, key_len, + batch_size, num_kv_heads, 1, key_len, device=device, dtype=torch.bfloat16 ) cache_position = torch.arange(key_len - query_len, key_len, device=device) @@ -758,7 +758,7 @@ def test_triton_forward_equivalence(accuracy_threshold=0.95): device=device, dtype=torch.bfloat16 ) attn_bias = torch.randn( - batch_size, num_kv_heads, query_len, key_len, + batch_size, num_kv_heads, 1, key_len, device=device, dtype=torch.bfloat16 ) cache_position = torch.arange(key_len - query_len, key_len, device=device) @@ -963,7 +963,7 @@ def test_flex_forward_equivalence(accuracy_threshold=0.95): device=device, dtype=torch.bfloat16 ) attn_bias = torch.randn( - batch_size, num_kv_heads, query_len, key_len, + batch_size, num_kv_heads, 1, key_len, device=device, dtype=torch.bfloat16 ) cache_position = torch.arange(key_len - query_len, key_len, device=device)