From 6a8d4c3654884b434493aeb27d365da68206b8e9 Mon Sep 17 00:00:00 2001 From: mganahl Date: Fri, 7 Aug 2020 13:48:17 -0400 Subject: [PATCH] fix tests --- tensornetwork/backends/jax/jax_backend_test.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tensornetwork/backends/jax/jax_backend_test.py b/tensornetwork/backends/jax/jax_backend_test.py index 5e5f5ced8..95f85dae9 100644 --- a/tensornetwork/backends/jax/jax_backend_test.py +++ b/tensornetwork/backends/jax/jax_backend_test.py @@ -653,13 +653,13 @@ def test_eigs_raises(): ################################################################## ############# This test should just not crash ################ ################################################################## -@pytest.mark.parametrize("dtype", [np.float64, np.complex128]) +@pytest.mark.parametrize("dtype", + [np.float64, np.complex128, np.float32, np.complex64]) def test_eigs_bugfix(dtype): backend = jax_backend.JaxBackend() D = 200 - dtype = np.complex128 - mat = np.random.rand(D, D).astype(dtype) - x = np.random.rand(D).astype(dtype) + mat = jax.numpy.array(np.random.rand(D, D).astype(dtype)) + x = jax.numpy.array(np.random.rand(D).astype(dtype)) def matvec_jax(vector, matrix): return matrix @ vector @@ -672,6 +672,7 @@ def matvec_jax(vector, matrix): maxiter=10, num_krylov_vecs=100, tol=0.0001) + #this test will cause some annoying output to std buffer with pytest.raises(np.linalg.LinAlgError): backend.eigs( matvec_jax, [mat], @@ -681,7 +682,7 @@ def matvec_jax(vector, matrix): maxiter=10, num_krylov_vecs=100, tol=0.0001, - QR_thresh=0.0) + res_thresh=0.0) def test_sum():