Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to admm and svmbir wrapper to address issue #57 #88

Merged
merged 70 commits into from
Nov 16, 2021
Merged
Show file tree
Hide file tree
Changes from 63 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
717e1c6
Rough in approximate prox function
Michael-T-McCann Nov 3, 2021
9a5cfa6
Add kwargs to all proxes
Michael-T-McCann Nov 3, 2021
123d296
change svmbir roi radius to have the full rectangle as the imaging do…
smajee Nov 3, 2021
f8d2cfb
update gitignore to ignore macos hidden files
smajee Nov 6, 2021
66fa592
pass initial image to prox
smajee Nov 6, 2021
5118df1
add stopping conditions to SVMBIRWeightedSquaredL2Loss
smajee Nov 6, 2021
08c5ba7
add optional is_masked argument to projector. SVMBIRWeightedSquaredL2…
smajee Nov 8, 2021
812d3dc
Use bug-fixed branch of svmbir
smajee Nov 12, 2021
9e0ae41
use small number of iterations for proximal map in example
smajee Nov 12, 2021
0bcc1b3
Merge branch 'main' into mike/approx_prox
bwohlberg Nov 12, 2021
3a6d97e
change CT angle generation so that pi is not included
smajee Nov 12, 2021
0cb6003
Merge branch 'mike/approx_prox' of github.com:lanl/scico into mike/ap…
smajee Nov 12, 2021
8313a04
Merge branch 'main' into mike/approx_prox
bwohlberg Nov 12, 2021
7640be4
Apply black
bwohlberg Nov 12, 2021
56189bd
Merge branch 'main' into mike/approx_prox
smajee Nov 12, 2021
1220f75
fix merge conflicts
smajee Nov 12, 2021
93e03ac
Fix lint error
bwohlberg Nov 12, 2021
4a8842b
Fix reshape edge case to enable all tests to pass
smajee Nov 12, 2021
903d792
Resolve #90
bwohlberg Nov 13, 2021
955677c
add initial test script
tbalke Nov 13, 2021
570a856
adds test
tbalke Nov 13, 2021
ae8910d
change tolerance
tbalke Nov 13, 2021
1f0887b
remove comment
tbalke Nov 13, 2021
7ebad4c
rename
tbalke Nov 13, 2021
4276e0a
Add missing test
bwohlberg Nov 13, 2021
1df1275
Fix additional bugs and correct docs
bwohlberg Nov 13, 2021
47c0f6a
improve documentation
smajee Nov 13, 2021
88dc430
Merge remote-tracking branch 'origin/brendt/admm-bug' into mike/appro…
smajee Nov 13, 2021
196bdf8
fix lint error with black
smajee Nov 13, 2021
70a041c
Fix math error
bwohlberg Nov 13, 2021
0f1cb51
Merge branch 'brendt/admm-bug' into mike/approx_prox
bwohlberg Nov 13, 2021
c1f26d8
Improved parameters
bwohlberg Nov 13, 2021
aa6703b
Improved parameters
bwohlberg Nov 13, 2021
a42b143
Improve docstring
smajee Nov 13, 2021
557b8ab
grammatical change
smajee Nov 13, 2021
e00f1ef
add tests for masked mode
smajee Nov 13, 2021
c07dcd2
Add tests for prox v0
smajee Nov 13, 2021
d496b64
Docstring changes
smajee Nov 13, 2021
f1c9ed4
change svmbir branch to master since nanfix branch was merged and del…
smajee Nov 13, 2021
964fca5
bugfix in test
smajee Nov 13, 2021
5fc99ed
Docs edits
bwohlberg Nov 13, 2021
b3d9367
Merge branch 'mike/approx_prox' of github.com:lanl/scico into mike/ap…
bwohlberg Nov 13, 2021
a8ab873
use svmbir to generate weights in test
smajee Nov 14, 2021
fa1df17
update svmbir requirements to pypi since bug fix has been pushed to pypi
smajee Nov 14, 2021
bb2b42b
add version num to svmbir
smajee Nov 14, 2021
238e5f9
perform cg svmbir-prox comparison test for both masked on and off
smajee Nov 14, 2021
c81e3dd
expose svmbir positivity to users
smajee Nov 14, 2021
629fa4e
update docstring
smajee Nov 14, 2021
86b2a37
Fix bug resulting from specification of svmbir version number
bwohlberg Nov 15, 2021
3fdc0cb
adds cg-prox to generic loss
tbalke Nov 15, 2021
06631b2
change args to maxiter and ctol
tbalke Nov 15, 2021
7ea5a11
tighter test bounds
tbalke Nov 15, 2021
b342a7e
update docstring for svmbir projector
smajee Nov 15, 2021
497e1ca
cast to Device
tbalke Nov 15, 2021
d6c6d85
Merge branch 'mike/approx_prox' of github.com:lanl/scico into mike/ap…
smajee Nov 15, 2021
72c77f8
prox_kwargs
tbalke Nov 15, 2021
9e45bdc
add generic prox for unweighted loss
tbalke Nov 15, 2021
f4f268d
Update error messages
bwohlberg Nov 15, 2021
57a06f2
handle prox_args=None
tbalke Nov 15, 2021
f318cd1
deal better with non-set dictionary entries
tbalke Nov 16, 2021
57e0df0
remove testing file
tbalke Nov 16, 2021
53713dd
add test for non-diagonal prox
tbalke Nov 16, 2021
0f5175f
Some docstring improvements
bwohlberg Nov 16, 2021
079ac16
add docstring
tbalke Nov 16, 2021
a640fa7
added ctol
tbalke Nov 16, 2021
db28d3f
Merge branch 'main' into mike/approx_prox
bwohlberg Nov 16, 2021
219e5ed
update ParallelBeamProjector docstring
smajee Nov 16, 2021
4c02266
Some docstring improvements
bwohlberg Nov 16, 2021
9b614bb
Improve parameters
bwohlberg Nov 16, 2021
2cc51b3
Improve parameters
bwohlberg Nov 16, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,6 @@ dmypy.json

