Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement stateless wrapper. #246

Merged
merged 7 commits into from
Dec 21, 2021
Merged

Implement stateless wrapper. #246

merged 7 commits into from
Dec 21, 2021

Conversation

n2cholas
Copy link
Contributor

Closes #104.

Not sure if this belongs in wrappers.py, please let me know if I should move it elsewhere.

@google-cla google-cla bot added the cla: yes copybara label for automatic import label Nov 26, 2021
Copy link
Member

@mkunesch mkunesch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Thank you so much for this PR!

I'll do a detailed review tomorrow, but I have two quick high-level questions/comments:

  • I think this code could actually be part of base.py. It's not really a wrapper in that it doesn't wrap a gradient transformation; instead, I think it's a really fundamental tool for creating gradient transformation so could go into base.py. What do you think?
  • Could we replace StatelessState by base.EmptyState or is there an advantage to stateless having its own state?

Thanks a lot again and as I said, I'll take a detailed look tomorrow!

@n2cholas
Copy link
Contributor Author

  • I agree, I'll move it to base.py.
  • Agree with this as well, completely forgot about base.EmptyState!

Copy link
Collaborator

@rosshemsley rosshemsley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I like this idea, I can imagine using this for creating quick transformations :)

I think I would be inclined to simplify this a bit, and to remove the on_leaves part, (and also, I don't htink we need to check for params being None).

We could also add an explicit example to the docs of using this as a decorator:

def stateless_gradient_transformation(f) -> optax.GradientTransformation:
  def init_fn(_):
    return optax.EmptyState()

  def update_fn(updates, state, params=None):
      return f(updates, params), state

  return optax.GradientTransformation(init_fn, update_fn)


@stateless_gradient_transformation
def double_grads(updates, params):
  return jax.tree_map(lambda x: 2*x, updates)


opt = optax.chain(optax.adam(1e-4), double_grads)
state = opt.init(jax.numpy.array([]))

What do you think?

(Note, for a parametrized example, one could write:

def gradient_multiplier(factor):

    @stateless_gradient_transformation
    def transformation(updates, params):
         return jax.tree_map(lambda x: x * factor, updates)

    return transformation


opt = optax.chain(optax.adam(1e-4), gradient_multiplier(2.0))

optax/_src/base.py Outdated Show resolved Hide resolved
optax/_src/base.py Show resolved Hide resolved
optax/_src/base.py Outdated Show resolved Hide resolved
optax/_src/base.py Outdated Show resolved Hide resolved
optax/__init__.py Show resolved Hide resolved
@rosshemsley
Copy link
Collaborator

(@mkunesch pointed out to me that there's already a thread on this in the issues - sorry I hadn't read that! I'll let you both figure out what to do here :) )

@n2cholas
Copy link
Contributor Author

n2cholas commented Dec 1, 2021

Thanks @rosshemsley for the feedback! I resolved a couple of your comments and left responses to others in the code review.

Copy link
Member

@mkunesch mkunesch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! I think this looks great - thanks a lot! I added comments on the choice of names and documentation layout and some minor thoughts on the test.
Thanks again!

optax/__init__.py Show resolved Hide resolved
optax/_src/base.py Show resolved Hide resolved
optax/_src/base.py Outdated Show resolved Hide resolved
optax/_src/base_test.py Outdated Show resolved Hide resolved
optax/_src/base_test.py Show resolved Hide resolved
Copy link
Member

@mkunesch mkunesch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making all the changes, this looks great! The only comment I have is on the positioning in the documentation but otherwise LGTM!

optax/__init__.py Show resolved Hide resolved
@n2cholas
Copy link
Contributor Author

n2cholas commented Dec 9, 2021

Agreed, fixed!

@n2cholas
Copy link
Contributor Author

n2cholas commented Dec 9, 2021

Would it be possible to rerun the tests? The failures were in control_variates_test.py, which are unaffected by this PR.

Copy link
Member

@mkunesch mkunesch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I have triggered the test workflow again now that we have fixed the broken test you mentioned.

Copy link
Member

@mkunesch mkunesch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! The checks still don't pass ... could you sync the latest changes? In theory we fixed this error on Monday. I've also flagged a line which gives a wrong import order error in pylint.
Once the checks pass I think it's ready to merge! Thanks a lot again!

@@ -15,8 +15,11 @@
"""Tests for base.py."""

from absl.testing import absltest

import chex
import numpy as np
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Pylint will raise a wrong import order error here)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll fix this change, but running test.sh locally did not raise this issue for me (tried running pylint --rcfile=.pylintrc optax/_src/base_test.py separately as well).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for letting me know! This would definitely be blocked internally as we try to merge so I'll look into adding an import order check to test.sh.

@copybara-service copybara-service bot merged commit aedf82a into google-deepmind:master Dec 21, 2021
@mkunesch
Copy link
Member

mkunesch commented Dec 22, 2021

Thanks a lot again for suggesting and implementing the stateless optimizer. This was such a great idea! We could think about whether to implement a few of the existing gradient transformations in optax using stateless ... good targets could be recently added gradient transformations that have good test coverage.

Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes copybara label for automatic import
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Stateless Transformation
3 participants