In [1]:
using Rocket, GraphPPL, ReactiveMP, Distributions, SpecialFunctions
using Plots, LinearAlgebra, Random
using Revise

## Generate data

In [12]:
#Generate data
function generate_data(rng, A, B, Q, P)
    x_prev = [ 10.0, -10.0 ]

    x = Vector{Vector{Float64}}(undef, n)
    y = Vector{Vector{Float64}}(undef, n)

    for i in 1:n
        x[i] = rand(rng, MvNormal(A * x_prev, Q))
        y[i] = rand(rng, MvNormal(B * x[i], P))
        x_prev = x[i]
    end
    
    return x, y
end

generate_data (generic function with 1 method)

In [39]:
# Seed for reproducibility
seed = 1234

rng = MersenneTwister(1234)


#θ = π/35; 
A = [ 1.2 1.7; 0 1 ]
B = diageye(2)
Q = diageye(2)
P = 25.0 .* diageye(2)

# Number of observations
n = 50;

In [40]:
logdet(A)

0.1823215567939546

In [41]:
logdet(B)

0.0

In [42]:
x, y = generate_data(rng, A, B, Q, P);

In [43]:
x0 = MvNormalMeanCovariance(zeros(2), 100.0 * diageye(2));

## Inference by ReactiveMP (involving Scale factor)
In this section, we add Scale factor update rule to ReactiveMP, then we compute the model evidence by scale factors and compare the computational time with the previous ReactiveMP without scale factors.

In [44]:
#Define a new structure, which consists of a message and a -log(scalefactor)
struct MessageWithScaleFactor{T}
    message :: T
    scalefactor :: Float64
end

In [45]:
#Define a structure for scalefactor
#When we use this, ReactiveMP will switch to MessageWithScaleFactor backend
struct scalefactormeta end

In [46]:
#create rules (both message and scalefactor) for Normal node
@rule MvNormalMeanCovariance(:μ, Marginalisation) (m_out::PointMass, m_Σ::PointMass, meta::scalefactormeta) = begin
    message = @call_rule MvNormalMeanCovariance(:μ, Marginalisation) (m_out = m_out, m_Σ = m_Σ) 
    scalefactor = 0.0
    return MessageWithScaleFactor(message,scalefactor)
end

@rule MvNormalMeanCovariance(:μ, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_Σ::PointMass, meta::scalefactormeta) = begin
    message = @call_rule MvNormalMeanCovariance(:μ, Marginalisation) (m_out = m_out, m_Σ = m_Σ) 
    scalefactor = 0.0
    return MessageWithScaleFactor(message,scalefactor)
end


@rule MvNormalMeanCovariance(:out, Marginalisation) (m_μ::PointMass, m_Σ::PointMass, meta::scalefactormeta) = begin
    message = @call_rule MvNormalMeanCovariance(:out, Marginalisation) (m_μ = m_μ, m_Σ = m_Σ) 
    scalefactor = 0.0
    return MessageWithScaleFactor(message,scalefactor)
end

@rule MvNormalMeanCovariance(:out, Marginalisation) (m_μ::MultivariateNormalDistributionsFamily, m_Σ::PointMass, meta::scalefactormeta) = begin
    message = @call_rule MvNormalMeanCovariance(:out, Marginalisation) (m_μ = m_μ, m_Σ = m_Σ) 
    scalefactor = 0.0
    return MessageWithScaleFactor(message,scalefactor)
end

@rule MvNormalMeanCovariance(:out, Marginalisation) (m_μ::MessageWithScaleFactor, m_Σ::PointMass, meta::scalefactormeta) = begin 
    message = @call_rule MvNormalMeanCovariance(:out, Marginalisation) (m_μ = m_μ.message, m_Σ = m_Σ) 
    scalefactor = m_μ.scalefactor
    return MessageWithScaleFactor(message,scalefactor)
end

@rule MvNormalMeanCovariance(:μ, Marginalisation) (m_out::MessageWithScaleFactor, m_Σ::PointMass, meta::scalefactormeta) = begin 
    message = @call_rule MvNormalMeanCovariance(:μ, Marginalisation) (m_out = m_out.message, m_Σ = m_Σ) 
    scalefactor = m_out.scalefactor
    return MessageWithScaleFactor(message, scalefactor)
end

In [58]:
@rule typeof(*)(:out, Marginalisation) (m_A::PointMass, m_in::MessageWithScaleFactor, meta::TinyCorrection) = begin 
    A = mean(m_A)
    message = @call_rule typeof(*)(:out, Marginalisation) (m_A = m_A, m_in = m_in.message, meta=meta)
    scalefactor = m_in.scalefactor 

    return MessageWithScaleFactor(message, scalefactor)
end

