-
Notifications
You must be signed in to change notification settings - Fork 161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement stateless wrapper. #246
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Thank you so much for this PR!
I'll do a detailed review tomorrow, but I have two quick high-level questions/comments:
- I think this code could actually be part of
base.py
. It's not really a wrapper in that it doesn't wrap a gradient transformation; instead, I think it's a really fundamental tool for creating gradient transformation so could go intobase.py
. What do you think? - Could we replace
StatelessState
bybase.EmptyState
or is there an advantage tostateless
having its own state?
Thanks a lot again and as I said, I'll take a detailed look tomorrow!
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I like this idea, I can imagine using this for creating quick transformations :)
I think I would be inclined to simplify this a bit, and to remove the on_leaves part, (and also, I don't htink we need to check for params being None).
We could also add an explicit example to the docs of using this as a decorator:
def stateless_gradient_transformation(f) -> optax.GradientTransformation:
def init_fn(_):
return optax.EmptyState()
def update_fn(updates, state, params=None):
return f(updates, params), state
return optax.GradientTransformation(init_fn, update_fn)
@stateless_gradient_transformation
def double_grads(updates, params):
return jax.tree_map(lambda x: 2*x, updates)
opt = optax.chain(optax.adam(1e-4), double_grads)
state = opt.init(jax.numpy.array([]))
What do you think?
(Note, for a parametrized example, one could write:
def gradient_multiplier(factor):
@stateless_gradient_transformation
def transformation(updates, params):
return jax.tree_map(lambda x: x * factor, updates)
return transformation
opt = optax.chain(optax.adam(1e-4), gradient_multiplier(2.0))
(@mkunesch pointed out to me that there's already a thread on this in the issues - sorry I hadn't read that! I'll let you both figure out what to do here :) ) |
Thanks @rosshemsley for the feedback! I resolved a couple of your comments and left responses to others in the code review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! I think this looks great - thanks a lot! I added comments on the choice of names and documentation layout and some minor thoughts on the test.
Thanks again!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making all the changes, this looks great! The only comment I have is on the positioning in the documentation but otherwise LGTM!
Agreed, fixed! |
Would it be possible to rerun the tests? The failures were in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I have triggered the test workflow again now that we have fixed the broken test you mentioned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! The checks still don't pass ... could you sync the latest changes? In theory we fixed this error on Monday. I've also flagged a line which gives a wrong import order error in pylint.
Once the checks pass I think it's ready to merge! Thanks a lot again!
optax/_src/base_test.py
Outdated
@@ -15,8 +15,11 @@ | |||
"""Tests for base.py.""" | |||
|
|||
from absl.testing import absltest | |||
|
|||
import chex | |||
import numpy as np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Pylint will raise a wrong import order error here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll fix this change, but running test.sh
locally did not raise this issue for me (tried running pylint --rcfile=.pylintrc optax/_src/base_test.py
separately as well).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for letting me know! This would definitely be blocked internally as we try to merge so I'll look into adding an import order check to test.sh
.
Thanks a lot again for suggesting and implementing the stateless optimizer. This was such a great idea! We could think about whether to implement a few of the existing gradient transformations in optax using Thanks again! |
Closes #104.
Not sure if this belongs in wrappers.py, please let me know if I should move it elsewhere.