-
Notifications
You must be signed in to change notification settings - Fork 45
Consolidate mask and bias memory operations #129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
395e9a4
29912d4
56c7c9b
834e334
4c7c27f
56d25fb
0899062
14f45ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -88,11 +88,12 @@ def prepare_dynamic_mask( | |||||||||
| ) | ||||||||||
|
|
||||||||||
| if attn_bias.shape[-1] > keep_window_size: | ||||||||||
| topk_indices = torch.topk( | ||||||||||
| topk_values, topk_indices = torch.topk( | ||||||||||
| attn_bias, keep_window_size, dim=-1, largest=True, sorted=False | ||||||||||
| ).indices | ||||||||||
| ) | ||||||||||
| valid_topk = topk_values != min_dtype | ||||||||||
| attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device) | ||||||||||
| attn_mask = attn_mask.scatter(-1, topk_indices, 1.0) | ||||||||||
| attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype)) | ||||||||||
| attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) | ||||||||||
| else: | ||||||||||
| attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) | ||||||||||
|
|
@@ -518,28 +519,70 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): | |||||||||
| # If you encounter NAN issues when running multiple configurations, try running a single configuration | ||||||||||
| test_configs = [ | ||||||||||
| # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal) | ||||||||||
| (1, 1, 1, 64, 64, 32, True), | ||||||||||
| (1, 1, 1, 64, 64, 32, False), | ||||||||||
| (1, 1, 1, 128, 128, 32, True), | ||||||||||
| (1, 1, 1, 128, 128, 32, False), | ||||||||||
| (1, 1, 1, 256, 256, 32, True), | ||||||||||
| (1, 1, 1, 256, 256, 32, False), | ||||||||||
| (1, 1, 1, 512, 512, 32, True), | ||||||||||
| (1, 1, 1, 512, 512, 32, False), | ||||||||||
| (1, 1, 1, 1024, 1024, 32, True), | ||||||||||
| (1, 1, 1, 1024, 1024, 32, False), | ||||||||||
| (1, 1, 1, 2048, 2048, 32, True), | ||||||||||
| (1, 1, 1, 2048, 2048, 32, False), | ||||||||||
| (1, 1, 1, 4096, 4096, 32, True), | ||||||||||
| (1, 1, 1, 4096, 4096, 32, False), | ||||||||||
| (1, 2, 1, 64, 64, 32, True), | ||||||||||
| (2, 1, 1, 128, 128, 32, True), | ||||||||||
| (2, 2, 1, 128, 128, 32, True), | ||||||||||
| (1, 2, 1, 64, 64, 128, True), | ||||||||||
| (1, 2, 1, 128, 128, 32, True), | ||||||||||
| (1, 2, 1, 128, 128, 32, False), | ||||||||||
| (1, 2, 1, 256, 256, 32, True), | ||||||||||
| (1, 2, 1, 256, 256, 32, False), | ||||||||||
| (1, 2, 1, 512, 512, 32, True), | ||||||||||
| (1, 2, 1, 512, 512, 32, False), | ||||||||||
| (1, 2, 1, 1024, 1024, 32, True), | ||||||||||
| (1, 2, 1, 1024, 1024, 32, False), | ||||||||||
| (1, 2, 1, 2048, 2048, 32, True), | ||||||||||
| (1, 2, 1, 2048, 2048, 32, False), | ||||||||||
| (1, 2, 1, 4096, 4096, 32, True), | ||||||||||
| (1, 2, 1, 4096, 4096, 32, False), | ||||||||||
|
|
||||||||||
| (1, 2, 1, 128, 128, 64, True), | ||||||||||
| (1, 2, 1, 128, 128, 64, False), | ||||||||||
| (1, 2, 1, 256, 256, 64, True), | ||||||||||
| (1, 2, 1, 256, 256, 64, False), | ||||||||||
| (1, 2, 1, 512, 512, 64, True), | ||||||||||
| (1, 2, 1, 512, 512, 64, False), | ||||||||||
| (1, 2, 1, 1024, 1024, 64, True), | ||||||||||
| (1, 2, 1, 1024, 1024, 64, False), | ||||||||||
| (1, 2, 1, 2048, 2048, 64, True), | ||||||||||
| (1, 2, 1, 2048, 2048, 64, False), | ||||||||||
| (1, 2, 1, 4096, 4096, 64, True), | ||||||||||
| (1, 2, 1, 4096, 4096, 64, False), | ||||||||||
|
|
||||||||||
| (1, 2, 1, 128, 128, 96, True), | ||||||||||
| (1, 2, 1, 128, 128, 96, False), | ||||||||||
| (1, 2, 1, 256, 256, 96, True), | ||||||||||
| (1, 2, 1, 256, 256, 96, False), | ||||||||||
| (1, 2, 1, 512, 512, 96, True), | ||||||||||
| (1, 2, 1, 512, 512, 96, False), | ||||||||||
| (1, 2, 1, 1024, 1024, 96, True), | ||||||||||
| (1, 2, 1, 1024, 1024, 96, False), | ||||||||||
| (1, 2, 1, 2048, 2048, 96, True), | ||||||||||
| (1, 2, 1, 2048, 2048, 96, False), | ||||||||||
| (1, 2, 1, 4096, 4096, 96, True), | ||||||||||
| (1, 2, 1, 4096, 4096, 96, False), | ||||||||||
|
|
||||||||||
| (1, 2, 1, 128, 128, 128, True), | ||||||||||
| (1, 2, 1, 128, 128, 128, True), | ||||||||||
| (1, 2, 1, 256, 256, 128, True), | ||||||||||
| (1, 2, 1, 3, 512, 128, True), | ||||||||||
| (1, 2, 1, 1, 512, 128, True), | ||||||||||
| (1, 2, 1, 256, 256, 128, False), | ||||||||||
| (1, 2, 1, 512, 512, 128, True), | ||||||||||
| (1, 2, 1, 512, 512, 128, False), | ||||||||||
| (1, 2, 1, 1024, 1024, 128, True), | ||||||||||
| (1, 2, 1, 1024, 1024, 128, False), | ||||||||||
| (1, 2, 1, 2048, 2048, 128, True), | ||||||||||
| (1, 2, 1, 2048, 2048, 128, False), | ||||||||||
| (1, 2, 1, 4096, 4096, 128, True), | ||||||||||
| (1, 2, 1, 4096, 4096, 128, False), | ||||||||||
|
|
||||||||||
| (1, 2, 1, 128, 128, 128, True), | ||||||||||
| (1, 2, 1, 128, 128, 128, False), | ||||||||||
| (1, 2, 1, 256, 256, 256, True), | ||||||||||
| (1, 2, 1, 256, 256, 256, False), | ||||||||||
| (1, 2, 1, 512, 512, 256, True), | ||||||||||
| (1, 2, 1, 512, 512, 256, False), | ||||||||||
| (1, 2, 1, 1024, 1024, 256, True), | ||||||||||
| (1, 2, 1, 1024, 1024, 256, False), | ||||||||||
| (1, 2, 1, 2048, 2048, 256, True), | ||||||||||
| (1, 2, 1, 2048, 2048, 256, False), | ||||||||||
| (1, 2, 1, 4096, 4096, 256, True), | ||||||||||
| (1, 2, 1, 4096, 4096, 256, False), | ||||||||||
| ] | ||||||||||
|
|
||||||||||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||||||||
|
|
@@ -672,27 +715,71 @@ def test_triton_forward_equivalence(accuracy_threshold=0.95): | |||||||||
| # If you encounter NAN issues when running multiple configurations, try running a single configuration | ||||||||||
| test_configs = [ | ||||||||||
| # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal) | ||||||||||
| (1, 1, 1, 64, 64, 32, True), | ||||||||||
| (1, 1, 1, 64, 64, 32, False), | ||||||||||
| (1, 1, 1, 128, 128, 32, True), | ||||||||||
| (1, 1, 1, 128, 128, 32, False), | ||||||||||
| (1, 1, 1, 256, 256, 32, True), | ||||||||||
| (1, 1, 1, 256, 256, 32, False), | ||||||||||
| (1, 1, 1, 512, 512, 32, True), | ||||||||||
| (1, 1, 1, 512, 512, 32, False), | ||||||||||
| (1, 1, 1, 1024, 1024, 32, True), | ||||||||||
| (1, 1, 1, 1024, 1024, 32, False), | ||||||||||
| (1, 1, 1, 2048, 2048, 32, True), | ||||||||||
| (1, 1, 1, 2048, 2048, 32, False), | ||||||||||
| (1, 1, 1, 4096, 4096, 32, True), | ||||||||||
| (1, 1, 1, 4096, 4096, 32, False), | ||||||||||
| (1, 2, 1, 64, 64, 32, True), | ||||||||||
| (2, 1, 1, 128, 128, 32, True), | ||||||||||
| (2, 2, 1, 128, 128, 32, True), | ||||||||||
| (1, 2, 1, 64, 64, 128, True), | ||||||||||
| (1, 2, 1, 128, 128, 32, True), | ||||||||||
| (1, 2, 1, 128, 128, 32, False), | ||||||||||
| (1, 2, 1, 256, 256, 32, True), | ||||||||||
| (1, 2, 1, 256, 256, 32, False), | ||||||||||
| (1, 2, 1, 512, 512, 32, True), | ||||||||||
| (1, 2, 1, 512, 512, 32, False), | ||||||||||
| (1, 2, 1, 1024, 1024, 32, True), | ||||||||||
| (1, 2, 1, 1024, 1024, 32, False), | ||||||||||
| (1, 2, 1, 2048, 2048, 32, True), | ||||||||||
| (1, 2, 1, 2048, 2048, 32, False), | ||||||||||
| (1, 2, 1, 4096, 4096, 32, True), | ||||||||||
| (1, 2, 1, 4096, 4096, 32, False), | ||||||||||
|
|
||||||||||
| (1, 2, 1, 128, 128, 64, True), | ||||||||||
| (1, 2, 1, 128, 128, 64, False), | ||||||||||
| (1, 2, 1, 256, 256, 64, True), | ||||||||||
| (1, 2, 1, 256, 256, 64, False), | ||||||||||
| (1, 2, 1, 512, 512, 64, True), | ||||||||||
| (1, 2, 1, 512, 512, 64, False), | ||||||||||
| (1, 2, 1, 1024, 1024, 64, True), | ||||||||||
| (1, 2, 1, 1024, 1024, 64, False), | ||||||||||
| (1, 2, 1, 2048, 2048, 64, True), | ||||||||||
| (1, 2, 1, 2048, 2048, 64, False), | ||||||||||
| (1, 2, 1, 4096, 4096, 64, True), | ||||||||||
| (1, 2, 1, 4096, 4096, 64, False), | ||||||||||
|
|
||||||||||
| (1, 2, 1, 128, 128, 96, True), | ||||||||||
| (1, 2, 1, 128, 128, 96, False), | ||||||||||
| (1, 2, 1, 256, 256, 96, True), | ||||||||||
| (1, 2, 1, 256, 256, 96, False), | ||||||||||
| (1, 2, 1, 512, 512, 96, True), | ||||||||||
| (1, 2, 1, 512, 512, 96, False), | ||||||||||
| (1, 2, 1, 1024, 1024, 96, True), | ||||||||||
| (1, 2, 1, 1024, 1024, 96, False), | ||||||||||
| (1, 2, 1, 2048, 2048, 96, True), | ||||||||||
| (1, 2, 1, 2048, 2048, 96, False), | ||||||||||
| (1, 2, 1, 4096, 4096, 96, True), | ||||||||||
| (1, 2, 1, 4096, 4096, 96, False), | ||||||||||
|
|
||||||||||
| (1, 2, 1, 128, 128, 128, True), | ||||||||||
| (1, 2, 1, 128, 128, 128, True), | ||||||||||
|
||||||||||
| (1, 2, 1, 128, 128, 128, True), |
Copilot
AI
Aug 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate test configuration in flex test function. The same test case (1, 2, 1, 128, 128, 128, True) appears twice in the test_configs list, which is redundant and increases test execution time unnecessarily.
| (1, 2, 1, 128, 128, 128, True), |
Copilot
AI
Aug 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate test configuration in flex test function. The test case (1, 2, 1, 128, 128, 128, True) appears again, and there's also a duplicate (1, 2, 1, 128, 128, 128, False) that should be removed to avoid redundant testing.
| (1, 2, 1, 128, 128, 128, True), | |
| (1, 2, 1, 128, 128, 128, False), | |
| # (1, 2, 1, 128, 128, 128, True), # Removed duplicate | |
| # (1, 2, 1, 128, 128, 128, False), # Removed duplicate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate test configuration. The same test case (1, 2, 1, 128, 128, 128, True) appears twice in the test_configs list, which is redundant and increases test execution time unnecessarily.