Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 45 additions & 47 deletions benchmarks/benchmark_forward_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,49 +735,49 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2):

# Test configurations: (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal)
configs = [
# # Vary sequence length
# (1, 2, 1, 256, 256, 128, 2048, True),
# (1, 2, 1, 512, 512, 128, 2048, True),
# (1, 2, 1, 1024, 1024, 128, 2048, True),
# (1, 2, 1, 2048, 2048, 128, 2048, True),
# (1, 2, 1, 4096, 4096, 128, 2048, True),
# (1, 2, 1, 8192, 8192, 128, 2048, True),
# (1, 2, 1, 16384, 16384, 128, 2048, True),
# (1, 2, 1, 32768, 32768, 128, 2048, True),

# # Inference
# (1, 2, 1, 2, 256, 128, 2048, True),
# (1, 2, 1, 2, 512, 128, 2048, True),
# (1, 2, 1, 2, 1024, 128, 2048, True),
# (1, 2, 1, 2, 2048, 128, 2048, True),
# (1, 2, 1, 2, 4096, 128, 2048, True),
# (1, 2, 1, 2, 8192, 128, 2048, True),
# (1, 2, 1, 2, 16384, 128, 2048, True),
# (1, 2, 1, 2, 32768, 128, 2048, True),
# (1, 2, 1, 2, 65536, 128, 2048, True),
# (1, 2, 1, 2, 131072, 128, 2048, True),
# (1, 2, 1, 2, 262144, 128, 2048, True),
# (1, 2, 1, 2, 524288, 128, 2048, True),

# # Vary batch size
# (1, 2, 1, 1024, 1024, 32, 2048, True),
# (2, 2, 1, 1024, 1024, 32, 2048, True),
# (4, 2, 1, 1024, 1024, 32, 2048, True),
# (8, 2, 1, 1024, 1024, 32, 2048, True),

# # Vary head count
# (1, 1, 1, 1024, 1024, 32, 2048, True),
# (1, 2, 1, 1024, 1024, 32, 2048, True),
# (1, 4, 1, 1024, 1024, 32, 2048, True),
# (1, 8, 2, 1024, 1024, 32, 2048, True),

# # Vary head dimension
# (1, 2, 1, 1024, 1024, 32, 2048, True),
# (1, 2, 1, 1024, 1024, 64, 2048, True),
# (1, 2, 1, 1024, 1024, 96, 2048, True),
# (1, 2, 1, 1024, 1024, 128, 2048, True),
# (1, 2, 1, 1024, 1024, 192, 2048, True),
# (1, 2, 1, 1024, 1024, 256, 2048, True),
# Vary sequence length
(1, 2, 1, 256, 256, 128, 2048, True),
(1, 2, 1, 512, 512, 128, 2048, True),
(1, 2, 1, 1024, 1024, 128, 2048, True),
(1, 2, 1, 2048, 2048, 128, 2048, True),
(1, 2, 1, 4096, 4096, 128, 2048, True),
(1, 2, 1, 8192, 8192, 128, 2048, True),
(1, 2, 1, 16384, 16384, 128, 2048, True),
(1, 2, 1, 32768, 32768, 128, 2048, True),

# Inference
(1, 2, 1, 2, 256, 128, 2048, True),
(1, 2, 1, 2, 512, 128, 2048, True),
(1, 2, 1, 2, 1024, 128, 2048, True),
(1, 2, 1, 2, 2048, 128, 2048, True),
(1, 2, 1, 2, 4096, 128, 2048, True),
(1, 2, 1, 2, 8192, 128, 2048, True),
Comment on lines 737 to +754
Copy link

Copilot AI Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The manually unrolled configs list is quite large and repetitive. Consider generating these configurations programmatically using loops or list comprehensions to improve readability and reduce duplication.

Copilot uses AI. Check for mistakes.
(1, 2, 1, 2, 16384, 128, 2048, True),
(1, 2, 1, 2, 32768, 128, 2048, True),
(1, 2, 1, 2, 65536, 128, 2048, True),
(1, 2, 1, 2, 131072, 128, 2048, True),
(1, 2, 1, 2, 262144, 128, 2048, True),
(1, 2, 1, 2, 524288, 128, 2048, True),

# Vary batch size
(1, 2, 1, 4096, 4096, 32, 2048, True),
(2, 2, 1, 4096, 4096, 32, 2048, True),
(4, 2, 1, 4096, 4096, 32, 2048, True),
(8, 2, 1, 4096, 4096, 32, 2048, True),

# Vary head count
(1, 1, 1, 4096, 4096, 32, 2048, True),
(1, 2, 1, 4096, 4096, 32, 2048, True),
(1, 4, 1, 4096, 4096, 32, 2048, True),
(1, 8, 2, 4096, 4096, 32, 2048, True),

# Vary head dimension
(1, 2, 1, 4096, 4096, 32, 2048, True),
(1, 2, 1, 4096, 4096, 64, 2048, True),
(1, 2, 1, 4096, 4096, 96, 2048, True),
(1, 2, 1, 4096, 4096, 128, 2048, True),
(1, 2, 1, 4096, 4096, 192, 2048, True),
(1, 2, 1, 4096, 4096, 256, 2048, True),

# Vary keep_window_size
(1, 2, 1, 32768, 32768, 128, 32, True),
Expand All @@ -792,11 +792,9 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2):
(1, 2, 1, 32768, 32768, 128, 16384, True),
(1, 2, 1, 32768, 32768, 128, 32768, True),

# # Test non-causal
# (1, 2, 1, 1024, 1024, 128, 2048, False),
# Test non-causal
(1, 2, 1, 4096, 4096, 128, 2048, False),
Comment on lines +795 to +796
Copy link

Copilot AI Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Instead of hardcoding a single non-causal test tuple, you could extend your programmatic generation approach (e.g., include is_causal as a parameter) so adding or modifying test modes scales more cleanly.

Suggested change
# Test non-causal
(1, 2, 1, 4096, 4096, 128, 2048, False),
# Generate configurations for both causal and non-causal cases
*[(1, 2, 1, 4096, 4096, 128, 2048, is_causal) for is_causal in [True, False]],

Copilot uses AI. Check for mistakes.
]

num_runs = 3 # Run 3 times and take average

print(f"\n📊 Benchmark Results (averaged over {num_runs} runs):")
print(f"🔧 {'Configuration':<60} ⚡ {'Flash':<10} 🚀 {'CUDA':<10} 🌟 {'Triton':<10} 🌟 {'Flex':<15} 📈 {'Speedup':<15}")
Expand Down