Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 59 additions & 209 deletions benchmarks/benchmark_forward_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
Implementations tested:
- Flash Attention (PyTorch SDPA Flash Attention backend) - Baseline
- Dynamic Mask Attention CUDA - Custom CUDA kernel implementation
- Dynamic Mask Attention CUDA (No TopK) - CUDA kernel without TopK computation
- Dynamic Mask Attention Triton - Triton kernel implementation
- Dynamic Mask Attention Flex - Flex Attention implementation

Expand Down Expand Up @@ -187,16 +186,15 @@ def flash_attention_cuda(
torch.cuda.synchronize()
start_time = time.time()

with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.FLASH_ATTENTION]):
attn_outputs = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
# attn_mask=causal_mask,
scale=scaling,
is_causal=is_causal if query_len == key_len else False,
enable_gqa=True
)
attn_outputs = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
scale=scaling,
# is_causal=is_causal if query_len == key_len else False,
enable_gqa=True
)

torch.cuda.synchronize()
end_time = time.time()
Expand Down Expand Up @@ -293,99 +291,6 @@ def dynamic_mask_attention_cuda(
return "OOM", 0


def dynamic_mask_attention_cuda_no_topk(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
dt_proj: torch.Tensor,
A: torch.Tensor,
scaling: float,
causal_mask: torch.Tensor,
keep_window_size=2048,
is_causal=True,
return_softmax=False
):
"""
CUDA implementation of dynamic mask attention without topk computation.
This version skips the topk calculation for more accurate kernel performance testing.

Args:
query_states: [batch_size, num_heads, query_len, head_dim]
key_states: [batch_size, num_kv_heads, key_len, head_dim]
value_states: [batch_size, num_kv_heads, key_len, head_dim]
dt_proj: [num_kv_heads, num_kv_heads * head_dim]
A: [num_kv_heads]
scaling: Attention scaling factor
causal_mask: Causal attention mask
keep_window_size: Number of tokens to keep in attention window
is_causal: Whether to apply causal masking
return_softmax: Whether to return softmax weights

Returns:
attn_outputs: [batch_size, query_len, num_heads, head_dim]
"""
# Calculate zoh_states
zoh_states = calculate_zoh_states(value_states, dt_proj, A)

# Create a simplified mask without topk computation
batch_size, _, query_len, _ = query_states.shape
_, num_kv_heads, key_len, _ = key_states.shape
dtype = query_states.dtype
device = query_states.device

# Create full attn mask (no topk selection)
attn_mask = torch.zeros(
(batch_size, num_kv_heads, query_len, key_len),
dtype=dtype,
device=device
)

# Ensure correct data types and memory layout
query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim]
key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]
value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]
attn_bias = zoh_states[:, :, None, :].expand(
-1, -1, query_states.shape[1], -1
).contiguous() # [batch, num_kv_heads, query_len, key_len]
# Create full active mask (no topk selection)
attn_mask = torch.zeros_like(
attn_bias,
dtype=dtype,
device=device
).contiguous() # [batch, num_kv_heads, query_len, key_len]

try:
out_tensor = None # Let the function allocate the output tensor

# Only measure the core CUDA kernel computation
torch.cuda.synchronize()
start_time = time.time()

result = flash_dmattn_cuda.fwd( # type: ignore
query_states, # q: [batch, seqlen_q, num_heads, head_dim]
key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim]
value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim]
attn_mask, # attn_mask: [batch, num_kv_heads, seqlen_q, seqlen_k]
attn_bias, # attn_bias: [batch, num_kv_heads, seqlen_q, seqlen_k]
out_tensor, # out: None to auto-allocate
0.0, # p_dropout
scaling, # softmax_scale
is_causal, # is_causal
keep_window_size, # keep_window_size
0.0, # softcap
return_softmax, # return_softmax
None # gen (generator)
)

torch.cuda.synchronize()
end_time = time.time()

attn_outputs = result[0]
return attn_outputs, (end_time - start_time) * 1000 # Return output and time in ms
except torch.cuda.OutOfMemoryError:
return "OOM", 0


