<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2021/blob/main/Broyden_Lagrange.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
eps=1e-12

In [2]:
# Limit step size to ensure a decrease in norm

def broyden3(func, x, J=None, tol=1e-10, max_iter=100, verbose=0, xmax=jnp.inf, xmin=-jnp.inf):
    Jf = jax.jacobian(func) if J is None else J
    J = Jf(x)
    Jinv = jnp.linalg.inv(J)
    f = func(x)

    for i in range(max_iter):

        dx = - Jinv @ f
        if verbose>0:
            print(f"\nIter: {i}  dx: {dx}")

        alpha_max_limits = jnp.min(jnp.where(x + dx > xmax, (xmax - x) / (dx), 1))
        alpha_min_limits = jnp.min(jnp.where(x + dx < xmin, (xmin - x) / (dx), 1))
        alpha = min(alpha_max_limits, alpha_min_limits)

        while alpha > 0.01:
            dx_try = alpha*dx
            xp = x + dx_try
            fp = func(xp)
            dnorm = jnp.linalg.norm(fp)-jnp.linalg.norm(f)
            if verbose>1:
                print(f"Alpha {alpha}   dnorm {dnorm}  dx_try {dx_try}   f {f}    fp {fp}")
            if dnorm > 0:
                alpha *= 0.5
            else:
                break
        if alpha <= 0.01:
            if verbose>0:
                print("reevaluate J")
            Jinv = jnp.linalg.inv(Jf(x))
            continue

        dx=dx_try
        f= fp
        x= xp
        if verbose>0:
          print(x, f)
        if jnp.all(jnp.abs(f)<tol):
          break

        u = jnp.expand_dims(fp,1)
        v = jnp.expand_dims(dx,1)/jnp.linalg.norm(dx)**2
        Jinv = Jinv - Jinv @ u @ v.T @ Jinv / (1 + v.T @ Jinv @ u)  #Sherman-Morrison
    return x, f

In [3]:
def func2(x):
    return jnp.array([jnp.sin(x[0])  + 0.5 * (x[0] - x[1])**3 - 0.01*jnp.sqrt(x[1]-0.1) - 1.0,
            0.5 * (x[1] - x[0])**3 + x[1] + 0.001*jnp.sqrt(1.-x[0])])

In [4]:
broyden3(func2, 0.95*jnp.ones(2), verbose=2, max_iter=20, xmin=jnp.array([-jnp.inf, 0.1+eps]), xmax = jnp.array([1.-eps, jnp.inf]))




Iter: 0  dx: [ 0.32776387 -0.9494907 ]
Alpha 0.1525488433460805   dnorm -0.1522769994586285  dx_try [ 0.05       -0.14484371]   f [-0.19580404  0.95022361]    fp [-0.16322784  0.80145776]
[1.         0.80515629] [-0.16322784  0.80145776]

Iter: 1  dx: [-0.06040385  0.17710007]
Alpha 1.0   dnorm 0.18520243065511433  dx_try [-0.06040385  0.17710007]   f [-0.16322784  0.80145776]    fp [-0.20211182  0.98254095]
Alpha 0.5   dnorm 0.09447817784862322  dx_try [-0.03020193  0.08855003]   f [-0.16322784  0.80145776]    fp [-0.18391719  0.89365983]
Alpha 0.25   dnorm 0.04810126644389001  dx_try [-0.01510096  0.04427502]   f [-0.16322784  0.80145776]    fp [-0.17419768  0.84831118]
Alpha 0.125   dnorm 0.024328390472125316  dx_try [-0.00755048  0.02213751]   f [-0.16322784  0.80145776]    fp [-0.16890823  0.82512827]
Alpha 0.0625   dnorm 0.012250770451923909  dx_try [-0.00377524  0.01106875]   f [-0.16322784  0.80145776]    fp [-0.1661218  0.8133705]
Alpha 0.03125   dnorm 0.006155804409375354  d

(DeviceArray([0.93426899, 0.19872083], dtype=float64),
 DeviceArray([ 4.55533389e-11, -2.28623176e-11], dtype=float64))

In [5]:
# No constraints
def rosen(x):
    return 100*(x[1]-x[0]**2)**2 + (1-x[0])**2

grad_rosen = jax.grad(rosen)
x0 = jnp.array([0.5,0.5])
broyden3(grad_rosen, x0, verbose=1, max_iter=500)


Iter: 0  dx: [-0.01020408 -0.26020408]
[0.48979592 0.23979592] [-1.0000085  -0.02082466]

Iter: 1  dx: [-0.01086358 -0.01075499]
[0.47893233 0.22904093] [-0.97790937 -0.06705118]

Iter: 2  dx: [0.90823721 0.88046619]
[0.53569716 0.28407006] [-0.30690004 -0.58027715]

Iter: 3  dx: [-0.06203555 -0.06038761]
[0.52794272 0.27652161] [-0.47912325 -0.44038046]

Iter: 4  dx: [0.00869576 0.00858789]
[0.5322906  0.28081556] [-0.39935449 -0.50354478]

Iter: 5  dx: [-0.00885714 -0.00918444]
reevaluate J

Iter: 6  dx: [0.31107115 0.33367822]
[0.55173255 0.30167045] [-0.29219898 -0.54767109]

Iter: 7  dx: [-0.02081895 -0.02235933]
[0.54913018 0.29887553] [-0.31561537 -0.53368427]

Iter: 8  dx: [0.0029633  0.00319621]
[0.55209348 0.30207174] [-0.29171845 -0.54709449]

Iter: 9  dx: [0.26969081 0.29989749]
[0.56894915 0.32081533] [-0.20489487 -0.577562  ]

Iter: 10  dx: [-0.01804669 -0.02009146]
[0.55992581 0.3107696 ] [-0.26483244 -0.54946203]

Iter: 11  dx: [0.01732295 0.01951974]
[0.57724876 0.330

(DeviceArray([1., 1.], dtype=float64),
 DeviceArray([-4.21334079e-11,  9.54791801e-12], dtype=float64))

In [6]:
# One equality constraint

def rosen(x):
    return 100*(x[1]-x[0]**2)**2 + (1-x[0])**2

def constr(x):
    return 2*x[0] + x[1] - 1

def L(x):
  return rosen(x)-x[2]*constr(x)

dL = jax.jit(jax.grad(L))
x0 = jnp.array([0.,0., 1.])
broyden3(dL, x0)


(DeviceArray([ 0.41494432,  0.17011137, -0.41348319], dtype=float64),
 DeviceArray([ 1.21760241e-14, -4.87804241e-15,  0.00000000e+00], dtype=float64))

In [7]:
# One inequality constraint x[2] is a slack variable

def rosen(x):
    return 100*(x[1]-x[0]**2)**2 + (1-x[0])**2

def constr(x):
    return x[0]**2 + x[1] + x[2]**2 - 1

def L(x):
  return rosen(x)-x[3]*constr(x)

dL = jax.jit(jax.grad(L)) 
x0 = jnp.array([0.1,0.1, 10., 1.])
x,f=broyden3(dL, x0, max_iter=500)
print(x)
print(rosen(x))

[ 7.07472158e-01  4.99483146e-01 -7.74302146e-13 -2.06741593e-01]
0.08567939371085512
