diff --git a/benchmarks/backward_equivalence.py b/benchmarks/backward_equivalence.py index f971ac8..468ab2c 100644 --- a/benchmarks/backward_equivalence.py +++ b/benchmarks/backward_equivalence.py @@ -598,23 +598,22 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): (1, 2, 1, 4096, 4096, 96, True), (1, 2, 1, 4096, 4096, 96, False), - # Not support head_dim => 128 in sm89 yet - # Because fwd uses splitkv branch, this branch does not support head_dim>=128 for now # Head dim 128 - # (1, 2, 1, 128, 128, 128, True), - # (1, 2, 1, 128, 128, 128, True), - # (1, 2, 1, 256, 256, 128, True), - # (1, 2, 1, 256, 256, 128, False), - # (1, 2, 1, 512, 512, 128, True), - # (1, 2, 1, 512, 512, 128, False), - # (1, 2, 1, 1024, 1024, 128, True), - # (1, 2, 1, 1024, 1024, 128, False), - # (1, 2, 1, 2048, 2048, 128, True), - # (1, 2, 1, 2048, 2048, 128, False), - # (1, 2, 1, 4096, 4096, 128, True), - # (1, 2, 1, 4096, 4096, 128, False), + (1, 2, 1, 128, 128, 128, True), + (1, 2, 1, 128, 128, 128, False), + (1, 2, 1, 256, 256, 128, True), + (1, 2, 1, 256, 256, 128, False), + (1, 2, 1, 512, 512, 128, True), + (1, 2, 1, 512, 512, 128, False), + (1, 2, 1, 1024, 1024, 128, True), + (1, 2, 1, 1024, 1024, 128, False), + (1, 2, 1, 2048, 2048, 128, True), + (1, 2, 1, 2048, 2048, 128, False), + (1, 2, 1, 4096, 4096, 128, True), + (1, 2, 1, 4096, 4096, 128, False), # Head dim 256 + # Because fwd uses splitkv branch, this branch does not support head_dim=256 for now # For head_dim=256, besides the reason of splitkv branch, bwd itself does not support it, not enough shared memory # (1, 2, 1, 128, 128, 256, True), # (1, 2, 1, 128, 128, 256, False), diff --git a/benchmarks/forward_equivalence.py b/benchmarks/forward_equivalence.py index 0f8acf8..5aac129 100644 --- a/benchmarks/forward_equivalence.py +++ b/benchmarks/forward_equivalence.py @@ -561,22 +561,22 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): (1, 2, 1, 4096, 4096, 96, True), (1, 2, 1, 4096, 4096, 96, False), - # Not support head_dim >= 128 in sm89 yet - # Because fwd uses splitkv branch by default, and shared memory is not enough for sm89 # Head dim 128 - # (1, 2, 1, 128, 128, 128, True), - # (1, 2, 1, 128, 128, 128, False), - # (1, 2, 1, 256, 256, 128, True), - # (1, 2, 1, 256, 256, 128, False), - # (1, 2, 1, 512, 512, 128, True), - # (1, 2, 1, 512, 512, 128, False), - # (1, 2, 1, 1024, 1024, 128, True), - # (1, 2, 1, 1024, 1024, 128, False), - # (1, 2, 1, 2048, 2048, 128, True), - # (1, 2, 1, 2048, 2048, 128, False), - # (1, 2, 1, 4096, 4096, 128, True), - # (1, 2, 1, 4096, 4096, 128, False), + (1, 2, 1, 128, 128, 128, True), + (1, 2, 1, 128, 128, 128, False), + (1, 2, 1, 256, 256, 128, True), + (1, 2, 1, 256, 256, 128, False), + (1, 2, 1, 512, 512, 128, True), + (1, 2, 1, 512, 512, 128, False), + (1, 2, 1, 1024, 1024, 128, True), + (1, 2, 1, 1024, 1024, 128, False), + (1, 2, 1, 2048, 2048, 128, True), + (1, 2, 1, 2048, 2048, 128, False), + (1, 2, 1, 4096, 4096, 128, True), + (1, 2, 1, 4096, 4096, 128, False), + # Not support head_dim = 256 in sm89 yet + # Because fwd uses splitkv branch by default, and shared memory is not enough for sm89 # Head dim 256 # (1, 2, 1, 128, 128, 256, True), # (1, 2, 1, 128, 128, 256, False),