<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2021/blob/main/NewtonTrust.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
import numpy as np
from functools import partial
from scipy.optimize._lsq import common
import scipy
from scipy.optimize import root
from scipy.sparse.linalg import spsolve, bicgstab, LinearOperator

In [111]:
def get_J(f):
    def J(x,v):
        return jax.jvp(f, (x,),(v,))[1]
    return jax.jit(J)

def get_line_f(func):
    def line_f(t,x,dx):
        return func(x+t*dx)
    return jax.jit(line_f)

def getLO(f):
    def mv(x,v):
        return jax.jvp(f,(x,),(jnp.squeeze(v),))[1]

    def rmv(x,v):
        return jax.vjp(f,x)[1](v)

    jit_mv = jax.jit(mv)
    jit_rmv = jax.jit(rmv)

    def LO(x):
        # return LinearOperator((x.size,x.size), matvec = partial(jit_mv,x), rmatvec=partial(jit_rmv, x))
        return LinearOperator((x.size,x.size), matvec = partial(jit_mv,x))

    return LO

In [183]:
def get_scaled(f):
    jac=jax.jit(jax.jacobian(f))
    def scaled(x):
        Jrow_max=jnp.max(jnp.abs(jac(jax.lax.stop_gradient(x))),axis=1)
        return f(x)/Jrow_max

    return scaled

In [230]:
rng=np.random.RandomState(123)
N=200
N_offdiagonal = int(3*N)
xguess=jnp.asarray(np.random.uniform(size=N))
triples = rng.randint(0,N,size=(N_offdiagonal,3))
eq = rng.randint(0,N,N_offdiagonal)
coeff = rng.uniform(-1,1,size=N)

def f_orig(x):
    x=jnp.asarray(x)+coeff
    res=x**2
    res=res.at[eq].add(-x[triples[:,0]]*x[triples[:,1]])
    return res

func=f_orig

In [231]:

jit_f = jax.jit(func)
line_f = get_line_f(func)
J = get_J(func)

fsq = jax.jit(lambda x: jnp.sum(func(x)**2))
fsq_grad = jax.jit(jax.grad(fsq))
def fsq_hvp(x,v):
    return jax.grad(lambda x: jnp.vdot(jax.grad(fsq)(x), v))(x)

fsq_hvp_jit=jax.jit(fsq_hvp)

def fsq_LO(x):
    return LinearOperator((N,N), matvec=partial(fsq_hvp_jit,x))

In [235]:
x1=xguess
f1 = fsq(x1)
LO=getLO(func)
trust_radius=1.
for i in range(500):
    hvp=fsq_LO(x1)
    grad=fsq_grad(x1)
    dx=bicgstab(hvp, -grad,atol=1e-14)[0]


    while True:
        dx_norm=np.linalg.norm(dx)
        if dx_norm>trust_radius:
            dx=dx*trust_radius/dx_norm

        x2=x1+dx
        f2=fsq(x2)


        exp_deltaf=jnp.dot(grad,dx) + jnp.dot(dx,hvp(dx))/2

        rho = (f2-f1)/exp_deltaf

        if rho>0.75 and dx_norm>trust_radius:
            trust_radius*=2
            break
        elif rho>0.75:
            trust_radius=dx_norm*2
            break
        elif rho>0.25:
            break
        trust_radius*=0.5
        if trust_radius<0.05:
            trust_radius*=10
            dx=-grad
            while True:
                x2=x1+dx
                f2=fsq(x2)
                if f2 > f1:
                    dx*=0.5
                else:
                    break
            print(f'grad dx: {np.linalg.norm(dx):7.2e}')
            break



    print(f'{i:3d}, {jnp.sum(f_orig(x2)**2):8.3e}, {f2:8.3e}, {(f2-f1): 7.3e}, {exp_deltaf: 7.3e}, {rho: 7.2f}, {np.linalg.norm(dx):7.2e}, {trust_radius: 7.3e}')
    x1=x2
    f1=f2
    if f1<1e-15:
        break
    

