Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 172 additions & 14 deletions benchmarks/benchmark_forward_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,82 @@ def dynamic_mask_attention_cuda(
return "OOM"


def dynamic_mask_attention_cuda_no_topk(
Copy link

Copilot AI Jun 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new no-topk function largely duplicates logic from dynamic_mask_attention_cuda. Extract common setup (transposes, mask creation) into shared helpers to reduce duplication and simplify future maintenance.

Copilot uses AI. Check for mistakes.
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
dt_proj: torch.Tensor,
A: torch.Tensor,
scaling: float,
causal_mask: torch.Tensor,
keep_window_size=2048,
is_causal=True,
return_softmax=False
):
"""
CUDA implementation of dynamic mask attention without topk computation.
This version skips the topk calculation for more accurate kernel performance testing.

Args:
query_states: [batch_size, num_heads, query_len, head_dim]
key_states: [batch_size, num_kv_heads, key_len, head_dim]
value_states: [batch_size, num_kv_heads, key_len, head_dim]
dt_proj: [num_kv_heads, num_kv_heads * head_dim]
A: [num_kv_heads]
scaling: Attention scaling factor
causal_mask: Causal attention mask
keep_window_size: Number of tokens to keep in attention window
is_causal: Whether to apply causal masking
return_softmax: Whether to return softmax weights

Returns:
attn_outputs: [batch_size, query_len, num_heads, head_dim]
"""
# Calculate zero_hold_states
zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A, causal_mask)

# Create a simplified mask without topk computation
batch_size, num_heads, query_len, head_dim = query_states.shape
_, num_kv_heads, key_len, _ = key_states.shape
dtype = query_states.dtype
device = query_states.device

# Create full active mask (no topk selection)
active_mask = torch.zeros(
(batch_size, num_kv_heads, query_len, key_len),
dtype=dtype,
device=device
)

# Ensure correct data types and memory layout
query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim]
key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]
value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]
zero_hold_states = zero_hold_states[:, :, None, :].expand(
-1, -1, query_states.shape[1], -1
).contiguous() # [batch, num_kv_heads, query_len, key_len]
active_mask = active_mask.contiguous() # [batch, num_kv_heads, query_len, key_len]

try:
result = apply_dynamic_mask_attention(
query_states=query_states,
key_states=key_states,
value_states=value_states,
zoh_states=zero_hold_states,
active_mask=active_mask,
scale=scaling,
keep_window_size=0,
is_causal=is_causal,
return_softmax=return_softmax
)

# Convert result back to original data type
attn_outputs = result[0]
return attn_outputs
except torch.cuda.OutOfMemoryError:
return "OOM"


