In [None]:
%matplotlib ipympl

import common.torch as torch
import matplotlib as mpl
import matplotlib.pyplot as plt
import torchvision
import math
mpl.style.use('ggplot')

from IPython.display import display, Math


Given the following distribution:

$P(z) = Exp(\lambda=1) = \lambda \exp(-\lambda z)$

$P(x|z) = \textrm{Normal}(\mu=z, \sigma^2=1)$

We want to model $P(z|x)$, which is commonly known as the posterior distribution.

Typically, this quantity could be computed using Bayes rule.

$P(z|x) = \frac{P(x|z)P(z)}{P(x)} = \frac{P(x|z)P(z)}{\int P(x|z)P(z) dz}$

We note that the integral in the denominator is intractable, so instead, we choose a surrogate distribution $q_{\theta}(z)$ and we try to get it as close as possible to $P(z|x)$. We attempt to find the $\theta$ such that the the probability of the observed $x$ is maximized. That is:

$$\arg\min_{q_{\theta}} -E_{z \sim q_{\theta}(z)}[\log P(x)]$$

We note that:
\begin{align}
\log P(x)  = E_{z \sim q_{\theta}(z)}[\log P(x)] &= E_{z \sim q_{\theta}(z)}\left[\log \left(P(x) \frac{q_{\theta}(z)}{q_{\theta}(z)}\right)\right] \\
&= E_{z \sim q_{\theta}(z)}\left[\log \left(\frac{P(x, z)}{P(z|x)} \frac{q_{\theta}(z)}{q_{\theta}(z)}\right)\right] \\
&= E_{z \sim q_{\theta}(z)}\left[\log \left(\frac{P(x, z)}{q_{\theta}(z)} \frac{q_{\theta}(z)}{P(z|x)}\right)\right] \\
&= E_{z \sim q_{\theta}(z)}\left[\log \left(\frac{P(x, z)}{q_{\theta}(z)}\right) + \log \left( \frac{q_{\theta}(z)}{P(z|x)}\right)\right] \\
&= E_{z \sim q_{\theta}(z)}\left[\log \left(\frac{P(x, z)}{q_{\theta}(z)}\right)\right] + E_{z \sim q_{\theta}(z)}\left[ \log \left( \frac{q_{\theta}(z)}{P(z|x)}\right)\right] \\
&= \mathcal{L_{\theta}} + D_{KL}(q_{\theta}(z) || P(z|x)) \\
\log P(x) &\geq \mathcal{L_{\theta}}
\end{align}

Since, the KL divergence is non-negative, $\mathcal{L_{\theta}}$ is known as the evidence lower bound, or the ELBO.

Assume that we have observed a dataset $\mathcal{D} = \{x_i\}_{i=1}^{N}$. Then, $P(\mathcal{D}|z) = \prod_{i=1}^{N} P(x_i|z)$.

The ELBO is then:

\begin{align}
\mathcal{L}_{\theta} &= E_{z\sim q_{\theta}(z)}[\log(P(\mathcal{D}, z)) - \log(q_{\theta}(z))] \\
&= E_{z\sim q_{\theta}(z)}[\log(P(\mathcal{D}|z)P(z)) - \log q_{\theta}(z)] \\
&= E_{z\sim q_{\theta}(z)}[\log P(\mathcal{D}|z) + \log P(z) - \log q_{\theta}(z)] \\
&= E_{z\sim q_{\theta}(z)}\left[\sum_{i=1}^{N}\log P(x_i|z) + \log P(z) - \log q_{\theta}(z)\right] \\
&= E_{z\sim q_{\theta}(z)}\left[-\frac{1}{2}\sum_{i=1}^{N} (x_i - z)^2 - \frac{N}{2} \log(2\pi) + \log \lambda - \lambda z - \log \theta + \theta z \right] \\
&= E_{z\sim q_{\theta}(z)}\left[-\frac{1}{2}\sum_{i=1}^{N} (x_i^2 -2 x_i z + z^2) - \frac{N}{2} \log(2\pi) + \log \lambda - \lambda z - \log \theta + \theta z \right]
\end{align}

This equation is true for all $q_{\theta}(z)$. Now we assume that $q_{\theta}(z) = Exp(\theta)$.