def dynamic_mask_attention_triton(
query_states: torch.Tensor,
key_states: torch.Tensor,
Expand Down Expand Up @@ -614,25 +519,21 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_
'config': config,
'flash_attention_times': [],
'dynamic_mask_attention_times': [],
'dynamic_mask_attention_no_topk_times': [],
'dynamic_mask_attention_triton_times': [],
'dynamic_mask_attention_flex_times': [],
'flash_attention_memory': 0,
'dynamic_mask_attention_memory': 0,
'dynamic_mask_attention_no_topk_memory': 0,
'dynamic_mask_attention_triton_memory': 0,
'dynamic_mask_attention_flex_memory': 0,
'flash_attention_status': 'success',
'dynamic_mask_attention_status': 'success',
'dynamic_mask_attention_no_topk_status': 'success',
'dynamic_mask_attention_triton_status': 'success',
'dynamic_mask_attention_flex_status': 'success'
}

# Determine which implementations to run
run_flash = test_type in ['all', 'flash', 'flash-vs-cuda', 'flash-vs-triton', 'flash-vs-flex']
run_cuda = test_type in ['all', 'cuda', 'flash-vs-cuda']
run_no_topk = test_type in ['all', 'cuda']
run_triton = test_type in ['all', 'triton', 'flash-vs-triton']
run_flex = test_type in ['all', 'flex', 'flash-vs-flex']

Expand Down Expand Up @@ -718,48 +619,6 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_
else:
results['dynamic_mask_attention_status'] = 'N/A'

# Benchmark Dynamic Mask Attention (No TopK)
if run_no_topk:
gc.collect()
torch.cuda.empty_cache()

# Warmup runs
for _ in range(warmup_runs):
result = dynamic_mask_attention_cuda_no_topk(
query_states, key_states, value_states,
dt_proj, A, scaling, causal_mask,
keep_window_size, is_causal
)
if result[0] == "OOM":
results['dynamic_mask_attention_no_topk_status'] = 'OOM'
break
torch.cuda.synchronize()

if results['dynamic_mask_attention_no_topk_status'] == 'success':
# Measure memory before benchmark
mem_before = measure_memory_usage()

# Actual benchmark runs
for _ in range(num_runs):
result = dynamic_mask_attention_cuda_no_topk(
query_states, key_states, value_states,
dt_proj, A, scaling, causal_mask,
keep_window_size, is_causal
)

if result[0] == "OOM":
results['dynamic_mask_attention_no_topk_status'] = 'OOM'
break

# Use the timing from the function instead of measuring here
results['dynamic_mask_attention_no_topk_times'].append(result[1]) # ms

# Measure memory after
mem_after = measure_memory_usage()
results['dynamic_mask_attention_no_topk_memory'] = mem_after[0] - mem_before[0]
else:
results['dynamic_mask_attention_no_topk_status'] = 'N/A'

# Benchmark Dynamic Mask Attention (Triton)
if run_triton:
gc.collect()
Expand Down Expand Up @@ -876,49 +735,49 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2):

