Skip to content

Commit

Permalink
Problem with W^1/2 weight exponent (#78)
Browse files Browse the repository at this point in the history
* fix bug in svmbir weights

* weight_op -> W Diagonal operator

* adjust other code for new interface

* fix weight in astra example

* \cdot cleanup

* remove self.functional

* fix docs .Diagonal

* Change is_smooth attribute of PoissonLoss

* Add test for PoissonLoss

* scale -> \alpha

* remove epsilon, remove smoothness

* update is_smooth test

Co-authored-by: Brendt Wohlberg <brendt@ieee.org>
  • Loading branch information
tbalke and bwohlberg authored Nov 12, 2021
1 parent f1a138c commit 4684909
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 75 deletions.
4 changes: 2 additions & 2 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ Losses
In SCICO, a loss is a special type of functional

.. math::
f(\mb{x}) = a l( \mb{y}, A(\mb{x}) )
f(\mb{x}) = \alpha l( \mb{y}, A(\mb{x}) )
where :math:`a` is a scale parameter,
where :math:`\alpha` is a scaling parameter,
:math:`l` is a functional,
:math:`\mb{y}` is a set of measurements,
and :math:`A` is an operator.
Expand Down
5 changes: 2 additions & 3 deletions examples/scripts/ct_astra_weighted_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,8 @@ def postprocess(x):
"""
lambda_weighted = 1.14e2

weights = counts / Io # scale by Io to balance the data vs regularization term
W = linop.Diagonal(snp.sqrt(weights))
f = loss.WeightedSquaredL2Loss(y=y, A=A, weight_op=W)
weights = counts / Io # scaled by Io to balance the data vs regularization term
f = loss.WeightedSquaredL2Loss(y=y, A=A, W=linop.Diagonal(weights))

admm_weighted = ADMM(
f=f,
Expand Down
4 changes: 1 addition & 3 deletions examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@
ρ = 100 # ADMM penalty parameter
σ = density * 0.2 # denoiser sigma

weight_op = Diagonal(weights ** 0.5)

f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, weight_op=weight_op, scale=0.5)
f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5)
g0 = σ * ρ * BM3D()
g1 = NonNegativeIndicator()

Expand Down
4 changes: 1 addition & 3 deletions examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@
ρ = 10 # ADMM penalty parameter
σ = density * 0.26 # denoiser sigma

weight_op = Diagonal(weights ** 0.5)

f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, weight_op=weight_op, scale=0.5)
f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5)
g0 = σ * ρ * BM3D()
g1 = NonNegativeIndicator()

Expand Down
4 changes: 2 additions & 2 deletions scico/admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class LinearSubproblemSolver(SubproblemSolver):
\mb{x}^{(k+1)} = \argmin_{\mb{x}} \; \frac{1}{2} \norm{\mb{y} - A x}_W^2 +
\sum_i \frac{\rho_i}{2} \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \;,
where :math:`W` is the weighting :class:`.LinearOperator` from the
where :math:`W` is the weighting :class:`.Diagonal` from the
:class:`.WeightedSquaredL2Loss` instance. This update step reduces to the
solution of the linear system
Expand Down Expand Up @@ -208,7 +208,7 @@ def compute_rhs(self) -> Union[JaxArray, BlockArray]:

if self.admm.f is not None:
if isinstance(self.admm.f, WeightedSquaredL2Loss):
ATWy = self.admm.f.A.adj(self.admm.f.weight_op @ self.admm.f.y)
ATWy = self.admm.f.A.adj(self.admm.f.W.diagonal @ self.admm.f.y)
rhs += 2.0 * self.admm.f.scale * ATWy
else:
ATy = self.admm.f.A.adj(self.admm.f.y)
Expand Down
13 changes: 2 additions & 11 deletions scico/linop/radon_svmbir.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from scico.loss import WeightedSquaredL2Loss
from scico.typing import JaxArray, Shape

from ._linop import Diagonal, LinearOperator
from ._linop import LinearOperator


class ParallelBeamProjector(LinearOperator):
Expand Down Expand Up @@ -120,21 +120,12 @@ def __init__(self, *args, **kwargs):
"to instantiate a `SVMBIRWeightedSquaredL2Loss`."
)

if not isinstance(self.weight_op, Diagonal):
raise ValueError(
f"`weight_op` must be `Diagonal` but instead got {type(self.weight_op)}"
)

self.weights = (
snp.conj(self.weight_op.diagonal) * self.weight_op.diagonal
) # because weight_op is W^{1/2}

self.has_prox = True

def prox(self, v: JaxArray, lam: float) -> JaxArray:
v = v.reshape(self.A.svmbir_input_shape)
y = self.y.reshape(self.A.svmbir_output_shape)
weights = self.weights.reshape(self.A.svmbir_output_shape)
weights = self.W.diagonal.reshape(self.A.svmbir_output_shape)
sigma_p = snp.sqrt(lam)
result = svmbir.recon(
np.array(y),
Expand Down
104 changes: 61 additions & 43 deletions scico/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ class Loss(functional.Functional):
r"""Generic Loss function.
.. math::
\mathrm{scale} \cdot l(\mb{y}, A(\mb{x})) \;
\alpha l(\mb{y}, A(\mb{x})) \;
where :math:`\alpha` is the scaling parameter and :math:`l(\cdot)` is the loss functional.
"""

Expand All @@ -55,7 +57,7 @@ def __init__(
r"""Initialize a :class:`Loss` object.
Args:
y : Measurements
y : Measurement.
A : Forward operator. Defaults to None. If None, ``self.A`` is a :class:`.Identity`.
scale : Scaling parameter. Default: 0.5.
Expand Down Expand Up @@ -109,7 +111,9 @@ class SquaredL2Loss(Loss):
Squared :math:`\ell_2` loss.
.. math::
\mathrm{scale} \cdot \norm{\mb{y} - A(\mb{x})}_2^2 \;
\alpha \norm{\mb{y} - A(\mb{x})}_2^2 \;
where :math:`\alpha` is the scaling parameter.
"""

Expand All @@ -119,8 +123,14 @@ def __init__(
A: Optional[Union[Callable, operator.Operator]] = None,
scale: float = 0.5,
):
r"""Initialize a :class:`SquaredL2Loss` object.
Args:
y : Measurement.
A : Forward operator. If None, defaults to :class:`.Identity`.
scale : Scaling parameter.
"""
y = ensure_on_device(y)
self.functional = functional.SquaredL2Norm()
super().__init__(y=y, A=A, scale=scale)

if isinstance(A, operator.Operator):
Expand All @@ -140,7 +150,7 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
Args:
x : Point at which to evaluate loss.
"""
return self.scale * self.functional(self.y - self.A(x))
return self.scale * (snp.abs(self.y - self.A(x)) ** 2).sum()

def prox(self, x: Union[JaxArray, BlockArray], lam: float) -> Union[JaxArray, BlockArray]:
if isinstance(self.A, linop.Diagonal):
Expand All @@ -154,7 +164,7 @@ def prox(self, x: Union[JaxArray, BlockArray], lam: float) -> Union[JaxArray, Bl
@property
def hessian(self) -> linop.LinearOperator:
r"""If ``self.A`` is a :class:`.LinearOperator`, returns a new :class:`.LinearOperator` corresponding
to Hessian :math:`\mathrm{A^H A}`.
to Hessian :math:`2 \alpha \mathrm{A^H A}`.
Otherwise not implemented.
"""
Expand All @@ -171,11 +181,12 @@ class WeightedSquaredL2Loss(Loss):
Weighted squared :math:`\ell_2` loss.
.. math::
\mathrm{scale} \cdot \norm{\mb{y} - A(\mb{x})}_{\mathrm{W}}^2 =
\mathrm{scale} \cdot \norm{\mathrm{W}^{1/2} \left( \mb{y} - A(\mb{x})\right)}_2^2\;
\alpha \norm{\mb{y} - A(\mb{x})}_W^2 =
\alpha \left(\mb{y} - A(\mb{x})\right)^T W \left(\mb{y} - A(\mb{x})\right)\;
Where :math:`\mathrm{W}` is an instance of :class:`scico.linop.LinearOperator`. If
:math:`\mathrm{W}` is None, reverts to the behavior of :class:`.SquaredL2Loss`.
where :math:`\alpha` is the scaling parameter and :math:`W` is an
instance of :class:`scico.linop.Diagonal`. If :math:`W` is None,
reverts to the behavior of :class:`.SquaredL2Loss`.
"""

Expand All @@ -184,30 +195,32 @@ def __init__(
y: Union[JaxArray, BlockArray],
A: Optional[Union[Callable, operator.Operator]] = None,
scale: float = 0.5,
weight_op: Optional[operator.Operator] = None,
W: Optional[linop.Diagonal] = None,
):

r"""Initialize a :class:`WeightedSquaredL2Loss` object.
Args:
y : Measurements
y : Measurement.
A : Forward operator. If None, defaults to :class:`.Identity`.
scale : Scaling parameter
weight_op: Weighting linear operator. Corresponds to :math:`W^{1/2}`
in the standard definition of the weighted squared :math:`\ell_2` loss.
scale : Scaling parameter.
W: Weighting diagonal operator. Must be non-negative.
If None, defaults to :class:`.Identity`.
"""
y = ensure_on_device(y)

self.weight_op: operator.Operator
self.W: linop.Diagonal

self.functional = functional.SquaredL2Norm()
if weight_op is None:
self.weight_op = linop.Identity(y.shape)
elif isinstance(weight_op, linop.LinearOperator):
self.weight_op = weight_op
if W is None:
self.W = linop.Identity(y.shape)
elif isinstance(W, linop.Diagonal):
if snp.all(W.diagonal >= 0):
self.W = W
else:
raise Exception(f"The weights, W.diagonal, must be non-negative.")
else:
raise TypeError(f"weight_op must be None or a LinearOperator, got {type(weight_op)}")
raise TypeError(f"W must be None or a linop.Diagonal, got {type(W)}")

super().__init__(y=y, A=A, scale=scale)

if isinstance(A, operator.Operator):
Expand All @@ -218,40 +231,43 @@ def __init__(
if isinstance(self.A, linop.LinearOperator):
self.is_quadratic = True

if isinstance(self.A, linop.Diagonal) and isinstance(self.weight_op, linop.Diagonal):
if isinstance(self.A, linop.Diagonal) and isinstance(self.W, linop.Diagonal):
self.has_prox = True

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
return self.scale * self.functional(self.weight_op(self.y - self.A(x)))
return self.scale * (self.W.diagonal * snp.abs(self.y - self.A(x)) ** 2).sum()

def prox(self, x: Union[JaxArray, BlockArray], lam: float) -> Union[JaxArray, BlockArray]:
if isinstance(self.A, linop.Diagonal):
c = self.scale * lam
c = 2.0 * self.scale * lam
A = self.A.diagonal
W = self.weight_op.diagonal
lhs = c * 2.0 * A.conj() * W * W.conj() * self.y + x
ATWTWA = c * 2.0 * A.conj() * W.conj() * W * A
return lhs / (ATWTWA + 1.0)
W = self.W.diagonal
lhs = c * A.conj() * W * self.y + x
ATWA = c * A.conj() * W * A
return lhs / (ATWA + 1.0)
else:
raise NotImplementedError

@property
def hessian(self) -> linop.LinearOperator:
r"""If ``self.A`` is a :class:`scico.linop.LinearOperator`, returns a
:class:`scico.linop.LinearOperator` corresponding to Hessian :math:`\mathrm{A^H W A}`.
:class:`scico.linop.LinearOperator` corresponding to the Hessian
:math:`2 \alpha \mathrm{A^H W A}`.
Otherwise not implemented.
"""
if isinstance(self.A, linop.LinearOperator):
A = self.A
W = self.W
if isinstance(A, linop.LinearOperator):
return linop.LinearOperator(
input_shape=self.A.input_shape,
output_shape=self.A.input_shape,
eval_fn=lambda x: 2 * self.scale * self.A.adj(self.weight_op(self.A(x))),
adj_fn=lambda x: 2 * self.scale * self.A.adj(self.weight_op(self.A(x))),
input_shape=A.input_shape,
output_shape=A.input_shape,
eval_fn=lambda x: 2 * self.scale * A.adj(W(A(x))),
adj_fn=lambda x: 2 * self.scale * A.adj(W(A(x))),
)
else:
raise NotImplementedError(
f"Hessian is not implemented for {type(self)} when `A` is {type(self.A)}; must be LinearOperator"
f"Hessian is not implemented for {type(self)} when `A` is {type(A)}; must be LinearOperator"
)


Expand All @@ -260,8 +276,9 @@ class PoissonLoss(Loss):
Poisson negative log likelihood loss
.. math::
\mathrm{scale} \cdot \sum_i [A(x)]_i - y_i \log\left( [A(x)]_i \right) + \log(y_i!)
\alpha \left( \sum_i [A(x)]_i - y_i \log\left( [A(x)]_i \right) + \log(y_i!) \right)
where :math:`\alpha` is the scaling parameter.
"""

def __init__(
Expand All @@ -273,18 +290,19 @@ def __init__(
r"""Initialize a :class:`Loss` object.
Args:
y : Measurements
A : Forward operator. Defaults to None. If None, ``self.A`` is a :class:`.Identity`.
scale : Scaling parameter. Default: 0.5.
y : Measurement.
A : Forward operator. Defaults to None. If None, ``self.A`` is a :class:`.Identity`.
scale : Scaling parameter. Default: 0.5.
"""
y = ensure_on_device(y)
super().__init__(y=y, A=A, scale=scale)

#: Constant term in Poisson log likehood; equal to ln(y!)
self.const: float = gammaln(self.y + 1) # ln(y!)

# The Poisson Loss is only smooth in the positive quadrant.
self.is_smooth = None

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
ε = 1e-9 # So that loss < infinity
Ax = self.A(x)
return self.scale * snp.sum(Ax - self.y * snp.log(Ax + ε) + self.const)
return self.scale * snp.sum(Ax - self.y * snp.log(Ax) + self.const)
4 changes: 2 additions & 2 deletions scico/test/linop/test_radon_svmbir.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,6 @@ def test_prox_weights(Nx, Ny, num_angles, num_channels, is_3d):

# test with weights
weights, _ = scico.random.uniform(sino.shape, dtype=im.dtype)
D = scico.linop.Diagonal(weights)
f = SVMBIRWeightedSquaredL2Loss(y=sino, A=A, weight_op=D)
W = scico.linop.Diagonal(weights)
f = SVMBIRWeightedSquaredL2Loss(y=sino, A=A, W=W)
prox_test(v, f, f.prox, alpha=0.25)
28 changes: 22 additions & 6 deletions scico/test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,9 @@ def setup_method(self):
W, key = randn((n,), key=key, dtype=dtype)
W = 0.1 * W + 1.0
self.Ao = linop.MatrixOperator(A)
self.Ao_abs = linop.MatrixOperator(snp.abs(A))
self.Do = linop.Diagonal(D)
self.Wo = linop.Diagonal(W)
self.W = linop.Diagonal(W)
self.y, key = randn((n,), key=key, dtype=dtype)
self.v, key = randn((n,), key=key, dtype=dtype) # point for prox eval
scalar, key = randn((1,), key=key, dtype=dtype)
Expand Down Expand Up @@ -377,14 +378,14 @@ def test_squared_l2(self):
pf = prox_test(self.v, L_d, L_d.prox, 0.75)

def test_weighted_squared_l2(self):
L = loss.WeightedSquaredL2Loss(y=self.y, A=self.Ao, weight_op=self.Wo)
L = loss.WeightedSquaredL2Loss(y=self.y, A=self.Ao, W=self.W)
assert L.is_smooth == True
assert L.has_eval == True
assert L.has_prox == False # not diagonal

# test eval
np.testing.assert_allclose(
L(self.v), 0.5 * ((self.Wo @ (self.Ao @ self.v - self.y)) ** 2).sum()
L(self.v), 0.5 * (self.W @ (self.Ao @ self.v - self.y) ** 2).sum()
)

cL = self.scalar * L
Expand All @@ -393,16 +394,15 @@ def test_weighted_squared_l2(self):
assert cL(self.v) == self.scalar * L(self.v)

# SquaredL2 with Diagonal linop has a prox
Wo = self.Wo
L_d = loss.WeightedSquaredL2Loss(y=self.y, A=self.Do, weight_op=Wo)
L_d = loss.WeightedSquaredL2Loss(y=self.y, A=self.Do, W=self.W)

assert L_d.is_smooth == True
assert L_d.has_eval == True
assert L_d.has_prox == True

# test eval
np.testing.assert_allclose(
L_d(self.v), 0.5 * ((self.Wo @ (self.Do @ self.v - self.y)) ** 2).sum()
L_d(self.v), 0.5 * (self.W @ (self.Do @ self.v - self.y) ** 2).sum()
)

cL = self.scalar * L_d
Expand All @@ -412,6 +412,22 @@ def test_weighted_squared_l2(self):

pf = prox_test(self.v, L_d, L_d.prox, 0.75)

def test_poisson(self):
L = loss.PoissonLoss(y=self.y, A=self.Ao_abs)
assert L.is_smooth == None
assert L.has_eval == True
assert L.has_prox == False

# test eval
v = snp.abs(self.v)
Av = self.Ao_abs @ v
np.testing.assert_allclose(L(v), 0.5 * snp.sum(Av - self.y * snp.log(Av) + L.const))

cL = self.scalar * L
assert L.scale == 0.5 # hasn't changed
assert cL.scale == self.scalar * L.scale
assert cL(v) == self.scalar * L(v)


class TestBM3D:
def setup(self):
Expand Down

0 comments on commit 4684909

Please sign in to comment.