Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with W^1/2 weight exponent #78

Merged
merged 22 commits into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ Losses
In SCICO, a loss is a special type of functional

.. math::
f(\mb{x}) = a l( \mb{y}, A(\mb{x}) )
f(\mb{x}) = \alpha l( \mb{y}, A(\mb{x}) )

where :math:`a` is a scale parameter,
where :math:`\alpha` is a scaling parameter,
:math:`l` is a functional,
:math:`\mb{y}` is a set of measurements,
and :math:`A` is an operator.
Expand Down
5 changes: 2 additions & 3 deletions examples/scripts/ct_astra_weighted_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,8 @@ def postprocess(x):
"""
lambda_weighted = 1.14e2

weights = counts / Io # scale by Io to balance the data vs regularization term
W = linop.Diagonal(snp.sqrt(weights))
f = loss.WeightedSquaredL2Loss(y=y, A=A, weight_op=W)
weights = counts / Io # scaled by Io to balance the data vs regularization term
f = loss.WeightedSquaredL2Loss(y=y, A=A, W=linop.Diagonal(weights))

admm_weighted = ADMM(
f=f,
Expand Down
4 changes: 1 addition & 3 deletions examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@
ρ = 100 # ADMM penalty parameter
σ = density * 0.2 # denoiser sigma

weight_op = Diagonal(weights ** 0.5)

f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, weight_op=weight_op, scale=0.5)
f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5)
g0 = σ * ρ * BM3D()
g1 = NonNegativeIndicator()

Expand Down
4 changes: 1 addition & 3 deletions examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@
ρ = 10 # ADMM penalty parameter
σ = density * 0.26 # denoiser sigma

weight_op = Diagonal(weights ** 0.5)

f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, weight_op=weight_op, scale=0.5)
f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5)
g0 = σ * ρ * BM3D()
g1 = NonNegativeIndicator()

Expand Down
4 changes: 2 additions & 2 deletions scico/admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class LinearSubproblemSolver(SubproblemSolver):
\mb{x}^{(k+1)} = \argmin_{\mb{x}} \; \frac{1}{2} \norm{\mb{y} - A x}_W^2 +
\sum_i \frac{\rho_i}{2} \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \;,

where :math:`W` is the weighting :class:`.LinearOperator` from the
where :math:`W` is the weighting :class:`.Diagonal` from the
:class:`.WeightedSquaredL2Loss` instance. This update step reduces to the
solution of the linear system

Expand Down Expand Up @@ -208,7 +208,7 @@ def compute_rhs(self) -> Union[JaxArray, BlockArray]:

if self.admm.f is not None:
if isinstance(self.admm.f, WeightedSquaredL2Loss):
ATWy = self.admm.f.A.adj(self.admm.f.weight_op @ self.admm.f.y)
ATWy = self.admm.f.A.adj(self.admm.f.W.diagonal @ self.admm.f.y)
rhs += 2.0 * self.admm.f.scale * ATWy
else:
ATy = self.admm.f.A.adj(self.admm.f.y)
Expand Down
13 changes: 2 additions & 11 deletions scico/linop/radon_svmbir.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from scico.loss import WeightedSquaredL2Loss
from scico.typing import JaxArray, Shape

from ._linop import Diagonal, LinearOperator
from ._linop import LinearOperator


class ParallelBeamProjector(LinearOperator):
Expand Down Expand Up @@ -120,21 +120,12 @@ def __init__(self, *args, **kwargs):
"to instantiate a `SVMBIRWeightedSquaredL2Loss`."
)

if not isinstance(self.weight_op, Diagonal):
raise ValueError(
f"`weight_op` must be `Diagonal` but instead got {type(self.weight_op)}"
)

self.weights = (
snp.conj(self.weight_op.diagonal) * self.weight_op.diagonal
) # because weight_op is W^{1/2}

self.has_prox = True

def prox(self, v: JaxArray, lam: float) -> JaxArray:
v = v.reshape(self.A.svmbir_input_shape)
y = self.y.reshape(self.A.svmbir_output_shape)
weights = self.weights.reshape(self.A.svmbir_output_shape)
weights = self.W.diagonal.reshape(self.A.svmbir_output_shape)
sigma_p = snp.sqrt(lam)
result = svmbir.recon(
np.array(y),
Expand Down
106 changes: 65 additions & 41 deletions scico/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ class Loss(functional.Functional):
r"""Generic Loss function.

