From 19d37c9b9225e301506ff870c926aae790a75424 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 30 Apr 2024 11:15:12 -0600 Subject: [PATCH] Bump maximum `jaxlib` / `jax` version (#504) * Bump maximum jax version * Update actions versions * Fix error in docstring * Improve equation spacing * Bump maximum jax version * Bump min flax version to avoid errors with jax>=0.4.25 * Add some error checking, ensure ValueError raised as expected by test_misc.py for functionals * Fix bad formatting due to use of incorrect black version * Remove apparently-spurious pytest.fixture decorator * Bump jaxlib/jax max version * Resolve TypeError in functools.wraps call * Update training data generation * Change conda source for astra-toolbox * Debug macos ci failure * Remove scipy version upper bound * Bump flax version upper bound * Fix black format version * Fix merge error * Clean up exception handling * More robust handling of ray.available_resources() output --------- Co-authored-by: crstngc --- .github/workflows/check_files.yml | 2 +- .github/workflows/lint.yml | 4 +- .github/workflows/mypy.yml | 4 +- .github/workflows/pypi_upload.yml | 4 +- .github/workflows/pytest_latest.yml | 4 +- .github/workflows/pytest_macos.yml | 2 +- .github/workflows/pytest_ubuntu.yml | 8 +- .github/workflows/test_examples.yml | 4 +- examples/scripts/ct_astra_3d_tv_padmm.py | 10 +- examples/scripts/deconv_tv_padmm.py | 8 +- requirements.txt | 8 +- scico/flax/examples/data_generation.py | 112 +++++++++++++++------ scico/functional/_norm.py | 7 +- scico/numpy/_blockarray.py | 3 + scico/test/flax/test_examples_flax.py | 118 +++++++++-------------- 15 files changed, 163 insertions(+), 135 deletions(-) diff --git a/.github/workflows/check_files.yml b/.github/workflows/check_files.yml index d6bcc408f..7fee7c27c 100644 --- a/.github/workflows/check_files.yml +++ b/.github/workflows/check_files.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest steps: - name: checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - id: files uses: Ana06/get-changed-files@v2.2.0 continue-on-error: true diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index ca4e713c9..a4eeb8cee 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -12,8 +12,8 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: python-version: "3.10" - name: Black code formatter diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index ad8c71bec..5e1d474de 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -14,11 +14,11 @@ jobs: mypy: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: recursive - name: Install Python 3 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Install dependencies diff --git a/.github/workflows/pypi_upload.yml b/.github/workflows/pypi_upload.yml index fb9f3223f..f60c78f47 100644 --- a/.github/workflows/pypi_upload.yml +++ b/.github/workflows/pypi_upload.yml @@ -15,11 +15,11 @@ jobs: name: Upload package to PyPI runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: recursive - name: Install Python 3 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Install dependencies diff --git a/.github/workflows/pytest_latest.yml b/.github/workflows/pytest_latest.yml index 91cc81893..60754eda6 100644 --- a/.github/workflows/pytest_latest.yml +++ b/.github/workflows/pytest_latest.yml @@ -15,11 +15,11 @@ jobs: pytest-latest-jax: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: recursive - name: Install Python 3 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Install lastversion diff --git a/.github/workflows/pytest_macos.yml b/.github/workflows/pytest_macos.yml index 25b85b256..ec97f4c02 100644 --- a/.github/workflows/pytest_macos.yml +++ b/.github/workflows/pytest_macos.yml @@ -24,7 +24,7 @@ jobs: shell: bash -l {0} steps: # Check-out the repository under $GITHUB_WORKSPACE - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: recursive # Set up conda/mamba environment diff --git a/.github/workflows/pytest_ubuntu.yml b/.github/workflows/pytest_ubuntu.yml index 617d914f8..ffad0c04f 100644 --- a/.github/workflows/pytest_ubuntu.yml +++ b/.github/workflows/pytest_ubuntu.yml @@ -24,7 +24,7 @@ jobs: shell: bash -l {0} steps: # Check-out the repository under $GITHUB_WORKSPACE - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: recursive # Set up conda/mamba environment @@ -63,7 +63,7 @@ jobs: pip install -r requirements.txt pip install -r dev_requirements.txt mamba install -c conda-forge svmbir>=0.3.3 - mamba install -c astra-toolbox astra-toolbox + mamba install -c conda-forge astra-toolbox mamba install -c conda-forge pyyaml pip install --upgrade --force-reinstall scipy>=1.6.0 # Temporary fix for GLIBCXX_3.4.30 not found in conda forge version pip install bm3d>=4.0.0 @@ -96,9 +96,9 @@ jobs: needs: test runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python 3.10 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Install deps diff --git a/.github/workflows/test_examples.yml b/.github/workflows/test_examples.yml index aba9ebc88..74645a29a 100644 --- a/.github/workflows/test_examples.yml +++ b/.github/workflows/test_examples.yml @@ -22,7 +22,7 @@ jobs: shell: bash -l {0} steps: # Check-out the repository under $GITHUB_WORKSPACE - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: recursive # Set up conda/mamba environment @@ -59,7 +59,7 @@ jobs: python -m pip install --upgrade pip pip install -r requirements.txt pip install -r dev_requirements.txt - mamba install -c astra-toolbox astra-toolbox + mamba install -c conda-forge astra-toolbox mamba install -c conda-forge pyyaml pip install --upgrade --force-reinstall scipy>=1.6.0 # Temporary fix for GLIBCXX_3.4.30 not found in conda forge version pip install -r examples/examples_requirements.txt diff --git a/examples/scripts/ct_astra_3d_tv_padmm.py b/examples/scripts/ct_astra_3d_tv_padmm.py index 62b07004d..c6c090075 100644 --- a/examples/scripts/ct_astra_3d_tv_padmm.py +++ b/examples/scripts/ct_astra_3d_tv_padmm.py @@ -62,7 +62,7 @@ $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - C \mathbf{x} \|_2^2 + \lambda \| D \mathbf{x} \|_{2,1} \;,$$ -where $C$ is the convolution operator and $D$ is a finite difference +where $C$ is the X-ray transform and $D$ is a finite difference operator. This problem can be expressed as $$\mathrm{argmin}_{\mathbf{x}, \mathbf{z}} \; (1/2) \| \mathbf{y} - @@ -77,11 +77,11 @@ with - $$f = 0 \quad g = g_0 + g_1$$ - $$g_0(\mathbf{z}_0) = (1/2) \| \mathbf{y} - \mathbf{z}_0 \|_2^2 \quad + $$f = 0 \qquad g = g_0 + g_1$$ + $$g_0(\mathbf{z}_0) = (1/2) \| \mathbf{y} - \mathbf{z}_0 \|_2^2 \qquad g_1(\mathbf{z}_1) = \lambda \| \mathbf{z}_1 \|_{2,1}$$ - $$A = \left( \begin{array}{c} C \\ D \end{array} \right) \quad - B = \left( \begin{array}{cc} -I & 0 \\ 0 & -I \end{array} \right) \quad + $$A = \left( \begin{array}{c} C \\ D \end{array} \right) \qquad + B = \left( \begin{array}{cc} -I & 0 \\ 0 & -I \end{array} \right) \qquad \mathbf{c} = \left( \begin{array}{c} 0 \\ 0 \end{array} \right) \;.$$ This is a more complex splitting than that used in the diff --git a/examples/scripts/deconv_tv_padmm.py b/examples/scripts/deconv_tv_padmm.py index d1df0df46..29ce0534a 100644 --- a/examples/scripts/deconv_tv_padmm.py +++ b/examples/scripts/deconv_tv_padmm.py @@ -75,11 +75,11 @@ with - $$f = 0 \quad g = g_0 + g_1$$ - $$g_0(\mathbf{z}_0) = (1/2) \| \mathbf{y} - \mathbf{z}_0 \|_2^2 \quad + $$f = 0 \qquad g = g_0 + g_1$$ + $$g_0(\mathbf{z}_0) = (1/2) \| \mathbf{y} - \mathbf{z}_0 \|_2^2 \qquad g_1(\mathbf{z}_1) = \lambda \| \mathbf{z}_1 \|_{2,1}$$ - $$A = \left( \begin{array}{c} C \\ D \end{array} \right) \quad - B = \left( \begin{array}{cc} -I & 0 \\ 0 & -I \end{array} \right) \quad + $$A = \left( \begin{array}{c} C \\ D \end{array} \right) \qquad + B = \left( \begin{array}{cc} -I & 0 \\ 0 & -I \end{array} \right) \qquad \mathbf{c} = \left( \begin{array}{c} 0 \\ 0 \end{array} \right) \;.$$ This is a more complex splitting than that used in the diff --git a/requirements.txt b/requirements.txt index 5b7e68043..227948c20 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ typing_extensions numpy>=1.20.0 -scipy>=1.6.0,<1.13 +scipy>=1.6.0 imageio>=2.17 tifffile matplotlib -jaxlib>=0.4.3,<=0.4.23 -jax>=0.4.3,<=0.4.23 +jaxlib>=0.4.3,<=0.4.26 +jax>=0.4.3,<=0.4.26 orbax-checkpoint<=0.5.7 -flax>=0.6.1,<=0.7.5 +flax>=0.8.0,<=0.8.2 pyabel>=0.9.0 diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index f6eaea002..f3d306ed3 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -12,6 +12,7 @@ """ import os +import warnings from time import time from typing import Callable, List, Tuple, Union @@ -114,14 +115,13 @@ def generate_foam2_images(seed: float, size: int, ndata: int) -> Array: if not have_xdesign: raise RuntimeError("Package xdesign is required for use of this function.") - np.random.seed(seed) - saux = np.zeros((ndata, size, size, 1)) + # np.random.seed(seed) + saux = jnp.zeros((ndata, size, size, 1)) for i in range(ndata): foam = Foam2(size_range=[0.075, 0.0025], gap=1e-3, porosity=1) - saux[i, ..., 0] = discrete_phantom(foam, size=size) - + saux = saux.at[i, ..., 0].set(discrete_phantom(foam, size=size)) # normalize - saux = saux / np.max(saux, axis=(1, 2), keepdims=True) + saux = saux / jnp.max(saux, axis=(1, 2), keepdims=True) return saux @@ -143,15 +143,47 @@ def generate_foam1_images(seed: float, size: int, ndata: int) -> Array: if not have_xdesign: raise RuntimeError("Package xdesign is required for use of this function.") - np.random.seed(seed) - saux = np.zeros((ndata, size, size, 1)) + # np.random.seed(seed) + saux = jnp.zeros((ndata, size, size, 1)) for i in range(ndata): foam = Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1) - saux[i, ..., 0] = discrete_phantom(foam, size=size) + saux = saux.at[i, ..., 0].set(discrete_phantom(foam, size=size)) return saux +def vector_f(f_: Callable, v: Array) -> Array: + """Vectorize application of operator. + + Args: + f_: Operator to apply. + v: Array to evaluate. + + Returns: + Result of evaluating operator over given arrays. + """ + lf = lambda x: jnp.atleast_3d(f_(x.squeeze())) + auto_batch = jax.vmap(lf) + return auto_batch(v) + + +def batched_f(f_: Callable, vr: Array) -> Array: + """Distribute application of operator over a batch of vectors + among available processes. + + Args: + f_: Operator to apply. + vr: Batch of arrays to evaluate. + + Returns: + Result of evaluating operator over given batch of arrays. This + evaluation preserves the batch axis. + """ + nproc = jax.device_count() + res = jax.pmap(lambda i: vector_f(f_, vr[i]))(jnp.arange(nproc)) + return res + + def generate_ct_data( nimg: int, size: int, @@ -194,37 +226,44 @@ def generate_ct_data( time_dtgen = time() - start_time else: start_time = time() - img = imgfunc(seed, size, nimg) + img = distributed_data_generation(imgfunc, size, nimg, False) time_dtgen = time() - start_time # Clip to [0,1] range. img = jnp.clip(img, a_min=0, a_max=1) - # Shard array + nproc = jax.device_count() - imgshd = img.reshape((nproc, -1, size, size, 1)) # Configure a CT projection operator to generate synthetic measurements. angles = np.linspace(0, jnp.pi, nproj) # evenly spaced projection angles gt_sh = (size, size) detector_spacing = 1 - A = XRayTransform2D(gt_sh, size, detector_spacing, angles) # Radon transform operator + A = XRayTransform2D(gt_sh, size, detector_spacing, angles) # X-ray transform operator # Compute sinograms in parallel. - a_map = lambda v: jnp.atleast_3d(A @ v.squeeze()) start_time = time() - sinoshd = jax.pmap(lambda i: jax.lax.map(a_map, imgshd[i]))(jnp.arange(nproc)) + if nproc > 1: + # Shard array + imgshd = img.reshape((nproc, -1, size, size, 1)) + sinoshd = batched_f(A, imgshd) + sino = sinoshd.reshape((-1, nproj, size, 1)) + else: + sino = vector_f(A, img) + time_sino = time() - start_time - sino = sinoshd.reshape((-1, nproj, size, 1)) - # Normalize sinogram - sino = sino / size - # Compute filtered back projection in parallel. - afbp_map = lambda v: jnp.atleast_3d(A.fbp(v.squeeze())) + # Compute filtered back-projection in parallel. start_time = time() - fbpshd = jax.pmap(lambda i: jax.lax.map(afbp_map, sinoshd[i]))(jnp.arange(nproc)) + if nproc > 1: + fbpshd = batched_f(A.fbp, sinoshd) + fbp = fbpshd.reshape((-1, size, size, 1)) + else: + fbp = vector_f(A.fbp, sino) time_fbp = time() - start_time - # Clip to [0,1] range. - fbpshd = jnp.clip(fbpshd, a_min=0, a_max=1) - fbp = fbpshd.reshape((-1, size, size, 1)) + + # Normalize sinogram. + sino = sino / size + # Shift FBP to [0,1] range. + fbp = (fbp - fbp.min()) / (fbp.max() - fbp.min()) if verbose: # pragma: no cover platform = jax.lib.xla_bridge.get_backend().platform @@ -276,24 +315,27 @@ def generate_blur_data( time_dtgen = time() - start_time else: start_time = time() - img = imgfunc(seed, size, nimg) + img = distributed_data_generation(imgfunc, size, nimg, False) time_dtgen = time() - start_time + # Clip to [0,1] range. img = jnp.clip(img, a_min=0, a_max=1) - # Shard array nproc = jax.device_count() - imgshd = img.reshape((nproc, -1, size, size, 1)) # Configure blur operator ishape = (size, size) A = CircularConvolve(h=blur_kernel, input_shape=ishape) # Compute blurred images in parallel - a_map = lambda v: jnp.atleast_3d(A @ v.squeeze()) start_time = time() - blurshd = jax.pmap(lambda i: jax.lax.map(a_map, imgshd[i]))(jnp.arange(nproc)) + if nproc > 1: + # Shard array + imgshd = img.reshape((nproc, -1, size, size, 1)) + blurshd = batched_f(A, imgshd) + blur = blurshd.reshape((-1, size, size, 1)) + else: + blur = vector_f(A, img) time_blur = time() - start_time - blur = blurshd.reshape((-1, size, size, 1)) # Normalize blurred images blur = blur / jnp.max(blur, axis=(1, 2), keepdims=True) # Add Gaussian noise @@ -336,7 +378,10 @@ def distributed_data_generation( ndata_per_proc = int(nimg // nproc) - imgs = jax.pmap(imgenf, static_broadcasted_argnums=(1, 2))(seeds, size, ndata_per_proc) + idx = np.arange(nproc) + imgs = jax.vmap(imgenf, (0, None, None))(idx, size, ndata_per_proc) + + # imgs = jax.pmap(imgenf, static_broadcasted_argnums=(1, 2))(seeds, size, ndata_per_proc) if not sharded: imgs = imgs.reshape((-1, size, size, 1)) @@ -365,9 +410,12 @@ def ray_distributed_data_generation( def data_gen(seed, size, ndata, imgf): return imgf(seed, size, ndata) + # Use half of available CPU resources. ar = ray.available_resources() - # Usage of half available CPU resources. - nproc = max(int(ar["CPU"]) // 2, 1) + if "CPU" not in ar: + warnings.warn("No CPU key in ray.available_resources() output") + nproc = max(int(ar.get("CPU", "1")) // 2, 1) + # nproc = max(int(ar["CPU"]) // 2, 1) if nproc > nimg: nproc = nimg if nproc > 1 and nimg % nproc > 0: diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index 70cae546e..0df46e9ab 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -479,6 +479,8 @@ class NuclearNorm(Functional): has_prox = True def __call__(self, x: Union[Array, BlockArray]) -> float: + if x.ndim != 2: + raise ValueError("Input array must be two dimensional.") return snp.sum(snp.linalg.svd(x, full_matrices=False, compute_uv=False)) def prox( @@ -490,12 +492,13 @@ def prox( :cite:`cai-2010-singular`. Args: - v: Input array :math:`\mb{v}`. + v: Input array :math:`\mb{v}`. Required to be two-dimensional. lam: Proximal parameter :math:`\lambda`. kwargs: Additional arguments that may be used by derived classes. """ - + if v.ndim != 2: + raise ValueError("Input array must be two dimensional.") svdU, svdS, svdV = snp.linalg.svd(v, full_matrices=False) svdS = snp.maximum(0, svdS - lam) return svdU @ snp.diag(svdS) @ svdV diff --git a/scico/numpy/_blockarray.py b/scico/numpy/_blockarray.py index 5c00f4f64..c01b25d2f 100644 --- a/scico/numpy/_blockarray.py +++ b/scico/numpy/_blockarray.py @@ -174,6 +174,9 @@ def prop_ba(self): def _da_method_wrapper(method_name): method = getattr(Array, method_name) + if method.__name__ is None: + return method + @wraps(method) def method_ba(self, *args, **kwargs): result = tuple(getattr(x, method_name)(*args, **kwargs) for x in self) diff --git a/scico/test/flax/test_examples_flax.py b/scico/test/flax/test_examples_flax.py index 0949d593f..aafa2359b 100644 --- a/scico/test/flax/test_examples_flax.py +++ b/scico/test/flax/test_examples_flax.py @@ -3,7 +3,7 @@ import numpy as np -from jax import device_count +import jax import pytest @@ -85,7 +85,7 @@ def test_distdatagen_flag(): @pytest.mark.skipif( - device_count() == 1, reason="no processes for checking failure of distributed computing" + jax.device_count() == 1, reason="no processes for checking failure of distributed computing" ) def test_distdatagen_exception(): N = 16 @@ -95,11 +95,15 @@ def test_distdatagen_exception(): @pytest.mark.skipif(not have_ray, reason="ray package not installed") -@pytest.fixture(scope="module") def test_ray_distdatagen(): N = 16 nimg = 8 - dt = ray_distributed_data_generation(fake_data_gen, N, nimg) + + def random_data_gen(seed, N, ndata): + dt, key = random.randn((ndata, N, N, 1), seed=seed) + return dt + + dt = ray_distributed_data_generation(random_data_gen, N, nimg) assert dt.ndim == 4 assert dt.shape == (nimg, N, N, 1) @@ -111,18 +115,15 @@ def test_ct_data_generation(): nproj = 45 def random_img_gen(seed, size, ndata): - np.random.seed(seed) - return np.random.randn(ndata, size, size, 1) + key = jax.random.PRNGKey(seed) + key, subkey = jax.random.split(key) + shape = (ndata, size, size, 1) + return jax.random.normal(subkey, shape) - try: - img, sino, fbp = generate_ct_data(nimg, N, nproj, imgfunc=random_img_gen) - except Exception as e: - print(e) - assert 0 - else: - assert img.shape == (nimg, N, N, 1) - assert sino.shape == (nimg, nproj, N, 1) - assert fbp.shape == (nimg, N, N, 1) + img, sino, fbp = generate_ct_data(nimg, N, nproj, imgfunc=random_img_gen) + assert img.shape == (nimg, N, N, 1) + assert sino.shape == (nimg, nproj, N, 1) + assert fbp.shape == (nimg, N, N, 1) @pytest.mark.skipif(not have_astra, reason="astra package not installed") @@ -132,18 +133,15 @@ def test_ct_data_generation_jax(): nproj = 45 def random_img_gen(seed, size, ndata): - np.random.seed(seed) - return np.random.randn(ndata, size, size, 1) + key = jax.random.PRNGKey(seed) + key, subkey = jax.random.split(key) + shape = (ndata, size, size, 1) + return jax.random.normal(subkey, shape) - try: - img, sino, fbp = generate_ct_data(nimg, N, nproj, imgfunc=random_img_gen, prefer_ray=False) - except Exception as e: - print(e) - assert 0 - else: - assert img.shape == (nimg, N, N, 1) - assert sino.shape == (nimg, nproj, N, 1) - assert fbp.shape == (nimg, N, N, 1) + img, sino, fbp = generate_ct_data(nimg, N, nproj, imgfunc=random_img_gen, prefer_ray=False) + assert img.shape == (nimg, N, N, 1) + assert sino.shape == (nimg, nproj, N, 1) + assert fbp.shape == (nimg, N, N, 1) def test_blur_data_generation(): @@ -156,20 +154,9 @@ def random_img_gen(seed, size, ndata): np.random.seed(seed) return np.random.randn(ndata, size, size, 1) - try: - img, blurn = generate_blur_data( - nimg, - N, - blur_kernel, - noise_sigma=0.01, - imgfunc=random_img_gen, - ) - except Exception as e: - print(e) - assert 0 - else: - assert img.shape == (nimg, N, N, 1) - assert blurn.shape == (nimg, N, N, 1) + img, blurn = generate_blur_data(nimg, N, blur_kernel, noise_sigma=0.01, imgfunc=random_img_gen) + assert img.shape == (nimg, N, N, 1) + assert blurn.shape == (nimg, N, N, 1) def test_blur_data_generation_jax(): @@ -179,19 +166,16 @@ def test_blur_data_generation_jax(): blur_kernel = np.ones((n, n)) / (n * n) def random_img_gen(seed, size, ndata): - np.random.seed(seed) - return np.random.randn(ndata, size, size, 1) + key = jax.random.PRNGKey(seed) + key, subkey = jax.random.split(key) + shape = (ndata, size, size, 1) + return jax.random.normal(subkey, shape) - try: - img, blurn = generate_blur_data( - nimg, N, blur_kernel, noise_sigma=0.01, imgfunc=random_img_gen, prefer_ray=False - ) - except Exception as e: - print(e) - assert 0 - else: - assert img.shape == (nimg, N, N, 1) - assert blurn.shape == (nimg, N, N, 1) + img, blurn = generate_blur_data( + nimg, N, blur_kernel, noise_sigma=0.01, imgfunc=random_img_gen, prefer_ray=False + ) + assert img.shape == (nimg, N, N, 1) + assert blurn.shape == (nimg, N, N, 1) def test_rotation90(): @@ -369,19 +353,14 @@ def test_build_image_dataset(testobj, augment): dtconf = dict(testobj.dtconf) dtconf["augment"] = augment - try: - train_ds, test_ds = build_image_dataset(img_train, img_test, dtconf) - except Exception as e: - print(e) - assert 0 + train_ds, test_ds = build_image_dataset(img_train, img_test, dtconf) + assert train_ds["image"].shape == train_ds["label"].shape + assert test_ds["image"].shape == test_ds["label"].shape + assert test_ds["label"].shape[0] == num_test + if augment: + assert train_ds["label"].shape[0] == num_train * 3 else: - assert train_ds["image"].shape == train_ds["label"].shape - assert test_ds["image"].shape == test_ds["label"].shape - assert test_ds["label"].shape[0] == num_test - if augment: - assert train_ds["label"].shape[0] == num_train * 3 - else: - assert train_ds["label"].shape[0] == num_train + assert train_ds["label"].shape[0] == num_train def test_padded_circular_convolve(): @@ -392,14 +371,9 @@ def test_padded_circular_convolve(): x, key = random.randn((N, N, C), seed=2468) - try: - pcc_op = PaddedCircularConvolve(N, C, kernel_size, blur_sigma) - xblur = pcc_op(x) - except Exception as e: - print(e) - assert 0 - else: - assert xblur.shape == x.shape + pcc_op = PaddedCircularConvolve(N, C, kernel_size, blur_sigma) + xblur = pcc_op(x) + assert xblur.shape == x.shape def test_runtime_error_scalar():