# 2-Stage VAE

Based on [Diagnosing and Enhancing VAE Models](https://arxiv.org/abs/1903.05789) by Dai and Wipf.


In [1]:
import json
import os, glob

from tqdm import tqdm_notebook as tqdm

import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfs = tf.contrib.summary
tfe = tf.contrib.eager

import sonnet as snt
import numpy as np
from functools import reduce

from utils import is_valid_file, setup_eager_checkpoints_and_restore

tf.enable_eager_execution()

W0705 16:51:52.187434 140449970063104 lazy_loader.py:50] 
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

W0705 16:51:52.202286 140449970063104 deprecation_wrapper.py:119] From /homes/gf332/Documents/projects/VAEs/vae_venv/lib/python3.7/site-packages/sonnet/python/custom_getters/restore_initializer.py:27: The name tf.GraphKeys is deprecated. Please use tf.compat.v1.GraphKeys instead.



In [2]:
def mnist_input_fn(data, batch_size=256, shuffle_samples=5000):
    dataset = tf.data.Dataset.from_tensor_slices(data)
    dataset = dataset.shuffle(shuffle_samples)
    dataset = dataset.map(mnist_parse_fn)
    dataset = dataset.batch(batch_size)

    return dataset


def mnist_parse_fn(data):
    return tf.cast(data, tf.float32) / 255.


optimizers = {
    "sgd": tf.train.GradientDescentOptimizer,
    "momentum": lambda lr:
                    tf.train.MomentumOptimizer(learning_rate=lr,
                                               momentum=0.9,
                                               use_nesterov=True),
    "adam": tf.train.AdamOptimizer,
    "rmsprop": tf.train.RMSPropOptimizer
}

