# GFE Optimization by Newton's Method

In [1]:
using LinearAlgebra
using PositiveFactorizations
using ForwardDiff: jacobian
using Plots

function softmax(v::Vector)
    r = v .- maximum(v)
    clamp!(r, -100.0, 0.0)
    exp.(r)./sum(exp.(r))
end

tiny = 1e-12
;

[33m[1m│ [22m[39m- If you have Compat checked out for development and have
[33m[1m│ [22m[39m  added Base64 as a dependency but haven't updated your primary
[33m[1m│ [22m[39m  environment's manifest file, try `Pkg.resolve()`.
[33m[1m│ [22m[39m- Otherwise you may need to report an issue with Compat


[33m[1m│ [22m[39m  exception = Required dependency Compat [34da2185-b29b-5c13-b0c7-acf172513d20] failed to load from a cache file.
[33m[1m└ [22m[39m[90m@ Base loading.jl:1055[39m


[33m[1m│ [22m[39m- If you have Compat checked out for development and have
[33m[1m│ [22m[39m  added Base64 as a dependency but haven't updated your primary
[33m[1m│ [22m[39m  environment's manifest file, try `Pkg.resolve()`.
[33m[1m│ [22m[39m- Otherwise you may need to report an issue with Compat


[33m[1m│ [22m[39mThis may mean ChainRulesCore [d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:1030[39m


[33m[1m│ [22m[39mThis may mean ChainRulesCore [d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:1030[39m


[33m[1m│ [22m[39m  exception = Required dependency Compat [34da2185-b29b-5c13-b0c7-acf172513d20] failed to load from a cache file.
[33m[1m└ [22m[39m[90m@ Base loading.jl:1055[39m


[33m[1m│ [22m[39m- If you have Compat checked out for development and have
[33m[1m│ [22m[39m  added Base64 as a dependency but haven't updated your primary
[33m[1m│ [22m[39m  environment's manifest file, try `Pkg.resolve()`.
[33m[1m│ [22m[39m- Otherwise you may need to report an issue with Compat


[33m[1m│ [22m[39m  exception = Required dependency Compat [34da2185-b29b-5c13-b0c7-acf172513d20] failed to load from a cache file.
[33m[1m└ [22m[39m[90m@ Base loading.jl:1055[39m


[33m[1m│ [22m[39m- If you have Compat checked out for development and have
[33m[1m│ [22m[39m  added Base64 as a dependency but haven't updated your primary
[33m[1m│ [22m[39m  environment's manifest file, try `Pkg.resolve()`.
[33m[1m│ [22m[39m- Otherwise you may need to report an issue with Compat


[33m[1m│ [22m[39mThis may mean ChainRulesCore [d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:1030[39m


[33m[1m│ [22m[39mThis may mean ChainRulesCore [d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:1030[39m


│   exception = ErrorException("Required dependency Compat [34da2185-b29b-5c13-b0c7-acf172513d20] failed to load from a cache file.")
└ @ Base loading.jl:1055


[33m[1m│ [22m[39m- If you have Compat checked out for development and have
[33m[1m│ [22m[39m  added Base64 as a dependency but haven't updated your primary
[33m[1m│ [22m[39m  environment's manifest file, try `Pkg.resolve()`.
[33m[1m│ [22m[39m- Otherwise you may need to report an issue with Compat


# Model

In [None]:
A = [0.99 0.01; 
     0.01 0.99]

c = [0.5, 0.5]

d = [0.01, 0.99]

s_0 = [0.9, 0.1] # Initial coordinate

g(s) = s - softmax(log.(d .+ tiny) + diag(A'*log.(A) .+ tiny) + A'*log.(c .+ tiny) - A'*log.(A*s .+ tiny)) # Convert fixed-point equation to root-finding problem
F(s) = -s'*log.(d .+ tiny) + s'*log.(s .+ tiny) - s'*diag(A'*log.(A) .+ tiny) - (A*s)'*log.(c .+ tiny) + (A*s)'*log.(A*s .+ tiny)
;

# Results

In [None]:
n_its = 5
G = zeros(n_its)
p = Vector{Float64}(undef, n_its) # Coordinates

G_0 = F(s_0)
s_k_min = s_0
for k=1:n_its
    s_k = s_k_min - inv(jacobian(g, s_k_min))*g(s_k_min) # Newton step for multivariate root finding

    p[k] = s_k[1]
    G[k] = F(s_k)

    s_k_min = s_k
end

In [None]:
plot(0:n_its, [G_0; G], color=:black, grid=true, linewidth=2, legend=false, xlabel="Coordinate Increment", ylabel="GFE [nats]")

# Landscape

In [None]:
ps = 0.0:0.05:1.0
m = length(ps)
Gs = zeros(m)
for i = 1:m
    Gs[i] = F([ps[i], 1.0-ps[i]])
end

In [None]:
plt = plot(ps,
           Gs,
           dpi=100,
           xlabel="s",
           ylabel="GFE [nats]", color=:black, linewidth=2)

p_0 = s_0[1]
plot!([p_0; p], [G_0; G], color=:green, marker=:o, linewidth=2, legend=false)

for k=1:n_its+1
    ann = ([p_0; p][k], [G_0; G][k], text(k-1, 12, :red, :center))
    annotate!(ann, linecolor=:red)
end

plt