In [1]:
# f = x^2 + y^2 + z^2
# s.t. 2x - y + z = 3

In [2]:
import numpy as np
from scipy.optimize import minimize

In [3]:
def objective(x):
    [a, b, c] = x
    return a**2 + b**2 + c**2

def constraint(x):
    [a, b, c] = x
    return 2*a - b + c - 3

In [4]:
cons = [{'type': 'eq', 'fun': constraint }]

x0 = [1, 1, 2]
assert constraint(x0) == 0

In [5]:
solution = minimize(objective, x0=x0, constraints=cons, method='SLSQP')
solution

     fun: 1.5
     jac: array([ 2.00000003, -0.99999999,  1.00000001])
 message: 'Optimization terminated successfully.'
    nfev: 11
     nit: 2
    njev: 2
  status: 0
 success: True
       x: array([ 1. , -0.5,  0.5])

In [6]:
solution.x

array([ 1. , -0.5,  0.5])

In [7]:
import jax
import jax.numpy as np
from jax import grad
from jax.ops import index, index_add, index_update
from scipy.optimize import fsolve

jax.devices()

In [8]:
x = np.zeros(4)
x = index_update(x, index[3], 1.)
print(x)

[0. 0. 0. 1.]


In [9]:
def f(x):
    return (x**2).sum()

def g(x):
    return 2*x[0] - x[1] + x[2] - 3

In [10]:
def loss_fn(x):
    xx = x[:3]
    _lambda = x[3]
    return f(xx) - _lambda*g(xx)

In [11]:
dldx = grad(loss_fn)
print(dldx)

<function grad.<locals>.grad_f at 0x7fda86409510>


In [12]:
def obj(L):
    a, b, c, _lambda = L
    dFda, dFdb, dFdc, dFdlam = dldx(L)
    return [dFda, dFdb, dFdc, dFdlam]

In [13]:
a, b, c, lamb = fsolve(obj, x)
print(a, b, c, lamb)

1.0000000335283827 -0.5000000167641913 0.4999999662080479 0.9999999934182814


[GpuDevice(id=0)]