In [None]:
import torch


Given the following distribution:

$P(z) = Exp(\lambda=1) = \lambda \exp(-\lambda z)$

$P(x|z) = \textrm{Normal}(\mu=z, \sigma^2=1)$

We want to model $P(z|x)$, which is commonly known as the posterior distribution.

Typically, this quantity could be computed using Bayes rule.

$P(z|x) = \frac{P(x|z)P(z)}{P(x)} = \frac{P(x|z)P(z)}{\int P(x|z)P(z) dz}$

We note that the integral in the denominator is intractable, so instead, we choose a surrogate distribution $q_{\theta}(z)$ and we try to get it as close as possible to $P(z|x)$. We attempt to find the $\theta$ such that the the probability of the observed $x$ is maximized. That is:

$$\arg\min_{q_{\theta}} -E_{z \sim q_{\theta}(z)}[\log P(x)]$$

We note that:
\begin{align}
\log P(x)  = E_{z \sim q_{\theta}(z)}[\log P(x)] &= E_{z \sim q_{\theta}(z)}\left[\log \left(P(x) \frac{q_{\theta}(z)}{q_{\theta}(z)}\right)\right] \\
&= E_{z \sim q_{\theta}(z)}\left[\log \left(\frac{P(x, z)}{P(z|x)} \frac{q_{\theta}(z)}{q_{\theta}(z)}\right)\right] \\
&= E_{z \sim q_{\theta}(z)}\left[\log \left(\frac{P(x, z)}{q_{\theta}(z)} \frac{q_{\theta}(z)}{P(z|x)}\right)\right] \\
&= E_{z \sim q_{\theta}(z)}\left[\log \left(\frac{P(x, z)}{q_{\theta}(z)}\right) + \log \left( \frac{q_{\theta}(z)}{P(z|x)}\right)\right] \\
&= E_{z \sim q_{\theta}(z)}\left[\log \left(\frac{P(x, z)}{q_{\theta}(z)}\right)\right] + E_{z \sim q_{\theta}(z)}\left[ \log \left( \frac{q_{\theta}(z)}{P(z|x)}\right)\right] \\
&= \mathcal{L_{\theta}} + D_{KL}(q_{\theta}(z) || P(z|x)) \\
\log P(x) &\geq \mathcal{L_{\theta}}
\end{align}

Since, the KL divergence is non-negative, $\mathcal{L_{\theta}}$ is known as the evidence lower bound, or the ELBO.


Assume that we have observed a dataset $\mathcal{D} = \{x_i\}_{i=1}^{N}$. Then, $P(\mathcal{D}|z) = \prod_{i=1}^{N} P(x_i|z)$.

The ELBO is then:

\begin{align}
\mathcal{L}_{\theta} &= E_{z\sim q_{\theta}(z)}[\log(P(\mathcal{D}, z)) - \log(q_{\theta}(z))] \\
&= E_{z\sim q_{\theta}(z)}[\log(P(\mathcal{D}|z)P(z)) - \log q_{\theta}(z)] \\
&= E_{z\sim q_{\theta}(z)}[\log P(\mathcal{D}|z) + \log P(z) - \log q_{\theta}(z)] \\
&= E_{z\sim q_{\theta}(z)}\left[\sum_{i=1}^{N}\log P(x_i|z) + \log P(z) - \log q_{\theta}(z)\right] \\
&= E_{z\sim q_{\theta}(z)}\left[-\frac{1}{2}\sum_{i=1}^{N} (x_i - z)^2 - \frac{N}{2} \log(2\pi) + \log \lambda - \lambda z - \log \theta + \theta z \right] \\
&= E_{z\sim q_{\theta}(z)}\left[-\frac{1}{2}\sum_{i=1}^{N} (x_i^2 -2 x_i z + z^2) - \frac{N}{2} \log(2\pi) + \log \lambda - \lambda z - \log \theta + \theta z \right]
\end{align}

This equation is true for all $q_{\theta}(z)$. Now we assume that $q_{\theta}(z) = Exp(\theta)$.

\begin{align}
\mathcal{L}_{\theta} &= E_{z\sim q_{\theta}(z)}\left[-\frac{1}{2}\sum_{i=1}^{N} (x_i^2 -2 x_i z + z^2) - \frac{N}{2} \log(2\pi) + \log \lambda - \lambda z - \log \theta + \theta z \right] \\
&= -\frac{1}{2}\sum_{i=1}^{N} E_{z\sim q_{\theta}(z)}\left[x_i^2 -2 x_i z + z^2\right] + E_{z\sim q_{\theta}(z)}\left[- \frac{N}{2} \log(2\pi) + \log \lambda - \lambda z - \log \theta + \theta z \right] \\
&= -\frac{1}{2}\sum_{i=1}^{N} \left(x_i^2 - \frac{2}{\theta}x_i + \frac{2}{\theta^2}\right) - \frac{N}{2} \log(2\pi) + \log\lambda - \frac{\lambda}{\theta} - \log\theta + 1 \\
&= -\frac{1}{2}\sum_{i=1}^{N} x_i^2 + \frac{1}{\theta}\sum_{i=1}^N x_i - \frac{N}{\theta^2} - \frac{N}{2} \log(2\pi) + \log\lambda - \frac{\lambda}{\theta} - \log\theta + 1 \\
&= \frac{1}{\theta} \left(\sum_{i=1}^{N} (x_i) - \lambda \right) - \frac{N}{\theta^2} - \log \theta + C
\end{align}

To compute the optimal $\theta^*$, we need to compute the $\theta$ for which $\nabla_{\theta} \mathcal{L}_{\theta} = 0$. Let $B = \sum_{i=1}^{N} (x_i) - \lambda$, then:

\begin{align}
0 &= \nabla_{\theta} \mathcal{L}_{\theta} \\
&= \nabla_{\theta} \left(\frac{1}{\theta} B - \frac{N}{\theta^2} - \log \theta + C\right) \\
&= \frac{-1}{\theta^2} B + \frac {2N}{\theta^3} - \frac{1}{\theta} \\
&= - \theta^2 -\theta B + 2N 
\end{align}

This is a quadratic equation in $\theta$ with solutions:
$$\theta^* = \frac{-B}{2} \pm \frac{1}{2}\sqrt{B^2 + 8 N}$$

Note that $\theta$ is constrained to be positive, so one of the solutions may not be feasible.

