## Import packages

In [1]:
using Revise
using Rocket, GraphPPL, ReactiveMP, Distributions, SpecialFunctions
using LinearAlgebra, Random, BenchmarkTools;

## Settings

In [2]:
# seed for reproducibility
rng = MersenneTwister(1234)

# state space parameters
A = [ 1.001 1.6; 0 1 ]
B = diageye(2)
Q = diageye(2)
P = 25.0 .* diageye(2)
z0 = MvNormalMeanCovariance(zeros(2), 100.0 * diageye(2))

# benchmark settings
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 60;

## Generate data

In [3]:
# generate data function
function generate_data(rng, A, B, Q, P; nr_samples=10)
    z_prev = [ 10.0, -10.0 ]

    z = Vector{Vector{Float64}}(undef, nr_samples)
    y = Vector{Vector{Float64}}(undef, nr_samples)

    for i in 1:nr_samples
        z[i] = rand(rng, MvNormal(A * z_prev, Q))
        y[i] = rand(rng, MvNormal(B * z[i], P))
        z_prev = z[i]
    end
    
    return z, y
end;

## ReactiveMP.jl scalefactor extension

In [4]:
using StatsFuns: log2π

In [5]:
function getall(dist::MvNormalMeanPrecision)
    μ, Λ = mean_precision(dist)
    ξ = Λ*μ
    Σ = cholinv(Λ)
    return μ, ξ, Σ, Λ
end

getall (generic function with 1 method)

In [6]:
function getall(dist::MvNormalMeanCovariance)
    μ, Σ = mean_cov(dist)
    Λ = cholinv(Σ)
    ξ = Λ*μ
    return μ, ξ, Σ, Λ
end

getall (generic function with 2 methods)

In [7]:
function getall(dist::MvNormalWeightedMeanPrecision)
    ξ, Λ = weightedmean_precision(dist)
    Σ = cholinv(Λ)
    μ = Σ*ξ
    return μ, ξ, Σ, Λ
end

getall (generic function with 3 methods)

In [8]:
#Product function for equality node
function ReactiveMP.prod(::ProdAnalytical, left::ScaledMessage{ <: MultivariateNormalDistributionsFamily }, right::ScaledMessage{ <: MultivariateNormalDistributionsFamily })
    μ_left, ξ_left, Σ_left, Λ_left = getall(left.message)
    μ_right, ξ_right, Σ_right, Λ_right = getall(right.message)

    n = length(μ_left)

    m = μ_left - μ_right
    V = Σ_left + Σ_right

    message = MvNormalWeightedMeanPrecision(ξ_left + ξ_right, Λ_left + Λ_right)

    iV, logdetV = ReactiveMP.cholinv_logdet(V)

    scale = left.scale + right.scale + logdetV/2 + n/2*log2π + dot(m,iV,m)/2

    return ScaledMessage(message,scale)
end

## Inference by ReactiveMP (scale factors)

In [9]:
# define model with meta = scalefactormeta
@model [ default_meta=ScaleFactorMeta() ] function model_lgssm_scalefactor(z0, A, B, Q, P; nr_samples=10)

    # we create constant variables for better efficiency
    cA = constvar(A)
    cB = constvar(B)
    cQ = constvar(Q)
    cP = constvar(P)
        
    # `z` is a sequence of hidden states
    z = randomvar(nr_samples)

    # `y` is a sequence of "clamped" observations
    y = datavar(Vector{Float64}, nr_samples)
        
    z_prior ~ MvNormalMeanCovariance(mean(z0), cov(z0))
    z_prev = z_prior
        
    for i in 1:nr_samples
        z[i] ~ MvNormalMeanCovariance(cA * z_prev, cQ)
        y[i] ~ MvNormalMeanCovariance(cB * z[i], cP)
        z_prev = z[i]
    end
        
    return z, y
end

model_lgssm_scalefactor (generic function with 1 method)

In [10]:
#Now we do inference
function inference_lgssm_scalefactor(data, z0, A, B, Q, P; nr_samples=10)

    # We create a model and get references for 
    # hidden states and observations
    model, (z, y) = model_lgssm_scalefactor(z0, A, B, Q, P; nr_samples=nr_samples, options = (limit_stack_depth = 500, ));

    zbuffer   = buffer(Marginal, nr_samples)
    
    # We subscribe on posterior marginals of `z`
    zsubscription = subscribe!(getmarginals(z), zbuffer)

    # `update!` updates our clamped datavars
    bmark_scalefactors = @benchmark update!($y, $data)

    # It is important to always unsubscribe
    unsubscribe!(zsubscription)
    
    return zbuffer, bmark_scalefactors
end

inference_lgssm_scalefactor (generic function with 1 method)

## Inference by ReactiveMP (Bethe free energy)

