In [25]:
using LinearAlgebra, BenchmarkTools, TimerOutputs

---

In [4]:
@views @fastmath function f(x)
    val = 0.0
     for i = eachindex(x)
        val += 1/4 * x[i]^4 + 1/2 * x[i]^2 + x[i]
    end
    return val
end

@show f(ones(3))
@benchmark f($(ones(Float64, Int(1e4)) .+ 0.121214))

f(ones(3)) = 5.25


BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     34.300 μs (0.00% GC)
  median time:      34.400 μs (0.00% GC)
  mean time:        36.490 μs (0.00% GC)
  maximum time:     196.100 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1

In [5]:
@views @fastmath ∇f(x) = [x[i]^3 + x[i] + 1 for i = eachindex(x)]
@show ∇f(ones(3))
@benchmark ∇f($(ones(Float64, Int(1e4))))

∇f(ones(3)) = [3.0, 3.0, 3.0]


BenchmarkTools.Trial: 
  memory estimate:  78.20 KiB
  allocs estimate:  2
  --------------
  minimum time:     8.800 μs (0.00% GC)
  median time:      9.300 μs (0.00% GC)
  mean time:        12.123 μs (6.53% GC)
  maximum time:     589.800 μs (98.27% GC)
  --------------
  samples:          10000
  evals/sample:     1

---

In [6]:
@views function backtrack(xk, dk, f, ∇fk, α0, c1, ρ, btmax) :: Tuple{Int, Float64}
    α = α0
    bt = 1
    fk = f(xk)
    ∇f_dk = ∇fk' * dk
    while bt < btmax
         if (f(xk + α * dk) < fk + c1 * α * ∇f_dk)
            break
        end
        α = ρ * α
        bt += 1
    end
    return bt, α
end

backtrack (generic function with 1 method)

In [7]:
function steepest_descent(
        x0,
        f, ∇f,
        rel_diff=1e-8, kmax=1000,
        α0=5.0, c1=1e-4, ρ=0.8, btmax=50
    )
    k = 0
    xk = x0
    
    while k < kmax
        ∇fk = ∇f(xk)
        dk = -∇fk
        
        bt, αk = backtrack(xk, dk, f, ∇fk, α0, c1, ρ, btmax)
        
        xk_new = xk + αk * dk
        if norm(xk_new - xk) / norm(xk) < rel_diff
            break
        end
        xk = xk_new
            
        k += 1
    end
    
    return xk, k
end

x0 = ones(Int(1e4));
xk, k = steepest_descent(x0, f, ∇f)
for x in xk @assert isapprox(-0.68233, x;atol=1e-3) end
@benchmark steepest_descent($x0, $f, $∇f)

BenchmarkTools.Trial: 
  memory estimate:  91.19 MiB
  allocs estimate:  2388
  --------------
  minimum time:     34.617 ms (3.02% GC)
  median time:      38.344 ms (3.16% GC)
  mean time:        41.732 ms (3.47% GC)
  maximum time:     81.641 ms (2.82% GC)
  --------------
  samples:          120
  evals/sample:     1

In [8]:
function fletcher_reeves(
        x0,
        f, ∇f,
        rel_diff=1e-8, kmax=1000,
        α0=5.0, c1=1e-4, ρ=0.8, btmax=50
    )
    xk = x0
    ∇fk = ∇f(xk)
    
    pk = -∇fk
    k = 0
    
    while k < kmax
        bt, αk = backtrack(xk, pk, f, ∇fk, α0, c1, ρ, btmax)
        
        xkp1 = xk + αk * pk
        if norm(xkp1 - xk) / norm(xkp1) < rel_diff
            break
        end
        
        ∇fkp1 = ∇f(xkp1)
        βkp1 = (∇fkp1' * ∇fkp1) / (∇fk' * ∇fk)
        
        pk = -∇fkp1 + βkp1 * pk
        
        xk = xkp1
        ∇fk = ∇fkp1
        
        k = k + 1
    end
    
    return xk, k
