From 43fba8d656347ee499988f48e030abc96085216d Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Thu, 10 Jul 2025 13:41:02 +0800 Subject: [PATCH 1/2] Removes no-topk CUDA implementation from benchmarks Eliminates the dynamic mask attention CUDA implementation without topk computation to simplify benchmark comparisons and reduce code complexity. Updates test configurations to use head dimension of 128 instead of 32 for more realistic performance testing scenarios. Adjusts benchmark output formatting to accommodate the reduced number of implementations being compared. --- benchmarks/benchmark_forward_performance.py | 175 ++------------------ 1 file changed, 13 insertions(+), 162 deletions(-) diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index d9ed6f5..997f7fb 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -8,7 +8,6 @@ Implementations tested: - Flash Attention (PyTorch SDPA Flash Attention backend) - Baseline - Dynamic Mask Attention CUDA - Custom CUDA kernel implementation -- Dynamic Mask Attention CUDA (No TopK) - CUDA kernel without TopK computation - Dynamic Mask Attention Triton - Triton kernel implementation - Dynamic Mask Attention Flex - Flex Attention implementation @@ -293,99 +292,6 @@ def dynamic_mask_attention_cuda( return "OOM", 0 -def dynamic_mask_attention_cuda_no_topk( - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, - scaling: float, - causal_mask: torch.Tensor, - keep_window_size=2048, - is_causal=True, - return_softmax=False -): - """ - CUDA implementation of dynamic mask attention without topk computation. - This version skips the topk calculation for more accurate kernel performance testing. - - Args: - query_states: [batch_size, num_heads, query_len, head_dim] - key_states: [batch_size, num_kv_heads, key_len, head_dim] - value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] - scaling: Attention scaling factor - causal_mask: Causal attention mask - keep_window_size: Number of tokens to keep in attention window - is_causal: Whether to apply causal masking - return_softmax: Whether to return softmax weights - - Returns: - attn_outputs: [batch_size, query_len, num_heads, head_dim] - """ - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) - - # Create a simplified mask without topk computation - batch_size, _, query_len, _ = query_states.shape - _, num_kv_heads, key_len, _ = key_states.shape - dtype = query_states.dtype - device = query_states.device - - # Create full attn mask (no topk selection) - attn_mask = torch.zeros( - (batch_size, num_kv_heads, query_len, key_len), - dtype=dtype, - device=device - ) - - # Ensure correct data types and memory layout - query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] - key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] - value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] - attn_bias = zoh_states[:, :, None, :].expand( - -1, -1, query_states.shape[1], -1 - ).contiguous() # [batch, num_kv_heads, query_len, key_len] - # Create full active mask (no topk selection) - attn_mask = torch.zeros_like( - attn_bias, - dtype=dtype, - device=device - ).contiguous() # [batch, num_kv_heads, query_len, key_len] - - try: - out_tensor = None # Let the function allocate the output tensor - - # Only measure the core CUDA kernel computation - torch.cuda.synchronize() - start_time = time.time() - - result = flash_dmattn_cuda.fwd( # type: ignore - query_states, # q: [batch, seqlen_q, num_heads, head_dim] - key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] - value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] - attn_mask, # attn_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] - attn_bias, # attn_bias: [batch, num_kv_heads, seqlen_q, seqlen_k] - out_tensor, # out: None to auto-allocate - 0.0, # p_dropout - scaling, # softmax_scale - is_causal, # is_causal - keep_window_size, # keep_window_size - 0.0, # softcap - return_softmax, # return_softmax - None # gen (generator) - ) - - torch.cuda.synchronize() - end_time = time.time() - - attn_outputs = result[0] - return attn_outputs, (end_time - start_time) * 1000 # Return output and time in ms - except torch.cuda.OutOfMemoryError: - return "OOM", 0 - - def dynamic_mask_attention_triton( query_states: torch.Tensor, key_states: torch.Tensor, @@ -614,17 +520,14 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ 'config': config, 'flash_attention_times': [], 'dynamic_mask_attention_times': [], - 'dynamic_mask_attention_no_topk_times': [], 'dynamic_mask_attention_triton_times': [], 'dynamic_mask_attention_flex_times': [], 'flash_attention_memory': 0, 'dynamic_mask_attention_memory': 0, - 'dynamic_mask_attention_no_topk_memory': 0, 'dynamic_mask_attention_triton_memory': 0, 'dynamic_mask_attention_flex_memory': 0, 'flash_attention_status': 'success', 'dynamic_mask_attention_status': 'success', - 'dynamic_mask_attention_no_topk_status': 'success', 'dynamic_mask_attention_triton_status': 'success', 'dynamic_mask_attention_flex_status': 'success' } @@ -632,7 +535,6 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ # Determine which implementations to run run_flash = test_type in ['all', 'flash', 'flash-vs-cuda', 'flash-vs-triton', 'flash-vs-flex'] run_cuda = test_type in ['all', 'cuda', 'flash-vs-cuda'] - run_no_topk = test_type in ['all', 'cuda'] run_triton = test_type in ['all', 'triton', 'flash-vs-triton'] run_flex = test_type in ['all', 'flex', 'flash-vs-flex'] @@ -718,48 +620,6 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ else: results['dynamic_mask_attention_status'] = 'N/A' - # Benchmark Dynamic Mask Attention (No TopK) - if run_no_topk: - gc.collect() - torch.cuda.empty_cache() - - # Warmup runs - for _ in range(warmup_runs): - result = dynamic_mask_attention_cuda_no_topk( - query_states, key_states, value_states, - dt_proj, A, scaling, causal_mask, - keep_window_size, is_causal - ) - if result[0] == "OOM": - results['dynamic_mask_attention_no_topk_status'] = 'OOM' - break - torch.cuda.synchronize() - - if results['dynamic_mask_attention_no_topk_status'] == 'success': - # Measure memory before benchmark - mem_before = measure_memory_usage() - - # Actual benchmark runs - for _ in range(num_runs): - result = dynamic_mask_attention_cuda_no_topk( - query_states, key_states, value_states, - dt_proj, A, scaling, causal_mask, - keep_window_size, is_causal - ) - - if result[0] == "OOM": - results['dynamic_mask_attention_no_topk_status'] = 'OOM' - break - - # Use the timing from the function instead of measuring here - results['dynamic_mask_attention_no_topk_times'].append(result[1]) # ms - - # Measure memory after - mem_after = measure_memory_usage() - results['dynamic_mask_attention_no_topk_memory'] = mem_after[0] - mem_before[0] - else: - results['dynamic_mask_attention_no_topk_status'] = 'N/A' - # Benchmark Dynamic Mask Attention (Triton) if run_triton: gc.collect() @@ -877,14 +737,14 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): # Test configurations: (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal) configs = [ # Vary sequence length - (1, 2, 1, 256, 256, 32, 2048, True), - (1, 2, 1, 512, 512, 32, 2048, True), - (1, 2, 1, 1024, 1024, 32, 2048, True), - (1, 2, 1, 2048, 2048, 32, 2048, True), - (1, 2, 1, 4096, 4096, 32, 2048, True), - (1, 2, 1, 8192, 8192, 32, 2048, True), - (1, 2, 1, 16384, 16384, 32, 2048, True), - (1, 2, 1, 32768, 32768, 32, 2048, True), + (1, 2, 1, 256, 256, 128, 2048, True), + (1, 2, 1, 512, 512, 128, 2048, True), + (1, 2, 1, 1024, 1024, 128, 2048, True), + (1, 2, 1, 2048, 2048, 128, 2048, True), + (1, 2, 1, 4096, 4096, 128, 2048, True), + (1, 2, 1, 8192, 8192, 128, 2048, True), + (1, 2, 1, 16384, 16384, 128, 2048, True), + (1, 2, 1, 32768, 32768, 128, 2048, True), # Inference (1, 2, 1, 2, 256, 128, 2048, True), @@ -940,8 +800,8 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): num_runs = 3 # Run 3 times and take average print(f"\nšŸ“Š Benchmark Results (averaged over {num_runs} runs):") - print(f"šŸ”§ {'Configuration':<60} ⚔ {'Flash':<10} šŸš€ {'CUDA':<10} šŸš€ {'No-TopK':<10} 🌟 {'Triton':<10} 🌟 {'Flex':<15} šŸ“ˆ {'Speedup':<15}") - print("šŸ”„" + "-" * 160 + "šŸ”„") + print(f"šŸ”§ {'Configuration':<60} ⚔ {'Flash':<10} šŸš€ {'CUDA':<10} 🌟 {'Triton':<10} 🌟 {'Flex':<15} šŸ“ˆ {'Speedup':<15}") + print("šŸ”„" + "-" * 150 + "šŸ”„") all_results = [] @@ -955,7 +815,6 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): implementations = { 'flash': ('flash_attention', results['flash_attention_status'], results['flash_attention_times']), 'cuda': ('dynamic_mask_attention', results['dynamic_mask_attention_status'], results['dynamic_mask_attention_times']), - 'no_topk': ('dynamic_mask_attention_no_topk', results['dynamic_mask_attention_no_topk_status'], results['dynamic_mask_attention_no_topk_times']), 'triton': ('dynamic_mask_attention_triton', results['dynamic_mask_attention_triton_status'], results['dynamic_mask_attention_triton_times']), 'flex': ('dynamic_mask_attention_flex', results['dynamic_mask_attention_flex_status'], results['dynamic_mask_attention_flex_times']) } @@ -977,7 +836,7 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): speedup_strs = {} flash_avg = time_avgs.get('flash', float('inf')) - for impl_key in ['cuda', 'no_topk', 'triton', 'flex']: + for impl_key in ['cuda', 'triton', 'flex']: impl_avg = time_avgs.get(impl_key, float('inf')) if flash_avg != float('inf') and impl_avg != float('inf') and impl_avg > 0: speedup = flash_avg / impl_avg @@ -1015,14 +874,13 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): speedup_summary = f"{best_impl}:{best_speedup}" if best_speedup != "N/A" else "N/A" - print(f"{icons} {config_short:<48} {time_strs['flash']:<12} {time_strs['cuda']:<12} {time_strs['no_topk']:<14} {time_strs['triton']:<12} {time_strs['flex']:<18} {speedup_summary:<15}") + print(f"{icons} {config_short:<48} {time_strs['flash']:<12} {time_strs['cuda']:<12} {time_strs['triton']:<12} {time_strs['flex']:<18} {speedup_summary:<15}") - print("šŸ”„" + "-" * 160 + "šŸ”„") + print("šŸ”„" + "-" * 150 + "šŸ”„") # Summary statistics implementation_speedups = { 'cuda': [], - 'no_topk': [], 'triton': [], 'flex': [] } @@ -1071,13 +929,6 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): print(f" {icon} {impl_name:10} vs Flash - Avg: {avg_speedup:.2f}x (Best: {max_speedup:.2f}x, Worst: {min_speedup:.2f}x)") else: print(f" āŒ {impl_key.replace('_', '-').upper():10} vs Flash - No successful runs") - - # Calculate overhead comparison - if implementation_speedups['cuda'] and implementation_speedups['no_topk']: - avg_cuda = np.mean(implementation_speedups['cuda']) - avg_no_topk = np.mean(implementation_speedups['no_topk']) - topk_overhead = ((avg_no_topk - avg_cuda) / avg_no_topk * 100) if avg_no_topk > 0 else 0 - print(f" šŸ’” TopK overhead: ~{topk_overhead:.1f}% performance impact") def main(): From 5329304d133549203f0649332fb3c565260eaf56 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Thu, 10 Jul 2025 13:43:22 +0800 Subject: [PATCH 2/2] Simplifies attention benchmark and comments out test configs Removes flash attention backend specification to use default SDPA behavior and enables attention mask usage. Comments out most benchmark configurations to focus testing on window size variations, reducing benchmark execution time while maintaining core functionality testing. --- benchmarks/benchmark_forward_performance.py | 109 ++++++++++---------- 1 file changed, 54 insertions(+), 55 deletions(-) diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index 997f7fb..a38036b 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -186,16 +186,15 @@ def flash_attention_cuda( torch.cuda.synchronize() start_time = time.time() - with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.FLASH_ATTENTION]): - attn_outputs = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - # attn_mask=causal_mask, - scale=scaling, - is_causal=is_causal if query_len == key_len else False, - enable_gqa=True - ) + attn_outputs = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + scale=scaling, + # is_causal=is_causal if query_len == key_len else False, + enable_gqa=True + ) torch.cuda.synchronize() end_time = time.time() @@ -736,49 +735,49 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): # Test configurations: (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal) configs = [ - # Vary sequence length - (1, 2, 1, 256, 256, 128, 2048, True), - (1, 2, 1, 512, 512, 128, 2048, True), - (1, 2, 1, 1024, 1024, 128, 2048, True), - (1, 2, 1, 2048, 2048, 128, 2048, True), - (1, 2, 1, 4096, 4096, 128, 2048, True), - (1, 2, 1, 8192, 8192, 128, 2048, True), - (1, 2, 1, 16384, 16384, 128, 2048, True), - (1, 2, 1, 32768, 32768, 128, 2048, True), - - # Inference - (1, 2, 1, 2, 256, 128, 2048, True), - (1, 2, 1, 2, 512, 128, 2048, True), - (1, 2, 1, 2, 1024, 128, 2048, True), - (1, 2, 1, 2, 2048, 128, 2048, True), - (1, 2, 1, 2, 4096, 128, 2048, True), - (1, 2, 1, 2, 8192, 128, 2048, True), - (1, 2, 1, 2, 16384, 128, 2048, True), - (1, 2, 1, 2, 32768, 128, 2048, True), - (1, 2, 1, 2, 65536, 128, 2048, True), - (1, 2, 1, 2, 131072, 128, 2048, True), - (1, 2, 1, 2, 262144, 128, 2048, True), - (1, 2, 1, 2, 524288, 128, 2048, True), - - # Vary batch size - (1, 2, 1, 1024, 1024, 32, 2048, True), - (2, 2, 1, 1024, 1024, 32, 2048, True), - (4, 2, 1, 1024, 1024, 32, 2048, True), - (8, 2, 1, 1024, 1024, 32, 2048, True), - - # Vary head count - (1, 1, 1, 1024, 1024, 32, 2048, True), - (1, 2, 1, 1024, 1024, 32, 2048, True), - (1, 4, 1, 1024, 1024, 32, 2048, True), - (1, 8, 2, 1024, 1024, 32, 2048, True), - - # Vary head dimension - (1, 2, 1, 1024, 1024, 32, 2048, True), - (1, 2, 1, 1024, 1024, 64, 2048, True), - (1, 2, 1, 1024, 1024, 96, 2048, True), - (1, 2, 1, 1024, 1024, 128, 2048, True), - (1, 2, 1, 1024, 1024, 192, 2048, True), - (1, 2, 1, 1024, 1024, 256, 2048, True), + # # Vary sequence length + # (1, 2, 1, 256, 256, 128, 2048, True), + # (1, 2, 1, 512, 512, 128, 2048, True), + # (1, 2, 1, 1024, 1024, 128, 2048, True), + # (1, 2, 1, 2048, 2048, 128, 2048, True), + # (1, 2, 1, 4096, 4096, 128, 2048, True), + # (1, 2, 1, 8192, 8192, 128, 2048, True), + # (1, 2, 1, 16384, 16384, 128, 2048, True), + # (1, 2, 1, 32768, 32768, 128, 2048, True), + + # # Inference + # (1, 2, 1, 2, 256, 128, 2048, True), + # (1, 2, 1, 2, 512, 128, 2048, True), + # (1, 2, 1, 2, 1024, 128, 2048, True), + # (1, 2, 1, 2, 2048, 128, 2048, True), + # (1, 2, 1, 2, 4096, 128, 2048, True), + # (1, 2, 1, 2, 8192, 128, 2048, True), + # (1, 2, 1, 2, 16384, 128, 2048, True), + # (1, 2, 1, 2, 32768, 128, 2048, True), + # (1, 2, 1, 2, 65536, 128, 2048, True), + # (1, 2, 1, 2, 131072, 128, 2048, True), + # (1, 2, 1, 2, 262144, 128, 2048, True), + # (1, 2, 1, 2, 524288, 128, 2048, True), + + # # Vary batch size + # (1, 2, 1, 1024, 1024, 32, 2048, True), + # (2, 2, 1, 1024, 1024, 32, 2048, True), + # (4, 2, 1, 1024, 1024, 32, 2048, True), + # (8, 2, 1, 1024, 1024, 32, 2048, True), + + # # Vary head count + # (1, 1, 1, 1024, 1024, 32, 2048, True), + # (1, 2, 1, 1024, 1024, 32, 2048, True), + # (1, 4, 1, 1024, 1024, 32, 2048, True), + # (1, 8, 2, 1024, 1024, 32, 2048, True), + + # # Vary head dimension + # (1, 2, 1, 1024, 1024, 32, 2048, True), + # (1, 2, 1, 1024, 1024, 64, 2048, True), + # (1, 2, 1, 1024, 1024, 96, 2048, True), + # (1, 2, 1, 1024, 1024, 128, 2048, True), + # (1, 2, 1, 1024, 1024, 192, 2048, True), + # (1, 2, 1, 1024, 1024, 256, 2048, True), # Vary keep_window_size (1, 2, 1, 32768, 32768, 128, 32, True), @@ -793,8 +792,8 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): (1, 2, 1, 32768, 32768, 128, 16384, True), (1, 2, 1, 32768, 32768, 128, 32768, True), - # Test non-causal - (1, 2, 1, 1024, 1024, 128, 2048, False), + # # Test non-causal + # (1, 2, 1, 1024, 1024, 128, 2048, False), ] num_runs = 3 # Run 3 times and take average