From a347663c8aca748599c90b40b19a11c0191f50dc Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Fri, 29 Aug 2025 16:17:14 +0800 Subject: [PATCH 1/2] Enables head dimension 128 test cases in backward equivalence Uncomments previously disabled test configurations for head dimension 128, allowing comprehensive testing of backward pass equivalence across various sequence lengths and causal modes. Moves limitation comment to head dimension 256 section where the restriction still applies due to splitkv branch constraints and shared memory limitations. --- benchmarks/backward_equivalence.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/benchmarks/backward_equivalence.py b/benchmarks/backward_equivalence.py index f971ac8..468ab2c 100644 --- a/benchmarks/backward_equivalence.py +++ b/benchmarks/backward_equivalence.py @@ -598,23 +598,22 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): (1, 2, 1, 4096, 4096, 96, True), (1, 2, 1, 4096, 4096, 96, False), - # Not support head_dim => 128 in sm89 yet - # Because fwd uses splitkv branch, this branch does not support head_dim>=128 for now # Head dim 128 - # (1, 2, 1, 128, 128, 128, True), - # (1, 2, 1, 128, 128, 128, True), - # (1, 2, 1, 256, 256, 128, True), - # (1, 2, 1, 256, 256, 128, False), - # (1, 2, 1, 512, 512, 128, True), - # (1, 2, 1, 512, 512, 128, False), - # (1, 2, 1, 1024, 1024, 128, True), - # (1, 2, 1, 1024, 1024, 128, False), - # (1, 2, 1, 2048, 2048, 128, True), - # (1, 2, 1, 2048, 2048, 128, False), - # (1, 2, 1, 4096, 4096, 128, True), - # (1, 2, 1, 4096, 4096, 128, False), + (1, 2, 1, 128, 128, 128, True), + (1, 2, 1, 128, 128, 128, False), + (1, 2, 1, 256, 256, 128, True), + (1, 2, 1, 256, 256, 128, False), + (1, 2, 1, 512, 512, 128, True), + (1, 2, 1, 512, 512, 128, False), + (1, 2, 1, 1024, 1024, 128, True), + (1, 2, 1, 1024, 1024, 128, False), + (1, 2, 1, 2048, 2048, 128, True), + (1, 2, 1, 2048, 2048, 128, False), + (1, 2, 1, 4096, 4096, 128, True), + (1, 2, 1, 4096, 4096, 128, False), # Head dim 256 + # Because fwd uses splitkv branch, this branch does not support head_dim=256 for now # For head_dim=256, besides the reason of splitkv branch, bwd itself does not support it, not enough shared memory # (1, 2, 1, 128, 128, 256, True), # (1, 2, 1, 128, 128, 256, False), From 1df70ebf71ae480ae4e7ab18f8c444183c32a8ac Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Fri, 29 Aug 2025 16:17:45 +0800 Subject: [PATCH 2/2] Enables head dimension 128 test cases in CUDA forward equivalence Uncomments previously disabled test cases for head dimension 128 configurations across various sequence lengths. Moves the comment about sm89 shared memory limitations to head dimension 256 section where it remains applicable. --- benchmarks/forward_equivalence.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/benchmarks/forward_equivalence.py b/benchmarks/forward_equivalence.py index 0f8acf8..5aac129 100644 --- a/benchmarks/forward_equivalence.py +++ b/benchmarks/forward_equivalence.py @@ -561,22 +561,22 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): (1, 2, 1, 4096, 4096, 96, True), (1, 2, 1, 4096, 4096, 96, False), - # Not support head_dim >= 128 in sm89 yet - # Because fwd uses splitkv branch by default, and shared memory is not enough for sm89 # Head dim 128 - # (1, 2, 1, 128, 128, 128, True), - # (1, 2, 1, 128, 128, 128, False), - # (1, 2, 1, 256, 256, 128, True), - # (1, 2, 1, 256, 256, 128, False), - # (1, 2, 1, 512, 512, 128, True), - # (1, 2, 1, 512, 512, 128, False), - # (1, 2, 1, 1024, 1024, 128, True), - # (1, 2, 1, 1024, 1024, 128, False), - # (1, 2, 1, 2048, 2048, 128, True), - # (1, 2, 1, 2048, 2048, 128, False), - # (1, 2, 1, 4096, 4096, 128, True), - # (1, 2, 1, 4096, 4096, 128, False), + (1, 2, 1, 128, 128, 128, True), + (1, 2, 1, 128, 128, 128, False), + (1, 2, 1, 256, 256, 128, True), + (1, 2, 1, 256, 256, 128, False), + (1, 2, 1, 512, 512, 128, True), + (1, 2, 1, 512, 512, 128, False), + (1, 2, 1, 1024, 1024, 128, True), + (1, 2, 1, 1024, 1024, 128, False), + (1, 2, 1, 2048, 2048, 128, True), + (1, 2, 1, 2048, 2048, 128, False), + (1, 2, 1, 4096, 4096, 128, True), + (1, 2, 1, 4096, 4096, 128, False), + # Not support head_dim = 256 in sm89 yet + # Because fwd uses splitkv branch by default, and shared memory is not enough for sm89 # Head dim 256 # (1, 2, 1, 128, 128, 256, True), # (1, 2, 1, 128, 128, 256, False),