<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2021/blob/main/NewtonTrust.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
import numpy as np
from functools import partial
from scipy.optimize._lsq import common
import scipy
from scipy.optimize import root,minimize
from scipy.sparse.linalg import spsolve, bicgstab, LinearOperator

In [2]:
def get_J(f):
    def J(x,v):
        return jax.jvp(f, (x,),(v,))[1]
    return jax.jit(J)

def get_line_f(func):
    def line_f(t,x,dx):
        return func(x+t*dx)
    return jax.jit(line_f)

def getLO(f):
    def mv(x,v):
        return jax.jvp(f,(x,),(jnp.squeeze(v),))[1]

    def rmv(x,v):
        return jax.vjp(f,x)[1](v)

    jit_mv = jax.jit(mv)
    jit_rmv = jax.jit(rmv)

    def LO(x):
        # return LinearOperator((x.size,x.size), matvec = partial(jit_mv,x), rmatvec=partial(jit_rmv, x))
        return LinearOperator((x.size,x.size), matvec = partial(jit_mv,x))

    return LO

In [315]:
def get_scaled(f):
    def scaled(x,scale):
        return jnp.sum((f(x)/scale)**2)

    return jax.jit(scaled)

In [348]:
rng=np.random.RandomState(123)
N=1000
N_offdiagonal = int(3*N)
xguess=jnp.asarray(np.random.uniform(size=N))
triples = rng.randint(0,N,size=(N_offdiagonal,3))
eq = rng.randint(0,N,N_offdiagonal)
coeff = rng.uniform(-1,1,size=N)
mult = np.c_[1e9*(1+rng.uniform(size=N)),np.ones(N)][np.arange(N),rng.choice([0,1],size=N,p=[0.1,0.9])]

def f_orig(x):
    x=jnp.asarray(x)+coeff
    res=x**2
    res=res.at[eq].add(-x[triples[:,0]]*x[triples[:,1]])
    return res*mult

fsq=get_scaled(f_orig)
fsq_grad = jax.jit(jax.grad(fsq))
def fsq_hvp(x,scale,v):
    return jax.grad(lambda x: jnp.vdot(jax.grad(fsq)(x,scale), v))(x)

fsq_hvp_jit=jax.jit(fsq_hvp)

def fsq_LO(x, scale):
    return LinearOperator((N,N), matvec=lambda v: fsq_hvp_jit(x=x,v=v,scale=scale))

## Dog Leg

In [349]:
x1=xguess
f1 = fsq(x1, 1)
trust_radius=1.
scale=np.ones_like(x1)
for i in range(500):
    lr=0.1
    scale=(1-lr)*scale+lr*jnp.abs(f_orig(x1))
    hvp=fsq_LO(x=x1, scale=scale)
    grad=fsq_grad(x1, scale=scale)
    newton=bicgstab(hvp, -grad,atol=1e-15)[0]

    grad_direction_failed=False
    while True:
        grad_direction=False
        newton_norm=np.linalg.norm(newton)
        grad_norm=np.linalg.norm(grad)
        cauchy = jnp.sum(grad**2)/jnp.sum(((hvp(-grad))**2))

        if newton_norm<trust_radius:
            x2=x1+newton


        elif newton_norm>trust_radius and cauchy*grad_norm > trust_radius:
            x2=x1+grad*trust_radius/grad_norm

        else: 
            sd=cauchy*(-grad)
            res=root(lambda s: jnp.linalg.norm(sd+s*(newton-sd))-trust_radius,0.5)
            x2=x1+sd+res.x*(newton-sd)

        f2=fsq(x2, scale=scale)
        dx=x2-x1

        exp_deltaf=jnp.dot(grad,dx) + jnp.dot(dx,hvp(dx))/2
        rho = (f2-f1)/exp_deltaf

        if newton_norm>=trust_radius and rho>0.75:
            trust_radius*=2
        elif rho<0.25:
            trust_radius*=0.25

        if f2<f1:
            break
        

    print(f'{i:3d}, {f2:8.3e}, {exp_deltaf: 7.3e}, {rho: 7.2f}, {np.linalg.norm(x2-x1):7.2e}, {trust_radius: 7.3e}')
    x1=x2
    f1=f2
    if f1<1e-15:
        break
    

# print(x1)
# print(func(x1))

  0, 9.943e+03, -1.781e+03,  210664756092518592.00, 1.00e+00,  2.000e+00
  1, 2.557e+03, -1.106e+03,    6.68, 2.00e+00,  4.000e+00
  2, 7.534e+02, -9.351e+02,    1.93, 4.00e+00,  8.000e+00
  3, 1.325e+02, -4.584e+02,    1.35, 6.43e+00,  8.000e+00
  4, 2.775e+01, -9.685e+01,    1.08, 4.79e+00,  8.000e+00
  5, 2.169e+01, -2.134e+01,    0.28, 5.48e+00,  8.000e+00
  6, 4.974e+00, -1.672e+01,    1.00, 3.00e+00,  8.000e+00
  7, 1.204e+00, -3.962e+00,    0.95, 2.15e+00,  8.000e+00
  8, 2.901e-01, -9.769e-01,    0.94, 1.45e+00,  8.000e+00
  9, 7.042e-02, -2.370e-01,    0.93, 1.02e+00,  8.000e+00
 10, 1.725e-02, -5.769e-02,    0.92, 6.85e-01,  8.000e+00
 11, 4.215e-03, -1.417e-02,    0.92, 4.68e-01,  8.000e+00
 12, 1.039e-03, -3.464e-03,    0.92, 3.20e-01,  8.000e+00
 13, 2.526e-04, -8.586e-04,    0.92, 2.33e-01,  8.000e+00
 14, 6.278e-05, -2.078e-04,    0.91, 1.46e-01,  8.000e+00
 15, 1.545e-05, -5.167e-05,    0.92, 1.01e-01,  8.000e+00
 16, 3.766e-06, -1.274e-05,    0.92, 7.10e-02,  8.000e+00

