Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mganahl committed Aug 7, 2020
1 parent 04388a6 commit 6a8d4c3
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions tensornetwork/backends/jax/jax_backend_test.py
Expand Up @@ -653,13 +653,13 @@ def test_eigs_raises():
##################################################################
############# This test should just not crash ################
##################################################################
@pytest.mark.parametrize("dtype", [np.float64, np.complex128])
@pytest.mark.parametrize("dtype",
[np.float64, np.complex128, np.float32, np.complex64])
def test_eigs_bugfix(dtype):
backend = jax_backend.JaxBackend()
D = 200
dtype = np.complex128
mat = np.random.rand(D, D).astype(dtype)
x = np.random.rand(D).astype(dtype)
mat = jax.numpy.array(np.random.rand(D, D).astype(dtype))
x = jax.numpy.array(np.random.rand(D).astype(dtype))

def matvec_jax(vector, matrix):
return matrix @ vector
Expand All @@ -672,6 +672,7 @@ def matvec_jax(vector, matrix):
maxiter=10,
num_krylov_vecs=100,
tol=0.0001)
#this test will cause some annoying output to std buffer
with pytest.raises(np.linalg.LinAlgError):
backend.eigs(
matvec_jax, [mat],
Expand All @@ -681,7 +682,7 @@ def matvec_jax(vector, matrix):
maxiter=10,
num_krylov_vecs=100,
tol=0.0001,
QR_thresh=0.0)
res_thresh=0.0)


def test_sum():
Expand Down

0 comments on commit 6a8d4c3

Please sign in to comment.