# Pyre type checker
.pyre/

# macos files
*.DS_Store
3 changes: 2 additions & 1 deletion docs/source/team.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@ Contributors
============

- `Oleg Korobkin <https://github.com/korobkin>`_ (BlockArray improvements)
- `Yanpeng Yuan <https://github.com/yanpeng7>`_ (Improvements to ASTRA interface)
- `Yanpeng Yuan <https://github.com/yanpeng7>`_ (ASTRA interface improvements)
- `Soumendu Majee <https://github.com/smajee>`_ (SVMBIR interface improvements)
4 changes: 2 additions & 2 deletions examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"""
num_angles = int(N / 2)
num_channels = N
angles = snp.linspace(0, snp.pi, num_angles, dtype=snp.float32)
angles = snp.linspace(0, snp.pi, num_angles, endpoint=False, dtype=snp.float32)
A = ParallelBeamProjector(x_gt.shape, angles, num_channels)
sino = A @ x_gt

Expand Down Expand Up @@ -88,7 +88,7 @@
"""
y, x0, weights = jax.device_put([y, x_mrf, weights])

ρ = 100 # ADMM penalty parameter
ρ = 20 # ADMM penalty parameter
σ = density * 0.2 # denoiser sigma

f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5)
Expand Down
6 changes: 4 additions & 2 deletions examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"""
num_angles = int(N / 2)
num_channels = N
angles = snp.linspace(0, snp.pi, num_angles, dtype=snp.float32)
angles = snp.linspace(0, snp.pi, num_angles, endpoint=False, dtype=snp.float32)
A = ParallelBeamProjector(x_gt.shape, angles, num_channels)
sino = A @ x_gt

Expand Down Expand Up @@ -91,7 +91,9 @@
ρ = 10 # ADMM penalty parameter
σ = density * 0.26 # denoiser sigma

f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5)
f = SVMBIRWeightedSquaredL2Loss(
y=y, A=A, W=Diagonal(weights), scale=0.5, prox_kwargs={"maxiter": 5}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really desirable that we just drop the default stopping tolerance here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would choose the tolerance to be 0 (or very) low so that always 5 iterations are being performed (unless there is really almost no change)

Copy link
Contributor

@tbalke tbalke Nov 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in a640fa7

)
g0 = σ * ρ * BM3D()
g1 = NonNegativeIndicator()

Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/denoise_tv_iso_pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class IsoProjector(functional.Functional):
def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
return 0.0

def prox(self, x: JaxArray, lam: float) -> JaxArray:
def prox(self, x: JaxArray, lam: float, **kwargs) -> JaxArray:
norm_x_ptp = jnp.sqrt(jnp.sum(jnp.abs(x) ** 2, axis=0))

x_out = x / jnp.maximum(jnp.ones(x.shape), norm_x_ptp)
Expand Down Expand Up @@ -172,7 +172,7 @@ class AnisoProjector(functional.Functional):
def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
return 0.0

def prox(self, x: JaxArray, lam: float) -> JaxArray:
def prox(self, x: JaxArray, lam: float, **kwargs) -> JaxArray:

return x / jnp.maximum(jnp.ones(x.shape), jnp.abs(x))

Expand Down
2 changes: 1 addition & 1 deletion misc/conda/make_conda_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ sort $ALLREQUIRE | uniq | $SED -E 's/(>|<|\|)/\\\1/g' \
| $SED -E '/^-r.*|^jaxlib.*|^jax.*|^astra-toolbox.*/d' > $FLTREQUIRE
# Remove requirements that cannot be installed via conda
for nc in $NOCONDA; do
$SED -i "/^$nc\$/d" $FLTREQUIRE
$SED -i "/^$nc.*\$/d" $FLTREQUIRE
done
# Get list of requirements to be installed via conda
CONDAREQ=$(cat $FLTREQUIRE | xargs)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ jax==0.2.19
jaxlib==0.1.70
flax
bm3d
svmbir
svmbir>=0.2.6
6 changes: 3 additions & 3 deletions scico/admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def compute_rhs(self) -> Union[JaxArray, BlockArray]:

if self.admm.f is not None:
if isinstance(self.admm.f, WeightedSquaredL2Loss):
ATWy = self.admm.f.A.adj(self.admm.f.W.diagonal @ self.admm.f.y)
ATWy = self.admm.f.A.adj(self.admm.f.W.diagonal * self.admm.f.y)
rhs += 2.0 * self.admm.f.scale * ATWy
else:
ATy = self.admm.f.A.adj(self.admm.f.y)
Expand All @@ -231,7 +231,7 @@ def solve(self, x0: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]:
"""
x0 = ensure_on_device(x0)
rhs = self.compute_rhs()
x, self.info = self.cg(self.lhs_op, rhs, x0=x0, **self.cg_kwargs)
x, self.info = self.cg(self.lhs_op, rhs, x0, **self.cg_kwargs)
return x