\begin{align}
\mathcal{L}_{\theta} &= E_{z\sim q_{\theta}(z)}\left[-\frac{1}{2}\sum_{i=1}^{N} (x_i^2 -2 x_i z + z^2) - \frac{N}{2} \log(2\pi) + \log \lambda - \lambda z - \log \theta + \theta z \right] \\
&= -\frac{1}{2}\sum_{i=1}^{N} E_{z\sim q_{\theta}(z)}\left[x_i^2 -2 x_i z + z^2\right] + E_{z\sim q_{\theta}(z)}\left[- \frac{N}{2} \log(2\pi) + \log \lambda - \lambda z - \log \theta + \theta z \right] \\
&= -\frac{1}{2}\sum_{i=1}^{N} \left(x_i^2 - \frac{2}{\theta}x_i + \frac{2}{\theta^2}\right) - \frac{N}{2} \log(2\pi) + \log\lambda - \frac{\lambda}{\theta} - \log\theta + 1 \\
&= -\frac{1}{2}\sum_{i=1}^{N} x_i^2 + \frac{1}{\theta}\sum_{i=1}^N x_i - \frac{N}{\theta^2} - \frac{N}{2} \log(2\pi) + \log\lambda - \frac{\lambda}{\theta} - \log\theta + 1 \\
&= \frac{1}{\theta} \left(\sum_{i=1}^{N} (x_i) - \lambda \right) - \frac{N}{\theta^2} - \log \theta + C
\end{align}

To compute the optimal $\theta^*$, we need to compute the $\theta$ for which $\nabla_{\theta} \mathcal{L}_{\theta} = 0$. Let $B = \sum_{i=1}^{N} (x_i) - \lambda$, then:

\begin{align}
0 &= \nabla_{\theta} \mathcal{L}_{\theta} \\
&= \nabla_{\theta} \left(\frac{1}{\theta} B - \frac{N}{\theta^2} - \log \theta + C\right) \\
&= \frac{-1}{\theta^2} B + \frac {2N}{\theta^3} - \frac{1}{\theta} \\
&= - \theta^2 -\theta B + 2N 
\end{align}

This is a quadratic equation in $\theta$ with solutions:
$$\theta^* = \frac{-B}{2} \pm \frac{1}{2}\sqrt{B^2 + 8 N}$$

Note that $\theta$ is constrained to be positive, so one of the solutions may not be feasible.



In [None]:
prior_lambda = 1.0
prior_dist = torch.distributions.Exponential(prior_lambda)
torch.manual_seed(0)
sampled_z = prior_dist.sample()

normal_std = 1.0
observation_dist = torch.distributions.Normal(sampled_z, normal_std)
sampled_x = observation_dist.sample((5,))
# sampled_x = torch.tensor(2.3)

print(f'z: {sampled_z} x: {sampled_x}')

In [None]:
Zs = torch.linspace(0, 15, 2000, dtype=torch.float64)
joint_evals = []
for z in Zs:
    obs_dist = torch.distributions.Normal(z, normal_std)
    log_obs_prob = torch.sum(obs_dist.log_prob(sampled_x))
    joint_evals.append(prior_dist.log_prob(z) + log_obs_prob)
joint_evals = torch.exp(torch.stack(joint_evals))
approx_posterior = joint_evals / torch.sum(joint_evals * Zs[1])

In [None]:
# Assuming an exponential posterior, find the optimal theta
B = torch.sum(sampled_x) - prior_lambda
theta_star = 0.5 * (-B + torch.sqrt(B**2 + 8 * sampled_x.numel()))
display(Math(rf'$\theta^*={theta_star:3f}$'))
variational_posterior = torch.distributions.Exponential(theta_star)
variational_evals = torch.exp(variational_posterior.log_prob(Zs))


In [None]:
plt.figure()
plt.plot(Zs, joint_evals, label='joint')
plt.plot(Zs, approx_posterior, label='numerical posterior')
plt.plot(Zs, variational_evals, label='variational posterior')
plt.xlabel('Z')
plt.ylabel('$p(Z|X_{obs})$')
plt.xlim(0, 5)
plt.legend()


In [None]:
def compute_KL(Zs, numerical_posterior, variational_posterior):
    delta = Zs[1]
    variational_vals = variational_posterior.log_prob(Zs)
    exp_variational_vals = torch.exp(variational_vals)
    numerical_vals = torch.log(numerical_posterior)
    return delta * torch.sum(exp_variational_vals * (variational_vals - numerical_vals))
    
thetas = torch.linspace(0.5, 2, 500, dtype=torch.float64)
kl_divergences = []
for theta in thetas:
    kl_divergences.append(compute_KL(Zs, approx_posterior, torch.distributions.Exponential(theta)))
    # break

