# VAE

## PGM

In [1]:
"""Variational auto-encoder for MNIST data.

References
----------
http://edwardlib.org/tutorials/decoder
http://edwardlib.org/tutorials/inference-networks
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import edward as ed
import numpy as np
import os
import tensorflow as tf

from edward.models import Bernoulli, Normal
from edward.util import Progbar
from keras.layers import Dense

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('/tmp')

from scipy.misc import imsave


def generator(array, batch_size):
  """Generate batch with respect to array's first axis."""
  start = 0  # pointer to where we are in iteration
  while True:
    stop = start + batch_size
    diff = stop - array.shape[0]
    if diff <= 0:
      batch = array[start:stop]
      start += batch_size
    else:
      batch = np.concatenate((array[start:], array[:diff]))
      start = diff
    batch = np.random.binomial(1, batch)  # binarize images
    yield batch


ed.set_seed(42)

data_dir = "/tmp/data"
out_dir = "/tmp/out"
if not os.path.exists(out_dir):
  os.makedirs(out_dir)
M = 100  # batch size during training
d = 2  # latent dimension

# DATA. MNIST batches are fed at training time.
#(x_train, _), (x_test, _) = mnist(data_dir)
x_train, x_test = mnist.train.images, mnist.test.images
x_train_generator = generator(x_train, M)

# MODEL
# Define a subgraph of the full model, corresponding to a minibatch of
# size M.
z = Normal(loc=tf.zeros([M, d]), scale=tf.ones([M, d]))
hidden = Dense(256, activation='relu')(z.value())
hidden = Dense(256, activation='relu')(hidden)
logits = Dense(28 * 28)(hidden)
x = Bernoulli(logits=logits)

# INFERENCE
# Define a subgraph of the variational model, corresponding to a
# minibatch of size M.
x_ph = tf.placeholder(tf.int32, [M, 28 * 28])
hidden = Dense(256, activation='relu')(tf.cast(x_ph, tf.float32))
hidden = Dense(256, activation='relu')(hidden)
qz = Normal(loc=Dense(d)(hidden),
            scale=Dense(d, activation='softplus')(hidden))

# Bind p(x, z) and q(z | x) to the same TensorFlow placeholder for x.
inference = ed.KLqp({z: qz}, data={x: x_ph})
a = tf.placeholder(tf.float32, (), name='lr')
optimizer = tf.train.RMSPropOptimizer(a, epsilon=1.0)
inference.initialize(optimizer=optimizer)

tf.global_variables_initializer().run()

n_epoch = 100
n_iter_per_epoch = x_train.shape[0] // M
for epoch in range(1, n_epoch + 1):
  print("Epoch: {0}".format(epoch))
  avg_loss = 0.0

  pbar = Progbar(n_iter_per_epoch)
  for t in range(1, n_iter_per_epoch + 1):
    pbar.update(t)
    x_batch = next(x_train_generator)
    if t < 33:
        lr = 0.01
    elif t<67:
        lr = 0.003
    else:
        lr = 0.001
    info_dict = inference.update(feed_dict={x_ph: x_batch, a: lr})
    avg_loss += info_dict['loss']

  # Print a lower bound to the average marginal likelihood for an
  # image.
  avg_loss = avg_loss / n_iter_per_epoch
  avg_loss = avg_loss / M
  print("-log p(x) <= {:0.3f}".format(avg_loss))

  # Prior predictive check.
  images = logits.eval()
  for m in range(M):
    imsave(os.path.join(out_dir, '%d.png') % m, images[m].reshape(28, 28))

Using TensorFlow backend.


Extracting /tmp/train-images-idx3-ubyte.gz
Extracting /tmp/train-labels-idx1-ubyte.gz
Extracting /tmp/t10k-images-idx3-ubyte.gz
Extracting /tmp/t10k-labels-idx1-ubyte.gz


## NN

In [1]:
import numpy as np
import tensorflow as tf
import os
from edward.util import Progbar

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('/tmp')

from scipy.misc import imsave


def generator(array, batch_size):
  """Generate batch with respect to array's first axis."""
  start = 0  # pointer to where we are in iteration
  while True:
    stop = start + batch_size
    diff = stop - array.shape[0]
    if diff <= 0:
      batch = array[start:stop]
      start += batch_size
    else:
      batch = np.concatenate((array[start:], array[:diff]))
      start = diff
        
    yield batch
    
data_dir = "/tmp/data"
out_dir = "/tmp/out"
if not os.path.exists(out_dir):
  os.makedirs(out_dir)
M = 100  # batch size during training
d = 2  # latent dimension

# DATA. MNIST batches are fed at training time.
#(x_train, _), (x_test, _) = mnist(data_dir)
x_train, x_test = mnist.train.images, mnist.test.images
x_train_generator = generator(x_train, M)
x_test_generator = generator(x_test, M)

