<a href="https://colab.research.google.com/github/profteachkids/CHE2064/blob/master/Broyden_inclass.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
eps=1e-12

In [4]:
def func(x):
    return jnp.array([jnp.sin(x[0])  + 0.5 * (x[0] - x[1])**3 - 1.0,
            0.5 * (x[1] - x[0])**3 + x[1]])

In [5]:
# Broyden update with no Sherman-Morrison
x = jnp.zeros(2)
J = jax.jacobian(func)(x)

f = func(x)

for i in range(20):
  xp = jnp.linalg.solve(J, -f) + x
  dx = xp - x
  fp = func(xp)
  f= fp
  x= xp
  print(x,fp)
  if jnp.linalg.norm(fp) < 1e-12:
    break

  J = J + jnp.outer(fp,dx)/jnp.linalg.norm(dx)**2



[1. 0.] [ 0.34147098 -0.5       ]
[0.74545034 0.37272517] [-0.29580697  0.34683492]
[0.87452582 0.23424262] [-0.10151446  0.10299655]
[0.94426608 0.18266999] [ 0.03094058 -0.03820376]
[0.92886637 0.19874589] [-0.00445347  0.00414106]
[0.931058   0.19761267] [-0.00047246  0.00033713]
[0.93134054 0.1975642 ] [-3.66025399e-05  2.14376618e-05]
[0.93136529 0.19756394] [-1.63573041e-06  9.88074612e-07]
[0.93136644 0.19756391] [ 9.80559234e-09 -5.98158895e-09]
[0.93136643 0.19756391] [-1.81729076e-11  1.10534082e-11]
[0.93136643 0.19756391] [-1.01030295e-14  6.16173779e-15]


In [66]:
def broyden(func, x, J=None, tol=1e-10, max_iter=100, verbose=0):
  J = jax.jacobian(func)(x) if J is None else J(x)
  Jinv = jnp.linalg.inv(J)
  f = func(x)

  for i in range(max_iter):
    xp = x - Jinv @ f
    dx = xp - x
    fp = func(xp)
    f= fp
    x= xp
    if verbose>0:
      print(x, f)
    if jnp.all(fp<tol):
      break

    u = fp.reshape(-1,1)
    v = dx.reshape(-1,1)/jnp.linalg.norm(dx)**2
    Jinv = Jinv - Jinv @ u @ v.T @ Jinv / (1 + v.T @ Jinv @ u)  #Sherman-Morrison
  return x, f

broyden(func, jnp.zeros(2))

(DeviceArray([0.93136643, 0.19756391], dtype=float64),
 DeviceArray([-1.81729076e-11,  1.10534082e-11], dtype=float64))

In [80]:
# Accomodate box bounds on variables

def broyden2(func, x, J=None, tol=1e-10, max_iter=100, verbose=0, xmax=jnp.inf, xmin=-jnp.inf):
  J = jax.jacobian(func) if J is None else J(x)
  Jinv = jnp.linalg.inv(J(x))
  f = func(x)

  for i in range(max_iter):
    dx = - Jinv @ f
    print(f'dx {dx}')
    alpha = jnp.min(jnp.concatenate([jnp.abs((xmax-x)/dx), jnp.abs((xmin-x)/dx), jnp.array([1.])]))
    dx =alpha*dx
    f= fp
    x= xp
    if verbose>0:
      print(i, x, f)
      print()
    if jnp.all(fp<tol):
      break

    u = fp.reshape(-1,1)
    v = dx.reshape(-1,1)/jnp.linalg.norm(dx)**2
    Jinv = Jinv - Jinv @ u @ v.T @ Jinv / (1 + v.T @ Jinv @ u)  #Sherman-Morrison
  return x, f

In [81]:
def func2(x):
    return jnp.array([ jnp.sin(x[0])+ 0.5 * (x[0] - x[1])**3 - 0.01*jnp.sqrt(x[1]-0.1) - 1.0,
            0.5 * (x[1] - x[0])**3 + x[1] + 0.001*jnp.sqrt(1.-x[0])])

In [82]:
xguess = jnp.array([0.93136643, 0.19756391])
func2(xguess)

DeviceArray([-0.00312353,  0.00026199], dtype=float64)

In [83]:
broyden2(func2, xguess, tol=1e-10, verbose=1, max_iter=5, xmin=jnp.array([-jnp.inf, 0.1]), xmax = jnp.array([1., jnp.inf]))

dx [0.00290108 0.00115437]
0 [0.93136643 0.19756391] [-1.01030295e-14  6.16173779e-15]



(DeviceArray([0.93136643, 0.19756391], dtype=float64),
 DeviceArray([-1.01030295e-14,  6.16173779e-15], dtype=float64))