Skip to content

Commit

Permalink
DOC: Write stochastic variational inference to doc
Browse files Browse the repository at this point in the history
  • Loading branch information
jluttine committed Feb 16, 2015
1 parent 97cdd74 commit daee736
Showing 1 changed file with 83 additions and 1 deletion.
84 changes: 83 additions & 1 deletion doc/source/user_guide/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,90 @@ the next iteration with a new annealing value.
Stochastic variational inference
--------------------------------

In stochastic variational inference :cite:`Hoffman:2013`, the idea is to use
mini-batches of large datasets to compute noisy gradients and learn the VB
distributions by using stochastic gradient ascent. In order for it to be
useful, the model must be such that it can be divided into "intermediate" and
"global" variables. The number of intermediate variables increases with the
data but the number of global variables remains fixed. The global variables are
learnt in the stochastic optimization.

By denoting the data as :math:`Y=[Y_1, \ldots, Y_N]`, the intermediate variables
as :math:`Z=[Z_1, \ldots, Z_N]` and the global variables as :math:`\theta`, the
model needs to have the following structure:

.. math::
p(Y, Z, \theta) &= p(\theta) \prod^N_{n=1} p(Y_n|Z_n,\theta) p(Z_n|\theta)
The algorithm consists of three steps which are iterated: 1) a random mini-batch
of the data is selected, 2) the corresponding intermediate variables are updated
by using normal VB update equations, and 3) the global variables are updated
with (stochastic) gradient ascent as if there was as many replications of the
mini-batch as needed to recover the original dataset size.

The learning rate for the gradient ascent must satisfy:

.. math::
\sum^\infty_{i=1} \alpha_i = \infty \qquad \text{and} \qquad
\sum^\infty_{i=1} \alpha^2 < \infty,
where :math:`i` is the iteration number. An example of a valid learning
parameter is :math:`\alpha_i = (\delta + i)^{-\gamma}`, where :math:`\delta \geq
0` is a delay and :math:`\gamma\in (0.5, 1]` is a forgetting rate.

Stochastic variational inference is relatively easy to use in BayesPy. The idea
is that the user creates a model for the size of a mini-batch and specifies a
multiplier for those plate axes that are replicated. For the PCA example, the
mini-batch model can be costructed as follows. We decide to use ``X`` as an
intermediate variable and the other variables are global. The global variables
``alpha``, ``C`` and ``tau`` are constructed identically as before. The
intermediate variable ``X`` is constructed as:

>>> X = GaussianARD(0, 1,
... shape=(D,),
... plates=(1,5),
... plates_multiplier=(1,20),
... name='X')

Note that the plates are ``(1,5)`` whereas they are ``(1,100)`` in the full
model. Thus, we need to provide a plates multiplier ``(1,20)`` to define how
the plates are replicated to get the full dataset. These multipliers do not
need to be integers, in this case the latter plate axis is multiplied by
:math:`100/5=20`. The remaining variables are defined as before:

>>> F = Dot(C, X)
>>> Y = GaussianARD(F, tau, name='Y')

Note that the plates of ``Y`` and ``F`` also correspond to the size of the
mini-batch and they also deduce the plate multipliers from their parents, thus
we do not need to specify the multiplier here explicitly (although it is ok to
do so).

Let us construct the inference engine for the new mini-batch model:

>>> Q = VB(Y, C, X, alpha, tau)

Use random initialization for ``C`` to break the symmetry in ``C`` and ``X``:

>>> C.initialize_from_random()

Then, stochastic variational inference algorithm could look as follows:

>>> for n in range(200):
... subset = np.random.choice(100, 5)
... Y.observe(data[:,subset])
... Q.update(X)
... learning_rate = (n + 2.0) ** (-0.7)
... Q.gradient_step(C, alpha, tau, scale=learning_rate)

The loop consists of three parts: 1) Draw a random mini-batch of the data (5
samples from 100). 2) Update the intermediate variable ``X``. 3) Update global
variables with gradient ascent using a proper learning rate.


Black-box variational inference
-------------------------------


TODO

0 comments on commit daee736

Please sign in to comment.