In [1]:
using LinearAlgebra, BenchmarkTools, TimerOutputs

---

In [2]:
@views @fastmath function f(x)
    val = 0.0
     for i = eachindex(x)
        val += 1/4 * x[i]^4 + 1/2 * x[i]^2 + x[i]
    end
    return val
end

@show f(ones(3))
@benchmark f($(ones(Float64, Int(1e4)) .+ 0.121214))

f(ones(3)) = 5.25


BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     24.300 μs (0.00% GC)
  median time:      24.400 μs (0.00% GC)
  mean time:        25.013 μs (0.00% GC)
  maximum time:     106.200 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1

In [3]:
@views @fastmath ∇f(x) = [x[i]^3 + x[i] + 1 for i = eachindex(x)]
@show ∇f(ones(3))
@benchmark ∇f($(ones(Float64, Int(1e4))))

∇f(ones(3)) = [3.0, 3.0, 3.0]


BenchmarkTools.Trial: 
  memory estimate:  78.20 KiB
  allocs estimate:  2
  --------------
  minimum time:     9.200 μs (0.00% GC)
  median time:      10.400 μs (0.00% GC)
  mean time:        14.428 μs (10.33% GC)
  maximum time:     3.170 ms (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1

---

In [4]:
@views function backtrack(xk, dk, f, ∇fk, α0, c1, ρ, btmax) :: Tuple{Int, Float64}
    α = α0
    bt = 1
    fk = f(xk)
    ∇f_dk = ∇fk' * dk
    while bt < btmax
         if (f(xk + α * dk) < fk + c1 * α * ∇f_dk)
            break
        end
        α = ρ * α
        bt += 1
    end
    return bt, α
end

backtrack (generic function with 1 method)

In [5]:
function steepest_descent(
        x0,
        f, ∇f,
        rel_diff=1e-8, kmax=1000,
        α0=5.0, c1=1e-4, ρ=0.8, btmax=50
    )
    k = 0
    xk = x0
    
    while k < kmax
        ∇fk = ∇f(xk)
        dk = -∇fk
        
        bt, αk = backtrack(xk, dk, f, ∇fk, α0, c1, ρ, btmax)
        
        xk_new = xk + αk * dk
        if norm(xk_new - xk) / norm(xk) < rel_diff
            break
        end
        xk = xk_new
            
        k += 1
    end
    
    return xk, k
end

x0 = ones(Int(1e4));
xk, k = steepest_descent(x0, f, ∇f)
for x in xk @assert isapprox(-0.68233, x;atol=1e-3) end
@benchmark steepest_descent($x0, $f, $∇f)

BenchmarkTools.Trial: 
  memory estimate:  91.19 MiB
  allocs estimate:  2388
  --------------
  minimum time:     27.973 ms (6.03% GC)
  median time:      29.720 ms (6.26% GC)
  mean time:        30.476 ms (6.68% GC)
  maximum time:     47.177 ms (17.66% GC)
  --------------
  samples:          165
  evals/sample:     1

In [6]:
function fletcher_reeves(
        x0,
        f, ∇f,
        rel_diff=1e-8, kmax=1000,
        α0=5.0, c1=1e-4, ρ=0.8, btmax=50
    )
    xk = x0
    ∇fk = ∇f(xk)
    
    pk = -∇fk
    k = 0
    
    while k < kmax
        bt, αk = backtrack(xk, pk, f, ∇fk, α0, c1, ρ, btmax)
        
        xkp1 = xk + αk * pk
        if norm(xkp1 - xk) / norm(xkp1) < rel_diff
            break
        end
        
        ∇fkp1 = ∇f(xkp1)
        βkp1 = (∇fkp1' * ∇fkp1) / (∇fk' * ∇fk)
        
        pk = -∇fkp1 + βkp1 * pk
        
        xk = xkp1
        ∇fk = ∇fkp1
        
        k = k + 1
    end
    
    return xk, k
end

x0 = ones(Int(1e4));
xk, k = fletcher_reeves(x0, f, ∇f)
for x in xk @assert isapprox(-0.68233, x;atol=1e-3) end
@benchmark fletcher_reeves($x0, $f, $∇f)

BenchmarkTools.Trial: 
  memory estimate:  123.11 MiB
  allocs estimate:  3224
  --------------
  minimum time:     35.476 ms (4.97% GC)
  median time:      38.324 ms (7.24% GC)
  mean time:        38.633 ms (7.12% GC)
  maximum time:     46.043 ms (6.73% GC)
  --------------
  samples:          130
  evals/sample:     1

In [7]:
function polak_ribiere(
        x0,
        f, ∇f,
        rel_diff=1e-8, kmax=1000,
        α0=5.0, c1=1e-4, ρ=0.8, btmax=50
    )
    xk = x0
    ∇fk = ∇f(xk)
    
    pk = -∇fk
    k = 0
    
    while k < kmax
        bt, αk = backtrack(xk, pk, f, ∇fk, α0, c1, ρ, btmax)
        
        xkp1 = xk + αk * pk
        if norm(xkp1 - xk) / norm(xkp1) < rel_diff
            break
        end
        
        ∇fkp1 = ∇f(xkp1)
        βkp1 = (∇fkp1' * (∇fkp1 - ∇fk)) / (∇fk' * ∇fk)
        
        pk = -∇fkp1 + βkp1 * pk
        
        xk = xkp1
        ∇fk = ∇fkp1
        
        k = k + 1
    end
    
    return xk, k
end

x0 = ones(Int(1e4));
xk, k = polak_ribiere(x0, f, ∇f)
for x in xk @assert isapprox(-0.68233, x;atol=1e-3) end
@benchmark polak_ribiere($x0, $f, $∇f)

BenchmarkTools.Trial: 
  memory estimate:  335.65 MiB
  allocs estimate:  8790
  --------------
  minimum time:     105.260 ms (7.12% GC)
  median time:      109.923 ms (6.93% GC)
  mean time:        111.349 ms (6.81% GC)
  maximum time:     134.725 ms (5.49% GC)
  --------------
  samples:          46
  evals/sample:     1

---

## Finite differences

In [8]:
g(x) = 1/4 * x^4 + 1/2 * x ^ 2 + x

g (generic function with 1 method)

In [9]:
@views function ∇f_fwd_diff(x; k=8)
    h = 10.0^(-k) * norm(x)
    return [(g(x[i] + h) - g(x[i])) / h for i = eachindex(x)]
end

@views function ∇f_cnt_diff(x; k=8)
    h = 10.0^(-k) * norm(x)
    return [(g(x[i] + h) - g(x[i] - h)) / 2h for i = eachindex(x)]
end

@benchmark ∇f_fwd_diff($(ones(1_000)); k=8)

BenchmarkTools.Trial: 
  memory estimate:  7.94 KiB
  allocs estimate:  1
  --------------
  minimum time:     86.000 μs (0.00% GC)
  median time:      86.400 μs (0.00% GC)
  mean time:        95.116 μs (0.16% GC)
  maximum time:     7.116 ms (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1

---

In [10]:
function pretty_print2(d::Dict, pre=1)
    todo = Vector{Tuple}()
    for (k,v) in d
        if typeof(v) <: Dict
            push!(todo, (k,v))
        else
            println(join(fill(" ", pre)) * "$(repr(k)) => $(repr(v))")
        end
    end

    for (k,d) in todo
        s = "$(repr(k)) => "
        println(join(fill(" ", pre)) * s)
        pretty_print2(d, pre+1+length(s))
    end
    nothing
end


pretty_print2 (generic function with 2 methods)

In [22]:
results = Dict()
for optimization_method in [steepest_descent, fletcher_reeves, polak_ribiere]
    results[Symbol(optimization_method)] = Dict()
    for input_arr in [ones(Int(1e4)), ones(Int(1e5))]
        results[Symbol(optimization_method)][Int(log10(length(input_arr)))] = Dict()
        for gradient_method in [∇f, ∇f_fwd_diff, ∇f_cnt_diff]
            results[Symbol(optimization_method)][Int(log10(length(input_arr)))][Symbol(gradient_method)] = Dict()
            if gradient_method == ∇f_fwd_diff || gradient_method == ∇f_cnt_diff
                for k = 2:2:14
                    results[Symbol(optimization_method)][Int(log10(length(input_arr)))][Symbol(gradient_method)][k] = -1
                end
            else
                results[Symbol(optimization_method)][Int(log10(length(input_arr)))][Symbol(gradient_method)] = -1
            end
        end
    end
end

for optimization_method in [steepest_descent, fletcher_reeves, polak_ribiere]
    for input_arr in [ones(Int(1e4)), ones(Int(1e5))] 
        for gradient_method in [∇f, ∇f_fwd_diff, ∇f_cnt_diff]
            if gradient_method == ∇f_fwd_diff || gradient_method == ∇f_cnt_diff
                for k = 2:2:14
                    println("$(Symbol(optimization_method)), 1e$(Int(log10(length(input_arr)))), $(Symbol(gradient_method)), $k")
                    stats = @timed optimization_method(input_arr, f, x -> gradient_method(x; k=k))
                    
                    results[Symbol(optimization_method)][Int(log10(length(input_arr)))][Symbol(gradient_method)][k] = (stats.value[2], stats.time, all(isapprox(-0.68233, x; atol=1e-3) for x in stats.value[1]))
                end
            else
                println("$(Symbol(optimization_method)), 1e$(Int(log10(length(input_arr)))), $(Symbol(gradient_method))")
                stats = @timed optimization_method(input_arr, f, gradient_method)
                
                # Only check for correctness for exact methods for now
                for x in stats.value[1] @assert isapprox(-0.68233, x; atol=1e-3) end
                
                results[Symbol(optimization_method)][Int(log10(length(input_arr)))][Symbol(gradient_method)] = (stats.value[2], stats.time, all(isapprox(-0.68233, x; atol=1e-3) for x in stats.value[1]))
            end
        end
    end
end
pretty_print2(results)

steepest_descent, 1e4, ∇f
steepest_descent, 1e4, ∇f_fwd_diff, 2
steepest_descent, 1e4, ∇f_fwd_diff, 4
steepest_descent, 1e4, ∇f_fwd_diff, 6
steepest_descent, 1e4, ∇f_fwd_diff, 8
steepest_descent, 1e4, ∇f_fwd_diff, 10
steepest_descent, 1e4, ∇f_fwd_diff, 12
steepest_descent, 1e4, ∇f_fwd_diff, 14
steepest_descent, 1e4, ∇f_cnt_diff, 2
steepest_descent, 1e4, ∇f_cnt_diff, 4
steepest_descent, 1e4, ∇f_cnt_diff, 6
steepest_descent, 1e4, ∇f_cnt_diff, 8
steepest_descent, 1e4, ∇f_cnt_diff, 10
steepest_descent, 1e4, ∇f_cnt_diff, 12
steepest_descent, 1e4, ∇f_cnt_diff, 14
steepest_descent, 1e5, ∇f
steepest_descent, 1e5, ∇f_fwd_diff, 2
steepest_descent, 1e5, ∇f_fwd_diff, 4
steepest_descent, 1e5, ∇f_fwd_diff, 6
steepest_descent, 1e5, ∇f_fwd_diff, 8
steepest_descent, 1e5, ∇f_fwd_diff, 10
steepest_descent, 1e5, ∇f_fwd_diff, 12
steepest_descent, 1e5, ∇f_fwd_diff, 14
steepest_descent, 1e5, ∇f_cnt_diff, 2
steepest_descent, 1e5, ∇f_cnt_diff, 4
steepest_descent, 1e5, ∇f_cnt_diff, 6
steepest_descent, 1e5, ∇f_c

In [25]:
import JSON

In [26]:
JSON.print(results, 2)

{
  "steepest_descent": {
    "4": {
      "∇f_fwd_diff": {
        "4": [
          1000,
          4.264075,
          false
        ],
        "14": [
          43,
          0.0892592,
          true
        ],
        "10": [
          48,
          0.1023416,
          true
        ],
        "2": [
          1000,
          3.8399208,
          false
        ],
        "8": [
          57,
          0.1208062,
          true
        ],
        "6": [
          42,
          0.0864973,
          true
        ],
        "12": [
          49,
          0.1061888,
          true
        ]
      },
      "∇f": [
        47,
        0.0479317,
        true
      ],
      "∇f_cnt_diff": {
        "4": [
          51,
          0.1143326,
          true
        ],
        "14": [
          43,
          0.0907704,
          true
        ],
        "10": [
          49,
          0.1042263,
          true
        ],
        "2": [
          1000,
          4.2710722,
          false
    