<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2021/blob/main/LevenbergMarquardtMatrixFree.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
from scipy.sparse.linalg import LinearOperator, cg, bicgstab, bicg
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
from functools import partial


In [2]:
x_data = np.linspace(0,1,10)
y_data = 3*np.sin(5*x_data+1)

In [93]:
def f(p):
    a,b,c = p
    return a*jnp.sin(b*x_data+c)

In [94]:
x0 = jnp.array([3.,5.0,0.99])

v0 = jnp.array([0.5,-0.8,0.2])
f(x0)

DeviceArray([ 2.50807794,  2.99904441,  2.58794368,  1.39842864,
             -0.21171258, -1.75817386, -2.77580304, -2.95851265,
             -2.2513464 , -0.86700921], dtype=float64)

In [149]:
x1=x0
L=10
for i in range(10):
    print(f'{i:3d} {jnp.linalg.norm(y_data-f(x1)):9.3e}')
    J = jax.jacobian(f)(x1)
    dx = np.linalg.solve(J.T@J + L*jnp.diag(jnp.diagonal(J.T@J)), J.T@(y_data-f(x1)))
    x1=x1+dx





  0 6.378e-02
  1 5.424e-02
  2 4.631e-02
  3 3.975e-02
  4 3.435e-02
  5 2.991e-02
  6 2.630e-02
  7 2.337e-02
  8 2.102e-02
  9 1.914e-02


In [172]:
def get_JTJ(f,N):

    zeros=jnp.zeros(N)
    idx=jnp.arange(N)

    def JTf(x):
        fval, fvjp = jax.vjp(f,x)
        return fvjp(y_data-f(x))[0]

    def JTJv(x,v):
        Av=(jax.jvp(f,(x,),(v,))[1])
        fval, fvjp=jax.vjp(f,x)
        return fvjp(Av.T)[0]

    JTJv_jit=jax.jit(JTJv)
    def get_LO(x):
        return LinearOperator((x.size,x.size), matvec=lambda v: JTJv_jit(x,v))

    def JTJ_diag(x,v):
        def JTJ_diag_i(i, _):
            return i+1, JTJv_jit(x, zeros.at[i].set(v[i]))[i]

        return jax.lax.scan(JTJ_diag_i, 0, idx)[1]

    JTJ_diag_jit=jax.jit(JTJ_diag)

    def LM_LO(x,L):
        return LinearOperator((x.size,x.size), matvec=lambda v: JTJv_jit(x,v) + L*JTJ_diag_jit(x,v))

    
    
    return get_LO, JTJv, JTJ_diag, LM_LO, jax.jit(JTf)

In [177]:
L=0.1
x1=x0
f1=jnp.linalg.norm(y_data-f(x1))
N=x1.size
print(f1)
for i in range(10):
    LO=LM_LO(x1, L)
    dx = bicgstab(LO, JTf(x1), atol=1e-25)[0]
    x2=x1+dx
    f2=jnp.linalg.norm(y_data-f(x2))
    if f2>f1:
        L=L*2
    else:
        L=L/2
        x1=x2
        f1=f2
    print(f'{i:3d} {f2:9.3e} {f1:9.3e} {L:9.3e}')
    if f2<1e-25:
        break

0.06378430618993293
  0 7.747e-03 7.747e-03 5.000e-02
  1 2.180e-03 2.180e-03 2.500e-02
  2 3.982e-04 3.982e-04 1.250e-02
  3 4.006e-05 4.006e-05 6.250e-03
  4 2.122e-06 2.122e-06 3.125e-03
  5 5.771e-08 5.771e-08 1.563e-03
  6 7.957e-10 7.957e-10 7.813e-04
  7 5.524e-12 5.524e-12 3.906e-04
  8 1.950e-14 1.950e-14 1.953e-04
  9 1.848e-15 1.848e-15 9.766e-05


In [None]:
def f(x):
    return 3*jnp.sin(2*x[0]*x[1]+x[2])+ jnp.cos(3*x[1]*x[2])*jnp.exp(1.5*x[2]*x[0])

