diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index 7f465cf..1d4e0bc 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -3,10 +3,10 @@ Performance Benchmark for Dynamic Mask Attention This script measures and compares the performance of multiple Dynamic Mask Attention -implementations against Flash Attention baseline across various configurations. +implementations against SDPA baseline across various configurations. Implementations tested: -- Flash Attention (PyTorch SDPA Flash Attention backend) - Baseline +- PyTorch SDPA - Baseline - Dynamic Mask Attention CUDA - Custom CUDA kernel implementation - Dynamic Mask Attention Triton - Triton kernel implementation - Dynamic Mask Attention Flex - Flex Attention implementation @@ -150,7 +150,7 @@ def calculate_zoh_states(value_states, dt_proj, A): return zoh_states -def flash_attention_cuda( +def scaled_dot_product_attention_cuda( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, @@ -159,7 +159,7 @@ def flash_attention_cuda( is_causal=True, ): """ - CUDA implementation of Flash Attention baseline. + CUDA implementation of SDPA baseline. Args: query_states: [batch_size, num_heads, query_len, head_dim] @@ -532,19 +532,19 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ } # Determine which implementations to run - run_flash = test_type in ['all', 'flash', 'flash-vs-cuda', 'flash-vs-triton', 'flash-vs-flex'] - run_cuda = test_type in ['all', 'cuda', 'flash-vs-cuda'] - run_triton = test_type in ['all', 'triton', 'flash-vs-triton'] - run_flex = test_type in ['all', 'flex', 'flash-vs-flex'] + run_flash = test_type in ['all', 'sdpa', 'sdpa-vs-cuda', 'sdpa-vs-triton', 'sdpa-vs-flex'] + run_cuda = test_type in ['all', 'cuda', 'sdpa-vs-cuda'] + run_triton = test_type in ['all', 'triton', 'sdpa-vs-triton'] + run_flex = test_type in ['all', 'flex', 'sdpa-vs-flex'] - # Benchmark Flash Attention + # Benchmark SDPA if run_flash: gc.collect() torch.cuda.empty_cache() # Warmup runs for _ in range(warmup_runs): - result = flash_attention_cuda( + result = scaled_dot_product_attention_cuda( query_states, key_states, value_states, scaling, causal_mask, is_causal ) @@ -559,7 +559,7 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ # Actual benchmark runs for _ in range(num_runs): - result = flash_attention_cuda( + result = scaled_dot_product_attention_cuda( query_states, key_states, value_states, scaling, causal_mask, is_causal ) @@ -712,15 +712,15 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): # Update title based on test type if test_type == 'all': - title = "⚡ Performance Benchmark: Flash vs CUDA vs Triton vs Flex ⚡" - elif test_type == 'flash-vs-cuda': - title = "⚡ Performance Benchmark: Flash Attention vs CUDA ⚡" - elif test_type == 'flash-vs-triton': - title = "⚡ Performance Benchmark: Flash Attention vs Triton ⚡" - elif test_type == 'flash-vs-flex': - title = "⚡ Performance Benchmark: Flash Attention vs Flex ⚡" - elif test_type == 'flash': - title = "⚡ Performance Benchmark: Flash Attention Only ⚡" + title = "⚡ Performance Benchmark: SDPA vs CUDA vs Triton vs Flex ⚡" + elif test_type == 'sdpa-vs-cuda': + title = "⚡ Performance Benchmark: SDPA Attention vs CUDA ⚡" + elif test_type == 'sdpa-vs-triton': + title = "⚡ Performance Benchmark: SDPA Attention vs Triton ⚡" + elif test_type == 'sdpa-vs-flex': + title = "⚡ Performance Benchmark: SDPA Attention vs Flex ⚡" + elif test_type == 'sdpa': + title = "⚡ Performance Benchmark: SDPA Attention Only ⚡" elif test_type == 'cuda': title = "⚡ Performance Benchmark: CUDA Implementations ⚡" elif test_type == 'triton': @@ -797,7 +797,7 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): ] print(f"\n📊 Benchmark Results (averaged over {num_runs} runs):") - print(f"🔧 {'Configuration':<60} ⚡ {'Flash':<10} 🚀 {'CUDA':<10} 🌟 {'Triton':<10} 🌟 {'Flex':<15} 📈 {'Speedup':<15}") + print(f"🔧 {'Configuration':<60} ⚡ {'SDPA':<10} 🚀 {'CUDA':<10} 🌟 {'Triton':<10} 🌟 {'Flex':<15} 📈 {'Speedup':<15}") print("🔄" + "-" * 150 + "🔄") all_results = [] @@ -810,7 +810,7 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): # Calculate averages for all implementations implementations = { - 'flash': ('flash_attention', results['flash_attention_status'], results['flash_attention_times']), + 'sdpa': ('flash_attention', results['flash_attention_status'], results['flash_attention_times']), 'cuda': ('dynamic_mask_attention', results['dynamic_mask_attention_status'], results['dynamic_mask_attention_times']), 'triton': ('dynamic_mask_attention_triton', results['dynamic_mask_attention_triton_status'], results['dynamic_mask_attention_triton_times']), 'flex': ('dynamic_mask_attention_flex', results['dynamic_mask_attention_flex_status'], results['dynamic_mask_attention_flex_times']) @@ -829,9 +829,9 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): time_strs[impl_key] = status[:8] # Truncate status for display time_avgs[impl_key] = float('inf') - # Calculate speedups (compared to Flash Attention baseline) + # Calculate speedups speedup_strs = {} - flash_avg = time_avgs.get('flash', float('inf')) + flash_avg = time_avgs.get('sdpa', float('inf')) for impl_key in ['cuda', 'triton', 'flex']: impl_avg = time_avgs.get(impl_key, float('inf')) @@ -871,7 +871,7 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): speedup_summary = f"{best_impl}:{best_speedup}" if best_speedup != "N/A" else "N/A" - print(f"{icons} {config_short:<48} {time_strs['flash']:<12} {time_strs['cuda']:<12} {time_strs['triton']:<12} {time_strs['flex']:<18} {speedup_summary:<15}") + print(f"{icons} {config_short:<48} {time_strs['sdpa']:<12} {time_strs['cuda']:<12} {time_strs['triton']:<12} {time_strs['flex']:<18} {speedup_summary:<15}") print("🔄" + "-" * 150 + "🔄") @@ -923,9 +923,9 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): icon = "😐" impl_name = impl_key.replace('_', '-').upper() - print(f" {icon} {impl_name:10} vs Flash - Avg: {avg_speedup:.2f}x (Best: {max_speedup:.2f}x, Worst: {min_speedup:.2f}x)") + print(f" {icon} {impl_name:10} vs SDPA - Avg: {avg_speedup:.2f}x (Best: {max_speedup:.2f}x, Worst: {min_speedup:.2f}x)") else: - print(f" ❌ {impl_key.replace('_', '-').upper():10} vs Flash - No successful runs") + print(f" ❌ {impl_key.replace('_', '-').upper():10} vs SDPA - No successful runs") def main(): @@ -947,7 +947,7 @@ def main(): parser.add_argument('--runs', type=int, default=3, help='Number of benchmark runs') parser.add_argument('--warmup', type=int, default=2, help='Number of warmup runs') parser.add_argument('--test-type', type=str, default='all', - choices=['all', 'flash', 'cuda', 'triton', 'flex', 'flash-vs-cuda', 'flash-vs-triton', 'flash-vs-flex'], + choices=['all', 'sdpa', 'cuda', 'triton', 'flex', 'sdpa-vs-cuda', 'sdpa-vs-triton', 'sdpa-vs-flex'], help='Type of benchmark to run (default: all)') args = parser.parse_args()