plt.figure()
plt.plot(thetas, kl_divergences)
plt.ylabel('KL Divergence')
plt.xlabel(r'$\theta$')


Note that the plot doesn't show a minimum at the expected location because the KL divergence is only computed for a subset of the range $[0, \infty)$.

Now we try to optimize the surrogate using pytorch.

We want to optimize the ELBO:
\begin{align}
\mathcal{L}_{\theta} &= E_{z\sim q_{\theta}(z)} \left[\log \frac{p(\mathcal{D},z)}{q_{\theta}(z)} \right]
\end{align}

Which means that we need to compute $\nabla_{\theta}\mathcal{L}_{\theta}$. This is difficult because the expectation is with respect to the distribution which is a function of $\theta$. We then apply the reparameterization trick. Let $q_{\theta}(z) = g_{\theta}(x, t)$, where $t\sim\epsilon$ is a fixed noise distribution. Then $E_{z\sim q_{\theta}}[f(z)] = E_{t \sim \epsilon}[f(g_{\theta}(x, t))]$. This expectation can be estimated by drawing samples from $\epsilon$ and computing the sample mean of $f(g_{\theta}(x, t))$.


In [None]:
# Now we try to train find the parameter through optimization

class VariationalModel(torch.nn.Module):
    N = 128
    
    def __init__(self):
        super(VariationalModel, self).__init__()

        self.param = torch.nn.Parameter(data=torch.tensor([1.0]))
        self.sampling_dist = torch.distributions.Exponential(1.0)
        self.prior = torch.distributions.Exponential(1.0)

    def elbo(self, data):
        
        sampled_z = self.sampling_dist.sample((self.N,))
        # Reparameterization trick for exponential distributions
        reparametrized_z = sampled_z / self.param

        # compute the probability of the data
        # p(D, z) = p(z) * prod(p(x_i, z) for x_i in data)
        elbo = 0.0
        for z in reparametrized_z:
            log_p_z = self.prior.log_prob(z)
            obs = torch.distributions.Normal(z, 1.0)
            log_data_given_z = torch.sum(obs.log_prob(data))

            log_model = torch.distributions.Exponential(self.param).log_prob(z)
            elbo += log_p_z + log_data_given_z - log_model
        return elbo / self.N
        

In [None]:
model = VariationalModel()
model.train()
opt = torch.optim.SGD(model.parameters(), lr=1e-4)

for i in range(1000):
    opt.zero_grad()

    loss = -model.elbo(sampled_x)
    loss.backward()
    if i % 100 == 0:
        print(i, loss.detach(), model.param.data)
    opt.step()    


This is really close to the optimal value we found above!

In [None]:
# Now we try to implement a VAE on MNIST

