Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove is_smooth attribute (#89) #184

Merged
merged 7 commits into from
Jan 27, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ An instance of :class:`.Functional`, ``f``, may provide three core operations.
- ``f.grad(x)`` returns the gradient of the functional evaluated at ``x``.
- Gradients are calculated using JAX reverse-mode automatic differentiation,
exposed through :func:`scico.grad`.
- A functional that is smooth has the attribute ``f.is_smooth == True``.
- NOTE: The gradient of a functional ``f`` can be evaluated even if ``f.is_smooth == False``.
- NOTE: The gradient of a functional ``f`` can be evaluated even if that functional is not smooth.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we refer to part of the jax docs to expand on this? Also, considering using markup (e.g. Note) rather than all caps for emphasis.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a look this morning and could not find any clear discussion of this in the jax docs. I can at least confirm that the statement is sometimes true, e.g., jax.grad(lambda x: jnp.abs(x))(0.0) == 1.0.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the statement is sometimes true, e.g., jax.grad(lambda x: jnp.abs(x))(0.0) == 1.0

I think this is an easier example because x is defined everywhere else but at 0.

Where it breaks is e.g. at x / (1 + jnp.sqrt(x)) which grads to nan at zero. See

I don't know what the general rule for scico is, though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, considering using markup

See 0fdfcf6

All that is required is that the functional can be evaluated, ``f.has_eval == True``.
However, the result may not be a valid gradient (or subgradient) for all inputs.
* Proximal operator
Expand Down Expand Up @@ -92,7 +91,7 @@ in the parameterized form :math:`\mathrm{prox}_{c f}`.

In SCICO, multiplying a :class:`.Functional` by a scalar
will return a :class:`.ScaledFunctional`.
This :class:`.ScaledFunctional` retains the ``has_eval``, ``is_smooth``, and ``has_prox`` attributes
This :class:`.ScaledFunctional` retains the ``has_eval`` and ``has_prox`` attributes
from the original :class:`.Functional`,
but the proximal method is modified to accomodate the additional scalar.

Expand Down Expand Up @@ -129,7 +128,7 @@ To add a new functional,
create a class which

1. inherits from base :class:`.Functional`;
2. has ``has_eval``, ``is_smooth``, and ``has_prox`` flags;
2. has ``has_eval`` and ``has_prox`` flags;
3. has ``_eval`` and ``prox`` methods, as necessary.

For example,
Expand All @@ -139,7 +138,6 @@ For example,
class MyFunctional(scico.functional.Functional):

has_eval = True
is_smooth = False
has_prox = True

def _eval(self, x: JaxArray) -> float:
Expand Down
4 changes: 0 additions & 4 deletions examples/scripts/denoise_tv_iso_pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ class IsoProjector(functional.Functional):

has_eval = True
has_prox = True
is_smooth = False

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
return 0.0
Expand All @@ -136,7 +135,6 @@ def prox(self, v: JaxArray, lam: float, **kwargs) -> JaxArray:
"""
reg_weight_iso = 1.4e0
f_iso = DualTVLoss(y=y, A=A, lmbda=reg_weight_iso)
f_iso.is_smooth = True
g_iso = IsoProjector()

solver_iso = AcceleratedPGM(
Expand Down Expand Up @@ -168,7 +166,6 @@ class AnisoProjector(functional.Functional):

has_eval = True
has_prox = True
is_smooth = False

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
return 0.0
Expand All @@ -186,7 +183,6 @@ def prox(self, v: JaxArray, lam: float, **kwargs) -> JaxArray:

reg_weight_aniso = 1.2e0
f = DualTVLoss(y=y, A=A, lmbda=reg_weight_aniso)
f.is_smooth = True
g = AnisoProjector()

solver = AcceleratedPGM(
Expand Down
1 change: 0 additions & 1 deletion examples/scripts/sparsecode_poisson_blkarr_pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def _eval(self, x: BlockArray) -> BlockArray:
Set up the loss function and the regularization.
"""
f = loss.PoissonLoss(y=y, A=A)
f.is_smooth = True
g = functional.NonNegativeIndicator()


Expand Down
1 change: 0 additions & 1 deletion examples/scripts/sparsecode_poisson_pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
"""
A = linop.MatrixOperator(D)
f = loss.PoissonLoss(y=y, A=A)
f.is_smooth = True
g = functional.NonNegativeIndicator()


Expand Down
12 changes: 0 additions & 12 deletions scico/_generic_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def __init__(
input_dtype: DType = np.float32,
output_dtype: Optional[DType] = None,
jit: bool = False,
is_smooth: bool = None,
):
r"""Operator init method.

Expand Down Expand Up @@ -160,9 +159,6 @@ def __init__(
self.shape = (self.output_shape, self.input_shape)
self.matrix_shape = (self.output_size, self.input_size)

#: True if this is a smooth mapping; false otherwise
self.is_smooth = is_smooth

if jit:
self.jit()

Expand Down Expand Up @@ -192,7 +188,6 @@ def __call__(
eval_fn=lambda z: self(x(z)),
input_dtype=self.input_dtype,
output_dtype=x.output_dtype,
is_smooth=(self.is_smooth and x.is_smooth),
)
raise ValueError(f"""Incompatible shapes {self.shape}, {x.shape} """)

Expand All @@ -216,7 +211,6 @@ def __add__(self, other):
eval_fn=lambda x: self(x) + other(x),
input_dtype=self.input_dtype,
output_dtype=result_type(self.output_dtype, other.output_dtype),
is_smooth=(self.is_smooth and other.is_smooth),
)
raise ValueError(f"shapes {self.shape} and {other.shape} do not match")
raise TypeError(f"Operation __add__ not defined between {type(self)} and {type(other)}")
Expand All @@ -230,7 +224,6 @@ def __sub__(self, other):
eval_fn=lambda x: self(x) - other(x),
input_dtype=self.input_dtype,
output_dtype=result_type(self.output_dtype, other.output_dtype),
is_smooth=(self.is_smooth and other.is_smooth),
)
raise ValueError(f"shapes {self.shape} and {other.shape} do not match")
raise TypeError(f"Operation __sub__ not defined between {type(self)} and {type(other)}")
Expand All @@ -243,7 +236,6 @@ def __mul__(self, other):
eval_fn=lambda x: other * self(x),
input_dtype=self.input_dtype,
output_dtype=result_type(self.output_dtype, other),
is_smooth=self.is_smooth,
)

