diff --git a/README.rst b/README.rst index f665c2540..9775e7c0c 100644 --- a/README.rst +++ b/README.rst @@ -32,7 +32,11 @@ .. image:: https://badge.fury.io/py/scico.svg :target: https://badge.fury.io/py/scico - :alt: Current PyPI package version + :alt: PyPI package version + +.. image:: https://static.pepy.tech/personalized-badge/scico?period=month&left_color=grey&right_color=brightgreen + :target: https://pepy.tech/project/scico + :alt: PyPI download statistics .. image:: https://raw.githubusercontent.com/jupyter/design/master/logos/Badges/nbviewer_badge.svg :target: https://nbviewer.jupyter.org/github/lanl/scico-data/tree/main/notebooks/index.ipynb diff --git a/data b/data index 4fb7ef5d5..71f16aa2c 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit 4fb7ef5d53857f0bef28248a3a16267672b6d66b +Subproject commit 71f16aa2ca6e67fa8f6e35006e1a572386306bde diff --git a/docs/source/conf.py b/docs/source/conf.py index 0cae5f50c..6c00aa24e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -169,7 +169,7 @@ def patched_parse(self): # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ["tmp", "*.tmp.*", "*.tmp", "index.ipynb"] +exclude_patterns = ["tmp", "*.tmp.*", "*.tmp", "index.ipynb", "exampledepend.rst"] # If true, '()' will be appended to :func: etc. cross-reference text. add_function_parentheses = False diff --git a/docs/source/exampledepend.rst b/docs/source/exampledepend.rst index 2528e0330..3a7ad4ffe 100644 --- a/docs/source/exampledepend.rst +++ b/docs/source/exampledepend.rst @@ -1,4 +1,4 @@ -.. _example_dependencies: +.. _example_depend: Example Dependencies -------------------- diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 1d1e5dbf3..cdc3f340f 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -65,6 +65,7 @@ Miscellaneous :maxdepth: 1 examples/demosaic_ppp_bm3d_admm + examples/denoise_l1tv_iso_admm examples/denoise_tv_iso_admm examples/denoise_tv_iso_pgm examples/denoise_tv_iso_multi @@ -109,6 +110,7 @@ Total Variation examples/deconv_microscopy_allchn_tv_admm examples/deconv_tv_admm examples/deconv_tv_admm_tune + examples/denoise_l1tv_iso_admm examples/denoise_tv_iso_admm examples/denoise_tv_iso_pgm examples/denoise_tv_iso_multi @@ -153,6 +155,7 @@ ADMM examples/deconv_tv_admm examples/deconv_tv_admm_tune examples/demosaic_ppp_bm3d_admm + examples/denoise_l1tv_iso_admm examples/denoise_tv_iso_admm examples/denoise_tv_iso_multi examples/sparsecode_admm diff --git a/docs/source/install.rst b/docs/source/install.rst index d39aaccb7..4ac3224f2 100644 --- a/docs/source/install.rst +++ b/docs/source/install.rst @@ -61,7 +61,7 @@ The instructions above install a CPU-only version of SCICO. To install a version Additional Dependencies ----------------------- -For instructions on installing dependencies related to the examples please see :ref:`example_dependencies`. +For instructions on installing dependencies related to the examples please see :ref:`example_depend`. For Developers diff --git a/docs/source/references.bib b/docs/source/references.bib index cefbd80be..bae66749d 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -1,3 +1,15 @@ +@Article {alliney-1992-digital, + author = {Alliney, Stefano}, + journal = {IEEE Transactions on Signal Processing}, + title = {Digital filters as absolute norm regularizers}, + year = 1992, + volume = 40, + number = 6, + pages = {1548--1562}, + doi = {10.1109/78.139258}, + month = Jun +} + @Article {almeida-2013-deconvolving, author = {Almeida, Mariana S. C. and Figueiredo, M\'ario}, journal = {IEEE Transactions on Image Processing}, @@ -71,6 +83,19 @@ @Book {beck-2017-first isbn = 1611974984 } +@Software {bradbury-2018-jax, + author = {James Bradbury and Roy Frostig and Peter Hawkins and + Matthew James Johnson and Chris Leary and Dougal + Maclaurin and George Necula and Adam Paszke and Jake + Vander{P}las and Skye Wanderman-{M}ilne and Qiao + Zhang}, + title = {{JAX}: composable transformations of + {P}ython+{N}um{P}y programs}, + url = {http://github.com/google/jax}, + version = {0.2.5}, + year = {2018} +} + @Article {boyd-2010-distributed, title = {Distributed optimization and statistical learning via the alternating direction method of multipliers}, @@ -114,7 +139,7 @@ @Article {chambolle-2010-firstorder author = {Antonin Chambolle and Thomas Pock}, title = {A First-Order Primal-Dual Algorithm for Convex Problems with~Applications to Imaging}, - journal = {Journal of Mathematical Imaging and Vision}, + journal = {Journal of Mathematical Imaging and Vision}, doi = {10.1007/s10851-010-0251-1}, year = 2010, month = Dec, @@ -137,7 +162,7 @@ @Article {clinthorne-1993-preconditioning doi = {10.1109/42.222670} } -@InProceedings{dabov-2008-image, +@InProceedings {dabov-2008-image, author = {Kostadin Dabov and Alessandro Foi and Vladimir Katkovnik and Karen Egiazarian}, title = {Image restoration by sparse {3D} transform-domain @@ -173,7 +198,7 @@ @Article {esser-2010-general Primal-Dual Algorithms for Convex Optimization in Imaging Science}, journal = {SIAM Journal on Imaging Sciences}, - doi = {10.1137/09076934x}, + doi = {10.1137/09076934x}, year = 2010, month = Jan, volume = 3, @@ -181,6 +206,15 @@ @Article {esser-2010-general pages = {1015--1046} } +@PhDThesis {esser-2010-primal, + author = {Ernie Esser}, + title = {Primal Dual Algorithms for Convex Models and + Applications to Image Restoration, Registration and + Nonlocal Inpainting}, + school = {University of California Los Angeles}, + year = 2010 +} + @InProceedings {florea-2017-robust, title = {A Robust {FISTA}-Like Algorithm}, author = {Mihai I. Florea and Sergiy A. Vorobyov}, @@ -222,6 +256,18 @@ @Article {glowinski-1975-approximation url = {http://eudml.org/doc/193269} } +@Article {goldstein-2009-split, + author = {Tom Goldstein and Stanley Osher}, + title = {The Split {B}regman Method for L1-Regularized + Problems}, + journal = {SIAM Journal on Imaging Sciences}, + volume = 2, + number = 2, + pages = {323--343}, + year = 2009, + doi = {10.1137/080725891} +} + @Misc {goldstein-2014-fasta, title = {A Field Guide to Forward-Backward Splitting with a {FASTA} Implementation}, @@ -308,7 +354,7 @@ @Article {parikh-2014-proximal } @Misc {pfister-2021-scico, - author = {Luke Pfister‬ and Thilo Balke and Fernando Davis and + author = {Luke Pfister and Thilo Balke and Fernando Davis and Cristina Garcia-Cardona and Michael McCann and Brendt Wohlberg}, title = {{S}cientific {C}omputational {I}maging {CO}de @@ -324,13 +370,35 @@ @InProceedings {pock-2011-diagonal algorithms in convex optimization}, booktitle = {Proceedings of the International Conference on Computer Vision (ICCV)}, - doi = {10.1109/iccv.2011.6126441}, + doi = {10.1109/iccv.2011.6126441}, pages = {1762--1769}, year = 2011, month = Nov, address = {Barcelona, Spain} } +@Misc {pyabel-2022, + author = {Stephen Gibson and Daniel Hickstein and Roman + Yurchak, Mikhail Ryazanov and Dhrubajyoti Das and + Gilbert Shih}, + title = {PyAbel}, + howpublished = {PyAbel/PyAbel: v0.8.5}, + year = 2022, + doi = {10.5281/zenodo.5888391} +} + +@Article {rudin-1992-nonlinear, + author = {Leonid I. Rudin and Stanley Osher and Emad Fatemi}, + title = {Nonlinear total variation based noise removal + algorithms}, + journal = {Physica D: Nonlinear Phenomena}, + volume = 60, + number = {1--4}, + pages = {259-268}, + year = 1992, + doi = {10.1016/0167-2789(92)90242-F} +} + @Article {sauer-1993-local, title = {A local update strategy for iterative reconstruction from projections}, @@ -369,15 +437,6 @@ @Misc {svmbir-2020 year = 2020 } -@Misc {pyabel-2022, - author = {Stephen Gibson and Daniel Hickstein and Roman Yurchak, - Mikhail Ryazanov and Dhrubajyoti Das and Gilbert Shih}, - title = {PyAbel}, - howpublished = {PyAbel/PyAbel: v0.8.5}, - year = 2022, - doi = {10.5281/zenodo.5888391} -} - @InProceedings {venkatakrishnan-2013-plugandplay2, author = {Singanallur V. Venkatakrishnan and Charles A. Bouman and Brendt Wohlberg}, @@ -414,14 +473,14 @@ @Book {voelz-2011-computational isbn = 9780819482044, } -@article{yang-2012-linearized, +@Article {yang-2012-linearized, author = {Junfeng Yang and Xiaoming Yuan}, title = {Linearized augmented {L}agrangian and alternating direction methods for nuclear norm minimization}, journal = {Mathematics of Computation}, - doi = {10.1090/s0025-5718-2012-02598-1}, + doi = {10.1090/s0025-5718-2012-02598-1}, year = 2012, - month = mar, + month = Mar, volume = 82, number = 281, pages = {301--329} diff --git a/examples/scripts/README.rst b/examples/scripts/README.rst index cad8b0a97..615e56b68 100644 --- a/examples/scripts/README.rst +++ b/examples/scripts/README.rst @@ -62,6 +62,8 @@ Miscellaneous `demosaic_ppp_bm3d_admm.py `_ Image Demosaicing (ADMM Plug-and-Play Priors w/ BM3D) + `denoise_l1tv_iso_admm.py `_ + ℓ1 Total Variation (ADMM) `denoise_tv_iso_admm.py `_ Isotropic Total Variation (ADMM) `denoise_tv_iso_pgm.py `_ @@ -118,6 +120,8 @@ Total Variation Image Deconvolution (ADMM w/ Total Variation) `deconv_tv_admm_tune.py `_ Image Deconvolution Parameter Tuning + `denoise_l1tv_iso_admm.py `_ + ℓ1 Total Variation (ADMM) `denoise_tv_iso_admm.py `_ Isotropic Total Variation (ADMM) `denoise_tv_iso_pgm.py `_ @@ -174,6 +178,8 @@ ADMM Image Deconvolution Parameter Tuning `demosaic_ppp_bm3d_admm.py `_ Image Demosaicing (ADMM Plug-and-Play Priors w/ BM3D) + `denoise_l1tv_iso_admm.py `_ + ℓ1 Total Variation (ADMM) `denoise_tv_iso_admm.py `_ Isotropic Total Variation (ADMM) `denoise_tv_iso_multi.py `_ diff --git a/examples/scripts/denoise_l1tv_iso_admm.py b/examples/scripts/denoise_l1tv_iso_admm.py new file mode 100644 index 000000000..37721a281 --- /dev/null +++ b/examples/scripts/denoise_l1tv_iso_admm.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# This file is part of the SCICO package. Details of the copyright +# and user license can be found in the 'LICENSE.txt' file distributed +# with the package. + +r""" +ℓ1 Total Variation (ADMM) +========================= + +This example demonstrates impulse noise removal via ℓ1 total variation +:cite:`alliney-1992-digital` :cite:`esser-2010-primal` (Sec. 2.4.4) +(i.e. total variation regularization with an ℓ1 data fidelity term), +minimizing the functional + + $$\mathrm{argmin}_{\mathbf{x}} \; \| \mathbf{y} - \mathbf{x} + \|_1 + \lambda \| C \mathbf{x} \|_1 \;,$$ + +where $\mathbf{y}$ is the noisy image, $C$ is a 2D finite difference +operator, and $\mathbf{x}$ is the denoised image. +""" + +import jax + +from xdesign import SiemensStar, discrete_phantom + +import scico.numpy as snp +from scico import functional, linop, loss, metric, plot +from scico.examples import spnoise +from scico.optimize.admm import ADMM, LinearSubproblemSolver +from scico.util import device_info +from scipy.ndimage import median_filter + +""" +Create a ground truth image and impose salt & pepper noise to create a +noisy test image. +""" +N = 256 # image size +phantom = SiemensStar(16) +x_gt = snp.pad(discrete_phantom(phantom, 240), 8) +x_gt = 0.5 * x_gt / x_gt.max() +x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU +y = spnoise(x_gt, 0.5) + + +""" +Denoise with median filtering. +""" +x_med = median_filter(y, size=(5, 5)) + + +""" +Denoise with ℓ1 total variation. +""" +λ = 1.5e0 +g_loss = loss.Loss(y=y, f=functional.L1Norm()) +g_tv = λ * functional.L21Norm() +# The append=0 option makes the results of horizontal and vertical finite +# differences the same shape, which is required for the L21Norm. +C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) + +solver = ADMM( + f=None, + g_list=[g_loss, g_tv], + C_list=[linop.Identity(input_shape=y.shape), C], + rho_list=[5e0, 5e0], + x0=y, + maxiter=100, + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 20}), + itstat_options={"display": True, "period": 10}, +) + +print(f"Solving on {device_info()}\n") +x_tv = solver.solve() +hist = solver.itstat_object.history(transpose=True) + + +""" +Plot results. +""" +plt_args = dict(norm=plot.matplotlib.colors.Normalize(vmin=0, vmax=1.0)) +fig, ax = plot.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(13, 12)) +plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0, 0], **plt_args) +plot.imview(y, title="Noisy image", fig=fig, ax=ax[0, 1], **plt_args) +plot.imview( + x_med, + title=f"Median filtering: {metric.psnr(x_gt, x_med):.2f} (dB)", + fig=fig, + ax=ax[1, 0], + **plt_args, +) +plot.imview( + x_tv, + title=f"ℓ1-TV denoising: {metric.psnr(x_gt, x_tv):.2f} (dB)", + fig=fig, + ax=ax[1, 1], + **plt_args, +) +fig.show() + + +""" +Plot convergence statistics. +""" +fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) +plot.plot( + hist.Objective, + title="Objective function", + xlbl="Iteration", + ylbl="Functional value", + fig=fig, + ax=ax[0], +) +plot.plot( + snp.vstack((hist.Prml_Rsdl, hist.Dual_Rsdl)).T, + ptyp="semilogy", + title="Residuals", + xlbl="Iteration", + lgnd=("Primal", "Dual"), + fig=fig, + ax=ax[1], +) +fig.show() + + +input("\nWaiting for input to close figures and exit") diff --git a/examples/scripts/denoise_tv_iso_admm.py b/examples/scripts/denoise_tv_iso_admm.py index 254446f78..4d0529d01 100644 --- a/examples/scripts/denoise_tv_iso_admm.py +++ b/examples/scripts/denoise_tv_iso_admm.py @@ -9,10 +9,11 @@ ================================ This example compares denoising via isotropic and anisotropic total -variation (TV) regularization. It solves the denoising problem +variation (TV) regularization :cite:`rudin-1992-nonlinear` +:cite:`goldstein-2009-split`. It solves the denoising problem $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - \mathbf{x} - \|^2 + \lambda R(\mathbf{x}) \;,$$ + \|_2^2 + \lambda R(\mathbf{x}) \;,$$ where $R$ is either the isotropic or anisotropic TV regularizer. In SCICO, switching between these two regularizers is a one-line @@ -53,11 +54,11 @@ """ -Denoise with isotropic total variation +Denoise with isotropic total variation. """ -reg_weight_iso = 1.4e0 +λ_iso = 1.4e0 f = loss.SquaredL2Loss(y=y) -g_iso = reg_weight_iso * functional.L21Norm() +g_iso = λ_iso * functional.L21Norm() # The append=0 option makes the results of horizontal and vertical finite # differences the same shape, which is required for the L21Norm. @@ -82,8 +83,8 @@ Denoise with anisotropic total variation for comparison. """ # Tune the weight to give the same data fidelty as the isotropic case. -reg_weight_aniso = 1.2e0 -g_aniso = reg_weight_aniso * functional.L1Norm() +λ_aniso = 1.2e0 +g_aniso = λ_aniso * functional.L1Norm() solver = ADMM( f=f, @@ -102,7 +103,7 @@ """ -Compute the data fidelity. +Compute and print the data fidelity. """ for x, name in zip((x_iso, x_aniso), ("Isotropic", "Anisotropic")): df = f(x) diff --git a/examples/scripts/denoise_tv_iso_multi.py b/examples/scripts/denoise_tv_iso_multi.py index db2c5fb27..238a9411a 100644 --- a/examples/scripts/denoise_tv_iso_multi.py +++ b/examples/scripts/denoise_tv_iso_multi.py @@ -12,7 +12,7 @@ in solving the isotropic total variation (TV) denoising problem $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - \mathbf{x} - \|^2 + \lambda R(\mathbf{x}) \;,$$ + \|_2^2 + \lambda R(\mathbf{x}) \;,$$ where $R$ is the isotropic TV: the sum of the norms of the gradient vectors at each point in the image $\mathbf{x}$. diff --git a/examples/scripts/denoise_tv_iso_pgm.py b/examples/scripts/denoise_tv_iso_pgm.py index 28455c44f..86d349ea6 100644 --- a/examples/scripts/denoise_tv_iso_pgm.py +++ b/examples/scripts/denoise_tv_iso_pgm.py @@ -14,7 +14,7 @@ denoising problem $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - \mathbf{x} - \|^2 + \lambda R(\mathbf{x}) + \iota_C(\mathbf{x}) \;,$$ + \|_2^2 + \lambda R(\mathbf{x}) + \iota_C(\mathbf{x}) \;,$$ where $R$ is a TV regularizer, $\iota_C(\cdot)$ is the indicator function of constraint set $C$, and $C = \{ \mathbf{x} \, | \, x_i \in [0, 1] \}$, diff --git a/examples/scripts/index.rst b/examples/scripts/index.rst index bf2e78468..8a50ac123 100644 --- a/examples/scripts/index.rst +++ b/examples/scripts/index.rst @@ -43,6 +43,7 @@ Miscellaneous ^^^^^^^^^^^^^ - demosaic_ppp_bm3d_admm.py + - denoise_l1tv_iso_admm.py - denoise_tv_iso_admm.py - denoise_tv_iso_pgm.py - denoise_tv_iso_multi.py @@ -78,6 +79,7 @@ Total Variation - deconv_microscopy_allchn_tv_admm.py - deconv_tv_admm.py - deconv_tv_admm_tune.py + - denoise_l1tv_iso_admm.py - denoise_tv_iso_admm.py - denoise_tv_iso_pgm.py - denoise_tv_iso_multi.py @@ -113,6 +115,7 @@ ADMM - deconv_tv_admm.py - deconv_tv_admm_tune.py - demosaic_ppp_bm3d_admm.py + - denoise_l1tv_iso_admm.py - denoise_tv_iso_admm.py - denoise_tv_iso_multi.py - sparsecode_admm.py diff --git a/scico/examples.py b/scico/examples.py index a162fb677..9ce95a108 100644 --- a/scico/examples.py +++ b/scico/examples.py @@ -19,7 +19,7 @@ import imageio import scico.numpy as snp -from scico import util +from scico import random, util from scico.typing import Array, JaxArray, Shape from scipy.ndimage import zoom @@ -291,3 +291,29 @@ def create_circular_phantom( img = img.at[dist_map < r].set(val) return img + + +def spnoise(img: Array, nfrac: float, nmin: float = 0.0, nmax: float = 1.0) -> Array: + """Return image with salt & pepper noise imposed on it. + + Args: + img: Input image. + nfrac: Desired fraction of pixels corrupted by noise. + nmin: Lower value for noise (pepper). Default 0.0. + nmax: Upper value for noise (salt). Default 1.0. + + Returns: + Noisy image + """ + + if isinstance(img, np.ndarray): + spm = np.random.uniform(-1.0, 1.0, img.shape) + imgn = img.copy() + imgn[spm < nfrac - 1.0] = nmin + imgn[spm > 1.0 - nfrac] = nmax + else: + spm, key = random.uniform(shape=img.shape, minval=-1.0, maxval=1.0, seed=0) + imgn = img + imgn = imgn.at[spm < nfrac - 1.0].set(nmin) + imgn = imgn.at[spm > 1.0 - nfrac].set(nmax) + return imgn diff --git a/scico/functional/_functional.py b/scico/functional/_functional.py index 0dfe81108..e6c69a888 100644 --- a/scico/functional/_functional.py +++ b/scico/functional/_functional.py @@ -86,8 +86,7 @@ def prox( lam: Proximal parameter :math:`\lambda`. kwargs: Additional arguments that may be used by derived classes. These include ``x0``, an initial guess for the - minimizer in the defintion of :math:`\mathrm{prox}`. - + minimizer in the definition of :math:`\mathrm{prox}`. """ if not self.has_prox: raise NotImplementedError( diff --git a/scico/linop/radon_svmbir.py b/scico/linop/radon_svmbir.py index e34f899c6..437fadaf8 100644 --- a/scico/linop/radon_svmbir.py +++ b/scico/linop/radon_svmbir.py @@ -57,8 +57,8 @@ def __init__( angles: Array of projection angles in radians, should be increasing. num_channels: Number of pixels in the sinogram. - center_offset: Position of the detector center relative to the center-of-rotation, - in units of pixels + center_offset: Position of the detector center relative to + the center of rotation, in units of pixels. is_masked: If ``True``, the valid region of the image is determined by a mask defined as the circle inscribed within the image boundary. Otherwise, the whole image @@ -178,9 +178,11 @@ def _bproj_hcb(self, y): class SVMBIRExtendedLoss(Loss): - r"""Extended Weighted squared :math:`\ell_2` loss with svmbir CT projector. + r"""Extended Weighted squared :math:`\ell_2` loss with svmbir CT + projector. - Generalization of the weighted squared :math:`\ell_2` loss of a CT reconstruction problem, + Generalization of the weighted squared :math:`\ell_2` loss of a CT + reconstruction problem, .. math:: \alpha \norm{\mb{y} - A(\mb{x})}_W^2 = @@ -189,21 +191,23 @@ class SVMBIRExtendedLoss(Loss): where :math:`A` is a :class:`.ParallelBeamProjector`, :math:`\alpha` is the scaling parameter and :math:`W` is an instance - of :class:`scico.linop.Diagonal`. If :math:`W` is None, it is set to - :class:`scico.linop.Identity`. + of :class:`scico.linop.Diagonal`. If :math:`W` is ``None``, it is set + to :class:`scico.linop.Identity`. - The extended loss differs from a typical weighted squared :math:`\ell_2` loss - is the following aspects. + The extended loss differs from a typical weighted squared + :math:`\ell_2` loss as follows. When ``positivity=True``, the prox projects onto the non-negative - orthant and the loss is infinite if any element of the input is negative. - When the ``is_masked`` option of the associated :class:`.ParallelBeamProjector` is `True`, - the reconstruction is computed over a masked region of the image as described + orthant and the loss is infinite if any element of the input is + negative. When the ``is_masked`` option of the associated + :class:`.ParallelBeamProjector` is ``True``, the reconstruction is + computed over a masked region of the image as described in class :class:`.ParallelBeamProjector`. """ def __init__( self, *args, + scale: float = 0.5, prox_kwargs: Optional[dict] = None, positivity: bool = False, W: Optional[Diagonal] = None, @@ -215,15 +219,15 @@ def __init__( y: Sinogram measurement. A: Forward operator. scale: Scaling parameter. - W: Weighting diagonal operator. Must be non-negative. - If ``None``, defaults to :class:`.Identity`. prox_kwargs: Dictionary of arguments passed to the - :meth:`svmbir.recon` prox routine. Defaults to - {"maxiter": 1000, "ctol": 0.001}. + :meth:`svmbir.recon` prox routine. Defaults to + {"maxiter": 1000, "ctol": 0.001}. positivity: Enforce positivity in the prox operation. The - loss is infinite if any element of the input is negative. + loss is infinite if any element of the input is negative. + W: Weighting diagonal operator. Must be non-negative. + If ``None``, defaults to :class:`.Identity`. """ - super().__init__(*args, **kwargs) + super().__init__(*args, scale=scale, **kwargs) if not isinstance(self.A, ParallelBeamProjector): raise ValueError("LinearOperator A must be a radon_svmbir.ParallelBeamProjector.") @@ -307,8 +311,8 @@ class SVMBIRWeightedSquaredL2Loss(SVMBIRExtendedLoss, WeightedSquaredL2Loss): where :math:`A` is a :class:`.ParallelBeamProjector`, :math:`\alpha` is the scaling parameter and :math:`W` is an instance - of :class:`scico.linop.Diagonal`. If :math:`W` is None, it is set to - :class:`scico.linop.Identity`. + of :class:`scico.linop.Diagonal`. If :math:`W` is ``None``, it is set + to :class:`scico.linop.Identity`. """ def __init__( @@ -323,17 +327,18 @@ def __init__( y: Sinogram measurement. A: Forward operator. scale: Scaling parameter. - W: Weighting diagonal operator. Must be non-negative. - If None, defaults to :class:`.Identity`. + W: Weighting diagonal operator. Must be non-negative. + If ``None``, defaults to :class:`.Identity`. prox_kwargs: Dictionary of arguments passed to the - :meth:`svmbir.recon` prox routine. Defaults to - {"maxiter": 1000, "ctol": 0.001}. + :meth:`svmbir.recon` prox routine. Defaults to + {"maxiter": 1000, "ctol": 0.001}. """ super().__init__(*args, **kwargs, prox_kwargs=prox_kwargs, positivity=False) if self.A.is_masked: raise ValueError( - "is_masked must be false for the ParallelBeamProjector in SVMBIRWeightedSquaredL2Loss." + "is_masked must be false for the ParallelBeamProjector in " + "SVMBIRWeightedSquaredL2Loss." ) diff --git a/scico/loss.py b/scico/loss.py index 498721fca..2975a111f 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -41,9 +41,9 @@ class Loss(functional.Functional): Generic loss function .. math:: - \alpha l(\mb{y}, A(\mb{x})) \;, + \alpha f(\mb{y}, A(\mb{x})) \;, - where :math:`\alpha` is the scaling parameter and :math:`l(\cdot)` is + where :math:`\alpha` is the scaling parameter and :math:`f(\cdot)` is the loss functional. """ @@ -51,14 +51,19 @@ def __init__( self, y: Union[JaxArray, BlockArray], A: Optional[Union[Callable, operator.Operator]] = None, - scale: float = 0.5, + f: Optional[functional.Functional] = None, + scale: float = 1.0, ): r""" Args: y: Measurement. A: Forward operator. Defaults to ``None``, in which case ``self.A`` is a :class:`.Identity`. - scale: Scaling parameter. Default: 0.5. + f: Functional :math:`f`. If defined, the loss function is + :math:`\alpha f(\mb{y} - A(\mb{x}))`. If ``None``, then + :meth:`__call__` and :meth:`prox` (where appropriate) must + be defined in a derived class. + scale: Scaling parameter. Default: 1.0. """ self.y = ensure_on_device(y) @@ -66,12 +71,15 @@ def __init__( # y and x must have same shape A = linop.Identity(self.y.shape) self.A = A + self.f = f self.scale = scale # Set functional-specific flags - self.has_prox = False self.has_eval = True - + if self.f is not None and isinstance(self.A, linop.Identity): + self.has_prox = True + else: + self.has_prox = False super().__init__() def __call__(self, x: Union[JaxArray, BlockArray]) -> float: @@ -80,7 +88,37 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float: Args: x: Point at which to evaluate loss. """ - raise NotImplementedError + if self.f is None: + raise NotImplementedError( + "Functional l is not defined and __call__ has" " not been overridden" + ) + return self.scale * self.f(self.A(x) - self.y) + + def prox( + self, v: Union[JaxArray, BlockArray], lam: float, **kwargs + ) -> Union[JaxArray, BlockArray]: + r"""Scaled proximal operator of loss function. + + Evaluate scaled proximal operator of this loss function, with + scaling :math:`\lambda` = `lam` and evaluated at point + :math:`\mb{v}` = `v`. If :meth:`prox` is not defined in a derived + class, and if operator :math:`A` is the identity operator, then + the proximal operator is computed using the proximal operator of + functional :math:`l`, via Theorem 6.11 in :cite:`beck-2017-first`. + + Args: + v: Point at which to evaluate prox function. + lam: Proximal parameter :math:`\lambda`. + kwargs: Additional arguments that may be used by derived + classes. These include ``x0``, an initial guess for the + minimizer in the defintion of :math:`\mathrm{prox}`. + """ + if not self.has_prox: + raise NotImplementedError( + f"prox is not implemented for {type(self)} when A is {type(self.A)}; " + "must be Identity" + ) + return self.f.prox(v - self.y, self.scale * lam, **kwargs) + self.y @_loss_mul_div_wrapper def __mul__(self, other): diff --git a/scico/test/linop/test_radon_svmbir.py b/scico/test/linop/test_radon_svmbir.py index 5a7c8e3f0..b1f6b6c9d 100644 --- a/scico/test/linop/test_radon_svmbir.py +++ b/scico/test/linop/test_radon_svmbir.py @@ -9,7 +9,7 @@ from scico.linop import Diagonal from scico.loss import WeightedSquaredL2Loss from scico.test.linop.test_linop import adjoint_test -from scico.test.test_functional import prox_test +from scico.test.prox import prox_test try: import svmbir diff --git a/scico/test/prox.py b/scico/test/prox.py new file mode 100644 index 000000000..5f6199e5c --- /dev/null +++ b/scico/test/prox.py @@ -0,0 +1,42 @@ +import numpy as np + +import scico.numpy as snp +from scico.solver import minimize + + +def prox_func(x, v, f, alpha): + """Evaluate functional of which the proximal operator is the argmin.""" + return 0.5 * snp.sum(snp.abs(x.reshape(v.shape) - v) ** 2) + alpha * snp.array( + f(x.reshape(v.shape)), dtype=snp.float64 + ) + + +def prox_solve(v, v0, f, alpha): + """Evaluate the alpha-scaled proximal operator of f at v, using v0 as an + initial point for the optimization.""" + fnc = lambda x: prox_func(x, v, f, alpha) + fmn = minimize( + fnc, + v0, + method="Nelder-Mead", + options={"maxiter": 1000, "xatol": 1e-9, "fatol": 1e-9}, + ) + + return fmn.x.reshape(v.shape), fmn.fun + + +def prox_test(v, nrm, prx, alpha, x0=None): + """Test the alpha-scaled proximal operator function prx of norm functional nrm + at point v.""" + # Evaluate the proximal operator at v + px = snp.array(prx(v, alpha, v0=x0)) + # Proximal operator functional value (i.e. Moreau envelope) at v + pf = prox_func(px, v, nrm, alpha) + # Brute-force solve of the proximal operator at v + mx, mf = prox_solve(v, px, nrm, alpha) + + # Compare prox functional value with brute-force solution + if pf < mf: + return # prox gave a lower cost than brute force, so it passes + + np.testing.assert_allclose(pf, mf, rtol=1e-6) diff --git a/scico/test/test_examples.py b/scico/test/test_examples.py index af8b6056e..7d11c3484 100644 --- a/scico/test/test_examples.py +++ b/scico/test/test_examples.py @@ -6,12 +6,14 @@ import imageio import pytest +import scico.numpy as snp from scico.examples import ( create_circular_phantom, create_cone, downsample_volume, epfl_deconv_data, rgb2gray, + spnoise, tile_volume_slices, volume_read, ) @@ -89,3 +91,14 @@ def test_create_cone(img_shape): assert x_gt.shape == img_shape # check symmetry assert np.abs(x_gt[(0,) * len(img_shape)] - x_gt[(-1,) * len(img_shape)]) < 1e-6 + + +def test_spnoise(): + x = 0.5 * np.ones((10, 11)) + y = spnoise(x, 0.5, nmin=0.01, nmax=0.99) + assert np.all(y >= 0.01) + assert np.all(y <= 0.99) + x = 0.5 * snp.ones((10, 11)) + y = spnoise(x, 0.5, nmin=0.01, nmax=0.99) + assert np.all(y >= 0.01) + assert np.all(y <= 0.99) diff --git a/scico/test/test_functional.py b/scico/test/test_functional.py index 1bd59067c..3ce83358f 100644 --- a/scico/test/test_functional.py +++ b/scico/test/test_functional.py @@ -10,12 +10,12 @@ import jax import pytest +from prox import prox_test import scico.numpy as snp -from scico import denoiser, functional, linop, loss +from scico import denoiser, functional from scico.blockarray import BlockArray from scico.random import randn -from scico.solver import minimize NO_BLOCK_ARRAY = [functional.L21Norm, functional.NuclearNorm] NO_COMPLEX = [ @@ -23,44 +23,6 @@ ] -def prox_func(x, v, f, alpha): - """Evaluate functional of which the proximal operator is the argmin.""" - return 0.5 * snp.sum(snp.abs(x.reshape(v.shape) - v) ** 2) + alpha * snp.array( - f(x.reshape(v.shape)), dtype=snp.float64 - ) - - -def prox_solve(v, v0, f, alpha): - """Evaluate the alpha-scaled proximal operator of f at v, using v0 as an - initial point for the optimization.""" - fnc = lambda x: prox_func(x, v, f, alpha) - fmn = minimize( - fnc, - v0, - method="Nelder-Mead", - options={"maxiter": 1000, "xatol": 1e-9, "fatol": 1e-9}, - ) - - return fmn.x.reshape(v.shape), fmn.fun - - -def prox_test(v, nrm, prx, alpha, x0=None): - """Test the alpha-scaled proximal operator function prx of norm functional nrm - at point v.""" - # Evaluate the proximal operator at v - px = snp.array(prx(v, alpha, v0=x0)) - # Proximal operator functional value (i.e. Moreau envelope) at v - pf = prox_func(px, v, nrm, alpha) - # Brute-force solve of the proximal operator at v - mx, mf = prox_solve(v, px, nrm, alpha) - - # Compare prox functional value with brute-force solution - if pf < mf: - return # prox gave a lower cost than brute force, so it passes - - np.testing.assert_allclose(pf, mf, rtol=1e-6) - - class ProxTestObj: def __init__(self, dtype): key = None @@ -324,105 +286,6 @@ def foo(c): np.testing.assert_allclose(non_pmap, pmapped) -class TestLoss: - def setup_method(self): - n = 4 - dtype = np.float64 - A, key = randn((n, n), key=None, dtype=dtype, seed=1234) - D, key = randn((n,), key=key, dtype=dtype) - W, key = randn((n,), key=key, dtype=dtype) - W = 0.1 * W + 1.0 - self.Ao = linop.MatrixOperator(A) - self.Ao_abs = linop.MatrixOperator(snp.abs(A)) - self.Do = linop.Diagonal(D) - self.W = linop.Diagonal(W) - self.y, key = randn((n,), key=key, dtype=dtype) - self.v, key = randn((n,), key=key, dtype=dtype) # point for prox eval - scalar, key = randn((1,), key=key, dtype=dtype) - self.scalar = scalar.copy().ravel()[0] - - def test_squared_l2(self): - L = loss.SquaredL2Loss(y=self.y, A=self.Ao) - assert L.has_eval == True - assert L.has_prox == True - - # test eval - np.testing.assert_allclose(L(self.v), 0.5 * ((self.Ao @ self.v - self.y) ** 2).sum()) - - cL = self.scalar * L - assert L.scale == 0.5 # hasn't changed - assert cL.scale == self.scalar * L.scale - assert cL(self.v) == self.scalar * L(self.v) - - # SquaredL2 with Diagonal linop has a prox - L_d = loss.SquaredL2Loss(y=self.y, A=self.Do) - - # test eval - np.testing.assert_allclose(L_d(self.v), 0.5 * ((self.Do @ self.v - self.y) ** 2).sum()) - - assert L_d.has_eval == True - assert L_d.has_prox == True - - cL = self.scalar * L_d - assert L_d.scale == 0.5 # hasn't changed - assert cL.scale == self.scalar * L_d.scale - assert cL(self.v) == self.scalar * L_d(self.v) - - pf = prox_test(self.v, L_d, L_d.prox, 0.75) - - pf = prox_test(self.v, L, L.prox, 0.75) - - def test_weighted_squared_l2(self): - L = loss.WeightedSquaredL2Loss(y=self.y, A=self.Ao, W=self.W) - assert L.has_eval == True - assert L.has_prox == True - - # test eval - np.testing.assert_allclose( - L(self.v), 0.5 * (self.W @ (self.Ao @ self.v - self.y) ** 2).sum() - ) - - cL = self.scalar * L - assert L.scale == 0.5 # hasn't changed - assert cL.scale == self.scalar * L.scale - assert cL(self.v) == self.scalar * L(self.v) - - # SquaredL2 with Diagonal linop has a prox - L_d = loss.WeightedSquaredL2Loss(y=self.y, A=self.Do, W=self.W) - - assert L_d.has_eval == True - assert L_d.has_prox == True - - # test eval - np.testing.assert_allclose( - L_d(self.v), 0.5 * (self.W @ (self.Do @ self.v - self.y) ** 2).sum() - ) - - cL = self.scalar * L_d - assert L_d.scale == 0.5 # hasn't changed - assert cL.scale == self.scalar * L_d.scale - assert cL(self.v) == self.scalar * L_d(self.v) - - pf = prox_test(self.v, L_d, L_d.prox, 0.75) - - pf = prox_test(self.v, L, L.prox, 0.75) - - def test_poisson(self): - L = loss.PoissonLoss(y=self.y, A=self.Ao_abs) - assert L.has_eval == True - assert L.has_prox == False - - # test eval - v = snp.abs(self.v) - Av = self.Ao_abs @ v - np.testing.assert_allclose(L(v), 0.5 * snp.sum(Av - self.y * snp.log(Av) + L.const)) - - cL = self.scalar * L - assert L.scale == 0.5 # hasn't changed - assert cL.scale == self.scalar * L.scale - assert cL(v) == self.scalar * L(v) - - class TestBM3D: def setup(self): key = None diff --git a/scico/test/test_loss.py b/scico/test/test_loss.py new file mode 100644 index 000000000..39e5689c6 --- /dev/null +++ b/scico/test/test_loss.py @@ -0,0 +1,131 @@ +import numpy as np + +from jax.config import config + +import pytest + +# enable 64-bit mode for output dtype checks +config.update("jax_enable_x64", True) + + +from prox import prox_test + +import scico.numpy as snp +from scico import functional, linop, loss +from scico.random import randn + + +class TestLoss: + def setup_method(self): + n = 4 + dtype = np.float64 + A, key = randn((n, n), key=None, dtype=dtype, seed=1234) + D, key = randn((n,), key=key, dtype=dtype) + W, key = randn((n,), key=key, dtype=dtype) + W = 0.1 * W + 1.0 + self.Ao = linop.MatrixOperator(A) + self.Ao_abs = linop.MatrixOperator(snp.abs(A)) + self.Do = linop.Diagonal(D) + self.W = linop.Diagonal(W) + self.y, key = randn((n,), key=key, dtype=dtype) + self.v, key = randn((n,), key=key, dtype=dtype) # point for prox eval + scalar, key = randn((1,), key=key, dtype=dtype) + self.scalar = scalar.copy().ravel()[0] + + def test_generic_squared_l2(self): + A = linop.Identity(input_shape=self.y.shape) + f = functional.SquaredL2Norm() + L0 = loss.Loss(self.y, A=A, f=f, scale=0.5) + L1 = loss.SquaredL2Loss(y=self.y, A=A) + np.testing.assert_allclose(L0(self.v), L1(self.v)) + np.testing.assert_allclose(L0.prox(self.v, self.scalar), L1.prox(self.v, self.scalar)) + + def test_generic_exception(self): + A = linop.Diagonal(self.v) + L = loss.Loss(self.y, A=A, scale=0.5) + with pytest.raises(NotImplementedError): + L(self.v) + f = functional.L1Norm() + L = loss.Loss(self.y, A=A, f=f, scale=0.5) + assert not L.has_prox + with pytest.raises(NotImplementedError): + L.prox(self.v, self.scalar) + + def test_squared_l2(self): + L = loss.SquaredL2Loss(y=self.y, A=self.Ao) + assert L.has_eval + assert L.has_prox + + # test eval + np.testing.assert_allclose(L(self.v), 0.5 * ((self.Ao @ self.v - self.y) ** 2).sum()) + + cL = self.scalar * L + assert L.scale == 0.5 # hasn't changed + assert cL.scale == self.scalar * L.scale + assert cL(self.v) == self.scalar * L(self.v) + + # SquaredL2 with Diagonal linop has a prox + L_d = loss.SquaredL2Loss(y=self.y, A=self.Do) + + # test eval + np.testing.assert_allclose(L_d(self.v), 0.5 * ((self.Do @ self.v - self.y) ** 2).sum()) + + assert L_d.has_eval + assert L_d.has_prox + + cL = self.scalar * L_d + assert L_d.scale == 0.5 # hasn't changed + assert cL.scale == self.scalar * L_d.scale + assert cL(self.v) == self.scalar * L_d(self.v) + + pf = prox_test(self.v, L_d, L_d.prox, 0.75) + pf = prox_test(self.v, L, L.prox, 0.75) + + def test_weighted_squared_l2(self): + L = loss.WeightedSquaredL2Loss(y=self.y, A=self.Ao, W=self.W) + assert L.has_eval + assert L.has_prox + + # test eval + np.testing.assert_allclose( + L(self.v), 0.5 * (self.W @ (self.Ao @ self.v - self.y) ** 2).sum() + ) + + cL = self.scalar * L + assert L.scale == 0.5 # hasn't changed + assert cL.scale == self.scalar * L.scale + assert cL(self.v) == self.scalar * L(self.v) + + # SquaredL2 with Diagonal linop has a prox + L_d = loss.WeightedSquaredL2Loss(y=self.y, A=self.Do, W=self.W) + + assert L_d.has_eval + assert L_d.has_prox + + # test eval + np.testing.assert_allclose( + L_d(self.v), 0.5 * (self.W @ (self.Do @ self.v - self.y) ** 2).sum() + ) + + cL = self.scalar * L_d + assert L_d.scale == 0.5 # hasn't changed + assert cL.scale == self.scalar * L_d.scale + assert cL(self.v) == self.scalar * L_d(self.v) + + pf = prox_test(self.v, L_d, L_d.prox, 0.75) + pf = prox_test(self.v, L, L.prox, 0.75) + + def test_poisson(self): + L = loss.PoissonLoss(y=self.y, A=self.Ao_abs) + assert L.has_eval + assert not L.has_prox + + # test eval + v = snp.abs(self.v) + Av = self.Ao_abs @ v + np.testing.assert_allclose(L(v), 0.5 * snp.sum(Av - self.y * snp.log(Av) + L.const)) + + cL = self.scalar * L + assert L.scale == 0.5 # hasn't changed + assert cL.scale == self.scalar * L.scale + assert cL(v) == self.scalar * L(v)