In [None]:
using LinearAlgebra
using Plots, LaTeXStrings
using Printf

---

# [Rosenbrock's banana function](https://en.wikipedia.org/wiki/Rosenbrock_function)

$$
f(x) = \big(1 - x_1\big)^2 + 100\big(x_2 - x_1^2\big)^2
$$

$$
\nabla f(x) =
\begin{bmatrix}
-2\big(1 - x_1\big) - 400x_1\big(x_2 - x_1^2\big)\\
200\big(x_2 - x_1^2\big)
\end{bmatrix}
$$

$$
\nabla^2 f(x) = 
\begin{bmatrix}
2 - 400\big(-3x_1^2 + x_2\big) & -400x_1 \\
-400x_1 & 200
\end{bmatrix}
$$

In [None]:
f(x) = (1 - x[1])^2 + 100*(x[2] - x[1]^2)^2

g(x) = [
    -2*(1 - x[1]) - 400*(x[2] - x[1]^2)*x[1]
    200*(x[2] - x[1]^2)
]

H(x) = [
    2 - 400*(-3x[1]^2 + x[2])  -400*x[1]
                    -400*x[1]        200
]

f(x, y) = f([x, y])

In [None]:
xx = -3:0.01:3
yy = -3:0.01:3
flevels = [0, 2, 20, 100, 500, 1500, 3000]

plot(xlabel=L"x", ylabel=L"y", aspect_ratio=:equal, colorbar=:none, size=(600,600))
contour!(xx, yy, f, levels=flevels, color=:black, contour_labels=true)
scatter!([1.0], [1.0], c=:black, label=:none)

---

# Gradient descent

In [None]:
function gradient_descent(f, g, x0, xs; tol=1e-8, verbose=true, α=0.1, β=0.5)
    x = copy(x0)
    
    xtrace = [x]
    k = 0
    if verbose
        @printf("%4s %12s %12s %12s %12s\n", "k", "t", "f(x)", "||g(x)||", "||xk - xs||")
        @printf("%4d %12.4e %12.4e %12.4e %12.4e\n", k, 0.0, f(x), norm(g(x)), norm(x-xs))
    end
    done = false
    while !done
        k += 1
        Δx = -g(x)  # Steepest descent
                
        # Perform a backtracking line search
        t = 1.0
        while f(x + t*Δx) > f(x) + α*t*dot(g(x),Δx)
            t *= β
            if t < 1e-10
                break
            end
        end
        
        x += t*Δx
        push!(xtrace, x)
        if norm(t*Δx) <= tol || k >= 200
            done = true
        end
        if verbose
            @printf("%4d %12.4e %12.4e %12.4e %12.4e\n", k, t, f(x), norm(g(x)), norm(x-xs))
        end
    end
    
    return xtrace
end     

In [None]:
#x0 = 6*rand(2) .- 3
x0 = [0.0, 0.0]
xs = [1.0, 1.0]
xtrace = gradient_descent(f, g, x0, xs, tol=1e-8);

In [None]:
xtr = hcat(xtrace...)
q = xtr[:,2:end] - xtr[:,1:end-1]

xx = -3:0.01:3
yy = -3:0.01:3
flevels = [0, 2, 20, 100, 500, 1500, 3000]

plot(xlabel=L"x", ylabel=L"y", aspect_ratio=:equal, colorbar=:none, size=(600,600),
    xlims=(-3,3), ylims=(-3,3))
contour!(xx, yy, f, levels=flevels, color=:black, contour_labels=true)
quiver!(xtr[1,1:end-1], xtr[2,1:end-1], quiver=(q[1,:],q[2,:]), label=:none, c=:red)
scatter!(xtr[1,:], xtr[2,:], label=:none, c=:red)
scatter!([1.0], [1.0], c=:black, label=:none)

In [None]:
plt1 = plot(0:length(xtrace)-1, f.(xtrace), yaxis=:log, label=L"f(x_k)", xlabel=L"k", ylims=(1e-5, 1e2))
plot!(0:length(xtrace)-1, norm.(g.(xtrace)), label=L"\|\|\nabla f(x_k)\|\|_2")
title!("Gradient descent")

---

# Heavy ball

In [None]:
function heavy_ball(f, g, x0, xs; tol=1e-8, verbose=true, α=0.1, β=0.5)
    x = copy(x0)
    xold = copy(x)
    
    xtrace = [x]
    k = 0
    if verbose
        @printf("%4s %12s %12s %12s %12s\n", "k", "t", "f(x)", "||g(x)||", "||xk - xs||")
        @printf("%4d %12.4e %12.4e %12.4e %12.4e\n", k, 0.0, f(x), norm(g(x)), norm(x-xs))
    end
    done = false
    while !done
        k += 1
        Δx = -g(x)  # Steepest descent
        
        # Add momentum
        if k > 1
            Δx += 60*(x - xold)
        end
                
        # Perform a backtracking line search
        t = 1.0
        while f(x + t*Δx) > f(x) + α*t*dot(g(x),Δx)
            t *= β
            if t < 1e-10
                break
            end
        end
        
        xold = copy(x)
        x += t*Δx
        push!(xtrace, x)
        if norm(t*Δx) <= tol || k >= 200
            done = true
        end
        if verbose
            @printf("%4d %12.4e %12.4e %12.4e %12.4e\n", k, t, f(x), norm(g(x)), norm(x-xs))
        end
    end
    
    return xtrace
