In [1]:
using Revise
using Rocket, GraphPPL, ReactiveMP, Distributions, SpecialFunctions
using Plots, LinearAlgebra, Random

┌ Info: Precompiling ReactiveMP [a194aa59-28ba-4574-a09c-4a745416d6e3]
└ @ Base loading.jl:1423


## Generate data

In [2]:
#Generate data
function generate_data(rng, A, B, Q, P)
    x_prev = [ 10.0, -10.0 ]

    x = Vector{Vector{Float64}}(undef, n)
    y = Vector{Vector{Float64}}(undef, n)

    for i in 1:n
        x[i] = rand(rng, MvNormal(A * x_prev, Q))
        y[i] = rand(rng, MvNormal(B * x[i], P))
        x_prev = x[i]
    end
    
    return x, y
end

generate_data (generic function with 1 method)

In [3]:
# Seed for reproducibility
seed = 1234

rng = MersenneTwister(1234)


#θ = π/35; 
A = [ 1.2 1.7; 0 1 ]
B = diageye(2)
Q = diageye(2)
P = 25.0 .* diageye(2)

# Number of observations
n = 50;

In [4]:
x, y = generate_data(rng, A, B, Q, P);

In [5]:
x0 = MvNormalMeanCovariance(zeros(2), 100.0 * diageye(2));

## Inference by ReactiveMP (involving Scale factor)
In this section, we add Scale factor update rule to ReactiveMP, then we compute the model evidence by scale factors and compare the computational time with the previous ReactiveMP without scale factors.

In [6]:
#create rules (both message and scalefactor) for Normal node
@rule MvNormalMeanCovariance(:μ, Marginalisation) (m_out::PointMass, m_Σ::PointMass, meta::ScaleFactorMeta) = begin
    message = @call_rule MvNormalMeanCovariance(:μ, Marginalisation) (m_out = m_out, m_Σ = m_Σ) 
    scale = 0.0
    return ScaledMessage(message,scale)
end

@rule MvNormalMeanCovariance(:μ, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_Σ::PointMass, meta::ScaleFactorMeta) = begin
    message = @call_rule MvNormalMeanCovariance(:μ, Marginalisation) (m_out = m_out, m_Σ = m_Σ) 
    scale = 0.0
    return ScaledMessage(message,scale)
end


@rule MvNormalMeanCovariance(:out, Marginalisation) (m_μ::PointMass, m_Σ::PointMass, meta::ScaleFactorMeta) = begin
    message = @call_rule MvNormalMeanCovariance(:out, Marginalisation) (m_μ = m_μ, m_Σ = m_Σ) 
    scale = 0.0
    return ScaledMessage(message,scale)
end

@rule MvNormalMeanCovariance(:out, Marginalisation) (m_μ::MultivariateNormalDistributionsFamily, m_Σ::PointMass, meta::ScaleFactorMeta) = begin
    message = @call_rule MvNormalMeanCovariance(:out, Marginalisation) (m_μ = m_μ, m_Σ = m_Σ) 
    scale = 0.0
    return ScaledMessage(message,scale)
end

@rule MvNormalMeanCovariance(:out, Marginalisation) (m_μ::ScaledMessage, m_Σ::PointMass, meta::ScaleFactorMeta) = begin 
    message = @call_rule MvNormalMeanCovariance(:out, Marginalisation) (m_μ = m_μ.message, m_Σ = m_Σ) 
    scale = m_μ.scale
    return ScaledMessage(message,scale)
end

@rule MvNormalMeanCovariance(:μ, Marginalisation) (m_out::ScaledMessage, m_Σ::PointMass, meta::ScaleFactorMeta) = begin 
    message = @call_rule MvNormalMeanCovariance(:μ, Marginalisation) (m_out = m_out.message, m_Σ = m_Σ) 
    scale = m_out.scale
    return ScaledMessage(message, scale)
end

In [12]:
@rule typeof(*)(:out, Marginalisation) (m_A::PointMass, m_in::ScaledMessage, meta::ScaleFactorMeta) = begin 
    message = @call_rule typeof(*)(:out, Marginalisation) (m_A = m_A, m_in = m_in.message, meta=TinyCorrection())
    scale = m_in.scale

    return ScaledMessage(message, scale)
end

@rule typeof(*)(:in, Marginalisation) (m_out::ScaledMessage, m_A::PointMass, meta::ScaleFactorMeta) = begin
    A = mean(m_A)     
    message = @call_rule typeof(*)(:in, Marginalisation) (m_out = m_out.message, m_A = m_A, meta=TinyCorrection())
    scale = m_out.scale + logdet(A)
    return ScaledMessage(message, scale)
end

