## Import Packages

In [1]:
using ReactiveMP, Rocket, Random, GraphPPL, Distributions, LinearAlgebra, SpecialFunctions
using BenchmarkTools

┌ Info: Precompiling ReactiveMP [a194aa59-28ba-4574-a09c-4a745416d6e3]
└ @ Base loading.jl:1423


## Settings

In [2]:
# seed for reproducibility
rng = MersenneTwister(1234)

# model parameters
θ_true = 0.7
α = 1.0
β = 1.5  

# benchmark settings
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 60;

## Generate data

In [3]:
generate_data(rng, θ; nr_samples=10) = rand(rng, Bernoulli(θ), nr_samples) |> bool2float

function bool2float(x::Vector{Bool})
    n = length(x)
    y = Vector{Float64}(undef,n)
    for i=1:n
        x[i] ? y[i] = 1.0 : y[i] = 0.0
    end
    return y
end;

## ReactiveMP.jl scale factor extension

In [4]:
#create rules (both message and scalefactor) for Bernoulli node
@rule Bernoulli(:p, Marginalisation) (m_out::PointMass, meta::ScaleFactorMeta, ) = begin 
    r = mean(m_out)
    message = Beta(one(r) + r, 2one(r) - r)
    scalefactor = -log(0.5)
    return ScaledMessage(message, scalefactor)
end

#create rules (both message and scalefactor) for Beta node
@rule Beta(:out, Marginalisation) (m_a::PointMass, m_b::PointMass, meta::ScaleFactorMeta) = begin
    message = @call_rule Beta(:out, Marginalisation) (m_a = m_a, m_b = m_b)
    scalefactor = 0.0
    return ScaledMessage(message, scalefactor)
end

In [5]:
#Product function for equality node
function ReactiveMP.prod(::ProdAnalytical, left::ScaledMessage{ <: Beta }, right::ScaledMessage{ <: Beta })
    a_left, b_left  = Distributions.params(left.message)
    a_right, b_right = Distributions.params(right.message)

    message = prod(ProdAnalytical(),left.message,right.message)
    scalefactor = left.scale + right.scale - log(beta(a_left + a_right - 1, b_left + b_right - 1)) +
                    log(beta(a_left, b_left)) + log(beta(a_right, b_right))

    return ScaledMessage(message,scalefactor)
end

## Inference by ReactiveMP (scale factors)

In [6]:
@model [ default_meta = ScaleFactorMeta() ] function model_cointoss_scalefactor(α, β; nr_samples=10)

    θ ~ Beta(α, β)

    y = datavar(Float64, nr_samples)

    for n = 1:nr_samples
        y[n] ~ Bernoulli(θ)
    end

    return y, θ
end

model_cointoss_scalefactor (generic function with 1 method)

In [7]:
function inference_cointoss_scalefactor(data, α, β; nr_samples=10)

    model, (y, θ) = model_cointoss_scalefactor(α, β; nr_samples = nr_samples, options = (limit_stack_depth = 500, ));

    θ_mar_sf = keep(Marginal)

    θ_sub = subscribe!(getmarginal(θ), θ_mar_sf)

    bmark_sf = @benchmark update!($y, $data)

    unsubscribe!(θ_sub)

    return θ_mar_sf, bmark_sf
end

inference_cointoss_scalefactor (generic function with 1 method)

## Inference by ReactiveMP (Bethe free energy)

In [8]:
@model function model_cointoss_bfe(α, β; nr_samples=10)

    θ ~ Beta(α, β)

    y = datavar(Float64, nr_samples)

    for n = 1:nr_samples
        y[n] ~ Bernoulli(θ)
    end
    
    return y, θ
end

model_cointoss_bfe (generic function with 1 method)

In [14]:
function inference_cointoss_bfe(data, α, β; nr_samples=10)

    model, (y, θ) = model_cointoss_bfe(α, β; nr_samples = nr_samples, options = (limit_stack_depth = 500, ));

    θ_mar = keep(Marginal)
    bfe = keep(Float64)

    θ_subscribe = subscribe!(getmarginal(θ), θ_mar);

    bfe_subscribe = subscribe!(score(Float64,BetheFreeEnergy(), model), bfe);

    bmark = @benchmark update!($y, $data)
    
    unsubscribe!((θ_subscribe, bfe_subscribe))

    return θ_mar, bfe, bmark
end

inference_cointoss_bfe (generic function with 1 method)

## Correctness check and performance comparison

#### Scale factors (N=10)

In [15]:
data = generate_data(rng, θ_true; nr_samples=10);

In [16]:
θ_mar_sf, bmark_sf = inference_cointoss_scalefactor(data, α, β; nr_samples=10);
println(-θ_mar_sf[end].data.scale)
bmark_sf

-6.523305898735956


BenchmarkTools.Trial: 10000 samples with 5 evaluations.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m6.420 μs[22m[39m … [35m 4.270 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 99.41%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m7.120 μs              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m8.957 μs[22m[39m ± [32m69.076 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m13.30% ±  1.72%

  [39m▇[39m█[39m [39m [34m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m█[39m█[39m█[39m▅[34m▃[39m[39m▄

#### Bethe free energy (N=10)

In [17]:
θ_mar_bfe, bfe, bmark_bfe = inference_cointoss_bfe(data, α, β; nr_samples=10);
println(bfe[end])
bmark_bfe

Marginal(Beta{Float64}(α=9.0, β=3.5))


KeepActor{Float64}([6.523305898735957, 6.523305898735957, 6.523305898735957, 6.523305898735957, 6.523305898735957, 6.523305898735957, 6.523305898735957, 6.523305898735957, 6.523305898735957, 6.523305898735957  …  6.523305898735957, 6.523305898735957, 6.523305898735957, 6.523305898735957, 6.523305898735957, 6.523305898735957, 6.523305898735957, 6.523305898735957, 6.523305898735957, 6.523305898735957])

#### Scale factors (N=100)

In [None]:
data = generate_data(rng, θ_true; nr_samples=100);

In [None]:
θ_mar_sf, bmark_sf = inference_cointoss_scalefactor(data, α, β; nr_samples=100);
println(-θ_mar_sf[end].data.scale)
bmark_sf

#### Bethe free energy (N=100)

In [None]:
θ_mar_bfe, bfe, bmark_bfe = inference_cointoss_bfe(data, α, β; nr_samples=100);
println(bfe[end])
bmark_bfe

#### Scale factors (N=1000)

In [None]:
data = generate_data(rng, θ_true; nr_samples=1000);

In [None]:
θ_mar_sf, bmark_sf = inference_cointoss_scalefactor(data, α, β; nr_samples=1000);
println(-θ_mar_sf[end].data.scale)
bmark_sf

#### Bethe free energy (N=1000)

In [None]:
θ_mar_bfe, bfe, bmark_bfe = inference_cointoss_bfe(data, α, β; nr_samples=1000);
println(bfe[end])
bmark_bfe