-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Problem with W^1/2 weight exponent #78
Changes from 5 commits
fc578da
603607b
16f5c63
67710c4
71deb29
1b12a7c
7db54df
911e846
89cd822
1d93d7f
d23dedd
f9b0ffc
a3cc78d
9a91bc2
adb5c67
86e12ca
1b274d3
8dc136d
d241130
ce30a7c
a4e259c
449d0b2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -119,6 +119,13 @@ def __init__( | |
A: Optional[Union[Callable, operator.Operator]] = None, | ||
scale: float = 0.5, | ||
): | ||
r"""Initialize a :class:`SquaredL2Loss` object. | ||
|
||
Args: | ||
y : Measurements. | ||
A : Forward operator. If None, defaults to :class:`.Identity`. | ||
scale : Scaling parameter. | ||
""" | ||
y = ensure_on_device(y) | ||
self.functional = functional.SquaredL2Norm() | ||
super().__init__(y=y, A=A, scale=scale) | ||
|
@@ -140,7 +147,7 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float: | |
Args: | ||
x : Point at which to evaluate loss. | ||
""" | ||
return self.scale * self.functional(self.y - self.A(x)) | ||
return self.scale * (snp.abs(self.y - self.A(x)) ** 2).sum() | ||
|
||
def prox(self, x: Union[JaxArray, BlockArray], lam: float) -> Union[JaxArray, BlockArray]: | ||
if isinstance(self.A, linop.Diagonal): | ||
|
@@ -154,7 +161,7 @@ def prox(self, x: Union[JaxArray, BlockArray], lam: float) -> Union[JaxArray, Bl | |
@property | ||
def hessian(self) -> linop.LinearOperator: | ||
r"""If ``self.A`` is a :class:`.LinearOperator`, returns a new :class:`.LinearOperator` corresponding | ||
to Hessian :math:`\mathrm{A^*A}`. | ||
to Hessian :math:`2 \mathrm{scale} \cdot \mathrm{A^* A}`. | ||
|
||
bwohlberg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Otherwise not implemented. | ||
""" | ||
|
@@ -171,11 +178,11 @@ class WeightedSquaredL2Loss(Loss): | |
Weighted squared :math:`\ell_2` loss. | ||
|
||
.. math:: | ||
\mathrm{scale} \cdot \norm{\mb{y} - A(\mb{x})}_{\mathrm{W}}^2 = | ||
\mathrm{scale} \cdot \norm{\mathrm{W}^{1/2} \left( \mb{y} - A(\mb{x})\right)}_2^2\; | ||
\mathrm{scale} \cdot \norm{\mb{y} - A(\mb{x})}_W^2 = | ||
\mathrm{scale} \cdot \left(\mb{y} - A(\mb{x})\right)^T W \left(\mb{y} - A(\mb{x})\right)\; | ||
|
||
Where :math:`\mathrm{W}` is an instance of :class:`scico.linop.LinearOperator`. If | ||
:math:`\mathrm{W}` is None, reverts to the behavior of :class:`.SquaredL2Loss`. | ||
Where :math:`W` is an instance of :class:`scico.linop.Diagonal`. If | ||
:math:`W` is None, reverts to the behavior of :class:`.SquaredL2Loss`. | ||
|
||
""" | ||
|
||
|
@@ -184,30 +191,33 @@ def __init__( | |
y: Union[JaxArray, BlockArray], | ||
A: Optional[Union[Callable, operator.Operator]] = None, | ||
scale: float = 0.5, | ||
weight_op: Optional[operator.Operator] = None, | ||
W: Optional[linop.Diagonal] = None, | ||
): | ||
|
||
r"""Initialize a :class:`WeightedSquaredL2Loss` object. | ||
|
||
Args: | ||
y : Measurements | ||
y : Measurements. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpick: change to "Measurement." There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed in 7db54df |
||
A : Forward operator. If None, defaults to :class:`.Identity`. | ||
scale : Scaling parameter | ||
weight_op: Weighting linear operator. Corresponds to :math:`W^{1/2}` | ||
in the standard definition of the weighted squared :math:`\ell_2` loss. | ||
scale : Scaling parameter. | ||
W: Weighting diagonal operator. Must be non-negative. | ||
If None, defaults to :class:`.Identity`. | ||
""" | ||
y = ensure_on_device(y) | ||
|
||
self.weight_op: operator.Operator | ||
self.W: linop.Diagonal | ||
|
||
self.functional = functional.SquaredL2Norm() | ||
if weight_op is None: | ||
self.weight_op = linop.Identity(y.shape) | ||
elif isinstance(weight_op, linop.LinearOperator): | ||
self.weight_op = weight_op | ||
if W is None: | ||
self.W = linop.Identity(y.shape) | ||
elif isinstance(W, linop.Diagonal): | ||
if np.all(W.diagonal >= 0): | ||
self.W = W | ||
else: | ||
raise Exception(f"The weights, W.diagonal, must be non-negative.") | ||
else: | ||
raise TypeError(f"weight_op must be None or a LinearOperator, got {type(weight_op)}") | ||
raise TypeError(f"W must be None or a linop.Diagonal, got {type(W)}") | ||
|
||
super().__init__(y=y, A=A, scale=scale) | ||
|
||
if isinstance(A, operator.Operator): | ||
|
@@ -218,40 +228,43 @@ def __init__( | |
if isinstance(self.A, linop.LinearOperator): | ||
self.is_quadratic = True | ||
|
||
if isinstance(self.A, linop.Diagonal) and isinstance(self.weight_op, linop.Diagonal): | ||
if isinstance(self.A, linop.Diagonal) and isinstance(self.W, linop.Diagonal): | ||
self.has_prox = True | ||
|
||
def __call__(self, x: Union[JaxArray, BlockArray]) -> float: | ||
return self.scale * self.functional(self.weight_op(self.y - self.A(x))) | ||
return self.scale * (self.W.diagonal * snp.abs(self.y - self.A(x)) ** 2).sum() | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See earlier comment on similar lines. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Let's see whether that is still the case when we add the test for the Hessian. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still unhappy, it seems. It would be best to address this before we merge. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a test for @Michael-T-McCann : Would you not agree that the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
def prox(self, x: Union[JaxArray, BlockArray], lam: float) -> Union[JaxArray, BlockArray]: | ||
if isinstance(self.A, linop.Diagonal): | ||
c = self.scale * lam | ||
c = 2.0 * self.scale * lam | ||
A = self.A.diagonal | ||
W = self.weight_op.diagonal | ||
lhs = c * 2.0 * A.conj() * W * W.conj() * self.y + x | ||
ATWTWA = c * 2.0 * A.conj() * W.conj() * W * A | ||
return lhs / (ATWTWA + 1.0) | ||
W = self.W.diagonal | ||
lhs = c * A.conj() * W * self.y + x | ||
ATWA = c * A.conj() * W * A | ||
return lhs / (ATWA + 1.0) | ||
else: | ||
raise NotImplementedError | ||
|
||
@property | ||
def hessian(self) -> linop.LinearOperator: | ||
r"""If ``self.A`` is a :class:`scico.linop.LinearOperator`, returns a | ||
:class:`scico.linop.LinearOperator` corresponding to Hessian :math:`\mathrm{A^* W A}`. | ||
:class:`scico.linop.LinearOperator` corresponding to the Hessian | ||
:math:`2 \mathrm{scale} \cdot \mathrm{A^* W A}`. | ||
|
||
Otherwise not implemented. | ||
""" | ||
if isinstance(self.A, linop.LinearOperator): | ||
A = self.A | ||
W = self.W | ||
if isinstance(A, linop.LinearOperator): | ||
return linop.LinearOperator( | ||
input_shape=self.A.input_shape, | ||
output_shape=self.A.input_shape, | ||
eval_fn=lambda x: 2 * self.scale * self.A.adj(self.weight_op(self.A(x))), | ||
adj_fn=lambda x: 2 * self.scale * self.A.adj(self.weight_op(self.A(x))), | ||
input_shape=A.input_shape, | ||
output_shape=A.input_shape, | ||
eval_fn=lambda x: 2 * self.scale * A.adj(W(A(x))), | ||
adj_fn=lambda x: 2 * self.scale * A.adj(W(A(x))), | ||
) | ||
else: | ||
raise NotImplementedError( | ||
f"Hessian is not implemented for {type(self)} when `A` is {type(self.A)}; must be LinearOperator" | ||
f"Hessian is not implemented for {type(self)} when `A` is {type(A)}; must be LinearOperator" | ||
) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change desirable? Dropping the use if
self.functional
seems like it may have consequences for derived classes.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it does. Does it?
I have this to keep it consistent with the weighted case where we can avoid computing the square root of the weights.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understood, and that's a worthwhile goal, but we should make sure there aren't any undesirable consequences before we make this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't mind this change for the time being.
Longer term, consider: Why is
functional
a property at all if it is not used here? Broadly, I think this discussion is a symptom of the existence ofLoss
. One might hope to implement__call__
at the level ofLoss
(and therefore usingself.functional
) to reduce repeated code. But we can't do that, becauseLoss
is so general it doesn't know that it should beA@x - y
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed: if this change is made, we should consider removing the
functional
attribute. With respect toLoss
, do you recall why it's so general? If there's good reason, perhaps we should have a specialization that really isA@x - y
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the reason it is general is because of the Poisson. I think the reason it exists at all is that it used to not be a subclass of
Functional
and therefore having a baseLoss
class made sense.