# Test configurations: (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal)
configs = [
# Vary sequence length
(1, 2, 1, 256, 256, 32, 2048, True),
(1, 2, 1, 512, 512, 32, 2048, True),
(1, 2, 1, 1024, 1024, 32, 2048, True),
(1, 2, 1, 2048, 2048, 32, 2048, True),
(1, 2, 1, 4096, 4096, 32, 2048, True),
(1, 2, 1, 8192, 8192, 32, 2048, True),
(1, 2, 1, 16384, 16384, 32, 2048, True),
(1, 2, 1, 32768, 32768, 32, 2048, True),

# Inference
(1, 2, 1, 2, 256, 128, 2048, True),
(1, 2, 1, 2, 512, 128, 2048, True),
(1, 2, 1, 2, 1024, 128, 2048, True),
(1, 2, 1, 2, 2048, 128, 2048, True),
(1, 2, 1, 2, 4096, 128, 2048, True),
(1, 2, 1, 2, 8192, 128, 2048, True),
(1, 2, 1, 2, 16384, 128, 2048, True),
(1, 2, 1, 2, 32768, 128, 2048, True),
(1, 2, 1, 2, 65536, 128, 2048, True),
(1, 2, 1, 2, 131072, 128, 2048, True),
(1, 2, 1, 2, 262144, 128, 2048, True),
(1, 2, 1, 2, 524288, 128, 2048, True),

# Vary batch size
(1, 2, 1, 1024, 1024, 32, 2048, True),
(2, 2, 1, 1024, 1024, 32, 2048, True),
(4, 2, 1, 1024, 1024, 32, 2048, True),
(8, 2, 1, 1024, 1024, 32, 2048, True),

# Vary head count
(1, 1, 1, 1024, 1024, 32, 2048, True),
(1, 2, 1, 1024, 1024, 32, 2048, True),
(1, 4, 1, 1024, 1024, 32, 2048, True),
(1, 8, 2, 1024, 1024, 32, 2048, True),

# Vary head dimension
(1, 2, 1, 1024, 1024, 32, 2048, True),
(1, 2, 1, 1024, 1024, 64, 2048, True),
(1, 2, 1, 1024, 1024, 96, 2048, True),
(1, 2, 1, 1024, 1024, 128, 2048, True),
(1, 2, 1, 1024, 1024, 192, 2048, True),
(1, 2, 1, 1024, 1024, 256, 2048, True),
# # Vary sequence length
# (1, 2, 1, 256, 256, 128, 2048, True),
# (1, 2, 1, 512, 512, 128, 2048, True),
# (1, 2, 1, 1024, 1024, 128, 2048, True),
# (1, 2, 1, 2048, 2048, 128, 2048, True),
# (1, 2, 1, 4096, 4096, 128, 2048, True),
# (1, 2, 1, 8192, 8192, 128, 2048, True),
# (1, 2, 1, 16384, 16384, 128, 2048, True),
# (1, 2, 1, 32768, 32768, 128, 2048, True),

# # Inference
# (1, 2, 1, 2, 256, 128, 2048, True),
# (1, 2, 1, 2, 512, 128, 2048, True),
# (1, 2, 1, 2, 1024, 128, 2048, True),
# (1, 2, 1, 2, 2048, 128, 2048, True),
# (1, 2, 1, 2, 4096, 128, 2048, True),
# (1, 2, 1, 2, 8192, 128, 2048, True),
# (1, 2, 1, 2, 16384, 128, 2048, True),
# (1, 2, 1, 2, 32768, 128, 2048, True),
# (1, 2, 1, 2, 65536, 128, 2048, True),
# (1, 2, 1, 2, 131072, 128, 2048, True),
# (1, 2, 1, 2, 262144, 128, 2048, True),
# (1, 2, 1, 2, 524288, 128, 2048, True),

# # Vary batch size
# (1, 2, 1, 1024, 1024, 32, 2048, True),
# (2, 2, 1, 1024, 1024, 32, 2048, True),
# (4, 2, 1, 1024, 1024, 32, 2048, True),
# (8, 2, 1, 1024, 1024, 32, 2048, True),

# # Vary head count
# (1, 1, 1, 1024, 1024, 32, 2048, True),
# (1, 2, 1, 1024, 1024, 32, 2048, True),
# (1, 4, 1, 1024, 1024, 32, 2048, True),
# (1, 8, 2, 1024, 1024, 32, 2048, True),

# # Vary head dimension
# (1, 2, 1, 1024, 1024, 32, 2048, True),
# (1, 2, 1, 1024, 1024, 64, 2048, True),
# (1, 2, 1, 1024, 1024, 96, 2048, True),
# (1, 2, 1, 1024, 1024, 128, 2048, True),
# (1, 2, 1, 1024, 1024, 192, 2048, True),
# (1, 2, 1, 1024, 1024, 256, 2048, True),

# Vary keep_window_size
(1, 2, 1, 32768, 32768, 128, 32, True),
Expand All @@ -933,15 +792,15 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2):
(1, 2, 1, 32768, 32768, 128, 16384, True),
(1, 2, 1, 32768, 32768, 128, 32768, True),

# Test non-causal
(1, 2, 1, 1024, 1024, 128, 2048, False),
# # Test non-causal
# (1, 2, 1, 1024, 1024, 128, 2048, False),
]

