In [None]:
%matplotlib ipympl
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.style.use('ggplot')

from IPython.display import display, Math


Given the following distribution:

$P(z) = Exp(\lambda=1) = \lambda \exp(-\lambda z)$

$P(x|z) = \textrm{Normal}(\mu=z, \sigma^2=1)$

We want to model $P(z|x)$, which is commonly known as the posterior distribution.

Typically, this quantity could be computed using Bayes rule.

$P(z|x) = \frac{P(x|z)P(z)}{P(x)} = \frac{P(x|z)P(z)}{\int P(x|z)P(z) dz}$

We note that the integral in the denominator is intractable, so instead, we choose a surrogate distribution $q_{\theta}(z)$ and we try to get it as close as possible to $P(z|x)$. We attempt to find the $\theta$ such that the the probability of the observed $x$ is maximized. That is:

$$\arg\min_{q_{\theta}} -E_{z \sim q_{\theta}(z)}[\log P(x)]$$

We note that:
\begin{align}
\log P(x)  = E_{z \sim q_{\theta}(z)}[\log P(x)] &= E_{z \sim q_{\theta}(z)}\left[\log \left(P(x) \frac{q_{\theta}(z)}{q_{\theta}(z)}\right)\right] \\
&= E_{z \sim q_{\theta}(z)}\left[\log \left(\frac{P(x, z)}{P(z|x)} \frac{q_{\theta}(z)}{q_{\theta}(z)}\right)\right] \\
&= E_{z \sim q_{\theta}(z)}\left[\log \left(\frac{P(x, z)}{q_{\theta}(z)} \frac{q_{\theta}(z)}{P(z|x)}\right)\right] \\
&= E_{z \sim q_{\theta}(z)}\left[\log \left(\frac{P(x, z)}{q_{\theta}(z)}\right) + \log \left( \frac{q_{\theta}(z)}{P(z|x)}\right)\right] \\
&= E_{z \sim q_{\theta}(z)}\left[\log \left(\frac{P(x, z)}{q_{\theta}(z)}\right)\right] + E_{z \sim q_{\theta}(z)}\left[ \log \left( \frac{q_{\theta}(z)}{P(z|x)}\right)\right] \\
&= \mathcal{L_{\theta}} + D_{KL}(q_{\theta}(z) || P(z|x)) \\
\log P(x) &\geq \mathcal{L_{\theta}}
\end{align}

Since, the KL divergence is non-negative, $\mathcal{L_{\theta}}$ is known as the evidence lower bound, or the ELBO.


Assume that we have observed a dataset $\mathcal{D} = \{x_i\}_{i=1}^{N}$. Then, $P(\mathcal{D}|z) = \prod_{i=1}^{N} P(x_i|z)$.

The ELBO is then:

\begin{align}
\mathcal{L}_{\theta} &= E_{z\sim q_{\theta}(z)}[\log(P(\mathcal{D}, z)) - \log(q_{\theta}(z))] \\
&= E_{z\sim q_{\theta}(z)}[\log(P(\mathcal{D}|z)P(z)) - \log q_{\theta}(z)] \\
&= E_{z\sim q_{\theta}(z)}[\log P(\mathcal{D}|z) + \log P(z) - \log q_{\theta}(z)] \\
&= E_{z\sim q_{\theta}(z)}\left[\sum_{i=1}^{N}\log P(x_i|z) + \log P(z) - \log q_{\theta}(z)\right] \\
&= E_{z\sim q_{\theta}(z)}\left[-\frac{1}{2}\sum_{i=1}^{N} (x_i - z)^2 - \frac{N}{2} \log(2\pi) + \log \lambda - \lambda z - \log \theta + \theta z \right] \\
&= E_{z\sim q_{\theta}(z)}\left[-\frac{1}{2}\sum_{i=1}^{N} (x_i^2 -2 x_i z + z^2) - \frac{N}{2} \log(2\pi) + \log \lambda - \lambda z - \log \theta + \theta z \right]
\end{align}

This equation is true for all $q_{\theta}(z)$. Now we assume that $q_{\theta}(z) = Exp(\theta)$.

\begin{align}
\mathcal{L}_{\theta} &= E_{z\sim q_{\theta}(z)}\left[-\frac{1}{2}\sum_{i=1}^{N} (x_i^2 -2 x_i z + z^2) - \frac{N}{2} \log(2\pi) + \log \lambda - \lambda z - \log \theta + \theta z \right] \\
&= -\frac{1}{2}\sum_{i=1}^{N} E_{z\sim q_{\theta}(z)}\left[x_i^2 -2 x_i z + z^2\right] + E_{z\sim q_{\theta}(z)}\left[- \frac{N}{2} \log(2\pi) + \log \lambda - \lambda z - \log \theta + \theta z \right] \\
&= -\frac{1}{2}\sum_{i=1}^{N} \left(x_i^2 - \frac{2}{\theta}x_i + \frac{2}{\theta^2}\right) - \frac{N}{2} \log(2\pi) + \log\lambda - \frac{\lambda}{\theta} - \log\theta + 1 \\
&= -\frac{1}{2}\sum_{i=1}^{N} x_i^2 + \frac{1}{\theta}\sum_{i=1}^N x_i - \frac{N}{\theta^2} - \frac{N}{2} \log(2\pi) + \log\lambda - \frac{\lambda}{\theta} - \log\theta + 1 \\
&= \frac{1}{\theta} \left(\sum_{i=1}^{N} (x_i) - \lambda \right) - \frac{N}{\theta^2} - \log \theta + C
\end{align}

To compute the optimal $\theta^*$, we need to compute the $\theta$ for which $\nabla_{\theta} \mathcal{L}_{\theta} = 0$. Let $B = \sum_{i=1}^{N} (x_i) - \lambda$, then:

