\begin{align}
\log p(x) &= \log \int_z dz q(z | x) \frac{p(x|z) p(z)}{q(x|z)} \\
&\ge \int_z dz q(z | x) \log  \frac{p(x|z) p(z)}{q(x|z)} \\
&= \mathbb{E}_{Z \sim q(z|x)} [\log p(x|Z)] - \mathbf{D}_{KL} (q(Z|x) || p(Z))
\end{align}

학습 데이타를 잘 반영한다는 것은 maximum likelihood 처리와 비슷하게 $\prod_{x \in D} p(x)$를 최대화 하는 확률밀도함수 p(x)를 찾는 것이다. $\log$를 취하면, $\mathbb{E}_{X \sim D} \log p(X)$로 쓸 수 있다. 우변의 첫 항은 decoder에서 두번 째 항은 encoder와 관련되어 있다.

# 1. Required modules

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from functools import partial

import matplotlib.pyplot as plt
%matplotlib inline


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 2. Loading the MNIST data

In [2]:
batch_size = 128

img_transform = transforms.Compose([
    transforms.ToTensor(),
    partial(torch.reshape, shape=(-1,)),
])

# Training dataset
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        root='./data',
        train=True,
        download=True,
        transform=img_transform),
    batch_size=batch_size,
    shuffle=True)
# Test dataset
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        root='./data',
        train=False,
        transform=img_transform),
    batch_size=batch_size,
    shuffle=False)

# sample plot
for imgs, targets in train_loader:
    print(imgs.shape)
    break

torch.Size([128, 784])


## 3. Building a variational autoencoder model

In [3]:
class VAE(nn.Module):
    def __init__(self, nx, nh, nz):
        super().__init__()
        
        self.nz = nz
        
        # encoder
        self.fc11 = nn.Linear(nx, nh)
        self.mu = nn.Linear(nh, nz)
        self.log_var = nn.Linear(nh, nz)
        
        # decoder
        self.fc21 = nn.Linear(nz, nh)
        self.fc22 = nn.Linear(nh, nx)
        
    def encoder(self, x):
        h = F.relu(self.fc11(x))
        return self.mu(h), self.log_var(h)
    
    def decoder(self, z):
        h = F.relu(self.fc21(z))
        return F.sigmoid(self.fc22(h))
    
    def sample_z(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)
    
    def forward(self, x):
        mu, log_var = self.encoder(x)
        z = self.sample_z(mu, log_var)
        return self.decoder(z), mu, log_var
    
vae = VAE(nx=784, nh=512, nz=2)
vae.to(device)
vae

VAE(
  (fc11): Linear(in_features=784, out_features=512, bias=True)
  (mu): Linear(in_features=512, out_features=2, bias=True)
  (log_var): Linear(in_features=512, out_features=2, bias=True)
  (fc21): Linear(in_features=2, out_features=512, bias=True)
  (fc22): Linear(in_features=512, out_features=784, bias=True)
)

\begin{align}
\mathcal{D}_{KL} [ \mathcal{N} (\mu_0, \Sigma_0) || \mathcal{N} (0, I)] = \frac{1}{2} \sum_k \left( \exp(\Sigma(X)) + \mu^2 (X) - 1 - \Sigma(X)\right)
\end{align}

#### Non-negative Kullback-Leibler divergence

\begin{align}
\mathcal{D}_{KL} &= \mathbb{E}_{P} \log \frac{P}{Q} \\
&= \mathbb{E}_{P} \left(- \log \frac{Q}{P} \right) \\
&\stackrel{\text{Jensen's inequility}}{\ge} -log \mathbb{E}_{P} \frac{Q}{P} \; \because -\log(x) \; \text{is a convex function}\\
&= - \log 1 \\
&= 0
\end{align}

In [4]:
optimizer = optim.Adam(vae.parameters())

def loss_function(recon_x, x, mu, log_var):
    BCE = -F.binary_cross_entropy(recon_x, x, reduction='sum')
    DKL = 0.5 * torch.sum(log_var.exp() + mu.pow(2) - 1 - log_var)
    return -(BCE - DKL)  # first minus sign to maximize

def train(epoch):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.cuda()
        optimizer.zero_grad()
        recon_batch, mu, log_var = vae(data)
        loss = loss_function(recon_batch, data, mu, log_var)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item() / len(data)))
    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))

In [5]:
for epoch in range(5):
    train(epoch)



====> Epoch: 0 Average loss: 186.3493
====> Epoch: 1 Average loss: 165.4888
====> Epoch: 2 Average loss: 161.6146
====> Epoch: 3 Average loss: 159.3400
====> Epoch: 4 Average loss: 157.5723


## Reference

1. D. Carl, Tutorial on variational autoencoders, arXiv:1606.05908v2, 2016
2. R.G. Krishnan, U. Shalit, D. Sontag, Deep Kalman Filters, arXiv:1511.05121v2, 2015
3. J. Duchi, [Derivations for linear algebra and optimization](http://web.stanford.edu/~jduchi/projects/general_notes.pdf)
4. A. Kristladl, [Variational autoencoder: intuition and implementation](https://wiseodd.github.io/techblog/2016/12/10/variational-autoencoder/), blog post 2016
5. https://github.com/lyeoni/pytorch-mnist-VAE
6. [mxnet variational autoencoder example](https://github.com/apache/incubator-mxnet/tree/master/example/autoencoder/variational_autoencoder)