def __neg__(self):
Expand All @@ -258,7 +250,6 @@ def __rmul__(self, other):
eval_fn=lambda x: other * self(x),
input_dtype=self.input_dtype,
output_dtype=result_type(self.output_dtype, other),
is_smooth=self.is_smooth,
)

@_wrap_mul_div_scalar
Expand All @@ -269,7 +260,6 @@ def __truediv__(self, other):
eval_fn=lambda x: self(x) / other,
input_dtype=self.input_dtype,
output_dtype=result_type(self.output_dtype, other),
is_smooth=self.is_smooth,
)

def jvp(self, primals, tangents):
Expand Down Expand Up @@ -354,7 +344,6 @@ def concat_args(args):
input_shape=input_shape,
output_shape=self.output_shape,
eval_fn=lambda x: self(concat_args(x)),
is_smooth=self.is_smooth,
)


Expand Down Expand Up @@ -467,7 +456,6 @@ def __init__(
input_dtype=input_dtype,
output_dtype=output_dtype,
jit=False,
is_smooth=True,
)

if not hasattr(self, "_adj"):
Expand Down
1 change: 0 additions & 1 deletion scico/functional/_denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ class BM3D(Functional):

has_eval = False
has_prox = True
is_smooth = False

def __init__(self, is_rgb: Optional[bool] = False):
r"""Initialize a :class:`BM3D` object.
Expand Down
1 change: 0 additions & 1 deletion scico/functional/_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class FlaxMap(Functional):

has_eval = False
has_prox = True
is_smooth = False

def __init__(self, model: Callable[..., nn.Module], variables: PyTree):
r"""Initialize a :class:`FlaxMap` object.
Expand Down
13 changes: 0 additions & 13 deletions scico/functional/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

"""Functional base class."""

import warnings
from typing import List, Optional, Union

import jax
Expand Down Expand Up @@ -38,19 +37,13 @@ class Functional:
#: This attribute must be overridden and set to True or False in any derived classes.
has_prox: Optional[bool] = None

#: True if this functional is differentiable, False otherwise.
#: Note that ``is_smooth = False`` does not preclude the use of the :func:`.grad` method.
#: This attribute must be overridden and set to True or False in any derived classes.
is_smooth: Optional[bool] = None

def __init__(self):
self._grad = scico.grad(self.__call__)

def __repr__(self):
return f"""{type(self)}
has_eval = {self.has_eval}
has_prox = {self.has_prox}
is_smooth = {self.is_smooth}
"""

