#### Reference:
- https://github.com/pytorch/examples/blob/master/vae/main.py
-  https://github.com/wiseodd/generative-models/blob/master/VAE/vanilla_vae/vae_pytorch.py

In [16]:
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.init
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

## Config

In [33]:
class BaseConfig():
    def __init__(self):
        self.batch_size = 64
        self.epochs = 10
        self.cuda = False
        self.seed = 1
        self.log_interval = 100

config = BaseConfig()

# seed
torch.manual_seed(args.seed)
if config.cuda:
    torch.cuda.manual_seed(args.seed)

## Dataset

In [34]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../../../datasets/', train=True, download=False, transform=transforms.ToTensor()),
    batch_size=args.batch_size,
    shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../../../datasets', train=False, download=False, transform=transforms.ToTensor()),
    batch_size=args.batch_size,
    shuffle=True)

In [35]:
import torch
import torch.nn as nn
import torch.nn.init
import torch.optim as optim
from torch.autograd import Variable

# Model
class VAE(nn.Module):
    def __init__(self, config):
        super(VAE, self).__init__()
        self.config = config

        # P(Z|X)
        self.fc_xh = nn.Linear(784, config.h_size)
        self.fc_hz_mu = nn.Linear(config.h_size, config.z_size)
        self.fc_hz_var = nn.Linear(config.h_size, config.z_size)

        # P(X|Z)
        self.fc_zh = nn.Linear(config.z_size, config.h_size)
        self.fc_hx = nn.Linear(config.h_size, 784)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_normal(module.weight.data)

    def encode(self, x):
        """
        Args:
            x [batch_size, 28*28]
        Return:
            mu [batch_size, z_size]
            log_variance [batch_size, z_size]
        """
        h = self.relu(self.fc_xh(x))

        mu = self.fc_hz_mu(h)
        log_variance = self.fc_hz_var(h)

        return mu, log_variance

    def reparameterize(self, mu, log_variance):
        """Sample z via reparameterization"""
        std = log_variance.mul(0.5).exp_()

        # Sampling from gaussian distribution
        if self.config.cuda:
            eps = torch.cuda.FloatTensor(std.size()).normal_()
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)

        z = eps.mul(std).add_(mu)
        return z

    def decode(self, z):
        """Reconstruct X with P(X|Z)"""
        h = self.relu(self.fc_zh(z))
        return self.sigmoid(self.fc_hx(h))

    def forward(self, x):
        # Encode X => Z
        mu, log_variance = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, log_variance)

        # Reconstruct Z => X
        x_recon = self.decode(z)
        return x_recon, mu, log_variance


In [36]:
model = VAE()
if args.cuda:
    model.cuda()
model.initialize()

binary_xent = nn.BCELoss()
binary_xent.size_average = False # losses are summed


optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Train

In [46]:
def train_one_epoch(epoch):
    model.train() # set to training mode
    train_loss = 0
    for batch_idx, (images, _) in enumerate(train_loader):
        images = Variable(images) # [batch_size, 1, 28, 28]
        if args.cuda:
            images = images.cuda()
        optimizer.zero_grad()
        
        recon_images, mu, log_variance = model(images)
        loss = loss_function(recon_images, images, mu, log_variance)
        
        loss.backward()
        train_loss += loss.data[0]
        optimizer.step()
        if batch_idx > 0 and batch_idx % args.log_interval == 0:
            log_string = f'Epoch {epoch} | '
            log_string += f'[{batch_idx*len(images)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\t'
            log_string += f'Loss: {loss.data[0] / len(images):.4f}'
            print(log_string)

    print(f'Epoch {epoch} | Average Loss: {train_loss / len(train_loader.dataset):.4f}\n')

    # Save original images
    if epoch == 1:
        images = images.view(images.size(0), 1, 28, 28)
        save_image(images.data, './data/real_images.png')


    # Save reconstructed images
    recon_images = recon_images.view(recon_images.size(0), 1, 28, 28)
    save_image(recon_images.data, './data/recon_images-%d.png' % (epoch))

for i in range(2):
    train_one_epoch(i+1)

## Test

In [None]:
def test()