diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index a38036b..7f465cf 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -735,49 +735,49 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): # Test configurations: (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal) configs = [ - # # Vary sequence length - # (1, 2, 1, 256, 256, 128, 2048, True), - # (1, 2, 1, 512, 512, 128, 2048, True), - # (1, 2, 1, 1024, 1024, 128, 2048, True), - # (1, 2, 1, 2048, 2048, 128, 2048, True), - # (1, 2, 1, 4096, 4096, 128, 2048, True), - # (1, 2, 1, 8192, 8192, 128, 2048, True), - # (1, 2, 1, 16384, 16384, 128, 2048, True), - # (1, 2, 1, 32768, 32768, 128, 2048, True), - - # # Inference - # (1, 2, 1, 2, 256, 128, 2048, True), - # (1, 2, 1, 2, 512, 128, 2048, True), - # (1, 2, 1, 2, 1024, 128, 2048, True), - # (1, 2, 1, 2, 2048, 128, 2048, True), - # (1, 2, 1, 2, 4096, 128, 2048, True), - # (1, 2, 1, 2, 8192, 128, 2048, True), - # (1, 2, 1, 2, 16384, 128, 2048, True), - # (1, 2, 1, 2, 32768, 128, 2048, True), - # (1, 2, 1, 2, 65536, 128, 2048, True), - # (1, 2, 1, 2, 131072, 128, 2048, True), - # (1, 2, 1, 2, 262144, 128, 2048, True), - # (1, 2, 1, 2, 524288, 128, 2048, True), - - # # Vary batch size - # (1, 2, 1, 1024, 1024, 32, 2048, True), - # (2, 2, 1, 1024, 1024, 32, 2048, True), - # (4, 2, 1, 1024, 1024, 32, 2048, True), - # (8, 2, 1, 1024, 1024, 32, 2048, True), - - # # Vary head count - # (1, 1, 1, 1024, 1024, 32, 2048, True), - # (1, 2, 1, 1024, 1024, 32, 2048, True), - # (1, 4, 1, 1024, 1024, 32, 2048, True), - # (1, 8, 2, 1024, 1024, 32, 2048, True), - - # # Vary head dimension - # (1, 2, 1, 1024, 1024, 32, 2048, True), - # (1, 2, 1, 1024, 1024, 64, 2048, True), - # (1, 2, 1, 1024, 1024, 96, 2048, True), - # (1, 2, 1, 1024, 1024, 128, 2048, True), - # (1, 2, 1, 1024, 1024, 192, 2048, True), - # (1, 2, 1, 1024, 1024, 256, 2048, True), + # Vary sequence length + (1, 2, 1, 256, 256, 128, 2048, True), + (1, 2, 1, 512, 512, 128, 2048, True), + (1, 2, 1, 1024, 1024, 128, 2048, True), + (1, 2, 1, 2048, 2048, 128, 2048, True), + (1, 2, 1, 4096, 4096, 128, 2048, True), + (1, 2, 1, 8192, 8192, 128, 2048, True), + (1, 2, 1, 16384, 16384, 128, 2048, True), + (1, 2, 1, 32768, 32768, 128, 2048, True), + + # Inference + (1, 2, 1, 2, 256, 128, 2048, True), + (1, 2, 1, 2, 512, 128, 2048, True), + (1, 2, 1, 2, 1024, 128, 2048, True), + (1, 2, 1, 2, 2048, 128, 2048, True), + (1, 2, 1, 2, 4096, 128, 2048, True), + (1, 2, 1, 2, 8192, 128, 2048, True), + (1, 2, 1, 2, 16384, 128, 2048, True), + (1, 2, 1, 2, 32768, 128, 2048, True), + (1, 2, 1, 2, 65536, 128, 2048, True), + (1, 2, 1, 2, 131072, 128, 2048, True), + (1, 2, 1, 2, 262144, 128, 2048, True), + (1, 2, 1, 2, 524288, 128, 2048, True), + + # Vary batch size + (1, 2, 1, 4096, 4096, 32, 2048, True), + (2, 2, 1, 4096, 4096, 32, 2048, True), + (4, 2, 1, 4096, 4096, 32, 2048, True), + (8, 2, 1, 4096, 4096, 32, 2048, True), + + # Vary head count + (1, 1, 1, 4096, 4096, 32, 2048, True), + (1, 2, 1, 4096, 4096, 32, 2048, True), + (1, 4, 1, 4096, 4096, 32, 2048, True), + (1, 8, 2, 4096, 4096, 32, 2048, True), + + # Vary head dimension + (1, 2, 1, 4096, 4096, 32, 2048, True), + (1, 2, 1, 4096, 4096, 64, 2048, True), + (1, 2, 1, 4096, 4096, 96, 2048, True), + (1, 2, 1, 4096, 4096, 128, 2048, True), + (1, 2, 1, 4096, 4096, 192, 2048, True), + (1, 2, 1, 4096, 4096, 256, 2048, True), # Vary keep_window_size (1, 2, 1, 32768, 32768, 128, 32, True), @@ -792,11 +792,9 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): (1, 2, 1, 32768, 32768, 128, 16384, True), (1, 2, 1, 32768, 32768, 128, 32768, True), - # # Test non-causal - # (1, 2, 1, 1024, 1024, 128, 2048, False), + # Test non-causal + (1, 2, 1, 4096, 4096, 128, 2048, False), ] - - num_runs = 3 # Run 3 times and take average print(f"\nšŸ“Š Benchmark Results (averaged over {num_runs} runs):") print(f"šŸ”§ {'Configuration':<60} ⚔ {'Flash':<10} šŸš€ {'CUDA':<10} 🌟 {'Triton':<10} 🌟 {'Flex':<15} šŸ“ˆ {'Speedup':<15}")