From daee7365a18a1bdb6c44e94d8811722d89ac1704 Mon Sep 17 00:00:00 2001 From: Jaakko Luttinen Date: Mon, 16 Feb 2015 11:46:07 +0200 Subject: [PATCH] DOC: Write stochastic variational inference to doc --- doc/source/user_guide/advanced.rst | 84 +++++++++++++++++++++++++++++- 1 file changed, 83 insertions(+), 1 deletion(-) diff --git a/doc/source/user_guide/advanced.rst b/doc/source/user_guide/advanced.rst index 3f9cda1f9..1ca0788b2 100644 --- a/doc/source/user_guide/advanced.rst +++ b/doc/source/user_guide/advanced.rst @@ -246,8 +246,90 @@ the next iteration with a new annealing value. Stochastic variational inference -------------------------------- +In stochastic variational inference :cite:`Hoffman:2013`, the idea is to use +mini-batches of large datasets to compute noisy gradients and learn the VB +distributions by using stochastic gradient ascent. In order for it to be +useful, the model must be such that it can be divided into "intermediate" and +"global" variables. The number of intermediate variables increases with the +data but the number of global variables remains fixed. The global variables are +learnt in the stochastic optimization. + +By denoting the data as :math:`Y=[Y_1, \ldots, Y_N]`, the intermediate variables +as :math:`Z=[Z_1, \ldots, Z_N]` and the global variables as :math:`\theta`, the +model needs to have the following structure: + +.. math:: + + p(Y, Z, \theta) &= p(\theta) \prod^N_{n=1} p(Y_n|Z_n,\theta) p(Z_n|\theta) + +The algorithm consists of three steps which are iterated: 1) a random mini-batch +of the data is selected, 2) the corresponding intermediate variables are updated +by using normal VB update equations, and 3) the global variables are updated +with (stochastic) gradient ascent as if there was as many replications of the +mini-batch as needed to recover the original dataset size. + +The learning rate for the gradient ascent must satisfy: + +.. math:: + + \sum^\infty_{i=1} \alpha_i = \infty \qquad \text{and} \qquad + \sum^\infty_{i=1} \alpha^2 < \infty, + +where :math:`i` is the iteration number. An example of a valid learning +parameter is :math:`\alpha_i = (\delta + i)^{-\gamma}`, where :math:`\delta \geq +0` is a delay and :math:`\gamma\in (0.5, 1]` is a forgetting rate. + +Stochastic variational inference is relatively easy to use in BayesPy. The idea +is that the user creates a model for the size of a mini-batch and specifies a +multiplier for those plate axes that are replicated. For the PCA example, the +mini-batch model can be costructed as follows. We decide to use ``X`` as an +intermediate variable and the other variables are global. The global variables +``alpha``, ``C`` and ``tau`` are constructed identically as before. The +intermediate variable ``X`` is constructed as: + +>>> X = GaussianARD(0, 1, +... shape=(D,), +... plates=(1,5), +... plates_multiplier=(1,20), +... name='X') + +Note that the plates are ``(1,5)`` whereas they are ``(1,100)`` in the full +model. Thus, we need to provide a plates multiplier ``(1,20)`` to define how +the plates are replicated to get the full dataset. These multipliers do not +need to be integers, in this case the latter plate axis is multiplied by +:math:`100/5=20`. The remaining variables are defined as before: + +>>> F = Dot(C, X) +>>> Y = GaussianARD(F, tau, name='Y') + +Note that the plates of ``Y`` and ``F`` also correspond to the size of the +mini-batch and they also deduce the plate multipliers from their parents, thus +we do not need to specify the multiplier here explicitly (although it is ok to +do so). + +Let us construct the inference engine for the new mini-batch model: + +>>> Q = VB(Y, C, X, alpha, tau) + +Use random initialization for ``C`` to break the symmetry in ``C`` and ``X``: + +>>> C.initialize_from_random() + +Then, stochastic variational inference algorithm could look as follows: + +>>> for n in range(200): +... subset = np.random.choice(100, 5) +... Y.observe(data[:,subset]) +... Q.update(X) +... learning_rate = (n + 2.0) ** (-0.7) +... Q.gradient_step(C, alpha, tau, scale=learning_rate) + +The loop consists of three parts: 1) Draw a random mini-batch of the data (5 +samples from 100). 2) Update the intermediate variable ``X``. 3) Update global +variables with gradient ascent using a proper learning rate. + Black-box variational inference ------------------------------- - +TODO