# Chapter 5: First-Order Methods

In [1]:
from typing import Type
from dataclasses import dataclass
import numpy as np
import jax

## Algorithm 5.1

In [2]:
@dataclass
class GradientDescent:
    alpha : float # α

def step(M:Type[GradientDescent], f, gradient, x):
    alpha, g = M.alpha, gradient(x)
    return x - alpha*g

### Example

In [4]:
def fun(x):
    return jax.numpy.sin(x[0]*x[1])+jax.numpy.exp(x[1]+x[2])-x[2]

x_0 = np.array([1.0, 2.0, 3.0])
gradient = jax.grad(fun)
M = GradientDescent(0.1)
x_1 = step(M, fun, gradient, x_0) 
print(x_1)

[  1.0832294 -12.799701  -11.741317 ]


## Algorithm 5.2

In [24]:
class ConjugateGradientDescent:

    def __init__(self, g):
        self.g = g 
        self.d = -self.g
    
    def step(self, f, gradient, x):
        g_prime = gradient(x)
        beta = max(0, np.dot(g_prime, g_prime - self.g)/(np.dot(self.g, self.g)))
        d_prime = -g_prime + beta*self.d
        x_prime = line_search(f, x, d_prime)
        self.d = d_prime
        self.g = g_prime
        
        return x_prime
    
    def first_step(self, f, x):
        x_prime = line_search(f, x, self.d)
        return x_prime

### Example

In [27]:
def bracket_minimum(f, x=0, s=1e-2, k=2.0):
    a, ya = x, f(x)
    b, yb = a + s, f(a+s)
    
    if yb > ya:
        a, b = b, a
        ya, yb = yb, ya
        s = -s
    
    while True:
        c, yc = b + s, f(b+s)
        if yc > yb:
            return (a,c) if a < c else (c, a)
        a, ya, b, yb = b, yb, c, yc

def minimize(f_deriv, a, b, epsilon):
    '''
    Bisection algorithm for univariate optmization
    '''
    if a > b:
        a, b = b, a
    
    ya, yb = f_deriv(a), f_deriv(b)

    if ya == 0:
        b = a
    if yb == 0:
        a = b

    while b - a > epsilon:
        x = (a+b)/2
        y = f_deriv(x)
        if y == 0:
            a, b = x, x
        elif y*ya> 0:
            a = x
        else:
            b = x
    
    return (a+b)/2

def line_search(f, x, d):
    '''
    One step of Exact Line Search
    '''
    objective = lambda alpha: f(x+alpha*d)

    a, b = bracket_minimum(objective)

    alpha = minimize(objective, a, b, 0.0001)

    return x + alpha*d

def func(x):
    return jax.numpy.sin(x[0]*x[1])+jax.numpy.exp(x[1]+x[2])-x[2]


x_0 = np.array([1.0, 2.0, 3.0])

gradient =  jax.grad(func) # ∇f
g = gradient(x_0) # ∇f(x)

model = ConjugateGradientDescent(g)

x_1 = model.first_step(func, x_0)
x_2 = model.step(func, gradient, x_1)
print(x_2)

[ 1.2423997  -1.4850699   0.50486636]


## Algorithm 5.3

In [33]:
class Momentum:

    def __init__(self, alpha, beta, dimension):
        self.alpha = alpha # α
        self.beta = beta # β
        self.v = np.zeros(dimension)

    def step(self, gradient, x):
        g = gradient(x)
        self.v = self.beta*self.v - self.alpha*g
        return x + self.v

### Example

In [38]:
def func(x):
    return jax.numpy.sin(x[0]*x[1])+jax.numpy.exp(x[1]+x[2])-x[2]

dim = 3
alpha = 0.1
beta = 0.3
model = Momentum(alpha, beta, dim)

x_0 = np.array([1.0, 2.0, 3.0])
gradient = jax.grad(func)
x_1 = model.step(gradient, x_0)
print(x_1)

[  1.0832294 -12.799701  -11.741317 ]


## Algorithm 5.4

In [36]:
class NesterovMomentum:

    def __init__(self, alpha, beta, dimension):
        self.alpha = alpha # α
        self.beta = beta # β
        self.v = np.zeros(dimension)

    def step(self, gradient, x):
        g = gradient(x+self.beta*self.v)
        self.v = self.beta*self.v - self.alpha*g
        return x + self.v

## Algorithm 5.5

