From 92b84713bfc79d812d26ebdb6246720711ae5955 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 30 Jun 2025 15:19:51 +0800 Subject: [PATCH 1/2] Fixes attention benchmarking and expands test coverage Corrects OOM condition logic to require both query and key lengths exceed threshold instead of either one. Captures attention mask output from dynamic mask preparation function and properly passes it to the CUDA kernel instead of using active mask incorrectly. Replaces hardcoded boolean literals with is_causal variable for better code maintainability. Expands benchmark configurations with larger sequence lengths, additional head dimensions, and higher dimensional embeddings to provide more comprehensive performance testing coverage. --- benchmarks/benchmark_forward_performance.py | 43 +++++++++++++-------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index f93de7f..edb9e65 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -131,7 +131,7 @@ def flash_attention_cuda( """ _, _, query_len, _ = query_states.shape _, _, key_len, _ = key_states.shape - if query_len > 32768 or key_len > 32768: + if query_len > 32768 and key_len > 32768: return "OOM" query_states = query_states.contiguous() @@ -186,7 +186,7 @@ def dynamic_mask_attention_cuda( # Calculate zero_hold_states zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A) - _, active_mask = prepare_dynamic_mask( + attn_mask, active_mask = prepare_dynamic_mask( query_states, zero_hold_states, keep_window_size, @@ -201,6 +201,7 @@ def dynamic_mask_attention_cuda( zero_hold_states = zero_hold_states[:, :, None, :].expand( -1, -1, query_states.shape[1], -1 ).contiguous() # [batch, num_kv_heads, query_len, key_len] + attn_mask = attn_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] active_mask = active_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] try: @@ -211,7 +212,7 @@ def dynamic_mask_attention_cuda( key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] zero_hold_states, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask - active_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] + attn_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] out_tensor, # out: None to auto-allocate 0.0, # p_dropout scaling, # softmax_scale @@ -374,6 +375,7 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): # Set scaling factor and keep window size scaling = head_dim ** -0.5 keep_window_size = 2048 + is_causal = True results = { 'config': config, @@ -396,7 +398,7 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): for _ in range(warmup_runs): result = flash_attention_cuda( query_states, key_states, value_states, - scaling, causal_mask, True + scaling, causal_mask, is_causal ) if result == "OOM": results['flash_attention_status'] = 'OOM' @@ -412,7 +414,7 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): start_time = time.time() result = flash_attention_cuda( query_states, key_states, value_states, - scaling, causal_mask, True + scaling, causal_mask, is_causal ) torch.cuda.synchronize() end_time = time.time() @@ -436,7 +438,7 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): result = dynamic_mask_attention_cuda( query_states, key_states, value_states, dt_proj, A, scaling, causal_mask, - keep_window_size, True + keep_window_size, is_causal ) if result == "OOM": results['dynamic_mask_attention_status'] = 'OOM' @@ -453,7 +455,7 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): result = dynamic_mask_attention_cuda( query_states, key_states, value_states, dt_proj, A, scaling, causal_mask, - keep_window_size, True + keep_window_size, is_causal ) torch.cuda.synchronize() end_time = time.time() @@ -477,7 +479,7 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): result = dynamic_mask_attention_cuda_no_topk( query_states, key_states, value_states, dt_proj, A, scaling, causal_mask, - keep_window_size, True + keep_window_size, is_causal ) if result == "OOM": results['dynamic_mask_attention_no_topk_status'] = 'OOM' @@ -494,7 +496,7 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): result = dynamic_mask_attention_cuda_no_topk( query_states, key_states, value_states, dt_proj, A, scaling, causal_mask, - keep_window_size, True + keep_window_size, is_causal ) torch.cuda.synchronize() end_time = time.time() @@ -531,14 +533,18 @@ def run_performance_benchmark(): (1, 2, 1, 32768, 32768, 32), # Inference - (1, 2, 1, 64, 256, 32), - (1, 2, 1, 64, 512, 32), - (1, 2, 1, 64, 1024, 32), - (1, 2, 1, 64, 2048, 32), - (1, 2, 1, 64, 4096, 32), - (1, 2, 1, 64, 8192, 32), - (1, 2, 1, 64, 16384, 32), - (1, 2, 1, 64, 32768, 32), + (1, 2, 1, 64, 256, 128), + (1, 2, 1, 64, 512, 128), + (1, 2, 1, 64, 1024, 128), + (1, 2, 1, 64, 2048, 128), + (1, 2, 1, 64, 4096, 128), + (1, 2, 1, 64, 8192, 128), + (1, 2, 1, 64, 16384, 128), + (1, 2, 1, 64, 32768, 128), + (1, 2, 1, 64, 65536, 128), + (1, 2, 1, 64, 131072, 128), + (1, 2, 1, 64, 262144, 128), + (1, 2, 1, 64, 524288, 128), # Vary batch size (1, 2, 1, 1024, 1024, 32), @@ -555,7 +561,10 @@ def run_performance_benchmark(): # Vary head dimension (1, 2, 1, 1024, 1024, 32), (1, 2, 1, 1024, 1024, 64), + (1, 2, 1, 1024, 1024, 96), (1, 2, 1, 1024, 1024, 128), + (1, 2, 1, 1024, 1024, 192), + (1, 2, 1, 1024, 1024, 256), ] num_runs = 3 # Run 3 times and take average From 4d42a848abe9a1394f95c22abb00a6fb26bfb9b5 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 30 Jun 2025 17:34:57 +0800 Subject: [PATCH 2/2] Improves benchmark memory management and test coverage Adds comprehensive memory cleanup with garbage collection and CUDA cache clearing between test configurations to prevent memory issues during extended benchmarking. Expands test configurations to include more diverse scenarios with varying sequence lengths from 4 to 4096 tokens and different head dimensions. Fixes inconsistent attention mask handling by ensuring both Python and CUDA implementations properly use the processed attention mask from prepare_dynamic_mask. Adds proper CUDA synchronization around timing measurements to ensure accurate performance comparisons. --- benchmarks/benchmark_forward_equivalence.py | 56 +++++++++++++++------ 1 file changed, 42 insertions(+), 14 deletions(-) diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index b7af351..bc10bec 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -16,6 +16,7 @@ import torch.nn.functional as F import argparse import time +import gc # Import the compiled CUDA extension try: @@ -155,7 +156,7 @@ def dynamic_mask_attention_python( zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A) # Use prepare_dynamic_mask function to process dynamic mask - attn_mask, _ = prepare_dynamic_mask( + attn_mask, active_mask = prepare_dynamic_mask( query_states, zero_hold_states, keep_window_size, @@ -211,7 +212,7 @@ def dynamic_mask_attention_cuda( zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A) # Use prepare_dynamic_mask to get the processed attention mask - _, active_mask = prepare_dynamic_mask( + attn_mask, active_mask = prepare_dynamic_mask( query_states, zero_hold_states, keep_window_size, @@ -226,6 +227,7 @@ def dynamic_mask_attention_cuda( zero_hold_states = zero_hold_states[:, :, None, :].expand( -1, -1, query_states.shape[1], -1 ).contiguous() # [batch, num_kv_heads, query_len, key_len] + attn_mask = attn_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] active_mask = active_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] # Call the CUDA implementation using the mha_fwd function signature @@ -234,7 +236,7 @@ def dynamic_mask_attention_cuda( query_states, # q: [batch, seqlen_q, num_heads, head_dim] key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] - zero_hold_states, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask + attn_mask, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask active_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] out_tensor, # out: None to auto-allocate 0.0, # p_dropout @@ -352,16 +354,30 @@ def test_forward_equivalence(accuracy_threshold=0.95): torch.manual_seed(0) # Test different parameter configurations + # If you encounter NAN issues when running multiple configurations, try running a single configuration test_configs = [ # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal) - (1, 1, 1, 64, 64, 32, True), # Small scale test, causal mask - (1, 1, 1, 64, 64, 32, False), # Small scale test, non-causal mask - (1, 1, 1, 128, 128, 32, True), # Medium scale test, causal mask - (1, 1, 1, 128, 128, 32, False), # Medium scale test, non-causal mask - (1, 1, 1, 256, 256, 32, True), # Large scale test, causal mask - (1, 2, 1, 64, 64, 32, True), # Medium scale test, GQA mode - (2, 1, 1, 128, 128, 32, True), # Medium scale test, Multi batch - (2, 2, 1, 128, 128, 32, True), # Medium scale test, Multi batch GQA mode + (1, 1, 1, 4, 64, 32, True), + (1, 1, 1, 4, 64, 32, False), + (1, 1, 1, 128, 128, 32, True), + (1, 1, 1, 128, 128, 32, False), + (1, 1, 1, 256, 256, 32, True), + (1, 1, 1, 256, 256, 32, False), + (1, 1, 1, 512, 512, 32, True), + (1, 1, 1, 512, 512, 32, False), + (1, 1, 1, 1024, 1024, 32, True), + (1, 1, 1, 1024, 1024, 32, False), + (1, 1, 1, 2048, 2048, 32, True), + (1, 1, 1, 2048, 2048, 32, False), + (1, 1, 1, 4096, 4096, 32, True), + (1, 1, 1, 4096, 4096, 32, False), + (1, 2, 1, 64, 64, 32, True), + (2, 1, 1, 128, 128, 32, True), + (2, 2, 1, 128, 128, 32, True), + (1, 2, 1, 64, 64, 128, True), + (1, 2, 1, 128, 128, 128, True), + (1, 2, 1, 256, 256, 128, True), + (1, 2, 1, 2, 256, 128, True), ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -371,6 +387,10 @@ def test_forward_equivalence(accuracy_threshold=0.95): all_passed = True for i, config in enumerate(test_configs): + torch.cuda.empty_cache() + gc.collect() + torch.cuda.synchronize() + batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal = config # Progress indicator @@ -425,7 +445,7 @@ def test_forward_equivalence(accuracy_threshold=0.95): dt_proj, A, scaling, causal_mask, keep_window_size, is_causal ) - py_output_copy = py_output.clone() + torch.cuda.synchronize() py_time = time.time() - start_time # Run CUDA implementation @@ -435,10 +455,14 @@ def test_forward_equivalence(accuracy_threshold=0.95): dt_proj, A, scaling, causal_mask, keep_window_size, is_causal ) + torch.cuda.synchronize() cuda_time = time.time() - start_time - + + # Analyze differences - is_close, max_diff, mean_diff = analyze_differences(py_output_copy, cuda_output, accuracy_threshold) + py_output_copy = py_output.clone() + cuda_output_copy = cuda_output.clone() + is_close, max_diff, mean_diff = analyze_differences(py_output_copy, cuda_output_copy, accuracy_threshold) # Report performance difference speedup = py_time / cuda_time if cuda_time > 0 else float('inf') @@ -457,6 +481,10 @@ def test_forward_equivalence(accuracy_threshold=0.95): if not is_close and max_diff > 1e-2: print(" ⚠️ Difference too large, stopping subsequent tests.") break + del query_states, key_states, value_states, dt_proj, A, causal_mask, py_output, cuda_output, py_output_copy, cuda_output_copy + torch.cuda.empty_cache() + gc.collect() + torch.cuda.synchronize() print("\n" + "🏁" + "=" * 76 + "🏁") summary_icon = "🎉" if all_passed else "😞"