def __mul__(self, other):
Expand Down Expand Up @@ -136,9 +129,6 @@ def grad(self, x: Union[JaxArray, BlockArray]):
Args:
x: Point at which to evaluate gradient.
"""
if not self.is_smooth: # could be True, False, or None
warnings.warn("This functional isn't smooth!", stacklevel=2)

return self._grad(x)


Expand All @@ -151,7 +141,6 @@ def __repr__(self):
def __init__(self, functional: Functional, scale: float):
self.functional = functional
self.scale = scale
self.is_smooth = functional.is_smooth
self.has_eval = functional.has_eval
self.has_prox = functional.has_prox
super().__init__()
Expand Down Expand Up @@ -209,7 +198,6 @@ def __init__(self, functional_list: List[Functional]):

self.has_eval: bool = all(fi.has_eval for fi in functional_list)
self.has_prox: bool = all(fi.has_prox for fi in functional_list)
self.is_smooth: bool = all(fi.is_smooth for fi in functional_list)

super().__init__()

Expand Down Expand Up @@ -256,7 +244,6 @@ class ZeroFunctional(Functional):

has_eval = True
has_prox = True
is_smooth = True

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
return 0.0
Expand Down
2 changes: 0 additions & 2 deletions scico/functional/_indicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class NonNegativeIndicator(Functional):

has_eval = True
has_prox = True
is_smooth = False

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
if snp.iscomplexobj(x):
Expand Down Expand Up @@ -87,7 +86,6 @@ class L2BallIndicator(Functional):

has_eval = True
has_prox = True
is_smooth = False

def __init__(self, radius: float = 1):
r"""Initialize a :class:`L2BallIndicator` object.
Expand Down
5 changes: 0 additions & 5 deletions scico/functional/_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class L0Norm(Functional):

has_eval = True
has_prox = True
is_smooth = False

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
return count_nonzero(x)
Expand Down Expand Up @@ -71,7 +70,6 @@ class L1Norm(Functional):

has_eval = True
has_prox = True
is_smooth = False

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
return snp.abs(x).sum()
Expand Down Expand Up @@ -118,7 +116,6 @@ class SquaredL2Norm(Functional):

has_eval = True
has_prox = True
is_smooth = True

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
# Directly implement the squared l2 norm to avoid nondifferentiable
Expand Down Expand Up @@ -152,7 +149,6 @@ class L2Norm(Functional):

has_eval = True
has_prox = True
is_smooth = False

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
return norm(x)
Expand Down Expand Up @@ -210,7 +206,6 @@ class L21Norm(Functional):

has_eval = True
has_prox = True
is_smooth = False

def __init__(self, l2_axis: int = 0):
r"""
Expand Down
8 changes: 0 additions & 8 deletions scico/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,6 @@ def __init__(
prox_kwargs = dict
self.prox_kwargs = prox_kwargs

if isinstance(A, operator.Operator):
self.is_smooth = A.is_smooth
else:
self.is_smooth = None

if isinstance(self.A, linop.Diagonal) and isinstance(self.W, linop.Diagonal):
self.has_prox = True

Expand Down Expand Up @@ -289,9 +284,6 @@ def __init__(
#: Constant term in Poisson log likehood; equal to ln(y!)
self.const: float = gammaln(self.y + 1) # ln(y!)

# The Poisson Loss is only smooth in the positive quadrant.
self.is_smooth = None

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
Ax = self.A(x)
return self.scale * snp.sum(Ax - self.y * snp.log(Ax) + self.const)
11 changes: 3 additions & 8 deletions scico/optimize/pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,7 @@ class PGM:

Minimize a function of the form :math:`f(\mb{x}) + g(\mb{x})`.

The function :math:`f` must be smooth and :math:`g` must have a
defined prox.
The function :math:`g` must have a defined prox.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's worth mentioning that f should be smooth, but that the algorithm will work for f with points at which it's non-smooth as long as those points are not visited during the optimization.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will at also work if valid subgradients are returned for nonsmooth points? It would be valuable to say so (with a ref) if it is true.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps @crstngc can comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's worth mentioning that f should be smooth, but that the algorithm will work for f with points at which it's non-smooth as long as those points are not visited during the optimization.

Generally agree, however,

as long as those points are not visited

Does as long as mean the same as if and only if?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's say

guaranteed to work if globally smooth

(leaving it open to the interpretation of the user what happens with not globally smooth losses)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See 02fe9bb

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bwohlberg please merge if you are happy with this.


Uses helper :class:`StepSize` to provide an estimate of the Lipschitz
constant :math:`L` of :math:`f`. The step size :math:`\alpha` is the
Expand Down Expand Up @@ -428,9 +427,6 @@ def __init__(
this parameter.
"""

if f.is_smooth is not True:
raise Exception(f"The functional f ({type(f)}) must be smooth.")

#: Functional or Loss to minimize; must have grad method defined.
self.f: Union[Loss, Functional] = f

Expand Down Expand Up @@ -556,9 +552,8 @@ class AcceleratedPGM(PGM):

Minimize a function of the form :math:`f(\mb{x}) + g(\mb{x})`.

The function :math:`f` must be smooth and :math:`g` must have a
defined prox. The accelerated form of PGM is also known as FISTA
:cite:`beck-2009-fast`.
The function :math:`g` must have a defined prox. The accelerated
form of PGM is also known as FISTA :cite:`beck-2009-fast`.

For documentation on inherited attributes, see :class:`.PGM`.
"""
Expand Down
Loading