In [3]:
class MnistTwoStageVAE(snt.AbstractModule):

    def __init__(self, 
                 latent_dim=32, 
                 second_stage_depth=3,
                 name="mnist_vae"):

        super(MnistTwoStageVAE, self).__init__(name=name)

        self.latent_dim = latent_dim
        self.second_stage_depth = second_stage_depth
        
        self.is_training = True
        self.use_second_stage = False
        
        self.first_run = True


    @property
    def kl_first_stage(self):
        """
        Calculates the KL divergence between the current variational posterior and the prior:

        KL[ q(z | theta) || p(z) ]

        """
        self._ensure_is_connected()
        
        return tfd.kl_divergence(self.latent_posterior, self.latent_prior)

    
    @property
    def kl_second_stage(self):
        self._ensure_is_connected()
        
        return tfd.kl_divergence(self.second_stage_posterior, self.second_stage_prior)
    
    
    @property
    def log_prob_first_stage(self):
        return tf.reduce_sum(self._log_prob)
    
    
    @property
    def log_prob_second_stage(self):
        return tf.reduce_sum(self._log_prob_second_stage)
    
    @snt.reuse_variables
    def get_first_stage_variables(self):
        
        all_variables = ()
        
        # Add all variables from the encoder
        for layer in self.encoder_layers:
            
            if isinstance(layer, snt.AbstractModule):
                all_variables += layer.get_all_variables()
                
        all_variables += self.encoder_loc_head.get_all_variables()  
        all_variables += self.encoder_log_scale_head.get_all_variables() 
        
        # Add all variables from the decoder
        for layer in self.decoder_levels:    
            
            if isinstance(layer, snt.AbstractModule):
                all_variables += layer.get_all_variables()
            
        # Add gamma
        all_variables += (tf.get_variable("log_gamma_x", [], tf.float32, tf.zeros_initializer()),)
        
        return all_variables
            
    
    @snt.reuse_variables
    def get_second_stage_variables(self):
        
        all_variables = ()
        
        # Add all variables from the encoder
        for layer in self.second_stage_encoder_layers:
            
            if isinstance(layer, snt.AbstractModule):
                all_variables += layer.get_all_variables()
                
        all_variables += self.second_stage_loc_head.get_all_variables()  
        all_variables += self.second_stage_log_scale_head.get_all_variables() 
        
        # Add all variables from the decoder
        for layer in self.second_stage_decoder_layers:    
            
            if isinstance(layer, snt.AbstractModule):
                all_variables += layer.get_all_variables()
            
        all_variables += self.second_stage_residual_head.get_all_variables()
        
        # Add gamma
        all_variables += (tf.get_variable("log_gamma_z", [], tf.float32, tf.zeros_initializer()),)
        
        return all_variables

    # =====================================================================
    # First stage
    # =====================================================================

    @snt.reuse_variables
    def encode_first_stage(self, inputs):
        """
        Builds the encoder part of the VAE, i.e. q(x | theta).
        This maps from the input to the latent representation.
        """

        # ----------------------------------------------------------------
        # Define Layers
        # ----------------------------------------------------------------
        self.encoder_layers = [
            snt.Conv2D(output_channels=64,
                       kernel_shape=(5, 5),
                       stride=2,
                       name="encoder_conv1"),
            tf.nn.leaky_relu,
            snt.Conv2D(output_channels=128,
                       kernel_shape=(5, 5),
                       stride=2,
                       use_bias=False,
                       name="encoder_conv2"),
            snt.BatchNorm(), #tf.keras.layers.BatchNormalization(),
            tf.nn.leaky_relu,
            snt.BatchFlatten(),
            snt.Linear(output_size=1024,
                       use_bias=False,
                       name="encoder_linear1"),
            snt.BatchNorm(), #tf.keras.layers.BatchNormalization(),
            tf.nn.leaky_relu
        ]
        
        self.encoder_loc_head = snt.Linear(output_size=self.latent_dim,
                                           name="encoder_loc_head")
        self.encoder_log_scale_head = snt.Linear(output_size=self.latent_dim,
                                                 name="encoder_log_scale_head")
        
        # ----------------------------------------------------------------
        # Apply Layers
        # ----------------------------------------------------------------

        activations = inputs
        
        for layer in self.encoder_layers:
            if isinstance(layer, snt.BatchNorm): #tf.keras.layers.BatchNormalization):
                activations = layer(activations, is_training=self.is_training and not self.use_second_stage)
            else:
                activations = layer(activations)
            
        # Get latent posterior statistics
        loc = self.encoder_loc_head(activations)
        scale = 1e-6 + tf.nn.softplus(self.encoder_log_scale_head(activations))
        
        # Create latent posterior
        self.latent_posterior = tfd.Normal(loc=loc, scale=scale)
        
        return self.latent_posterior.sample()


    @snt.reuse_variables
    def decode_first_stage(self, latent_code):
        """
        Builds the decoder part of the VAE
        """
        # ----------------------------------------------------------------
        # Define Layers
        # ----------------------------------------------------------------
        
        self.decoder_levels = [self.encoder_loc_head.transpose()]
        
        for layer in self.encoder_layers[::-1]:
            
            # Some layers need care for reversing
            if isinstance(layer, snt.Transposable):
                layer = layer.transpose()
                
            elif isinstance(layer, snt.BatchFlatten):
                layer = BatchReshape((28, 28, 1))
           
            # Add layer
            self.decoder_levels.append(layer)
        
        
        # ----------------------------------------------------------------
        # Apply Layers
        # ----------------------------------------------------------------

        # Create prior
        self.latent_prior = tfd.Normal(loc=tf.zeros_like(latent_code),
                                       scale=tf.ones_like(latent_code))
        
        activations = latent_code
        
        for layer in self.decoder_levels:
            if isinstance(layer, snt.BatchNorm): #tf.keras.layers.BatchNormalization):
                activations = layer(activations, is_training=self.is_training and not self.use_second_stage)
            else:
                activations = layer(activations)
        
        return tf.nn.sigmoid(activations)
    
    # =====================================================================
    # Second stage
    # =====================================================================
    
    @snt.reuse_variables
    def encode_second_stage(self, inputs):
        # ----------------------------------------------------------------
        # Define Layers
        # ----------------------------------------------------------------
        
        self.second_stage_encoder_layers = []
        
        for i in range(self.second_stage_depth):
            
            self.second_stage_encoder_layers.append(snt.Linear(output_size=self.latent_dim))
            self.second_stage_encoder_layers.append(tf.nn.relu)
            
        self.second_stage_loc_head = snt.Linear(output_size=self.latent_dim)
        self.second_stage_log_scale_head = snt.Linear(output_size=self.latent_dim)
        
        # ----------------------------------------------------------------
        # Apply Layers
        # ----------------------------------------------------------------
        
        activations = inputs
        
        for layer in self.second_stage_encoder_layers:
            activations = layer(activations)
            
        # Add residual connection
        activations = tf.concat([inputs, activations], axis=-1)
        
        # Get second stage latent statistics
        loc = self.second_stage_loc_head(activations)
        log_scale = self.second_stage_log_scale_head(activations)
        scale = 1e-6 + tf.nn.softplus(log_scale)
        
        # Create second stage distribution
        self.second_stage_posterior = tfd.Normal(loc=loc,
                                                 scale=scale)
        
        return self.second_stage_posterior.sample()
        
    
    @snt.reuse_variables
    def decode_second_stage(self, latents):
        # ----------------------------------------------------------------
        # Define Layers
        # ----------------------------------------------------------------
        
        self.second_stage_decoder_layers = []
        
        for i in range(self.second_stage_depth):
            
            self.second_stage_decoder_layers.append(snt.Linear(output_size=self.latent_dim))
            self.second_stage_decoder_layers.append(tf.nn.relu)
            
        self.second_stage_residual_head = snt.Linear(output_size=self.latent_dim)
            
        # ----------------------------------------------------------------
        # Apply Layers
        # ----------------------------------------------------------------
        
        self.second_stage_prior = tfd.Normal(loc=tf.zeros_like(latents),
                                             scale=tf.ones_like(latents))
        
        activations = latents
        
        for layer in self.second_stage_decoder_layers:
            activations = layer(activations)
            
        # Add residual connection
        activations = tf.concat([latents, activations], axis=-1)
        
        likelihood_loc = self.second_stage_residual_head(activations)
        
        return likelihood_loc
        
    
    # =====================================================================
    # Build
    # =====================================================================
    
    def _build(self, inputs):
        """
        Build standard VAE:
        1. Encode input -> latent mu, sigma
        2. Sample z ~ N(z | mu, sigma)
        """
        
        reshaper = snt.BatchReshape((28, 28, 1))
        inputs = reshaper(inputs)
        
        # Code the latents on the first stage
        latents = self.encode_first_stage(inputs)
        
        # If the first stage is trained, train the second stage
        if self.use_second_stage or self.first_run:
            
            second_stage_latents = self.encode_second_stage(latents)
            latents_ = self.decode_second_stage(second_stage_latents)
            
            self.latent_log_gamma = tf.get_variable("log_gamma_z", [], tf.float32, tf.zeros_initializer())
            self.latent_gamma = tf.exp(self.latent_log_gamma)

            # Create likelihood distribution
            self.likelihood_second_stage = tfd.Normal(loc=latents,
                                                      scale=self.latent_gamma)
            
            self._log_prob_second_stage = self.likelihood_second_stage.log_prob(latents)
            
            latents = latents_
            
            if self.first_run:
                self.first_run = False
        
        # Reconsturct image from the latents
        reconstruction = self.decode_first_stage(latents)
        
        self.log_gamma = tf.get_variable("log_gamma_x", [], tf.float32, tf.zeros_initializer())
        self.likelihood = tfd.Normal(loc=reconstruction,
                                     scale=tf.exp(self.log_gamma))
        self._log_prob = self.likelihood.log_prob(inputs)
        
        return reconstruction

In [4]:
mvae = MnistTwoStageVAE()

test_ones = tf.ones((1, 28, 28))

mvae(test_ones)
mvae(test_ones)

mvae.kl_first_stage
mvae.log_prob_first_stage

mvae.use_second_stage = True

mvae(test_ones)

fsv = mvae.get_first_stage_variables()
ssv = mvae.get_second_stage_variables()

var_names = set(map(lambda x: x.name, fsv)) | set(map(lambda x: x.name, ssv))
all_var_names = set(map(lambda x: x.name, mvae.get_all_variables()))

# Check if we have all variables captured
print(var_names - all_var_names)
print(all_var_names - var_names)

W0705 16:51:52.350741 140449970063104 deprecation_wrapper.py:119] From /homes/gf332/Documents/projects/VAEs/vae_venv/lib/python3.7/site-packages/sonnet/python/modules/base.py:177: The name tf.make_template is deprecated. Please use tf.compat.v1.make_template instead.

W0705 16:51:53.129023 140449970063104 deprecation_wrapper.py:119] From /homes/gf332/Documents/projects/VAEs/vae_venv/lib/python3.7/site-packages/sonnet/python/modules/base.py:278: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0705 16:51:53.132635 140449970063104 deprecation_wrapper.py:119] From /homes/gf332/Documents/projects/VAEs/vae_venv/lib/python3.7/site-packages/sonnet/python/modules/base.py:579: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.

W0705 16:51:53.136412 140449970063104 deprecation_wrapper.py:119] From /homes/gf332/Documents/projects/VAEs/vae_venv/lib/python3.7/site-packages/sonnet/python/modules/conv.py:134: The

set()
set()


In [5]:
def run(config, model_dir, is_training, train_first_stage=True):

    num_batches = config["num_training_examples"] // config["batch_size"] + 1
  
    print("Configuration:")
    print(json.dumps(config, indent=4, sort_keys=True))

    # ==========================================================================
    # Load dataset
    # ==========================================================================

    ((train_data, _),
     (eval_data, _)) = tf.keras.datasets.mnist.load_data()

    # ==========================================================================
    # Create model
    # ==========================================================================

    g = tf.get_default_graph()
    
    with g.as_default():
    
        vae = MnistTwoStageVAE(latent_dim=64)
        vae(tf.zeros((1, 28, 28)))
        
        del vae
        
    vae = MnistTwoStageVAE(latent_dim=64)
    vae(tf.zeros((1, 28, 28)))
    
    vae.use_second_stage = not train_first_stage

    optimizer = optimizers[config["optimizer"]](config["learning_rate"])

    # ==========================================================================
    # Define Checkpoints
    # ==========================================================================

    global_step = tf.train.get_or_create_global_step()

    trainable_vars = vae.get_all_variables() + (global_step,)
    checkpoint_dir = os.path.join(model_dir, "checkpoints")

    checkpoint, ckpt_prefix = setup_eager_checkpoints_and_restore(
        variables=trainable_vars,
        checkpoint_dir=checkpoint_dir,
        checkpoint_name=config["checkpoint_name"])

    # ==========================================================================
    # Define Tensorboard Summary writer
    # ==========================================================================

    logdir = os.path.join(model_dir, "log")
    writer = tfs.create_file_writer(logdir)
    writer.set_as_default()

    tfs.graph(g)
    tfs.flush(writer)

    # ==========================================================================
    # Train the model
    # ==========================================================================

    if is_training:
        
        if train_first_stage:
            beta = config["beta1"]

            for epoch in range(1, config["num_epochs"] + 1):

                dataset = mnist_input_fn(data=train_data,
                                        batch_size=config["batch_size"])

                with tqdm(total=num_batches) as pbar:
                    for batch in dataset:
                        # Increment global step
                        global_step.assign_add(1)

                        # Record gradients of the forward pass
                        with tf.GradientTape() as tape, tfs.record_summaries_every_n_global_steps(config["log_freq"]):

                            output = vae(batch)

                            kl = vae.kl_first_stage
                            total_kl = tf.reduce_sum(kl)

                            log_prob = vae.log_prob_first_stage

                            warmup_coef = tf.minimum(1., global_step.numpy() / (config["warmup"] * num_batches))

                            # negative ELBO
                            loss = total_kl - beta * warmup_coef * log_prob 

                            output = tf.cast(output, tf.float32)

                            # Add tensorboard summaries
                            tfs.scalar("Loss", loss)
                            tfs.scalar("Total_KL", kl)
                            tfs.scalar("Max_KL", tf.reduce_max(kl))
                            tfs.scalar("Log-Probability", log_prob)
                            tfs.scalar("Warmup_Coef", warmup_coef)
                            tfs.scalar("Gamma-x", tf.exp(vae.log_gamma))
                            tfs.image("Reconstruction", output)

                        # Backprop
                        grads = tape.gradient(loss, vae.get_first_stage_variables())
                        optimizer.apply_gradients(zip(grads, vae.get_first_stage_variables()))

                        # Update the progress bar
                        pbar.update(1)
                        pbar.set_description("Epoch {}, ELBO: {:.2f}".format(epoch, loss))

                checkpoint.save(ckpt_prefix)

            tfs.flush(writer)
            print("First Stage Training Complete!")
        
        vae.use_second_stage = True
        
        beta = config["beta2"]
        
        for epoch in range(1, config["num_epochs_stage_2"] + 1):

            dataset = mnist_input_fn(data=train_data,
                                     batch_size=config["batch_size"])

            with tqdm(total=num_batches) as pbar:
                for batch in dataset:
                    # Increment global step
                    global_step.assign_add(1)

                    # Record gradients of the forward pass
                    with tf.GradientTape() as tape, tfs.record_summaries_every_n_global_steps(config["log_freq"]):

                        output = vae(batch)

                        kl = vae.kl_second_stage
                        total_kl = tf.reduce_sum(kl)
                        
                        log_prob = vae.log_prob_second_stage

                        warmup_coef = tf.minimum(1., global_step.numpy() / (config["warmup"] * num_batches))

                        # negative ELBO
                        loss = total_kl - beta * warmup_coef * log_prob 

                        output = tf.cast(output, tf.float32)

                        # Add tensorboard summaries
                        tfs.scalar("Loss", loss)
                        tfs.scalar("Total_KL", kl)
                        tfs.scalar("Max_KL", tf.reduce_max(kl))
                        tfs.scalar("Log-Probability", log_prob)
                        tfs.scalar("Warmup_Coef", warmup_coef)
                        tfs.scalar("Gamma-z", tf.exp(vae.latent_gamma))
                        tfs.image("Reconstruction", output)

                    # Backprop
                    grads = tape.gradient(loss, vae.get_second_stage_variables())
                    optimizer.apply_gradients(zip(grads, vae.get_second_stage_variables()))

                    # Update the progress bar
                    pbar.update(1)
                    pbar.set_description("Epoch {}, ELBO: {:.2f}".format(epoch, loss))

            checkpoint.save(ckpt_prefix)
            
        tfs.flush(writer)
        print("Second Stage Training Complete!")

    return vae

In [6]:
MODEL_DIR = "/tmp/2-stage-vae/"

config = {
        "num_training_examples": 60000,
        "batch_size": 250,
        "num_epochs": 5,
        "num_epochs_stage_2": 5,
        
        "beta1": 0.1,
        "beta2": 0.1,
        "warmup": 10.,
        
        "learning_rate": 1e-3,
        
        "optimizer": "adam",
        
        "checkpoint_name": "_ckpt",
        "log_freq": 100,
    }

vae = run(config, model_dir=MODEL_DIR, is_training=True, train_first_stage=True)

Configuration:
{
    "batch_size": 250,
    "beta1": 0.1,
    "beta2": 0.1,
    "checkpoint_name": "_ckpt",
    "learning_rate": 0.001,
    "log_freq": 100,
    "num_epochs": 5,
    "num_epochs_stage_2": 5,
    "num_training_examples": 60000,
    "optimizer": "adam",
    "warmup": 10.0
}


W0705 16:52:17.449425 140449970063104 deprecation.py:506] From /homes/gf332/Documents/projects/VAEs/vae_venv/lib/python3.7/site-packages/sonnet/python/modules/conv.py:298: calling TruncatedNormal.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W0705 16:52:17.456937 140449970063104 deprecation.py:506] From /homes/gf332/Documents/projects/VAEs/vae_venv/lib/python3.7/site-packages/sonnet/python/modules/conv.py:303: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


Checkpoint found at /tmp/2-stage-vae/checkpoints/_ckpt-5, restoring...
Model restored!


HBox(children=(IntProgress(value=0, max=241), HTML(value='')))




HBox(children=(IntProgress(value=0, max=241), HTML(value='')))




HBox(children=(IntProgress(value=0, max=241), HTML(value='')))




HBox(children=(IntProgress(value=0, max=241), HTML(value='')))




HBox(children=(IntProgress(value=0, max=241), HTML(value='')))


First Stage Training Complete!


HBox(children=(IntProgress(value=0, max=241), HTML(value='')))




HBox(children=(IntProgress(value=0, max=241), HTML(value='')))




HBox(children=(IntProgress(value=0, max=241), HTML(value='')))




HBox(children=(IntProgress(value=0, max=241), HTML(value='')))




HBox(children=(IntProgress(value=0, max=241), HTML(value='')))


Second Stage Training Complete!


In [9]:
snt.BatchNorm()(tf.ones((1, 28, 28, 3)), is_training=False, test)

<tf.Tensor: id=1153684, shape=(1, 28, 28, 3), dtype=float32, numpy=
array([[[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        ...,

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0