In [1]:
using Revise
using Rocket, GraphPPL, ReactiveMP, Distributions, SpecialFunctions
using Plots, LinearAlgebra, Random, BenchmarkTools

┌ Info: Precompiling BenchmarkTools [6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf]
└ @ Base loading.jl:1423


## Generate data

In [2]:
#Generate data
function generate_data(rng, A, B, Q, P)
    x_prev = [ 10.0, -10.0 ]

    x = Vector{Vector{Float64}}(undef, n)
    y = Vector{Vector{Float64}}(undef, n)

    for i in 1:n
        x[i] = rand(rng, MvNormal(A * x_prev, Q))
        y[i] = rand(rng, MvNormal(B * x[i], P))
        x_prev = x[i]
    end
    
    return x, y
end

generate_data (generic function with 1 method)

In [60]:
# Seed for reproducibility
seed = 1234

rng = MersenneTwister(1234)


#θ = π/35; 
A = [ 1.2 1.7; 0 1 ]
B = diageye(2)
Q = diageye(2)
P = 25.0 .* diageye(2)

# Number of observations
n = 70;

BenchmarkTools.DEFAULT_PARAMETERS.seconds = 20

20

In [61]:
x, y = generate_data(rng, A, B, Q, P);

In [62]:
x0 = MvNormalMeanCovariance(zeros(2), 100.0 * diageye(2));

## Inference by ReactiveMP (involving Scale factor)
In this section, we add Scale factor update rule to ReactiveMP, then we compute the model evidence by scale factors and compare the computational time with the previous ReactiveMP without scale factors.

In [63]:
#Product function for equality node
function ReactiveMP.prod(::ProdAnalytical, left::ScaledMessage{ <: MultivariateNormalDistributionsFamily }, right::ScaledMessage{ <: MultivariateNormalDistributionsFamily })
    mean_left, var_left = mean_cov(left.message)
    mean_right, var_right = mean_cov(right.message)
    n = length(mean_left)
    m, V = mean_left - mean_right, var_left + var_right

    message = prod(ProdAnalytical(),left.message,right.message)
    scale = left.scale + right.scale + 0.5*logdet(V) + n/2*log(2π) + m'*inv(V)*m/2

    return ScaledMessage(message,scale)
end

In [64]:
#Define model with meta = scalefactormeta
@model [default_meta=ScaleFactorMeta() ] function rotate_ssm_scalefactor(n, x0, A, B, Q, P)
    # We create constvar references for better efficiency
    cA = constvar(A)
    cB = constvar(B)
    cQ = constvar(Q)
    cP = constvar(P)
        
    # `x` is a sequence of hidden states
    x = randomvar(n)
    # `y` is a sequence of "clamped" observations
    y = datavar(Vector{Float64}, n)
        
    x_prior ~ MvNormalMeanCovariance(mean(x0), cov(x0))
    x_prev = x_prior
        
    for i in 1:n
        x[i] ~ MvNormalMeanCovariance(cA * x_prev, cQ)
        y[i] ~ MvNormalMeanCovariance(cB * x[i], cP)
        x_prev = x[i]
    end
        
    return x, y
end

rotate_ssm_scalefactor (generic function with 1 method)

In [65]:
#Now we do inference
function inference(data, x0, A, B, Q, P)

    # We create a model and get references for 
    # hidden states and observations
    model, (x, y) = rotate_ssm_scalefactor(n, x0, A, B, Q, P);

    xbuffer   = buffer(Marginal, n)
    
    # We subscribe on posterior marginals of `x`
    xsubscription = subscribe!(getmarginals(x), xbuffer)
    # `update!` updates our clamped datavars
    bmark_scf = @benchmark update!($y, $data)

    # It is important to always unsubscribe
    unsubscribe!(xsubscription)
    
    return xbuffer, bmark_scf
end

inference (generic function with 1 method)

In [66]:
xmarginals_sf, bmark_scf = inference(y, x0, A, B, Q, P);

In [67]:
xmarginals_sf[end].data.scale

458.5427007954405

## Inference by regular ReactiveMP

In [68]:
@model function rotate_ssm(n, x0, A, B, Q, P)
    
    # We create constvar references for better efficiency
    cA = constvar(A)
    cB = constvar(B)
    cQ = constvar(Q)
    cP = constvar(P)
    
    # `x` is a sequence of hidden states
    x = randomvar(n)
    # `y` is a sequence of "clamped" observations
    y = datavar(Vector{Float64}, n)
    
    x_prior ~ MvNormalMeanCovariance(mean(x0), cov(x0))
    x_prev = x_prior
    
    for i in 1:n
        x[i] ~ MvNormalMeanCovariance(cA * x_prev, cQ)
        y[i] ~ MvNormalMeanCovariance(cB * x[i], cP)
        x_prev = x[i]
    end
    
    return x, y
end

rotate_ssm (generic function with 1 method)

In [69]:
function inference(data, x0, A, B, Q, P)

    # We create a model and get references for 
    # hidden states and observations
    model, (x, y) = rotate_ssm(n, x0, A, B, Q, P);

    xbuffer   = buffer(Marginal, n)
    bfe       = nothing
    
    # We subscribe on posterior marginals of `x`
    xsubscription = subscribe!(getmarginals(x), xbuffer)
    # We are also intereset in BetheFreeEnergy functional,
    # which in this case is equal to minus log evidence
    fsubcription = subscribe!(score(BetheFreeEnergy(), model), (v) -> bfe = v)

    # `update!` updates our clamped datavars
    bmark_reactmp = @benchmark update!($y, $data)

    # It is important to always unsubscribe
    unsubscribe!((xsubscription, fsubcription))
    
    return xbuffer, bfe, bmark_reactmp
end

inference (generic function with 1 method)

In [70]:
xmarginals, bfe, bmark_reactmp = inference(y, x0, A, B, Q, P);

In [71]:
bfe

458.5426925542269

In [72]:
#Calculate the difference between scale factor and bfe
xmarginals_sf[end].data.scale - bfe

8.241213606652309e-6

## Performance comparison by BenchmarkTools

In [73]:
# benchmark of ReactiveMP with scale factor
bmark_scf

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m1.255 ms[22m[39m … [35m15.101 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 87.82%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m1.423 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m1.670 ms[22m[39m ± [32m 1.432 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m10.02% ± 10.37%

  [39m█[34m [39m[32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m█[34m▇[39m[32m█[39m[39m▃[39m▂[3

In [74]:
# benchmark of regular ReactiveMP
bmark_reactmp

BenchmarkTools.Trial: 7213 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m2.005 ms[22m[39m … [35m19.530 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 81.50%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m2.538 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m2.761 ms[22m[39m ± [32m 1.742 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m7.52% ± 10.13%

  [39m█[39m▆[34m█[39m[32m▃[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m█[39m█[34m█[39m[32m█[39m[39m▆[39m▄