_Author: Dmitry Bagaev_

In [None]:
using DrWatson

In [None]:
@quickactivate "RxInferThesisExperiments"

In [None]:
using RxInferThesisExperiments, Turing, StaticArrays, Plots, PGFPlotsX, LaTeXStrings
using LinearAlgebra, StableRNGs, Random, BenchmarkTools, ColorSchemes, Dates, ProgressMeter

import Distributions

In [None]:
# For the thesis I use the `pgfplotsx` backend of the `Plots`, since it generates high-quiality plots
# But it is reallyyy slow, if some wants to generate plots faster uncoment the `gr()` line
pgfplotsx()

# Use fast plotting backend
# gr()

In [None]:
const environment = HGFEnvironment()

In [None]:
# Include the model specification
include(srcdir("models", "turing", "hgf.jl"));

In [None]:
Turing.setprogress!(true)

In [None]:
T = 10_000
seed = 42
rng = StableRNG(seed)

zstates, xstates, observations = rand(rng, environment, T);

# `plotting` range
prange = 20:100:10000 # max((lastindex(observations) - 5000), firstindex(observations)):40:lastindex(observations)
colors = ColorSchemes.tableau_10

# Some default settings for plotting
pfontsettings = (
    titlefontsize=18,
    guidefontsize=16,
    tickfontsize=14,
    legendfontsize=14,
    legend = :bottomleft,
    size = (400, 300)
)

p1 = plot(xlabel = "Time step index", ylabel = ""; pfontsettings...)
p1 = plot!(prange, zstates[prange], color = colors[1], linewidth = 2, label = L"z")

p2 = plot(xlabel = "Time step index", ylabel = ""; pfontsettings...)
p2 = plot!(p2, prange, xstates[prange], color = colors[7], linewidth = 2, label = L"x")
p2 = scatter!(p2, prange, observations[prange], color = colors[5], ms = 2, alpha = 0.5, msw = 0, label = L"y")

p = plot(p1, p2, layout = @layout([ a b ]), size = (800, 300))

display("image/png", p)

In [None]:
# Turing need to recreate the model every time....
model = (observation, zt_min_prior, xt_min_prior, z_std_prior, y_std_prior) -> begin 
    HGF(observation, zt_min_prior, xt_min_prior, z_std_prior, y_std_prior, environment.kappa, environment.omega)
end
results = run_inference(model, observations; nsamples = 10, method = NUTS(),);

In [None]:
e_states = extract_posteriors(T, results)

emz = Distributions.mean.(e_states[:z])
evz = Distributions.std.(e_states[:z])

emx = Distributions.mean.(e_states[:x])
evx = Distributions.std.(e_states[:x])

p1 = plot(xlabel = "Time step index", ylabel = ""; pfontsettings...)
p1 = plot!(p1, prange, zstates[prange], color = colors[1], linewidth = 2, label = L"z")
p1 = plot!(p1, prange, emz[prange], ribbon = 3evz[prange], color = colors[2], linewidth = 2, label = L"q(z)")

p2 = plot(xlabel = "Time step index", ylabel = ""; pfontsettings...)
p2 = plot!(p2, prange, xstates[prange], color = colors[7], linewidth = 2, label = L"x")
p2 = plot!(p2, prange, emx[prange], ribbon = 3evx[prange], color = colors[3], linewidth = 2, label = L"q(z)")
p2 = scatter!(p2, prange, observations[prange], color = colors[5], ms = 2, alpha = 0.5, msw = 0, label = L"y")

p = plot(p1, p2, layout = @layout([ a b ]), size = (800, 300))

display("image/png", p)

In [None]:
println("AMSE Z: ", compute_amse(zstates, e_states[:z]))
println("AMSE X: ", compute_amse(xstates, e_states[:x]))

In [None]:
versioninfo()

In [None]:
] status