KeyboardInterrupt: ignored

# previous TN

In [281]:
x1=xguess
f1 = fsq(x1, 1)
trust_radius=1.
for i in range(100):
    hvp=fsq_LO(x=x1, scale=1.)
    grad=fsq_grad(x1, scale=1)
    dx=bicgstab(hvp, -grad,atol=1e-14)[0]

    grad_direction_failed=False
    while True:
        grad_direction=False
        dx_norm=np.linalg.norm(dx)
        if dx_norm>trust_radius:
            dx=dx*trust_radius/dx_norm

        x2=x1+dx
        f2=fsq(x2, scale=1.)


        exp_deltaf=jnp.dot(grad,dx) + jnp.dot(dx,hvp(dx))/2

        rho = (f2-f1)/exp_deltaf

        if exp_deltaf<0:
            if rho>0.75 and dx_norm>trust_radius:
                trust_radius*=2
                break
            elif rho>0.75:
                trust_radius=dx_norm*2
                break
            elif rho>0.25:
                break
            trust_radius*=0.5

        else:
            print('grad')
            dx=-grad
            for j in range(32):
                x2=x1+dx
                f2=fsq(x2)
                if f2 > f1:
                    dx*=0.5
                else:
                    break
            if j==31:
                grad_direction_failed=True
                break
            grad_direction=True
            break



    print(f'{i:3d}, {jnp.sum(f_orig(x2)**2):8.3e}, {f2:8.3e}, {(f2-f1): 7.3e}, {exp_deltaf: 7.3e}, {rho: 7.2f}, {np.linalg.norm(dx):7.2e}, {trust_radius: 7.3e}',end='')
    if grad_direction:
        print(f',  grad: {j:3d}',end='')
    print('')
    x1=x2
    f1=f2
    if f1<1e-15:
        break
    

# print(x1)
# print(func(x1))

  0, 1.245e+02, 1.245e+02, -8.245e+01, -8.108e+01,    1.02, 1.00e+00,  2.000e+00
  1, 3.516e+01, 3.516e+01, -8.931e+01, -8.007e+01,    1.12, 2.00e+00,  4.000e+00
  2, 6.945e+00, 6.945e+00, -2.822e+01, -2.344e+01,    1.20, 1.79e+00,  3.587e+00
  3, 1.372e+00, 1.372e+00, -5.573e+00, -4.630e+00,    1.20, 1.20e+00,  2.391e+00
  4, 2.710e-01, 2.710e-01, -1.101e+00, -9.146e-01,    1.20, 7.97e-01,  1.594e+00
  5, 5.353e-02, 5.353e-02, -2.175e-01, -1.807e-01,    1.20, 5.31e-01,  1.063e+00
  6, 1.057e-02, 1.057e-02, -4.296e-02, -3.569e-02,    1.20, 3.54e-01,  7.087e-01
  7, 2.089e-03, 2.089e-03, -8.485e-03, -7.049e-03,    1.20, 2.36e-01,  4.724e-01
  8, 4.126e-04, 4.126e-04, -1.676e-03, -1.392e-03,    1.20, 1.57e-01,  3.150e-01
  9, 8.150e-05, 8.150e-05, -3.311e-04, -2.750e-04,    1.20, 1.05e-01,  2.100e-01
 10, 1.610e-05, 1.610e-05, -6.540e-05, -5.433e-05,    1.20, 7.00e-02,  1.400e-01
 11, 3.180e-06, 3.180e-06, -1.292e-05, -1.073e-05,    1.20, 4.67e-02,  9.332e-02
 12, 6.281e-07, 6.281e-07, -

## Line Search

In [None]:
x1=xguess
f1 = fsq(x1, 1)
for i in range(100):
    gradient_direction=False
    hvp=fsq_LO(x=x1, scale=1.)
    grad=fsq_grad(x1, scale=1)
    dx=bicgstab(hvp, -grad,atol=1e-14)[0]
    alpha, fc, f2 = scipy.optimize.linesearch.line_search_armijo(lambda x: fsq(x,scale=1), x1, dx, grad, f1, c1=0.5)
    if alpha is None or alpha <1e-10:
        dx=-grad
        alpha, fc, f2 = scipy.optimize.linesearch.line_search_armijo(lambda x: fsq(x, scale=1), x1, dx, grad, f1, c1=0.5)
        gradient_direction=True
    x2 = x1 + alpha*dx
    f2 = fsq(x2, scale=1)
    print(f'{i:3d}, {fc:3d},{f2:9.3e}, {np.linalg.norm(alpha*dx):9.2e}, {alpha: 9.3e}',end='')
    if gradient_direction:
        print(' gradient direction', end='')
    print()
    x1=x2
    f1=f2
    if f1<1e-15:
        break