.. math::
\mathrm{scale} \cdot l(\mb{y}, A(\mb{x})) \;
\alpha l(\mb{y}, A(\mb{x})) \;

where :math:`\alpha` is the scaling parameter and :math:`l(\cdot)` is the loss functional.

"""

Expand All @@ -55,7 +57,7 @@ def __init__(
r"""Initialize a :class:`Loss` object.

Args:
y : Measurements
y : Measurement.
A : Forward operator. Defaults to None. If None, ``self.A`` is a :class:`.Identity`.
scale : Scaling parameter. Default: 0.5.

Expand Down Expand Up @@ -109,7 +111,9 @@ class SquaredL2Loss(Loss):
Squared :math:`\ell_2` loss.

.. math::
\mathrm{scale} \cdot \norm{\mb{y} - A(\mb{x})}_2^2 \;
\alpha \norm{\mb{y} - A(\mb{x})}_2^2 \;

where :math:`\alpha` is the scaling parameter.

"""

Expand All @@ -119,8 +123,14 @@ def __init__(
A: Optional[Union[Callable, operator.Operator]] = None,
scale: float = 0.5,
):
r"""Initialize a :class:`SquaredL2Loss` object.

Args:
y : Measurement.
A : Forward operator. If None, defaults to :class:`.Identity`.
scale : Scaling parameter.
"""
y = ensure_on_device(y)
self.functional = functional.SquaredL2Norm()
super().__init__(y=y, A=A, scale=scale)

if isinstance(A, operator.Operator):
Expand All @@ -140,7 +150,7 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
Args:
x : Point at which to evaluate loss.
"""
return self.scale * self.functional(self.y - self.A(x))
return self.scale * (snp.abs(self.y - self.A(x)) ** 2).sum()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change desirable? Dropping the use if self.functional seems like it may have consequences for derived classes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it does. Does it?

I have this to keep it consistent with the weighted case where we can avoid computing the square root of the weights.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, and that's a worthwhile goal, but we should make sure there aren't any undesirable consequences before we make this change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind this change for the time being.

Longer term, consider: Why is functional a property at all if it is not used here? Broadly, I think this discussion is a symptom of the existence of Loss. One might hope to implement __call__ at the level of Loss (and therefore using self.functional) to reduce repeated code. But we can't do that, because Loss is so general it doesn't know that it should be A@x - y.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed: if this change is made, we should consider removing the functional attribute. With respect to Loss, do you recall why it's so general? If there's good reason, perhaps we should have a specialization that really is A@x - y?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason it is general is because of the Poisson. I think the reason it exists at all is that it used to not be a subclass of Functional and therefore having a base Loss class made sense.

