In [None]:
using DrWatson

In [None]:
@quickactivate "RxInferThesisExperiments"

In [None]:
using RxInferThesisExperiments, Plots, PGFPlotsX, LaTeXStrings
using LinearAlgebra, StableRNGs, Random, BenchmarkTools, ColorSchemes, Dates, DataFrames

import RxInfer, ReactiveMP, ForneyLab, Turing

In [None]:
pgfplotsx()

# gr()

In [None]:
const outfolder = plotsdir("hgf")

In [None]:
mkpath(outfolder);

In [None]:
function analyze_benchmarks(filterfunction, bfolder)
    benchmarks = prepare_benchmarks_table(bfolder);

    # Select only a portion of benchmarks for plotting
    filtered = filter(filterfunction, benchmarks)

    sorted = sort(filtered, [ :T ])

    # RxInfer includes the model creation time in it
    inference = getindex.(sorted.inference, 1) .- getindex.(sorted.inference, 3)
    creation  = getindex.(sorted.creation, 1) .- getindex.(sorted.creation, 3)

    min_timing_range = min(minimum(inference), minimum(creation))
    max_timing_range = max(maximum(inference), maximum(creation))

    return sorted, (inference, creation), (min_timing_range, max_timing_range)
end

In [None]:
target_seed = 42
target_niterations = 3
target_nsamples = 100

In [None]:
rxifb, (rxi_inference, rxi_creation), (rxi_min_tr, rxi_max_tr) = analyze_benchmarks(datadir("hgf", "rxinfer")) do r
    return r["niterations"] == target_niterations && r["seed"] == target_seed
end

In [None]:
flfb, (fl_inference, fl_creation), (fl_min_tr, fl_max_tr) = analyze_benchmarks(datadir("hgf", "forneylab")) do r
    return r["niterations"] == target_niterations && r["seed"] == target_seed
end

In [None]:
tgfb, (tg_inference, tg_creation), (tg_min_tr, tg_max_tr) = analyze_benchmarks(datadir("hgf", "turing", "nuts")) do r
    return r["nsamples"] == target_nsamples && r["seed"] == target_seed
end

In [None]:
min_timing_range = min(rxi_min_tr, fl_min_tr, tg_min_tr)
max_timing_range = max(rxi_max_tr, fl_max_tr, tg_max_tr)

timing_range = exp.(range(log(min_timing_range), log(max_timing_range); length = 10))
sizes_range = sort(collect(union(rxifb.T, flfb.T, tgfb.T)))

yticks = (timing_range, replace.(to_ms_str.(timing_range; digits = 0), ".0" => ""))
xticks = (sizes_range, string.(sizes_range))

pfontsettings = (
    titlefontsize=18,
    guidefontsize=16,
    tickfontsize=14,
    legendfontsize=14,
    legend_font_halign = :left
)

p = plot(
    size = (800, 400),
    yscale = :log10, xscale = :log10, yticks = yticks, xticks = xticks, 
    ylabel = "Time (log-scale)", xlabel = "Number of observation (log-scale)",
    legend = :outerright;
    pfontsettings...
)

p = plot!(p, rxifb.T, rxi_inference, label = "Reactive MP ($(target_niterations) iterations)", marker = :circle)
p = plot!(p, flfb.T, fl_inference, label = "Scheduled MP (inference, $(target_niterations) iterations)", marker = :utriangle)
p = plot!(p, flfb.T, fl_creation, label = "Scheduled MP (compilation)", marker = :rect)
p = plot!(p, tgfb.T, tg_inference, label = "NUTS ($target_nsamples)", marker = :dtriangle)

savefig(joinpath(outfolder, "04-benchmark_comparison.tex"))

display("image/png", p)