Expand Down Expand Up @@ -604,7 +604,7 @@ def z_and_u_step(self, u_list, z_list):
zip(self.rho_list, self.g_list, self.C_list, z_list, u_list)
):
Cix = Ci(self.x)
zi = gi.prox(Cix + ui, 1 / rhoi)
zi = gi.prox(Cix + ui, 1 / rhoi, v0=zi)
ui = ui + Cix - zi
z_list[i] = zi
u_list[i] = ui
Expand Down
2 changes: 1 addition & 1 deletion scico/functional/_denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, is_rgb: Optional[bool] = False):

super().__init__()

def prox(self, x: JaxArray, lam: float = 1) -> JaxArray:
def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray:
r"""Apply BM3D denoiser with noise level ``lam``.

Args:
Expand Down
2 changes: 1 addition & 1 deletion scico/functional/_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, model: Callable[..., nn.Module], variables: PyTree):
self.variables = variables
super().__init__()

def prox(self, x: JaxArray, lam: float = 1) -> JaxArray:
def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray:
r"""Apply trained flax model.

*Warning*: The ``lam`` parameter is ignored, and has no effect on
Expand Down
26 changes: 18 additions & 8 deletions scico/functional/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
f"Functional {type(self)} cannot be evaluated; has_eval={self.has_eval}"
)

def prox(self, x: Union[JaxArray, BlockArray], lam: float = 1) -> Union[JaxArray, BlockArray]:
def prox(
self, x: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs
) -> Union[JaxArray, BlockArray]:
r"""Scaled proximal operator of functional.

Evaluate scaled proximal operator of this functional, with
Expand All @@ -90,15 +92,18 @@ def prox(self, x: Union[JaxArray, BlockArray], lam: float = 1) -> Union[JaxArray

Args:
x : Point at which to evaluate prox function.
lam : Proximal parameter :math:`\lambda`
lam : Proximal parameter :math:`\lambda`.
kwargs : Additional arguments that may be used by derived
classes. These include ``v0``, an initial guess for the minimizer.

"""
if not self.has_prox:
raise NotImplementedError(
f"Functional {type(self)} does not have a prox; has_prox={self.has_prox}"
)

def conj_prox(
self, x: Union[JaxArray, BlockArray], lam: float = 1
self, x: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs
) -> Union[JaxArray, BlockArray]:
r"""Scaled proximal operator of convex conjugate of functional.

Expand All @@ -116,9 +121,10 @@ def conj_prox(

Args:
x : Point at which to evaluate prox function.
lam : Proximal parameter :math:`\lambda`
lam : Proximal parameter :math:`\lambda`.
kwargs : additional keyword args, passed directly to ``self.prox``.
"""
return x - lam * self.prox(x / lam, 1.0 / lam)
return x - lam * self.prox(x / lam, 1.0 / lam, **kwargs)

def grad(self, x: Union[JaxArray, BlockArray]):
r"""Evaluates the gradient of this functional at point :math:`\mb{x}`.
Expand Down Expand Up @@ -174,7 +180,9 @@ def __init__(self, functional: Functional, scale: float):
def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
return self.scale * self.functional(x)

def prox(self, x: Union[JaxArray, BlockArray], lam: float = 1) -> Union[JaxArray, BlockArray]:
def prox(
self, x: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs
) -> Union[JaxArray, BlockArray]:
return self.functional.prox(x, lam * self.scale)


Expand Down Expand Up @@ -216,7 +224,7 @@ def __call__(self, x: BlockArray) -> float:
f"Number of blocks in x, {len(x.shape)}, and length of functional_list, {len(self.functional_list)}, do not match"
)

def prox(self, x: BlockArray, lam: float = 1) -> BlockArray:
def prox(self, x: BlockArray, lam: float = 1.0, **kwargs) -> BlockArray:
r"""Evaluate proximal operator of the separable functional.