\begin{align}
0 &= \nabla_{\theta} \mathcal{L}_{\theta} \\
&= \nabla_{\theta} \left(\frac{1}{\theta} B - \frac{N}{\theta^2} - \log \theta + C\right) \\
&= \frac{-1}{\theta^2} B + \frac {2N}{\theta^3} - \frac{1}{\theta} \\
&= - \theta^2 -\theta B + 2N 
\end{align}

This is a quadratic equation in $\theta$ with solutions:
$$\theta^* = \frac{-B}{2} \pm \frac{1}{2}\sqrt{B^2 + 8 N}$$

Note that $\theta$ is constrained to be positive, so one of the solutions may not be feasible.



In [None]:
prior_lambda = 1.0
prior_dist = torch.distributions.Exponential(prior_lambda)
torch.manual_seed(0)
sampled_z = prior_dist.sample()

normal_std = 1.0
observation_dist = torch.distributions.Normal(sampled_z, normal_std)
sampled_x = observation_dist.sample((5,))
# sampled_x = torch.tensor(2.3)

print(f'z: {sampled_z} x: {sampled_x}')

In [None]:
Zs = torch.linspace(0, 15, 2000, dtype=torch.float64)
joint_evals = []
for z in Zs:
    obs_dist = torch.distributions.Normal(z, normal_std)
    log_obs_prob = torch.sum(obs_dist.log_prob(sampled_x))
    joint_evals.append(prior_dist.log_prob(z) + log_obs_prob)
joint_evals = torch.exp(torch.stack(joint_evals))
approx_posterior = joint_evals / torch.sum(joint_evals * Zs[1])

In [None]:
# Assuming an exponential posterior, find the optimal theta
B = torch.sum(sampled_x) - prior_lambda
theta_star = 0.5 * (-B + torch.sqrt(B**2 + 8 * sampled_x.numel()))
display(Math(rf'$\theta^*={theta_star:3f}$'))
variational_posterior = torch.distributions.Exponential(theta_star)
variational_evals = torch.exp(variational_posterior.log_prob(Zs))


In [None]:
plt.figure()
plt.plot(Zs, joint_evals, label='joint')
plt.plot(Zs, approx_posterior, label='numerical posterior')
plt.plot(Zs, variational_evals, label='variational posterior')
plt.xlabel('Z')
plt.ylabel('$p(Z|X_{obs})$')
plt.xlim(0, 5)
plt.legend()


In [None]:
def compute_KL(Zs, numerical_posterior, variational_posterior):
    delta = Zs[1]
    variational_vals = variational_posterior.log_prob(Zs)
    exp_variational_vals = torch.exp(variational_vals)
    numerical_vals = torch.log(numerical_posterior)
    return delta * torch.sum(exp_variational_vals * (variational_vals - numerical_vals))
    
thetas = torch.linspace(0.5, 2, 500, dtype=torch.float64)
kl_divergences = []
for theta in thetas:
    kl_divergences.append(compute_KL(Zs, approx_posterior, torch.distributions.Exponential(theta)))
    # break

plt.figure()
plt.plot(thetas, kl_divergences)
plt.ylabel('KL Divergence')
plt.xlabel(r'$\theta$')


Note that the plot doesn't show a minimum at the expected location because the KL divergence is only computed for a subset of the range $[0, \infty)$.

Now we try to optimize the surrogate using pytorch.

We want to optimize the ELBO:
\begin{align}
\mathcal{L}_{\theta} &= E_{z\sim q_{\theta}(z)} \left[\log \frac{p(\mathcal{D},z)}{q_{\theta}(z)} \right]
\end{align}

Which means that we need to compute $\nabla_{\theta}\mathcal{L}_{\theta}$. This is difficult because the expectation is with respect to the distribution which is a function of $\theta$. We then apply the reparameterization trick. Let $q_{\theta}(z) = g_{\theta}(x, t)$, where $t\sim\epsilon$ is a fixed noise distribution. Then $E_{z\sim q_{\theta}}[f(z)] = E_{t \sim \epsilon}[f(g_{\theta}(x, t))]$. This expectation can be estimated by drawing samples from $\epsilon$ and computing the sample mean of $f(g_{\theta}(x, t))$.


In [None]:
# Now we try to train find the parameter through optimization

class VariationalModel(torch.nn.Module):
    def __init__(self):
        super(VariationalModel, self).__init__()

        self.param = torch.nn.Parameter(data=torch.tensor([1.0]))
        self.sampling_dist = torch.distributions.Exponential(1.0)
        self.prior = torch.distributions.Exponential(1.0)

    def elbo(self, data):
        sampled_z = self.sampling_dist.sample()
        # Reparameterization trick for exponential distributions
        reparametrized_z = sampled_z / self.param

        # compute the probability of the data
        # p(D, z) = p(z) * prod(p(x_i, z) for x_i in data)
        log_p_z = self.prior.log_prob(reparametrized_z)
        obs = torch.distributions.Normal(reparametrized_z, 1.0)
        log_data_given_z = torch.sum(obs.log_prob(data))

        log_model = torch.distributions.Exponential(self.param).log_prob(reparametrized_z)
        return log_p_z + log_data_given_z - log_model
        

In [None]:
model = VariationalModel()
model.train()
opt = torch.optim.SGD(model.parameters(), lr=1e-5)

for i in range(40000):
    opt.zero_grad()

    loss = -model.elbo(sampled_x)
    loss.backward()
    if i % 1000 == 0:
        print(i, loss.detach(), model.param.data)
    opt.step()    


This is really close to the optimal value we found above!