def measure_memory_usage():
"""
Measure current GPU memory usage.
Expand Down Expand Up @@ -280,10 +356,13 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2):
'config': config,
'flash_attention_times': [],
'dynamic_mask_attention_times': [],
'dynamic_mask_attention_no_topk_times': [],
'flash_attention_memory': 0,
'dynamic_mask_attention_memory': 0,
'dynamic_mask_attention_no_topk_memory': 0,
'flash_attention_status': 'success',
'dynamic_mask_attention_status': 'success'
'dynamic_mask_attention_status': 'success',
'dynamic_mask_attention_no_topk_status': 'success'
}

# Benchmark Flash Attention
Expand Down Expand Up @@ -366,6 +445,47 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2):
mem_after = measure_memory_usage()
results['dynamic_mask_attention_memory'] = mem_after[0] - mem_before[0]

# Benchmark Dynamic Mask Attention (No TopK)
gc.collect()
torch.cuda.empty_cache()

# Warmup runs
for _ in range(warmup_runs):
Copy link

Copilot AI Jun 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The loops for warmup and benchmark runs for dynamic_mask_attention_no_topk are similar to those used for the standard dynamic_mask_attention. Consider refactoring these into a helper function to reduce duplication and improve maintainability.

Copilot uses AI. Check for mistakes.
result = dynamic_mask_attention_cuda_no_topk(
query_states, key_states, value_states,
dt_proj, A, scaling, causal_mask,
keep_window_size, True
)
if result == "OOM":
results['dynamic_mask_attention_no_topk_status'] = 'OOM'
break
torch.cuda.synchronize()

if results['dynamic_mask_attention_no_topk_status'] == 'success':
# Measure memory before benchmark
mem_before = measure_memory_usage()

# Actual benchmark runs
for _ in range(num_runs):
start_time = time.time()
result = dynamic_mask_attention_cuda_no_topk(
query_states, key_states, value_states,
dt_proj, A, scaling, causal_mask,
keep_window_size, True
)
torch.cuda.synchronize()
end_time = time.time()

if result == "OOM":
results['dynamic_mask_attention_no_topk_status'] = 'OOM'
break

results['dynamic_mask_attention_no_topk_times'].append((end_time - start_time) * 1000) # ms

# Measure memory after
mem_after = measure_memory_usage()
results['dynamic_mask_attention_no_topk_memory'] = mem_after[0] - mem_before[0]

return results


Expand All @@ -386,6 +506,16 @@ def run_performance_benchmark():
(1, 2, 1, 8192, 8192, 32),
(1, 2, 1, 16384, 16384, 32),
(1, 2, 1, 32768, 32768, 32),

# Inference
(1, 2, 1, 64, 256, 32),
(1, 2, 1, 64, 512, 32),
(1, 2, 1, 64, 1024, 32),
(1, 2, 1, 64, 2048, 32),
(1, 2, 1, 64, 4096, 32),
(1, 2, 1, 64, 8192, 32),
(1, 2, 1, 64, 16384, 32),
(1, 2, 1, 64, 32768, 32),

# Vary batch size
(1, 2, 1, 1024, 1024, 32),
Expand All @@ -408,8 +538,8 @@ def run_performance_benchmark():
num_runs = 3 # Run 3 times and take average

print(f"\n📊 Benchmark Results (averaged over {num_runs} runs):")
print(f"🔧 {'Configuration':<45}⚡ {'Flash Attn (ms)':<18}🚀 {'DMA (ms)':<17}📈 {'Speedup':<12}💾 {'Memory (MB)':<12}")
print("🔄" + "-" * 117 + "🔄")
print(f"🔧 {'Configuration':<42} ⚡ {'Flash (ms)':<12} 🚀 {'DMA (ms)':<12} 🚀 {'DMA-Skip-All (ms)':<22} 📈 {'Speedup':<12} 📈 {'Skip-All-Speedup':<20} 💾 {'Memory':<10}")
print("🔄" + "-" * 155 + "🔄")
Comment on lines +541 to +542
Copy link

Copilot AI Jun 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The header printing uses hard-coded column widths. Consider extracting these widths as constants or using a formatting helper to improve flexibility and maintainability.

Suggested change
print(f"🔧 {'Configuration':<42}{'Flash (ms)':<12} 🚀 {'DMA (ms)':<12} 🚀 {'DMA-Skip-All (ms)':<22} 📈 {'Speedup':<12} 📈 {'Skip-All-Speedup':<20} 💾 {'Memory':<10}")
print("🔄" + "-" * 155 + "🔄")
print(f"🔧 {'Configuration':<{CONFIGURATION_WIDTH}}{'Flash (ms)':<{FLASH_WIDTH}} 🚀 {'DMA (ms)':<{DMA_WIDTH}} 🚀 {'DMA-Skip-All (ms)':<{DMA_SKIP_ALL_WIDTH}} 📈 {'Speedup':<{SPEEDUP_WIDTH}} 📈 {'Skip-All-Speedup':<{SKIP_ALL_SPEEDUP_WIDTH}} 💾 {'Memory':<{MEMORY_WIDTH}}")
print("🔄" + "-" * (CONFIGURATION_WIDTH + FLASH_WIDTH + DMA_WIDTH + DMA_SKIP_ALL_WIDTH + SPEEDUP_WIDTH + SKIP_ALL_SPEEDUP_WIDTH + MEMORY_WIDTH + 20) + "🔄")

Copilot uses AI. Check for mistakes.

all_results = []

Expand All @@ -433,31 +563,46 @@ def run_performance_benchmark():
else:
dma_time_str = results['dynamic_mask_attention_status']
dma_avg = float('inf')

if results['dynamic_mask_attention_no_topk_status'] == 'success' and results['dynamic_mask_attention_no_topk_times']:
dma_nt_avg = sum(results['dynamic_mask_attention_no_topk_times']) / len(results['dynamic_mask_attention_no_topk_times'])
dma_nt_time_str = f"{dma_nt_avg:.2f}"
else:
dma_nt_time_str = results['dynamic_mask_attention_no_topk_status']
dma_nt_avg = float('inf')

# Calculate speedup
# Calculate speedups
if flash_avg != float('inf') and dma_avg != float('inf') and dma_avg > 0:
speedup = flash_avg / dma_avg
speedup_str = f"{speedup:.2f}x"
else:
speedup_str = "N/A"

if flash_avg != float('inf') and dma_nt_avg != float('inf') and dma_nt_avg > 0:
kernel_speedup = flash_avg / dma_nt_avg
kernel_speedup_str = f"{kernel_speedup:.2f}x"
else:
kernel_speedup_str = "N/A"

# Memory usage
mem_diff = results['dynamic_mask_attention_memory'] - results['flash_attention_memory']
mem_str = f"{mem_diff:+.1f}"
mem_str = f"{mem_diff:+.0f}"

# Format output
config_short = f"b={batch_size}, h={num_heads}, kv={num_kv_heads}, q={query_len}, k={key_len}, d={head_dim}"
config_short = f"b={batch_size},h={num_heads},kv={num_kv_heads},q={query_len},k={key_len},d={head_dim}"

# Add status icons
flash_icon = "✅" if results['flash_attention_status'] == 'success' else "💥"
dma_icon = "✅" if results['dynamic_mask_attention_status'] == 'success' else "💥"
dma_nt_icon = "✅" if results['dynamic_mask_attention_no_topk_status'] == 'success' else "💥"

print(f"{flash_icon}{dma_icon} {config_short:<47}{flash_time_str:<20}{dma_time_str:<20}{speedup_str:<15}{mem_str:<15}")
print(f"{flash_icon}{dma_icon}{dma_nt_icon} {config_short:<42} {flash_time_str:<14} {dma_time_str:<20} {dma_nt_time_str:<20} {speedup_str:<18} {kernel_speedup_str:<20} {mem_str:<12}")

print("🔄" + "-" * 117 + "🔄")
print("🔄" + "-" * 155 + "🔄")

# Summary statistics
speedups = []
kernel_speedups = []
for results in all_results:
if (results['flash_attention_status'] == 'success' and
results['dynamic_mask_attention_status'] == 'success' and
Expand All @@ -468,15 +613,28 @@ def run_performance_benchmark():

if dma_avg > 0:
speedups.append(flash_avg / dma_avg)

if (results['flash_attention_status'] == 'success' and
results['dynamic_mask_attention_no_topk_status'] == 'success' and
results['flash_attention_times'] and results['dynamic_mask_attention_no_topk_times']):

flash_avg = sum(results['flash_attention_times']) / len(results['flash_attention_times'])
dma_nt_avg = sum(results['dynamic_mask_attention_no_topk_times']) / len(results['dynamic_mask_attention_no_topk_times'])

if dma_nt_avg > 0:
kernel_speedups.append(flash_avg / dma_nt_avg)

print(f"\n🏆 Summary:")
if speedups:
avg_speedup = np.mean(speedups)
print(f"\n🏆 Summary:")
speedup_icon = "🚀" if avg_speedup > 1.5 else "📈" if avg_speedup > 1.0 else "😐"
print(f" {speedup_icon} Average speedup: {avg_speedup:.2f}x")
print(f" ⭐ Best speedup: {np.max(speedups):.2f}x")
print(f" 📉 Worst speedup: {np.min(speedups):.2f}x")
print(f" 📊 Speedup std: {np.std(speedups):.2f}")
print(f" {speedup_icon} DMA vs Flash - Average speedup: {avg_speedup:.2f}x (Best: {np.max(speedups):.2f}x, Worst: {np.min(speedups):.2f}x)")

if kernel_speedups:
avg_kernel_speedup = np.mean(kernel_speedups)
kernel_icon = "🔥" if avg_kernel_speedup > 2.0 else "🚀" if avg_kernel_speedup > 1.5 else "📈" if avg_kernel_speedup > 1.0 else "😐"
print(f" {kernel_icon} DMA-NoTopK vs Flash - Average kernel speedup: {avg_kernel_speedup:.2f}x (Best: {np.max(kernel_speedups):.2f}x, Worst: {np.min(kernel_speedups):.2f}x)")
print(f" 💡 TopK overhead: ~{((np.mean(kernel_speedups) - np.mean(speedups) if speedups else 0) / np.mean(kernel_speedups) * 100) if kernel_speedups else 0:.1f}% performance impact")


def main():
Expand Down Expand Up @@ -518,4 +676,4 @@ def main():


if __name__ == "__main__":
main()
main()