In [None]:
using DrWatson

In [None]:
@quickactivate "RxInferThesisExperiments"

In [None]:
using RxInferThesisExperiments, Turing, StaticArrays, Plots, PGFPlotsX, LaTeXStrings
using LinearAlgebra, StableRNGs, Random, BenchmarkTools, ColorSchemes, Dates, ProgressMeter, Logging

In [None]:
const bfolder_nuts = datadir("hgf", "turing", "nuts")

In [None]:
# Pregenerate paths for benchmark data
mkpath(bfolder_nuts);

In [None]:
const environment = HGFEnvironment()

In [None]:
# Include the model specification
include(srcdir("models", "turing", "hgf.jl"));

In [None]:
function run_benchmark_nuts(params)
    return with_logger(NullLogger()) do
        @unpack T, nsamples, seed = params

        zstates, xstates, observations = rand(StableRNG(seed), environment, T);
        model    = (observation, zt_min_prior, xt_min_prior, z_std_prior, y_std_prior) -> begin 
            HGF(observation, zt_min_prior, xt_min_prior, z_std_prior, y_std_prior, environment.kappa, environment.omega)
        end
        result   = run_inference(model, observations; nsamples = nsamples, rng = StableRNG(seed))
        e_states = extract_posteriors(T, result)
        z_amse   = compute_amse(zstates, e_states[:z])
        x_amse   = compute_amse(xstates, e_states[:x])

        # Turing need to recreate the model every time....
        benchmark_modelcreation = @benchmark begin end

        benchmark_inference = @benchmark run_inference(model, observations; nsamples = $nsamples, rng = StableRNG($seed)) setup=begin
            model = (observation, zt_min_prior, xt_min_prior, z_std_prior, y_std_prior) -> begin 
                HGF(observation, zt_min_prior, xt_min_prior, z_std_prior, y_std_prior, environment.kappa, environment.omega)
            end
            zstates, xstates, observations = rand(StableRNG($seed), environment, $T);
        end

        emse = compute_emse(seed) do _seed
            local zstates, xstates, observations = rand(StableRNG(_seed), environment, T);
            local model = (observation, zt_min_prior, xt_min_prior, z_std_prior, y_std_prior) -> begin 
                HGF(observation, zt_min_prior, xt_min_prior, z_std_prior, y_std_prior, environment.kappa, environment.omega)
            end
            local result   = run_inference(model, observations; nsamples = nsamples, rng = StableRNG(_seed))
            local e_states = extract_posteriors(T, result)
            return [ compute_amse(zstates, e_states[:z]), compute_amse(xstates, e_states[:x]) ]
        end

        z_emse = emse[1]
        x_emse = emse[2]

        states = (z = zstates, x = xstates)

        output = @strdict T nsamples seed states e_states observations z_amse x_amse z_emse x_emse benchmark_modelcreation benchmark_inference

        return output
    end
end

In [None]:
# Here we create a list of parameters we want to run our benchmarks with
benchmark_params_nuts = dict_list(Dict(
    "T"           => [ 10, 20, 30, 100, 300 ],
    "nsamples"    => [ 100, 200 ],
    "seed"        => [ 42 ]
));

In [None]:
# First run maybe slow, you may track the progress in the terminal
# Subsequent runs will not create new benchmarks 
# but will reload it from data folder
benchmarks_nuts = map(benchmark_params_nuts) do params
    result, _ = produce_or_load(run_benchmark_nuts, bfolder_nuts, params; tag = false, force = false)
    return result
end;

In [None]:
sort(prepare_benchmarks_table(bfolder_nuts), [ :T ])

# Versions

In [None]:
versioninfo()

In [None]:
] status