Skip to content
23 changes: 12 additions & 11 deletions benchmarks/backward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,12 @@ def prepare_dynamic_mask(
)

if attn_bias.shape[-1] > keep_window_size:
topk_indices = torch.topk(
topk_values, topk_indices = torch.topk(
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
).indices
)
valid_topk = topk_values != min_dtype
attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = attn_mask.scatter(-1, topk_indices, 1.0)
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype))
attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
else:
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
Expand Down Expand Up @@ -561,19 +562,19 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
(1, 2, 1, 256, 256, 32, False),
(1, 2, 1, 512, 512, 32, True),
(1, 2, 1, 512, 512, 32, False),
(1, 2, 1, 1024, 1024, 32, True), # some -Inf and Inf in dbias, Idk why
(1, 2, 1, 1024, 1024, 32, True),
(1, 2, 1, 1024, 1024, 32, False),
(1, 2, 1, 2048, 2048, 32, True),
(1, 2, 1, 2048, 2048, 32, False),
(1, 2, 1, 4096, 4096, 32, True), # some NAN in dbias, Idk why
(1, 2, 1, 4096, 4096, 32, True),
(1, 2, 1, 4096, 4096, 32, False),
(1, 2, 1, 128, 128, 64, True),
(1, 2, 1, 128, 128, 64, False),
(1, 2, 1, 256, 256, 64, True), # some NAN in dbias, Idk why
(1, 2, 1, 256, 256, 64, True),
(1, 2, 1, 256, 256, 64, False),
(1, 2, 1, 512, 512, 64, True),
(1, 2, 1, 512, 512, 64, False),
(1, 2, 1, 1024, 1024, 64, True), # some NAN in dbias, Idk why
(1, 2, 1, 1024, 1024, 64, True), # some INF in dbias, Idk why
(1, 2, 1, 1024, 1024, 64, False),
(1, 2, 1, 2048, 2048, 64, True),
(1, 2, 1, 2048, 2048, 64, False),
Expand All @@ -585,26 +586,26 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
(1, 2, 1, 256, 256, 96, False),
(1, 2, 1, 512, 512, 96, True),
(1, 2, 1, 512, 512, 96, False),
(1, 2, 1, 1024, 1024, 96, True), # some NAN in dbias, Idk why
(1, 2, 1, 1024, 1024, 96, True), # some INF in dbias, Idk why
(1, 2, 1, 1024, 1024, 96, False),
(1, 2, 1, 2048, 2048, 96, True),
(1, 2, 1, 2048, 2048, 96, False),
(1, 2, 1, 4096, 4096, 96, True),
(1, 2, 1, 4096, 4096, 96, False),
(1, 2, 1, 128, 128, 128, True), # some NAN in dbias, Idk why
(1, 2, 1, 128, 128, 128, True),
(1, 2, 1, 128, 128, 128, True),
(1, 2, 1, 256, 256, 128, True),
(1, 2, 1, 256, 256, 128, False),
(1, 2, 1, 512, 512, 128, True),
(1, 2, 1, 512, 512, 128, False),
(1, 2, 1, 1024, 1024, 128, True), # some NAN in dbias, Idk why
(1, 2, 1, 1024, 1024, 128, True), # some INF in dbias, Idk why
(1, 2, 1, 1024, 1024, 128, False),
(1, 2, 1, 2048, 2048, 128, True),
(1, 2, 1, 2048, 2048, 128, False),
(1, 2, 1, 4096, 4096, 128, True),
(1, 2, 1, 4096, 4096, 128, False),

# Not support head_dim > 128 yet in sm 80
# Not support head_dim > 128 in sm80 yet
# (1, 2, 1, 128, 128, 256, True),
# (1, 2, 1, 128, 128, 128, False),
# (1, 2, 1, 256, 256, 256, True),
Expand Down
260 changes: 195 additions & 65 deletions benchmarks/forward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,12 @@ def prepare_dynamic_mask(
)