def rosen(x):
    return jnp.sum(100.0*(x[1:] - x[:-1]**2.0)**2.0 + (1 - x[:-1])**2.0)

In [None]:
x0=jnp.full(10,0.9)
h0=jax.hessian(rosen)(x0)
diag=jnp.diag(h0)
diag @ x0

DeviceArray(6593.4, dtype=float64, weak_type=True)

In [None]:
rosen(x0)

DeviceArray(7.38, dtype=float64)

In [None]:
h0 @ x0

DeviceArray([ 228.6,   84.6,   84.6,   84.6,   84.6,   84.6,   84.6,
               84.6,   84.6, -144. ], dtype=float64, weak_type=True)

In [None]:
def get_LM_LO(f, N):
    zeros=jnp.zeros(N)
    idx=jnp.arange(N)
    grad=jax.jit(jax.grad(f))

    def hvp(x,v):
        return jax.grad(lambda x: jnp.vdot(grad(x), v))(x)

    def hvp_diag(x,v):
        def hvp_diag_i(carry, i):
            return carry+jax.grad(lambda x: jnp.vdot(jax.grad(f)(x), zeros.at[i].set(v[i])))(x)[i], None

        return jax.lax.scan(hvp_diag_i, 0, idx)[0]

    def LM_LO(x,v,L):
        return hvp(x,v) + L*hvp_diag(x,v)

    return jax.jit(LM_LO), hvp, hvp_diag

In [None]:
LM_LO, hvp, hvp_diag=get_LM_LO(rosen,x0.size)

In [None]:
hvp(x0,x0)

DeviceArray([ 228.6,   84.6,   84.6,   84.6,   84.6,   84.6,   84.6,
               84.6,   84.6, -144. ], dtype=float64)

In [None]:
hvp_diag(x0,x0)

DeviceArray(6593.4, dtype=float64)

In [None]:
LM_LO(x0,x0,1)

DeviceArray([6822. , 6678. , 6678. , 6678. , 6678. , 6678. , 6678. ,
             6678. , 6678. , 6449.4], dtype=float64)

In [None]:
L=1.
x1=x0
f1=rosen(x1)
print(f1)
grad=jax.jit(jax.grad(rosen))
for i in range(100):
    LO=LinearOperator((x1.size,x1.size), matvec=lambda v: LM_LO(x1,v,L))
    # LO=LinearOperator((x1.size,x1.size), matvec=lambda v: hvp(x1,v))
    dx = bicgstab(LO, -grad(x1), atol=1e-25)[0]
    x2=x1+dx
    f2=rosen(x2)
    if f2>f1:
        L=L*2
    else:
        L=L/2
        x1=x2
        f1=f2
    print(f'{i:3d} {f2:9.3e} {L:9.3e}')
    if f2<1e-25:
        break
   
    
    

7.379999999999995
  0 4.503e+00 5.000e-01
  1 3.290e-01 2.500e-01
  2 1.173e-01 1.250e-01
  3 1.144e-01 6.250e-02
  4 1.129e-01 3.125e-02
  5 1.100e-01 1.562e-02
  6 1.045e-01 7.812e-03
  7 9.464e-02 3.906e-03
  8 7.832e-02 1.953e-03
  9 5.568e-02 9.766e-04
 10 3.189e-02 4.883e-04
 11 1.414e-02 2.441e-04
 12 4.482e-03 1.221e-04
 13 8.397e-04 6.104e-05
 14 6.428e-05 3.052e-05
 15 1.123e-06 1.526e-05
 16 3.333e-09 7.629e-06
 17 2.377e-12 3.815e-06
 18 4.396e-16 1.907e-06
 19 2.062e-20 9.537e-07
 20 2.435e-25 4.768e-07
 21 5.707e-30 2.384e-07


In [None]:
x1

DeviceArray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float64)

In [None]:
bicgstab(LO, -grad(x0))

(array([-0.02821009,  0.02340543,  0.19104967]), 0)