From 473672a28cc7ed21e0f0b7b6a77eedfb197da0cf Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Thu, 10 Jul 2025 20:56:17 +0800 Subject: [PATCH] Renames Flash Attention to SDPA in benchmark suite Updates terminology throughout the performance benchmark to use "SDPA" (Scaled Dot Product Attention) instead of "Flash Attention" for consistency with PyTorch's official naming convention. Changes function names, variable references, command-line options, and display text to reflect that the baseline implementation uses PyTorch's SDPA rather than specifically Flash Attention, improving clarity about the actual underlying implementation being benchmarked. --- benchmarks/benchmark_forward_performance.py | 56 ++++++++++----------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index 7f465cf..1d4e0bc 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -3,10 +3,10 @@ Performance Benchmark for Dynamic Mask Attention This script measures and compares the performance of multiple Dynamic Mask Attention -implementations against Flash Attention baseline across various configurations. +implementations against SDPA baseline across various configurations. Implementations tested: -- Flash Attention (PyTorch SDPA Flash Attention backend) - Baseline +- PyTorch SDPA - Baseline - Dynamic Mask Attention CUDA - Custom CUDA kernel implementation - Dynamic Mask Attention Triton - Triton kernel implementation - Dynamic Mask Attention Flex - Flex Attention implementation @@ -150,7 +150,7 @@ def calculate_zoh_states(value_states, dt_proj, A): return zoh_states -def flash_attention_cuda( +def scaled_dot_product_attention_cuda( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, @@ -159,7 +159,7 @@ def flash_attention_cuda( is_causal=True, ): """ - CUDA implementation of Flash Attention baseline. + CUDA implementation of SDPA baseline. Args: query_states: [batch_size, num_heads, query_len, head_dim] @@ -532,19 +532,19 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ } # Determine which implementations to run - run_flash = test_type in ['all', 'flash', 'flash-vs-cuda', 'flash-vs-triton', 'flash-vs-flex'] - run_cuda = test_type in ['all', 'cuda', 'flash-vs-cuda'] - run_triton = test_type in ['all', 'triton', 'flash-vs-triton'] - run_flex = test_type in ['all', 'flex', 'flash-vs-flex'] + run_flash = test_type in ['all', 'sdpa', 'sdpa-vs-cuda', 'sdpa-vs-triton', 'sdpa-vs-flex'] + run_cuda = test_type in ['all', 'cuda', 'sdpa-vs-cuda'] + run_triton = test_type in ['all', 'triton', 'sdpa-vs-triton'] + run_flex = test_type in ['all', 'flex', 'sdpa-vs-flex'] - # Benchmark Flash Attention + # Benchmark SDPA if run_flash: gc.collect() torch.cuda.empty_cache() # Warmup runs for _ in range(warmup_runs): - result = flash_attention_cuda( + result = scaled_dot_product_attention_cuda( query_states, key_states, value_states, scaling, causal_mask, is_causal ) @@ -559,7 +559,7 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ # Actual benchmark runs for _ in range(num_runs): - result = flash_attention_cuda( + result = scaled_dot_product_attention_cuda( query_states, key_states, value_states, scaling, causal_mask, is_causal ) @@ -712,15 +712,15 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): # Update title based on test type if test_type == 'all': - title = "⚡ Performance Benchmark: Flash vs CUDA vs Triton vs Flex ⚡" - elif test_type == 'flash-vs-cuda': - title = "⚡ Performance Benchmark: Flash Attention vs CUDA ⚡" - elif test_type == 'flash-vs-triton': - title = "⚡ Performance Benchmark: Flash Attention vs Triton ⚡" - elif test_type == 'flash-vs-flex': - title = "⚡ Performance Benchmark: Flash Attention vs Flex ⚡" - elif test_type == 'flash': - title = "⚡ Performance Benchmark: Flash Attention Only ⚡" + title = "⚡ Performance Benchmark: SDPA vs CUDA vs Triton vs Flex ⚡" + elif test_type == 'sdpa-vs-cuda': + title = "⚡ Performance Benchmark: SDPA Attention vs CUDA ⚡" + elif test_type == 'sdpa-vs-triton': + title = "⚡ Performance Benchmark: SDPA Attention vs Triton ⚡" + elif test_type == 'sdpa-vs-flex': + title = "⚡ Performance Benchmark: SDPA Attention vs Flex ⚡" + elif test_type == 'sdpa': + title = "⚡ Performance Benchmark: SDPA Attention Only ⚡" elif test_type == 'cuda': title = "⚡ Performance Benchmark: CUDA Implementations ⚡" elif test_type == 'triton': @@ -797,7 +797,7 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): ] print(f"\n📊 Benchmark Results (averaged over {num_runs} runs):") - print(f"🔧 {'Configuration':<60} ⚡ {'Flash':<10} 🚀 {'CUDA':<10} 🌟 {'Triton':<10} 🌟 {'Flex':<15} 📈 {'Speedup':<15}") + print(f"🔧 {'Configuration':<60} ⚡ {'SDPA':<10} 🚀 {'CUDA':<10} 🌟 {'Triton':<10} 🌟 {'Flex':<15} 📈 {'Speedup':<15}") print("🔄" + "-" * 150 + "🔄") all_results = [] @@ -810,7 +810,7 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): # Calculate averages for all implementations implementations = { - 'flash': ('flash_attention', results['flash_attention_status'], results['flash_attention_times']), + 'sdpa': ('flash_attention', results['flash_attention_status'], results['flash_attention_times']), 'cuda': ('dynamic_mask_attention', results['dynamic_mask_attention_status'], results['dynamic_mask_attention_times']), 'triton': ('dynamic_mask_attention_triton', results['dynamic_mask_attention_triton_status'], results['dynamic_mask_attention_triton_times']), 'flex': ('dynamic_mask_attention_flex', results['dynamic_mask_attention_flex_status'], results['dynamic_mask_attention_flex_times']) @@ -829,9 +829,9 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): time_strs[impl_key] = status[:8] # Truncate status for display time_avgs[impl_key] = float('inf') - # Calculate speedups (compared to Flash Attention baseline) + # Calculate speedups speedup_strs = {} - flash_avg = time_avgs.get('flash', float('inf')) + flash_avg = time_avgs.get('sdpa', float('inf')) for impl_key in ['cuda', 'triton', 'flex']: impl_avg = time_avgs.get(impl_key, float('inf')) @@ -871,7 +871,7 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): speedup_summary = f"{best_impl}:{best_speedup}" if best_speedup != "N/A" else "N/A" - print(f"{icons} {config_short:<48} {time_strs['flash']:<12} {time_strs['cuda']:<12} {time_strs['triton']:<12} {time_strs['flex']:<18} {speedup_summary:<15}") + print(f"{icons} {config_short:<48} {time_strs['sdpa']:<12} {time_strs['cuda']:<12} {time_strs['triton']:<12} {time_strs['flex']:<18} {speedup_summary:<15}") print("🔄" + "-" * 150 + "🔄") @@ -923,9 +923,9 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): icon = "😐" impl_name = impl_key.replace('_', '-').upper() - print(f" {icon} {impl_name:10} vs Flash - Avg: {avg_speedup:.2f}x (Best: {max_speedup:.2f}x, Worst: {min_speedup:.2f}x)") + print(f" {icon} {impl_name:10} vs SDPA - Avg: {avg_speedup:.2f}x (Best: {max_speedup:.2f}x, Worst: {min_speedup:.2f}x)") else: - print(f" ❌ {impl_key.replace('_', '-').upper():10} vs Flash - No successful runs") + print(f" ❌ {impl_key.replace('_', '-').upper():10} vs SDPA - No successful runs") def main(): @@ -947,7 +947,7 @@ def main(): parser.add_argument('--runs', type=int, default=3, help='Number of benchmark runs') parser.add_argument('--warmup', type=int, default=2, help='Number of warmup runs') parser.add_argument('--test-type', type=str, default='all', - choices=['all', 'flash', 'cuda', 'triton', 'flex', 'flash-vs-cuda', 'flash-vs-triton', 'flash-vs-flex'], + choices=['all', 'sdpa', 'cuda', 'triton', 'flex', 'sdpa-vs-cuda', 'sdpa-vs-triton', 'sdpa-vs-flex'], help='Type of benchmark to run (default: all)') args = parser.parse_args()