This notebook performs a comprehensive becnhmark suit for the inference procedure for the double pendulum system using the RxInfer framework.

_Author: Dmitry Bagaev_

In [1]:
using DrWatson

In [2]:
@quickactivate "RxInferThesisExperiments"

In [13]:
using RxInferThesisExperiments, Turing, StaticArrays, Plots, PGFPlotsX, LaTeXStrings, ReverseDiff
using LinearAlgebra, StableRNGs, Random, BenchmarkTools, ColorSchemes, Dates, DataFrames, Logging

In [4]:
const bfolder = datadir("nlds", "turing", "nuts")

"/Users/bvdmitri/.julia/dev/thesis/data/nlds/turing/nuts"

In [5]:
# Pregenerate paths for benchmark data
mkpath(bfolder);

In [6]:
# Create default environment with default parameters
const environment = DoublePendulum()

DoublePendulum()

In [7]:
# Define state-transition function, uses RK4 method internally, see the `src/` folder
f(state) = state_transition(environment)(state)

f (generic function with 1 method)

In [8]:
# Include the model specification
include(srcdir("models", "turing", "doublependulum.jl"));

In [14]:
function run_benchmark(params)
    return with_logger(NullLogger()) do
        @unpack T, nsamples, seed = params

        states, observations = rand(StableRNG(seed), environment, T);
        model    = double_pendulum(observations, T)
        method   = NUTS()
        result   = sample_inference(model, method = method, nsamples = nsamples, rng = StableRNG(seed))
        e_states = extract_posteriors(T, result)
        amse     = compute_amse(states, e_states)

        benchmark_modelcreation = @benchmark double_pendulum($observations, $T)

        benchmark_inference = @benchmark sample_inference(model, method = $method; nsamples = $nsamples, rng = StableRNG($seed)) setup=begin
            states, observations = rand(StableRNG($seed), environment, $T);
            model = double_pendulum(observations, $T)
        end

        emse = compute_emse(seed) do _seed
            local states, observations = rand(StableRNG(_seed), environment, T; random_start = true);
            local model    = double_pendulum(observations, T)
            local method   = NUTS()
            local result   = sample_inference(model, method = method, nsamples = nsamples, rng = StableRNG(_seed))
            local e_states = extract_posteriors(T, result)
            return compute_amse(states, e_states)
        end

        output = @strdict T nsamples seed states e_states observations amse emse benchmark_modelcreation benchmark_inference

        return output
    end
end

run_benchmark (generic function with 1 method)

In [18]:
# Here we create a list of parameters we want to run our benchmarks with
benchmark_params = dict_list(Dict(
    "T"           => [ 10, 20, 30, 50, 100 ],
    "nsamples"    => [ 50, 100 ],
    "seed"        => [ 42 ]
));

In [19]:
# Disable turing's show progress as it hurts performance (a bit)
Turing.setprogress!(false)

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m[Turing]: progress logging is disabled globally
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m[AdvancedVI]: global PROGRESS is set as false


false

In [20]:
# First run maybe slow, you may track the progress in the terminal
# Subsequent runs will not create new benchmarks 
# but will reload it from data folder
benchmarks = map(benchmark_params) do params
    result, _ = produce_or_load(run_benchmark, bfolder, params; tag = false, force = false)
    return result
end;

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile /Users/bvdmitri/.julia/dev/thesis/data/nlds/turing/nuts/T=50_nsamples=50_seed=42.jld2 does not exist. Producing it now...
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile /Users/bvdmitri/.julia/dev/thesis/data/nlds/turing/nuts/T=50_nsamples=50_seed=42.jld2 saved.
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile /Users/bvdmitri/.julia/dev/thesis/data/nlds/turing/nuts/T=100_nsamples=50_seed=42.jld2 does not exist. Producing it now...
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile /Users/bvdmitri/.julia/dev/thesis/data/nlds/turing/nuts/T=100_nsamples=50_seed=42.jld2 saved.
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile /Users/bvdmitri/.julia/dev/thesis/data/nlds/turing/nuts/T=50_nsamples=100_seed=42.jld2 does not exist. Producing it now...
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile /Users/bvdmitri/.julia/dev/thesis/data/nlds/turing/nuts/T=50_nsamples=100_seed=42.jld2 saved.
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39