# @rule typeof(*)(:in, Marginalisation) (m_out::PointMass, m_A::PointMass) = PointMass(mean(inv, m_A) * mean(m_out))

In [13]:
#Product function for equality node
function ReactiveMP.prod(::ProdAnalytical, left::ScaledMessage{ <: MultivariateNormalDistributionsFamily }, right::ScaledMessage{ <: MultivariateNormalDistributionsFamily })
    mean_left, var_left = mean_cov(left.message)
    mean_right, var_right = mean_cov(right.message)
    n = length(mean_left)
    m, V = mean_left - mean_right, var_left + var_right

    message = prod(ProdAnalytical(),left.message,right.message)
    scale = left.scale + right.scale + 0.5*logdet(V) + n/2*log(2π) + m'*inv(V)*m/2

    return ScaledMessage(message,scale)
end

In [14]:
#Define model with meta = scalefactormeta
@model [default_meta=ScaleFactorMeta() ] function rotate_ssm_scalefactor(n, x0, A, B, Q, P)
    # We create constvar references for better efficiency
    cA = constvar(A)
    cB = constvar(B)
    cQ = constvar(Q)
    cP = constvar(P)
        
    # `x` is a sequence of hidden states
    x = randomvar(n)
    # `y` is a sequence of "clamped" observations
    y = datavar(Vector{Float64}, n)
        
    x_prior ~ MvNormalMeanCovariance(mean(x0), cov(x0))
    x_prev = x_prior
        
    for i in 1:n
        x[i] ~ MvNormalMeanCovariance(cA * x_prev, cQ)
        y[i] ~ MvNormalMeanCovariance(cB * x[i], cP)
        x_prev = x[i]
    end
        
    return x, y
end

rotate_ssm_scalefactor (generic function with 1 method)

In [15]:
#Now we do inference
function inference(data, x0, A, B, Q, P)

    # We create a model and get references for 
    # hidden states and observations
    model, (x, y) = rotate_ssm_scalefactor(n, x0, A, B, Q, P);

    xbuffer   = buffer(Marginal, n)
    
    # We subscribe on posterior marginals of `x`
    xsubscription = subscribe!(getmarginals(x), xbuffer)
    # `update!` updates our clamped datavars
    update!(y, data)

    # It is important to always unsubscribe
    unsubscribe!(xsubscription)
    
    return xbuffer
end

inference (generic function with 1 method)

In [16]:
xmarginals_sf = inference(y, x0, A, B, Q, P);

In [17]:
xmarginals_sf

