Skip to content

Commit

Permalink
Bump jaxlib/jax and flax max versions (#518)
Browse files Browse the repository at this point in the history
* Fix oversight in test cleanup function

* Bump jaxlib/jax and flax max versions

* Deal with apparent changes in jax.pure_callback: argument now jax array instead of numpy array

* Resolve doctest failure of mysterious origin

* Update docstring
  • Loading branch information
bwohlberg committed May 14, 2024
1 parent 4c90b49 commit 71b15ff
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 11 deletions.
8 changes: 6 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
have_ray = True
ray.init(num_cpus=1) # call required to be here: see ray-project/ray#44087

import jax.numpy as jnp

import scico.numpy as snp


Expand All @@ -24,7 +26,8 @@ def pytest_sessionstart(session):

def pytest_sessionfinish(session, exitstatus):
"""Clean up after end of test session."""
ray.shutdown()
if have_ray:
ray.shutdown()


@pytest.fixture(autouse=True)
Expand All @@ -33,7 +36,8 @@ def add_modules(doctest_namespace):
Necessary because `np` is used in doc strings for jax functions
(e.g. `linear_transpose`) that get pulled into `scico/__init__.py`.
Also allow `snp` to be used without explicitly importing.
Also allow `snp` and `jnp` to be used without explicitly importing.
"""
doctest_namespace["np"] = np
doctest_namespace["snp"] = snp
doctest_namespace["jnp"] = jnp
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ scipy>=1.6.0
imageio>=2.17
tifffile
matplotlib
jaxlib>=0.4.3,<=0.4.26
jax>=0.4.3,<=0.4.26
jaxlib>=0.4.3,<=0.4.28
jax>=0.4.3,<=0.4.28
orbax-checkpoint<=0.5.7
flax>=0.8.0,<=0.8.2
flax>=0.8.0,<=0.8.3
pyabel>=0.9.0
14 changes: 8 additions & 6 deletions scico/linop/xray/astra.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,12 @@ def angle_to_vector(det_spacing: Tuple[float, float], angles: np.ndarray) -> np.

def _ensure_writeable(x):
"""Ensure that `x.flags.writeable` is ``True``, copying if needed."""

if not x.flags.writeable:
try:
x.setflags(write=True)
except ValueError:
x = x.copy()
if hasattr(x, "flags"): # x is a numpy array
if not x.flags.writeable:
try:
x.setflags(write=True)
except ValueError:
x = x.copy()
else: # x is a jax array (which is immutable)
x = np.array(x)
return x

0 comments on commit 71b15ff

Please sign in to comment.