Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove is_smooth attribute (#89) #184

Merged
merged 7 commits into from
Jan 27, 2022
Merged

remove is_smooth attribute (#89) #184

merged 7 commits into from
Jan 27, 2022

Conversation

tbalke
Copy link
Contributor

@tbalke tbalke commented Jan 19, 2022

  • requires consider re-building the examples

@tbalke tbalke requested a review from bwohlberg January 19, 2022 23:16
@codecov
Copy link

codecov bot commented Jan 19, 2022

Codecov Report

Merging #184 (5473521) into main (5a2ae28) will increase coverage by 0.00%.
The diff coverage is n/a.

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #184   +/-   ##
=======================================
  Coverage   92.22%   92.23%           
=======================================
  Files          48       48           
  Lines        3370     3347   -23     
=======================================
- Hits         3108     3087   -21     
+ Misses        262      260    -2     
Flag Coverage Δ
unittests 92.23% <ø> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
scico/_generic_operators.py 91.81% <ø> (-0.04%) ⬇️
scico/functional/_denoiser.py 90.90% <ø> (-0.14%) ⬇️
scico/functional/_flax.py 92.00% <ø> (-0.31%) ⬇️
scico/functional/_functional.py 88.52% <ø> (-1.19%) ⬇️
scico/functional/_indicator.py 100.00% <ø> (ø)
scico/functional/_norm.py 100.00% <ø> (ø)
scico/loss.py 89.10% <ø> (+0.53%) ⬆️
scico/optimize/pgm.py 95.31% <ø> (+0.46%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 5a2ae28...5473521. Read the comment docs.

@@ -39,8 +39,7 @@ An instance of :class:`.Functional`, ``f``, may provide three core operations.
- ``f.grad(x)`` returns the gradient of the functional evaluated at ``x``.
- Gradients are calculated using JAX reverse-mode automatic differentiation,
exposed through :func:`scico.grad`.
- A functional that is smooth has the attribute ``f.is_smooth == True``.
- NOTE: The gradient of a functional ``f`` can be evaluated even if ``f.is_smooth == False``.
- NOTE: The gradient of a functional ``f`` can be evaluated even if that functional is not smooth.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we refer to part of the jax docs to expand on this? Also, considering using markup (e.g. Note) rather than all caps for emphasis.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a look this morning and could not find any clear discussion of this in the jax docs. I can at least confirm that the statement is sometimes true, e.g., jax.grad(lambda x: jnp.abs(x))(0.0) == 1.0.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the statement is sometimes true, e.g., jax.grad(lambda x: jnp.abs(x))(0.0) == 1.0

I think this is an easier example because x is defined everywhere else but at 0.

Where it breaks is e.g. at x / (1 + jnp.sqrt(x)) which grads to nan at zero. See

I don't know what the general rule for scico is, though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, considering using markup

See 0fdfcf6

@@ -387,8 +387,7 @@ class PGM:

Minimize a function of the form :math:`f(\mb{x}) + g(\mb{x})`.

The function :math:`f` must be smooth and :math:`g` must have a
defined prox.
The function :math:`g` must have a defined prox.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's worth mentioning that f should be smooth, but that the algorithm will work for f with points at which it's non-smooth as long as those points are not visited during the optimization.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will at also work if valid subgradients are returned for nonsmooth points? It would be valuable to say so (with a ref) if it is true.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps @crstngc can comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's worth mentioning that f should be smooth, but that the algorithm will work for f with points at which it's non-smooth as long as those points are not visited during the optimization.

Generally agree, however,

as long as those points are not visited

Does as long as mean the same as if and only if?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's say

guaranteed to work if globally smooth

(leaving it open to the interpretation of the user what happens with not globally smooth losses)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See 02fe9bb

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bwohlberg please merge if you are happy with this.

Copy link
Collaborator

@bwohlberg bwohlberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved, modulo comments.

@bwohlberg bwohlberg linked an issue Jan 20, 2022 that may be closed by this pull request
@bwohlberg
Copy link
Collaborator

requires re-building the examples

In principle this is true, but we've been a bit slack about this recently (a better policy is probably to just do this prior to a new release, or after really major changes), so I think we can let this slide for this PR.

@bwohlberg bwohlberg added the improvement Improvement of existing code, including addressing of omissions or inconsistencies label Jan 21, 2022
@@ -387,7 +387,8 @@ class PGM:

Minimize a function of the form :math:`f(\mb{x}) + g(\mb{x})`.

The function :math:`g` must have a defined prox.
The function :math:`g` must have a defined prox and convergence is
guaranteed if :math:`f` is smooth.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't really true since f and g are usually also assumed to be convex. It's worth noting that the ADMM docs simply specify what the required types of these arguments are, without getting into convergence requirements. For now, perhaps better just to follow the ADMM example, i.e.

Solve an optimization problem of the form

    .. math::
        \argmin_{\mb{x}} \; f(\mb{x}) + g(\mb{x}) \;,

    where :math:`f` and :math:`g` are instances of :class:`.Functional`.

and also open a new issue as a reminder that more careful specifications (such as when functionals need to have proximal operators defined) should be added to all the optimizer classes.

Copy link
Contributor Author

@tbalke tbalke Jan 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see 5473521

@@ -552,7 +553,8 @@ class AcceleratedPGM(PGM):

Minimize a function of the form :math:`f(\mb{x}) + g(\mb{x})`.

The function :math:`g` must have a defined prox. The accelerated
The function :math:`g` must have a defined prox and convergence is
guaranteed if :math:`f` is smooth. The accelerated
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above.

Copy link
Contributor Author

@tbalke tbalke Jan 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see 5473521

@tbalke tbalke merged commit d3b433f into main Jan 27, 2022
@tbalke tbalke deleted the thilo/is_smooth branch January 27, 2022 23:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
improvement Improvement of existing code, including addressing of omissions or inconsistencies
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Loss is_smooth attribute
3 participants