BufferActor{Marginal, Vector{Marginal}}(Marginal[Marginal(ScaledMessage{MvNormalWeightedMeanPrecision{Float64, Vector{Float64}, Matrix{Float64}}}(MvNormalWeightedMeanPrecision(
xi: [-5.154798743055691, -15.826890766427399]
Λ: [0.22593741095105907 0.4271073362933909; 0.427107336293391 1.5240023527732014]
)
, 329.13168412463995)), Marginal(ScaledMessage{MvNormalWeightedMeanPrecision{Float64, Vector{Float64}, Matrix{Float64}}}(MvNormalWeightedMeanPrecision(
xi: [-8.826993661175004, -22.664963725694136]
Λ: [0.2473311895284936 0.385886253505809; 0.385886253505809 1.6553892936970163]
)
, 329.13168412464483)), Marginal(ScaledMessage{MvNormalWeightedMeanPrecision{Float64, Vector{Float64}, Matrix{Float64}}}(MvNormalWeightedMeanPrecision(
xi: [-13.504392336002036, -29.785140134515945]
Λ: [0.2547944140144844 0.3578381384233469; 0.35783813842334694 1.8153758927335981]
)
, 329.13168412469656)), Marginal(ScaledMessage{MvNormalWeightedMeanPrecision{Float64, Vector{Float64}, Matrix{Float64}}}(MvNormal

## Inference by regular ReactiveMP

In [18]:
@model function rotate_ssm(n, x0, A, B, Q, P)
    
    # We create constvar references for better efficiency
    cA = constvar(A)
    cB = constvar(B)
    cQ = constvar(Q)
    cP = constvar(P)
    
    # `x` is a sequence of hidden states
    x = randomvar(n)
    # `y` is a sequence of "clamped" observations
    y = datavar(Vector{Float64}, n)
    
    x_prior ~ MvNormalMeanCovariance(mean(x0), cov(x0))
    x_prev = x_prior
    
    for i in 1:n
        x[i] ~ MvNormalMeanCovariance(cA * x_prev, cQ)
        y[i] ~ MvNormalMeanCovariance(cB * x[i], cP)
        x_prev = x[i]
    end
    
    return x, y
end

rotate_ssm (generic function with 1 method)

In [19]:
function inference(data, x0, A, B, Q, P)

    # We create a model and get references for 
    # hidden states and observations
    model, (x, y) = rotate_ssm(n, x0, A, B, Q, P);

    xbuffer   = buffer(Marginal, n)
    bfe       = nothing
    
    # We subscribe on posterior marginals of `x`
    xsubscription = subscribe!(getmarginals(x), xbuffer)
    # We are also intereset in BetheFreeEnergy functional,
    # which in this case is equal to minus log evidence
    fsubcription = subscribe!(score(BetheFreeEnergy(), model), (v) -> bfe = v)

    # `update!` updates our clamped datavars
    update!(y, data)

    # It is important to always unsubscribe
    unsubscribe!((xsubscription, fsubcription))
    
    return xbuffer, bfe
end

inference (generic function with 1 method)

In [20]:
xmarginals, bfe = inference(y, x0, A, B, Q, P);

In [21]:
bfe

329.1316825388758

In [29]:
println("ScaleFactor approximates BFE: ", isapprox(xmarginals_sf[end].data.scale,bfe))

ScaleFactor approximates BFE: true


In [None]:
mz = A*mean(x0)
vz = A*cov(x0)*A'
my = B*mz
vy = B*(vz + Q)*B' + P
logpdf(MvNormalMeanCovariance(my, vy),y[1]) 

In [None]:
bfe - me_scf

In [22]:
logdet(A)*2

0.3646431135879092

In [23]:
@model function example(A)
    
    # We create constvar references for better efficiency
    cA = constvar(A)
    
    # `x` is a sequence of hidden states
    x ~ MvNormalMeanCovariance([1., 1.], [2. 0.;0. 1.])
    # `y` is a sequence of "clamped" observations
    y = datavar(Vector{Float64})
    
    y~ MvNormalMeanPrecision(A*x, [1e10 0; 0 1e10])
    
    return x, y
end

example (generic function with 1 method)

In [24]:
function inference(data, A)

    # We create a model and get references for 
    # hidden states and observations
    model, (x, y) = example(A);

    xbuffer   = keep(Marginal)
    bfe       = nothing
    
    # We subscribe on posterior marginals of `x`
    xsubscription = subscribe!(getmarginal(x), xbuffer)
    # We are also intereset in BetheFreeEnergy functional,
    # which in this case is equal to minus log evidence
    fsubcription = subscribe!(score(BetheFreeEnergy(), model), (v) -> bfe = v)

    # `update!` updates our clamped datavars
    update!(y, data)

    # It is important to always unsubscribe
    unsubscribe!((xsubscription, fsubcription))
    
    return xbuffer, bfe
end

inference (generic function with 2 methods)

In [25]:
# #Product function for equality node
# function ReactiveMP.prod(::ProdAnalytical, left::MultivariateNormalDistributionsFamily, right::PointMass)

#     return left
# end
# @marginalrule MvNormalMeanCovariance(:out_μ_Σ) (m_out::PointMass, m_μ::PointMass, m_Σ::PointMass, ) = begin 
#     return m_out, m_μ, m_Σ
# end

# @average_energy MvNormalMeanCovariance (q_out_μ_Σ::Any,) = begin
#     # naive: @views (d*log2π + mean(logdet, q_Σ) + tr(cholinv(mean(q_Σ))*( V[1:d,1:d] - V[1:d,d+1:end] - V[d+1:end,1:d] + V[d+1:end,d+1:end] + (m[1:d] - m[d+1:end])*(m[1:d] - m[d+1:end])' ))) / 2
#     println(q_out_μ_Σ)
#     return 0
# end

# @average_energy MvNormalMeanCovariance (q_out_μ_Σ::Any,) = begin
#     dim = ndims(q_out_μ_Σ[1])
#     m_mean = mean(q_out_μ_Σ[2])
#     m_out   = mean(q_out_μ_Σ[1])
#     return (dim * log(2π) + mean(logdet, q_out_μ_Σ[3]) + tr(cholinv(mean(q_out_μ_Σ[3]))*((m_out - m_mean)*(m_out - m_mean)'))) / 2
# end

In [26]:
xmarginals, bfe = inference([3., 2.], A);

In [27]:
-bfe

-3.311216657840774

In [28]:
logpdf(MvNormalMeanCovariance([1., 1.], [2. 0.;0. 1.]),A\[3., 2.]) - logdet(A)

-3.3112166579277167

In [29]:
logpdf(MvNormalMeanCovariance(A*[1., 1.], A*[2. 0.;0. 1.]*A'),[3., 2.]) 

-3.3112166579277176