Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax support #7

Closed
bamos opened this issue Oct 28, 2019 · 19 comments
Closed

Jax support #7

bamos opened this issue Oct 28, 2019 · 19 comments
Labels
help wanted Extra attention is needed

Comments

@bamos
Copy link
Collaborator

bamos commented Oct 28, 2019

This would broadly involve:

  1. Adding a jax directory that follows the same structure as the PyTorch/Tensorflow ones to our main module and appropriately add that layer: https://github.com/cvxgrp/cvxpylayers/tree/master/cvxpylayers
  2. Filling in the details of the module and tests
  3. Updating the README/image to include Jax
  4. Adding some Jax examples: https://github.com/cvxgrp/cvxpylayers/tree/master/examples

\cc @gsp-27 who has expressed interest in helping with this -- please let us know if you make any progress here!

@bamos bamos added the help wanted Extra attention is needed label Oct 30, 2019
@joaogui1
Copy link

I would like to try, is that ok?

@bamos
Copy link
Collaborator Author

bamos commented Oct 31, 2019

I would like to try, is that ok?

Yes! Maybe you or @gsp-27 can take a lock on the initial version, send in a PR, and we can all coordinate afterwards? What do you all think?

@joaogui1
Copy link

I would like to try, is that ok?

Yes! Maybe you or @gsp-27 can take a lock on the initial version, send in a PR, and we can all coordinate afterwards? What do you all think?

Sounds great

@gsp-27
Copy link

gsp-27 commented Nov 1, 2019

Hi @joaogui1, sure give it a shot. I can also help.

Thanks,
Gaurav Pathak

@rodrigob
Copy link

rodrigob commented Jan 8, 2020

any update on this ?

@joaogui1
Copy link

joaogui1 commented Jan 8, 2020

I didn't manage to do it, sorry

@rodrigob
Copy link

rodrigob commented Jan 9, 2020

lack of time or too high technical difficulty ?

@joaogui1
Copy link

joaogui1 commented Jan 9, 2020

High technical difficulty, don't understand enough about CVXPY to understand the errors

@sbarratt
Copy link
Collaborator

I'm taking a crack at this

@akshayka
Copy link
Member

@sbarratt, you should be able to use

https://jax.readthedocs.io/en/latest/jax.html#jax.custom_gradient

to define the gradient.

@sbarratt
Copy link
Collaborator

sbarratt commented Jan 16, 2020 via email

@bodono
Copy link
Member

bodono commented Feb 17, 2021

Are there any updates on this?

@bamos
Copy link
Collaborator Author

bamos commented Feb 17, 2021 via email

@bamos
Copy link
Collaborator Author

bamos commented Feb 20, 2021

@bodono do you (or anyone else) have any recommendations/preferences on what a JAX implementation would look like? Maybe it would make the most sense to have a pure JAX version that is easily wrappable with at least flax/haiku? This way we could (hopefully) easily connect in any future JAX libraries/frameworks

@bodono
Copy link
Member

bodono commented Feb 22, 2021

I don't feel strongly about it and I'm very far from a jax expert, but I agree that having a pure jax version would probably be best.

@eserie
Copy link

eserie commented Apr 15, 2021

Hi all, did you considered to use eagerpy to factorize the three tensor implementations (tensorflow, pytorch and jax)?

@bamos
Copy link
Collaborator Author

bamos commented Apr 15, 2021

Hi all, did you considered to use eagerpy to factorize the three tensor implementations (tensorflow, pytorch and jax)?

That looks useful! Here we'd need to define new operations in each of the frameworks with custom derivatives and I don't immediately see how to do that in the eagerpy docs

@bamos
Copy link
Collaborator Author

bamos commented Apr 15, 2021

Also this issue is resolved with #100

@bamos bamos closed this as completed Apr 15, 2021
@eserie
Copy link

eserie commented Apr 16, 2021

That looks useful! Here we'd need to define new operations in each of the frameworks with custom derivatives and I don't immediately see how to do that in the eagerpy docs

Yes sure, its true that eagerpy is more design to go in the other way around: write some code that uses a tensor framework to compute derivatives of it. So not sure it could help here. However, maybe the fact to have lightweighted unified tensor API may help to factorize a bit.

I'll have a look at the Jax integration in #100!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

8 participants