In [11]:
@model function model_lgssm_bfe(z0, A, B, Q, P; nr_samples=10)
    
    # We create constvar references for better efficiency
    cA = constvar(A)
    cB = constvar(B)
    cQ = constvar(Q)
    cP = constvar(P)
    
    # `z` is a sequence of hidden states
    z = randomvar(nr_samples)
    # `y` is a sequence of "clamped" observations
    y = datavar(Vector{Float64}, nr_samples)
    
    z_prior ~ MvNormalMeanCovariance(mean(z0), cov(z0))
    z_prev = z_prior
    
    for i in 1:nr_samples
        z[i] ~ MvNormalMeanCovariance(cA * z_prev, cQ)
        y[i] ~ MvNormalMeanCovariance(cB * z[i], cP)
        z_prev = z[i]
    end
    
    return z, y
end

model_lgssm_bfe (generic function with 1 method)

In [12]:
function inference_lgssm_bfe(data, z0, A, B, Q, P; nr_samples=10)
    
    # We create a model and get references for 
    # hidden states and observations
    model, (z, y) = model_lgssm_bfe(z0, A, B, Q, P; nr_samples=nr_samples, options = (limit_stack_depth = 500, ));

    zbuffer   = buffer(Marginal, nr_samples)
    bfe       = nothing
    
    # We subscribe on posterior marginals of `x`
    zsubscription = subscribe!(getmarginals(z), zbuffer)
    # We are also intereset in BetheFreeEnergy functional,
    # which in this case is equal to minus log evidence
    bfe_subcription = subscribe!(score(BetheFreeEnergy(), model), (v) -> bfe = v)

    # `update!` updates our clamped datavars
    bmark_bfe = @benchmark update!($y, $data)

    # It is important to always unsubscribe
    unsubscribe!((zsubscription, bfe_subcription))
    
    return zbuffer, bfe, bmark_bfe
end

inference_lgssm_bfe (generic function with 1 method)

## Correctness check and performance comparison

#### Scale factors (N=10)

In [13]:
_, data = generate_data(rng, A, B, Q, P; nr_samples=10);

In [14]:
xmarginals_sf, bmark_sf = inference_lgssm_scalefactor(data, z0, A, B, Q, P; nr_samples=10);
println(-xmarginals_sf[end].data.scale)
bmark_sf

-65.37723230272508


BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m128.100 μs[22m[39m … [35m  7.860 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 96.53%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m135.700 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m151.930 μs[22m[39m ± [32m243.586 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m6.29% ±  3.86%

  [39m▄[39m▇[39m█[34m▇[39m[39m▆[39m▅[39m▄[39m▄[39m▃[39m▃[32m▂[39m[39m▂[39m▃[39m▃[39m▁[39m▁[39m [39m▁[39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂
  [39m█[39m█[39m█

#### Bethe free energy (N=10)

In [15]:
zmarginals, bfe, bmark_bfe = inference_lgssm_bfe(data, z0, A, B, Q, P; nr_samples=10)
println(-bfe)
bmark_bfe

-65.3772323030775


BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m252.300 μs[22m[39m … [35m 12.388 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 93.62%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m285.000 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m323.744 μs[22m[39m ± [32m431.054 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m6.65% ±  4.91%

  [39m [39m▂[39m█[39m▆[39m▇[34m▅[39m[39m▂[39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▄[39m█[39m█

#### Scale factors (N=100)

In [16]:
_, data = generate_data(rng, A, B, Q, P; nr_samples=100);

In [17]:
xmarginals_sf, bmark_sf = inference_lgssm_scalefactor(data, z0, A, B, Q, P; nr_samples=100);
println(-xmarginals_sf[end].data.scale)
bmark_sf

-637.3876361386191

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m1.413 ms[22m[39m … [35m16.308 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% …  0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m1.528 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m1.760 ms[22m[39m ± [32m 1.029 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m6.83% ± 10.10%

  [39m█[34m▆[39m[39m▅[32m▄[39m[39m▃[39m▃[39m▂[39m▂[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂
  [39m█[34m█[39m[39m█[32m█[39m[39m█[39m█




#### Bethe free energy (N=100)

In [18]:
zmarginals, bfe, bmark_bfe = inference_lgssm_bfe(data, z0, A, B, Q, P; nr_samples=100)
println(-bfe)
bmark_bfe

#### Scale factors (N=1000)

In [None]:
_, data = generate_data(rng, A, B, Q, P; nr_samples=1000);

In [None]:
xmarginals_sf, bmark_sf = inference_lgssm_scalefactor(data, z0, A, B, Q, P; nr_samples=1000);
println(-xmarginals_sf[end].data.scale)
bmark_sf

#### Bethe free energy (N=1000)

In [None]:
zmarginals, bfe, bmark_bfe = inference_lgssm_bfe(data, z0, A, B, Q, P; nr_samples=1000)
println(-bfe)
bmark_bfe