Skip to content

Commit

Permalink
Resolve numerous deprecation warnings (#485)
Browse files Browse the repository at this point in the history
* Address deprecation warnings

* Coding standards compliance

* Move test modules to more coherent locations

* Remove tests that are greatly complicated by the deprecation of jax device method

* Bug fix

* Test should not print to stdout

* Suppress deprecation warnings in other packages
  • Loading branch information
bwohlberg committed Dec 14, 2023
1 parent 9b32636 commit 4eb21c1
Show file tree
Hide file tree
Showing 15 changed files with 24 additions and 42 deletions.
8 changes: 7 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,10 @@
testpaths = scico/test docs
addopts = --doctest-glob="*rst"
doctest_optionflags = NORMALIZE_WHITESPACE NUMBER
filterwarnings = ignore::DeprecationWarning:.*.compat
filterwarnings =
ignore::DeprecationWarning:.*pkg_resources.*
ignore::DeprecationWarning:.*flax.*
ignore::DeprecationWarning:.*.tensorboardx.*
ignore::DeprecationWarning:.*xdesign.*
ignore:.*pkg_resources.*:DeprecationWarning
ignore:.*imp module.*:DeprecationWarning
10 changes: 6 additions & 4 deletions scico/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,10 @@ def wrapper(x, *args):
# apply val_grad_func to un-vectorized input
val = val_func(snp.reshape(x, shape).astype(dtype), *args)

# Convert val into numpy array, then cast to float
# Convert 'val' into a scalar, rather than ndarray of shape (1,)
val = np.array(val).astype(float).item()
# Convert val into numpy array, cast to float, convert to scalar
val = np.array(val).astype(float)
val = val.item() if val.ndim == 0 else val[0].item()

return val

return wrapper
Expand Down Expand Up @@ -280,7 +281,8 @@ def minimize_scalar(
def f(x, *args):
# Wrap jax-based function `func` to return a numpy float rather
# than a jax array of size (1,)
return func(x, *args).item()
y = func(x, *args)
return y.item() if y.ndim == 0 else y[0].item()

res = spopt.minimize_scalar(
fun=f,
Expand Down
2 changes: 0 additions & 2 deletions scico/test/flax/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,10 @@ def test_apply_from_checkpoint(testobj):
model,
testobj.test_ds,
)
print("variables: ", variables)
except Exception as e:
print(e)
assert 0
else:

flat_params2 = flatten_dict(variables["params"])
flat_bstats2 = flatten_dict(variables["batch_stats"])
params2 = [t[1] for t in sorted(flat_params2.items())]
Expand Down
2 changes: 1 addition & 1 deletion scico/test/functional/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

import jax.numpy as jnp
from jax.config import config
from jax import config

# enable 64-bit mode for output dtype checks
config.update("jax_enable_x64", True)
Expand Down
6 changes: 3 additions & 3 deletions scico/test/functional/test_loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from jax.config import config
from jax import config

import pytest

Expand Down Expand Up @@ -31,7 +31,7 @@ def setup_method(self):
self.v, key = randn((n,), key=key, dtype=dtype) # point for prox eval
scalar, key = randn((1,), key=key, dtype=dtype)
self.key = key
self.scalar = scalar.item()
self.scalar = scalar[0].item()

def test_generic_squared_l2(self):
A = linop.Identity(input_shape=self.y.shape)
Expand Down Expand Up @@ -146,7 +146,7 @@ def setup_method(self):
self.x, key = randn((n,), key=key, dtype=complex_dtype(dtype))
self.v, key = randn((n,), key=key, dtype=complex_dtype(dtype)) # point for prox eval
scalar, key = randn((1,), key=key, dtype=dtype)
self.scalar = scalar.item()
self.scalar = scalar[0].item()

@pytest.mark.parametrize("loss_tuple", abs_loss)
def test_properties(self, loss_tuple):
Expand Down
2 changes: 1 addition & 1 deletion scico/test/functional/test_separable.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from jax.config import config
from jax import config

# enable 64-bit mode for output dtype checks
config.update("jax_enable_x64", True)
Expand Down
7 changes: 1 addition & 6 deletions scico/test/linop/test_diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from jax.config import config
from jax import config

import pytest

Expand Down Expand Up @@ -75,7 +75,6 @@ def test_adjoint(self, input_shape, diagonal_dtype):
@pytest.mark.parametrize("input_shape1", input_shapes)
@pytest.mark.parametrize("input_shape2", input_shapes)
def test_binary_op(self, input_shape1, input_shape2, diagonal_dtype, operator):

diagonal1, key = randn(input_shape1, dtype=diagonal_dtype, key=self.key)
diagonal2, key = randn(input_shape2, dtype=diagonal_dtype, key=key)
x, key = randn(input_shape1, dtype=diagonal_dtype, key=key)
Expand All @@ -96,7 +95,6 @@ def test_binary_op(self, input_shape1, input_shape2, diagonal_dtype, operator):
@pytest.mark.parametrize("input_shape1", input_shapes)
@pytest.mark.parametrize("input_shape2", input_shapes)
def test_matmul(self, input_shape1, input_shape2, diagonal_dtype):

diagonal1, key = randn(input_shape1, dtype=diagonal_dtype, key=self.key)
diagonal2, key = randn(input_shape2, dtype=diagonal_dtype, key=key)
x, key = randn(input_shape1, dtype=diagonal_dtype, key=key)
Expand Down Expand Up @@ -161,7 +159,6 @@ def test_scalar_left(self, operator):

@pytest.mark.parametrize("diagonal_dtype", [np.float32, np.complex64])
def test_gram_op(self, diagonal_dtype):

input_shape = (7,)
diagonal, key = randn(input_shape, dtype=diagonal_dtype, key=self.key)

Expand All @@ -174,7 +171,6 @@ def test_gram_op(self, diagonal_dtype):
@pytest.mark.parametrize("diagonal_dtype", [np.float32, np.complex64])
@pytest.mark.parametrize("ord", [None, "fro", "nuc", -np.inf, np.inf, 1, -1, 2, -2])
def test_norm(self, diagonal_dtype, ord):

input_shape = (5,)
diagonal, key = randn(input_shape, dtype=diagonal_dtype, key=self.key)

Expand All @@ -185,7 +181,6 @@ def test_norm(self, diagonal_dtype, ord):
snp.testing.assert_allclose(n1, n2, rtol=1e-6)

def test_norm_except(self):

input_shape = (5,)
diagonal, key = randn(input_shape, dtype=np.float32, key=self.key)

Expand Down
2 changes: 1 addition & 1 deletion scico/test/linop/test_linop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from jax.config import config
from jax import config

import pytest

Expand Down
2 changes: 1 addition & 1 deletion scico/test/linop/test_linop_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from jax.config import config
from jax import config

import pytest

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class OperatorsTestObj:
def __init__(self, dtype):
key = None
scalar, key = randn(shape=(1,), dtype=dtype, key=key)
self.scalar = scalar.item() # convert to float
self.scalar = scalar[0].item() # convert to float

self.a0, key = randn(shape=(2, 3), dtype=dtype, key=key)
self.a1, key = randn(shape=(2, 3, 4), dtype=dtype, key=key)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion scico/test/operator/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from jax.config import config
from jax import config

import pytest

Expand Down
21 changes: 1 addition & 20 deletions scico/test/test_solver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np

import jax
from jax.scipy.linalg import block_diag

import pytest
Expand Down Expand Up @@ -208,7 +207,7 @@ def test_minimize_scalar(self):
def test_minimize_vector(dtype, method):
B, M, N = (4, 3, 2)

# Models a 12x8 block-diagonal matrix with 4x3 blocks
# model a 12x8 block-diagonal matrix with 4x3 blocks
A, key = random.randn((B, M, N), dtype=dtype)
x, key = random.randn((B, N), dtype=dtype, key=key)
y = snp.sum(A * x[:, None], axis=2) # contract along the N axis
Expand All @@ -225,24 +224,6 @@ def f(x):
assert out.x.shape == x.shape
np.testing.assert_allclose(out.x.ravel(), expected, rtol=5e-4)

# Check if minimize returns the object to the proper device
devices = jax.devices()

# For default device:
x0 = jax.device_put(snp.zeros_like(x), devices[0])
out = solver.minimize(f, x0=x0, method=method)
assert out.x.device() == devices[0]
assert out.x.shape == x0.shape
np.testing.assert_allclose(out.x.ravel(), expected, rtol=5e-4)

# If more than one device is present:
if len(devices) > 1:
x0 = jax.device_put(snp.zeros_like(x), devices[1])
out = solver.minimize(f, x0=x0, method=method)
assert out.x.device() == devices[1]
assert out.x.shape == x0.shape
np.testing.assert_allclose(out.x.ravel(), expected, rtol=5e-4)


def test_split_join_array():
x, key = random.randn((4, 4), dtype=np.complex64)
Expand Down

0 comments on commit 4eb21c1

Please sign in to comment.