# INFERENCE
# Define a subgraph of the variational model, corresponding to a
# minibatch of size M.
x_ph = tf.placeholder(tf.float32, [M, 28 * 28])
hidden = tf.layers.dense(x_ph, 256, activation=tf.nn.relu)
hidden = tf.layers.dense(hidden, 256, activation=tf.nn.relu)
qz = tf.distributions.Normal(loc=tf.layers.dense(hidden, d),
            scale=tf.layers.dense(hidden, d, activation=tf.nn.softplus))
z = tf.distributions.Normal(loc=tf.zeros([M, d]), scale=tf.ones([M, d]))
qz_sample = tf.reshape(qz.sample(1), [-1,d])

# MODEL
# Define a subgraph of the full model, corresponding to a minibatch of
# size M.
hidden = tf.layers.dense(qz_sample, 256, activation=tf.nn.relu)
hidden = tf.layers.dense(hidden, 256, activation=tf.nn.relu)
logits = tf.layers.dense(hidden, 28 * 28)

elbo_ce = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=x_ph)
elbo_kl = tf.contrib.distributions.kl_divergence(qz, z)
loss = tf.reduce_mean(tf.reduce_sum(elbo_ce, axis=1) + tf.reduce_sum(elbo_kl, axis=1))

Extracting /tmp/train-images-idx3-ubyte.gz
Extracting /tmp/train-labels-idx1-ubyte.gz
Extracting /tmp/t10k-images-idx3-ubyte.gz
Extracting /tmp/t10k-labels-idx1-ubyte.gz


In [3]:
# Bind p(x, z) and q(z | x) to the same TensorFlow placeholder for x.
a = tf.placeholder(tf.float32, (), name='lr')
optimizer = tf.train.RMSPropOptimizer(a, epsilon=1.0)
train = optimizer.minimize(loss)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

n_epoch = 300
n_iter_per_epoch = x_train.shape[0] // M
for epoch in range(1, n_epoch + 1):
    print("Epoch: {0}".format(epoch))
    avg_loss = 0.0

    pbar = Progbar(n_iter_per_epoch)
    for t in range(1, n_iter_per_epoch + 1):
        pbar.update(t)
        x_batch = next(x_train_generator)
        if t < 100:
            lr = 0.01
        elif t<200:
            lr = 0.003
        else:
            lr = 0.001
        _, l = sess.run([train, loss], feed_dict={x_ph: x_batch, a: lr})
        avg_loss += l

    # Print a lower bound to the average marginal likelihood for an
    # image.
    avg_loss = avg_loss / n_iter_per_epoch
    avg_loss = avg_loss / M
    print("-log p(x) <= {:0.9f}".format(avg_loss))

    # Prior predictive check.
    x_te_batch = next(x_test_generator)
    images = sess.run(logits, feed_dict={x_ph: x_te_batch})
    for m in range(M):
        imsave(os.path.join(out_dir, 'tf-%d.png') % m, images[m].reshape(28, 28))

Epoch: 1
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 2.094773068
Epoch: 2
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.899471207
Epoch: 3
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.818423749
Epoch: 4
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.777073597
Epoch: 5
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.749046630
Epoch: 6
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.727784376
Epoch: 7
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.708131479
Epoch: 8
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.690938685
Epoch: 9
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.676714485
Epoch: 10
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.663287454
Epoch: 11
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <

Epoch: 89
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.444682873
Epoch: 90
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.443812109
Epoch: 91
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.442261048
Epoch: 92
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.441710467
Epoch: 93
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.443322100
Epoch: 94
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.441769680
Epoch: 95
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.439868834
Epoch: 96
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.441670528
Epoch: 97
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.439892510
Epoch: 98
550/550 [100%] ██████████████████████████████ Elapsed: 1s 
-log p(x) <= 1.440763854
Epoch: 99
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-

550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.416945252
Epoch: 177
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.415250732
Epoch: 178
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.413470109
Epoch: 179
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.413798755
Epoch: 180
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.414278865
Epoch: 181
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.413362197
Epoch: 182
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.411851879
Epoch: 183
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.416820471
Epoch: 184
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.412721109
Epoch: 185
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.412744560
Epoch: 186
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-l

550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.401748088
Epoch: 264
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.401802702
Epoch: 265
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.398187009
Epoch: 266
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.399588265
Epoch: 267
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.400804171
Epoch: 268
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.401860356
Epoch: 269
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.401639632
Epoch: 270
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.400145901
Epoch: 271
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.400682203
Epoch: 272
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-log p(x) <= 1.400617491
Epoch: 273
550/550 [100%] ██████████████████████████████ Elapsed: 1s
-l