end     

In [None]:
#x0 = 6*rand(2) .- 3
x0 = [0.0, 0.0]
xs = [1.0, 1.0]
xtrace = heavy_ball(f, g, x0, xs, tol=1e-8);

In [None]:
xtr = hcat(xtrace...)
q = xtr[:,2:end] - xtr[:,1:end-1]

xx = -3:0.01:3
yy = -3:0.01:3
flevels = [0, 2, 20, 100, 500, 1500, 3000]

plot(xlabel=L"x", ylabel=L"y", aspect_ratio=:equal, colorbar=:none, size=(600,600),
    xlims=(-3,3), ylims=(-3,3))
contour!(xx, yy, f, levels=flevels, color=:black, contour_labels=true)
quiver!(xtr[1,1:end-1], xtr[2,1:end-1], quiver=(q[1,:],q[2,:]), label=:none, c=:red)
scatter!(xtr[1,:], xtr[2,:], label=:none, c=:red)
scatter!([1.0], [1.0], c=:black, label=:none)

In [None]:
plt2 = plot(0:length(xtrace)-1, f.(xtrace), yaxis=:log, label=L"f(x_k)", xlabel=L"k", ylims=(1e-5, 1e2))
plot!(0:length(xtrace)-1, norm.(g.(xtrace)), label=L"\|\|\nabla f(x_k)\|\|")
title!("Heavy ball")

---

# Nesterov acceleration

In [None]:
function nesterov(f, g, x0, xs; tol=1e-8, verbose=true, α=0.1, β=0.5, γ=1.0)
    x = copy(x0)
    xold = copy(x)
    y = copy(x)
    
    xtrace = [x]
    k = 0
    if verbose
        @printf("%4s %12s %12s %12s %12s\n", "k", "t", "f(x)", "||g(x)||", "||xk - xs||")
        @printf("%4d %12.4e %12.4e %12.4e %12.4e\n", k, 0.0, f(x), norm(g(x)), norm(x-xs))
    end
    done = false
    while !done
        k += 1
        Δx = -g(y)  # Steepest descent
        
        # Add momentum
        if k > 1
            Δx += 60*(x - xold)
        end
                
        # Perform a backtracking line search
        t = 1.0
        while f(x + t*Δx) > f(x) + α*t*dot(g(x),Δx)
            t *= β
            if t < 1e-10
                break
            end
        end
        
        v = x - xold
        xold = copy(x)
        x = x + 0.5*v + t*Δx
        y = x + γ*(x - xold)
        
        push!(xtrace, x)
        if norm(x - xold) <= tol || k >= 200
            done = true
        end
        if verbose
            @printf("%4d %12.4e %12.4e %12.4e %12.4e\n", k, t, f(x), norm(g(x)), norm(x-xs))
        end
    end
    
    return xtrace
end     

In [None]:
#x0 = 6*rand(2) .- 3
x0 = [0.0, 0.0]
xs = [1.0, 1.0]
xtrace = nesterov(f, g, x0, xs, tol=1e-8);

In [None]:
xtr = hcat(xtrace...)
q = xtr[:,2:end] - xtr[:,1:end-1]

xx = -3:0.01:3
yy = -3:0.01:3
flevels = [0, 2, 20, 100, 500, 1500, 3000]

plot(xlabel=L"x", ylabel=L"y", aspect_ratio=:equal, colorbar=:none, size=(600,600),
    xlims=(-3,3), ylims=(-3,3))
contour!(xx, yy, f, levels=flevels, color=:black, contour_labels=true)
quiver!(xtr[1,1:end-1], xtr[2,1:end-1], quiver=(q[1,:],q[2,:]), label=:none, c=:red)
scatter!(xtr[1,:], xtr[2,:], label=:none, c=:red)
scatter!([1.0], [1.0], c=:black, label=:none)

In [None]:
plt3 = plot(0:length(xtrace)-1, f.(xtrace), yaxis=:log, label=L"f(x_k)", xlabel=L"k", ylims=(1e-5, 1e2))
plot!(0:length(xtrace)-1, norm.(g.(xtrace)), label=L"\|\|\nabla f(x_k)\|\|")
title!("Nesterov acceleration")

In [None]:
plot(plt1, plt2, plt3, layout=(1,3), size=(900,500))

---