-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] JAX implementation #100
Conversation
This looks great!! Thanks so much for taking this on. I'd love to have this merged in, once it's finished. |
+1 |
This is great! A few ideas for extending this:
|
1021eab
to
17c0470
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! It's awesome that we'll finally have a JAX layer.
With this change we now have a lot of duplicated code, across the three layers. It would be great if this code could be unified and re-used by the three layers. You don't need to do that for this PR, though you're more than welcome to do so if you have the inclination.
@@ -5,7 +5,7 @@ | |||
# cvxpylayers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At some point we should update the CVXPY Layers logo to include JAX :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done :)
if gp: | ||
if not problem.is_dgp(dpp=True): | ||
raise ValueError('Problem must be DPP.') | ||
else: | ||
if not problem.is_dcp(dpp=True): | ||
raise ValueError('Problem must be DPP.') | ||
|
||
if not set(problem.parameters()) == set(parameters): | ||
raise ValueError("The layer's parameters must exactly match " | ||
"problem.parameters") | ||
if not set(variables).issubset(set(problem.variables())): | ||
raise ValueError("Argument variables must be a subset of " | ||
"problem.variables") | ||
if not isinstance(parameters, list) and \ | ||
not isinstance(parameters, tuple): | ||
raise ValueError("The layer's parameters must be provided as " | ||
"a list or tuple") | ||
if not isinstance(variables, list) and \ | ||
not isinstance(variables, tuple): | ||
raise ValueError("The layer's variables must be provided as " | ||
"a list or tuple") | ||
|
||
var_dict = {v.id for v in variables} | ||
|
||
# Construct compiler | ||
param_order = parameters | ||
if gp: | ||
for param in parameters: | ||
if param.value is None: | ||
raise ValueError("An initial value for each parameter is " | ||
"required when gp=True.") | ||
data, solving_chain, _ = problem.get_problem_data( | ||
solver=cp.SCS, gp=True) | ||
compiler = data[cp.settings.PARAM_PROB] | ||
dgp2dcp = solving_chain.get(cp.reductions.Dgp2Dcp) | ||
param_ids = [p.id for p in compiler.parameters] | ||
old_params_to_new_params = ( | ||
dgp2dcp.canon_methods._parameters | ||
) | ||
else: | ||
data, _, _ = problem.get_problem_data(solver=cp.SCS) | ||
compiler = data[cp.settings.PARAM_PROB] | ||
param_ids = [p.id for p in param_order] | ||
dgp2dcp = None | ||
cone_dims = dims_to_solver_dict(data["dims"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should eventually abstract this into a common function, used by all three layers. Maintaining three separate versions of the same code won't be fun.
No pressure to do this now (though you're more than welcome to do so, if you want to).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
I just pushed a notebook with an intro to the JAX version, squashed the commits, and added followup development issues. I may put in a few last tweaks and will merge in soon once CI passes! Let me know if you all have any other comments or anything else you want to get in before this is merged. |
To add to @shoyer's comment:
Another approach is to use XLA's custom call via jax. This tutorial covers the topic well. |
Awesome! I'll take a look tonight or tomorrow morning. Excited for this to be merged in. |
Hey! Hope you all have been doing well. I've been using JAX and was finally able to connect this in for #7. This initial version is pretty similar to the PyTorch one and was easy with some of the recent JAX developments on enabling custom JAX primitives with a non-traced vjp. The interface follows our PyTorch/TF layers has an outer function that takes the CVXPY objects and then returns a callable into a primitive for the problem.
Let me know if you have any thoughts on this initial version! There are a few things left I'll finish up over the next few days:
Some useful JAX references:
Primitive
with only impl and vjp google/jax#3415. @mattjj and @shoyer, your responses in these threads were extremely helpful! Let me know if there's anything else I should watch out for in this PR\cc @bodono