if attn_bias.shape[-1] > keep_window_size:
topk_indices = torch.topk(
topk_values, topk_indices = torch.topk(
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
).indices
)
valid_topk = topk_values != min_dtype
attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = attn_mask.scatter(-1, topk_indices, 1.0)
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype))
attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
else:
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
Expand Down Expand Up @@ -518,28 +519,70 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95):
# If you encounter NAN issues when running multiple configurations, try running a single configuration
test_configs = [
# (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal)
(1, 1, 1, 64, 64, 32, True),
(1, 1, 1, 64, 64, 32, False),
(1, 1, 1, 128, 128, 32, True),
(1, 1, 1, 128, 128, 32, False),
(1, 1, 1, 256, 256, 32, True),
(1, 1, 1, 256, 256, 32, False),
(1, 1, 1, 512, 512, 32, True),
(1, 1, 1, 512, 512, 32, False),
(1, 1, 1, 1024, 1024, 32, True),
(1, 1, 1, 1024, 1024, 32, False),
(1, 1, 1, 2048, 2048, 32, True),
(1, 1, 1, 2048, 2048, 32, False),
(1, 1, 1, 4096, 4096, 32, True),
(1, 1, 1, 4096, 4096, 32, False),
(1, 2, 1, 64, 64, 32, True),
(2, 1, 1, 128, 128, 32, True),
(2, 2, 1, 128, 128, 32, True),
(1, 2, 1, 64, 64, 128, True),
(1, 2, 1, 128, 128, 32, True),
(1, 2, 1, 128, 128, 32, False),
(1, 2, 1, 256, 256, 32, True),
(1, 2, 1, 256, 256, 32, False),
(1, 2, 1, 512, 512, 32, True),
(1, 2, 1, 512, 512, 32, False),
(1, 2, 1, 1024, 1024, 32, True),
(1, 2, 1, 1024, 1024, 32, False),
(1, 2, 1, 2048, 2048, 32, True),
(1, 2, 1, 2048, 2048, 32, False),
(1, 2, 1, 4096, 4096, 32, True),
(1, 2, 1, 4096, 4096, 32, False),

(1, 2, 1, 128, 128, 64, True),
(1, 2, 1, 128, 128, 64, False),
(1, 2, 1, 256, 256, 64, True),
(1, 2, 1, 256, 256, 64, False),
(1, 2, 1, 512, 512, 64, True),
(1, 2, 1, 512, 512, 64, False),
(1, 2, 1, 1024, 1024, 64, True),
(1, 2, 1, 1024, 1024, 64, False),
(1, 2, 1, 2048, 2048, 64, True),
(1, 2, 1, 2048, 2048, 64, False),
(1, 2, 1, 4096, 4096, 64, True),
(1, 2, 1, 4096, 4096, 64, False),

(1, 2, 1, 128, 128, 96, True),
(1, 2, 1, 128, 128, 96, False),
(1, 2, 1, 256, 256, 96, True),
(1, 2, 1, 256, 256, 96, False),
(1, 2, 1, 512, 512, 96, True),
(1, 2, 1, 512, 512, 96, False),
(1, 2, 1, 1024, 1024, 96, True),
(1, 2, 1, 1024, 1024, 96, False),
(1, 2, 1, 2048, 2048, 96, True),
(1, 2, 1, 2048, 2048, 96, False),
(1, 2, 1, 4096, 4096, 96, True),
(1, 2, 1, 4096, 4096, 96, False),

(1, 2, 1, 128, 128, 128, True),
(1, 2, 1, 128, 128, 128, True),
Copy link

Copilot AI Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate test configuration. The same test case (1, 2, 1, 128, 128, 128, True) appears twice in the test_configs list, which is redundant and increases test execution time unnecessarily.

Suggested change
(1, 2, 1, 128, 128, 128, True),

Copilot uses AI. Check for mistakes.
(1, 2, 1, 256, 256, 128, True),
(1, 2, 1, 3, 512, 128, True),
(1, 2, 1, 1, 512, 128, True),
(1, 2, 1, 256, 256, 128, False),
(1, 2, 1, 512, 512, 128, True),
(1, 2, 1, 512, 512, 128, False),
(1, 2, 1, 1024, 1024, 128, True),
(1, 2, 1, 1024, 1024, 128, False),
(1, 2, 1, 2048, 2048, 128, True),
(1, 2, 1, 2048, 2048, 128, False),
(1, 2, 1, 4096, 4096, 128, True),
(1, 2, 1, 4096, 4096, 128, False),

(1, 2, 1, 128, 128, 128, True),
(1, 2, 1, 128, 128, 128, False),
(1, 2, 1, 256, 256, 256, True),
(1, 2, 1, 256, 256, 256, False),
(1, 2, 1, 512, 512, 256, True),
(1, 2, 1, 512, 512, 256, False),
(1, 2, 1, 1024, 1024, 256, True),
(1, 2, 1, 1024, 1024, 256, False),
(1, 2, 1, 2048, 2048, 256, True),
(1, 2, 1, 2048, 2048, 256, False),
(1, 2, 1, 4096, 4096, 256, True),
(1, 2, 1, 4096, 4096, 256, False),
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -672,27 +715,71 @@ def test_triton_forward_equivalence(accuracy_threshold=0.95):
# If you encounter NAN issues when running multiple configurations, try running a single configuration
test_configs = [
# (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal)
(1, 1, 1, 64, 64, 32, True),
(1, 1, 1, 64, 64, 32, False),
(1, 1, 1, 128, 128, 32, True),
(1, 1, 1, 128, 128, 32, False),
(1, 1, 1, 256, 256, 32, True),
(1, 1, 1, 256, 256, 32, False),
(1, 1, 1, 512, 512, 32, True),
(1, 1, 1, 512, 512, 32, False),
(1, 1, 1, 1024, 1024, 32, True),
(1, 1, 1, 1024, 1024, 32, False),
(1, 1, 1, 2048, 2048, 32, True),
(1, 1, 1, 2048, 2048, 32, False),
(1, 1, 1, 4096, 4096, 32, True),
(1, 1, 1, 4096, 4096, 32, False),
(1, 2, 1, 64, 64, 32, True),
(2, 1, 1, 128, 128, 32, True),
(2, 2, 1, 128, 128, 32, True),
(1, 2, 1, 64, 64, 128, True),
(1, 2, 1, 128, 128, 32, True),
(1, 2, 1, 128, 128, 32, False),
(1, 2, 1, 256, 256, 32, True),
(1, 2, 1, 256, 256, 32, False),
(1, 2, 1, 512, 512, 32, True),
(1, 2, 1, 512, 512, 32, False),
(1, 2, 1, 1024, 1024, 32, True),
(1, 2, 1, 1024, 1024, 32, False),
(1, 2, 1, 2048, 2048, 32, True),
(1, 2, 1, 2048, 2048, 32, False),
(1, 2, 1, 4096, 4096, 32, True),
(1, 2, 1, 4096, 4096, 32, False),

(1, 2, 1, 128, 128, 64, True),
(1, 2, 1, 128, 128, 64, False),
(1, 2, 1, 256, 256, 64, True),
(1, 2, 1, 256, 256, 64, False),
(1, 2, 1, 512, 512, 64, True),
(1, 2, 1, 512, 512, 64, False),
(1, 2, 1, 1024, 1024, 64, True),
(1, 2, 1, 1024, 1024, 64, False),
(1, 2, 1, 2048, 2048, 64, True),
(1, 2, 1, 2048, 2048, 64, False),
(1, 2, 1, 4096, 4096, 64, True),
(1, 2, 1, 4096, 4096, 64, False),

(1, 2, 1, 128, 128, 96, True),
(1, 2, 1, 128, 128, 96, False),
(1, 2, 1, 256, 256, 96, True),
(1, 2, 1, 256, 256, 96, False),
(1, 2, 1, 512, 512, 96, True),
(1, 2, 1, 512, 512, 96, False),
(1, 2, 1, 1024, 1024, 96, True),
(1, 2, 1, 1024, 1024, 96, False),
(1, 2, 1, 2048, 2048, 96, True),
(1, 2, 1, 2048, 2048, 96, False),
(1, 2, 1, 4096, 4096, 96, True),
(1, 2, 1, 4096, 4096, 96, False),

(1, 2, 1, 128, 128, 128, True),
(1, 2, 1, 128, 128, 128, True),
Copy link

Copilot AI Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate test configuration in triton test function. The same test case (1, 2, 1, 128, 128, 128, True) appears twice in the test_configs list, which is redundant and increases test execution time unnecessarily.

Suggested change
(1, 2, 1, 128, 128, 128, True),

Copilot uses AI. Check for mistakes.
(1, 2, 1, 256, 256, 128, True),
(1, 2, 1, 256, 256, 128, False),
(1, 2, 1, 512, 512, 128, True),
(1, 2, 1, 512, 512, 128, False),
(1, 2, 1, 1024, 1024, 128, True),
(1, 2, 1, 1024, 1024, 128, False),
(1, 2, 1, 2048, 2048, 128, True),
(1, 2, 1, 2048, 2048, 128, False),
(1, 2, 1, 4096, 4096, 128, True),
(1, 2, 1, 4096, 4096, 128, False),

# Not support head_dim > 128 in triton yet
# (1, 2, 1, 128, 128, 128, True),
# (1, 2, 1, 128, 128, 128, False),
# (1, 2, 1, 256, 256, 256, True),
# (1, 2, 1, 256, 256, 256, False),
# (1, 2, 1, 512, 512, 256, True),
# (1, 2, 1, 512, 512, 256, False),
# (1, 2, 1, 1024, 1024, 256, True),
# (1, 2, 1, 1024, 1024, 256, False),
# (1, 2, 1, 2048, 2048, 256, True),
# (1, 2, 1, 2048, 2048, 256, False),
# (1, 2, 1, 4096, 4096, 256, True),
# (1, 2, 1, 4096, 4096, 256, False),
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -843,27 +930,70 @@ def test_flex_forward_equivalence(accuracy_threshold=0.95):
# Test configurations for Flex Attention
test_configs = [
# (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal)
(1, 1, 1, 64, 64, 32, True),
(1, 1, 1, 64, 64, 32, False),
(1, 1, 1, 128, 128, 32, True),
(1, 1, 1, 128, 128, 32, False),
(1, 1, 1, 256, 256, 32, True),
(1, 1, 1, 256, 256, 32, False),
(1, 1, 1, 512, 512, 32, True),
(1, 1, 1, 512, 512, 32, False),
(1, 1, 1, 1024, 1024, 32, True),
(1, 1, 1, 1024, 1024, 32, False),
(1, 1, 1, 2048, 2048, 32, True),
(1, 1, 1, 2048, 2048, 32, False),
(1, 1, 1, 4096, 4096, 32, True),
(1, 1, 1, 4096, 4096, 32, False),
(1, 2, 1, 64, 64, 32, True),
(2, 1, 1, 128, 128, 32, True),
(2, 2, 1, 128, 128, 32, True),
(1, 2, 1, 64, 64, 128, True),
(1, 2, 1, 128, 128, 32, True),
(1, 2, 1, 128, 128, 32, False),
(1, 2, 1, 256, 256, 32, True),
(1, 2, 1, 256, 256, 32, False),
(1, 2, 1, 512, 512, 32, True),
(1, 2, 1, 512, 512, 32, False),
(1, 2, 1, 1024, 1024, 32, True),
(1, 2, 1, 1024, 1024, 32, False),
(1, 2, 1, 2048, 2048, 32, True),
(1, 2, 1, 2048, 2048, 32, False),
(1, 2, 1, 4096, 4096, 32, True),
(1, 2, 1, 4096, 4096, 32, False),

(1, 2, 1, 128, 128, 64, True),
(1, 2, 1, 128, 128, 64, False),
(1, 2, 1, 256, 256, 64, True),
(1, 2, 1, 256, 256, 64, False),
(1, 2, 1, 512, 512, 64, True),
(1, 2, 1, 512, 512, 64, False),
(1, 2, 1, 1024, 1024, 64, True),
(1, 2, 1, 1024, 1024, 64, False),
(1, 2, 1, 2048, 2048, 64, True),
(1, 2, 1, 2048, 2048, 64, False),
(1, 2, 1, 4096, 4096, 64, True),
(1, 2, 1, 4096, 4096, 64, False),

(1, 2, 1, 128, 128, 96, True),
(1, 2, 1, 128, 128, 96, False),
(1, 2, 1, 256, 256, 96, True),
(1, 2, 1, 256, 256, 96, False),
(1, 2, 1, 512, 512, 96, True),
(1, 2, 1, 512, 512, 96, False),
(1, 2, 1, 1024, 1024, 96, True),
(1, 2, 1, 1024, 1024, 96, False),
(1, 2, 1, 2048, 2048, 96, True),
(1, 2, 1, 2048, 2048, 96, False),
(1, 2, 1, 4096, 4096, 96, True),
(1, 2, 1, 4096, 4096, 96, False),

(1, 2, 1, 128, 128, 128, True),
(1, 2, 1, 128, 128, 128, True),
Copy link

Copilot AI Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate test configuration in flex test function. The same test case (1, 2, 1, 128, 128, 128, True) appears twice in the test_configs list, which is redundant and increases test execution time unnecessarily.

Suggested change
(1, 2, 1, 128, 128, 128, True),

Copilot uses AI. Check for mistakes.
(1, 2, 1, 256, 256, 128, True),
(1, 2, 1, 256, 256, 128, False),
(1, 2, 1, 512, 512, 128, True),
(1, 2, 1, 512, 512, 128, False),
(1, 2, 1, 1024, 1024, 128, True),
(1, 2, 1, 1024, 1024, 128, False),
(1, 2, 1, 2048, 2048, 128, True),
(1, 2, 1, 2048, 2048, 128, False),
(1, 2, 1, 4096, 4096, 128, True),
(1, 2, 1, 4096, 4096, 128, False),

(1, 2, 1, 128, 128, 128, True),
(1, 2, 1, 128, 128, 128, False),
Comment on lines +985 to +986
Copy link

Copilot AI Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate test configuration in flex test function. The test case (1, 2, 1, 128, 128, 128, True) appears again, and there's also a duplicate (1, 2, 1, 128, 128, 128, False) that should be removed to avoid redundant testing.

Suggested change
(1, 2, 1, 128, 128, 128, True),
(1, 2, 1, 128, 128, 128, False),
# (1, 2, 1, 128, 128, 128, True), # Removed duplicate
# (1, 2, 1, 128, 128, 128, False), # Removed duplicate

Copilot uses AI. Check for mistakes.
(1, 2, 1, 256, 256, 256, True),
(1, 2, 1, 256, 256, 256, False),
(1, 2, 1, 512, 512, 256, True),
(1, 2, 1, 512, 512, 256, False),
(1, 2, 1, 1024, 1024, 256, True),
(1, 2, 1, 1024, 1024, 256, False),
(1, 2, 1, 2048, 2048, 256, True),
(1, 2, 1, 2048, 2048, 256, False),
(1, 2, 1, 4096, 4096, 256, True),
(1, 2, 1, 4096, 4096, 256, False),
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -1051,13 +1181,13 @@ def main():
print("\n" + "📍" + " Starting Standard Forward Pass Tests " + "📍")
test_results['cuda'] = test_cuda_forward_equivalence(args.accuracy_threshold)

# if args.test_type in ['all', 'triton']:
# print("\n" + "🔥" + " Starting Python vs Triton Tests " + "🔥")
# test_results['triton'] = test_triton_forward_equivalence(args.accuracy_threshold)
if args.test_type in ['all', 'triton']:
print("\n" + "🔥" + " Starting Python vs Triton Tests " + "🔥")
test_results['triton'] = test_triton_forward_equivalence(args.accuracy_threshold)

# if args.test_type in ['all', 'flex']:
# print("\n" + "🌟" + " Starting Python vs Flex Attention Tests " + "🌟")
# test_results['flex'] = test_flex_forward_equivalence(args.accuracy_threshold)
if args.test_type in ['all', 'flex']:
print("\n" + "🌟" + " Starting Python vs Flex Attention Tests " + "🌟")
test_results['flex'] = test_flex_forward_equivalence(args.accuracy_threshold)


# Print overall summary
Expand Down
7 changes: 4 additions & 3 deletions benchmarks/forward_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,12 @@ def prepare_dynamic_mask(
)

if attn_bias.shape[-1] > keep_window_size:
topk_indices = torch.topk(
topk_values, topk_indices = torch.topk(
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
).indices
)
valid_topk = topk_values != min_dtype
attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = attn_mask.scatter(-1, topk_indices, 1.0)
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype))
attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
else:
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
Expand Down
Loading