Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] JAX implementation #100

Merged
merged 1 commit into from
Apr 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ build_script:
- pip install .

test_script:
- pytest
- pytest cvxpylayers/torch cvxpylayers/tensorflow
9 changes: 1 addition & 8 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
matrix:
include:
- os: linux
dist: xenial
language: python
python: "3.5"
- os: linux
dist: xenial
language: python
Expand All @@ -12,10 +8,6 @@ matrix:
dist: xenial
language: python
python: "3.7"
- os: linux
dist: bionic
language: python
python: "3.5"
- os: linux
dist: bionic
language: python
Expand All @@ -31,6 +23,7 @@ before_install:

install:
- pip install --upgrade pip
- pip install jax==0.2.12 jaxlib==0.1.64
- pip install tensorflow pytest flake8 jupyter matplotlib sklearn tqdm
- pip install torch==1.3.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
- pip install .
Expand Down
63 changes: 46 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# cvxpylayers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point we should update the CVXPY Layers logo to include JAX :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done :)


cvxpylayers is a Python library for constructing differentiable convex
optimization layers in PyTorch and TensorFlow using CVXPY.
optimization layers in PyTorch, JAX, and TensorFlow using CVXPY.
A convex optimization layer solves a parametrized convex optimization problem
in the forward pass to produce a solution.
It computes the derivative of the solution with respect to
Expand Down Expand Up @@ -36,28 +36,32 @@ cvxpylayers.
pip install cvxpylayers
```

Our package includes convex optimization layers for PyTorch and TensorFlow 2.0;
Our package includes convex optimization layers for
PyTorch, JAX, and TensorFlow 2.0;
the layers are functionally equivalent. You will need to install
[PyTorch](https://pytorch.org) or [TensorFlow](https://www.tensorflow.org)
[PyTorch](https://pytorch.org),
[JAX](https://github.com/google/jax), or
[TensorFlow](https://www.tensorflow.org)
separately, which can be done by following the instructions on their websites.

cvxpylayers has the following dependencies:
* Python 3
* [NumPy](https://pypi.org/project/numpy/)
* [CVXPY](https://github.com/cvxgrp/cvxpy) >= 1.1.a4
* [TensorFlow](https://tensorflow.org) >= 2.0 or [PyTorch](https://pytorch.org) >= 1.0
* [PyTorch](https://pytorch.org) >= 1.0, [JAX](https://github.com/google/jax) >= 0.2.12, or [TensorFlow](https://tensorflow.org) >= 2.0
* [diffcp](https://github.com/cvxgrp/diffcp) >= 1.0.13

## Usage
Below are usage examples of our PyTorch and TensorFlow layers. Note that
the parametrized convex optimization problems must be constructed in CVXPY,
using [DPP](https://www.cvxpy.org/tutorial/advanced/index.html#disciplined-parametrized-programming).
Below are usage examples of our PyTorch, JAX, and TensorFlow layers.
Note that the parametrized convex optimization problems must be constructed
in CVXPY, using
[DPP](https://www.cvxpy.org/tutorial/advanced/index.html#disciplined-parametrized-programming).

### PyTorch

```python
import cvxpy as cp
import torch
import torch
from cvxpylayers.torch import CvxpyLayer

n, m = 2, 3
Expand All @@ -82,6 +86,36 @@ solution.sum().backward()

Note: `CvxpyLayer` cannot be traced with `torch.jit`.

### JAX
```python
import cvxpy as cp
import jax
from cvxpylayers.jax import CvxpyLayer

n, m = 2, 3
x = cp.Variable(n)
A = cp.Parameter((m, n))
b = cp.Parameter(m)
constraints = [x >= 0]
objective = cp.Minimize(0.5 * cp.pnorm(A @ x - b, p=1))
problem = cp.Problem(objective, constraints)
assert problem.is_dpp()

cvxpylayer = CvxpyLayer(problem, parameters=[A, b], variables=[x])
key = jax.random.PRNGKey(0)
key, k1, k2 = jax.random.split(key, 3)
A_jax = jax.random.normal(k1, shape=(m, n))
b_jax = jax.random.normal(k2, shape=(m,))

solution, = cvxpylayer(A_jax, b_jax)

# compute the gradient of the summed solution with respect to A, b
dcvxpylayer = jax.grad(lambda A, b: sum(cvxpylayer(A, b)[0]), argnums=[0, 1])
gradA, gradb = dcvxpylayer(A_jax, b_jax)
```

Note: `CvxpyLayer` cannot be traced with the JAX `jit` or `vmap` operations.

### TensorFlow 2
```python
import cvxpy as cp
Expand Down Expand Up @@ -118,11 +152,11 @@ Starting with version 0.1.3, cvxpylayers can also differentiate through log-log
import cvxpy as cp
import torch
from cvxpylayers.torch import CvxpyLayer

x = cp.Variable(pos=True)
y = cp.Variable(pos=True)
z = cp.Variable(pos=True)

a = cp.Parameter(pos=True, value=2.)
b = cp.Parameter(pos=True, value=1.)
c = cp.Parameter(value=0.5)
Expand Down Expand Up @@ -168,14 +202,9 @@ To install `pytest`, run:
pip install pytest
```

To run the tests for `torch`, in the main directory of this repository, run:
```bash
pytest cvxpylayers/torch
```

To run the tests for `tensorflow`, in the main directory of this repository, run:
Execute the tests from the main directory of this repository with:
```bash
pytest cvxpylayers/tensorflow
pytest cvxpylayers/{torch,jax,tensorflow}
```

## Projects using cvxpylayers
Expand Down
1 change: 1 addition & 0 deletions cvxpylayers/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from cvxpylayers.jax.cvxpylayer import CvxpyLayer # noqa: F401
Loading