In [None]:
using Plots, LaTeXStrings
using LinearAlgebra
using Printf

In [None]:
m, n = 100, 2

A = rand(m, n)
x̄ = [1.0, 1.0]
b = A*x̄ + randn(m)

loss(x) = 0.5*norm(A*x - b)^2
loss(x,y) = loss([x,y])

gloss(x) = A'*(A*x - b)

In [None]:
ax, bx = -10, 10
ay, by = -10, 10

xx = range(ax, bx, length=200)
yy = range(ay, by, length=200)
flevels = [0, 2, 20, 50, 100, 200, 500, 1000]

plot(xlabel=L"x", ylabel=L"y", aspect_ratio=:equal, colorbar=:none, size=(600,600),
    xlims=(ax,bx), ylims=(ay,by))
contour!(xx, yy, loss, levels=flevels, color=:black, contour_labels=true)
scatter!([x̄[1]], [x̄[2]], c=:black, label=:none)

In [None]:
function sgd(A, b, x0, xs; tol=1e-8, verbose=true, α=0.1, β=0.5)
    x = copy(x0)
    
    xtrace = [x]
    k = 0
    if verbose
        @printf("%4s %12s %12s %12s\n", "k", "loss", "||g(x)||", "||xk - xs||")
        @printf("%4d %12.4e %12.4e %12.4e\n", k, loss(x), norm(gloss(x)), norm(x-xs))
    end
    done = false
    for epoch = 1:50
        for mbi = 1:10
            mb = 10(mbi-1)+1:10mbi  # Minibatch
            Ak = A[mb,:]
            bk = b[mb]

            lk(x) = 0.5*norm(Ak*x - bk)^2
            glk(x) = Ak'*(Ak*x - bk)

            k += 1
            Δx = -glk(x)  # Steepest descent

            # Perform a backtracking line search
            t = 1e-1
            while lk(x + t*Δx) > lk(x) + α*t*dot(glk(x),Δx)
                t *= β
                if t < 1e-10
                    break
                end
            end

            x += t*Δx

            push!(xtrace, x)
        end
        
        if verbose
            @printf("%4d %12.4e %12.4e %12.4e\n", k, loss(x), norm(gloss(x)), norm(x-xs))
        end
    end
    
    return xtrace
end     

In [None]:
x0 = [-30.0, 20.0]
xs = x̄
xtrace = sgd(A, b, x0, xs, tol=1e-8);

In [None]:
xtr = hcat(xtrace...)
q = xtr[:,2:end] - xtr[:,1:end-1]

ax, bx = -10, 10
ay, by = -10, 10

xx = range(ax, bx, length=200)
yy = range(ay, by, length=200)
flevels = [0, 2, 20, 50, 100, 200, 500, 1000]

plot(xlabel=L"x", ylabel=L"y", aspect_ratio=:equal, colorbar=:none, size=(600,600),
    xlims=(ax,bx), ylims=(ay,by))
contour!(xx, yy, loss, levels=flevels, color=:black, contour_labels=true)
quiver!(xtr[1,1:end-1], xtr[2,1:end-1], quiver=(q[1,:],q[2,:]), label=:none, c=:red)
scatter!(xtr[1,:], xtr[2,:], label=:none, c=:red)
scatter!([x̄[1]], [x̄[2]], c=:black, label=:none)

In [None]:
plt1 = plot(0:length(xtrace)-1, loss.(xtrace), yaxis=:log,
    label=L"f(x_k)", xlabel=L"k", ylims=(1e-0, 1e4))
plot!(0:length(xtrace)-1, norm.(gloss.(xtrace)), label=L"\|\|\nabla f(x_k)\|\|_2")
title!("Stochastic gradient descent")

---

# Adam

Kingma and Ba, [Adam: A Method for Stochastic Optimization](https://arxiv.org/pdf/1412.6980.pdf), 2015.

In [None]:
function adam(A, b, x0, xs; tol=1e-8, verbose=true, δ=0.9, β=0.999, η=2e-1)
    x = copy(x0)
    
    mk = zeros(length(x0))
    vk = zeros(length(x0))
    
    xtrace = [x]
    k = 0
    if verbose
        @printf("%4s %12s %12s %12s\n", "k", "loss", "||g(x)||", "||xk - xs||")
        @printf("%4d %12.4e %12.4e %12.4e\n", k, loss(x), norm(gloss(x)), norm(x-xs))
    end
    done = false
    for epoch = 1:50
        for mbi = 1:10
            mb = 10(mbi-1)+1:10mbi  # Minibatch
            Ak = A[mb,:]
            bk = b[mb]

            lk(x) = 0.5*norm(Ak*x - bk)^2
            glk(x) = Ak'*(Ak*x - bk)

            k += 1
            gk = glk(x)
            mk = δ*mk + (1-δ)*gk
            vk = β*vk + (1-β)*gk.^2 
            
            m̂k = mk./(1 - δ^k)
            v̂k = vk./(1 - β^k)
            
            d = m̂k./(sqrt.(v̂k) .+ 1e-8)
            
            x -= η*d

            push!(xtrace, x)
        end

        if verbose
            @printf("%4d %12.4e %12.4e %12.4e\n", k, loss(x), norm(gloss(x)), norm(x-xs))
        end
    end
    
    return xtrace
end     

In [None]:
x0 = [-30.0, 20.0]
xs = x̄
xtrace = adam(A, b, x0, xs, tol=1e-8);

In [None]:
xtr = hcat(xtrace...)
q = xtr[:,2:end] - xtr[:,1:end-1]

ax, bx = -10, 10
ay, by = -10, 10

xx = range(ax, bx, length=200)
yy = range(ay, by, length=200)
flevels = [0, 2, 20, 50, 100, 200, 500, 1000]

plot(xlabel=L"x", ylabel=L"y", aspect_ratio=:equal, colorbar=:none, size=(600,600),
    xlims=(ax,bx), ylims=(ay,by))
contour!(xx, yy, loss, levels=flevels, color=:black, contour_labels=true)
quiver!(xtr[1,1:end-1], xtr[2,1:end-1], quiver=(q[1,:],q[2,:]), label=:none, c=:red)
scatter!(xtr[1,:], xtr[2,:], label=:none, c=:red)
scatter!([x̄[1]], [x̄[2]], c=:black, label=:none)

In [None]:
plt2 = plot(0:length(xtrace)-1, loss.(xtrace), yaxis=:log, 
    label=L"f(x_k)", xlabel=L"k", ylims=(1e0, 1e4))
plot!(0:length(xtrace)-1, norm.(gloss.(xtrace)), label=L"\|\|\nabla f(x_k)\|\|_2")
title!("Adam")

In [None]:
plot(plt1, plt2, layout=(1,2), size=(900,500))