Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement complex norm in optimizers #279

Merged
merged 3 commits into from
Jan 31, 2022
Merged

Implement complex norm in optimizers #279

merged 3 commits into from
Jan 31, 2022

Conversation

wdphy16
Copy link
Contributor

@wdphy16 wdphy16 commented Jan 14, 2022

As proposed in #196, I reviewed all transformations in transform.py and optimizers in alias.py, and replaced x**2 with (x.conj() * x).real to support complex numbers.

I think we need to double check the correctness of all changes. What tests do you think we should add?

I'm not sure how to generalize the transform sm3 to complex numbers. The optimizers fromage and noisy_sgd should be working, but in the tests I can't find a learning rate to make them work well.

@wdphy16
Copy link
Contributor Author

wdphy16 commented Jan 14, 2022

I rewrote AliasTest.test_optimization in alias_test.py to test for complex parameters on two target functions. By the way, I added the missing tests for all optimizers in alias.py.

@PhilipVinc
Copy link

this should also use the complex norm in gradient clipping. See #287

@wdphy16
Copy link
Contributor Author

wdphy16 commented Jan 24, 2022

Now I've also implemented complex gradient clipping in this PR (and closed #161), and I hope it doesn't make this PR look too large.

@PhilipVinc
Copy link

I'd suggest renaming abs_sqr to abs2, which is the name used in Scipy, C++ Eigen, Julia ...

@PhilipVinc
Copy link

@mkunesch If you could spare the time, could we get a review on this PR?

@wdphy16
Copy link
Contributor Author

wdphy16 commented Jan 24, 2022

I'd suggest renaming abs_sqr to abs2, which is the name used in Scipy, C++ Eigen, Julia ...

Yes I agree

Copy link
Collaborator

@hbq1 hbq1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! Thanks a lot for putting this together.

optax/_src/transform.py Outdated Show resolved Hide resolved
optax/_src/transform.py Outdated Show resolved Hide resolved
optax/_src/numerics.py Outdated Show resolved Hide resolved
optax/_src/alias_test.py Outdated Show resolved Hide resolved
dtype=(jnp.float32, jnp.complex64),
)
def test_optimization(self, opt_name, opt, target, dtype):
if (opt_name in ['fromage', 'noisy_sgd', 'sm3']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use a tuple here (as per google pystyle guide)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add this check to pylint in test.sh?

optax/_src/alias_test.py Outdated Show resolved Hide resolved
optax/_src/alias_test.py Outdated Show resolved Hide resolved
@zerothi
Copy link

zerothi commented Jan 26, 2022

I'd suggest renaming abs_sqr to abs2, which is the name used in Scipy, C++ Eigen, Julia ...

In scipy it has been decided to name it abs_sq, not abs2 FYI.

Copy link
Collaborator

@hbq1 hbq1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes! This PR looks good to me.

@hbq1
Copy link
Collaborator

hbq1 commented Jan 28, 2022

I'd suggest renaming abs_sqr to abs2, which is the name used in Scipy, C++ Eigen, Julia ...

In scipy it has been decided to name it abs_sq, not abs2 FYI.

@wdphy16 could you update the name to abs_sq for (forward?) consistency with scipy, please?

@wdphy16
Copy link
Contributor Author

wdphy16 commented Jan 28, 2022

@wdphy16 could you update the name to abs_sq for (forward?) consistency with scipy, please?

Done that

@PhilipVinc
Copy link

@hbq1 After you merge this PR and #288 could you also tag a new release?

@hbq1
Copy link
Collaborator

hbq1 commented Jan 28, 2022

@PhilipVinc we'll release a new version early next week if you don't mind (releasing something Friday evening might turn into a hectic "weekend" ;) )

@PhilipVinc
Copy link

PhilipVinc commented Jan 28, 2022

yeah of course!
I simply meant... after you merge this, please tag a new version an imprecised time sooner rather than later, so that we can benefit from this.

EDIT: Maybe it's not really important, but it might also be worth mentioning in the Readme that optimisers are not expected to work correctly when complex numbers are involved on version prior to (I guess) 0.1.1.

copybara-service bot pushed a commit that referenced this pull request Jan 31, 2022
@hbq1 hbq1 merged commit ee46275 into google-deepmind:master Jan 31, 2022
@hbq1
Copy link
Collaborator

hbq1 commented Jan 31, 2022

The PR has been merged!
I amended it a bit to make it pass all internal tests, which resulted in merge conflicts that Copybara couldn't automatically resolve, so I had to do that manually. That led to "Files Changed" tab showing empty diff, nevertheless all Blames seem to be correct and up-to-date. The full diff is here 8ff9ddd

@wdphy16 would it be OK if we keep it as it is, or would you prefer us to revert the changes and re-merge this PR?

Thank you once again for this great contribution! 🚀

@wdphy16
Copy link
Contributor Author

wdphy16 commented Feb 1, 2022

It's ok, thank you!

@wdphy16 wdphy16 deleted the complex_optimizer branch February 1, 2022 09:14
Copy link
Contributor Author

@wdphy16 wdphy16 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry that I forgot to submit the reviews before, and now I just want to keep them for future reference

optax/_src/transform.py Outdated Show resolved Hide resolved
optax/_src/transform.py Outdated Show resolved Hide resolved
@@ -145,7 +146,6 @@ def update_fn(grads, state, params):

def _update(grad, v_row, v_col, v, param, step):
shape = param.shape
grad = grad.astype(jnp.float32)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was there a type cast? Is it safe to remove it, so that we can keep the gradients to be complex?

optax/_src/alias_test.py Outdated Show resolved Hide resolved
dtype=(jnp.float32, jnp.complex64),
)
def test_optimization(self, opt_name, opt, target, dtype):
if (opt_name in ['fromage', 'noisy_sgd', 'sm3']
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add this check to pylint in test.sh?

optax/_src/numerics.py Outdated Show resolved Hide resolved
optax/_src/transform.py Outdated Show resolved Hide resolved
optax/_src/alias_test.py Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants