<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2021/blob/main/LevenbergMarquardtMatrixFree.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [197]:
import numpy as np
from scipy.sparse.linalg import LinearOperator, cg, bicgstab, bicg
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
from functools import partial


In [198]:
x_data = np.linspace(0,1,10)
y_data = 3*np.sin(5*x_data+1)

In [214]:
def f(p):
    a,b,c = p
    return a*jnp.sin(b*x_data+c)

In [215]:
x0 = jnp.array([3.,5.0,0.99])
v0 = jnp.array([0.5,-0.8,0.2])
f(x0)

DeviceArray([ 2.50807794,  2.99904441,  2.58794368,  1.39842864,
             -0.21171258, -1.75817386, -2.77580304, -2.95851265,
             -2.2513464 , -0.86700921], dtype=float64)

In [216]:
x1=x0
L=10
for i in range(10):
    print(f'{i:3d} {jnp.linalg.norm(y_data-f(x1)):9.3e}')
    J = jax.jacobian(f)(x1)
    dx = np.linalg.solve(J.T@J + L*jnp.diag(jnp.diagonal(J.T@J)), J.T@(y_data-f(x1)))
    x1=x1+dx

  0 6.378e-02
  1 5.424e-02
  2 4.631e-02
  3 3.975e-02
  4 3.435e-02
  5 2.991e-02
  6 2.630e-02
  7 2.337e-02
  8 2.102e-02
  9 1.914e-02


In [237]:
def LM(f,x1,y_data):

    zeros=jnp.zeros(x1.size)
    idx=jnp.arange(x1.size)

    def JTf(x):
        fval, fvjp = jax.vjp(f,x)
        return fvjp(y_data-f(x))[0]

    JTf_jit=jax.jit(JTf)

    def JTJv(x,v):
        Av=(jax.jvp(f,(x,),(v,))[1])
        fval, fvjp=jax.vjp(f,x)
        return fvjp(Av.T)[0]

    JTJv_jit=jax.jit(JTJv)

    def JTJ_diag(x,v):
        def JTJ_diag_i(i, _):
            return i+1, JTJv_jit(x, zeros.at[i].set(v[i]))[i]

        return jax.lax.scan(JTJ_diag_i, 0, idx)[1]

    JTJ_diag_jit=jax.jit(JTJ_diag)

    def LM_LO(x,L):
        return LinearOperator((x.size,x.size), matvec=lambda v: JTJv_jit(x,v) + L*JTJ_diag_jit(x,v))

    L=0.1
    f1=jnp.linalg.norm(y_data-f(x1))
    for i in range(500):
        LO=LM_LO(x1, L)
        dx = bicgstab(LO, JTf_jit(x1), atol=1e-25)[0]
        x2=x1+dx
        f2=jnp.linalg.norm(y_data-f(x2))
        if f2>f1:
            L=L*2
        else:
            L=L/2
            x1=x2
            f1=f2
        print(f'{i:3d} {f2:9.3e} {f1:9.3e} {L:9.3e}')
        if f2<1e-25:
            break
    return x1

In [238]:
LM(f,x0,y_data)

  0 7.747e-03 7.747e-03 5.000e-02
  1 2.180e-03 2.180e-03 2.500e-02
  2 3.982e-04 3.982e-04 1.250e-02
  3 4.006e-05 4.006e-05 6.250e-03
  4 2.122e-06 2.122e-06 3.125e-03
  5 5.771e-08 5.771e-08 1.563e-03
  6 7.957e-10 7.957e-10 7.813e-04
  7 5.524e-12 5.524e-12 3.906e-04
  8 1.950e-14 1.950e-14 1.953e-04
  9 1.848e-15 1.848e-15 9.766e-05
 10 0.000e+00 0.000e+00 4.883e-05


DeviceArray([3., 5., 1.], dtype=float64)

In [239]:
def rosen(x):
    return jnp.sum(100.0*(x[1:] - x[:-1]**2.0)**2.0 + (1 - x[:-1])**2.0)

rosen_grad=jax.jit(jax.grad(rosen))

In [242]:
rosen_x0=jnp.full(7,0.7)

In [243]:
LM(rosen_grad,rosen_x0,jnp.zeros_like(rosen_x0))

  0 5.986e+01 5.986e+01 5.000e-02
  1 5.357e+01 5.357e+01 2.500e-02
  2 3.391e+01 3.391e+01 1.250e-02
  3 5.838e+00 5.838e+00 6.250e-03
  4 1.581e+00 1.581e+00 3.125e-03
  5 1.527e+00 1.527e+00 1.563e-03
  6 1.512e+00 1.512e+00 7.813e-04
  7 1.485e+00 1.485e+00 3.906e-04
  8 1.432e+00 1.432e+00 1.953e-04
  9 1.339e+00 1.339e+00 9.766e-05
 10 1.297e+00 1.297e+00 4.883e-05
 11 1.904e+00 1.297e+00 9.766e-05
 12 1.131e+00 1.131e+00 4.883e-05
 13 1.370e+00 1.131e+00 9.766e-05
 14 9.507e-01 9.507e-01 4.883e-05
 15 1.031e+00 9.507e-01 9.766e-05
 16 8.363e-01 8.363e-01 4.883e-05
 17 8.532e-01 8.363e-01 9.766e-05
 18 7.617e-01 7.617e-01 4.883e-05
 19 7.495e-01 7.495e-01 2.441e-05
 20 1.009e+00 7.495e-01 4.883e-05
 21 6.371e-01 6.371e-01 2.441e-05
 22 7.430e-01 6.371e-01 4.883e-05
 23 5.531e-01 5.531e-01 2.441e-05
 24 5.976e-01 5.531e-01 4.883e-05
 25 4.989e-01 4.989e-01 2.441e-05
 26 5.126e-01 4.989e-01 4.883e-05
 27 4.594e-01 4.594e-01 2.441e-05
 28 4.568e-01 4.568e-01 1.221e-05
 29 6.161e-01 

DeviceArray([1., 1., 1., 1., 1., 1., 1.], dtype=float64)