In [21]:
sort(prepare_benchmarks_table(bfolder), [ :T, :nsamples ])

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mScanning folder /Users/bvdmitri/.julia/dev/thesis/data/nlds/turing/nuts for result files.
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mAdded 10 entries.


Row,T,nsamples,seed,amse,emse,inference,creation
Unnamed: 0_level_1,Int64?,Int64?,Int64?,Float64?,Float64?,Tuple…?,Tuple…?
1,10,50,42,8.75468,6.20178,"(6.05655e9, 6.05655e9, 6.87432e7)","(3.422, 4.3362, 0.0)"
2,10,100,42,5.71171,4.92248,"(1.13747e10, 1.13747e10, 1.37468e8)","(3.291, 3.9133, 0.0)"
3,20,50,42,4.36263,10.0011,"(9.22417e9, 9.22417e9, 7.15448e7)","(3.494, 3.98589, 0.0)"
4,20,100,42,3.69172,8.02788,"(2.1521e10, 2.1521e10, 1.74118e8)","(3.291, 3.84254, 0.0)"
5,30,50,42,10.2714,8.10127,"(1.81542e10, 1.81542e10, 1.17916e8)","(3.291, 3.83923, 0.0)"
6,30,100,42,24.3499,6.32423,"(4.01462e10, 4.01462e10, 2.43946e8)","(3.292, 3.91343, 0.0)"
7,50,50,42,12.5688,5.25755,"(3.79663e10, 3.79663e10, 1.86353e8)","(3.382, 3.91803, 0.0)"
8,50,100,42,6.36813,4.90763,"(7.45187e10, 7.45187e10, 3.91663e8)","(3.29, 3.86858, 0.0)"
9,100,50,42,0.615323,5.09143,"(9.7227e10, 9.7227e10, 7.78381e8)","(3.292, 3.89914, 0.0)"
10,100,100,42,1.24621,3.7418,"(2.4265e11, 2.4265e11, 1.18686e9)","(3.382, 3.93415, 0.0)"


# Versions

In [22]:
versioninfo()

Julia Version 1.9.0
Commit 8e630552924 (2023-05-07 11:25 UTC)
Platform Info:
  OS: macOS (x86_64-apple-darwin22.4.0)
  CPU: 12 × Intel(R) Core(TM) i7-8850H CPU @ 2.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, skylake)
  Threads: 2 on 12 virtual cores


In [14]:
] status

[36m[1mProject[22m[39m RxInferThesisExperiments v1.0.0
[32m[1mStatus[22m[39m `~/.julia/dev/thesis/Project.toml`
  [90m[6e4b80f9] [39mBenchmarkTools v1.3.2
  [90m[35d6a980] [39mColorSchemes v3.21.0
  [90m[a93c6f00] [39mDataFrames v1.5.0
[32m⌃[39m [90m[31c24e10] [39mDistributions v0.25.94
  [90m[634d3b9d] [39mDrWatson v2.12.5
  [90m[9fc3f58a] [39mForneyLab v0.12.0
  [90m[f6369f11] [39mForwardDiff v0.10.35
  [90m[7073ff75] [39mIJulia v1.24.0
  [90m[b964fa9f] [39mLaTeXStrings v1.3.0
  [90m[3bd65402] [39mOptimisers v0.2.18
  [90m[8314cec4] [39mPGFPlotsX v1.6.0
  [90m[e4faabce] [39mPProf v2.2.2
[32m⌃[39m [90m[91a5bcdd] [39mPlots v1.38.12
  [90m[37e2e3b7] [39mReverseDiff v1.14.6
[32m⌃[39m [90m[86711068] [39mRxInfer v2.10.4
  [90m[860ef19b] [39mStableRNGs v1.0.0
  [90m[aedffcd0] [39mStatic v0.8.7
  [90m[90137ffa] [39mStaticArrays v1.5.25
  [90m[fce5fe82] [39mTuring v0.25.1
  [90m[e88e6eb3] [39mZygote v0.6.61
  [90m[37e2e46d] [39mLinearAlg