## Import Packages

In [1]:
using ReactiveMP, Rocket, Random, GraphPPL, Distributions, LinearAlgebra, SpecialFunctions
using BenchmarkTools

## Settings

In [2]:
# seed for reproducibility
rng = MersenneTwister(1234)

# model parameters
θ_true = 0.7
α = 1.0
β = 1.5  

# benchmark settings
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 60;

## Generate data

In [3]:
generate_data(rng, θ; nr_samples=10) = rand(rng, Bernoulli(θ), nr_samples) |> bool2float

function bool2float(x::Vector{Bool})
    n = length(x)
    y = Vector{Float64}(undef,n)
    for i=1:n
        x[i] ? y[i] = 1.0 : y[i] = 0.0
    end
    return y
end;

## ReactiveMP.jl scale factor extension

In [4]:
#create rules (both message and scalefactor) for Bernoulli node
@rule Bernoulli(:p, Marginalisation) (m_out::PointMass, meta::ScaleFactorMeta, ) = begin 
    r = mean(m_out)
    message = Beta(one(r) + r, 2one(r) - r)
    scalefactor = -log(0.5)
    return ScaledMessage(message, scalefactor)
end

#create rules (both message and scalefactor) for Beta node
@rule Beta(:out, Marginalisation) (m_a::PointMass, m_b::PointMass, meta::ScaleFactorMeta) = begin
    message = @call_rule Beta(:out, Marginalisation) (m_a = m_a, m_b = m_b)
    scalefactor = 0.0
    return ScaledMessage(message, scalefactor)
end

In [5]:
#Product function for equality node
function ReactiveMP.prod(::ProdAnalytical, left::ScaledMessage{ <: Beta }, right::ScaledMessage{ <: Beta })
    a_left, b_left  = Distributions.params(left.message)
    a_right, b_right = Distributions.params(right.message)

    message = prod(ProdAnalytical(),left.message,right.message)
    scalefactor = left.scale + right.scale - log(beta(a_left + a_right - 1, b_left + b_right - 1)) +
                    log(beta(a_left, b_left)) + log(beta(a_right, b_right))

    return ScaledMessage(message,scalefactor)
end

## Inference by ReactiveMP (scale factors)

In [6]:
@model [ default_meta = ScaleFactorMeta() ] function model_cointoss_scalefactor(α, β; nr_samples=10)

    θ ~ Beta(α, β)

    y = datavar(Float64, nr_samples)

    for n = 1:nr_samples
        y[n] ~ Bernoulli(θ)
    end

    return y, θ
end

model_cointoss_scalefactor (generic function with 1 method)

In [7]:
function inference_cointoss_scalefactor(data, α, β; nr_samples=10)

    model, (y, θ) = model_cointoss_scalefactor(α, β; nr_samples = nr_samples, options = (limit_stack_depth = 500, ));

    θ_mar_sf = keep(Marginal)

    θ_sub = subscribe!(getmarginal(θ), θ_mar_sf)

    bmark_sf = @benchmark update!($y, $data)

    unsubscribe!(θ_sub)

    return θ_mar_sf, bmark_sf
end

inference_cointoss_scalefactor (generic function with 1 method)

## Inference by ReactiveMP (Bethe free energy)

In [8]:
@model function model_cointoss_bfe(α, β; nr_samples=10)

    θ ~ Beta(α, β)

    y = datavar(Float64, nr_samples)

    for n = 1:nr_samples
        y[n] ~ Bernoulli(θ)
    end
    
    return y, θ
end

model_cointoss_bfe (generic function with 1 method)

In [9]:
function inference_cointoss_bfe(data, α, β; nr_samples=10)

    model, (y, θ) = model_cointoss_bfe(α, β; nr_samples = nr_samples, options = (limit_stack_depth = 500, ));

    θ_mar = keep(Marginal)
    bfe = keep(Float64)

    θ_subscribe = subscribe!(getmarginal(θ), θ_mar);

    bfe_subscribe = subscribe!(score(Float64,BetheFreeEnergy(), model), bfe);

    bmark = @benchmark update!($y, $data)
    
    unsubscribe!((θ_subscribe, bfe_subscribe))

    return θ_mar, bfe, bmark
end

inference_cointoss_bfe (generic function with 1 method)

