In [6]:
import tensorflow as tf
import os
import numpy as np
tfd = tf.contrib.distributions


In [18]:
def make_encoder(images, code_size):
    x = tf.layers.conv2d(images, 32, 3, padding='SAME')
    x = tf.layers.conv2d(x, 32, 3, padding='SAME')
    x = tf.layers.conv2d(x, 64, 3, padding='SAME')
    x = tf.layers.conv2d(x, 64, 3, padding='SAME')
    x = tf.reduce_mean(x, axis=[1,2])
    mean = tf.layers.dense(x, code_size)
    stddev = tf.layers.dense(x, code_size)
    return tfd.MultivariateNormalDiag(mean, stddev), mean

In [17]:
def make_prior(code_size):
    loc = tf.zeros(code_size)
    scale = tf.ones(code_size)
    return tfd.MultivariateNormalDiag(loc, scale)

In [24]:
def make_decoder(code, image_shape):
    num_units = image_shape[0] * image_shape[1] * image_shape[2]
    code = tf.layers.dense(code, num_units)
    code = tf.reshape(code, (-1, image_shape[0], image_shape[1], image_shape[2]))
    x = tf.layers.conv2d_transpose(code, 64, 3, padding='SAME')
    x = tf.layers.conv2d_transpose(x, 64, 3, padding='SAME')
    x = tf.layers.conv2d_transpose(x, 32, 3, padding='SAME')
    x = tf.layers.conv2d_transpose(x, 32, 3, padding='SAME')
    logits = tf.layers.conv2d_transpose(x, 3, 3, padding='SAME')
    return tfd.Independent(tfd.Bernoulli(logits), 3)

In [11]:
make_encoder = tf.make_template('encoder', make_encoder)
make_decoder = tf.make_template('decoder', make_decoder)

In [26]:
def vae(images):
    code_size = 32
    prior = make_prior(code_size)
    posterior, mean = make_encoder(images, code_size)
    code = posterior.sample()
    likelihood = make_decoder(code, images.shape[1:]).log_prob(images)
    divergence = tfd.kl_divergence(posterior, prior)
    elbo = tf.reduce_mean(likelihood - divergence)
    return elbo, mean

In [13]:
def load_data(folder):
    inputs = np.load(os.path.join(folder, 'inputs.npy'))
    examples = np.load(os.path.join(folder, 'examples.npy'))
    examples = examples.reshape(-1, 64, 64, 3)
    return np.concatenate([inputs, examples], axis=0)

In [36]:
def run_training():
    EPOCHS = 100
#     train_data = load_data('complearn/train/')
#     val_data = load_data('complearn/val/')
    test_data = load_data('complearn/test/')
    print('data loaded')
    data = np.concatenate([test_data])
    
    data_placeholder = tf.placeholder(tf.float32, [None, data.shape[1], data.shape[2], data.shape[3]])
    dataset = tf.data.Dataset.from_tensor_slices(data_placeholder)
    dataset = dataset.shuffle(buffer_size=1000)
    dataset = dataset.batch(64)
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()
    elbo, mean = vae(next_element)
    optimize = tf.train.AdamOptimizer(0.001).minimize(-elbo)
    init = tf.global_variables_initializer()
    print('Training started')
    with tf.train.MonitoredTrainingSession() as sess:
        sess.run(init)
        for epoch in range(EPOCHS):
            sess.run(iterator.initializer, feed_dict={data_placeholder: data})
            while True:
                try:
                    sess.run([optimize])
                except tf.errors.OutOfRangeError:
                    break
            loss = sess.run([elbo])
            print('Epoch %d elbo loss: %f' % (epoch, loss))
    

In [37]:
run_training()

data loaded
Training started
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


KeyboardInterrupt: 