diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index b7af351..bc10bec 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -16,6 +16,7 @@ import torch.nn.functional as F import argparse import time +import gc # Import the compiled CUDA extension try: @@ -155,7 +156,7 @@ def dynamic_mask_attention_python( zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A) # Use prepare_dynamic_mask function to process dynamic mask - attn_mask, _ = prepare_dynamic_mask( + attn_mask, active_mask = prepare_dynamic_mask( query_states, zero_hold_states, keep_window_size, @@ -211,7 +212,7 @@ def dynamic_mask_attention_cuda( zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A) # Use prepare_dynamic_mask to get the processed attention mask - _, active_mask = prepare_dynamic_mask( + attn_mask, active_mask = prepare_dynamic_mask( query_states, zero_hold_states, keep_window_size, @@ -226,6 +227,7 @@ def dynamic_mask_attention_cuda( zero_hold_states = zero_hold_states[:, :, None, :].expand( -1, -1, query_states.shape[1], -1 ).contiguous() # [batch, num_kv_heads, query_len, key_len] + attn_mask = attn_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] active_mask = active_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] # Call the CUDA implementation using the mha_fwd function signature @@ -234,7 +236,7 @@ def dynamic_mask_attention_cuda( query_states, # q: [batch, seqlen_q, num_heads, head_dim] key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] - zero_hold_states, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask + attn_mask, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask active_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] out_tensor, # out: None to auto-allocate 0.0, # p_dropout @@ -352,16 +354,30 @@ def test_forward_equivalence(accuracy_threshold=0.95): torch.manual_seed(0) # Test different parameter configurations + # If you encounter NAN issues when running multiple configurations, try running a single configuration test_configs = [ # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal) - (1, 1, 1, 64, 64, 32, True), # Small scale test, causal mask - (1, 1, 1, 64, 64, 32, False), # Small scale test, non-causal mask - (1, 1, 1, 128, 128, 32, True), # Medium scale test, causal mask - (1, 1, 1, 128, 128, 32, False), # Medium scale test, non-causal mask - (1, 1, 1, 256, 256, 32, True), # Large scale test, causal mask - (1, 2, 1, 64, 64, 32, True), # Medium scale test, GQA mode - (2, 1, 1, 128, 128, 32, True), # Medium scale test, Multi batch - (2, 2, 1, 128, 128, 32, True), # Medium scale test, Multi batch GQA mode + (1, 1, 1, 4, 64, 32, True), + (1, 1, 1, 4, 64, 32, False), + (1, 1, 1, 128, 128, 32, True), + (1, 1, 1, 128, 128, 32, False), + (1, 1, 1, 256, 256, 32, True), + (1, 1, 1, 256, 256, 32, False), + (1, 1, 1, 512, 512, 32, True), + (1, 1, 1, 512, 512, 32, False), + (1, 1, 1, 1024, 1024, 32, True), + (1, 1, 1, 1024, 1024, 32, False), + (1, 1, 1, 2048, 2048, 32, True), + (1, 1, 1, 2048, 2048, 32, False), + (1, 1, 1, 4096, 4096, 32, True), + (1, 1, 1, 4096, 4096, 32, False), + (1, 2, 1, 64, 64, 32, True), + (2, 1, 1, 128, 128, 32, True), + (2, 2, 1, 128, 128, 32, True), + (1, 2, 1, 64, 64, 128, True), + (1, 2, 1, 128, 128, 128, True), + (1, 2, 1, 256, 256, 128, True), + (1, 2, 1, 2, 256, 128, True), ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -371,6 +387,10 @@ def test_forward_equivalence(accuracy_threshold=0.95): all_passed = True for i, config in enumerate(test_configs): + torch.cuda.empty_cache() + gc.collect() + torch.cuda.synchronize() + batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal = config # Progress indicator @@ -425,7 +445,7 @@ def test_forward_equivalence(accuracy_threshold=0.95): dt_proj, A, scaling, causal_mask, keep_window_size, is_causal ) - py_output_copy = py_output.clone() + torch.cuda.synchronize() py_time = time.time() - start_time # Run CUDA implementation @@ -435,10 +455,14 @@ def test_forward_equivalence(accuracy_threshold=0.95): dt_proj, A, scaling, causal_mask, keep_window_size, is_causal ) + torch.cuda.synchronize() cuda_time = time.time() - start_time - + + # Analyze differences - is_close, max_diff, mean_diff = analyze_differences(py_output_copy, cuda_output, accuracy_threshold) + py_output_copy = py_output.clone() + cuda_output_copy = cuda_output.clone() + is_close, max_diff, mean_diff = analyze_differences(py_output_copy, cuda_output_copy, accuracy_threshold) # Report performance difference speedup = py_time / cuda_time if cuda_time > 0 else float('inf') @@ -457,6 +481,10 @@ def test_forward_equivalence(accuracy_threshold=0.95): if not is_close and max_diff > 1e-2: print(" ⚠️ Difference too large, stopping subsequent tests.") break + del query_states, key_states, value_states, dt_proj, A, causal_mask, py_output, cuda_output, py_output_copy, cuda_output_copy + torch.cuda.empty_cache() + gc.collect() + torch.cuda.synchronize() print("\n" + "🏁" + "=" * 76 + "🏁") summary_icon = "🎉" if all_passed else "😞" diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index f93de7f..edb9e65 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -131,7 +131,7 @@ def flash_attention_cuda( """ _, _, query_len, _ = query_states.shape _, _, key_len, _ = key_states.shape - if query_len > 32768 or key_len > 32768: + if query_len > 32768 and key_len > 32768: return "OOM" query_states = query_states.contiguous() @@ -186,7 +186,7 @@ def dynamic_mask_attention_cuda( # Calculate zero_hold_states zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A) - _, active_mask = prepare_dynamic_mask( + attn_mask, active_mask = prepare_dynamic_mask( query_states, zero_hold_states, keep_window_size, @@ -201,6 +201,7 @@ def dynamic_mask_attention_cuda( zero_hold_states = zero_hold_states[:, :, None, :].expand( -1, -1, query_states.shape[1], -1 ).contiguous() # [batch, num_kv_heads, query_len, key_len] + attn_mask = attn_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] active_mask = active_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] try: @@ -211,7 +212,7 @@ def dynamic_mask_attention_cuda( key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] zero_hold_states, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask - active_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] + attn_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] out_tensor, # out: None to auto-allocate 0.0, # p_dropout scaling, # softmax_scale @@ -374,6 +375,7 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): # Set scaling factor and keep window size scaling = head_dim ** -0.5 keep_window_size = 2048 + is_causal = True results = { 'config': config, @@ -396,7 +398,7 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): for _ in range(warmup_runs): result = flash_attention_cuda( query_states, key_states, value_states, - scaling, causal_mask, True + scaling, causal_mask, is_causal ) if result == "OOM": results['flash_attention_status'] = 'OOM' @@ -412,7 +414,7 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): start_time = time.time() result = flash_attention_cuda( query_states, key_states, value_states, - scaling, causal_mask, True + scaling, causal_mask, is_causal ) torch.cuda.synchronize() end_time = time.time() @@ -436,7 +438,7 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): result = dynamic_mask_attention_cuda( query_states, key_states, value_states, dt_proj, A, scaling, causal_mask, - keep_window_size, True + keep_window_size, is_causal ) if result == "OOM": results['dynamic_mask_attention_status'] = 'OOM' @@ -453,7 +455,7 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): result = dynamic_mask_attention_cuda( query_states, key_states, value_states, dt_proj, A, scaling, causal_mask, - keep_window_size, True + keep_window_size, is_causal ) torch.cuda.synchronize() end_time = time.time() @@ -477,7 +479,7 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): result = dynamic_mask_attention_cuda_no_topk( query_states, key_states, value_states, dt_proj, A, scaling, causal_mask, - keep_window_size, True + keep_window_size, is_causal ) if result == "OOM": results['dynamic_mask_attention_no_topk_status'] = 'OOM' @@ -494,7 +496,7 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): result = dynamic_mask_attention_cuda_no_topk( query_states, key_states, value_states, dt_proj, A, scaling, causal_mask, - keep_window_size, True + keep_window_size, is_causal ) torch.cuda.synchronize() end_time = time.time() @@ -531,14 +533,18 @@ def run_performance_benchmark(): (1, 2, 1, 32768, 32768, 32), # Inference - (1, 2, 1, 64, 256, 32), - (1, 2, 1, 64, 512, 32), - (1, 2, 1, 64, 1024, 32), - (1, 2, 1, 64, 2048, 32), - (1, 2, 1, 64, 4096, 32), - (1, 2, 1, 64, 8192, 32), - (1, 2, 1, 64, 16384, 32), - (1, 2, 1, 64, 32768, 32), + (1, 2, 1, 64, 256, 128), + (1, 2, 1, 64, 512, 128), + (1, 2, 1, 64, 1024, 128), + (1, 2, 1, 64, 2048, 128), + (1, 2, 1, 64, 4096, 128), + (1, 2, 1, 64, 8192, 128), + (1, 2, 1, 64, 16384, 128), + (1, 2, 1, 64, 32768, 128), + (1, 2, 1, 64, 65536, 128), + (1, 2, 1, 64, 131072, 128), + (1, 2, 1, 64, 262144, 128), + (1, 2, 1, 64, 524288, 128), # Vary batch size (1, 2, 1, 1024, 1024, 32), @@ -555,7 +561,10 @@ def run_performance_benchmark(): # Vary head dimension (1, 2, 1, 1024, 1024, 32), (1, 2, 1, 1024, 1024, 64), + (1, 2, 1, 1024, 1024, 96), (1, 2, 1, 1024, 1024, 128), + (1, 2, 1, 1024, 1024, 192), + (1, 2, 1, 1024, 1024, 256), ] num_runs = 3 # Run 3 times and take average