In [1]:
using ForwardDiff
using Plots
using Random

In [2]:
function norm(a)
    sqrt.(sum(a.^2))
end

norm (generic function with 1 method)

In [3]:
function horner(a,x)
    r = 0.0
    for i = length(a):-1:2
       r = x*(a[i]+r)
    end
    return a[1]+r
end

horner (generic function with 1 method)

In [4]:
xdata = collect(range(-10,10,10000))
a = [-10,1,2]
a0 = zeros(3)
ydata = (x->horner(a,x)).(xdata);

In [5]:
E(a) = sum(( ((xx->horner(a,xx)).(xdata)) - ydata).^2)
∇E(a) = ForwardDiff.gradient(E,a)

∇E (generic function with 1 method)

In [6]:
function grad_descent(lr,a0,it=10000)
    a = a0
    i = 0
    while i<it
        grad = ∇E(a)
        grad /= norm(grad)
        a -= lr*grad
        i += 1
    end
    a
end

grad_descent (generic function with 2 methods)

In [7]:
grad_descent(0.1,a0)

3-element Vector{Float64}:
 -9.881805283858665
  0.9999999999999983
  1.9480280088393882

# Optimizador Adam

In [23]:
"""
Optimizador Adam, tal como está descrito en https://arxiv.org/abs/1412.6980

Los parámetros son α, β1 y β2 y ϵ que por defecto son los sugeridos en el paper:
    α = 1e-4
    β1 = 0.9
    β2 = 0.999
    ϵ = 1e-8
"""
function adam_optimizer(a0,α=1e-4,β1=0.9,β2=0.999,tolerance=1e-6,ϵ=1e-8)
    a = a0
    t = 1
    
    m = zeros(length(a0))
    v = zeros(length(a0))
    
    e0 = E(a)
    Δe = e0

    while Δe > tolerance
        grad = ∇E(a)
        
        m = (β1 * m) + (1-β1)*grad
        v = (β2 * v) + (1-β2)*(grad .^ 2)
        
        m_ = m/(1-β1^t)
        v_ = v/(1-β2^t)
        
        Δa = α * (m_./(sqrt.(v_) .+ ϵ))
        
        a -= Δa
        t += 1
        
        e1 = E(a)
       
        Δe = abs(e1 - e0)
        e0 = e1
    end
    println("The optimization has converged. Total iterations: $t")
    a
end

adam_optimizer

In [24]:
@time adam_optimizer(a0)

The optimization has converged. Total iterations: 24207
 11.978742 seconds (1.69 M allocations: 27.111 GiB, 6.81% gc time, 0.50% compilation time)


3-element Vector{Float64}:
 0.9888725479561656
 0.9999999999999997
 1.816439836674675