diff --git a/pytest.ini b/pytest.ini index 639d8e002..b81dc7cc1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,4 +2,10 @@ testpaths = scico/test docs addopts = --doctest-glob="*rst" doctest_optionflags = NORMALIZE_WHITESPACE NUMBER -filterwarnings = ignore::DeprecationWarning:.*.compat +filterwarnings = + ignore::DeprecationWarning:.*pkg_resources.* + ignore::DeprecationWarning:.*flax.* + ignore::DeprecationWarning:.*.tensorboardx.* + ignore::DeprecationWarning:.*xdesign.* + ignore:.*pkg_resources.*:DeprecationWarning + ignore:.*imp module.*:DeprecationWarning diff --git a/scico/solver.py b/scico/solver.py index f93cd710e..c897a55eb 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -93,9 +93,10 @@ def wrapper(x, *args): # apply val_grad_func to un-vectorized input val = val_func(snp.reshape(x, shape).astype(dtype), *args) - # Convert val into numpy array, then cast to float - # Convert 'val' into a scalar, rather than ndarray of shape (1,) - val = np.array(val).astype(float).item() + # Convert val into numpy array, cast to float, convert to scalar + val = np.array(val).astype(float) + val = val.item() if val.ndim == 0 else val[0].item() + return val return wrapper @@ -280,7 +281,8 @@ def minimize_scalar( def f(x, *args): # Wrap jax-based function `func` to return a numpy float rather # than a jax array of size (1,) - return func(x, *args).item() + y = func(x, *args) + return y.item() if y.ndim == 0 else y[0].item() res = spopt.minimize_scalar( fun=f, diff --git a/scico/test/flax/test_apply.py b/scico/test/flax/test_apply.py index 43f54bcca..fbe53bb1d 100644 --- a/scico/test/flax/test_apply.py +++ b/scico/test/flax/test_apply.py @@ -114,12 +114,10 @@ def test_apply_from_checkpoint(testobj): model, testobj.test_ds, ) - print("variables: ", variables) except Exception as e: print(e) assert 0 else: - flat_params2 = flatten_dict(variables["params"]) flat_bstats2 = flatten_dict(variables["batch_stats"]) params2 = [t[1] for t in sorted(flat_params2.items())] diff --git a/scico/test/functional/test_core.py b/scico/test/functional/test_core.py index 6cf0c0393..d0b0c6475 100644 --- a/scico/test/functional/test_core.py +++ b/scico/test/functional/test_core.py @@ -3,7 +3,7 @@ import numpy as np import jax.numpy as jnp -from jax.config import config +from jax import config # enable 64-bit mode for output dtype checks config.update("jax_enable_x64", True) diff --git a/scico/test/functional/test_loss.py b/scico/test/functional/test_loss.py index f5b4a3dd9..0a3392f48 100644 --- a/scico/test/functional/test_loss.py +++ b/scico/test/functional/test_loss.py @@ -1,6 +1,6 @@ import numpy as np -from jax.config import config +from jax import config import pytest @@ -31,7 +31,7 @@ def setup_method(self): self.v, key = randn((n,), key=key, dtype=dtype) # point for prox eval scalar, key = randn((1,), key=key, dtype=dtype) self.key = key - self.scalar = scalar.item() + self.scalar = scalar[0].item() def test_generic_squared_l2(self): A = linop.Identity(input_shape=self.y.shape) @@ -146,7 +146,7 @@ def setup_method(self): self.x, key = randn((n,), key=key, dtype=complex_dtype(dtype)) self.v, key = randn((n,), key=key, dtype=complex_dtype(dtype)) # point for prox eval scalar, key = randn((1,), key=key, dtype=dtype) - self.scalar = scalar.item() + self.scalar = scalar[0].item() @pytest.mark.parametrize("loss_tuple", abs_loss) def test_properties(self, loss_tuple): diff --git a/scico/test/functional/test_separable.py b/scico/test/functional/test_separable.py index 0af94d2ca..f0947e2f9 100644 --- a/scico/test/functional/test_separable.py +++ b/scico/test/functional/test_separable.py @@ -1,6 +1,6 @@ import numpy as np -from jax.config import config +from jax import config # enable 64-bit mode for output dtype checks config.update("jax_enable_x64", True) diff --git a/scico/test/linop/test_diag.py b/scico/test/linop/test_diag.py index 55c32ea5a..303228170 100644 --- a/scico/test/linop/test_diag.py +++ b/scico/test/linop/test_diag.py @@ -2,7 +2,7 @@ import numpy as np -from jax.config import config +from jax import config import pytest @@ -75,7 +75,6 @@ def test_adjoint(self, input_shape, diagonal_dtype): @pytest.mark.parametrize("input_shape1", input_shapes) @pytest.mark.parametrize("input_shape2", input_shapes) def test_binary_op(self, input_shape1, input_shape2, diagonal_dtype, operator): - diagonal1, key = randn(input_shape1, dtype=diagonal_dtype, key=self.key) diagonal2, key = randn(input_shape2, dtype=diagonal_dtype, key=key) x, key = randn(input_shape1, dtype=diagonal_dtype, key=key) @@ -96,7 +95,6 @@ def test_binary_op(self, input_shape1, input_shape2, diagonal_dtype, operator): @pytest.mark.parametrize("input_shape1", input_shapes) @pytest.mark.parametrize("input_shape2", input_shapes) def test_matmul(self, input_shape1, input_shape2, diagonal_dtype): - diagonal1, key = randn(input_shape1, dtype=diagonal_dtype, key=self.key) diagonal2, key = randn(input_shape2, dtype=diagonal_dtype, key=key) x, key = randn(input_shape1, dtype=diagonal_dtype, key=key) @@ -161,7 +159,6 @@ def test_scalar_left(self, operator): @pytest.mark.parametrize("diagonal_dtype", [np.float32, np.complex64]) def test_gram_op(self, diagonal_dtype): - input_shape = (7,) diagonal, key = randn(input_shape, dtype=diagonal_dtype, key=self.key) @@ -174,7 +171,6 @@ def test_gram_op(self, diagonal_dtype): @pytest.mark.parametrize("diagonal_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("ord", [None, "fro", "nuc", -np.inf, np.inf, 1, -1, 2, -2]) def test_norm(self, diagonal_dtype, ord): - input_shape = (5,) diagonal, key = randn(input_shape, dtype=diagonal_dtype, key=self.key) @@ -185,7 +181,6 @@ def test_norm(self, diagonal_dtype, ord): snp.testing.assert_allclose(n1, n2, rtol=1e-6) def test_norm_except(self): - input_shape = (5,) diagonal, key = randn(input_shape, dtype=np.float32, key=self.key) diff --git a/scico/test/linop/test_linop.py b/scico/test/linop/test_linop.py index ebd0f2406..7f97f9caa 100644 --- a/scico/test/linop/test_linop.py +++ b/scico/test/linop/test_linop.py @@ -2,7 +2,7 @@ import numpy as np -from jax.config import config +from jax import config import pytest diff --git a/scico/test/linop/test_linop_util.py b/scico/test/linop/test_linop_util.py index 4f618c90a..f04bdb7f8 100644 --- a/scico/test/linop/test_linop_util.py +++ b/scico/test/linop/test_linop_util.py @@ -1,6 +1,6 @@ import numpy as np -from jax.config import config +from jax import config import pytest diff --git a/scico/test/test_blockarray.py b/scico/test/numpy/test_blockarray.py similarity index 99% rename from scico/test/test_blockarray.py rename to scico/test/numpy/test_blockarray.py index 101201e7d..ed6003df7 100644 --- a/scico/test/test_blockarray.py +++ b/scico/test/numpy/test_blockarray.py @@ -24,7 +24,7 @@ class OperatorsTestObj: def __init__(self, dtype): key = None scalar, key = randn(shape=(1,), dtype=dtype, key=key) - self.scalar = scalar.item() # convert to float + self.scalar = scalar[0].item() # convert to float self.a0, key = randn(shape=(2, 3), dtype=dtype, key=key) self.a1, key = randn(shape=(2, 3, 4), dtype=dtype, key=key) diff --git a/scico/test/test_numpy.py b/scico/test/numpy/test_numpy.py similarity index 100% rename from scico/test/test_numpy.py rename to scico/test/numpy/test_numpy.py diff --git a/scico/test/test_numpy_util.py b/scico/test/numpy/test_numpy_util.py similarity index 100% rename from scico/test/test_numpy_util.py rename to scico/test/numpy/test_numpy_util.py diff --git a/scico/test/test_biconvolve.py b/scico/test/operator/test_biconvolve.py similarity index 100% rename from scico/test/test_biconvolve.py rename to scico/test/operator/test_biconvolve.py diff --git a/scico/test/operator/test_operator.py b/scico/test/operator/test_operator.py index 80ecc2082..f76343ff1 100644 --- a/scico/test/operator/test_operator.py +++ b/scico/test/operator/test_operator.py @@ -2,7 +2,7 @@ import numpy as np -from jax.config import config +from jax import config import pytest diff --git a/scico/test/test_solver.py b/scico/test/test_solver.py index ebbcbf9c5..68d670c41 100644 --- a/scico/test/test_solver.py +++ b/scico/test/test_solver.py @@ -1,6 +1,5 @@ import numpy as np -import jax from jax.scipy.linalg import block_diag import pytest @@ -208,7 +207,7 @@ def test_minimize_scalar(self): def test_minimize_vector(dtype, method): B, M, N = (4, 3, 2) - # Models a 12x8 block-diagonal matrix with 4x3 blocks + # model a 12x8 block-diagonal matrix with 4x3 blocks A, key = random.randn((B, M, N), dtype=dtype) x, key = random.randn((B, N), dtype=dtype, key=key) y = snp.sum(A * x[:, None], axis=2) # contract along the N axis @@ -225,24 +224,6 @@ def f(x): assert out.x.shape == x.shape np.testing.assert_allclose(out.x.ravel(), expected, rtol=5e-4) - # Check if minimize returns the object to the proper device - devices = jax.devices() - - # For default device: - x0 = jax.device_put(snp.zeros_like(x), devices[0]) - out = solver.minimize(f, x0=x0, method=method) - assert out.x.device() == devices[0] - assert out.x.shape == x0.shape - np.testing.assert_allclose(out.x.ravel(), expected, rtol=5e-4) - - # If more than one device is present: - if len(devices) > 1: - x0 = jax.device_put(snp.zeros_like(x), devices[1]) - out = solver.minimize(f, x0=x0, method=method) - assert out.x.device() == devices[1] - assert out.x.shape == x0.shape - np.testing.assert_allclose(out.x.ravel(), expected, rtol=5e-4) - def test_split_join_array(): x, key = random.randn((4, 4), dtype=np.complex64)