end

x0 = ones(Int(1e4));
xk, k = fletcher_reeves(x0, f, ∇f)
for x in xk @assert isapprox(-0.68233, x;atol=1e-3) end
@benchmark fletcher_reeves($x0, $f, $∇f)

BenchmarkTools.Trial: 
  memory estimate:  113.64 MiB
  allocs estimate:  2976
  --------------
  minimum time:     40.870 ms (4.19% GC)
  median time:      47.334 ms (3.94% GC)
  mean time:        53.634 ms (3.91% GC)
  maximum time:     107.098 ms (4.45% GC)
  --------------
  samples:          94
  evals/sample:     1

In [9]:
function polak_ribiere(
        x0,
        f, ∇f,
        rel_diff=1e-8, kmax=1000,
        α0=5.0, c1=1e-4, ρ=0.8, btmax=50
    )
    xk = x0
    ∇fk = ∇f(xk)
    
    pk = -∇fk
    k = 0
    
    while k < kmax
        bt, αk = backtrack(xk, pk, f, ∇fk, α0, c1, ρ, btmax)
        
        xkp1 = xk + αk * pk
        if norm(xkp1 - xk) / norm(xkp1) < rel_diff
            break
        end
        
        ∇fkp1 = ∇f(xkp1)
        βkp1 = (∇fkp1' * (∇fkp1 - ∇fk)) / (∇fk' * ∇fk)
        
        pk = -∇fkp1 + βkp1 * pk
        
        xk = xkp1
        ∇fk = ∇fkp1
        
        k = k + 1
    end
    
    return xk, k
end

x0 = ones(Int(1e4));
xk, k = polak_ribiere(x0, f, ∇f)
for x in xk @assert isapprox(-0.68233, x;atol=1e-3) end
@benchmark polak_ribiere($x0, $f, $∇f)

BenchmarkTools.Trial: 
  memory estimate:  335.65 MiB
  allocs estimate:  8790
  --------------
  minimum time:     131.658 ms (3.49% GC)
  median time:      134.745 ms (3.52% GC)
  mean time:        136.238 ms (3.43% GC)
  maximum time:     152.757 ms (2.92% GC)
  --------------
  samples:          37
  evals/sample:     1

---

## Finite differences

In [10]:
g(x) = 1/4 * x^4 + 1/2 * x ^ 2 + x

g (generic function with 1 method)

In [11]:
@views function ∇f_fwd_diff(x; k=8)
    h = 10.0^(-k)
    return [(g(x[i] + h) - g(x[i])) / h for i = eachindex(x)]
end

@views function ∇f_cnt_diff(x; k=8)
    h = 10.0^(-k) * norm(x)
    return [(g(x[i] + h) - g(x[i] - h)) / 2h for i = eachindex(x)]
end

@benchmark ∇f_fwd_diff($(ones(1_000)); k=8)

BenchmarkTools.Trial: 
  memory estimate:  7.94 KiB
  allocs estimate:  1
  --------------
  minimum time:     31.700 μs (0.00% GC)
  median time:      31.800 μs (0.00% GC)
  mean time:        33.795 μs (0.30% GC)
  maximum time:     1.131 ms (91.11% GC)
  --------------
  samples:          10000
  evals/sample:     1

---

In [12]:
function pretty_print2(d::Dict, pre=1)
    todo = Vector{Tuple}()
    for (k,v) in d
        if typeof(v) <: Dict
            push!(todo, (k,v))
        else
            println(join(fill(" ", pre)) * "$(repr(k)) => $(repr(v))")
        end
    end

    for (k,d) in todo
        s = "$(repr(k)) => "
        println(join(fill(" ", pre)) * s)
        pretty_print2(d, pre+1+length(s))
    end
    nothing
end


pretty_print2 (generic function with 2 methods)

