This notebook does not perform any becnhmark and simply runs the inference procedure and generates some nice plots for the thesis.

_Author: Dmitry Bagaev_

In [None]:
using DrWatson

In [None]:
@quickactivate "RxInferThesisExperiments"

In [None]:
using RxInferThesisExperiments, RxInfer, StaticArrays, Plots, PGFPlotsX, LaTeXStrings
using LinearAlgebra, StableRNGs, Random, BenchmarkTools, ColorSchemes, Dates

In [None]:
# For the thesis I use the `pgfplotsx` backend of the `Plots`, since it generates high-quiality plots
# But it is reallyyy slow, if some wants to generate plots faster uncoment the `gr()` line
pgfplotsx()

# Use fast plotting backend
# gr()

In [None]:
# Pregenerate paths for plots 
mkpath(plotsdir("hgf"));

In [None]:
const environment = HGFEnvironment()

In [None]:
# Include the model specification
include(srcdir("models", "rxinfer", "hgf.jl"));

In [None]:
T = 10_000
seed = 42
rng = StableRNG(seed)

zstates, xstates, observations = rand(rng, environment, T);

# `plotting` range
prange = 20:100:10000 # max((lastindex(observations) - 5000), firstindex(observations)):40:lastindex(observations)
colors = ColorSchemes.tableau_10

# Some default settings for plotting
pfontsettings = (
    titlefontsize=18,
    guidefontsize=16,
    tickfontsize=14,
    legendfontsize=14,
    legend = :bottomleft,
    size = (400, 300)
)

p1 = plot(xlabel = "Time step index", ylabel = ""; pfontsettings...)
p1 = plot!(prange, zstates[prange], color = colors[1], linewidth = 2, label = L"s^{(2)}")

p2 = plot(xlabel = "Time step index", ylabel = ""; pfontsettings..., legend = :bottomright)
p2 = plot!(p2, prange, xstates[prange], color = colors[7], linewidth = 2, label = L"s^{(1)}")
p2 = scatter!(p2, prange, observations[prange], color = colors[5], ms = 2, alpha = 0.5, msw = 0, label = L"y")

savefig(p1, plotsdir("hgf", "04-hierarchical_example_states_1.tex"))
savefig(p1, plotsdir("hgf", "04-hierarchical_example_states_1.pdf"))
savefig(p2, plotsdir("hgf", "04-hierarchical_example_states_2.tex"))
savefig(p2, plotsdir("hgf", "04-hierarchical_example_states_2.pdf"))

p = plot(p1, p2, layout = @layout([ a b ]), size = (800, 300))

display("image/png", p)

# Inference

In [None]:
model = hgf(environment.kappa, environment.omega)
results = run_inference(model, observations, free_energy = true, iterations = 5);

In [None]:
e_states = extract_posteriors(T, results)

emz = mean.(e_states[:z])
evz = std.(e_states[:z])

emx = mean.(e_states[:x])
evx = std.(e_states[:x])

p1 = plot(xlabel = "Time step index", ylabel = ""; pfontsettings...)
p1 = plot!(p1, prange, zstates[prange], color = colors[1], linewidth = 2, label = L"s^{(2)}")
p1 = plot!(p1, prange, emz[prange], ribbon = 3evz[prange], color = colors[2], linewidth = 2, label = L"q(s^{(2)})")

p2 = plot(xlabel = "Time step index", ylabel = ""; pfontsettings..., legend = :bottomright)
p2 = plot!(p2, prange, xstates[prange], color = colors[7], linewidth = 2, label = L"s^{(1)}")
p2 = plot!(p2, prange, emx[prange], ribbon = 3evx[prange], color = colors[3], linewidth = 2, label = L"q(s^{(1)})")
p2 = scatter!(p2, prange, observations[prange], color = colors[5], ms = 2, alpha = 0.5, msw = 0, label = L"y")

p3 = plot(xlabel = "Variational iteration index", ylabel = "Bethe Free Energy"; pfontsettings...)
plot!(results.free_energy_history, label = "Bethe Free Energy", legend = :topright)

savefig(p1, plotsdir("hgf", "04-hierarchical_example_inference_states_1.tex"))
savefig(p1, plotsdir("hgf", "04-hierarchical_example_inference_states_1.pdf"))
savefig(p2, plotsdir("hgf", "04-hierarchical_example_inference_states_2.tex"))
savefig(p2, plotsdir("hgf", "04-hierarchical_example_inference_states_2.pdf"))
savefig(p3, plotsdir("hgf", "04-hierarchical_example_inference_free_energy.tex"))
savefig(p3, plotsdir("hgf", "04-hierarchical_example_inference_free_energy.pdf"))

p = plot(p1, p2, p3, layout = @layout([ a b c ]), size = (1200, 300))

display("image/png", p)

In [None]:
println("AMSE Z: ", compute_amse(zstates, e_states[:z]))
println("AMSE X: ", compute_amse(xstates, e_states[:x]))

# Versions

In [None]:
versioninfo()

In [None]:
] status