In [41]:
class Adagrad:

    def __init__(self, alpha, epsilon, dimension):
        self.alpha = alpha # α
        self.epsilon = epsilon # ϵ
        self.s = np.zeros(dimension)

    def step(self, gradient, x):
        g = gradient(x)
        self.s += g*g
        return x - self.alpha / (np.sqrt(self.s) + self.epsilon)

### Example

In [42]:
def func(x):
    return jax.numpy.sin(x[0]*x[1])+jax.numpy.exp(x[1]+x[2])-x[2]

alpha = 0.1
epsilon = 1e-8
dim = 3
model = Adagrad(alpha, epsilon, dim)

x_0 = np.array([1.0, 2.0, 3.0])
gradient = jax.grad(func)
x_1 = model.step(gradient, x_0)
print(x_1)

[0.8798501  1.99932431 2.99932163]


## Algorithm 5.6

In [45]:
class RMSProp:

    def __init__(self, alpha, gamma, epsilon, dimension):
        self.alpha = alpha # α
        self.gamma = gamma # γ
        self.epsilon = epsilon # ϵ
        self.s = np.zeros(dimension)

    def step(self, gradient, x):
        g = gradient(x)
        self.s = self.gamma*self.s + (1 - self.gamma)*(g*g)
        return x - self.alpha*g / (np.sqrt(self.s) + self.epsilon)

## Algorithm 5.7

In [4]:
class Adadelta:

    def __init__(self, gamma_s, gamma_x, epsilon, dim):
        self.gamma_s = gamma_s # γs
        self.gamma_x = gamma_x # γx
        self.epsilon = epsilon # ϵ
        self.u = np.zeros(dim)
        self.s = np.zeros(dim)

    def step(self, gradient, x):
        g = gradient(x)
        self.s = self.gamma_s*self.s + (1 - self.gamma_s)*(g*g)
        delta_x = ( ( np.sqrt(self.u) + self.epsilon )/( np.sqrt(self.s) + self.epsilon )) * g
        self.u = self.gamma_x*self.u + (1 - self.gamma_x)*(delta_x*delta_x)

        return x + delta_x        


## Algorithm 5.8

In [13]:
class Adam:

    def __init__(self, alpha, gamma_v, gamma_s, epsilon, dim):
        self.alpha = alpha # α
        self.gamma_v = gamma_v # γv
        self.gamma_s = gamma_s # γs
        self.epsilon = epsilon # ϵ
        self.k = 1
        self.v = np.zeros(dim)
        self.s = np.zeros(dim)


    def step(self, gradient, x):
        g = gradient(x)
        self.v = self.gamma_v*self.v + (1 - self.gamma_v)*g
        self.s = self.gamma_s*self.s + (1-self.gamma_s)*(g*g)
        self.k += 1
        v_hat = self.v / (1 - self.gamma_v**(self.k) )
        s_hat = self.s / (1 - self.gamma_s**(self.k) )

        return x - self.alpha*v_hat / (np.sqrt(s_hat) + self.epsilon)

### Example

In [18]:
def func(x):
    return jax.numpy.sin(x[0]*x[1])+jax.numpy.exp(x[1]+x[2])-x[2]

alpha = 0.0011
gamma_v = 0.9
gamma_s = 0.999
epsilon = 1e-8
dim = 3
model = Adam(alpha, gamma_v, gamma_s, epsilon, dim)

x_0 = np.array([1.0, 2.0, 3.0])
gradient = jax.grad(func)
x_1 = model.step(gradient, x_0)
print(x_1)

[1.0008185 1.9991815 2.9991815]


## Algorithm 5.9

In [19]:
class HyperGradientDescent:

    def __init__(self, alpha, mu, dim):
        self.alpha = alpha # α
        self.mu = mu # μ
        self.g_prev = np.zeros(dim) #previous gradient


    def step(self, gradient, x):
        g = gradient(x)
        self.alpha += self.mu*(np.dot(g,self.g_prev))
        self.g_prev = g

        return x - self.alpha*g

## Algorithm 5.10

In [23]:
class HyperNesterovMomentum:

    def __init__(self, alpha, mu, beta, dim):
        self.alpha = alpha # α
        self.mu = mu # μ
        self.beta = beta # β
        self.v = np.zeros(dim)
        self.g_prev = np.zeros(dim) #previous gradient

    def step(self, gradient, x):
        g = gradient(x)
        self.alpha -= self.mu*(np.dot(g,-self.g_prev-self.beta*self.v))
        self.v = self.beta*self.v + g
        self.g_prev = g

        return x - self.alpha*(g + self.beta*self.v)