print(x1)
print(func(x1))

  0, 2.839e+02, 2.839e+02, -1.214e+02, -1.205e+02,    1.01, 1.00e+00,  2.000e+00
  1, 1.245e+02, 1.245e+02, -1.594e+02, -1.524e+02,    1.05, 2.00e+00,  4.000e+00
  2, 2.460e+01, 2.460e+01, -9.994e+01, -8.303e+01,    1.20, 2.91e+00,  5.827e+00
  3, 4.859e+00, 4.859e+00, -1.974e+01, -1.640e+01,    1.20, 1.94e+00,  3.886e+00
  4, 9.599e-01, 9.599e-01, -3.900e+00, -3.240e+00,    1.20, 1.29e+00,  2.590e+00
  5, 1.896e-01, 1.896e-01, -7.703e-01, -6.399e-01,    1.20, 8.64e-01,  1.728e+00
  6, 3.745e-02, 3.745e-02, -1.522e-01, -1.264e-01,    1.20, 5.76e-01,  1.151e+00
  7, 7.398e-03, 7.398e-03, -3.005e-02, -2.497e-02,    1.20, 3.84e-01,  7.680e-01
  8, 1.461e-03, 1.461e-03, -5.937e-03, -4.932e-03,    1.20, 2.56e-01,  5.118e-01
  9, 2.887e-04, 2.887e-04, -1.173e-03, -9.742e-04,    1.20, 1.71e-01,  3.413e-01
 10, 5.702e-05, 5.702e-05, -2.316e-04, -1.924e-04,    1.20, 1.14e-01,  2.274e-01
 11, 1.126e-05, 1.126e-05, -4.576e-05, -3.801e-05,    1.20, 7.58e-02,  1.516e-01
 12, 2.225e-06, 2.225e-06, -

In [233]:
scipy.optimize.root(jit_f,xguess,jac=jax.jit(jax.jacobian(func)), method='hybr')

    fjac: array([[-4.35253611e-01,  1.16941707e-03,  9.71994437e-05, ...,
         3.22839674e-04,  7.41900385e-04,  1.81337700e-03],
       [-1.60918746e-02, -6.76288131e-01,  8.86013558e-04, ...,
         2.94281857e-03,  6.76273212e-03,  1.65296893e-02],
       [ 5.41117229e-02,  3.14656793e-02, -5.94600412e-01, ...,
         1.17006929e-03,  2.68887292e-03, -7.25025898e-02],
       ...,
       [-4.08071101e-02, -4.24927717e-02,  1.98221477e-02, ...,
        -3.79358541e-05, -2.09787426e-02,  1.77208550e-02],
       [ 6.33103515e-03, -2.99912728e-02, -9.74441926e-03, ...,
         1.00495778e-02, -1.02471816e-02, -3.09213125e-03],
       [-3.52899646e-02, -1.03527736e-01, -1.20816106e-01, ...,
        -1.34751518e-02, -3.69302889e-02, -8.79573544e-02]])
     fun: array([-5.26996063e-29,  5.72417194e-29,  4.61699334e-29,  1.07698002e-30,
       -9.11558050e-29,  7.46952670e-30, -8.81752358e-29,  7.32469676e-29,
       -1.00022016e-28,  3.98320831e-29,  1.24461297e-29,  2.13007852e-29

In [234]:
def f(x):
    return jnp.array([jnp.sin(x[0])*jnp.cos(x[1]),jnp.cos(2*x[0])+jnp.sin(x[1])])

In [28]:
x0=jnp.array([0.5,1])
jax.jacobian(f)(x0)

DeviceArray([[ 0.47415988, -0.40342268],
             [-1.68294197,  0.54030231]], dtype=float64)

In [29]:
f0,vjp=jax.vjp(f,x0)

In [31]:
vjp(jnp.array([1.,0.]))

(DeviceArray([ 0.47415988, -0.40342268], dtype=float64),)

In [76]:
bicgstab?