<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/Broyden.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import jax
jax.config.update('jax_enable_x64',True)
import jax.numpy as jnp
from plotly.subplots import make_subplots
from scipy.optimize import minimize

In [2]:
def f(x):
    return 0.3*(x-0.9)*(x-0.3)*(x+0.5)*(x+0.8)

In [3]:
xplot = np.linspace(-1,1,100)
fig=make_subplots()
fig.add_scatter(x=xplot, y=f(xplot), mode='lines')
fig.update_layout(width=600, height=400, template='plotly_dark')

In [4]:
def newton(f, xguess, maxiter=100, tol=1e-10):

    fprime = jax.grad(f)
    x1=xguess
    for i in range(maxiter):
        x2 = x1 - f(x1)/fprime(x1)
        if jnp.abs(f(x2))<tol:
            break
        x1 = x2

    return x2

In [5]:
newton(f, 0.)



Array(0.9, dtype=float64, weak_type=True)

In [6]:
@jax.jit
def f2(x):
    return jnp.array([x[0]  + 0.5 * (x[0] - x[1])**3 - 1.0,
            0.5 * (x[1] - x[0])**3 + x[1]])

In [7]:
def newton_Nd(f, xguess, maxiter=100, tol=1e-10):

    J = jax.jit(jax.jacobian(f))
    x1=xguess
    for i in range(maxiter):
        x2 = x1 - jnp.linalg.inv(J(x1)) @ f(x1)
        print(i, f(x2))
        if jnp.linalg.norm(f(x2))<tol:
            break
        x1 = x2

    return x2

In [8]:
guess = jnp.array([0.,0.])
newton_Nd(f2, guess)

0 [ 0.5 -0.5]
1 [ 0.0859375 -0.0859375]
2 [ 0.00447052 -0.00447052]
3 [ 1.41153108e-05 -1.41153108e-05]
4 [ 1.41997303e-10 -1.41997386e-10]
5 [ 0.00000000e+00 -5.55111512e-17]


Array([0.8411639, 0.1588361], dtype=float64)

In [23]:
def broyden(f, x1, maxiter=100, tol=1e-10):

    fval1=f(x1)
    fval1_norm = jnp.linalg.norm(fval1)
    Jinv = np.eye(x1.size)
    for i in range(maxiter):
        dx= - Jinv @ fval1
        for j in range(10):
            x2 = x1 + dx
            fval2 = f(x2)
            fval2_norm = jnp.linalg.norm(fval2)
            if fval2_norm < fval1_norm:
                break
            dx /=2
        print(i, x2, fval2_norm)
        if jnp.linalg.norm(fval2)<tol:
            break
        dx = dx.reshape(-1,1)
        df = (fval2 - fval1).reshape(-1,1)
        Jinv = Jinv + (dx - Jinv@df) @ dx.T @ Jinv/ (dx.T @ Jinv @ df)
        x1 = x2
        fval1 = fval2

    return x2

In [24]:
def rosen(v):
    x,y = v
    return (1-x)**2 + 100.*(y-x**2)**2

In [25]:
def constraint(v):
    x,y = v
    return (x-1.)**2 + (y-1.)**2 - 0.5**2

In [26]:
xguess = jnp.array([0.,0.])
res=minimize(rosen, xguess, constraints=dict(type='eq', fun=constraint))
res

 message: Optimization terminated successfully
 success: True
  status: 0
     fun: 0.06128195914165823
       x: [ 7.528e-01  5.654e-01]
     nit: 10
     jac: [-1.197e-01 -2.489e-01]
    nfev: 35
    njev: 10

In [27]:
constraint(res.x)

1.0398111749410077e-07

In [28]:
rosen_jac = jax.jacobian(rosen)

In [None]:
broyden(rosen_jac, xguess)

In [30]:
#Lagrangian of Rosenbrock Function with a Constraint on Circle centered at 1,1 radius 0.5
def lagrange(v):
    L = v[-1]
    return rosen(v[:2]) + L*constraint(v[:2])

In [31]:
lagrange_jac = jax.jacobian(lagrange)

In [32]:
jnp.r_[xguess, 1.]

Array([0., 0., 1.], dtype=float64)

In [None]:
broyden(lagrange_jac, jnp.r_[xguess, 1.])

In [37]:
#parabolic inequality constraint
def parabola_constraint(v):
    x,y,s = v
    return y-x**2 - s**2

In [38]:
#Lagrangian of Rosenbrock Function with a Constraint on Circle centered at 1,1 radius 0.5
# and an inequality constraint y > x**2
def lagrange2(v):
    L1, L2 = v[-2:]
    return rosen(v[:2]) + L1*constraint(v[:2]) + L2*parabola_constraint(v[:3])

In [39]:
lagrange2_jac = jax.jacobian(lagrange2)

In [40]:
broyden(lagrange2_jac, jnp.r_[xguess, 0.1, 1.,1.])

0 [0.25     0.0625   0.1125   0.890625 1.000625] 3.6123150061440215
1 [0.26288025 0.05918255 0.11545834 0.89563156 1.00082843] 3.7085345366763853
2 [0.26477779 0.06945628 0.11685641 0.89679417 1.00097888] 3.54875801763982
3 [0.24815104 0.05166616 0.08897419 0.82796663 0.99798483] 3.6253000468789076
4 [0.27523805 0.06013933 0.00908654 0.71826703 0.98985882] 3.90013204249985
5 [ 0.31016713  0.07780855 -0.04969049  0.70178862  0.98736541] 4.185305955543495
6 [ 0.32768285  0.08841656 -0.07783315  0.72348621  0.99065589] 4.277865846119309
7 [ 0.3262072   0.08630204 -0.06637135  0.7007806   0.98673464] 4.45189301563122
8 [ 0.32733326  0.08686143 -0.06544423  0.69904145  0.98654875] 4.479450059307884
9 [ 0.32797418  0.08723989 -0.06518355  0.69890056  0.98660472] 4.485617854573112
10 [ 0.35202848  0.10382335 -0.06095285  0.72057239  0.9938417 ] 4.428921986963451
11 [ 0.37602063  0.12135101 -0.05635568  0.74577158  1.00147452] 4.4145943686548375
12 [ 0.39829508  0.13891477 -0.05195915  0.77279

Array([ 7.52161310e-01,  5.65746636e-01, -9.38500726e-13, -2.75041429e-01,
       -2.38875332e-01], dtype=float64)

In [41]:
(7.52161310e-01)**2

0.565746636260916