Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] JAX implementation #100

Merged
merged 1 commit into from
Apr 15, 2021
Merged

[WIP] JAX implementation #100

merged 1 commit into from
Apr 15, 2021

Conversation

bamos
Copy link
Collaborator

@bamos bamos commented Apr 5, 2021

Hey! Hope you all have been doing well. I've been using JAX and was finally able to connect this in for #7. This initial version is pretty similar to the PyTorch one and was easy with some of the recent JAX developments on enabling custom JAX primitives with a non-traced vjp. The interface follows our PyTorch/TF layers has an outer function that takes the CVXPY objects and then returns a callable into a primitive for the problem.

Let me know if you have any thoughts on this initial version! There are a few things left I'll finish up over the next few days:

  • Add DGP
  • Add the rest of the PyTorch tests and add JAX to the CI
  • Update the README (image and JAX example)
  • Document current limitations somewhere (e.g. can't use this with the JIT/vmap, and can't do higher-order derivatives)
  • Add a sample notebook

Some useful JAX references:

\cc @bodono

@bamos bamos requested review from akshayka and sbarratt April 5, 2021 02:22
@akshayka
Copy link
Member

akshayka commented Apr 6, 2021

This looks great!! Thanks so much for taking this on. I'd love to have this merged in, once it's finished.

@sbarratt
Copy link
Collaborator

sbarratt commented Apr 6, 2021

This looks great!! Thanks so much for taking this on. I'd love to have this merged in, once it's finished.

+1

@shoyer
Copy link

shoyer commented Apr 6, 2021

This is great!

A few ideas for extending this:

@bamos bamos force-pushed the jax branch 3 times, most recently from 1021eab to 17c0470 Compare April 8, 2021 02:55
@bamos
Copy link
Collaborator Author

bamos commented Apr 8, 2021

@akshayka @sbarratt, can you do a review through this now? The main layer and all tests are implemented and passing (with GP support), and I'll add the last piece with a sample/tutorial notebook tomorrow.

Copy link
Member

@akshayka akshayka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! It's awesome that we'll finally have a JAX layer.

With this change we now have a lot of duplicated code, across the three layers. It would be great if this code could be unified and re-used by the three layers. You don't need to do that for this PR, though you're more than welcome to do so if you have the inclination.

@@ -5,7 +5,7 @@
# cvxpylayers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point we should update the CVXPY Layers logo to include JAX :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done :)

Comment on lines +36 to +80
if gp:
if not problem.is_dgp(dpp=True):
raise ValueError('Problem must be DPP.')
else:
if not problem.is_dcp(dpp=True):
raise ValueError('Problem must be DPP.')

if not set(problem.parameters()) == set(parameters):
raise ValueError("The layer's parameters must exactly match "
"problem.parameters")
if not set(variables).issubset(set(problem.variables())):
raise ValueError("Argument variables must be a subset of "
"problem.variables")
if not isinstance(parameters, list) and \
not isinstance(parameters, tuple):
raise ValueError("The layer's parameters must be provided as "
"a list or tuple")
if not isinstance(variables, list) and \
not isinstance(variables, tuple):
raise ValueError("The layer's variables must be provided as "
"a list or tuple")

var_dict = {v.id for v in variables}

# Construct compiler
param_order = parameters
if gp:
for param in parameters:
if param.value is None:
raise ValueError("An initial value for each parameter is "
"required when gp=True.")
data, solving_chain, _ = problem.get_problem_data(
solver=cp.SCS, gp=True)
compiler = data[cp.settings.PARAM_PROB]
dgp2dcp = solving_chain.get(cp.reductions.Dgp2Dcp)
param_ids = [p.id for p in compiler.parameters]
old_params_to_new_params = (
dgp2dcp.canon_methods._parameters
)
else:
data, _, _ = problem.get_problem_data(solver=cp.SCS)
compiler = data[cp.settings.PARAM_PROB]
param_ids = [p.id for p in param_order]
dgp2dcp = None
cone_dims = dims_to_solver_dict(data["dims"])
Copy link
Member

@akshayka akshayka Apr 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should eventually abstract this into a common function, used by all three layers. Maintaining three separate versions of the same code won't be fun.

No pressure to do this now (though you're more than welcome to do so, if you want to).

Copy link
Collaborator

@sbarratt sbarratt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

@bamos
Copy link
Collaborator Author

bamos commented Apr 12, 2021

I just pushed a notebook with an intro to the JAX version, squashed the commits, and added followup development issues. I may put in a few last tweaks and will merge in soon once CI passes! Let me know if you all have any other comments or anything else you want to get in before this is merged.

@froystig
Copy link

To add to @shoyer's comment:

Another approach is to use XLA's custom call via jax. This tutorial covers the topic well.

@akshayka
Copy link
Member

I just pushed a notebook with an intro to the JAX version, squashed the commits, and added followup development issues. I may put in a few last tweaks and will merge in soon once CI passes! Let me know if you all have any other comments or anything else you want to get in before this is merged.

Awesome! I'll take a look tonight or tomorrow morning. Excited for this to be merged in.

@bamos bamos merged commit af3c8f0 into master Apr 15, 2021
@bamos bamos deleted the jax branch April 15, 2021 01:56
@bamos bamos mentioned this pull request Apr 15, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants