<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/Broyden.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [58]:
import numpy as np
import jax
jax.config.update('jax_enable_x64',True)
import jax.numpy as jnp
from plotly.subplots import make_subplots
from scipy.optimize import minimize

In [2]:
def f(x):
    return 0.3*(x-0.9)*(x-0.3)*(x+0.5)*(x+0.8)

In [3]:
xplot = np.linspace(-1,1,100)
fig=make_subplots()
fig.add_scatter(x=xplot, y=f(xplot), mode='lines')
fig.update_layout(width=600, height=400, template='plotly_dark')

In [4]:
def newton(f, xguess, maxiter=100, tol=1e-10):

    fprime = jax.grad(f)
    x1=xguess
    for i in range(maxiter):
        x2 = x1 - f(x1)/fprime(x1)
        if jnp.abs(f(x2))<tol:
            break
        x1 = x2

    return x2

In [5]:
newton(f, 0.)



Array(0.9, dtype=float64, weak_type=True)

In [6]:
@jax.jit
def f2(x):
    return jnp.array([x[0]  + 0.5 * (x[0] - x[1])**3 - 1.0,
            0.5 * (x[1] - x[0])**3 + x[1]])

In [7]:
def newton_Nd(f, xguess, maxiter=100, tol=1e-10):

    J = jax.jit(jax.jacobian(f))
    x1=xguess
    for i in range(maxiter):
        x2 = x1 - jnp.linalg.inv(J(x1)) @ f(x1)
        print(i, f(x2))
        if jnp.linalg.norm(f(x2))<tol:
            break
        x1 = x2

    return x2

In [8]:
guess = jnp.array([0.,0.])
newton_Nd(f2, guess)

0 [ 0.5 -0.5]
1 [ 0.0859375 -0.0859375]
2 [ 0.00447052 -0.00447052]
3 [ 1.41153108e-05 -1.41153108e-05]
4 [ 1.41997303e-10 -1.41997386e-10]
5 [ 0.00000000e+00 -5.55111512e-17]


Array([0.8411639, 0.1588361], dtype=float64)

In [53]:
def broyden(f, x1, maxiter=100, tol=1e-10):

    fval1=f(x1)
    fval1_norm = jnp.linalg.norm(fval1)
    Jinv = np.eye(x1.size)
    for i in range(maxiter):
        dx= - Jinv @ fval1
        for j in range(10):
            x2 = x1 + dx
            fval2 = f(x2)
            fval2_norm = jnp.linalg.norm(fval2)
            if fval2_norm < fval1_norm:
                break
            dx /= 2
        print(i, x2, fval2_norm)
        if jnp.linalg.norm(fval2)<tol:
            break
        dx = dx.reshape(-1,1)
        df = (fval2 - fval1).reshape(-1,1)
        Jinv = Jinv + (dx - Jinv@df) @ dx.T @ Jinv/ (dx.T @ Jinv @ df)
        x1 = x2
        fval1 = fval2

    return x2

In [54]:
def rosen(v):
    x,y = v
    return (1-x)**2 + 100.*(y-x**2)**2

In [57]:
def constraint(v):
    x,y = v
    return (x-1.)**2 + (y-1.)**2 - 0.5**2

In [60]:
xguess = jnp.array([0.,0.])
res=minimize(rosen, xguess, constraints=dict(type='eq', fun=constraint))
res

 message: Optimization terminated successfully
 success: True
  status: 0
     fun: 0.06128195914165823
       x: [ 7.528e-01  5.654e-01]
     nit: 10
     jac: [-1.197e-01 -2.489e-01]
    nfev: 35
    njev: 10

In [61]:
constraint(res.x)

1.0398111749410077e-07

In [55]:
rosen_jac = jax.jacobian(rosen)

In [56]:
broyden(rosen_jac, xguess)

0 [0.0625 0.    ] 1.9414691262453962
1 [0.06347442 0.01370614] 2.869666864893233
2 [0.12270822 0.01107265] 1.7508832903664424
3 [0.13700356 0.01215966] 1.8993771713859868
4 [0.14755163 0.0141538 ] 1.9740649652903446
5 [0.15328993 0.01564233] 1.9841082174606424
6 [0.16677257 0.01968447] 1.9765656874064566
7 [0.18068134 0.02426756] 1.9685294774461997
8 [0.20623822 0.03348335] 1.9959418355652714
9 [0.21598372 0.03743634] 1.9977664690406087
10 [0.23492013 0.0459107 ] 1.9687254417278823
11 [0.25381406 0.05503692] 1.9529529252865176
12 [0.28545597 0.0717264 ] 1.9769684912126209
13 [0.30606593 0.08411987] 1.9236783172342664
14 [0.3528506  0.11477275] 1.947765168585583
15 [0.37964706 0.13823151] 1.2293843071883834
16 [0.33305311 0.10551092] 1.2440372951992476
17 [0.40199537 0.15605275] 1.1503922045015966
18 [0.42751206 0.17566174] 1.4226874337500652
19 [0.43509633 0.18641768] 0.8526559520876754
20 [0.42214773 0.17549019] 0.8837115374925402
21 [0.50162598 0.24421994] 1.5605934289178163
22 [0.48

Array([1., 1.], dtype=float64)