def prox(self, x: Union[JaxArray, BlockArray], lam: float) -> Union[JaxArray, BlockArray]:
if isinstance(self.A, linop.Diagonal):
Expand All @@ -154,7 +164,7 @@ def prox(self, x: Union[JaxArray, BlockArray], lam: float) -> Union[JaxArray, Bl
@property
def hessian(self) -> linop.LinearOperator:
r"""If ``self.A`` is a :class:`.LinearOperator`, returns a new :class:`.LinearOperator` corresponding
to Hessian :math:`\mathrm{A^*A}`.
to Hessian :math:`2 \alpha \mathrm{A^* A}`.

bwohlberg marked this conversation as resolved.
Show resolved Hide resolved
Otherwise not implemented.
"""
Expand All @@ -171,11 +181,12 @@ class WeightedSquaredL2Loss(Loss):
Weighted squared :math:`\ell_2` loss.

.. math::
\mathrm{scale} \cdot \norm{\mb{y} - A(\mb{x})}_{\mathrm{W}}^2 =
\mathrm{scale} \cdot \norm{\mathrm{W}^{1/2} \left( \mb{y} - A(\mb{x})\right)}_2^2\;
\alpha \norm{\mb{y} - A(\mb{x})}_W^2 =
\alpha \left(\mb{y} - A(\mb{x})\right)^T W \left(\mb{y} - A(\mb{x})\right)\;

Where :math:`\mathrm{W}` is an instance of :class:`scico.linop.LinearOperator`. If
:math:`\mathrm{W}` is None, reverts to the behavior of :class:`.SquaredL2Loss`.
where :math:`\alpha` is the scaling parameter and :math:`W` is an
instance of :class:`scico.linop.Diagonal`. If :math:`W` is None,
reverts to the behavior of :class:`.SquaredL2Loss`.

"""

Expand All @@ -184,30 +195,32 @@ def __init__(
y: Union[JaxArray, BlockArray],
A: Optional[Union[Callable, operator.Operator]] = None,
scale: float = 0.5,
weight_op: Optional[operator.Operator] = None,
W: Optional[linop.Diagonal] = None,
):

r"""Initialize a :class:`WeightedSquaredL2Loss` object.

Args:
y : Measurements
y : Measurement.
A : Forward operator. If None, defaults to :class:`.Identity`.
scale : Scaling parameter
weight_op: Weighting linear operator. Corresponds to :math:`W^{1/2}`
in the standard definition of the weighted squared :math:`\ell_2` loss.
scale : Scaling parameter.
W: Weighting diagonal operator. Must be non-negative.
If None, defaults to :class:`.Identity`.
"""
y = ensure_on_device(y)

self.weight_op: operator.Operator
self.W: linop.Diagonal

self.functional = functional.SquaredL2Norm()
if weight_op is None:
self.weight_op = linop.Identity(y.shape)
elif isinstance(weight_op, linop.LinearOperator):
self.weight_op = weight_op
if W is None:
self.W = linop.Identity(y.shape)
elif isinstance(W, linop.Diagonal):
if snp.all(W.diagonal >= 0):
self.W = W
else:
raise Exception(f"The weights, W.diagonal, must be non-negative.")
else:
raise TypeError(f"weight_op must be None or a LinearOperator, got {type(weight_op)}")
raise TypeError(f"W must be None or a linop.Diagonal, got {type(W)}")

super().__init__(y=y, A=A, scale=scale)

if isinstance(A, operator.Operator):
Expand All @@ -218,40 +231,43 @@ def __init__(
if isinstance(self.A, linop.LinearOperator):
self.is_quadratic = True

if isinstance(self.A, linop.Diagonal) and isinstance(self.weight_op, linop.Diagonal):
if isinstance(self.A, linop.Diagonal) and isinstance(self.W, linop.Diagonal):
self.has_prox = True

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
return self.scale * self.functional(self.weight_op(self.y - self.A(x)))
return self.scale * (self.W.diagonal * snp.abs(self.y - self.A(x)) ** 2).sum()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See earlier comment on similar lines.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI seems to be failing because the changes would slightly reduce the test coverage percentage.

Let's see whether that is still the case when we add the test for the Hessian.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still unhappy, it seems. It would be best to address this before we merge.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a test for loss.PoissonLoss which will, I assume, resolve this.

@Michael-T-McCann : Would you not agree that the Loss tests should be in a separate test_loss.py file rather than included in test_functional.py?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loss is currently a subclass of Functional (was not always this way). Therefore I think it makes sense for the losses to get tested in test_functional.py, unless the file is much too long.

def prox(self, x: Union[JaxArray, BlockArray], lam: float) -> Union[JaxArray, BlockArray]:
if isinstance(self.A, linop.Diagonal):
c = self.scale * lam
c = 2.0 * self.scale * lam
A = self.A.diagonal
W = self.weight_op.diagonal
lhs = c * 2.0 * A.conj() * W * W.conj() * self.y + x
ATWTWA = c * 2.0 * A.conj() * W.conj() * W * A
return lhs / (ATWTWA + 1.0)
W = self.W.diagonal
lhs = c * A.conj() * W * self.y + x
ATWA = c * A.conj() * W * A
return lhs / (ATWA + 1.0)
else:
raise NotImplementedError

@property
def hessian(self) -> linop.LinearOperator:
r"""If ``self.A`` is a :class:`scico.linop.LinearOperator`, returns a
:class:`scico.linop.LinearOperator` corresponding to Hessian :math:`\mathrm{A^* W A}`.
:class:`scico.linop.LinearOperator` corresponding to the Hessian
:math:`2 \alpha \mathrm{A^* W A}`.

Otherwise not implemented.
"""
if isinstance(self.A, linop.LinearOperator):
A = self.A
W = self.W
if isinstance(A, linop.LinearOperator):
return linop.LinearOperator(
input_shape=self.A.input_shape,
output_shape=self.A.input_shape,
eval_fn=lambda x: 2 * self.scale * self.A.adj(self.weight_op(self.A(x))),
adj_fn=lambda x: 2 * self.scale * self.A.adj(self.weight_op(self.A(x))),
input_shape=A.input_shape,
output_shape=A.input_shape,
eval_fn=lambda x: 2 * self.scale * A.adj(W(A(x))),
adj_fn=lambda x: 2 * self.scale * A.adj(W(A(x))),
)
else:
raise NotImplementedError(
f"Hessian is not implemented for {type(self)} when `A` is {type(self.A)}; must be LinearOperator"
f"Hessian is not implemented for {type(self)} when `A` is {type(A)}; must be LinearOperator"
)


Expand All @@ -260,8 +276,9 @@ class PoissonLoss(Loss):
Poisson negative log likelihood loss

.. math::
\mathrm{scale} \cdot \sum_i [A(x)]_i - y_i \log\left( [A(x)]_i \right) + \log(y_i!)
\alpha \left( \sum_i [A(x)]_i - y_i \log\left( [A(x)]_i \right) + \log(y_i!) \right)

where :math:`\alpha` is the scaling parameter.
"""

def __init__(
Expand All @@ -273,17 +290,24 @@ def __init__(
r"""Initialize a :class:`Loss` object.

Args:
y : Measurements
A : Forward operator. Defaults to None. If None, ``self.A`` is a :class:`.Identity`.
scale : Scaling parameter. Default: 0.5.

y : Measurement.
A : Forward operator. Defaults to None. If None, ``self.A`` is a :class:`.Identity`.
scale : Scaling parameter. Default: 0.5.
"""
y = ensure_on_device(y)
super().__init__(y=y, A=A, scale=scale)

#: Constant term in Poisson log likehood; equal to ln(y!)
self.const: float = gammaln(self.y + 1) # ln(y!)

# The "true" Poisson loss is not differentiable at zero, but the
# ε that we add to allow evaluation at zero has the side-effect
# of making it differentiable there.
if isinstance(A, operator.Operator):
self.is_smooth = A.is_smooth
else:
self.is_smooth = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Michael-T-McCann : Any thoughts on this addition? It's a bit ugly mathematically because we're labeling a non-smooth functional as smooth (because we smoothed it at zero), but this seems to be the right choice from a practical/computational perspective.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My vote: document the use of epsilon, set is_smooth=True.

The only way around the epsilon I come up with is accounting for dark count rate in the forward model.

Copy link
Contributor Author

@tbalke tbalke Nov 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even with the epsilon it is still not smooth for Ax < (-epsilon). I think a way to get around the mathematical imprecision is to say that the attribute we are looking for is whether it can be used in pgm or not.

If A(x) = exp(x) for example, then we can use the current PoissonLoss with no problem. But if A(x) is linear then we can get into trouble even when using the epsilon since Ax < (-epsilon) is very likely.

So smoothness of the loss \alpha L(y, A(x)) is not only dependent on L but also on A(.) or perhaps y.

@bwohlberg is the is_smooth attribute used for anything other than the check in pgm?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even with the epsilon it is still not smooth for Ax < (-epsilon).

Good point. Given that epsilon just shifts the problem, perhaps we should remove it? Alternatively, instead of adding epsilon, set values less than epsilon to epsilon?

@bwohlberg is the is_smooth attribute used for anything other than the check in pgm?

I'm not sure. That may be the only use at the moment. Perhaps the whole is_smooth mechanism is worth a re-think. How is the issue of zeros handled in the Poisson loss example with PGM?

Copy link
Contributor Author

@tbalke tbalke Nov 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that epsilon just shifts the problem, perhaps we should remove it? Alternatively, instead of adding epsilon, set values less than epsilon to epsilon?

Yes, this may be a better option. But see below.

Perhaps the whole is_smooth mechanism is worth a re-think.

Yes. I think if pgm rejected all non-smooth L(y, A(x)) for all x, we would be unnecessarily impractical. A looser condition of L(y, A(x)) smooth around x_0 or in some feasible set could be more practical.

How is the issue of zeros handled in the Poisson loss example with PGM?

(a) The initial condition is > 0 (almost surely).
(b) There is a non-negative indicator.
(c) Matrix A does not have negative entires
However, it seems like initializing to negatives or zero breaks the example code. I guess this kind of problem is not really that realistic. If we assume y to be Poisson, then in what world would A(x) be negative for any x in the feasible set?

I am leaning towards completely removing the epsilon and then then either A or the feasible set needs to be such that A(x)>0. (I think that currently would not break anything but I would not bet on it.)

@Michael-T-McCann ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had assumed the epsilon was necessary for some example or another. If it isn't, sure, remove it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue opened in #89


def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
ε = 1e-9 # So that loss < infinity
Ax = self.A(x)
Expand Down
4 changes: 2 additions & 2 deletions scico/test/linop/test_radon_svmbir.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,6 @@ def test_prox_weights(Nx, Ny, num_angles, num_channels, is_3d):

# test with weights
weights, _ = scico.random.uniform(sino.shape, dtype=im.dtype)
D = scico.linop.Diagonal(weights)
f = SVMBIRWeightedSquaredL2Loss(y=sino, A=A, weight_op=D)
W = scico.linop.Diagonal(weights)
f = SVMBIRWeightedSquaredL2Loss(y=sino, A=A, W=W)
prox_test(v, f, f.prox, alpha=0.25)
28 changes: 22 additions & 6 deletions scico/test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,9 @@ def setup_method(self):
W, key = randn((n,), key=key, dtype=dtype)
W = 0.1 * W + 1.0
self.Ao = linop.MatrixOperator(A)
self.Ao_abs = linop.MatrixOperator(snp.abs(A))
self.Do = linop.Diagonal(D)
self.Wo = linop.Diagonal(W)
self.W = linop.Diagonal(W)
self.y, key = randn((n,), key=key, dtype=dtype)
self.v, key = randn((n,), key=key, dtype=dtype) # point for prox eval
scalar, key = randn((1,), key=key, dtype=dtype)
Expand Down Expand Up @@ -377,14 +378,14 @@ def test_squared_l2(self):
pf = prox_test(self.v, L_d, L_d.prox, 0.75)

def test_weighted_squared_l2(self):
L = loss.WeightedSquaredL2Loss(y=self.y, A=self.Ao, weight_op=self.Wo)
L = loss.WeightedSquaredL2Loss(y=self.y, A=self.Ao, W=self.W)
assert L.is_smooth == True
assert L.has_eval == True
assert L.has_prox == False # not diagonal

# test eval
np.testing.assert_allclose(
L(self.v), 0.5 * ((self.Wo @ (self.Ao @ self.v - self.y)) ** 2).sum()
L(self.v), 0.5 * (self.W @ (self.Ao @ self.v - self.y) ** 2).sum()
)

cL = self.scalar * L
Expand All @@ -393,16 +394,15 @@ def test_weighted_squared_l2(self):
assert cL(self.v) == self.scalar * L(self.v)

# SquaredL2 with Diagonal linop has a prox
Wo = self.Wo
L_d = loss.WeightedSquaredL2Loss(y=self.y, A=self.Do, weight_op=Wo)
L_d = loss.WeightedSquaredL2Loss(y=self.y, A=self.Do, W=self.W)

assert L_d.is_smooth == True
assert L_d.has_eval == True
assert L_d.has_prox == True

# test eval
np.testing.assert_allclose(
L_d(self.v), 0.5 * ((self.Wo @ (self.Do @ self.v - self.y)) ** 2).sum()
L_d(self.v), 0.5 * (self.W @ (self.Do @ self.v - self.y) ** 2).sum()
)

cL = self.scalar * L_d
Expand All @@ -412,6 +412,22 @@ def test_weighted_squared_l2(self):

pf = prox_test(self.v, L_d, L_d.prox, 0.75)

def test_poisson(self):
L = loss.PoissonLoss(y=self.y, A=self.Ao_abs)
assert L.is_smooth == True
assert L.has_eval == True
assert L.has_prox == False

# test eval
v = snp.abs(self.v)
Av = self.Ao_abs @ v
np.testing.assert_allclose(L(v), 0.5 * snp.sum(Av - self.y * snp.log(Av) + L.const))

cL = self.scalar * L
assert L.scale == 0.5 # hasn't changed
assert cL.scale == self.scalar * L.scale
assert cL(v) == self.scalar * L(v)


class TestBM3D:
def setup(self):
Expand Down