@rule typeof(*)(:in, Marginalisation) (m_out::MessageWithScaleFactor, m_A::PointMass, meta::TinyCorrection) = begin
     
    message = @call_rule typeof(*)(:in, Marginalisation) (m_out = m_out.message, m_A = m_A, meta=meta)
    scalefactor = m_out.scalefactor + logdet(A)
    return MessageWithScaleFactor(message, scalefactor)
end
@rule typeof(*)(:in, Marginalisation) (m_out::PointMass, m_A::PointMass) = PointMass(mean(inv, m_A) * mean(m_out))

In [59]:
#Product function for equality node
function ReactiveMP.prod(::ProdAnalytical, left::MessageWithScaleFactor{ <: MultivariateNormalDistributionsFamily }, right::MessageWithScaleFactor{ <: MultivariateNormalDistributionsFamily })
    mean_left, var_left = mean_cov(left.message)
    mean_right, var_right = mean_cov(right.message)
    n = length(mean_left)
    m, V = mean_left - mean_right, var_left + var_right

    message = prod(ProdAnalytical(),left.message,right.message)
    scalefactor = left.scalefactor + right.scalefactor + 0.5*logdet(V) + n/2*log(2π) + m'*inv(V)*m/2

    return MessageWithScaleFactor(message,scalefactor)
end

In [60]:
#Define model with meta = scalefactormeta
@model function rotate_ssm_scalefactor(n, x0, A, B, Q, P)
    meta = scalefactormeta();
    # We create constvar references for better efficiency
    cA = constvar(A)
    cB = constvar(B)
    cQ = constvar(Q)
    cP = constvar(P)
        
    # `x` is a sequence of hidden states
    x = randomvar(n)
    # `y` is a sequence of "clamped" observations
    y = datavar(Vector{Float64}, n)
        
    x_prior ~ MvNormalMeanCovariance(mean(x0), cov(x0)) where {meta = meta}
    x_prev = x_prior
        
    for i in 1:n
        x[i] ~ MvNormalMeanCovariance(cA * x_prev, cQ) where {meta=meta}
        y[i] ~ MvNormalMeanCovariance(cB * x[i], cP) where {meta=meta}
        x_prev = x[i]
    end
        
    return x, y
end

rotate_ssm_scalefactor (generic function with 1 method)

In [61]:
#Now we do inference
function inference(data, x0, A, B, Q, P)

    # We create a model and get references for 
    # hidden states and observations
    model, (x, y) = rotate_ssm_scalefactor(n, x0, A, B, Q, P);

    xbuffer   = buffer(Marginal, n)
    
    # We subscribe on posterior marginals of `x`
    xsubscription = subscribe!(getmarginals(x), xbuffer)
    # `update!` updates our clamped datavars
    update!(y, data)

    # It is important to always unsubscribe
    unsubscribe!(xsubscription)
    
    return xbuffer
end

inference (generic function with 2 methods)

In [62]:
xmarginals_sf = inference(y, x0, A, B, Q, P);

In [63]:
xmarginals_sf

