<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2021/blob/main/LevenbergMarquardt_LBFGSB_MatrixFree.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import numpy as np
from scipy.sparse.linalg import LinearOperator, cg, bicgstab, bicg
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
from functools import partial
from scipy.optimize import minimize

In [7]:
x_data = np.linspace(0,1,10)
y_data = 3*np.sin(5*x_data+1)

In [8]:
def f(p):
    a,b,c = p
    return a*jnp.sin(b*x_data+c)

In [9]:
x0 = jnp.array([3.,5.0,0.99])
v0 = jnp.array([0.5,-0.8,0.2])
f(x0)

DeviceArray([ 2.50807794,  2.99904441,  2.58794368,  1.39842864,
             -0.21171258, -1.75817386, -2.77580304, -2.95851265,
             -2.2513464 , -0.86700921], dtype=float64)

In [5]:
x1=x0
L=0.1
for i in range(20):
    print(f'{i:3d} {jnp.linalg.norm(y_data-f(x1)):9.3e}')
    J = jax.jacobian(f)(x1)
    dx = np.linalg.solve(J.T@J + L*jnp.diag(jnp.diagonal(J.T@J)), J.T@(y_data-f(x1)))
    x1=x1+dx

  0 6.378e-02


KeyboardInterrupt: ignored

In [None]:
def LM_iterative(f,x1,y_data=0.,max_iter=100, tol=1e-20, cgtol=1e-25):

    zeros=jnp.zeros(x1.size)
    idx=jnp.arange(x1.size)

    def JTf(x):
        fval, fvjp = jax.vjp(f,x)
        return fvjp(y_data-f(x))[0]

    JTf_jit=jax.jit(JTf)

    def JTJv(x,v):
        Av=(jax.jvp(f,(x,),(v,))[1])
        fval, fvjp=jax.vjp(f,x)
        return fvjp(Av.T)[0]

    JTJv_jit=jax.jit(JTJv)

    def JTJ_diag(x,v):
        def JTJ_diag_i(i, _):
            return i+1, JTJv_jit(x, zeros.at[i].set(v[i]))[i]

        return jax.lax.scan(JTJ_diag_i, 0, idx)[1]

    JTJ_diag_jit=jax.jit(JTJ_diag)

    def LM_LO(x,L):
        return LinearOperator((x.size,x.size), matvec=lambda v: JTJv_jit(x,v) + L*JTJ_diag_jit(x,v))

    L=0.1
    f1=jnp.linalg.norm(y_data-f(x1))
    for i in range(max_iter):
        LO=LM_LO(x1, L)
        dx = bicgstab(LO, JTf_jit(x1), tol=cgtol)[0]
        x2=x1+dx
        f2=jnp.linalg.norm(y_data-f(x2))
        if f2>f1:
            L=L*1.5
        else:
            L=L/2
            x1=x2
            f1=f2
        print(f'{i:3d} {f2:9.3e} {f1:9.3e} {L:9.3e}')
        if f2<tol:
            break
    return x1

In [None]:
def LM_full(f,x1,y_data=0.,max_iter=100, tol=1e-20, cgtol=1e-25, L=1.):

    jac = jax.jacobian(f)
    fjit = jax.jit(f)
    f1=jnp.linalg.norm(y_data-fjit(x1))
    for i in range(max_iter):
        J=jac(x1)
        M = np.diag(1./((1.+L)*jnp.diagonal(J.T@J)))
        dx = bicgstab(np.array(J.T@J+ L*jnp.diag(jnp.diagonal(J.T@J))), J.T @ (y_data-fjit(x1)), tol=cgtol,M=M)[0]
        # dx = np.linalg.solve(J.T@J + L*jnp.diag(jnp.diagonal(J.T@J)), J.T@(y_data-f(x1)))
        x2=x1+dx
        f2=jnp.linalg.norm(y_data-fjit(x2))
        if f2>f1:
            L=L*1.5
        else:
            L=L/2
            x1=x2
            f1=f2
        print(f'{i:3d} {f2:9.3e} {f1:9.3e} {L:9.3e}')
        if f2<tol:
            break
    return x1

In [None]:
LM_full(f,x0,y_data, L=0.1)

In [None]:
def rosen(x):
    return jnp.sum(100.0*(x[1:] - x[:-1]**2.0)**2.0 + (1 - x[:-1])**2.0)
rosen_grad=jax.jit(jax.grad(rosen))

In [None]:
rosen_x0=jnp.full(10,0.8)

In [None]:
LM_full(rosen_grad,rosen_x0)

In [28]:
rng=np.random.RandomState(123)
N=1000
N_offdiagonal = int(3*N)
xguess=jnp.asarray(np.random.uniform(size=N))
triples = rng.randint(0,N,size=(N_offdiagonal,3))
eq = rng.randint(0,N,N_offdiagonal)
coeff = rng.uniform(-1,1,size=N)
mult = np.c_[1e12*(1+rng.uniform(size=N)),np.ones(N)][np.arange(N),rng.choice([0,1],size=N,p=[0.2,0.8])]

def f_orig(x):
    x=jnp.asarray(x)+coeff
    res=x**2
    res=res.at[eq].add(-x[triples[:,0]]*x[triples[:,1]]*x[triples[:,2]])
    return res*mult

In [29]:
_, f_vjp= jax.vjp(f_orig,xguess)

In [30]:
idx=jnp.zeros(xguess.size)

def vjp_max(carry, _):
    return carry+1, jnp.max(jnp.abs(f_vjp(idx.at[carry].set(1.))[0]))

fac=jax.lax.scan(vjp_max,0,jnp.arange(xguess.size))[1]

In [31]:
sqerr = jax.jit(lambda x: jnp.sum((f_orig(x)/fac)**2))
jacobian = jax.jit(jax.jacobian(sqerr))
jac=lambda x: np.array(jacobian(x))
res=minimize(sqerr, xguess, jac=jac, method='L-BFGS-B', tol=1e-25)

In [32]:
res.fun

DeviceArray(9.5420855e-26, dtype=float64)

In [None]:
f_orig(res.x)