Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 42 additions & 14 deletions benchmarks/benchmark_forward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch.nn.functional as F
import argparse
import time
import gc

# Import the compiled CUDA extension
try:
Expand Down Expand Up @@ -155,7 +156,7 @@ def dynamic_mask_attention_python(
zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A)

# Use prepare_dynamic_mask function to process dynamic mask
attn_mask, _ = prepare_dynamic_mask(
attn_mask, active_mask = prepare_dynamic_mask(
query_states,
zero_hold_states,
keep_window_size,
Expand Down Expand Up @@ -211,7 +212,7 @@ def dynamic_mask_attention_cuda(
zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A)

# Use prepare_dynamic_mask to get the processed attention mask
_, active_mask = prepare_dynamic_mask(
attn_mask, active_mask = prepare_dynamic_mask(
query_states,
zero_hold_states,
keep_window_size,
Expand All @@ -226,6 +227,7 @@ def dynamic_mask_attention_cuda(
zero_hold_states = zero_hold_states[:, :, None, :].expand(
-1, -1, query_states.shape[1], -1
).contiguous() # [batch, num_kv_heads, query_len, key_len]
attn_mask = attn_mask.contiguous() # [batch, num_kv_heads, query_len, key_len]
active_mask = active_mask.contiguous() # [batch, num_kv_heads, query_len, key_len]

# Call the CUDA implementation using the mha_fwd function signature
Expand All @@ -234,7 +236,7 @@ def dynamic_mask_attention_cuda(
query_states, # q: [batch, seqlen_q, num_heads, head_dim]
key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim]
value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim]
zero_hold_states, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask
attn_mask, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask
active_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k]
out_tensor, # out: None to auto-allocate
0.0, # p_dropout
Expand Down Expand Up @@ -352,16 +354,30 @@ def test_forward_equivalence(accuracy_threshold=0.95):
torch.manual_seed(0)

# Test different parameter configurations
# If you encounter NAN issues when running multiple configurations, try running a single configuration
test_configs = [
# (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal)
(1, 1, 1, 64, 64, 32, True), # Small scale test, causal mask
(1, 1, 1, 64, 64, 32, False), # Small scale test, non-causal mask
(1, 1, 1, 128, 128, 32, True), # Medium scale test, causal mask
(1, 1, 1, 128, 128, 32, False), # Medium scale test, non-causal mask
(1, 1, 1, 256, 256, 32, True), # Large scale test, causal mask
(1, 2, 1, 64, 64, 32, True), # Medium scale test, GQA mode
(2, 1, 1, 128, 128, 32, True), # Medium scale test, Multi batch
(2, 2, 1, 128, 128, 32, True), # Medium scale test, Multi batch GQA mode
(1, 1, 1, 4, 64, 32, True),
(1, 1, 1, 4, 64, 32, False),
(1, 1, 1, 128, 128, 32, True),
(1, 1, 1, 128, 128, 32, False),
(1, 1, 1, 256, 256, 32, True),
(1, 1, 1, 256, 256, 32, False),
(1, 1, 1, 512, 512, 32, True),
(1, 1, 1, 512, 512, 32, False),
(1, 1, 1, 1024, 1024, 32, True),
(1, 1, 1, 1024, 1024, 32, False),
(1, 1, 1, 2048, 2048, 32, True),
(1, 1, 1, 2048, 2048, 32, False),
(1, 1, 1, 4096, 4096, 32, True),
(1, 1, 1, 4096, 4096, 32, False),
(1, 2, 1, 64, 64, 32, True),
(2, 1, 1, 128, 128, 32, True),
(2, 2, 1, 128, 128, 32, True),
(1, 2, 1, 64, 64, 128, True),
(1, 2, 1, 128, 128, 128, True),
(1, 2, 1, 256, 256, 128, True),
(1, 2, 1, 2, 256, 128, True),
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -371,6 +387,10 @@ def test_forward_equivalence(accuracy_threshold=0.95):
all_passed = True

for i, config in enumerate(test_configs):
torch.cuda.empty_cache()
gc.collect()
torch.cuda.synchronize()

batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal = config

# Progress indicator
Expand Down Expand Up @@ -425,7 +445,7 @@ def test_forward_equivalence(accuracy_threshold=0.95):
dt_proj, A, scaling, causal_mask,
keep_window_size, is_causal
)
py_output_copy = py_output.clone()
torch.cuda.synchronize()
py_time = time.time() - start_time

# Run CUDA implementation
Expand All @@ -435,10 +455,14 @@ def test_forward_equivalence(accuracy_threshold=0.95):
dt_proj, A, scaling, causal_mask,
keep_window_size, is_causal
)
torch.cuda.synchronize()
cuda_time = time.time() - start_time



# Analyze differences
is_close, max_diff, mean_diff = analyze_differences(py_output_copy, cuda_output, accuracy_threshold)
py_output_copy = py_output.clone()
cuda_output_copy = cuda_output.clone()
is_close, max_diff, mean_diff = analyze_differences(py_output_copy, cuda_output_copy, accuracy_threshold)

# Report performance difference
speedup = py_time / cuda_time if cuda_time > 0 else float('inf')
Expand All @@ -457,6 +481,10 @@ def test_forward_equivalence(accuracy_threshold=0.95):
if not is_close and max_diff > 1e-2:
print(" ⚠️ Difference too large, stopping subsequent tests.")
break
del query_states, key_states, value_states, dt_proj, A, causal_mask, py_output, cuda_output, py_output_copy, cuda_output_copy
torch.cuda.empty_cache()
gc.collect()
torch.cuda.synchronize()

print("\n" + "🏁" + "=" * 76 + "🏁")
summary_icon = "🎉" if all_passed else "😞"
Expand Down
43 changes: 26 additions & 17 deletions benchmarks/benchmark_forward_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def flash_attention_cuda(
"""
_, _, query_len, _ = query_states.shape
_, _, key_len, _ = key_states.shape
if query_len > 32768 or key_len > 32768:
if query_len > 32768 and key_len > 32768:
return "OOM"

query_states = query_states.contiguous()
Expand Down Expand Up @@ -186,7 +186,7 @@ def dynamic_mask_attention_cuda(
# Calculate zero_hold_states
zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A)

_, active_mask = prepare_dynamic_mask(
attn_mask, active_mask = prepare_dynamic_mask(
query_states,
zero_hold_states,
keep_window_size,
Expand All @@ -201,6 +201,7 @@ def dynamic_mask_attention_cuda(
zero_hold_states = zero_hold_states[:, :, None, :].expand(
-1, -1, query_states.shape[1], -1
).contiguous() # [batch, num_kv_heads, query_len, key_len]
attn_mask = attn_mask.contiguous() # [batch, num_kv_heads, query_len, key_len]
active_mask = active_mask.contiguous() # [batch, num_kv_heads, query_len, key_len]

try:
Expand All @@ -211,7 +212,7 @@ def dynamic_mask_attention_cuda(
key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim]
value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim]
zero_hold_states, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask
active_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k]
attn_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k]
out_tensor, # out: None to auto-allocate
0.0, # p_dropout
scaling, # softmax_scale
Expand Down Expand Up @@ -374,6 +375,7 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2):
# Set scaling factor and keep window size
scaling = head_dim ** -0.5
keep_window_size = 2048
is_causal = True

results = {
'config': config,
Expand All @@ -396,7 +398,7 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2):
for _ in range(warmup_runs):
result = flash_attention_cuda(
query_states, key_states, value_states,
scaling, causal_mask, True
scaling, causal_mask, is_causal
)
if result == "OOM":
results['flash_attention_status'] = 'OOM'
Expand All @@ -412,7 +414,7 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2):
start_time = time.time()
result = flash_attention_cuda(
query_states, key_states, value_states,
scaling, causal_mask, True
scaling, causal_mask, is_causal
)
torch.cuda.synchronize()
end_time = time.time()
Expand All @@ -436,7 +438,7 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2):
result = dynamic_mask_attention_cuda(
query_states, key_states, value_states,
dt_proj, A, scaling, causal_mask,
keep_window_size, True
keep_window_size, is_causal
)
if result == "OOM":
results['dynamic_mask_attention_status'] = 'OOM'
Expand All @@ -453,7 +455,7 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2):
result = dynamic_mask_attention_cuda(
query_states, key_states, value_states,
dt_proj, A, scaling, causal_mask,
keep_window_size, True
keep_window_size, is_causal
)
torch.cuda.synchronize()
end_time = time.time()
Expand All @@ -477,7 +479,7 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2):
result = dynamic_mask_attention_cuda_no_topk(
query_states, key_states, value_states,
dt_proj, A, scaling, causal_mask,
keep_window_size, True
keep_window_size, is_causal
)
if result == "OOM":
results['dynamic_mask_attention_no_topk_status'] = 'OOM'
Expand All @@ -494,7 +496,7 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2):
result = dynamic_mask_attention_cuda_no_topk(
query_states, key_states, value_states,
dt_proj, A, scaling, causal_mask,
keep_window_size, True
keep_window_size, is_causal
)
torch.cuda.synchronize()
end_time = time.time()
Expand Down Expand Up @@ -531,14 +533,18 @@ def run_performance_benchmark():
(1, 2, 1, 32768, 32768, 32),

# Inference
(1, 2, 1, 64, 256, 32),
(1, 2, 1, 64, 512, 32),
(1, 2, 1, 64, 1024, 32),
(1, 2, 1, 64, 2048, 32),
(1, 2, 1, 64, 4096, 32),
(1, 2, 1, 64, 8192, 32),
(1, 2, 1, 64, 16384, 32),
(1, 2, 1, 64, 32768, 32),
(1, 2, 1, 64, 256, 128),
(1, 2, 1, 64, 512, 128),
(1, 2, 1, 64, 1024, 128),
(1, 2, 1, 64, 2048, 128),
(1, 2, 1, 64, 4096, 128),
(1, 2, 1, 64, 8192, 128),
(1, 2, 1, 64, 16384, 128),
(1, 2, 1, 64, 32768, 128),
(1, 2, 1, 64, 65536, 128),
(1, 2, 1, 64, 131072, 128),
(1, 2, 1, 64, 262144, 128),
(1, 2, 1, 64, 524288, 128),

# Vary batch size
(1, 2, 1, 1024, 1024, 32),
Expand All @@ -555,7 +561,10 @@ def run_performance_benchmark():
# Vary head dimension
(1, 2, 1, 1024, 1024, 32),
(1, 2, 1, 1024, 1024, 64),
(1, 2, 1, 1024, 1024, 96),
(1, 2, 1, 1024, 1024, 128),
(1, 2, 1, 1024, 1024, 192),
(1, 2, 1, 1024, 1024, 256),
]

num_runs = 3 # Run 3 times and take average
Expand Down