BufferActor{Marginal,Array{Marginal,1}}(Marginal[Marginal(MessageWithScaleFactor{MvNormalWeightedMeanPrecision{Float64,Array{Float64,1},Array{Float64,2}}}(MvNormalWeightedMeanPrecision(
xi: [-5.154798743055683, -15.826890766427375]
Λ: [0.2259374109510591 0.42710733629339115; 0.42710733629339115 1.524002352773202]
)
, 338.24776196433646)), Marginal(MessageWithScaleFactor{MvNormalWeightedMeanPrecision{Float64,Array{Float64,1},Array{Float64,2}}}(MvNormalWeightedMeanPrecision(
xi: [-8.826993661174992, -22.664963725694108]
Λ: [0.24733118952849362 0.3858862535058092; 0.3858862535058092 1.6553892936970167]
)
, 338.24776196434135)), Marginal(MessageWithScaleFactor{MvNormalWeightedMeanPrecision{Float64,Array{Float64,1},Array{Float64,2}}}(MvNormalWeightedMeanPrecision(
xi: [-13.504392336002027, -29.785140134515927]
Λ: [0.25479441401448444 0.3578381384233471; 0.3578381384233471 1.8153758927335988]
)
, 338.24776196439313)), Marginal(MessageWithScaleFactor{MvNormalWeightedMeanPrecision{Float64,Arra

## Inference by regular ReactiveMP

In [53]:
@model function rotate_ssm(n, x0, A, B, Q, P)
    
    # We create constvar references for better efficiency
    cA = constvar(A)
    cB = constvar(B)
    cQ = constvar(Q)
    cP = constvar(P)
    
    # `x` is a sequence of hidden states
    x = randomvar(n)
    # `y` is a sequence of "clamped" observations
    y = datavar(Vector{Float64}, n)
    
    x_prior ~ MvNormalMeanCovariance(mean(x0), cov(x0))
    x_prev = x_prior
    
    for i in 1:n
        x[i] ~ MvNormalMeanCovariance(cA * x_prev, cQ)
        y[i] ~ MvNormalMeanCovariance(cB * x[i], cP)
        x_prev = x[i]
    end
    
    return x, y
end

rotate_ssm (generic function with 1 method)

In [54]:
function inference(data, x0, A, B, Q, P)

    # We create a model and get references for 
    # hidden states and observations
    model, (x, y) = rotate_ssm(n, x0, A, B, Q, P);

    xbuffer   = buffer(Marginal, n)
    bfe       = nothing
    
    # We subscribe on posterior marginals of `x`
    xsubscription = subscribe!(getmarginals(x), xbuffer)
    # We are also intereset in BetheFreeEnergy functional,
    # which in this case is equal to minus log evidence
    fsubcription = subscribe!(score(BetheFreeEnergy(), model), (v) -> bfe = v)

    # `update!` updates our clamped datavars
    update!(y, data)

    # It is important to always unsubscribe
    unsubscribe!((xsubscription, fsubcription))
    
    return xbuffer, bfe
end

inference (generic function with 2 methods)

In [55]:
xmarginals, bfe = inference(y, x0, A, B, Q, P);

In [56]:
bfe

329.13168253878126

In [57]:
me_scf = xmarginals_sf[end].data.scalefactor

329.1316824490589

In [444]:
bfe - me_scf

-0.36464311348635725

In [445]:
logdet(A)*2

0.3646431135879092

In [3]:
@model function example(A)
    
    # We create constvar references for better efficiency
    cA = constvar(A)
    
    # `x` is a sequence of hidden states
    x ~ MvNormalMeanCovariance([1., 1.], [2. 0.;0. 1.])
    # `y` is a sequence of "clamped" observations
    y = datavar(Vector{Float64})
    
    y~ MvNormalMeanPrecision(A*x, [1e10 0; 0 1e10])
    
    return x, y
end

example (generic function with 1 method)

In [4]:
function inference(data, A)

    # We create a model and get references for 
    # hidden states and observations
    model, (x, y) = example(A);

    xbuffer   = keep(Marginal)
    bfe       = nothing
    
    # We subscribe on posterior marginals of `x`
    xsubscription = subscribe!(getmarginal(x), xbuffer)
    # We are also intereset in BetheFreeEnergy functional,
    # which in this case is equal to minus log evidence
    fsubcription = subscribe!(score(BetheFreeEnergy(), model), (v) -> bfe = v)

    # `update!` updates our clamped datavars
    update!(y, data)

    # It is important to always unsubscribe
    unsubscribe!((xsubscription, fsubcription))
    
    return xbuffer, bfe
end

inference (generic function with 1 method)

In [498]:
# #Product function for equality node
# function ReactiveMP.prod(::ProdAnalytical, left::MultivariateNormalDistributionsFamily, right::PointMass)

#     return left
# end
# @marginalrule MvNormalMeanCovariance(:out_μ_Σ) (m_out::PointMass, m_μ::PointMass, m_Σ::PointMass, ) = begin 
#     return m_out, m_μ, m_Σ
# end

# @average_energy MvNormalMeanCovariance (q_out_μ_Σ::Any,) = begin
#     # naive: @views (d*log2π + mean(logdet, q_Σ) + tr(cholinv(mean(q_Σ))*( V[1:d,1:d] - V[1:d,d+1:end] - V[d+1:end,1:d] + V[d+1:end,d+1:end] + (m[1:d] - m[d+1:end])*(m[1:d] - m[d+1:end])' ))) / 2
#     println(q_out_μ_Σ)
#     return 0
# end

# @average_energy MvNormalMeanCovariance (q_out_μ_Σ::Any,) = begin
#     dim = ndims(q_out_μ_Σ[1])
#     m_mean = mean(q_out_μ_Σ[2])
#     m_out   = mean(q_out_μ_Σ[1])
#     return (dim * log(2π) + mean(logdet, q_out_μ_Σ[3]) + tr(cholinv(mean(q_out_μ_Σ[3]))*((m_out - m_mean)*(m_out - m_mean)'))) / 2
# end

In [5]:
xmarginals, bfe = inference([3., 2.], A);

In [6]:
bfe

3.311216657840774

In [8]:
-logpdf(MvNormalMeanCovariance([1., 1.], [2. 0.;0. 1.]),A\[3., 2.]) 

3.1288951011337622

In [9]:
-logpdf(MvNormalMeanCovariance(A*[1., 1.], A*[2. 0.;0. 1.]*A'),[3., 2.]) 

3.3112166579277176