Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
[submodule "csrc/cutlass"]
path = csrc/cutlass
url = https://github.com/NVIDIA/cutlass.git

[submodule "csrc/cub"]
path = csrc/cub
url = https://github.com/NVIDIA/cub.git
105 changes: 79 additions & 26 deletions benchmarks/benchmark_forward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@

import torch
import torch.nn.functional as F
import numpy as np
import argparse
import time
from flash_dma_cpp import apply_dynamic_mask_attention # type: ignore

# Import the compiled CUDA extension
try:
import flash_dma_cuda
print("✅ Successfully imported flash_dma_cuda")
except ImportError as e:
print(f"❌ Failed to import flash_dma_cuda: {e}")
print("Please make sure the package is properly installed with: pip install .")
exit(1)


def prepare_dynamic_mask(
Expand Down Expand Up @@ -45,7 +52,6 @@ def prepare_dynamic_mask(
attn_mask = dt_states[:, :, None, :].expand(
-1, -1, hidden_states.shape[2], -1
) # [batch_size, num_kv_heads, query_len, key_len]
active_mask = torch.ones_like(attn_mask, dtype=dtype, device=attn_mask.device)

if attention_mask is not None:
if attention_mask.dtype == torch.bool:
Expand All @@ -65,7 +71,8 @@ def prepare_dynamic_mask(
active_mask = torch.zeros_like(attn_mask, dtype=dtype, device=attn_mask.device)
active_mask = active_mask.scatter(-1, topk_indices, 1.0)
attn_mask = attn_mask.masked_fill(active_mask == 0.0, min_dtype)

else:
active_mask = torch.ones_like(attn_mask, dtype=dtype, device=attn_mask.device)
return attn_mask, active_mask


Expand Down Expand Up @@ -140,10 +147,8 @@ def dynamic_mask_attention_python(
Returns:
attn_outputs: [batch_size, query_len, num_heads, head_dim]
"""
batch_size, num_heads, query_len, head_dim = query_states.shape
_, num_kv_heads, key_len, _ = key_states.shape
device = query_states.device
dtype = query_states.dtype
_, num_heads, _, _ = query_states.shape
_, num_kv_heads, _, _ = key_states.shape

num_queries_per_kv = num_heads // num_kv_heads

Expand Down Expand Up @@ -201,17 +206,20 @@ def dynamic_mask_attention_cuda(
Returns:
attn_outputs: [batch_size, query_len, num_heads, head_dim]
"""

# Calculate zero_hold_states
zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A, causal_mask)

# Use prepare_dynamic_mask to get the processed attention mask
_, active_mask = prepare_dynamic_mask(
query_states,
zero_hold_states,
zero_hold_states,
keep_window_size,
causal_mask if is_causal else None
) # [batch_size, num_kv_heads, query_len, key_len]

# Ensure correct data types and memory layout
# Ensure correct data types and memory layout for CUDA function
# CUDA function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format
query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim]
key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]
value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]
Expand All @@ -220,20 +228,25 @@ def dynamic_mask_attention_cuda(
).contiguous() # [batch, num_kv_heads, query_len, key_len]
active_mask = active_mask.contiguous() # [batch, num_kv_heads, query_len, key_len]

result = apply_dynamic_mask_attention(
query_states=query_states,
key_states=key_states,
value_states=value_states,
zoh_states=zero_hold_states,
active_mask=active_mask,
scale=scaling,
keep_window_size=keep_window_size,
is_causal=is_causal,
return_softmax=return_softmax
# Call the CUDA implementation using the mha_fwd function signature
out_tensor = None # Let the function allocate the output tensor
result = flash_dma_cuda.fwd( # type: ignore
query_states, # q: [batch, seqlen_q, num_heads, head_dim]
key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim]
value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim]
zero_hold_states, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask
active_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k]
out_tensor, # out: None to auto-allocate
0.0, # p_dropout
scaling, # softmax_scale
is_causal, # is_causal
keep_window_size, # keep_window_size
0.0, # softcap
return_softmax, # return_softmax
None # gen (generator)
)

