In [1]:
import LinearAlgebra: I, ⋅
import Base.MathConstants: φ
abstract type DescentMethod end

# Metody optymalizacji

## Metoda najszybszego spadku z bezwładnością

In [2]:
mutable struct Momentum <: DescentMethod
  α # learning rate
  β # momentum decay
  v # momentum
end
Momentum(α, β, n::Integer) = Momentum(α, β, zeros(n))

function step!(M::Momentum, f, ∇f, x) 
  α, β, v, g = M.α, M.β, M.v, ∇f(x)
  v[:] = β*v .- α*g
  return x + v
end

step! (generic function with 1 method)

## BFGS

In [3]:
mutable struct BFGS <: DescentMethod
  Q
end
BFGS(n::Integer) = BFGS(Matrix(1.0I, n, n))

function strong_backtracking(f, ∇, x, d; α=1, β=1e-4, σ=0.1)
  y0, g0, y_prev, α_prev = f(x), ∇(x)⋅d, NaN, 0
  αlo, αhi = NaN, NaN
  # bracket phase
  while true
    y = f(x + α*d)
    if y > y0 + β*α*g0 || (!isnan(y_prev) && y ≥ y_prev)
      αlo, αhi = α_prev, α
      break
    end
    g = ∇(x + α*d)⋅d
    if abs(g) ≤ -σ*g0
      return α
    elseif g ≥ 0
      αlo, αhi = α, α_prev
      break
    end
    y_prev, α_prev, α = y, α, 2α
  end
  # zoom phase
  ylo = f(x + αlo*d)
  while true
    α = (αlo + αhi)/2
    y = f(x + α*d)
    if y > y0 + β*α*g0 || y ≥ ylo
      αhi = α
    else
      g = ∇(x + α*d)⋅d
      if abs(g) ≤ -σ*g0
        return α
      elseif g*(αhi - αlo) ≥ 0
        αhi = αlo
      end
      αlo = α
    end
  end
end

function step!(M::BFGS, f, ∇f, x)
  if f(x) ≈ 0.0
    return x
  end

  Q, g = M.Q, ∇f(x)
  α = strong_backtracking(f, ∇f, x, -Q*g)
  x′ = x + α*(-Q*g)
  g′ = ∇f(x′)
  δ = x′ - x
  γ = g′ - g
  Q[:] = Q - (δ*γ'*Q + Q*γ*δ')/(δ'*γ) + (1 + (γ'*Q*γ)/(δ'*γ))[1]*(δ*δ')/(δ'*γ)
  return x′
end

step! (generic function with 2 methods)

## L-BFGS

In [35]:
using LinearAlgebra

mutable struct LBFGS
  m
  δs
  γs
  qs
  LBFGS() = new()
end

function init!(M::LBFGS, m) 
  M.m = m
  M.δs = [] 
  M.γs = [] 
  M.qs = []
  return M
end

function step!(M::LBFGS, f, ∇f, θ) 
    δs, γs, qs = M.δs, M.γs, M.qs 
    m, g = length(δs), ∇f(θ)
    d = -g
    if m > 0 
        q = g
        for i in m:-1:1
            qs[i] = copy(q)
            q -= (δs[i]⋅q) / (γs[i]⋅δs[i]) * γs[i]
        end
        z = (γs[m] .* δs[m] .* q) / (γs[m]⋅γs[m]) 
        for i in 1:+1:m
            z += δs[i]*(δs[i]⋅qs[i]-γs[i]⋅z)/(γs[i]⋅δs[i]) 
        end
        d = -z;
    end
    φ =α-> f(θ+α*d); φ′=α->∇f(θ+α*d)⋅d 
    α = line_search(φ, φ′, d)
    θ′ = θ + α*d; g′ = ∇f(θ′)
    δ =θ′-θ;γ =g′-g
    push!(δs, δ); push!(γs, γ); push!(qs, zero(θ)) 
    while length(δs) > M.m
        popfirst!(δs); popfirst!(γs); popfirst!(qs) 
    end
    return θ′ 
end

step! (generic function with 2 methods)

# Test

In [44]:
function main()
    f(x)  = 100*(x[2] - x[1]^2)^2 + (1-x[1])^2 # funkcja
    ∇f(x) = [400x[1]^3 - 400x[1]*x[2] + 2x[1] - 2,
        200x[2] - 200x[1]^2] # pochodne
    
    x₀  = [1.1, 2.0] # wektor początkowy x
    pts = [x₀] # tutaj zbieramy kolejne wektory x
    val = Float64[] # tutaj zbieramy wartości f. straty
    # opt = BFGS(2) # optymalizator
    # opt = Momentum(0.001, 0.00001, 2)
    opt = LBFGS()
    for i=1:25
        push!(val, f(pts[end])) # odłóż wynik funkcji dla najnowszych x (chcemy żeby 0)
        push!(pts, step!(opt, f, ∇f, pts[end])) # odłóż nowe x - nowe x to wynik step
    end

    pts, val
end

pts, val = main()
val

LoadError: UndefRefError: access to undefined reference

In [33]:
pts

26-element Vector{Vector{Float64}}:
 [1.1, 2.0]
 [1.4474, 1.842]
 [1.3000510386304, 1.892591772]
 [1.4047323122140916, 1.8521004641265164]
 [1.3358379774247768, 1.876334540183804]
 [1.384255755030436, 1.8579604948738675]
 [1.3512603155077207, 1.8696010112256174]
 [1.3741756418268958, 1.8608618134388624]
 [1.3583133019746658, 1.8663611022771842]
 [1.3691943803136428, 1.8620919420788997]
 [1.3615546437748907, 1.86461216118801]
 [1.3667030879886406, 1.862455963749593]
 [1.3630059737065408, 1.8635402155812373]
 [1.365417527041267, 1.8623892401794513]
 [1.363607608622858, 1.8627843852641048]
 [1.3647123412039972, 1.8621126542215645]
 [1.3638043576814343, 1.8621780715068401]
 [1.3642854772204929, 1.861734923065819]
 [1.3638079721902345, 1.8616429066921203]
 [1.363991768890755, 1.8613087614354604]
 [1.3637194781999222, 1.861141714927262]
 [1.3637616650846283, 1.860859533315718]
 [1.3635870952510727, 1.8606567996616379]
 [1.3635619121288758, 1.8603993909690253]
 [1.3634338295776516, 1.86017972