num_runs = 3 # Run 3 times and take average

print(f"\n📊 Benchmark Results (averaged over {num_runs} runs):")
print(f"🔧 {'Configuration':<60} ⚡ {'Flash':<10} 🚀 {'CUDA':<10} 🚀 {'No-TopK':<10} 🌟 {'Triton':<10} 🌟 {'Flex':<15} 📈 {'Speedup':<15}")
print("🔄" + "-" * 160 + "🔄")
print(f"🔧 {'Configuration':<60} ⚡ {'Flash':<10} 🚀 {'CUDA':<10} 🌟 {'Triton':<10} 🌟 {'Flex':<15} 📈 {'Speedup':<15}")
print("🔄" + "-" * 150 + "🔄")

all_results = []

Expand All @@ -955,7 +814,6 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2):
implementations = {
'flash': ('flash_attention', results['flash_attention_status'], results['flash_attention_times']),
'cuda': ('dynamic_mask_attention', results['dynamic_mask_attention_status'], results['dynamic_mask_attention_times']),
'no_topk': ('dynamic_mask_attention_no_topk', results['dynamic_mask_attention_no_topk_status'], results['dynamic_mask_attention_no_topk_times']),
'triton': ('dynamic_mask_attention_triton', results['dynamic_mask_attention_triton_status'], results['dynamic_mask_attention_triton_times']),
'flex': ('dynamic_mask_attention_flex', results['dynamic_mask_attention_flex_status'], results['dynamic_mask_attention_flex_times'])
}
Expand All @@ -977,7 +835,7 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2):
speedup_strs = {}
flash_avg = time_avgs.get('flash', float('inf'))

for impl_key in ['cuda', 'no_topk', 'triton', 'flex']:
for impl_key in ['cuda', 'triton', 'flex']:
Copy link

Copilot AI Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hardcoded list of implementation keys should be extracted to a constant or derived from the implementations dictionary to avoid maintenance issues when adding or removing implementations.

Suggested change
for impl_key in ['cuda', 'triton', 'flex']:
for impl_key in [key for key in implementations.keys() if key != 'flash']:

Copilot uses AI. Check for mistakes.
impl_avg = time_avgs.get(impl_key, float('inf'))
if flash_avg != float('inf') and impl_avg != float('inf') and impl_avg > 0:
speedup = flash_avg / impl_avg
Expand Down Expand Up @@ -1015,14 +873,13 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2):

speedup_summary = f"{best_impl}:{best_speedup}" if best_speedup != "N/A" else "N/A"

print(f"{icons} {config_short:<48} {time_strs['flash']:<12} {time_strs['cuda']:<12} {time_strs['no_topk']:<14} {time_strs['triton']:<12} {time_strs['flex']:<18} {speedup_summary:<15}")
print(f"{icons} {config_short:<48} {time_strs['flash']:<12} {time_strs['cuda']:<12} {time_strs['triton']:<12} {time_strs['flex']:<18} {speedup_summary:<15}")

print("🔄" + "-" * 160 + "🔄")
print("🔄" + "-" * 150 + "🔄")

# Summary statistics
implementation_speedups = {
'cuda': [],
'no_topk': [],
'triton': [],
'flex': []
}
Expand Down Expand Up @@ -1071,13 +928,6 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2):
print(f" {icon} {impl_name:10} vs Flash - Avg: {avg_speedup:.2f}x (Best: {max_speedup:.2f}x, Worst: {min_speedup:.2f}x)")
else:
print(f" ❌ {impl_key.replace('_', '-').upper():10} vs Flash - No successful runs")

# Calculate overhead comparison
if implementation_speedups['cuda'] and implementation_speedups['no_topk']:
avg_cuda = np.mean(implementation_speedups['cuda'])
avg_no_topk = np.mean(implementation_speedups['no_topk'])
topk_overhead = ((avg_no_topk - avg_cuda) / avg_no_topk * 100) if avg_no_topk > 0 else 0
print(f" 💡 TopK overhead: ~{topk_overhead:.1f}% performance impact")


def main():
Expand Down