<a href="https://colab.research.google.com/github/profteachkids/CHE2064/blob/master/NewtonTrust.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
import numpy as np
from functools import partial
from scipy.optimize._lsq import common
import scipy
from scipy.optimize import root
from scipy.sparse.linalg import spsolve, bicgstab, LinearOperator

In [111]:
def get_J(f):
    def J(x,v):
        return jax.jvp(f, (x,),(v,))[1]
    return jax.jit(J)

def get_line_f(func):
    def line_f(t,x,dx):
        return func(x+t*dx)
    return jax.jit(line_f)

def getLO(f):
    def mv(x,v):
        return jax.jvp(f,(x,),(jnp.squeeze(v),))[1]

    def rmv(x,v):
        return jax.vjp(f,x)[1](v)

    jit_mv = jax.jit(mv)
    jit_rmv = jax.jit(rmv)

    def LO(x):
        # return LinearOperator((x.size,x.size), matvec = partial(jit_mv,x), rmatvec=partial(jit_rmv, x))
        return LinearOperator((x.size,x.size), matvec = partial(jit_mv,x))

    return LO

In [183]:
def get_scaled(f):
    jac=jax.jit(jax.jacobian(f))
    def scaled(x):
        Jrow_max=jnp.max(jnp.abs(jac(jax.lax.stop_gradient(x))),axis=1)
        return f(x)/Jrow_max

    return scaled

In [251]:
rng=np.random.RandomState(123)
N=300
N_offdiagonal = int(3*N)
xguess=jnp.asarray(np.random.uniform(size=N))
triples = rng.randint(0,N,size=(N_offdiagonal,3))
eq = rng.randint(0,N,N_offdiagonal)
coeff = rng.uniform(-1,1,size=N)

def f_orig(x):
    x=jnp.asarray(x)+coeff
    res=x**2
    res=res.at[eq].add(-x[triples[:,0]]*x[triples[:,1]]*x[triples[:,2]])
    return res

func=f_orig

In [252]:

jit_f = jax.jit(func)
line_f = get_line_f(func)
J = get_J(func)

fsq = jax.jit(lambda x: jnp.sum(func(x)**2))
fsq_grad = jax.jit(jax.grad(fsq))
def fsq_hvp(x,v):
    return jax.grad(lambda x: jnp.vdot(jax.grad(fsq)(x), v))(x)

fsq_hvp_jit=jax.jit(fsq_hvp)

def fsq_LO(x):
    return LinearOperator((N,N), matvec=partial(fsq_hvp_jit,x))

In [255]:
x1=xguess
f1 = fsq(x1)
LO=getLO(func)
trust_radius=1.
for i in range(500):
    hvp=fsq_LO(x1)
    grad=fsq_grad(x1)
    dx=bicgstab(hvp, -grad,atol=1e-14)[0]


    while True:
        dx_norm=np.linalg.norm(dx)
        if dx_norm>trust_radius:
            dx=dx*trust_radius/dx_norm

        x2=x1+dx
        f2=fsq(x2)


        exp_deltaf=jnp.dot(grad,dx) + jnp.dot(dx,hvp(dx))/2

        rho = (f2-f1)/exp_deltaf
        if exp_deltaf<0:
            if rho>0.75 and dx_norm>trust_radius:
                trust_radius*=2
                break
            elif rho>0.75:
                trust_radius=dx_norm*2
                break
            elif rho>0.25:
                break
            trust_radius*=0.5

        if trust_radius<0.05 or exp_deltaf<0:
            if trust_radius<0.05:
                trust_radius*=10
            dx=-grad
            for i in range(32):
                x2=x1+dx
                f2=fsq(x2)
                if f2 > f1:
                    dx*=0.5
                else:
                    break
            if i==31:
                print('grad failed')
                break
                
            print(f'{i:3d}, grad dx: {np.linalg.norm(dx):7.2e}')
            break



    print(f'{i:3d}, {jnp.sum(f_orig(x2)**2):8.3e}, {f2:8.3e}, {(f2-f1): 7.3e}, {exp_deltaf: 7.3e}, {rho: 7.2f}, {np.linalg.norm(dx):7.2e}, {trust_radius: 7.3e}')
    x1=x2
    f1=f2
    if f1<1e-15:
        break
    

print(x1)
print(func(x1))

  0, 3.818e+02, 3.818e+02, -3.536e+01, -3.541e+01,    1.00, 1.00e+00,  2.000e+00
  1, 3.368e+02, 3.368e+02, -4.501e+01, -4.553e+01,    0.99, 2.00e+00,  4.000e+00
  2, 2.584e+02, 2.584e+02, -7.844e+01, -8.649e+01,    0.91, 4.00e+00,  8.000e+00
  3, grad dx: 1.53e+01
  3, 1.739e+02, 1.739e+02, -8.450e+01, -8.315e+01,   -0.41, 1.53e+01,  4.000e+00
  4, 1.248e+02, 1.248e+02, -4.909e+01, -5.622e+01,    0.87, 4.00e+00,  8.000e+00
  3, grad dx: 9.22e+00
  3, 3.748e+01, 3.748e+01, -8.728e+01, -9.145e+00,  -97.81, 9.22e+00,  4.000e+00
  4, grad dx: 2.83e+00
  4, 7.092e+00, 7.092e+00, -3.039e+01, -1.270e+00,   -4.24, 2.83e+00,  2.000e+00
  7, 6.743e+00, 6.743e+00, -3.488e-01, -1.113e+00,    0.31, 2.00e+00,  2.000e+00
  1, grad dx: 2.89e+00
  1, 1.154e+00, 1.154e+00, -5.589e+00, -5.067e-01,   -0.16, 2.89e+00,  1.000e+00
  9, 1.124e+00, 1.124e+00, -2.988e-02, -8.098e-02,    0.37, 1.00e+00,  1.000e+00
  0, grad dx: 1.42e+00
  0, 2.723e-01, 2.723e-01, -8.520e-01, -1.364e-01,    0.15, 1.42e+00,  5.00

In [234]:
def f(x):
    return jnp.array([jnp.sin(x[0])*jnp.cos(x[1]),jnp.cos(2*x[0])+jnp.sin(x[1])])

In [28]:
x0=jnp.array([0.5,1])
jax.jacobian(f)(x0)

DeviceArray([[ 0.47415988, -0.40342268],
             [-1.68294197,  0.54030231]], dtype=float64)

In [29]:
f0,vjp=jax.vjp(f,x0)

In [31]:
vjp(jnp.array([1.,0.]))

(DeviceArray([ 0.47415988, -0.40342268], dtype=float64),)

In [76]:
bicgstab?