diff --git a/tests/test_functional.py b/tests/test_functional.py index 1cca04511..efa5fe4a7 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -22,6 +22,8 @@ torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000) k = 20 +random.seed(42) + def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True): idx = torch.isclose(a, b, rtol=rtol, atol=atol)