Evaluate proximal operator of the separable functional (see Theorem 6.6 of :cite:`beck-2017-first`).
Expand Down Expand Up @@ -252,5 +260,7 @@ class ZeroFunctional(Functional):
def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
return 0.0

def prox(self, x: Union[JaxArray, BlockArray], lam: float = 1) -> Union[JaxArray, BlockArray]:
def prox(
self, x: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs
) -> Union[JaxArray, BlockArray]:
return x
8 changes: 6 additions & 2 deletions scico/functional/_indicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
# snp.inf if snp.any(x < 0) else 0.0
return jax.lax.cond(snp.any(x < 0), lambda x: snp.inf, lambda x: 0.0, None)

def prox(self, x: Union[JaxArray, BlockArray], lam: float = 1) -> Union[JaxArray, BlockArray]:
def prox(
self, x: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs
) -> Union[JaxArray, BlockArray]:
r"""Evaluate proximal operator of indicator over non-negative orthant:

.. math::
Expand Down Expand Up @@ -95,7 +97,9 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
# snp.inf if norm(x) > self.radius else 0.0
return jax.lax.cond(norm(x) > self.radius, lambda x: snp.inf, lambda x: 0.0, None)

def prox(self, x: Union[JaxArray, BlockArray], lam: float = 1) -> Union[JaxArray, BlockArray]:
def prox(
self, x: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs
) -> Union[JaxArray, BlockArray]:
r"""Evaluate proximal operator of indicator over :math:`\ell_2` ball:

.. math::
Expand Down
18 changes: 13 additions & 5 deletions scico/functional/_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float:

@staticmethod
@jit
def prox(x: Union[JaxArray, BlockArray], lam: float = 1) -> Union[JaxArray, BlockArray]:
def prox(
x: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs
) -> Union[JaxArray, BlockArray]:
r"""Evaluate proximal operator of :math:`\ell_0` norm


Expand Down Expand Up @@ -71,7 +73,7 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
return snp.abs(x).sum()

@staticmethod
def prox(x: Union[JaxArray, BlockArray], lam: float = 1) -> JaxArray:
def prox(x: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs) -> JaxArray:
r"""Evaluate proximal operator of :math:`\ell_1` norm

.. math::
Expand Down Expand Up @@ -115,7 +117,9 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
# behavior of snp.norm(x) at 0.
return (snp.abs(x) ** 2).sum()

def prox(self, x: Union[JaxArray, BlockArray], lam: float = 1) -> Union[JaxArray, BlockArray]:
def prox(
self, x: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs
) -> Union[JaxArray, BlockArray]:
r"""Evaluate proximal operator of squared :math:`\ell_2` norm:

.. math::
Expand Down Expand Up @@ -143,7 +147,9 @@ class L2Norm(Functional):
def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
return norm(x)

def prox(self, x: Union[JaxArray, BlockArray], lam: float = 1) -> Union[JaxArray, BlockArray]:
def prox(
self, x: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs
) -> Union[JaxArray, BlockArray]:
r"""Evaluate proximal operator of :math:`\ell_2` norm:

.. math::
Expand Down Expand Up @@ -199,7 +205,9 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
l2 = norm(x, axis=self.l2_axis)
return snp.abs(l2).sum()

def prox(self, x: Union[JaxArray, BlockArray], lam: float = 1) -> Union[JaxArray, BlockArray]:
def prox(
self, x: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs
) -> Union[JaxArray, BlockArray]:
r"""Evaluate proximal operator of the :math:`\ell_{2,1}` norm.

In two dimensions,
Expand Down
11 changes: 6 additions & 5 deletions scico/linop/radon_astra.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@
try:
import astra
except ImportError:
raise ImportError(
"Could not import astra, please refer to INSTALL.rst "
"for instructions on how to install the ASTRA toolbox."
)
raise ImportError("Could not import astra; please install the ASTRA toolbox.")


from jaxlib.xla_extension import GpuDevice
Expand All @@ -38,7 +35,11 @@


class ParallelBeamProjector(LinearOperator):
r"""Parallel beam projector based on ASTRA."""
r"""Parallel beam Radon transform based on the ASTRA toolbox.

Perform tomographic projection of an image at specified angles,
using the `ASTRA toolbox <https://github.com/astra-toolbox/astra-toolbox>`_.
"""

def __init__(
self,
Expand Down