In [5]:
import cvxpy as cp
import jax
from cvxpylayers.jax import CvxpyLayer

n, m = 2, 3
x = cp.Variable(n)
A = cp.Parameter((m, n))
b = cp.Parameter(m)
constraints = [x >= 0]
objective = cp.Minimize(0.5 * cp.pnorm(A @ x - b, p=1))
problem = cp.Problem(objective, constraints)
assert problem.is_dpp()

cvxpylayer = CvxpyLayer(problem, parameters=[A, b], variables=[x])
key = jax.random.PRNGKey(0)
key, k1, k2 = jax.random.split(key, 3)
A_jax = jax.random.normal(k1, shape=(m, n))
b_jax = jax.random.normal(k2, shape=(m,))

solution, = cvxpylayer(A_jax, b_jax)

# compute the gradient of the summed solution with respect to A, b
dcvxpylayer = jax.grad(lambda A, b: sum(cvxpylayer(A, b)[0]), argnums=[0, 1])
%timeit gradA, gradb = dcvxpylayer(A_jax, b_jax)

4.01 ms ± 7.34 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [2]:
gradA

Array([[-5.3574888e-07,  2.1834301e-06],
       [ 2.6866549e-06,  1.0437016e+00],
       [ 5.3580015e-07,  1.9128194e-05]], dtype=float32)

In [3]:
gradb

Array([ 4.9595337e-07, -9.4603354e-01, -1.9813251e-05], dtype=float32)

In [10]:
dcvxpylayer_jit = jax.jit( dcvxpylayer )
%timeit gradA, gradb = dcvxpylayer(A_jax, b_jax)

4.05 ms ± 3.81 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
