Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LinearSolveTest.test_solve_sparse fails with jax 0.4.26 #592

Open
GaetanLepage opened this issue May 6, 2024 · 1 comment
Open

LinearSolveTest.test_solve_sparse fails with jax 0.4.26 #592

GaetanLepage opened this issue May 6, 2024 · 1 comment

Comments

@GaetanLepage
Copy link

Context: updating jax in nixpkgs: NixOS/nixpkgs#291705 (comment)

One of the jaxopt tests fail when ran with the latest jax (0.4.26):

============================= test session starts ==============================
platform linux -- Python 3.11.9, pytest-8.1.1, pluggy-1.4.0
rootdir: /build/source
plugins: xdist-3.5.0
48 workers [561 items]   m
...s.................................................................... [ 12%]
........................................................................ [ 25%]
.............................................s......s............s...... [ 38%]
............s........s.................................................. [ 51%]
.................F...................................................... [ 64%]
........................................................................ [ 77%]
........................................................................ [ 89%]
.........................................................                [100%]
=================================== FAILURES ===================================
______________________ LinearSolveTest.test_solve_sparse _______________________
[gw24] linux -- Python 3.11.9 /nix/store/lpi16513bai8kg2bd841745vzk72475x-python3-3.11.9/bin/python3.11

self = <linear_solve_test.LinearSolveTest testMethod=test_solve_sparse>

    def test_solve_sparse(self):
      rng = onp.random.RandomState(0)
    
      # Matrix case.
      A = rng.randn(5, 5)
      b = rng.randn(5)
    
      def matvec(x):
        return jnp.dot(A, x)
    
      x = linear_solve.solve_lu(matvec, b)
      x2 = linear_solve.solve_normal_cg(matvec, b)
      x3 = linear_solve.solve_gmres(matvec, b)
      x4 = linear_solve.solve_bicgstab(matvec, b)
      x5 = linear_solve.solve_iterative_refinement(matvec, b)
      x6 = linear_solve.solve_qr(matvec, b)
    
      self.assertArraysAllClose(x, x2, atol=1e-4)
      self.assertArraysAllClose(x, x3, atol=1e-4)
>     self.assertArraysAllClose(x, x4, atol=1e-4)

tests/linear_solve_test.py:133: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
jaxopt/_src/test_util.py:292: in assertArraysAllClose
    _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
jaxopt/_src/test_util.py:262: in _assert_numpy_allclose
    onp.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (<function assert_allclose.<locals>.compare at 0x7ffcd857af20>, array([-6.9443436, -1.9871655,  7.7470713,  7.654949 ,...87526],
      dtype=float32), array([-6.9444494, -1.9872105,  7.7471952,  7.655079 , -7.0388584],
      dtype=float32))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-06, atol=0.0001', 'verbose': True}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError: 
E           Not equal to tolerance rtol=1e-06, atol=0.0001
E           
E           Mismatched elements: 2 / 5 (40%)
E           Max absolute difference: 0.0001297
E           Max relative difference: 2.267556e-05
E            x: array([-6.944344, -1.987165,  7.747071,  7.654949, -7.038753],
E                 dtype=float32)
E            y: array([-6.944449, -1.987211,  7.747195,  7.655079, -7.038858],
E                 dtype=float32)

/nix/store/lpi16513bai8kg2bd841745vzk72475x-python3-3.11.9/lib/python3.11/contextlib.py:81: AssertionError

Any idea ?

@GaetanLepage
Copy link
Author

Very likely related to #577

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant