## Import packages

In [1]:
using Revise
using ReactiveMP, Rocket, GraphPPL, Distributions
using LinearAlgebra, Random
using BenchmarkTools

## Settings

In [2]:
# seed for reproducibility
rng = MersenneTwister(1234)

# model parameters
A = [0.6 0.1 0.2; 0.3 0.7 0.3; 0.1 0.2 0.5];    # Process transition
B = [0.8 0.25 0.1; 0.1 0.5 0.6; 0.1 0.25 0.3];  # Observation transition
z0 = [1.0, 0.0, 0.0]; # Initial state

# benchmark settings
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 60;

## Generate data

In [3]:
function onehot_vec(dist::Categorical, rng)
    k = ncategories(dist); # get the number of categories
    x = zeros(k);
    x[rand(rng, dist)] = 1.0;

    return x
end

function generate_data(rng, z0, A, B; nr_samples=10)

    z = Vector{Vector{Float64}}(undef, nr_samples) # one-hot encoded state
    y = Vector{Vector{Float64}}(undef, nr_samples) # one-hot encoded observation

    z_prev = z0;
    for t = 1:nr_samples
        z[t] = onehot_vec(Categorical(A*z_prev), rng)
        y[t] = onehot_vec(Categorical(B*z[t]), rng)
        z_prev = z[t]
    end
    
    return y
end;

## ReactiveMP.jl scale factor extension

In [4]:
@rule Transition(:out, Marginalisation) (m_in::Categorical, m_a::PointMass, meta::ScaleFactorMeta) = begin 
    message = @call_rule Transition(:out, Marginalisation) (m_in = m_in, m_a = m_a)
    scalefactor = 0.0
    return ScaledMessage(message, scalefactor)
end

@rule Transition(:out, Marginalisation) (m_in::ScaledMessage, m_a::PointMass, meta::ScaleFactorMeta) = begin 
    A = mean(m_a)
    message = @call_rule Transition(:out, Marginalisation) (m_in = m_in.message, m_a = m_a)
    scalefactor = m_in.scale 
    return ScaledMessage(message, scalefactor)
end

