# Improving Variational Inference

## Recap: Variational Inference

Key idea in variational inference is to approximate the true posterior with a variational distribution. This turns the inference problem to an optimization problem that minimizes the distance between these two distributions using a metric like KL divergence.

<p align="center"> 
<img src="images/vi-recap.png" alt="Variational Inference" width="350"/>
</p>

In the above diagram $x$ is the data (e.g. image or a sound clip) that has been generated by an unknown latent factor $z$. Part of this problem is to learn $z$ and in doing that we need to approximate the true posterior with an approximate distribution.

Let $q_{x}(z)$ be the approximation of the true posterior $p(z | x)$. And we pick KL divergence to minimize the distance between these 2 distributions.

$\begin{aligned} D_{\mathrm{KL}}\left[q_{x}(z) \| p(z | x)\right] &=\mathbb{E}_{z \sim q_{x}(z)}\left[\log q_{x}(z)-\log p(z | x)\right] \\ &=\mathbb{E}_{z \sim q_{x}(z)}\left[\log q_{x}(z)-\log \frac{p(z, x)}{p(x)}\right] \\ &=\mathbb{E}_{z \sim q_{x}(z)}\left[\log q_{x}(z)-\log p(z)-\log p(x | z)+\log p(x)\right] \\ &=\underbrace{\mathbb{E}_{z \sim q_{x}(z)}\left[\log q_{x}(z)-\log p(z)-\log p(x | z)\right]}_{\text {Only this part depends on } z} \end{aligned}$

Using simple algebra and Bayes' rule, we arrive at the above expression that can be used to minimize the distance between the distributions. The expectation in the expression can be approximated by stochastic samples and each term within the expectation can be computed in O(1) time.

## Variational Inference as Importance Sampling

To train a latent variable model, we want to compute the marginal likelihood for any given $x$: $p(x)=\sum_{z} p(z, x)$. And the way we train it is by using Maximum Likelihood. This means that for any given $x$ we need to compute the marginal probability $p(x)$. To get the marginal, assuming that we have the discrete latent code $z$, we need to sum over all $z$ for this joint probability $p(z, x)$. This becomes difficult if $z$ has exponential number of choices.

But here's an intuition from empirical observations. For any $x$, typically $p(x, z)$ will have probability mass concentrated in very few places. e.g. if for an image the latent code $z$ is used to describe a semantic thing like if there is a car in the image, then only for that timy fraction of the image where the car is present will have that $z$ turned on.

Hence in most of the VAE or any other meaningful latent variable models, we will have a very picky distribution in the joint space for any given $x$. This implies that in the high dimensional space we don't need to enumerate through all the possibilities of $z$. We can place emphasis on the possibilities of $z$ that are more likely under that $x$. This is one way of looking at _variational inference as Importance Sampling_.

**Intuition:** The variational distribution $q(z | x)$ samples the high density region of $p(z, x)$

Have a look at the following derivation:

$\begin{aligned} \log p(x) &=\log \sum_{z} p(z, x) \\ &=\log \sum_{z} q(z | x) \frac{p(z, x)}{q(z | x)} \\ &=\log \mathbb{E}_{z \sim q(z | x)}\left[\frac{p(z, x)}{q(z | x)}\right] \\ & \geq \mathbb{E}_{z \sim q(z | x)}\left[\log \frac{p(z, x)}{q(z | x)}\right] \end{aligned}$

A few things to note here:

* The expression in the third line is very close to the Variational Lower Bound (VLB). In fact if we get the $\log$ inside by applying Jensen's inequality as in the fourth line then we get the VLB itself.
* The third line of the derivation implies that if we draw a lot of samples from $q(z | x)$, average them and then take the $\log$ we approach $\log p(x)$.
* The fourth line of the derivation gives the VLB. This says that if I push the $\log$ inside I can draw _one_ sample $z$ and get the lower bound itself.


So what the above implies is that Variational Inference is one way to do Importance Sampling. And that's exactly what we do in a VAE _with one sample_ (the 4th line in the derivation above). **The question is if we use multiple samples for importance sampling in a VAE, will that improve the lower bound ?**

The answer is **Yes!** and Burda et al shows the way to do it in the paper [Importance Weighted Autoencoders](https://arxiv.org/abs/1509.00519) - by Yuri Burda, Roger Grosse & Ruslan Salakhutdinov