# Chapter 4: Local Descent

In [1]:
import numpy as np
import jax

ModuleNotFoundError: No module named 'jax'

## Algorithm 4.1

In [25]:
def bracket_minimum(f, x=0, s=1e-2, k=2.0):
    a, ya = x, f(x)
    b, yb = a + s, f(a+s)
    
    if yb > ya:
        a, b = b, a
        ya, yb = yb, ya
        s = -s
    
    while True:
        c, yc = b + s, f(b+s)
        if yc > yb:
            return (a,c) if a < c else (c, a)
        a, ya, b, yb = b, yb, c, yc

def minimize(f_deriv, a, b, epsilon):
    '''
    Bisection algorithm for univariate optmization
    '''
    if a > b:
        a, b = b, a
    
    ya, yb = f_deriv(a), f_deriv(b)

    if ya == 0:
        b = a
    if yb == 0:
        a = b

    while b - a > epsilon:
        x = (a+b)/2
        y = f_deriv(x)
        if y == 0:
            a, b = x, x
        elif y*ya> 0:
            a = x
        else:
            b = x
    
    return (a+b)/2

def line_search(f, x, d):
    '''
    One step of Exact Line Search
    '''
    objective = lambda alpha: f(x+alpha*d)

    a, b = bracket_minimum(objective)

    alpha = minimize(objective, a, b, 0.0001)

    return x + alpha*d

### Example

In [26]:
def func(x):
    return np.sin(x[0]*x[1])+np.exp(x[1]+x[2])-x[2]

x = np.array([1.0, 2.0, 3.0])
d = np.array([0.0, -1.0, -1.0])

sol = line_search(func, x, d)
print(sol)

[ 1.         -1.13996094 -0.13996094]


## Algorithm 4.2

In [None]:
def backtracking_line_search(f, gradient, x, d, alpha, p=0.5, beta=1e-4):
    y, g = f(x), gradient
    while f(x+alpha*d) > y + beta*alpha*(np.dot(g,d)):
        alpha *= p
    return alpha

### Example

In [None]:
def aproximate_line_search(f, x, alpha):
    '''
    One Step of Aproximate Line Search with Backtracking
    '''
    gradient = jax.grad(f)
    d = -gradient(x)

    alpha = backtracking_line_search(f, gradient, x, d, alpha)

    return x + alpha*d

def func(x):
    return jax.numpy.sin(x[0]*x[1])+jax.numpy.exp(x[1]+x[2])-x[2]

x = np.array([1.0, 2.0, 3.0])
alpha = 1.0
sol = line_search(func, x, alpha)
print(sol)

## Algorithm 4.3

In [None]:
def strong_backtracking(f, gradient, x, d, alpha=1, beta=1e-4, sigma=0.1):
    y0, g0, y_prev, alpha_prev = f(x), np.dot(gradient(x),d), np.nan, 0
    alpha_lo, alpha_hi = np.nan, np.nan

    while True:
        y = f(x+alpha*d)
        if y > y0+beta*alpha*g0 or (not np.isnan(y_prev) and y >= y_prev):
            alpha_lo, alpha_hi = alpha_prev, alpha
            break
        g = np.dot(gradient(x+alpha*d),d)

        if np.abs(g) <= -sigma*g0:
            return alpha
        elif g >= 0:
            alpha_lo, alpha_hi = alpha, alpha_prev
            break
        y_prev, alpha_prev, alpha = y, alpha, 2*alpha

    ylo = f(x+alpha_lo*d)

    while True:
        alpha = (alpha_lo + alpha_hi)/2
        y = f(x+alpha*d)
        if y > y0 + beta*alpha*g0 or y >= ylo:
            alpha_hi = alpha
        else:
            g = np.dot(gradient(x+alpha*d),d)
            if np.abs(g) <= -sigma*g0:
                return alpha
            elif g*(alpha_hi - alpha_lo) >= 0:
                alpha_hi = alpha_lo

            alpha_lo = alpha

In [None]:
def aproximate_line_search(f, x, d, alpha):
    '''
    One Step of Aproximate Line Search with Strong Backtracking
    '''
    gradient = jax.grad(f)
    d = -gradient(x)

    alpha = strong_backtracking(f, gradient, x, d, alpha)

    return x + alpha*d

def func(x):
    return jax.numpy.sin(x[0]*x[1])+jax.numpy.exp(x[1]+x[2])-x[2]

x = np.array([1.0, 2.0, 3.0])
alpha = 1.0
sol = line_search(func, x, d, alpha)
print(sol)