class VariationalAutoEncoder(torch.nn.Module):
    def __init__(self):
        super(VariationalAutoEncoder, self).__init__()
        # Define the encoder layers
        self._encoder = torch.nn.Sequential(
            torch.nn.Conv2d(1, 8, 3),
            torch.nn.LeakyReLU(),
            torch.nn.Conv2d(8, 8, 3),
            torch.nn.LeakyReLU(),
            torch.nn.MaxPool2d(2, 2),
            torch.nn.Conv2d(8, 16, 3),
            torch.nn.LeakyReLU(),
            torch.nn.Conv2d(16, 16, 3),
            torch.nn.LeakyReLU(),
            torch.nn.MaxPool2d(2, 2),
            torch.nn.Conv2d(16, 32, 2),
            torch.nn.LeakyReLU(),
            torch.nn.Conv2d(32, 32, 2),
            torch.nn.LeakyReLU(),
            torch.nn.MaxPool2d(2, 2),
        )

        self._mean_model = torch.nn.Linear(32, 32)
        self._log_sigma_model = torch.nn.Linear(32, 32)

        # Define the decoder layers
        self._decoder = torch.nn.Sequential(
            torch.nn.Conv2d(32, 32, 1),
            torch.nn.LeakyReLU(),
            torch.nn.Upsample(scale_factor=2, mode='bilinear'),
            torch.nn.Conv2d(32, 16, 3, padding=2),
            torch.nn.LeakyReLU(),
            torch.nn.Conv2d(16, 16, 3, padding=2),
            torch.nn.LeakyReLU(),
            torch.nn.Upsample(scale_factor=2, mode='bilinear'),
            torch.nn.Conv2d(16, 8, 3, padding=2),
            torch.nn.LeakyReLU(),
            torch.nn.Conv2d(8, 8, 3, padding=1),
            torch.nn.LeakyReLU(),
            torch.nn.Upsample(scale_factor=2, mode='bilinear'),
            torch.nn.Conv2d(8, 1, 3, padding=1),          
        )

        self._prior_mean = torch.zeros(32)
        self._scale_tril = torch.eye(32)
        self._prior_dist = torch.distributions.MultivariateNormal(self._prior_mean, scale_tril=self._scale_tril)

    def encoder(self, x: torch.Tensor):
        '''
        Implements p_{\theta}(z | x)
        '''
        feature = self._encoder(x)
        mean = self._mean_model(feature.squeeze())
        log_sigma = self._log_sigma_model(feature.squeeze())
        return mean, log_sigma
        ...

    def decoder(self, z: torch.Tensor):
        '''
        Implements p_{\phi}(x | z)
        '''
        return self._decoder(z)

    def elbo(self, data: torch.Tensor):
        """
         \mathcal{L}_{\theta, \phi} = E_{z\sim q_{\theta}(z)}[\log(P(\mathcal{D}|z)P(z)) - \log q_{\theta}(z)]
         Assume that z = g_\theta(x, t), where t \sim N(0, I), then
         = E_{q_{\theta}}[\log P(D|g(t, \epsilon)) P(g(t, \epsilon)) - \log q_{\theta}(g(t, \epsilon))] 
        """
        n_b, n_c, h, w = data.shape
        # p(z|x) is approximated by the encoder. p(x|z) is approximated by the decoder.
        t = self._prior_dist.sample((n_b,))
        dev_t = t.to(data.device)
        mean, log_sigma = self.encoder(data)
        dev_z = mean + torch.exp(log_sigma) * dev_t
        dev_z = dev_z.unsqueeze(-1).unsqueeze(-1)
        output = self.decoder(dev_z)

        # The decoder parametrizes a gaussian with a mean given by the output and an identity covariance
        # This corresponds to an MSE loss
        log_p_x_given_z = -0.5 * (data - output) ** 2
        log_p_x_given_z = torch.sum(log_p_x_given_z, (-1, -2, -3))

        # We prescribe that the latent distribution is a guassian with zero mean and unit variance.
        # This term encourages the encoder to keep mean and log sigma small
        log_p_z = -0.5 * dev_z ** 2 
        log_p_z = torch.sum(log_p_z, -1)

        # Since we reparametrized sampling from sampling from z ~ q(z|x) to sampling from t~N(0, I), the Cdf of q is
        # related to the CDF of N(0, I) through the change of variables formula Q(z) = F(g(t)). Then the pdf is the 
        # derivative with respect to t, q(z) = f(t)|det(J)| where J is the jacobian dz/dt. Since z = mu + exp(log sigma) * t,
        # the jacobian is simply exp(log sigma).
        z = dev_z.to('cpu')
        log_q_z_given_z = self._prior_dist.log_prob(t.squeeze()).to(data.device) + torch.abs(torch.sum(log_sigma, -1))

        return torch.mean(log_p_x_given_z + log_p_z - log_q_z_given_z)


In [None]:
model = VariationalAutoEncoder()

In [None]:

transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(0.5, 0.5)
])

dataset = torchvision.datasets.MNIST('/home/erick/scratch', download=True, transform=transforms)

subset = torch.utils.data.Subset(dataset, [0]*256*100)

loader = torch.utils.data.DataLoader(subset, batch_size=256, shuffle=True)


model = model.cuda()
model = model.train()
optim = torch.optim.SGD(model.parameters(), lr=1e-4)

image = dataset[0][0].unsqueeze(0).cuda()

for epoch_idx in range(10):
    for i, (batch, labels) in enumerate(loader):
        optim.zero_grad()
        batch = batch.cuda()
        loss = -model.elbo(batch)
        loss.backward()
        optim.step()
    
        if i % 100 == 0:
            print(epoch_idx, i, loss)
    mean, log_sigma = model.encoder(image)
    print(torch.norm(mean, p=1), torch.mean(torch.exp(log_sigma)))
    
    

In [None]:
plt.figure()
plt.imshow(dataset[0][0][0])

In [None]:
image = dataset[0][0].unsqueeze(0).cuda()

mean, log_sigma = model.encoder(image)

In [None]:
torch.norm(mean)
torch.exp(log_sigma)