Variational message passing demo
===

ForneyLab comes with support for variational message passing (VMP) (Dauwels, 2007). In this demo we illustrate VMP by estimating the mean and precision (inverse variance) of i.i.d. samples drawn from a Gaussian distribution. The model is easily defined as 
\begin{align*}
    y_i \sim \mathcal{N}(m, w^{-1})
\end{align*}
The factor graph below shows our generative model.

```
-------------> = ---> (w)
               | 
---> = --------|----> (m)
     |         |
     ---> N <---
          | 
        (y_i)
```

Variational inference approximates the posterior over $m$ and $w$ by a recognition distribution
\begin{align*}
    p(m, w | y_{1:n}) \approx q(m)\times q(w)
\end{align*}
We minimize the KL divergence between the exact posterior and the recognition distribution by variational message passing. 

In [1]:
using(ForneyLab)

n = 10

g = FactorGraph()

# Priors
m ~ GaussianMeanVariance(constant(0.0), constant(100.0))
w ~ Gamma(constant(0.01), constant(0.01))

# Observarion model
for i = 1:n
    y_i ~ GaussianMeanPrecision(m, w)
    placeholder(y_i, :y, index=i)
end

# Assign id for ease of lookup
m.id = :m
w.id = :w
;

:w

With the model defined, we can now specify the recognition factorization, generate a schedule for each recognition factor and compile these schedules to Julia code.

In [2]:
# Specify recognition factorization
q_m = RecognitionFactor(m)
q_w = RecognitionFactor(w)

# Generate schedules
schedule_q_m = variationalSchedule(q_m)
schedule_q_w = variationalSchedule(q_w)

# Convert schedules to Julia executable code
algo_q_m = messagePassingAlgorithm(schedule_q_m, m)
algo_q_m = replace(algo_q_m, "step!", "stepM!")
algo_q_w = messagePassingAlgorithm(schedule_q_w, w)
algo_q_w = replace(algo_q_w, "step!", "stepW!");

# Inspect the algorithm code
println(algo_q_m)
println("\n")
println(algo_q_w)

function stepM!(marginals::Dict, data::Dict)

messages = Array{Message}(20)

messages[1] = ruleVBGaussianMeanVariance3(ProbabilityDistribution(PointMass, m=0.0), ProbabilityDistribution(PointMass, m=100.0), nothing)
messages[2] = ruleVBGaussianMeanPrecision1(nothing, marginals[:w], ProbabilityDistribution(PointMass, m=data[:y][1]))
messages[3] = ruleVBGaussianMeanPrecision1(nothing, marginals[:w], ProbabilityDistribution(PointMass, m=data[:y][3]))
messages[4] = ruleVBGaussianMeanPrecision1(nothing, marginals[:w], ProbabilityDistribution(PointMass, m=data[:y][4]))
messages[5] = ruleVBGaussianMeanPrecision1(nothing, marginals[:w], ProbabilityDistribution(PointMass, m=data[:y][6]))
messages[6] = ruleVBGaussianMeanPrecision1(nothing, marginals[:w], ProbabilityDistribution(PointMass, m=data[:y][8]))
messages[7] = ruleVBGaussianMeanPrecision1(nothing, marginals[:w], ProbabilityDistribution(PointMass, m=data[:y][10]))
messages[8] = ruleVBGaussianMeanPrecision1(nothing, marginals[:w], Probabil

We can now iteratively execute the updates for each factor, and inspect the results.

In [3]:
eval(parse(algo_q_m))
eval(parse(algo_q_w))

# Toy dataset
m_true = 3.0
w_true = 4.0
y = sqrt(1/w_true)*randn(n) + m_true
data = Dict(:y => y)

# Initial recognition distributions
marginals = Dict(:m => ProbabilityDistribution(Gaussian, m=0.0, v=100.0),
                 :w => ProbabilityDistribution(Gamma, a=0.01, b=0.01))

n_its = 50
for i = 1:n_its
   stepM!(marginals, data)
   stepW!(marginals, data)    
end
;

In [4]:
# Inspect the results
println("True mean: $(m_true)")
println("True precision: $(w_true)")
println("Number of samples: $(n)")
println("Sample mean: $(round(mean(y),2))")
println("Sample precision: $(round(1/var(y),2))")
println("\n----- Estimation after $(n_its) VMP updates -----")
println("Mean estimate: $(round(mean(marginals[:m]),2)), with variance $(round(var(marginals[:m]),2))")
println("Precision estimate: $(round(mean(marginals[:w]),2)), with variance $(round(var(marginals[:w]),2))")

True mean: 3.0
True precision: 4.0
Number of samples: 10
Sample mean: 2.85
Sample precision: 2.5

----- Estimation after 50 VMP updates -----
Mean estimate: 2.85, with variance 0.04
Precision estimate: 2.49, with variance 1.24
