diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index 39e64e0..2c73a53 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -357,8 +357,8 @@ def test_forward_equivalence(accuracy_threshold=0.95): # If you encounter NAN issues when running multiple configurations, try running a single configuration test_configs = [ # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal) - (1, 1, 1, 4, 64, 32, True), - (1, 1, 1, 4, 64, 32, False), + (1, 1, 1, 64, 64, 32, True), + (1, 1, 1, 64, 64, 32, False), (1, 1, 1, 128, 128, 32, True), (1, 1, 1, 128, 128, 32, False), (1, 1, 1, 256, 256, 32, True), @@ -377,7 +377,7 @@ def test_forward_equivalence(accuracy_threshold=0.95): (1, 2, 1, 64, 64, 128, True), (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 256, 256, 128, True), - (1, 2, 1, 511, 512, 128, True), + (1, 2, 1, 512, 512, 128, False), ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu")