In [1]:
import jax
import jax.numpy as jnp
from jax.config import config


from fax.implicit.twophase import two_phase_solver
from scipy.optimize import newton


config.update("jax_enable_x64", True)


In [3]:
def phi(z):
    return (z-2) ** 3 + 0.4


#see Evaluating Derivatives
def lighthouse(x, params):
    t = 2
    z = x[0]
    return phi(z) - z * jnp.tan(t)


def grad_lighthouse(z):
    return 3 * z ** 2 - 12 * z + 12 - jnp.tan(t)


def make_operator(params):
    
    def f(i,x):
        del i
        return lighthouse(x, params)
    return f

def newton_solver(x, params):
    
    f = lighthouse
    Df = jax.grad(lighthouse)

    epsilon = 1e-5
    max_iter = 1000
    xn = x
    
    for n in range(0,max_iter):
        
        fxn = f(xn, params) 
        Dfxn = Df(xn,params)
        xn2 = xn - fxn/Dfxn
        check = abs(fxn) < epsilon
        xn = jnp.where(check, xn, xn2) 
        
    return xn



In [4]:
fc = two_phase_solver(make_operator, forward_solver = newton_solver)

x0 = jnp.zeros(2)

params = [2]

x_star = fc(x0, params)



In [6]:
lighthouse(x_star, params)

DeviceArray(-1.6875389e-07, dtype=float64)