In [1]:
using DrWatson

In [2]:
@quickactivate "RxInferThesisExperiments"

In [14]:
using RxInferThesisExperiments, ForneyLab, StaticArrays, Plots, PGFPlotsX, LaTeXStrings
using LinearAlgebra, StableRNGs, Random, BenchmarkTools, ColorSchemes, ProgressMeter, Dates

In [4]:
const bfolder = datadir("hgf", "forneylab")

"/Users/bvdmitri/.julia/dev/thesis/data/hgf/forneylab"

In [5]:
# Pregenerate paths for benchmark data
mkpath(bfolder);

In [6]:
const environment = HGFEnvironment()

HGFEnvironment()

In [7]:
# Include the model specification
include(srcdir("models", "forneylab", "hgf.jl"));

In [8]:
function run_benchmark(params)
    @unpack T, niterations, seed = params
    
    zstates, xstates, observations = rand(StableRNG(seed), environment, T);
    model    = hgf(environment.kappa, environment.omega)
    result   = run_inference(model, observations; iterations = niterations)
    e_states = extract_posteriors(T, result)
    z_amse   = compute_amse(zstates, e_states[:z])
    x_amse   = compute_amse(xstates, e_states[:x])
    
    benchmark_modelcreation = @benchmark hgf(environment.kappa, environment.omega; force = true)
    
    benchmark_inference = @benchmark run_inference(model, observations; iterations = $niterations) setup=begin
        model = hgf($(environment.kappa), $(environment.omega))
        zstates, xstates, observations = rand(StableRNG($seed), environment, $T);
    end
    
    emse = compute_emse(seed) do _seed
        local zstates, xstates, observations = rand(StableRNG(_seed), environment, T);
        local model    = hgf(environment.kappa, environment.omega)
        local result   = run_inference(model, observations; iterations = niterations)
        local e_states = extract_posteriors(T, result)
        return [ compute_amse(zstates, e_states[:z]), compute_amse(xstates, e_states[:x]) ]
    end
    
    z_emse = emse[1]
    x_emse = emse[2]
    
    states = (z = zstates, x = xstates)
    
    output = @strdict T niterations seed states e_states observations z_amse x_amse z_emse x_emse benchmark_modelcreation benchmark_inference
    
    return output
end

run_benchmark (generic function with 1 method)

In [21]:
# Here we create a list of parameters we want to run our benchmarks with
benchmark_params = dict_list(Dict(
    "T"           => [ 10, 20, 30, 100 ],
    "niterations" => [ 3, 10 ],
    "seed"        => [ 42 ]
));

In [22]:
# First run maybe slow, you may track the progress in the terminal
# Subsequent runs will not create new benchmarks 
# but will reload it from data folder
benchmarks = map(benchmark_params) do params
    result, _ = produce_or_load(run_benchmark, bfolder, params; tag = false, force = false)
    return result
end;

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile /Users/bvdmitri/.julia/dev/thesis/data/hgf/forneylab/T=100_niterations=3_seed=42.jld2 does not exist. Producing it now...
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile /Users/bvdmitri/.julia/dev/thesis/data/hgf/forneylab/T=100_niterations=3_seed=42.jld2 saved.
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile /Users/bvdmitri/.julia/dev/thesis/data/hgf/forneylab/T=10_niterations=10_seed=42.jld2 does not exist. Producing it now...
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile /Users/bvdmitri/.julia/dev/thesis/data/hgf/forneylab/T=10_niterations=10_seed=42.jld2 saved.
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile /Users/bvdmitri/.julia/dev/thesis/data/hgf/forneylab/T=20_niterations=10_seed=42.jld2 does not exist. Producing it now...
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile /Users/bvdmitri/.julia/dev/thesis/data/hgf/forneylab/T=20_niterations=10_seed=42.jld2 saved.
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFil

In [23]:
sort(prepare_benchmarks_table(bfolder), [ :T ])

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mScanning folder /Users/bvdmitri/.julia/dev/thesis/data/hgf/forneylab for result files.
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mAdded 8 entries.


Row,T,z_emse,x_emse,z_amse,niterations,seed,x_amse,inference,creation
Unnamed: 0_level_1,Int64?,Float64?,Float64?,Float64?,Int64?,Int64?,Float64?,Tuple…?,Tuple…?
1,10,1.14916,0.1582,1.03239,10,42,0.224379,"(4.42955e7, 5.2719e7, 0.0)","(20.5719, 24.2844, 0.0)"
2,10,1.32345,0.158711,1.18141,3,42,0.231314,"(1.26549e7, 1.39382e7, 0.0)","(20.0281, 22.458, 0.0)"
3,20,0.767198,0.154245,0.586204,10,42,0.187199,"(9.22362e7, 1.03299e8, 3.5976e6)","(21.8494, 28.0149, 0.0)"
4,20,0.881845,0.155043,0.663982,3,42,0.190732,"(2.57316e7, 3.00252e7, 0.0)","(20.8613, 23.7992, 0.0)"
5,30,0.59447,0.152034,0.432486,10,42,0.165137,"(1.39433e8, 1.49582e8, 3.5124e6)","(21.2, 25.1332, 0.0)"
6,30,0.681506,0.152781,0.488609,3,42,0.167435,"(3.85109e7, 4.28757e7, 0.0)","(20.1266, 22.6136, 0.0)"
7,100,0.240489,0.154593,0.161371,10,42,0.153471,"(4.85097e8, 5.24958e8, 2.30265e7)","(20.4794, 24.596, 0.0)"
8,100,0.274408,0.155031,0.18027,3,42,0.154218,"(1.42926e8, 1.60292e8, 3.95558e6)","(20.6894, 23.634, 0.0)"


# Versions

In [12]:
versioninfo()

Julia Version 1.9.0
Commit 8e630552924 (2023-05-07 11:25 UTC)
Platform Info:
  OS: macOS (x86_64-apple-darwin22.4.0)
  CPU: 12 × Intel(R) Core(TM) i7-8850H CPU @ 2.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, skylake)
  Threads: 2 on 12 virtual cores


In [13]:
] status

[36m[1mProject[22m[39m RxInferThesisExperiments v1.0.0
[32m[1mStatus[22m[39m `~/.julia/dev/thesis/Project.toml`
  [90m[b5ca4192] [39mAdvancedVI v0.2.3
  [90m[6e4b80f9] [39mBenchmarkTools v1.3.2
  [90m[76274a88] [39mBijectors v0.12.4
  [90m[35d6a980] [39mColorSchemes v3.21.0
  [90m[a93c6f00] [39mDataFrames v1.5.0
  [90m[31c24e10] [39mDistributions v0.25.95
  [90m[634d3b9d] [39mDrWatson v2.12.5
  [90m[442a2c76] [39mFastGaussQuadrature v0.5.1
  [90m[9fc3f58a] [39mForneyLab v0.12.0
  [90m[f6369f11] [39mForwardDiff v0.10.35
  [90m[14197337] [39mGenericLinearAlgebra v0.3.11
  [90m[19dc6840] [39mHCubature v1.5.1
  [90m[7073ff75] [39mIJulia v1.24.0
  [90m[b964fa9f] [39mLaTeXStrings v1.3.0
  [90m[bdcacae8] [39mLoopVectorization v0.12.159
  [90m[3bd65402] [39mOptimisers v0.2.18
  [90m[8314cec4] [39mPGFPlotsX v1.6.0
  [90m[e4faabce] [39mPProf v2.2.2
  [90m[91a5bcdd] [39mPlots v1.38.15
  [90m[92933f4c] [39mProgressMeter v1.7.2
  [90m[37e2e3b7] [39m