Code Demo for Variational Message Passing
=====================================

## Background

Bayesian inference is one of the widely established foundation for machine learning. While exact inference mostly intractble, approximation appraoches like *Monte Carlo method*, *expectation propagation (EP)* and *variational message passing (VMP)* performs iterative local inferences on decomposed graph.

*Monte Carlo* method suffers from intensive computation costs, which is usually intractible in practical settings. *EP* and *VMP* are very similar in the way that both of them attempt to optimize a proposed distribution $Q(X)$ to match the posterior distribution $P(X|D)$. The differences between them are following:

* *expecation propagation* optimize for the "inclusive" divergence $ \mathbf{KL}(P(X|D)\ ||\ Q(X)) $
    1. It uses approximation to summarize the entire posterior and ignores the modes. If the posterior is complex, with multiple conflicting solutions, this can lead to very broad approximations or non-convergence;
    2. Optimization may assign probability to all plausible outcomes, even when they only have zero probability (zero avoiding).

* *variational message passing* optimize for the "exclusive" divergence $ \mathbf{KL}(Q(X)\ ||\ P(X|D) $
    1. It is similar to EM and also garuntee to converge to some solution, as it optimizing for the evidence lower bound;
    2. Optimization may overly confidently fit into one mode of the distribution, but less likely to assign probability to zero probability regions (zero forcing)

![Zero forcing and zero avoiding](zeroforcing.png)

There are also some approches like *Power EP* that flexible enough to cover both sides, and you may find more infomation in the reference materials at the end of this notebook.

We already have mature inference libraries for methods above, like [Infer.NET](https://dotnet.github.io/infer/default.html), or [Bayspy](https://pypi.org/project/bayespy/) in Python. But they have multiple complicated abstraction layers that hides the core inference algorithm behind, which is not easy to read and to undestand.

This notebook forces on using miniature examples to demostrate basic *VMP* algorithm over gaussian distributions. It can chain multiple gaussian or gamma random varaibles to form a basic bayesian graph, and solve the posterial based on given observations and prior distribution. 
    

## Implementation

In [15]:
import numpy as np
import torch as tc
import scipy as sp
import scipy.special as sps
import math

*Variable* and *Distribution* lays the fundation of the miniature *VMP* framework. *Variable* are the nodes on the Bayesian graph and it can either be a constant value (used as hyperparameter in the prior distribution) or a random variable over a given distribution. *Distribution* is the base class for different distribution types. The supported sub-classes are in exponential family like Gaussian distribution, Gamma distribution.

*VMP* requires its distributions in the graph follows the conjugate rules, which is the prior distribution shall have the same function form as its posterior distribution.

The message passing inferences has *forward* and *backward* 2 stages, they will do inference along the DAG direction or passing messages backwards respectively.

In [16]:
#Variable includes hyper parmeters (constant) or parameters (based on distributions)
class Variable:
    def __init__(self, _name:str, _dist):
        self.name = _name
        # self.expectedValue is either const value (hyper-param) or a distribution's nature stats vector
        self.expectedValue = None
        self.dist = None
        self.refDists = list()
        self.messages = list()
        self.observation = None
        self.version = 0
        if isinstance(_dist, Distribution):
            _dist.registerOutputVariable(self)
        elif isinstance(_dist, (int, float)):
            self.expectedValue = _dist
    
    def __str__(self):
        return "Variable Name: {0}, Distribution: {1}, Expectation: {2}".format(self.name, (type(self.dist).__name__ if self.dist else "point-mass"), self.getExpValue())
    
    def __repr__(self):
        return "<{0}:{1}={2}>".format(self.name, (type(self.dist).__name__ if self.dist else "point-mass"), self.getExpValue())
    
    # is the sink of the inference DAG
    def isSink(self):
        return self.observation is not None and not self.refDists
    
    # is the source of the inference DAG
    def isSource(self):
        return self.expectedValue is not None and self.dist is None
    
    # reset the inference state
    def resetState(self):
        self.version = 0
        self.messages.clear()
        if self.dist:
            self.expectedValue = None
    
    # provide observed value of the variable
    def observe(self, _evidence):
        if not self.dist:
            raise ValueError("You cannot observe constants")
        self.observation = _evidence
    
    # being depended on as input parameter for a distribution
    def reference(self, _dist):
        self.refDists.append(_dist)
    
    # get expectation of the variable
    # only can be invoked after inference
    def getExpValue(self):
        return self.expectedValue[0] if isinstance(self.expectedValue, list) else self.expectedValue
    
    # update the variable from its underneath distribution
    def _pullForward(self, _version):
        if self.dist and self.version < _version:
            self.dist._forward(_version)
            self.version = _version
        self.messages.clear()
            
        return self.version
    
    # update the variable from distributions referenced it
    def _pullBackward(self, _version):
        if self.version < _version:
            for ref in self.refDists:
                ref._backward(_version)            
            if self.dist:
                self.version = _version

        return self.version

# Distribution base class for gaussian message passing inference
class Distribution:
    def __init__(self, _inputNames, _outputName):
        self.params = dict()
        self.inputs = _inputNames
        self.output = _outputName
    
    # sub-class should provide nature stats vector calculations based on 3 cases
    # 1. variable was observed
    # 2. variable was not observed, without backward updating messages
    # 3. variable was not observed, with backward updating messages
    def calcNatureStats(self, _messages):
        raise NotImplementedError("Subclasses should implement this!")
    
    # create backward updating messages for its parameters: _target
    # sub-class should provide corresponding messages based on different types of its parameters
    def calcBackwardMessage(self, _target):
        raise NotImplementedError("Subclasses should implement this!")
    
    # assign the output parameter to the distribution
    def registerOutputVariable(self, _var):
        self.params[self.output] = _var
        _var.dist = self
    
    # assign the input parameter to the distribution
    def registerInputVariable(self, _name, _var):
        if _name not in self.inputs:
            raise ValueError("parameter name {0} not found".format(_name))
        self.params[_name] = _var
        _var.reference(self)
    
    # forward passing messages to its output variable
    def _forward(self, _version):
        if self.params[self.output].version < _version:
            vers_ = 0
            for p in self.inputs:
                vers_ = max(vers_, self.params[p]._pullForward(_version))
            
            if self.params[self.output].version <= vers_:
                self.params[self.output].expectedValue = self.calcNatureStats(None)
        
            self.params[self.output].version = _version

    # backward passing messages to its parameter variables
    def _backward(self, _version):
        if self.params[self.output].version < _version:
            self.params[self.output]._pullBackward(_version)
            self.params[self.output].expectedValue = self.calcNatureStats(self.params[self.output].messages)
        
        for t in self.inputs:
            m = self.calcBackwardMessage(t)
            if m is not None:
                self.params[t].messages.append(m)

In this example, we implemented 2 distributions *Gaussian* and *Gamma*. Their conditional probability given parents (parameters) all can be easilly written into a *multi-linear* form in respect to the random variable and their parameters, as instructioned in the paper [1], which greatly simplified the steps to compute varitional messages. 

In particular, supported distribution in the graph can be written into the form of
$$ \mathbf{ln} P(Y|pa_y) = \phi_y (pa_y) \mathbf{u}_y(Y) + f_y(Y) + g_y(pa_y) $$
where $Y$ is the random variable, $pa_y$ are its parient nodes. We have $\phi_y(pa_y)$ as the nature parameter of the distribution and $\mathbf{u}_y(Y)$ as the nature stats of the distribution.

The distributions bellow shall provide 2 values:
1. The expectation of the random variable - which can be given by $\frac{\mathbf{d}\hat{g}(\phi)}{\mathbf{d}\phi}$, which $\mathbf{d}\phi$ is the $\phi_y$ re-parametered function.
2. The backward messages to their parameters - which requires re-write their *multi-linear* density function into the forms of the density functions of their parent nodes.

In [17]:
# the gamma distribution without bayesian input parameters
class GammaDistribution(Distribution):    
    _ParamAlpha = "alpha"
    _ParamBeta = "beta"
    _ParamGamma = "gamma"
    
    def __init__(self, _alpha, _beta):
        super().__init__([GammaDistribution._ParamAlpha, GammaDistribution._ParamBeta], GammaDistribution._ParamGamma)
        if _alpha.dist:
            raise ValueError("parameter alpha must be point-mass distribution")
        self.registerInputVariable(GammaDistribution._ParamAlpha, _alpha)
        if _beta.dist:
            raise ValueError("parameter beta must be point-mass distribution")
        self.registerInputVariable(GammaDistribution._ParamBeta, _beta)
        
    def calcNatureStats(self, _messages):
        if self.params[self.output].observation is not None:
            val_ = self.params[self.output].observation            
            return [val_, math.log(val_)]
        elif _messages:
            a = -self.params[GammaDistribution._ParamBeta].expectedValue + sum(map(lambda x: x[0], _messages))
            b = self.params[GammaDistribution._ParamAlpha].expectedValue - 1 + sum(map(lambda x: x[1], _messages))
            a, b = b + 1, -a
            # E(ln(x)) = digamma(a) - ln(b))
            return [a / b, sps.digamma(a) - math.log(b)]
        else:
            a = self.params[GammaDistribution._ParamAlpha].expectedValue
            b = self.params[GammaDistribution._ParamBeta].expectedValue            
            # E(ln(x)) = digamma(a) - ln(b))
            return [a / b, sps.digamma(a) - math.log(b)]
    
    # No backward message are all non-bayesian variables
    def calcBackwardMessage(self, _target):
        pass

# the normal distribution (with bayesian input parameters)
class NormalDistribution(Distribution):
    _ParamMu = "mu"
    _ParamGamma = "gm"
    _ParamX = "x"
    
    def __init__(self, _mu = None, _gamma = None):
        super().__init__([NormalDistribution._ParamMu, NormalDistribution._ParamGamma], NormalDistribution._ParamX)
        if _mu.dist and not isinstance(_mu.dist, NormalDistribution):
            raise ValueError("parameter mu must be {0} or point-mass distribution".format(NormalDistribution.__name__))
        self.registerInputVariable(NormalDistribution._ParamMu, _mu)
        if _gamma.dist and not isinstance(_gamma.dist, GammaDistribution):
            raise ValueError("parameter Garmma must be {0} or point-mass distribution".format(GammaPriorDistribution.__name__))
        self.registerInputVariable(NormalDistribution._ParamGamma, _gamma)
    
    def calcNatureStats(self, _messages):
        if self.params[self.output].observation is not None:
            val_ = self.params[self.output].observation
            return [val_, val_ * val_]        
        elif not _messages:
            mu_ = self.params[NormalDistribution._ParamMu].expectedValue
            if isinstance(mu_, list):
                return mu_
            else:
                # E(x^2) = E(x)^2 + Phi(x)^2
                phi2_ = 1 / self.params[NormalDistribution._ParamGamma].getExpValue()
                return [mu_, mu_ * mu_ + phi2_]
        else:
            # E(x) = mu = -a/(2b)
            # mu^2 = a^2/(4b^2)
            # gamma = -2b
            # E(x^2) = mu^2 + 1/gamma = (a^2-2b)/(4b^2)
            a = sum(map(lambda x: x[0], _messages)) + self.params[NormalDistribution._ParamMu].getExpValue() * self.params[NormalDistribution._ParamGamma].getExpValue()
            b = sum(map(lambda x: x[1], _messages)) - self.params[NormalDistribution._ParamGamma].getExpValue() / 2
            return [-a / (2 * b), (a * a - 2 * b) / (4 * b * b)]
    
    def calcBackwardMessage(self, _target):
        if _target == NormalDistribution._ParamMu:            
            # case of non-bayesian parameter
            if not self.params[NormalDistribution._ParamMu].dist:
                return None
            # parameter is normal distribution
            return [self.params[NormalDistribution._ParamGamma].getExpValue() * self.params[NormalDistribution._ParamX].expectedValue[0], -self.params[NormalDistribution._ParamGamma].getExpValue() / 2]
        
        elif _target == NormalDistribution._ParamGamma:
            # case of non-bayesian parameter            
            if not self.params[NormalDistribution._ParamGamma].dist:
                return None
            # parameter is gamma distribution
            mu2Exp_ = self.params[NormalDistribution._ParamMu].expectedValue[1] if isinstance(self.params[NormalDistribution._ParamMu].expectedValue, list) else self.params[NormalDistribution._ParamMu].expectedValue * self.params[NormalDistribution._ParamMu].expectedValue
            return [-(self.params[NormalDistribution._ParamX].expectedValue[1] + mu2Exp_) / 2 + self.params[NormalDistribution._ParamX].expectedValue[0] * self.params[NormalDistribution._ParamMu].getExpValue(), 0.5]
        else:
            raise ValueError("parameter {0} not found".format(_target))            

Following is the simple inference function to solve the hidden variable's posterior expectations, given the graph structure and observations.

It takes multiple iterations of *forward* and *backward* message passing steps, until the target variable converges under certain error.

In [32]:
# inference action for a given graph (list of chained variable),
# and query (the variable targeted for posterial)
# the return value is the expectation of the target variable
def inference(_graph, _query, _eps = 1E-5):
    if _eps <= 0:
        raise ValueError("parameter eps must be greater than 0")
    sinks_ = list(filter(Variable.isSink, _graph))
    sources_ = list(filter(Variable.isSource, _graph))
    for p in _graph:
        p.resetState()
    delta_ = _eps + 1
    i = 0
    lastVal_ = None
    while delta_ > _eps:
        i += 1
        for s in sinks_:
            s._pullForward(i)
        i += 1
        for s in sources_:
            s._pullBackward(i)
        if lastVal_ is not None:
            delta_ = abs(_query.getExpValue() - lastVal_)
        lastVal_ = _query.getExpValue()
        print("iter {0}: exp = {1}".format(i // 2, lastVal_))
    return lastVal_

Following code builds a simple demo chain graph. $y$ is a hidden gaussian variable with its *precision* parameter in gamma distribution. It also derives gaussian varible $x$ as its *mean* parameter. The $x$ variable has 2 observations. The inference solution is for $E(Y|X = x_1, x_2)$

In [33]:
a = Variable("Alpha", 10)
b = Variable("Beta", 1)
g1 = Variable("GammaY", GammaDistribution(a, b))
g2 = Variable("GammaX", 5.0)
m = Variable("Mu", -10)
y = Variable("Y", NormalDistribution(m, g1))
x1 = Variable("X1", NormalDistribution(y, g2))
x2 = Variable("X2", NormalDistribution(y, g2))
x1.observe(1)
x2.observe(5)
graph_ = [a, b, g1, g2, m, y, x1, x2]

In [34]:
inference(graph_, y)

iter 1: exp = -3.5
iter 2: exp = 2.4116379310344827
iter 3: exp = 2.8274818617172435
iter 4: exp = 2.8382154765027456
iter 5: exp = 2.8384792504300926
iter 6: exp = 2.8384857245040926


2.8384857245040926

## Reference
[1] [Bishop, C. M., & Winn, J. (2006). Variational message passing. Journal of Machine Learning Research, 6(1), 661](https://www.microsoft.com/en-us/research/publication/variational-message-passing/)

[2] [Minka, T. P. (2005). Divergence measures and message passing. Microsoft Research Technical Report, (MSR-TR-2005-173)](https://www.seas.harvard.edu/courses/cs281/papers/minka-divergence.pdf)

[3] [Working with different inference algorithms
](https://dotnet.github.io/infer/userguide/Working%20with%20different%20inference%20algorithms.html)