@rule Transition(:in, Marginalisation) (m_out::PointMass, m_a::PointMass, meta::ScaleFactorMeta) = begin 
    A = mean(m_a)
    message = Categorical((A' * probvec(m_out)) ./ sum(A' * probvec(m_out)))
    scalefactor = -log(sum(A' * probvec(m_out)))
    return ScaledMessage(message, scalefactor)
end

@rule Transition(:in, Marginalisation) (m_out::ScaledMessage, m_a::PointMass, meta::ScaleFactorMeta) = begin 
    A = mean(m_a)
    message = Categorical((A' * probvec(m_out.message)) ./ sum(A' * probvec(m_out.message)))
    scalefactor = m_out.scale - log(sum(A' * probvec(m_out.message)))

    return ScaledMessage(message, scalefactor)
end;

In [5]:
#Product function for equality node
function ReactiveMP.prod(::ProdAnalytical, left::ScaledMessage{ <: Categorical }, right::ScaledMessage{ <:Categorical })
    mean_left = probvec(left.message)
    mean_right = probvec(right.message)

    message = prod(ProdAnalytical(),left.message,right.message)
    scalefactor = left.scale + right.scale - log(dot(mean_left, mean_right))

    return ScaledMessage(message,scalefactor)
end;

## Inference by ReactiveMP (scale factors)

In [6]:
@model [ default_meta = ScaleFactorMeta() ] function model_hmm_scalefactor(A, B; nr_samples=10)
    #define variables
    z = randomvar(nr_samples)
    y = datavar(Vector{Float64},nr_samples)

    cA = constvar(A)
    cB = constvar(B)

    # define initial state
    z_init ~ Categorical([1/3, 1/3, 1/3]) 

    z_prev = z_init

    for n = 1:nr_samples
        z[n] ~ Transition(z_prev, cA) 
        y[n] ~ Transition(z[n], cB) 
        z_prev = z[n]
    end

    return z, y
end

model_hmm_scalefactor (generic function with 1 method)

In [7]:
function inference_hmm_scalefactor(data, A, B; nr_samples=10)

    model, (z, y) = model_hmm_scalefactor(A, B; nr_samples=nr_samples, options = (limit_stack_depth = 500, ));

    z_mar = keep(Vector{Marginal})

    z_subscript = subscribe!(getmarginals(z), z_mar)

    bmark_scf = @benchmark update!($y, $data)

    unsubscribe!(z_subscript)

    return z_mar, bmark_scf
end

inference_hmm_scalefactor (generic function with 1 method)

## Inference by ReactiveMP (Bethe free energy)

In [8]:
@model function model_hmm_bfe(A, B; nr_samples=10)
    #define variables
    z = randomvar(nr_samples)
    y = datavar(Vector{Float64}, nr_samples)

    cA = constvar(A)
    cB = constvar(B)
    
    # define initial state
    z_init ~ Categorical([1/3, 1/3, 1/3])

    z_prev = z_init

    for n = 1:nr_samples
        z[n] ~ Transition(z_prev, cA)
        y[n] ~ Transition(z[n], cB) 
        z_prev = z[n]
    end

    return z, y
end

model_hmm_bfe (generic function with 1 method)

In [9]:
function inference_hmm_bfe(data, A, B; nr_samples=10)
    
    model, (z, y) = model_hmm_bfe(A, B; nr_samples=nr_samples, options = (limit_stack_depth = 500, ));

    z_mar = keep(Vector{Marginal})
    FE = keep(Float64)

    z_subscript = subscribe!(getmarginals(z), z_mar)
    fe_sub = subscribe!(score(Float64,BetheFreeEnergy(), model), FE)

    bmark = @benchmark update!($y, $data)

    unsubscribe!((z_subscript, fe_sub))

    return z_mar, FE, bmark
end

inference_hmm_bfe (generic function with 1 method)

## Correctness check and performance comparison

#### Scale factors (N=10)

In [10]:
data = generate_data(rng, z0, A, B; nr_samples=10);

In [11]:
zmarginals_sf, bmark_sf = inference_hmm_scalefactor(data, A, B; nr_samples=10);
println(-zmarginals_sf[end][end].data.scale)
bmark_sf

-10.763796563572129


BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m35.100 μs[22m[39m … [35m512.500 μs[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m41.600 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m46.324 μs[22m[39m ± [32m 13.930 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m0.00% ± 0.00%

  [39m [39m▁[39m█[39m▇[39m█[39m▃[39m▃[34m▁[39m[39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▃[39m█[39m█[39m█[39m█

#### Bethe free energy (N=10)

In [12]:
zmarginals, bfe, bmark_bfe = inference_hmm_bfe(data, A, B; nr_samples=10)
println(-bfe[end])
bmark_bfe

-10.76379656357213


BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m 83.200 μs[22m[39m … [35m232.314 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 99.89%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m 96.600 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m127.695 μs[22m[39m ± [32m  2.322 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m18.17% ±  1.00%

  [39m [39m [39m▅[39m█[39m▆[34m▂[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▂[39m▆[39

#### Scale factors (N=100)

In [13]:
data = generate_data(rng, z0, A, B; nr_samples=100);

In [14]:
zmarginals_sf, bmark_sf = inference_hmm_scalefactor(data, A, B; nr_samples=100);
println(-zmarginals_sf[end][end].data.scale)
bmark_sf

-104.81619895041653


BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m409.000 μs[22m[39m … [35m449.752 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 99.74%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m471.150 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m643.427 μs[22m[39m ± [32m  7.234 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m19.41% ±  1.73%

  [39m [39m [39m [39m█[39m▇[39m [39m [34m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▁[39m▃[39

#### Bethe free energy (N=100)

In [15]:
zmarginals, bfe, bmark_bfe = inference_hmm_bfe(data, A, B; nr_samples=100)
println(-bfe[end])
bmark_bfe

-104.81619895041644


BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m861.500 μs[22m[39m … [35m422.511 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 99.64%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m981.600 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m  1.256 ms[22m[39m ± [32m  9.048 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m16.07% ±  2.23%

  [39m [39m [39m [39m▂[39m▃[39m▂[39m█[39m▃[39m▁[34m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▃[39m█[39

#### Scale factors (N=1000)

In [16]:
data = generate_data(rng, z0, A, B; nr_samples=1000);

In [17]:
zmarginals_sf, bmark_sf = inference_hmm_scalefactor(data, A, B; nr_samples=1000);
println(-zmarginals_sf[end][end].data.scale)
bmark_sf

-1058.9038469849415


BenchmarkTools.Trial: 2701 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m 5.630 ms[22m[39m … [35m   2.943 s[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 99.65%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m 8.372 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m22.218 ms[22m[39m ± [32m145.533 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m62.40% ± 26.63%

  [39m▆[39m█[34m█[39m[39m▅[39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂[39m▂[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁
  [39m█[39m█[34m█[39m[3

#### Bethe free energy (N=1000)

In [18]:
zmarginals, bfe, bmark_bfe = inference_hmm_bfe(data, A, B; nr_samples=1000)
println(-bfe[end])
bmark_bfe

-1058.9038469849293


BenchmarkTools.Trial: 3279 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m13.404 ms[22m[39m … [35m505.365 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 96.34%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m15.315 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m18.284 ms[22m[39m ± [32m 36.122 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m15.77% ±  7.69%

  [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m▃[39m▅[39m█[39m▆[39m▇[34m▆[39m[39m▇[39m▆[39m▃[39m▂[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▂[39m▂[39m▂[39m▂[