## Correctness check and performance comparison

#### Scale factors (N=10)

In [10]:
data = generate_data(rng, θ_true; nr_samples=10);

In [11]:
θ_mar_sf, bmark_sf = inference_cointoss_scalefactor(data, α, β; nr_samples=10);
println(-θ_mar_sf[end].data.scale)
bmark_sf

-7.349984471920427


BenchmarkTools.Trial: 10000 samples with 5 evaluations.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m6.400 μs[22m[39m … [35m 3.054 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 99.29%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m6.960 μs              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m8.775 μs[22m[39m ± [32m58.925 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m13.37% ±  1.99%

  [39m [39m▄[39m█[39m▆[34m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▃[39m█[39m█[39m█[34m▇[39m[39m▅

#### Bethe free energy (N=10)

In [12]:
θ_mar_bfe, bfe, bmark_bfe = inference_cointoss_bfe(data, α, β; nr_samples=10);
println(-bfe[end])
bmark_bfe

-7.349984471920429


BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m25.300 μs[22m[39m … [35m 26.079 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 99.54%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m28.300 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m35.594 μs[22m[39m ± [32m365.255 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m14.47% ±  1.41%

  [39m [39m█[39m▇[39m▄[34m▂[39m[39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▇[39m█[39m█[39m█

#### Scale factors (N=100)

In [13]:
data = generate_data(rng, θ_true; nr_samples=100);

In [14]:
θ_mar_sf, bmark_sf = inference_cointoss_scalefactor(data, α, β; nr_samples=100);
println(-θ_mar_sf[end].data.scale)
bmark_sf

-57.64998869192096


BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m64.600 μs[22m[39m … [35m 26.146 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 99.45%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m67.900 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m89.448 μs[22m[39m ± [32m621.685 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m16.98% ±  2.44%

  [39m▇[39m█[34m▆[39m[39m▄[39m▃[39m▃[39m▃[39m▃[39m▄[39m▅[39m▅[39m▄[39m▃[39m▃[39m▂[39m▁[32m▁[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂
  [39m█[39m█[34m█[39m[

#### Bethe free energy (N=100)

In [15]:
θ_mar_bfe, bfe, bmark_bfe = inference_cointoss_bfe(data, α, β; nr_samples=100);
println(-bfe[end])
bmark_bfe

-57.64998869192087


BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m234.500 μs[22m[39m … [35m22.721 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 98.22%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m244.700 μs              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m307.073 μs[22m[39m ± [32m 1.029 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m15.95% ±  4.70%

  [39m▇[39m█[39m▇[34m▇[39m[39m▅[39m▄[39m▄[39m▄[39m▃[39m▂[39m▂[39m▁[39m▂[39m▁[39m▁[39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂
  [39m█[39m█[39m█[34m█

#### Scale factors (N=1000)

In [16]:
data = generate_data(rng, θ_true; nr_samples=1000);

In [17]:
θ_mar_sf, bmark_sf = inference_cointoss_scalefactor(data, α, β; nr_samples=1000);
println(-θ_mar_sf[end].data.scale)
bmark_sf

-616.8891152516392


BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m641.700 μs[22m[39m … [35m20.715 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 92.29%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m699.050 μs              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m877.647 μs[22m[39m ± [32m 1.099 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m10.33% ±  7.94%

  [39m█[39m█[34m▆[39m[39m▅[39m▅[39m▅[39m▅[39m▄[32m▄[39m[39m▄[39m▃[39m▃[39m▂[39m▂[39m▂[39m▁[39m▁[39m▁[39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂
  [39m█[39m█[34m█[39m[

#### Bethe free energy (N=1000)

In [18]:
θ_mar_bfe, bfe, bmark_bfe = inference_cointoss_bfe(data, α, β; nr_samples=1000);
println(-bfe[end])
bmark_bfe

-616.8891152516348


BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m2.614 ms[22m[39m … [35m18.681 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 76.50%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m3.779 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m4.211 ms[22m[39m ± [32m 2.237 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m8.73% ± 12.80%

  [39m [39m▆[39m█[39m█[39m▇[34m▆[39m[39m▅[32m▁[39m[39m▃[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▅[39m█[39m█[39m█[39m█[34m█[39m[39m