In [38]:
results = Dict()
for optimization_method in [steepest_descent, fletcher_reeves, polak_ribiere]
    results[Symbol(optimization_method)] = Dict()
    for input_arr in [ones(Int(1e3)), ones(Int(1e4))]
        results[Symbol(optimization_method)][Int(log10(length(input_arr)))] = Dict()
        for gradient_method in [∇f, ∇f_fwd_diff, ∇f_cnt_diff]
            results[Symbol(optimization_method)][Int(log10(length(input_arr)))][Symbol(gradient_method)] = Dict()
            if gradient_method == ∇f_fwd_diff || gradient_method == ∇f_cnt_diff
                for k = 2:2:14
                    results[Symbol(optimization_method)][Int(log10(length(input_arr)))][Symbol(gradient_method)][k] = -1
                end
            else
                results[Symbol(optimization_method)][Int(log10(length(input_arr)))][Symbol(gradient_method)] = -1
            end
        end
    end
end

for optimization_method in [steepest_descent, fletcher_reeves, polak_ribiere]
    for input_arr in [ones(Int(1e3)), ones(Int(1e4))] 
        for gradient_method in [∇f, ∇f_fwd_diff, ∇f_cnt_diff]
            if gradient_method == ∇f_fwd_diff || gradient_method == ∇f_cnt_diff
                for k = 2:2:14
                    println("$(Symbol(optimization_method)), 1e$(Int(log10(length(input_arr)))), $(Symbol(gradient_method)), $k")
                    stats = @timed optimization_method(input_arr, f, x -> gradient_method(x; k=k))
                    
                    results[Symbol(optimization_method)][Int(log10(length(input_arr)))][Symbol(gradient_method)][k] = (stats.value[2], stats.time)
                end
            else
                println("$(Symbol(optimization_method)), 1e$(Int(log10(length(input_arr)))), $(Symbol(gradient_method))")
                stats = @timed optimization_method(input_arr, f, gradient_method)
                
                # Only check for correctness for exact methods for now
                for x in stats.value[1] @assert isapprox(-0.68233, x; atol=1e-3) end
                
                results[Symbol(optimization_method)][Int(log10(length(input_arr)))][Symbol(gradient_method)] = (stats.value[2], stats.time)
            end
        end
    end
end
pretty_print2(results)

steepest_descent, 1e3, ∇f
steepest_descent, 1e3, ∇f_fwd_diff, 2
steepest_descent, 1e3, ∇f_fwd_diff, 4
steepest_descent, 1e3, ∇f_fwd_diff, 6
steepest_descent, 1e3, ∇f_fwd_diff, 8
steepest_descent, 1e3, ∇f_fwd_diff, 10
steepest_descent, 1e3, ∇f_fwd_diff, 12
steepest_descent, 1e3, ∇f_fwd_diff, 14
steepest_descent, 1e3, ∇f_cnt_diff, 2
steepest_descent, 1e3, ∇f_cnt_diff, 4
steepest_descent, 1e3, ∇f_cnt_diff, 6
steepest_descent, 1e3, ∇f_cnt_diff, 8
steepest_descent, 1e3, ∇f_cnt_diff, 10
steepest_descent, 1e3, ∇f_cnt_diff, 12
steepest_descent, 1e3, ∇f_cnt_diff, 14
steepest_descent, 1e4, ∇f
steepest_descent, 1e4, ∇f_fwd_diff, 2
steepest_descent, 1e4, ∇f_fwd_diff, 4
steepest_descent, 1e4, ∇f_fwd_diff, 6
steepest_descent, 1e4, ∇f_fwd_diff, 8
steepest_descent, 1e4, ∇f_fwd_diff, 10
steepest_descent, 1e4, ∇f_fwd_diff, 12
steepest_descent, 1e4, ∇f_fwd_diff, 14
steepest_descent, 1e4, ∇f_cnt_diff, 2
steepest_descent, 1e4, ∇f_cnt_diff, 4
steepest_descent, 1e4, ∇f_cnt_diff, 6
steepest_descent, 1e4, ∇f_c