In [None]:
using DrWatson

In [None]:
@quickactivate "RxInferThesisExperiments"

In [None]:
using RxInferThesisExperiments, RxInfer, StaticArrays, Plots, PGFPlotsX, LaTeXStrings
using LinearAlgebra, StableRNGs, Random, BenchmarkTools, ColorSchemes, Dates

In [None]:
const bfolder = datadir("hgf", "rxinfer")
const outdir = plotsdir("hgf")

In [None]:
pgfplotsx()

In [None]:
# Pregenerate paths for benchmark data
mkpath(bfolder);
mkpath(outdir)

In [None]:
const environment = HGFEnvironment()

In [None]:
# Include the model specification
include(srcdir("models", "rxinfer", "hgf.jl"));

In [None]:
function run_benchmark(params)
    @unpack T, niterations, seed = params
    
    zstates, xstates, observations = rand(StableRNG(seed), environment, T);
    model    = hgf(environment.kappa, environment.omega)
    result   = run_inference(model, observations; iterations = niterations)
    e_states = extract_posteriors(T, result)
    z_amse   = compute_amse(zstates, e_states[:z])
    x_amse   = compute_amse(xstates, e_states[:x])
    
    benchmark_modelcreation = @benchmark RxInfer.create_model(hgf(environment.kappa, environment.omega), 
        constraints = hgfconstraints(),
    )
    
    benchmark_inference = @benchmark run_inference(model, observations; iterations = $niterations) seconds=30 setup=begin
        model = hgf($(environment.kappa), $(environment.omega))
        zstates, xstates, observations = rand(StableRNG($seed), environment, $T);
    end
    
    emse = compute_emse(seed) do _seed
        local zstates, xstates, observations = rand(StableRNG(_seed), environment, T);
        local model    = hgf(environment.kappa, environment.omega)
        local result   = run_inference(model, observations; iterations = niterations)
        local e_states = extract_posteriors(T, result)
        return [ compute_amse(zstates, e_states[:z]), compute_amse(xstates, e_states[:x]) ]
    end
    
    z_emse = emse[1]
    x_emse = emse[2]
    
    states = (z = zstates, x = xstates)
    
    output = @strdict T niterations seed states e_states observations z_amse x_amse z_emse x_emse benchmark_modelcreation benchmark_inference
    
    return output
end

In [None]:
# Here we create a list of parameters we want to run our benchmarks with
benchmark_params = dict_list(Dict(
    "T"           => [ 10, 20, 30, 100, 300, 500, 700, 1000, 3000, 5000, 7000, 10_000, 30_000, 50_000, 70_000, 100_000 ],
    "niterations" => [ 3, 5, 10, 20 ],
    "seed"        => [ 42 ]
));

In [None]:
# First run maybe slow, you may track the progress in the terminal
# Subsequent runs will not create new benchmarks 
# but will reload it from data folder
benchmarks = map(benchmark_params) do params
    result, _ = produce_or_load(run_benchmark, bfolder, params; tag = false, force = false)
    return result
end;

In [None]:
benchmarks_table = sort(prepare_benchmarks_table(bfolder), [ :T ])

# Extra plots

In [None]:
# `plotting`range
colors = ColorSchemes.tableau_10

# Some default settings for plotting
pfontsettings = (
    titlefontsize=18,
    guidefontsize=16,
    tickfontsize=14,
    legendfontsize=14,
    legend = :outertop,
    legend_font_halign = :left,
    legend_orientation=:h,
    legend_column = 2,
    size = (800, 300)
)

p1xticks = (
    [ 10, 100, 1000, 10_000, 100_000 ],
    [ L"10^1", L"10^2", L"10^3", L"10^4", L"10^5" ]
)

p1yticks = (
    [ 0.1, 1.0, 10.0, 100, 1000, 10000, 100000 ], 
    [ L"10^{-1}", L"10^{0}", L"10^{1}", L"10^{2}", L"10^{3}", L"10^{4}", L"10^{5}" ]
)

p1 = plot(
    xlabel = "Number of observations in dataset (log10-scale)", 
    ylabel = "Time (in ms, log10-scale)"; 
    xscale = :log10,
    yscale = :log10,
    xticks = p1xticks,
    yticks = p1yticks,
    ylims = (minimum(p1yticks[1]), maximum(p1yticks[1])),
    pfontsettings...
)

nits = [ 3, 5, 10, 20 ];
mshapes = [  :diamond, :circle, :rect, :utriangle ]
styles = [ :solid, :dash, :dot, :dashdot ]

for (index, (mshape, nit)) in enumerate(zip(mshapes, nits))
    filtered    = filter((r) -> r["niterations"] == nit, benchmarks_table)
    sorted      = sort(filtered, [ :T ])
    range       = map(f -> f["T"], eachrow(sorted))
    t_inference = map(f -> to_ms(f["inference"][1] - f["inference"][3]), eachrow(sorted))
    
    plot!(p1, range, t_inference, label = "$nit iterations", marker = mshape, color = colors[index], style = styles[index])
end

savefig(p1, joinpath(outdir, "04-rxinfer_hgf_scalability_size.tex"))

##

p2xticks = (
    nits,
    string.(nits)
)

p2yticks = (
    [ 0.1, 1.0, 10.0, 100, 1000, 10_000, 100_000 ], 
    [ L"10^{-1}", L"10^{0}", L"10^{1}", L"10^{2}", L"10^{3}", L"10^{4}", L"10^{5}" ]
)

p2 = plot(
    xlabel = "Number of iterations", 
    ylabel = "Time (in ms, log10-scale)"; 
    yscale = :log10,
    xticks = p2xticks,
    yticks = p2yticks,
    ylims = (minimum(p2yticks[1]), maximum(p2yticks[1])),
    pfontsettings...
)

# :diamond, :hexagon, :cross, :xcross, :utriangle, :dtriangle, :rtriangle, :ltriangle, :pentagon, :heptagon,

Ts = [ 10, 1000, 10_000, 100_000 ];
mshapes = [:utriangle, :dtriangle, :rtriangle, :ltriangle ]
styles = [ :solid, :dash, :dot, :dashdot ]

for (index, (mshape, T)) in enumerate(zip(mshapes, Ts))
    filtered = filter((r) -> r["T"] == T, benchmarks_table)
    sorted      = filter((f) -> f["niterations"] ∈ nits, (sort(filtered, [ :niterations ])))
    range       = map(f -> f["niterations"], eachrow(sorted))
    t_inference = map(f -> to_ms(f["inference"][1] - f["inference"][3]), eachrow(sorted))
    
    plot!(p2, range, t_inference, label = "$T observations", marker = mshape, color = colors[index], style = styles[index])
end

savefig(p2, joinpath(outdir, "04-rxinfer_hgf_scalability_nits.tex"))

p = plot(p1, p2, size = (800, 600), layout = @layout([ a; b ]))

display("image/png", p)

# Versions

In [None]:
versioninfo()

In [None]:
] status