_Author: Dmitry Bagaev_

In [None]:
using DrWatson

In [None]:
@quickactivate "RxInferThesisExperiments"

In [None]:
using RxInferThesisExperiments, ForneyLab, StaticArrays, Plots, PGFPlotsX, LaTeXStrings
using LinearAlgebra, StableRNGs, Random, BenchmarkTools, ColorSchemes, Dates

import Distributions

In [None]:
# For the thesis I use the `pgfplotsx` backend of the `Plots`, since it generates high-quiality plots
# But it is reallyyy slow, if some wants to generate plots faster uncoment the `gr()` line
pgfplotsx()

# Use fast plotting backend
# gr()

In [None]:
# Pregenerate paths for plots 
mkpath(plotsdir("lds"));

## Double pendulum environment

In [None]:
const environment = RotatingTracking(
    2, 
    [cos(π/20) sin(π/20)/2; -sin(π/20)/2 cos(π/20)], 
    [0.0 -1.9; 1.3 0.0], 
    [0.0001 0.0; 0.0 0.0001], 
    [1.0 0.0; 0.0 1.0]
)

In [None]:
# Include the model specification
include(srcdir("models", "forneylab", "rotating.jl"));

In [None]:
# For plotting purposes here we only generate 25 points
T = 250
seed = 42
rng = StableRNG(seed)

states, observations = rand(rng, environment, T);

# `plotting`range
prange = firstindex(states):lastindex(states)
colors = ColorSchemes.tableau_10

# Some default settings for plotting
pfontsettings = (
    titlefontsize=18,
    guidefontsize=16,
    tickfontsize=14,
    legendfontsize=14,
    legend = :bottomleft,
    size = (400, 300)
)

p1 = plot(xlabel = "Time step index", ylabel = "First component of the state"; pfontsettings...)
p1 = plot!(prange, getindex.(states, 1)[prange], color = colors[1], linewidth = 2, label = L"\theta_1")
p1 = scatter!(p1, prange, getindex.(observations, 1)[prange], ms = 2, msw = 0, color = colors[5], alpha = 0.5, label = L"y")

p2 = plot(xlabel = "Time step index", ylabel = "Second component of the state"; pfontsettings...)
p2 = plot!(p2, prange, getindex.(states, 2)[prange], color = colors[3], linewidth = 2, label = L"\dot{\theta}_1")
p2 = scatter!(p2, prange, getindex.(observations, 2)[prange], ms = 2, msw = 0, color = colors[5], alpha = 0.5, label = L"y")

p = plot(p1, p2, layout = @layout([ a b ]), size = (800, 300))

display("image/png", p)

In [None]:
model   = rotating(T, seed, environment) 
results = run_inference(model, observations);

In [None]:
e_states = extract_posteriors(T, results)

em = Distributions.mean.(e_states)
ev = Distributions.std.(e_states)
p1 = plot(xlabel = "Time step index", ylabel = "Angle (radians)"; pfontsettings...)

p1 = plot(xlabel = "Time step index", ylabel = "First component of the state"; pfontsettings...)
p1 = plot!(p1, prange, getindex.(states, 1)[prange], color = colors[1], linewidth = 2, label = L"s^{(1)}")
p1 = scatter!(p1, prange, getindex.(observations, 1)[prange], ms = 2, msw = 0, color = colors[5], alpha = 0.5, label = L"y^{(1)}")
p1 = plot!(p1, prange, getindex.(em, 1)[prange], ribbon = 3getindex.(ev, 1, 1), color = colors[2], linewidth = 2, label = L"q(s^{(1)})")

p2 = plot(xlabel = "Time step index", ylabel = "Second component of the state"; pfontsettings...)
p2 = plot!(p2, prange, getindex.(states, 2)[prange], color = colors[3], linewidth = 2, label = L"s^{(2)}")
p2 = scatter!(p2, prange, getindex.(observations, 2)[prange], ms = 2, msw = 0, color = colors[5], alpha = 0.5, label = L"y^{(1)}")
p2 = plot!(p2, prange, getindex.(em, 2)[prange], ribbon = 3getindex.(ev, 2, 2), color = colors[4], linewidth = 2, label = L"q(s^{(2)})")

p = plot(p1, p2, size = (800, 300), layout = @layout([ a b ]))

display("image/png", p)

In [None]:
println("AMSE: ", compute_amse(states, e_states))

## Versions

In [None]:
versioninfo()

In [None]:
] status