<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2021/blob/main/NewtonTrust.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
import numpy as np
from functools import partial
from scipy.optimize._lsq import common
import scipy
from scipy.optimize import root,minimize
from scipy.sparse.linalg import spsolve, bicgstab, LinearOperator

In [2]:
def get_J(f):
    def J(x,v):
        return jax.jvp(f, (x,),(v,))[1]
    return jax.jit(J)

def get_line_f(func):
    def line_f(t,x,dx):
        return func(x+t*dx)
    return jax.jit(line_f)

def getLO(f):
    def mv(x,v):
        return jax.jvp(f,(x,),(jnp.squeeze(v),))[1]

    def rmv(x,v):
        return jax.vjp(f,x)[1](v)

    jit_mv = jax.jit(mv)
    jit_rmv = jax.jit(rmv)

    def LO(x):
        # return LinearOperator((x.size,x.size), matvec = partial(jit_mv,x), rmatvec=partial(jit_rmv, x))
        return LinearOperator((x.size,x.size), matvec = partial(jit_mv,x))

    return LO

In [354]:
def get_scaled(f):
    def scaled(x,scale):
        return jnp.sum((f(x)/scale)**2)

    return jax.jit(scaled)

In [365]:
rng=np.random.RandomState(123)
N=100
N_offdiagonal = int(3*N)
xguess=jnp.asarray(np.random.uniform(size=N))
triples = rng.randint(0,N,size=(N_offdiagonal,3))
eq = rng.randint(0,N,N_offdiagonal)
coeff = rng.uniform(-1,1,size=N)
mult = np.c_[1e12*(1+rng.uniform(size=N)),np.ones(N)][np.arange(N),rng.choice([0,1],size=N,p=[0.1,0.9])]

def f_orig(x):
    x=jnp.asarray(x)+coeff
    res=x**2
    res=res.at[eq].add(-x[triples[:,0]]*x[triples[:,1]])
    return res*mult

fsq=get_scaled(f_orig)
fsq_grad = jax.jit(jax.grad(fsq))
def fsq_hvp(x,scale,v):
    return jax.grad(lambda x: jnp.vdot(jax.grad(fsq)(x,scale), v))(x)

fsq_hvp_jit=jax.jit(fsq_hvp)

def fsq_LO(x, scale):
    return LinearOperator((N,N), matvec=lambda v: fsq_hvp_jit(x=x,v=v,scale=scale))

## Dog Leg

In [None]:
x1=xguess
f1 = fsq(x1, 1)
trust_radius=1.
scale=np.ones_like(x1)
for i in range(500):
    lr=0.1
    scale=(1-lr)*scale+lr*jnp.abs(f_orig(x1))
    hvp=fsq_LO(x=x1, scale=scale)
    grad=fsq_grad(x1, scale=scale)
    newton=bicgstab(hvp, -grad,atol=1e-15)[0]

    grad_direction_failed=False
    for j in range(32):
        grad_direction=False
        newton_norm=np.linalg.norm(newton)
        grad_norm=np.linalg.norm(grad)
        cauchy = jnp.sum(grad**2)/jnp.sum(((hvp(-grad))**2))

        if newton_norm<trust_radius:
            x2=x1+newton


        elif newton_norm>trust_radius and cauchy*grad_norm > trust_radius:
            x2=x1+grad*trust_radius/grad_norm

        else: 
            sd=cauchy*(-grad)
            res=root(lambda s: jnp.linalg.norm(sd+s*(newton-sd))-trust_radius,0.5)
            x2=x1+sd+res.x*(newton-sd)

        f2=fsq(x2, scale=scale)
        dx=x2-x1

        exp_deltaf=jnp.dot(grad,dx) + jnp.dot(dx,hvp(dx))/2
        rho = (f2-f1)/exp_deltaf

        if newton_norm>=trust_radius and rho>0.75:
            trust_radius*=2
        elif rho<0.25:
            trust_radius*=0.25

        if f2<f1:
            break
        

    print(f'{i:3d}, {j:3d}, {f2:8.3e}, {exp_deltaf: 7.3e}, {rho: 7.2f}, {np.linalg.norm(x2-x1):7.2e}, {trust_radius: 7.3e}')
    x1=x2
    f1=f2
    if f1<1e-24:
        break
    

# print(x1)
# print(func(x1))

In [370]:
print(f_orig(x1))

[ 6.71391269e-21  8.17575859e-22  1.67483475e-21  3.69561078e-21
  1.75946416e-21 -3.12045541e-21  2.84773176e-21 -4.60825143e-21
  4.17255915e-21  1.66730256e-21  1.47842972e-23 -5.86337749e-22
  1.08259868e-09 -4.16074687e-21  2.26985659e-21 -3.75746249e-21
  1.67281598e-21  5.34514072e-21  2.06172100e-09  3.87311650e-21
  2.03293252e-20 -8.45846485e-22  7.15515545e-21 -5.33008718e-21
 -3.18837178e-21  4.31256253e-21 -4.29809524e-22  2.88825457e-21
 -1.32336411e-10  1.93200777e-21  1.12537970e-20 -9.64669604e-21
  2.38210397e-21  3.97339326e-21  3.18044355e-21  6.26043369e-21
  6.94344088e-22 -6.15286699e-22 -4.14388476e-21  4.49229596e-21
  4.76720991e-21  3.25149528e-21 -8.49207357e-21  5.98024401e-22
  1.18738819e-21  7.48388193e-22 -1.30409733e-21  4.04691881e-21
  7.50236848e-10  2.84381740e-21 -6.77503170e-21  2.79472115e-21
 -5.25715920e-25  1.11183306e-23 -6.45301773e-10  4.86261513e-21
 -3.77359330e-22  1.37259858e-22 -4.45395322e-21  4.10281872e-21
  4.97302106e-21  7.98537

## Line Search

In [None]:
x1=xguess
f1 = fsq(x1, 1)
for i in range(100):
    gradient_direction=False
    hvp=fsq_LO(x=x1, scale=1.)
    grad=fsq_grad(x1, scale=1)
    dx=bicgstab(hvp, -grad,atol=1e-14)[0]
    alpha, fc, f2 = scipy.optimize.linesearch.line_search_armijo(lambda x: fsq(x,scale=1), x1, dx, grad, f1, c1=0.5)
    if alpha is None or alpha <1e-10:
        dx=-grad
        alpha, fc, f2 = scipy.optimize.linesearch.line_search_armijo(lambda x: fsq(x, scale=1), x1, dx, grad, f1, c1=0.5)
        gradient_direction=True
    x2 = x1 + alpha*dx
    f2 = fsq(x2, scale=1)
    print(f'{i:3d}, {fc:3d},{f2:9.3e}, {np.linalg.norm(alpha*dx):9.2e}, {alpha: 9.3e}',end='')
    if gradient_direction:
        print(' gradient direction', end='')
    print()
    x1=x2
    f1=f2
    if f1<1e-15:
        break