Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 28 additions & 28 deletions benchmarks/benchmark_forward_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
Performance Benchmark for Dynamic Mask Attention

This script measures and compares the performance of multiple Dynamic Mask Attention
implementations against Flash Attention baseline across various configurations.
implementations against SDPA baseline across various configurations.

Implementations tested:
- Flash Attention (PyTorch SDPA Flash Attention backend) - Baseline
- PyTorch SDPA - Baseline
- Dynamic Mask Attention CUDA - Custom CUDA kernel implementation
- Dynamic Mask Attention Triton - Triton kernel implementation
- Dynamic Mask Attention Flex - Flex Attention implementation
Expand Down Expand Up @@ -150,7 +150,7 @@ def calculate_zoh_states(value_states, dt_proj, A):
return zoh_states


def flash_attention_cuda(
def scaled_dot_product_attention_cuda(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
Expand All @@ -159,7 +159,7 @@ def flash_attention_cuda(
is_causal=True,
):
"""
CUDA implementation of Flash Attention baseline.
CUDA implementation of SDPA baseline.

Args:
query_states: [batch_size, num_heads, query_len, head_dim]
Expand Down Expand Up @@ -532,19 +532,19 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_
}

# Determine which implementations to run
run_flash = test_type in ['all', 'flash', 'flash-vs-cuda', 'flash-vs-triton', 'flash-vs-flex']
run_cuda = test_type in ['all', 'cuda', 'flash-vs-cuda']
run_triton = test_type in ['all', 'triton', 'flash-vs-triton']
run_flex = test_type in ['all', 'flex', 'flash-vs-flex']
run_flash = test_type in ['all', 'sdpa', 'sdpa-vs-cuda', 'sdpa-vs-triton', 'sdpa-vs-flex']
run_cuda = test_type in ['all', 'cuda', 'sdpa-vs-cuda']
run_triton = test_type in ['all', 'triton', 'sdpa-vs-triton']
run_flex = test_type in ['all', 'flex', 'sdpa-vs-flex']

# Benchmark Flash Attention
# Benchmark SDPA
if run_flash:
gc.collect()
torch.cuda.empty_cache()

# Warmup runs
for _ in range(warmup_runs):
result = flash_attention_cuda(
result = scaled_dot_product_attention_cuda(
query_states, key_states, value_states,
scaling, causal_mask, is_causal
)
Expand All @@ -559,7 +559,7 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_

# Actual benchmark runs
for _ in range(num_runs):
result = flash_attention_cuda(
result = scaled_dot_product_attention_cuda(
query_states, key_states, value_states,
scaling, causal_mask, is_causal
)
Expand Down Expand Up @@ -712,15 +712,15 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2):

# Update title based on test type
if test_type == 'all':
title = "⚡ Performance Benchmark: Flash vs CUDA vs Triton vs Flex ⚡"
elif test_type == 'flash-vs-cuda':
title = "⚡ Performance Benchmark: Flash Attention vs CUDA ⚡"
elif test_type == 'flash-vs-triton':
title = "⚡ Performance Benchmark: Flash Attention vs Triton ⚡"
elif test_type == 'flash-vs-flex':
title = "⚡ Performance Benchmark: Flash Attention vs Flex ⚡"
elif test_type == 'flash':
title = "⚡ Performance Benchmark: Flash Attention Only ⚡"
title = "⚡ Performance Benchmark: SDPA vs CUDA vs Triton vs Flex ⚡"
elif test_type == 'sdpa-vs-cuda':
title = "⚡ Performance Benchmark: SDPA Attention vs CUDA ⚡"
elif test_type == 'sdpa-vs-triton':
title = "⚡ Performance Benchmark: SDPA Attention vs Triton ⚡"
elif test_type == 'sdpa-vs-flex':
title = "⚡ Performance Benchmark: SDPA Attention vs Flex ⚡"
elif test_type == 'sdpa':
title = "⚡ Performance Benchmark: SDPA Attention Only ⚡"
elif test_type == 'cuda':
title = "⚡ Performance Benchmark: CUDA Implementations ⚡"
elif test_type == 'triton':
Expand Down Expand Up @@ -797,7 +797,7 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2):
]

print(f"\n📊 Benchmark Results (averaged over {num_runs} runs):")
print(f"🔧 {'Configuration':<60} ⚡ {'Flash':<10} 🚀 {'CUDA':<10} 🌟 {'Triton':<10} 🌟 {'Flex':<15} 📈 {'Speedup':<15}")
print(f"🔧 {'Configuration':<60} ⚡ {'SDPA':<10} 🚀 {'CUDA':<10} 🌟 {'Triton':<10} 🌟 {'Flex':<15} 📈 {'Speedup':<15}")
print("🔄" + "-" * 150 + "🔄")

all_results = []
Expand All @@ -810,7 +810,7 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2):

# Calculate averages for all implementations
implementations = {
'flash': ('flash_attention', results['flash_attention_status'], results['flash_attention_times']),
'sdpa': ('flash_attention', results['flash_attention_status'], results['flash_attention_times']),
'cuda': ('dynamic_mask_attention', results['dynamic_mask_attention_status'], results['dynamic_mask_attention_times']),
'triton': ('dynamic_mask_attention_triton', results['dynamic_mask_attention_triton_status'], results['dynamic_mask_attention_triton_times']),
'flex': ('dynamic_mask_attention_flex', results['dynamic_mask_attention_flex_status'], results['dynamic_mask_attention_flex_times'])
Expand All @@ -829,9 +829,9 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2):
time_strs[impl_key] = status[:8] # Truncate status for display
time_avgs[impl_key] = float('inf')

# Calculate speedups (compared to Flash Attention baseline)
# Calculate speedups
speedup_strs = {}
flash_avg = time_avgs.get('flash', float('inf'))
flash_avg = time_avgs.get('sdpa', float('inf'))

for impl_key in ['cuda', 'triton', 'flex']:
impl_avg = time_avgs.get(impl_key, float('inf'))
Expand Down Expand Up @@ -871,7 +871,7 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2):

speedup_summary = f"{best_impl}:{best_speedup}" if best_speedup != "N/A" else "N/A"

print(f"{icons} {config_short:<48} {time_strs['flash']:<12} {time_strs['cuda']:<12} {time_strs['triton']:<12} {time_strs['flex']:<18} {speedup_summary:<15}")
print(f"{icons} {config_short:<48} {time_strs['sdpa']:<12} {time_strs['cuda']:<12} {time_strs['triton']:<12} {time_strs['flex']:<18} {speedup_summary:<15}")

print("🔄" + "-" * 150 + "🔄")

Expand Down Expand Up @@ -923,9 +923,9 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2):
icon = "😐"

impl_name = impl_key.replace('_', '-').upper()
print(f" {icon} {impl_name:10} vs Flash - Avg: {avg_speedup:.2f}x (Best: {max_speedup:.2f}x, Worst: {min_speedup:.2f}x)")
print(f" {icon} {impl_name:10} vs SDPA - Avg: {avg_speedup:.2f}x (Best: {max_speedup:.2f}x, Worst: {min_speedup:.2f}x)")
else:
print(f" ❌ {impl_key.replace('_', '-').upper():10} vs Flash - No successful runs")
print(f" ❌ {impl_key.replace('_', '-').upper():10} vs SDPA - No successful runs")


def main():
Expand All @@ -947,7 +947,7 @@ def main():
parser.add_argument('--runs', type=int, default=3, help='Number of benchmark runs')
parser.add_argument('--warmup', type=int, default=2, help='Number of warmup runs')
parser.add_argument('--test-type', type=str, default='all',
choices=['all', 'flash', 'cuda', 'triton', 'flex', 'flash-vs-cuda', 'flash-vs-triton', 'flash-vs-flex'],
choices=['all', 'sdpa', 'cuda', 'triton', 'flex', 'sdpa-vs-cuda', 'sdpa-vs-triton', 'sdpa-vs-flex'],
help='Type of benchmark to run (default: all)')

args = parser.parse_args()
Expand Down