-
Notifications
You must be signed in to change notification settings - Fork 161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement complex norm in optimizers #279
Conversation
I rewrote |
this should also use the complex norm in gradient clipping. See #287 |
Now I've also implemented complex gradient clipping in this PR (and closed #161), and I hope it doesn't make this PR look too large. |
@mkunesch If you could spare the time, could we get a review on this PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! Thanks a lot for putting this together.
optax/_src/alias_test.py
Outdated
dtype=(jnp.float32, jnp.complex64), | ||
) | ||
def test_optimization(self, opt_name, opt, target, dtype): | ||
if (opt_name in ['fromage', 'noisy_sgd', 'sm3'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use a tuple here (as per google pystyle guide)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add this check to pylint in test.sh
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes! This PR looks good to me.
Done that |
@PhilipVinc we'll release a new version early next week if you don't mind (releasing something Friday evening might turn into a hectic "weekend" ;) ) |
yeah of course! EDIT: Maybe it's not really important, but it might also be worth mentioning in the Readme that optimisers are not expected to work correctly when complex numbers are involved on version prior to (I guess) |
PiperOrigin-RevId: 425367157
The PR has been merged! @wdphy16 would it be OK if we keep it as it is, or would you prefer us to revert the changes and re-merge this PR? Thank you once again for this great contribution! 🚀 |
It's ok, thank you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry that I forgot to submit the reviews before, and now I just want to keep them for future reference
optax/_src/factorized.py
Outdated
@@ -145,7 +146,6 @@ def update_fn(grads, state, params): | |||
|
|||
def _update(grad, v_row, v_col, v, param, step): | |||
shape = param.shape | |||
grad = grad.astype(jnp.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was there a type cast? Is it safe to remove it, so that we can keep the gradients to be complex?
optax/_src/alias_test.py
Outdated
dtype=(jnp.float32, jnp.complex64), | ||
) | ||
def test_optimization(self, opt_name, opt, target, dtype): | ||
if (opt_name in ['fromage', 'noisy_sgd', 'sm3'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add this check to pylint in test.sh
?
As proposed in #196, I reviewed all transformations in
transform.py
and optimizers inalias.py
, and replacedx**2
with(x.conj() * x).real
to support complex numbers.I think we need to double check the correctness of all changes. What tests do you think we should add?
I'm not sure how to generalize the transform
sm3
to complex numbers. The optimizersfromage
andnoisy_sgd
should be working, but in the tests I can't find a learning rate to make them work well.