# Convert result back to original data type
attn_outputs = result[0]
attn_outputs = result[0] # [batch, query_len, num_heads, head_dim]
return attn_outputs


Expand Down Expand Up @@ -343,9 +356,12 @@ def test_forward_equivalence(accuracy_threshold=0.95):
# (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal)
(1, 1, 1, 64, 64, 32, True), # Small scale test, causal mask
(1, 1, 1, 64, 64, 32, False), # Small scale test, non-causal mask
(1, 2, 1, 128, 128, 32, True), # Medium scale test, GQA mode
(1, 1, 1, 128, 128, 32, True), # Medium scale test, causal mask
(1, 1, 1, 128, 128, 32, False), # Medium scale test, non-causal mask
(1, 1, 1, 256, 256, 32, True), # Large scale test, causal mask
(1, 2, 1, 64, 64, 32, True), # Medium scale test, GQA mode
(2, 1, 1, 128, 128, 32, True), # Medium scale test, Multi batch
(2, 2, 1, 128, 128, 32, True), # Medium scale test, Multi batch GQA mode
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -456,6 +472,7 @@ def main():
of dynamic mask attention.

This script validates numerical consistency including:
- Standard forward pass (fwd)
- Different batch sizes, head counts, sequence lengths and dimensions
- Causal and non-causal mask options
- Numerical equivalence analysis
Expand All @@ -470,24 +487,60 @@ def main():
parser.add_argument('--verbose', action='store_true', help='Verbose output')
parser.add_argument('--accuracy-threshold', type=float, default=0.95,
help='Minimum accuracy ratio to pass test (default: 0.95)')
parser.add_argument('--test-type', type=str, default='all',
choices=['all', 'fwd'],
help='Type of test to run (default: all)')

args = parser.parse_args()

# Set random seed
torch.manual_seed(args.seed)

# Print test environment information
print("🧬" + "=" * 78 + "🧬")
print("🔬 Dynamic Mask Attention Forward Pass Equivalence Test Suite 🔬")
print("🧬" + "=" * 78 + "🧬")
print(f"🐍 PyTorch version: {torch.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_icon = "🔥" if device.type == "cuda" else "💻"
print(f"{device_icon} Device: {device}")

if torch.cuda.is_available():
print(f"🎮 CUDA device: {torch.cuda.get_device_name()}")
print(f"🎲 Random seed: {args.seed}")
print(f"🎲 Random seed: {args.seed}")
print(f"📊 Test type: {args.test_type}")
print(f"🎯 Accuracy threshold: {args.accuracy_threshold*100:.1f}%")

# Run equivalence test
test_forward_equivalence(args.accuracy_threshold)
# Track overall test results
test_results = {}

# Run tests based on user selection
if args.test_type in ['all', 'fwd']:
print("\n" + "📍" + " Starting Standard Forward Pass Tests " + "📍")
test_results['fwd'] = test_forward_equivalence(args.accuracy_threshold)


# Print overall summary
print("\n" + "🏆" + "=" * 78 + "🏆")
print("🔬 FINAL TEST SUMMARY 🔬")
print("🏆" + "=" * 78 + "🏆")

all_passed = True
for test_name, result in test_results.items():
status_icon = "✅" if result else "❌"
status_text = "PASSED" if result else "FAILED"
print(f" {status_icon} {test_name.upper():12} : {status_text}")
all_passed = all_passed and result

# Overall result
overall_icon = "🎉" if all_passed else "😞"
overall_text = "ALL TESTS PASSED" if all_passed else "SOME TESTS FAILED"
print(f"\n{overall_icon} OVERALL RESULT: {overall_text}")
print("🏆" + "=" * 78 + "🏆")

# Exit with appropriate code
import sys
sys.exit(0 if all_passed else 1)


if __name__ == "__main__":
Expand Down
32 changes: 0 additions & 32 deletions csrc/apply_dynamic_mask_api.cpp

This file was deleted.

67 changes: 0 additions & 67 deletions csrc/apply_dynamic_mask_attention_api.cpp

This file was deleted.

Loading