# Chapter 6: Second-Order Methods

In [4]:
from copy import copy
from dataclasses import dataclass
import jax
import numpy as np
from typing import Type

## Algorithm 6.1

In [18]:
def newtons_method(gradient, hessian, x, epsilon, k_max):
        delta = np.infty*np.ones(x.shape[0])
        k = 1
        while np.linalg.norm(delta) > epsilon and k <= k_max:
                delta = np.linalg.solve(hessian(x),-gradient(x))
                x -= delta
                k += 1
        return x

### Example

In [20]:
def fun(x):
    return jax.numpy.sin(x[0]*x[1])+jax.numpy.exp(x[1]+x[2])-x[2]

x_0 = np.array([1.0, 2.0, 3.0])
gradient = jax.grad(fun)
hessian = jax.hessian(fun)
epsilon = 0.0001
k_max = 2
x_1 = newtons_method(gradient, hessian, x_0, epsilon, k_max)
print(x_1)

[  0.29836422  28.81422853 -21.8234632 ]


## Algorithm 6.2

In [52]:
def secant_method(derivative, x_0, x_1, epsilon):
    g_0 = derivative(x_0)
    delta = np.infty
    while np.abs(delta) > epsilon:
        g_1 = derivative(x_1)
        delta = ((x_1-x_0)/(g_1-g_0))*g_1
        x_0, x_1, g_0 = x_1, x_1 - delta, g_1

    return x_1

## Algorithm 6.3

In [82]:
@dataclass
class DFP:
    Q = np.identity(1)

def init(M:Type[DFP], x):
    m = x.shape[0]
    M.Q = np.identity(m)
    return M

def step(M:Type[DFP], f, gradient, x):
    Q, g = M.Q, gradient(x)
    x_prime = line_search(f, x, np.dot(-Q, g) )
    g_prime = gradient(x_prime)
    # -- Vectors as columns -- #
    x = x.reshape(-1,1)
    x_prime = x_prime.reshape(-1,1)
    g = g.reshape(-1,1)
    g_prime = g_prime.reshape(-1,1)
    # ------------------------ #
    delta = x_prime - x
    gamma = g_prime - g
    Q = Q - np.dot(Q,np.dot(gamma,np.dot(gamma.T,Q)))/np.dot(gamma.T,np.dot(Q, gamma)) + np.dot(delta, delta.T)/np.dot(delta.T,gamma)
    M.Q = Q
    return x_prime.reshape(-1,)

### Example

In [11]:
def bracket_minimum(f, x=0, s=1e-2, k=2.0):
    a, ya = x, f(x)
    b, yb = a + s, f(a+s)
    
    if yb > ya:
        a, b = b, a
        ya, yb = yb, ya
        s = -s
    
    while True:
        c, yc = b + s, f(b+s)
        if yc > yb:
            return (a,c) if a < c else (c, a)
        a, ya, b, yb = b, yb, c, yc

def minimize(f_deriv, a, b, epsilon):
    '''
    Bisection algorithm for univariate optmization
    '''
    if a > b:
        a, b = b, a
    
    ya, yb = f_deriv(a), f_deriv(b)

    if ya == 0:
        b = a
    if yb == 0:
        a = b

    while b - a > epsilon:
        x = (a+b)/2
        y = f_deriv(x)
        if y == 0:
            a, b = x, x
        elif y*ya> 0:
            a = x
        else:
            b = x
    
    return (a+b)/2

def line_search(f, x, d):
    '''
    One step of Exact Line Search
    '''
    objective = lambda alpha: f(x+alpha*d)

    a, b = bracket_minimum(objective)

    alpha = minimize(objective, a, b, 0.0001)

    return x + alpha*d

In [81]:
def func(x):
    return jax.numpy.sin(x[0]*x[1])+jax.numpy.exp(x[1]+x[2])-x[2]

x_0 = np.array([1.0, 2.0, 3.0])
gradient = jax.grad(func)
M = DFP()
M = init(M, x_0)
x_1 = step(M, func, gradient, x_0)
print(x_1)

[[-0.43911442]
 [ 1.8882942 ]
 [ 0.00341964]]
[ 1.01407747 -0.50323066  0.50664456]


## Algorithm 6.4

In [71]:
@dataclass
class BFGS:
    Q = np.identity(1)

def init(M:Type[BFGS], x):
    m = x.shape[0]
    M.Q = np.identity(m)
    return M

def step(M:Type[BFGS], f, gradient, x):
    Q, g = M.Q, gradient(x)
    x_prime = line_search(f, x, np.dot(-Q, g) )
    g_prime = gradient(x_prime)
    # -- Vectors as columns -- #
    x = x.reshape(-1,1)
    x_prime = x_prime.reshape(-1,1)
    g = g.reshape(-1,1)
    g_prime = g_prime.reshape(-1,1)
    # ------------------------ #
    delta = x_prime - x
    gamma = g_prime - g
    prt1 = (np.dot(delta, np.dot(gamma.T,Q)) + np.dot(Q, np.dot(gamma,delta.T)))/np.dot(delta.T, gamma)
    prt2 = ( 1 + np.dot(gamma.T,np.dot(Q,gamma))/np.dot(delta.T, gamma) )
    prt3 = np.dot(delta, delta.T)/np.dot(delta.T, gamma)
    Q = Q - prt1 + prt2*prt3
    M.Q = Q
    return x_prime.reshape(-1,)

## Algorithm 6.5

In [71]:
@dataclass
class LimitedMemoryBFGS:
    m: int
    deltas = [] #ẟs
    gammas = [] #γs
    qs = []

def step(M:Type[LimitedMemoryBFGS], f, gradient, x):
    deltas, gammas, qs = M.deltas, M.gammas, M.qs
    g = gradient(x)
    m = len(M.deltas)
    print(m)
    if m >0:
        q = g.reshape(-1,1)
        for i in range(m-1,-1,-1):
            qs[i] = copy(q)
            q -=  (np.dot(deltas[i].T, q) / np.dot(gammas[i].T, deltas[i]))[0,0] * gammas[i] 
        z = (gammas[m-1]*deltas[m-1]*q) / np.dot(gammas[m-1].T,gammas[m-1])
        for i in range(m):
            z += np.dot(deltas[i], ( np.dot(deltas[i].T,qs[i]) - np.dot(gammas[i].T,z) ) / np.dot(gammas[i].T, deltas[i]) )
        x_prime = line_search(f, x, -z.reshape(-1,))
    else:
        x_prime = line_search(f, x, -g)
    g_prime = gradient(x_prime)
    deltas.insert(-1, x_prime.reshape(-1,1) - x.reshape(-1,1) )
    gammas.insert(-1, g_prime.reshape(-1,1) - g.reshape(-1,1))
    qs.insert(-1,np.zeros((x.shape[0],1)))
    while len(deltas) > M.m:
        deltas.pop(0)
        gammas.pop(0)
        qs.pop(0)
    M.gammas = gammas
    M.deltas = deltas
    M.qs = qs
    return x_prime

In [70]:
def func(x):
    return jax.numpy.sin(x[0]*x[1])+jax.numpy.exp(x[1]+x[2])-x[2]

x_0 = np.array([1.0, 2.0, 3.0])
gradient = jax.grad(func)
M = LimitedMemoryBFGS(m=10)
x_1 = step(M, func, gradient, x_0)
x_2 = step(M, func, gradient, x_1)
print(x_2)

0
1
[[-9.1543436e-05]
 [ 2.4279181e-02]
 [ 8.1006093e-03]]
[ 1.0180019  -1.5440781   0.15937167]
