diff --git a/tensornetwork/tests/ncon_interface_test.py b/tensornetwork/tests/ncon_interface_test.py index c083f05dc..eb7ce93d2 100644 --- a/tensornetwork/tests/ncon_interface_test.py +++ b/tensornetwork/tests/ncon_interface_test.py @@ -19,6 +19,7 @@ from tensornetwork.ncon_interface import (_get_cont_out_labels, _canonicalize_network_structure) from tensornetwork.backends.backend_factory import get_backend +from tensornetwork.backends.jax.jax_backend import JaxBackend from tensornetwork.contractors import greedy @@ -57,11 +58,10 @@ def test_return_type(backend): result_2 = ncon_interface.ncon([n1, n2], [(-1, 1), (1, -2)], backend=backend) result_3 = ncon_interface.ncon([n1, t2], [(-1, 1), (1, -2)], backend=backend) assert isinstance(result_2, Tensor) - if backend not in ('jax', get_backend('jax')): - # jitted functions return jaxlib.xla_extension.Buffer, - # convert_to_tensor returns jax.interpreters.xla._DeviceArray now. - assert isinstance(result_1, type(n1.backend.convert_to_tensor(t1))) - assert isinstance(result_3, type(n1.backend.convert_to_tensor(t1))) + if isinstance(backend, JaxBackend) or backend == 'jax': + pytest.skip('return-type tests for jax not implemented') + assert isinstance(result_1, type(n1.backend.convert_to_tensor(t1))) + assert isinstance(result_3, type(n1.backend.convert_to_tensor(t1))) def test_order_spec(backend):