In [1]:
import cvxpy as cp
import jax
from cvxpylayers.jax import CvxpyLayer

n, m = 2, 3
x = cp.Variable(n)
A = cp.Parameter((m, n))
b = cp.Parameter(m)
constraints = [x >= 0]
objective = cp.Minimize(0.5 * cp.pnorm(A @ x - b, p=1))
problem = cp.Problem(objective, constraints)
assert problem.is_dpp()

cvxpylayer = CvxpyLayer(problem, parameters=[A, b], variables=[x])
key = jax.random.PRNGKey(0)
key, k1, k2 = jax.random.split(key, 3)
A_jax = jax.random.normal(k1, shape=(m, n))
b_jax = jax.random.normal(k2, shape=(m,))

solution, = cvxpylayer(A_jax, b_jax)

# compute the gradient of the summed solution with respect to A, b
dcvxpylayer = jax.grad(lambda A, b: sum(cvxpylayer(A, b)[0]), argnums=[0, 1])
gradA, gradb = dcvxpylayer(A_jax, b_jax)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
gradA

Array([[-5.3574888e-07,  2.1834301e-06],
       [ 2.6866549e-06,  1.0437016e+00],
       [ 5.3580015e-07,  1.9128194e-05]], dtype=float32)

In [3]:
x = cp.Variable()
a = cp.Parameter(value=3) 
const = [x <= a]
objective = cp.Maximize(x)
prob = cp.Problem(objective, const)
print(f"Simple DCP: {prob.is_dgp(dpp=True)}")
prob.solve(requires_grad=True)
prob.backward()
print(f"x: {x.value}, a_grad:{a.gradient}")

Simple DCP: False
x: 2.9999803327826546, a_grad:0.999994633701959


In [4]:
from cvxpylayers.jax import CvxpyLayer
import jax.numpy as jnp
from jax import grad
x = cp.Variable()
a = cp.Parameter(value=4) 
const = [x <= 3]
objective = cp.Minimize( cp.norm(x-a) )
prob = cp.Problem(objective, const)
# print(f"Simple DCP: {prob.is_dgp(dpp=True)}")
# prob.solve(requires_grad=True)
# prob.backward()
# print(f"x: {x.value}, a_grad:{a.gradient}")

cvx_layer = CvxpyLayer( prob, parameters=[a], variables=[x] )
def cvx_layer_simple(a):
    return cvx_layer(a)[0][0]
cvx_layer_grad = grad( cvx_layer_simple )
a_jax = jnp.array([2.8])
print(f"from layer: x: {cvx_layer_simple(a_jax)}, a_grad: {cvx_layer_grad(a_jax)}")

from layer: x: 2.800097703933716, a_grad: [1.0000271]


## Pytorch example

In [10]:
import cvxpy as cp
import torch
from cvxpylayers.torch import CvxpyLayer

x = cp.Variable()
a = cp.Parameter(value=4) 
const = [x <= 3]
objective = cp.Minimize( cp.sum_squares(x-a) )
prob = cp.Problem(objective, const)

cvxpylayer = CvxpyLayer(prob, parameters=[a], variables=[x])
a_tch = torch.tensor([2.8], requires_grad=True)

# solve the problem
solution, = cvxpylayer(a_tch)

# compute the gradient of the sum of the solution with respect to A, b
solution.sum().backward()
print(f"torch grad: {a_tch.grad}")

torch grad: tensor([1.0000])
