-
Notifications
You must be signed in to change notification settings - Fork 156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax support #7
Comments
I would like to try, is that ok? |
Yes! Maybe you or @gsp-27 can take a lock on the initial version, send in a PR, and we can all coordinate afterwards? What do you all think? |
Sounds great |
Hi @joaogui1, sure give it a shot. I can also help. Thanks, |
any update on this ? |
I didn't manage to do it, sorry |
lack of time or too high technical difficulty ? |
High technical difficulty, don't understand enough about CVXPY to understand the errors |
I'm taking a crack at this |
@sbarratt, you should be able to use https://jax.readthedocs.io/en/latest/jax.html#jax.custom_gradient to define the gradient. |
That doesn't work. It still tries to trace the function. We need
https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html
to define a custom non-traceable python function.
…On Thu, Jan 16, 2020 at 2:53 PM Akshay Agrawal ***@***.***> wrote:
@sbarratt <https://github.com/sbarratt>, you should be able to use
https://jax.readthedocs.io/en/latest/jax.html#jax.custom_gradient
to define the gradient.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#7?email_source=notifications&email_token=AB7LUGIWAAZFVVSIX4IDB7TQ6DQORA5CNFSM4JF6X5D2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEJF227Q#issuecomment-575384958>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AB7LUGK5FI5GOAIE6VP6MP3Q6DQORANCNFSM4JF6X5DQ>
.
|
Are there any updates on this? |
I've been recently thinking about adding a jax/flax version but haven't started.
… On Feb 17, 2021, at 7:38 AM, bodono ***@***.***> wrote:
Are there any updates on this?
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub, or unsubscribe.
|
@bodono do you (or anyone else) have any recommendations/preferences on what a JAX implementation would look like? Maybe it would make the most sense to have a pure JAX version that is easily wrappable with at least flax/haiku? This way we could (hopefully) easily connect in any future JAX libraries/frameworks |
I don't feel strongly about it and I'm very far from a jax expert, but I agree that having a pure jax version would probably be best. |
Hi all, did you considered to use eagerpy to factorize the three tensor implementations (tensorflow, pytorch and jax)? |
That looks useful! Here we'd need to define new operations in each of the frameworks with custom derivatives and I don't immediately see how to do that in the eagerpy docs |
Also this issue is resolved with #100 |
Yes sure, its true that eagerpy is more design to go in the other way around: write some code that uses a tensor framework to compute derivatives of it. So not sure it could help here. However, maybe the fact to have lightweighted unified tensor API may help to factorize a bit. I'll have a look at the Jax integration in #100! |
This would broadly involve:
jax
directory that follows the same structure as the PyTorch/Tensorflow ones to our main module and appropriately add that layer: https://github.com/cvxgrp/cvxpylayers/tree/master/cvxpylayers\cc @gsp-27 who has expressed interest in helping with this -- please let us know if you make any progress here!
The text was updated successfully, but these errors were encountered: