# 2-Stage VAE

Based on [Diagnosing and Enhancing VAE Models](https://arxiv.org/abs/1903.05789) by Dai and Wipf.


In [8]:
import json
import os, glob

from tqdm import tqdm_notebook as tqdm

import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfs = tf.contrib.summary
tfe = tf.contrib.eager

import sonnet as snt
import numpy as np
from functools import reduce

from utils import is_valid_file, setup_eager_checkpoints_and_restore

tf.enable_eager_execution()

In [2]:
def mnist_input_fn(data, batch_size=256, shuffle_samples=5000):
    dataset = tf.data.Dataset.from_tensor_slices(data)
    dataset = dataset.shuffle(shuffle_samples)
    dataset = dataset.map(mnist_parse_fn)
    dataset = dataset.batch(batch_size)

    return dataset


def mnist_parse_fn(data):
    return tf.cast(data, tf.float32) / 255.


optimizers = {
    "sgd": tf.train.GradientDescentOptimizer,
    "momentum": lambda lr:
                    tf.train.MomentumOptimizer(learning_rate=lr,
                                               momentum=0.9,
                                               use_nesterov=True),
    "adam": tf.train.AdamOptimizer,
    "rmsprop": tf.train.RMSPropOptimizer
}

In [13]:
class MnistVAE(snt.AbstractModule):

    def __init__(self, latent_dim=2, name="mnist_vae"):

        super(MnistVAE, self).__init__(name=name)

        self.latent_dim = latent_dim
        self.is_training = True


    @property
    def kl_divergence(self):
        """
        Calculates the KL divergence between the current variational posterior and the prior:

        KL[ q(z | theta) || p(z) ]

        """
        self._ensure_is_connected()
        
        return tfd.kl_divergence(self.latent_posterior, self.latent_prior)

    @property
    def input_log_prob(self):
        """
        Returns the log-likelihood of the current input for the output Bernoulli
        """
        return tf.reduce_sum(self.log_prob)


    @snt.reuse_variables
    def encode(self, inputs):
        """
        Builds the encoder part of the VAE, i.e. q(x | theta).
        This maps from the input to the latent representation.
        """

        # ----------------------------------------------------------------
        # Define Layers
        # ----------------------------------------------------------------
        self.encoder_layers = [
            snt.Conv2D(output_channels=64,
                       kernel_shape=(5, 5),
                       stride=2),
            tf.nn.leaky_relu,
            snt.Conv2D(output_channels=128,
                       kernel_shape=(5, 5),
                       stride=2,
                       use_bias=False),
            snt.BatchNorm(),
            tf.nn.leaky_relu,
            snt.BatchFlatten(),
            snt.Linear(output_size=1024,
                       use_bias=False),
            snt.BatchNorm(),
            tf.nn.leaky_relu
        ]
        
        self.encoder_loc_head = snt.Linear(output_size=self.latent_dim)
        self.encoder_log_scale_head = snt.Linear(output_size=self.latent_dim)
        
        # ----------------------------------------------------------------
        # Apply Layers
        # ----------------------------------------------------------------

        activations = inputs
        
        for layer in self.encoder_layers:
            if isinstance(layer, snt.BatchNorm):
                activations = layer(activations, is_training=self.is_training)
            else:
                activations = layer(activations)
            
        # Get latent posterior statistics
        loc = self.encoder_loc_head(activations)
        scale = 1e-6 + tf.nn.softplus(self.encoder_log_scale_head(activations))
        
        # Create latent posterior
        self.latent_posterior = tfd.Normal(loc=loc, scale=scale)
        
        return self.latent_posterior.sample()


    @snt.reuse_variables
    def decode(self, latent_code):
        """
        Builds the decoder part of the VAE
        """
        # ----------------------------------------------------------------
        # Define Layers
        # ----------------------------------------------------------------
        
        self.decoder_levels = [self.encoder_loc_head.transpose()]
        
        for layer in self.encoder_layers[::-1]:
            
            if isinstance(layer, snt.Transposable):
                self.decoder_levels.append(layer.transpose())
                
            else:
                if isinstance(layer, snt.BatchFlatten):
                    self.decoder_levels.append(BatchReshape((28, 28, 1)))
        
        # ----------------------------------------------------------------
        # Apply Layers
        # ----------------------------------------------------------------

        # Create prior
        self.latent_prior = tfd.Normal(loc=tf.zeros_like(latent_code),
                                       scale=tf.ones_like(latent_code))
        
        activations = latent_code
        
        for layer in self.decoder_levels:
            activations = layer(activations)
        
        return tf.nn.sigmoid(activations)


    def _build(self, inputs):
        """
        Build standard VAE:
        1. Encode input -> latent mu, sigma
        2. Sample z ~ N(z | mu, sigma)
        """
        
        reshaper = snt.BatchReshape((28, 28, 1))
        inputs = reshaper(inputs)
        
        latents = self.encode(inputs)
        reconstruction = self.decode(latents)
        
        self.log_gamma = tf.get_variable("log_gamma_x", [], tf.float32, tf.zeros_initializer())
        self.likelihood = tfd.Normal(loc=reconstruction,
                                     scale=tf.exp(self.log_gamma))
        self.log_prob = self.likelihood.log_prob(inputs)
        
        return reconstruction

In [14]:
def run(config, model_dir, is_training):

    num_batches = config["num_training_examples"] // config["batch_size"] + 1
  
    print("Configuration:")
    print(json.dumps(config, indent=4, sort_keys=True))

    # ==========================================================================
    # Load dataset
    # ==========================================================================

    ((train_data, _),
     (eval_data, _)) = tf.keras.datasets.mnist.load_data()

    # ==========================================================================
    # Create model
    # ==========================================================================

    g = tf.get_default_graph()
    
    with g.as_default():
    
        vae = MnistVAE()
        vae(tf.zeros((1, 28, 28)))
        
        del vae
        
    vae = MnistVAE()
    vae(tf.zeros((1, 28, 28)))

    optimizer = optimizers[config["optimizer"]](config["learning_rate"])

    # ==========================================================================
    # Define Checkpoints
    # ==========================================================================

    global_step = tf.train.get_or_create_global_step()

    trainable_vars = vae.get_all_variables() + (global_step,)
    checkpoint_dir = os.path.join(model_dir, "checkpoints")

    checkpoint, ckpt_prefix = setup_eager_checkpoints_and_restore(
        variables=trainable_vars,
        checkpoint_dir=checkpoint_dir,
        checkpoint_name=config["checkpoint_name"])

    # ==========================================================================
    # Define Tensorboard Summary writer
    # ==========================================================================

    logdir = os.path.join(model_dir, "log")
    writer = tfs.create_file_writer(logdir)
    writer.set_as_default()

    tfs.graph(g)
    tfs.flush(writer)

    # ==========================================================================
    # Train the model
    # ==========================================================================

    beta = config["beta"]

    if is_training:
        for epoch in range(1, config["num_epochs"] + 1):

            dataset = mnist_input_fn(data=train_data,
                                    batch_size=config["batch_size"])

            with tqdm(total=num_batches) as pbar:
                for batch in dataset:
                    # Increment global step
                    global_step.assign_add(1)

                    # Record gradients of the forward pass
                    with tf.GradientTape() as tape, tfs.record_summaries_every_n_global_steps(config["log_freq"]):

                        output = vae(batch)

                        kl = vae.kl_divergence 
                        total_kl = tf.reduce_sum(kl)
                        
                        log_prob = vae.input_log_prob

                        warmup_coef = tf.minimum(1., global_step.numpy() / (config["warmup"] * num_batches))

                        # negative ELBO
                        loss = total_kl - beta * warmup_coef * log_prob 

                        output = tf.cast(tf.expand_dims(output, axis=-1), tf.float32)

                        # Add tensorboard summaries
                        tfs.scalar("Loss", loss)
                        tfs.scalar("Total_KL", kl)
                        tfs.scalar("Max_KL", tf.reduce_max(kl))
                        tfs.scalar("Log-Probability", log_prob)
                        tfs.scalar("Warmup_Coef", warmup_coef)
                        tfs.scalar("Gamma-x", tf.exp(vae.log_gamma))
                        tfs.image("Reconstruction", output)

                    # Backprop
                    grads = tape.gradient(loss, vae.get_all_variables())
                    optimizer.apply_gradients(zip(grads, vae.get_all_variables()))

                    # Update the progress bar
                    pbar.update(1)
                    pbar.set_description("Epoch {}, ELBO: {:.2f}".format(epoch, loss))

            checkpoint.save(ckpt_prefix)

    return vae

In [15]:
MODEL_DIR = "/tmp/2-stage-vae/"

config = {
        "num_training_examples": 60000,
        "batch_size": 250,
        "num_epochs": 40,
        
        "beta": 1.,
        "warmup": 20.,
        
        "learning_rate": 1e-3,
        
        "optimizer": "adam",
        
        "checkpoint_name": "_ckpt",
        "log_freq": 100,
    }

vae = run(config, model_dir=MODEL_DIR, is_training=True)

Configuration:
{
    "batch_size": 250,
    "beta": 1.0,
    "checkpoint_name": "_ckpt",
    "learning_rate": 0.001,
    "log_freq": 100,
    "num_epochs": 40,
    "num_training_examples": 60000,
    "optimizer": "adam",
    "warmup": 20.0
}
No checkpoint found!


HBox(children=(IntProgress(value=0, max=241), HTML(value='')))




KeyboardInterrupt: 