<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2021/blob/main/Broyden_Lagrange.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
eps=1e-12

In [47]:
# Limit step size to ensure a decrease in norm

def broyden3(func, x, J=None, tol=1e-10, max_iter=100, verbose=0, xmax=jnp.inf, xmin=-jnp.inf):
    Jf = jax.jacobian(func) if J is None else J
    J = Jf(x)
    Jinv = jnp.linalg.inv(J)
    f = func(x)

    for i in range(max_iter):

        dx = - Jinv @ f
        if verbose>0:
            print(f"\nIter: {i}  dx: {dx}")

        alpha_max_limits = jnp.min(jnp.where(x + dx > xmax, (xmax - x) / (dx), 1))
        alpha_min_limits = jnp.min(jnp.where(x + dx < xmin, (xmin - x) / (dx), 1))
        alpha = min(alpha_max_limits, alpha_min_limits)

        while alpha > 0.01:
            dx_try = alpha*dx
            xp = x + dx_try
            fp = func(xp)
            dnorm = jnp.linalg.norm(fp)-jnp.linalg.norm(f)
            if verbose>1:
                print(f"Alpha {alpha}   dnorm {dnorm}  dx_try {dx_try}   f {f}    fp {fp}")
            if dnorm > 0:
                alpha *= 0.5
            else:
                break
        if alpha <= 0.01:
            if verbose>0:
                print("reevaluate J")
            Jinv = jnp.linalg.inv(Jf(x))
            continue

        dx=dx_try
        f= fp
        x= xp
        if verbose>0:
          print(x, f)
        if jnp.all(jnp.abs(f)<tol):
          break

        u = jnp.expand_dims(fp,1)
        v = jnp.expand_dims(dx,1)/jnp.linalg.norm(dx)**2
        Jinv = Jinv - Jinv @ u @ v.T @ Jinv / (1 + v.T @ Jinv @ u)  #Sherman-Morrison
    return x, f

In [41]:
def func2(x):
    return jnp.array([jnp.sin(x[0])  + 0.5 * (x[0] - x[1])**3 - 0.01*jnp.sqrt(x[1]-0.1) - 1.0,
            0.5 * (x[1] - x[0])**3 + x[1] + 0.001*jnp.sqrt(1.-x[0])])

In [None]:
broyden3(func2, 0.95*jnp.ones(2), verbose=2, max_iter=20, xmin=jnp.array([-jnp.inf, 0.1+eps]), xmax = jnp.array([1.-eps, jnp.inf]))

In [None]:
# No constraints
def rosen(x):
    return 100*(x[1]-x[0]**2)**2 + (1-x[0])**2

grad_rosen = jax.grad(rosen)
x0 = jnp.array([0.5,0.5])
broyden3(grad_rosen, x0, verbose=1, max_iter=500)

In [35]:
# One equality constraint

def rosen(x):
    return 100*(x[1]-x[0]**2)**2 + (1-x[0])**2

def constr(x):
    return 2*x[0] + x[1] - 1

def L(x):
  return rosen(x)-x[2]*constr(x)

dL = jax.jit(jax.grad(L))
x0 = jnp.array([0.,0., 1.])
broyden3(dL, x0)


(DeviceArray([ 0.41494432,  0.17011137, -0.41348319], dtype=float64),
 DeviceArray([ 1.21760241e-14, -4.87804241e-15,  0.00000000e+00], dtype=float64))

In [33]:
# One inequality constraint x[2] is a slack variable

def rosen(x):
    return 100*(x[1]-x[0]**2)**2 + (1-x[0])**2

def constr(x):
    return x[0]**2 + x[1] + x[2]**2 - 1

def L(x):
  return rosen(x)-x[3]*constr(x)

def grads(x):
    return jax.jacobian(rosen)(x) + x[3]*jax.jacobian(constr)(x)

dL = jax.jit(jax.grad(L)) 
x0 = jnp.array([0.1,0.1, 10., 1.])
x,f=broyden3(dL, x0, max_iter=500)
print(x)
print(rosen(x))

[ 7.07472158e-01  4.99483146e-01 -1.54800687e-18 -2.06741593e-01]
0.08567939371082824
