This notebook performs a comprehensive becnhmark suit for the inference procedure for the double pendulum system using the RxInfer framework.

_Author: Dmitry Bagaev_

In [None]:
using DrWatson

In [None]:
@quickactivate "RxInferThesisExperiments"

In [None]:
using RxInferThesisExperiments, RxInfer, StaticArrays, Plots, PGFPlotsX, LaTeXStrings
using LinearAlgebra, StableRNGs, Random, BenchmarkTools, ColorSchemes, Dates, DataFrames

In [None]:
const bfolder = datadir("lds", "rxinfer")

In [None]:
# Pregenerate paths for benchmark data
mkpath(bfolder);

In [None]:
# Include the model specification
include(srcdir("models", "rxinfer", "rotating.jl"));

In [None]:
function run_benchmark(params)
    @unpack T, d, seed = params
    
    environment = RotatingTracking(d; rng = StableRNG(seed))
    
    states, observations = rand(StableRNG(seed), environment, T);
    model    = rotating(T, environment) 
    result   = inference(model = model, data = (y = observations, ));
    e_states = extract_posteriors(T, result)
    amse     = compute_amse(states, e_states)
    
    benchmark_modelcreation = @benchmark RxInfer.create_model(rotating($T, $(environment)))
    
    benchmark_inference = @benchmark inference(model = model, data = (y = observations, )) setup=begin
        model = rotating($T, $(environment))
        states, observations = rand(StableRNG($seed), $environment, $T);
    end
    
    emse = compute_emse(seed) do _seed
        local environment = RotatingTracking(d; rng = StableRNG(_seed))
        local states, observations = rand(StableRNG(_seed), environment, T);
        local model    = rotating(T, environment) 
        local result   = inference(model = model, data = (y = observations, ));
        local e_states = extract_posteriors(T, result)
        return compute_amse(states, e_states)
    end
    
    output = @strdict T d seed states e_states observations amse emse benchmark_modelcreation benchmark_inference
    
    return output
end

In [None]:
# Here we create a list of parameters we want to run our benchmarks with
benchmark_params = dict_list(Dict(
    "T"    => [ 10, 20, 30 ],
    "d"    => [ 2, 3, 4 ],
    "seed" => [ 42 ]
));

In [None]:
# First run maybe slow, you may track the progress in the terminal
# Subsequent runs will not create new benchmarks 
# but will reload it from data folder
benchmarks = map(benchmark_params) do params
    result, _ = produce_or_load(run_benchmark, bfolder, params; tag = false, force = false)
    return result
end;

In [None]:
benchmarks_table = sort(prepare_benchmarks_table(bfolder), [ :T ])

# Extra plots

In [None]:
pgfplotsx()

In [None]:
# `plotting`range
colors = ColorSchemes.tableau_10

# Some default settings for plotting
pfontsettings = (
    titlefontsize=18,
    guidefontsize=16,
    tickfontsize=14,
    legendfontsize=14,
    legend = :outerright,
    legend_font_halign = :left,
    size = (800, 300)
)

p1xticks = (benchmarks_table.T, string.(benchmarks_table.T))

p1yticks = (
    [ 0.01, 0.1, 0.3, 0.5, 0.7, 1, 1.3, 1.5, 2, 3 ], 
    [ "0.01", "0.1", "0.3", "0.5", "0.7", "1", "1.3", "1.5", "2", "3" ]
)

p1 = plot(
    xlabel = "Number of observations in dataset (log10-scale)", 
    ylabel = "Time (in ms, log10-scale)"; 
    xscale = :log10,
    yscale = :log10,
    xticks = p1xticks,
    yticks = p1yticks,
    pfontsettings...
)

nd = [ 2, 3, 4 ];
mshapes = [ :utriangle, :diamond, :pentagon ]

for (index, (mshape, d)) in enumerate(zip(mshapes, nd))
    filtered = filter((r) -> r["d"] == d, benchmarks_table)
    sorted      = sort(filtered, [ :T ])
    range       = map(f -> f["T"], eachrow(sorted))
    t_inference = map(f -> to_ms(f["inference"][1] - f["inference"][3]), eachrow(sorted))
    
    plot!(p1, range, t_inference, label = "Reactive MP inference ($d dimensional)", marker = mshape, color = colors[index])
end

##

p2xticks = (
    [ 2, 3, 4 ],
    string.([ 2, 3, 4 ])
)

p2yticks = (
    [ 0.01, 0.1, 0.3, 0.5, 0.7, 1, 1.3, 1.5, 2, 3 ], 
    [ "0.01", "0.1", "0.3", "0.5", "0.7", "1", "1.3", "1.5", "2", "3" ]
)

p2 = plot(
    xlabel = "Number of dimensions", 
    ylabel = "Time (in ms, log10-scale)"; 
    yscale = :log10,
    xticks = p2xticks,
    yticks = p2yticks,
    pfontsettings...
)

Ts = [ 10, 20, 30 ];
mshapes = [ :utriangle, :diamond, :pentagon, :circle ]

for (index, (mshape, T)) in enumerate(zip(mshapes, Ts))
    filtered = filter((r) -> r["T"] == T, benchmarks_table)
    sorted      = sort(filtered, [ :d ])
    range       = map(f -> f["d"], eachrow(sorted))
    t_inference = map(f -> to_ms(f["inference"][1] - f["inference"][3]), eachrow(sorted))
    
    plot!(p2, range, t_inference, label = "Reactive MP inference ($T observations)", marker = mshape, color = colors[index])
end

plot(p1, p2, size = (800, 600), layout = @layout([ a; b ]))

# Versions

In [None]:
versioninfo()

In [None]:
] status