In [3]:
using ReactiveMP, Rocket, Random, GraphPPL, Distributions, LinearAlgebra, SpecialFunctions
using BenchmarkTools

## Generate data

In [7]:
function bool2float(x::Vector{Bool})
    n = length(x)
    y = Vector{Float64}(undef,n)
    for i=1:n
        x[i] ? y[i] = 1.0 : y[i] = 0.0
    end
    return y
end

bool2float (generic function with 1 method)

In [8]:
#Generate data
Random.seed!(12)
N = 50; # number of coin tosses
θ_true = 0.7; # probability that we get a head
likelihood = Bernoulli(θ_true); # this is p(y|θ)

y_data = rand(likelihood, N) |> bool2float; # observations
n_head = sum(y_data); # number of heads. We should note that when a head comes up, then y = 1. When tail comes up, y = 0

In [9]:
#Define parameters for the Beta prior of θ
α, β = 1.0, 1.5

(1.0, 1.5)

## Inference by ReactiveMP (without scale factor)

In [10]:
@model function coin_toss(n)
    θ ~ Beta(α, β)
    y = datavar(Float64,n)
    for i=1:n
        y[i] ~ Bernoulli(θ)
    end
    
    return y, θ
end

coin_toss (generic function with 1 method)

In [11]:
function inference(data)
    n = length(data)
    model, (y,θ) = coin_toss(n)

    θ_mar = keep(Marginal)
    bfe = keep(Float64)

    θ_subscribe = subscribe!(getmarginal(θ), θ_mar);
    bfe_subscribe = subscribe!(score(Float64,BetheFreeEnergy(), model), bfe);

    bmark = @benchmark update!($y, $data)
    unsubscribe!((θ_subscribe, bfe_subscribe))

    return θ_mar, bfe, bmark
end

inference (generic function with 1 method)

In [12]:
θ_infer, bfe, bmark = inference(y_data);

## Inference by ReactiveMP with scale factor

In [15]:
#create rules (both message and scalefactor) for Bernoulli node
@rule Bernoulli(:p, Marginalisation) (m_out::PointMass, meta::ScaleFactorMeta, ) = begin 
    r = mean(m_out)
    message = Beta(one(r) + r, 2one(r) - r)
    scalefactor = -log(0.5)
    return ScaledMessage(message, scalefactor)
end

#create rules (both message and scalefactor) for Beta node
@rule Beta(:out, Marginalisation) (m_a::PointMass, m_b::PointMass, meta::ScaleFactorMeta) = begin
    message = @call_rule Beta(:out, Marginalisation) (m_a = m_a, m_b = m_b)
    scalefactor = 0.0
    return ScaledMessage(message, scalefactor)
end

In [20]:
#Product function for equality node
function ReactiveMP.prod(::ProdAnalytical, left::ScaledMessage{ <: Beta }, right::ScaledMessage{ <: Beta })
    a_left, b_left  = Distributions.params(left.message)
    a_right, b_right = Distributions.params(right.message)

    message = prod(ProdAnalytical(),left.message,right.message)
    scalefactor = left.scale + right.scale - log(beta(a_left + a_right - 1, b_left + b_right - 1)) +
                    log(beta(a_left, b_left)) + log(beta(a_right, b_right))

    return ScaledMessage(message,scalefactor)
end

In [17]:
@model [default_meta = ScaleFactorMeta()] function coin_toss_sf(n)
    θ ~ Beta(α, β)
    y = datavar(Float64,n)

    for i=1:n
        y[i] ~ Bernoulli(θ)
    end

    return y, θ
end

coin_toss_sf (generic function with 1 method)

In [21]:
function inference_sf(data)
    n = length(data)
    model, (y,θ) = coin_toss_sf(n)

    θ_mar_sf = keep(Marginal)

    θ_sub = subscribe!(getmarginal(θ),θ_mar_sf)

    bmark_scf = @benchmark update!($y, $data)
    unsubscribe!(θ_sub)

    return θ_mar_sf, bmark_scf
end

inference_sf (generic function with 1 method)

In [22]:
θ_infer_sf, bmark_scf = inference_sf(y_data);

## FE comparison

In [26]:
# FE computed by ReactiveMP without scalefactor
bfe.values[1]

35.46718331705813

In [29]:
# FE computed by Scale factor 
θ_infer_sf.values[1].data.scale

35.46718331705815

In [30]:
# The difference between 2 FE values
θ_infer_sf.values[1].data.scale - bfe.values[1]

2.1316282072803006e-14

## Performance comparison by BenchmarkTools

In [31]:
# benchmark of ReactiveMP without scale factor
bmark 

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m129.600 μs[22m[39m … [35m 21.299 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 94.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m137.700 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m204.100 μs[22m[39m ± [32m550.348 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m11.73% ±  4.45%

  [39m█[34m▅[39m[39m▃[39m▂[39m▁[39m▁[39m [39m [39m [39m [32m▆[39m[39m▂[39m▁[39m▁[39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁
  [39m█[34m█[39

In [32]:
# benchmark of ReactiveMP with scale factor 
bmark_scf

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m36.500 μs[22m[39m … [35m 13.768 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 99.50%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m38.000 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m46.502 μs[22m[39m ± [32m279.545 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m13.39% ±  2.22%

  [39m▅[39m█[39m▆[34m▅[39m[39m▅[39m▄[39m▄[39m▄[39m▃[39m▃[39m▃[39m▂[39m▁[39m▂[39m▁[39m▁[39m [32m [39m[39m [39m [39m [39m [39m▂[39m▃[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂
  [39m█[39m█[39m█[34m█