Skip to content

Commit

Permalink
Bump maximum jaxlib / jax version (#504)
Browse files Browse the repository at this point in the history
* Bump maximum jax version

* Update actions versions

* Fix error in docstring

* Improve equation spacing

* Bump maximum jax version

* Bump min flax version to avoid errors with jax>=0.4.25

* Add some error checking, ensure ValueError raised as expected by test_misc.py for functionals

* Fix bad formatting due to use of incorrect black version

* Remove apparently-spurious pytest.fixture decorator

* Bump jaxlib/jax max version

* Resolve TypeError in functools.wraps call

* Update training data generation

* Change conda source for astra-toolbox

* Debug macos ci failure

* Remove scipy version upper bound

* Bump flax version upper bound

* Fix black format version

* Fix merge error

* Clean up exception handling

* More robust handling of ray.available_resources() output

---------

Co-authored-by: crstngc <cristina.cgarcia@gmail.com>
  • Loading branch information
bwohlberg and crstngc committed Apr 30, 2024
1 parent 738b4a0 commit 19d37c9
Show file tree
Hide file tree
Showing 15 changed files with 163 additions and 135 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check_files.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: checkout
uses: actions/checkout@v3
uses: actions/checkout@v4
- id: files
uses: Ana06/get-changed-files@v2.2.0
continue-on-error: true
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Black code formatter
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ jobs:
mypy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
submodules: recursive
- name: Install Python 3
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install dependencies
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/pypi_upload.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ jobs:
name: Upload package to PyPI
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
submodules: recursive
- name: Install Python 3
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install dependencies
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/pytest_latest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ jobs:
pytest-latest-jax:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
submodules: recursive
- name: Install Python 3
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install lastversion
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pytest_macos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
shell: bash -l {0}
steps:
# Check-out the repository under $GITHUB_WORKSPACE
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
submodules: recursive
# Set up conda/mamba environment
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/pytest_ubuntu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
shell: bash -l {0}
steps:
# Check-out the repository under $GITHUB_WORKSPACE
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
submodules: recursive
# Set up conda/mamba environment
Expand Down Expand Up @@ -63,7 +63,7 @@ jobs:
pip install -r requirements.txt
pip install -r dev_requirements.txt
mamba install -c conda-forge svmbir>=0.3.3
mamba install -c astra-toolbox astra-toolbox
mamba install -c conda-forge astra-toolbox
mamba install -c conda-forge pyyaml
pip install --upgrade --force-reinstall scipy>=1.6.0 # Temporary fix for GLIBCXX_3.4.30 not found in conda forge version
pip install bm3d>=4.0.0
Expand Down Expand Up @@ -96,9 +96,9 @@ jobs:
needs: test
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install deps
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test_examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
shell: bash -l {0}
steps:
# Check-out the repository under $GITHUB_WORKSPACE
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
submodules: recursive
# Set up conda/mamba environment
Expand Down Expand Up @@ -59,7 +59,7 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r dev_requirements.txt
mamba install -c astra-toolbox astra-toolbox
mamba install -c conda-forge astra-toolbox
mamba install -c conda-forge pyyaml
pip install --upgrade --force-reinstall scipy>=1.6.0 # Temporary fix for GLIBCXX_3.4.30 not found in conda forge version
pip install -r examples/examples_requirements.txt
Expand Down
10 changes: 5 additions & 5 deletions examples/scripts/ct_astra_3d_tv_padmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
$$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - C \mathbf{x}
\|_2^2 + \lambda \| D \mathbf{x} \|_{2,1} \;,$$
where $C$ is the convolution operator and $D$ is a finite difference
where $C$ is the X-ray transform and $D$ is a finite difference
operator. This problem can be expressed as
$$\mathrm{argmin}_{\mathbf{x}, \mathbf{z}} \; (1/2) \| \mathbf{y} -
Expand All @@ -77,11 +77,11 @@
with
$$f = 0 \quad g = g_0 + g_1$$
$$g_0(\mathbf{z}_0) = (1/2) \| \mathbf{y} - \mathbf{z}_0 \|_2^2 \quad
$$f = 0 \qquad g = g_0 + g_1$$
$$g_0(\mathbf{z}_0) = (1/2) \| \mathbf{y} - \mathbf{z}_0 \|_2^2 \qquad
g_1(\mathbf{z}_1) = \lambda \| \mathbf{z}_1 \|_{2,1}$$
$$A = \left( \begin{array}{c} C \\ D \end{array} \right) \quad
B = \left( \begin{array}{cc} -I & 0 \\ 0 & -I \end{array} \right) \quad
$$A = \left( \begin{array}{c} C \\ D \end{array} \right) \qquad
B = \left( \begin{array}{cc} -I & 0 \\ 0 & -I \end{array} \right) \qquad
\mathbf{c} = \left( \begin{array}{c} 0 \\ 0 \end{array} \right) \;.$$
This is a more complex splitting than that used in the
Expand Down
8 changes: 4 additions & 4 deletions examples/scripts/deconv_tv_padmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@
with
$$f = 0 \quad g = g_0 + g_1$$
$$g_0(\mathbf{z}_0) = (1/2) \| \mathbf{y} - \mathbf{z}_0 \|_2^2 \quad
$$f = 0 \qquad g = g_0 + g_1$$
$$g_0(\mathbf{z}_0) = (1/2) \| \mathbf{y} - \mathbf{z}_0 \|_2^2 \qquad
g_1(\mathbf{z}_1) = \lambda \| \mathbf{z}_1 \|_{2,1}$$
$$A = \left( \begin{array}{c} C \\ D \end{array} \right) \quad
B = \left( \begin{array}{cc} -I & 0 \\ 0 & -I \end{array} \right) \quad
$$A = \left( \begin{array}{c} C \\ D \end{array} \right) \qquad
B = \left( \begin{array}{cc} -I & 0 \\ 0 & -I \end{array} \right) \qquad
\mathbf{c} = \left( \begin{array}{c} 0 \\ 0 \end{array} \right) \;.$$
This is a more complex splitting than that used in the
Expand Down
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
typing_extensions
numpy>=1.20.0
scipy>=1.6.0,<1.13
scipy>=1.6.0
imageio>=2.17
tifffile
matplotlib
jaxlib>=0.4.3,<=0.4.23
jax>=0.4.3,<=0.4.23
jaxlib>=0.4.3,<=0.4.26
jax>=0.4.3,<=0.4.26
orbax-checkpoint<=0.5.7
flax>=0.6.1,<=0.7.5
flax>=0.8.0,<=0.8.2
pyabel>=0.9.0
112 changes: 80 additions & 32 deletions scico/flax/examples/data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"""

import os
import warnings
from time import time
from typing import Callable, List, Tuple, Union

Expand Down Expand Up @@ -114,14 +115,13 @@ def generate_foam2_images(seed: float, size: int, ndata: int) -> Array:
if not have_xdesign:
raise RuntimeError("Package xdesign is required for use of this function.")

np.random.seed(seed)
saux = np.zeros((ndata, size, size, 1))
# np.random.seed(seed)
saux = jnp.zeros((ndata, size, size, 1))
for i in range(ndata):
foam = Foam2(size_range=[0.075, 0.0025], gap=1e-3, porosity=1)
saux[i, ..., 0] = discrete_phantom(foam, size=size)

saux = saux.at[i, ..., 0].set(discrete_phantom(foam, size=size))
# normalize
saux = saux / np.max(saux, axis=(1, 2), keepdims=True)
saux = saux / jnp.max(saux, axis=(1, 2), keepdims=True)

return saux

Expand All @@ -143,15 +143,47 @@ def generate_foam1_images(seed: float, size: int, ndata: int) -> Array:
if not have_xdesign:
raise RuntimeError("Package xdesign is required for use of this function.")

np.random.seed(seed)
saux = np.zeros((ndata, size, size, 1))
# np.random.seed(seed)
saux = jnp.zeros((ndata, size, size, 1))
for i in range(ndata):
foam = Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1)
saux[i, ..., 0] = discrete_phantom(foam, size=size)
saux = saux.at[i, ..., 0].set(discrete_phantom(foam, size=size))

return saux


def vector_f(f_: Callable, v: Array) -> Array:
"""Vectorize application of operator.
Args:
f_: Operator to apply.
v: Array to evaluate.
Returns:
Result of evaluating operator over given arrays.
"""
lf = lambda x: jnp.atleast_3d(f_(x.squeeze()))
auto_batch = jax.vmap(lf)
return auto_batch(v)


def batched_f(f_: Callable, vr: Array) -> Array:
"""Distribute application of operator over a batch of vectors
among available processes.
Args:
f_: Operator to apply.
vr: Batch of arrays to evaluate.
Returns:
Result of evaluating operator over given batch of arrays. This
evaluation preserves the batch axis.
"""
nproc = jax.device_count()
res = jax.pmap(lambda i: vector_f(f_, vr[i]))(jnp.arange(nproc))
return res


def generate_ct_data(
nimg: int,
size: int,
Expand Down Expand Up @@ -194,37 +226,44 @@ def generate_ct_data(
time_dtgen = time() - start_time
else:
start_time = time()
img = imgfunc(seed, size, nimg)
img = distributed_data_generation(imgfunc, size, nimg, False)
time_dtgen = time() - start_time
# Clip to [0,1] range.
img = jnp.clip(img, a_min=0, a_max=1)
# Shard array

nproc = jax.device_count()
imgshd = img.reshape((nproc, -1, size, size, 1))

# Configure a CT projection operator to generate synthetic measurements.
angles = np.linspace(0, jnp.pi, nproj) # evenly spaced projection angles
gt_sh = (size, size)
detector_spacing = 1
A = XRayTransform2D(gt_sh, size, detector_spacing, angles) # Radon transform operator
A = XRayTransform2D(gt_sh, size, detector_spacing, angles) # X-ray transform operator

# Compute sinograms in parallel.
a_map = lambda v: jnp.atleast_3d(A @ v.squeeze())
start_time = time()
sinoshd = jax.pmap(lambda i: jax.lax.map(a_map, imgshd[i]))(jnp.arange(nproc))
if nproc > 1:
# Shard array
imgshd = img.reshape((nproc, -1, size, size, 1))
sinoshd = batched_f(A, imgshd)
sino = sinoshd.reshape((-1, nproj, size, 1))
else:
sino = vector_f(A, img)

time_sino = time() - start_time
sino = sinoshd.reshape((-1, nproj, size, 1))
# Normalize sinogram
sino = sino / size

# Compute filtered back projection in parallel.
afbp_map = lambda v: jnp.atleast_3d(A.fbp(v.squeeze()))
# Compute filtered back-projection in parallel.
start_time = time()
fbpshd = jax.pmap(lambda i: jax.lax.map(afbp_map, sinoshd[i]))(jnp.arange(nproc))
if nproc > 1:
fbpshd = batched_f(A.fbp, sinoshd)
fbp = fbpshd.reshape((-1, size, size, 1))
else:
fbp = vector_f(A.fbp, sino)
time_fbp = time() - start_time
# Clip to [0,1] range.
fbpshd = jnp.clip(fbpshd, a_min=0, a_max=1)
fbp = fbpshd.reshape((-1, size, size, 1))

# Normalize sinogram.
sino = sino / size
# Shift FBP to [0,1] range.
fbp = (fbp - fbp.min()) / (fbp.max() - fbp.min())

if verbose: # pragma: no cover
platform = jax.lib.xla_bridge.get_backend().platform
Expand Down Expand Up @@ -276,24 +315,27 @@ def generate_blur_data(
time_dtgen = time() - start_time
else:
start_time = time()
img = imgfunc(seed, size, nimg)
img = distributed_data_generation(imgfunc, size, nimg, False)
time_dtgen = time() - start_time

# Clip to [0,1] range.
img = jnp.clip(img, a_min=0, a_max=1)
# Shard array
nproc = jax.device_count()
imgshd = img.reshape((nproc, -1, size, size, 1))

# Configure blur operator
ishape = (size, size)
A = CircularConvolve(h=blur_kernel, input_shape=ishape)

# Compute blurred images in parallel
a_map = lambda v: jnp.atleast_3d(A @ v.squeeze())
start_time = time()
blurshd = jax.pmap(lambda i: jax.lax.map(a_map, imgshd[i]))(jnp.arange(nproc))
if nproc > 1:
# Shard array
imgshd = img.reshape((nproc, -1, size, size, 1))
blurshd = batched_f(A, imgshd)
blur = blurshd.reshape((-1, size, size, 1))
else:
blur = vector_f(A, img)
time_blur = time() - start_time
blur = blurshd.reshape((-1, size, size, 1))
# Normalize blurred images
blur = blur / jnp.max(blur, axis=(1, 2), keepdims=True)
# Add Gaussian noise
Expand Down Expand Up @@ -336,7 +378,10 @@ def distributed_data_generation(

ndata_per_proc = int(nimg // nproc)

imgs = jax.pmap(imgenf, static_broadcasted_argnums=(1, 2))(seeds, size, ndata_per_proc)
idx = np.arange(nproc)
imgs = jax.vmap(imgenf, (0, None, None))(idx, size, ndata_per_proc)

# imgs = jax.pmap(imgenf, static_broadcasted_argnums=(1, 2))(seeds, size, ndata_per_proc)

if not sharded:
imgs = imgs.reshape((-1, size, size, 1))
Expand Down Expand Up @@ -365,9 +410,12 @@ def ray_distributed_data_generation(
def data_gen(seed, size, ndata, imgf):
return imgf(seed, size, ndata)

# Use half of available CPU resources.
ar = ray.available_resources()
# Usage of half available CPU resources.
nproc = max(int(ar["CPU"]) // 2, 1)
if "CPU" not in ar:
warnings.warn("No CPU key in ray.available_resources() output")
nproc = max(int(ar.get("CPU", "1")) // 2, 1)
# nproc = max(int(ar["CPU"]) // 2, 1)
if nproc > nimg:
nproc = nimg
if nproc > 1 and nimg % nproc > 0:
Expand Down
7 changes: 5 additions & 2 deletions scico/functional/_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,8 @@ class NuclearNorm(Functional):
has_prox = True

def __call__(self, x: Union[Array, BlockArray]) -> float:
if x.ndim != 2:
raise ValueError("Input array must be two dimensional.")
return snp.sum(snp.linalg.svd(x, full_matrices=False, compute_uv=False))

def prox(
Expand All @@ -490,12 +492,13 @@ def prox(
:cite:`cai-2010-singular`.
Args:
v: Input array :math:`\mb{v}`.
v: Input array :math:`\mb{v}`. Required to be two-dimensional.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
classes.
"""

if v.ndim != 2:
raise ValueError("Input array must be two dimensional.")
svdU, svdS, svdV = snp.linalg.svd(v, full_matrices=False)
svdS = snp.maximum(0, svdS - lam)
return svdU @ snp.diag(svdS) @ svdV
Loading

0 comments on commit 19d37c9

Please sign in to comment.