diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index 2c73a53..5d89419 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -377,7 +377,7 @@ def test_forward_equivalence(accuracy_threshold=0.95): (1, 2, 1, 64, 64, 128, True), (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 256, 256, 128, True), - (1, 2, 1, 512, 512, 128, False), + (1, 2, 1, 512, 512, 128, True), ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu")