Naive variational message passing demo
===

ForneyLab comes with support for variational message passing (VMP). Variational methods often involve complicated and extensive derivations. VMP on a factor graph eases our derivation troubles by expressing the variational algorithm in terms of local update rules. In his 2007 article, Dauwels gives a generic introduction to VMP on a factor graph. ForneyLab implements this VMP approach.

In this demo we illustrate VMP by estimating the mean and variance of samples drawn from a Gaussian distribution. The factor graph below shows our generative model. We observe the samples $y = \{y_1, ..., y_k, ..., y_n\}$ and wish to estimate the posterior distribution over the mean $m$ and variance $s$ of these samples. We use the factor graph notation of (Reller, 2012; State-space methods in statistical signal processing) where filled black nodes represent observed variables and dotted arcs denote the repetition of a section.

<img src="images/gauss_est.png">

For estimation we use a variational message passing algorithm with a mean field factorized distribution over our variables $q(m,s,y) = q(m)\,q(s)\,q(y)$.

We start by building the graph.

In [1]:
using(ForneyLab)

# Initial settings
N              = 20 # Number of observed samples
n_its          = 50 # Number of vmp iterations
true_mean      = 5.0
true_variance  = 2.0
y_observations = sqrt(true_variance)*randn(N) + true_mean # y observation buffer

# Build graph
for k=1:N
    GaussianNode(id=:g*k) # s() for symbol concatenation
    EqualityNode(id=:m_eq*k) # Equality node chain for mean
    EqualityNode(id=:s_eq*k) # Equality node chain for variance
    TerminalNode(y_observations[k], id=:y*k) # Observed y values are stored in terminal node values

    Edge(n(:g*k).i[:out], n(:y*k), id=:y*k)
    Edge(n(:m_eq*k).i[3], n(:g*k).i[:mean], id=:m*k)
    Edge(n(:s_eq*k).i[3], n(:g*k).i[:variance], id=:s*k)
    if k > 1 # Connect sections
        Edge(n(:m_eq*(k-1)).i[2], n(:m_eq*k).i[1])
        Edge(n(:s_eq*(k-1)).i[2], n(:s_eq*k).i[1])
    end
end

# Attach beginning and end nodes
PriorNode(GaussianDistribution(m=0.0, V=100.0), id=:m0) # Prior
PriorNode(InverseGammaDistribution(a=0.01, b=0.01), id=:s0) # Prior
TerminalNode(vague(GaussianDistribution), id=:mN) # Vague distribution
TerminalNode(vague(InverseGammaDistribution), id=:sN) # Vague distribution

Edge(n(:m0), n(:m_eq1).i[1])
Edge(n(:s0), n(:s_eq1).i[1])
Edge(n(:m_eq*N).i[2], n(:mN))
Edge(n(:s_eq*N).i[2], n(:sN));

The estimation results are vailable after the last section. We set write buffers to collect these results.

In [2]:
# Specify some write buffers
m_out = attachWriteBuffer(n(:mN).i[:out].partner)
s_out = attachWriteBuffer(n(:sN).i[:out].partner);

We specify a variational Bayes algorithm with a naive (mean field) factorization. The factorization is passed as a dictionary of edges. Separate dictionary entries assign separate distribution types. The distribution type is assigned to each entry in the edge vector.

In [3]:
# Specify the variational algorithm for n_its vmp iterations
algo = VariationalBayes(Dict(   eg(:m*(1:N)) => GaussianDistribution,
                                eg(:s*(1:N)) => InverseGammaDistribution,
                                eg(:y*(1:N)) => GaussianDistribution),
                        n_iterations=n_its)

show(algo)

VariationalBayes inference algorithm
    number of factors: 22
    number of iterations: 50


In [4]:
draw(algo.factorization.factors[3])

Now we can iteratively execute these schedules and inspect the results. 

In [5]:
run(algo)

# Inspect the results
println("True mean: $(true_mean)")
println("True variance: $(true_variance)")
println("Number of samples: $(length(y_observations))")
println("Sample mean: $(round(mean(y_observations),2))")
println("Sample variance: $(round(var(y_observations),2))")
println("\n----- Estimation after $(n_its) VMP updates -----")
println("Mean estimate: $(round(mean(m_out[end])[1],2)), with variance $(round(var(m_out[end])[1, 1],2))")
println("Variance estimate: $(round(mean(s_out[end]),2)), with variance $(round(var(s_out[end]),2))")

True mean: 5.0
True variance: 2.0
Number of samples: 20
Sample mean: 4.86
Sample variance: 2.46

----- Estimation after 50 VMP updates -----
Mean estimate: 4.86, with variance 0.03
Variance estimate: 2.62, with variance 0.86
