In [2]:
#import argparse
import input_data
import tensorflow as tf
import time
from tensorflow.python.ops import control_flow_ops

In [15]:
# Hyperparams
EPOCHS = 100
BATCH_SIZE = 128
DISPLAY_STEP = 1

# Batch norm:
EMA_DECAY = 0.9

# Optimizer
ETA = 0.01
BETA1 = 0.9
BETA2 = 0.999
EPSILON = 1e-8

# Architecture
N_ENCODER_H1 = N_DECODER_H3 = 1000
N_ENCODER_H2 = N_DECODER_H2 = 500
N_ENCODER_H3 = N_DECODER_H1 = 250
N_CODE = 125 # size of encoded layer

In [5]:
def layer_batch_norm(x, n_out, is_training):
    beta_init  = tf.constant_initializer(value=0., dtype=tf.float32)
    gamma_init = tf.constant_initializer(value=1., dtype=tf.float32)
    beta  = tf.get_variable('beta',  [n_out], initializer=beta_init)
    gamma = tf.get_variable('gamma', [n_out], initializer=gamma_init)
    batch_mean, batch_var = tf.nn.moments(x, [0], name='moments')
    ema = tf.train.ExponentialMovingAverage(decay=EMA_DECAY)
    ema_apply_op = ema.apply([batch_mean, batch_var])
    ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
    
    def mean_var_with_update():
        with tf.control_dependencies([ema_apply_op]):
            return tf.identity(batch_mean), tf.identity(batch_var)
        
    mean, var = control_flow_ops.cond(
        is_training, mean_var_with_update, lambda: (ema_mean, ema_var))
    x_reshaped = tf.reshape(x, [-1, 1, 1, n_out])
    normed = tf.nn.batch_norm_with_global_normalization(
        x_reshaped, mean, var, beta, gamma, 1e-3, True)
    return tf.reshape(normed, [-1, n_out])

In [6]:
def fully_connected(input, weight_shape, bias_shape, is_training):
    weight_init = tf.random_normal_initializer(
        stddev=(1. / weight_shape[0]) ** 0.5)
    bias_init = tf.constant_initializer(value=0)
    W = tf.get_variable('W', weight_shape, initializer=weight_init)
    b = tf.get_variable('b', bias_shape,   initializer=bias_init)
    logits = tf.matmul(input, W) + b
    return tf.nn.tanh(
        layer_batch_norm(logits, weight_shape[1], is_training))

In [7]:
def encoder(x, n_code, is_training):
    with tf.variable_scope('encoder'):
        with tf.variable_scope('h1'):
            h1 = fully_connected(
                x, [784, N_ENCODER_H1], [N_ENCODER_H1], is_training)
        with tf.variable_scope('h2'):
            h2 = fully_connected(h1, 
                                 [N_ENCODER_H1, N_ENCODER_H2], 
                                 [N_ENCODER_H2], 
                                 is_training)
        with tf.variable_scope('h3'):
            h3 = fully_connected(h2,
                                 [N_ENCODER_H2, N_ENCODER_H3],
                                 [N_ENCODER_H3],
                                 is_training)
        with tf.variable_scope('code'):
            code = fully_connected(
                h3, [N_ENCODER_H3, n_code], [n_code], is_training)
    return code

In [19]:
def decoder(code, n_code, is_training):
    with tf.variable_scope('decoder'):
        with tf.variable_scope('h1'):
            h1 = fully_connected(code, 
                                 [n_code, N_DECODER_H1], 
                                 [N_DECODER_H1], 
                                 is_training)
        with tf.variable_scope('h2'):
            h2 = fully_connected(h1, 
                                 [N_DECODER_H1, N_DECODER_H2], 
                                 [N_DECODER_H2], 
                                 is_training)
        with tf.variable_scope('h3'):
            h3 = fully_connected(h2, 
                                 [N_DECODER_H2, N_DECODER_H3], 
                                 [N_DECODER_H3], 
                                 is_training)
        with tf.variable_scope("output"):
            output = fully_connected(
                h3, [N_DECODER_H3, 784], [784], is_training)
    return output

In [23]:
def loss(output, x):
    with tf.variable_scope('training'):
        l2 = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(output, x)), 1))
        train_loss = tf.reduce_mean(l2)
        train_summary_op = tf.summary.scalar('train_cost', train_loss)
        return train_loss, train_summary_op

In [12]:
def training(cost, global_step):
    optimizer = tf.train.AdamOptimizer(learning_rate=ETA, 
                                       beta1=BETA1, 
                                       beta2=BETA2, 
                                       epsilon=EPSILON, 
                                       use_locking=False, 
                                       name='Adam')
    train_op = optimizer.minimize(cost, global_step=global_step)
    return train_op

In [25]:
def image_summary(label, tensor):
    tensor_reshaped = tf.reshape(tensor, [-1, 28, 28, 1])
    return tf.summary.image(label, tensor_reshaped)

In [29]:
def evaluate(output, x):
    with tf.variable_scope('validation'):
        in_im_op = image_summary('input_image', x)
        out_im_op = image_summary('output_image', output)
        l2 = tf.sqrt(tf.reduce_sum(
            tf.square(
                tf.subtract(output, x, name='val_diff')), 
            1))
        val_loss = tf.reduce_mean(l2)
        val_summary_op = tf.summary.scalar('val_cost', val_loss)
        return val_loss, in_im_op, out_im_op, val_summary_op

In [17]:
mnist = input_data.read_data_sets('data/', one_hot=True)

Extracting data/train-images-idx3-ubyte.gz
Extracting data/train-labels-idx1-ubyte.gz
Extracting data/t10k-images-idx3-ubyte.gz
Extracting data/t10k-labels-idx1-ubyte.gz


In [36]:
with tf.Graph().as_default():
    with tf.variable_scope('autoencoder'):
        x = tf.placeholder('float', [None, 784]) # images: 28 * 28 = 784
        is_training = tf.placeholder(tf.bool)
        code = encoder(x, N_CODE, is_training)
        output = decoder(code, N_CODE, is_training)
        cost, train_summary_op = loss(output, x)
        global_step = tf.Variable(0, name='global_step', trainable=False)
        train_op = training(cost, global_step)
        eval_op, in_im_op, out_im_op, val_summary_op = evaluate(output, x)
        summary_op = tf.summary.merge_all()
        saver = tf.train.Saver(max_to_keep=20)
        sess = tf.Session()
        train_writer = tf.summary.FileWriter(
            'mnist_autoencoder_h=%d_logs/' %N_CODE, graph=sess.graph)
        val_writer = tf.summary.FileWriter(
            'mnist_autoencoder_h=%d_logs/' %N_CODE, graph=sess.graph)
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        
        # Train
        for epoch in range(EPOCHS):
            avg_cost = 0.
            n_batches = mnist.train.num_examples // BATCH_SIZE
            
            for batch in range(n_batches):
                batch_x, batch_y = mnist.train.next_batch(BATCH_SIZE)
                _, new_cost, train_summary = sess.run(
                    [train_op, cost, train_summary_op], 
                    feed_dict={x: batch_x, is_training: True})
                train_writer.add_summary(train_summary, 
                                         sess.run(global_step))
                avg_cost += new_cost / n_batches
                
            if epoch % DISPLAY_STEP == 0:
                train_writer.add_summary(train_summary, 
                                         sess.run(global_step))
                validation_loss, in_im, out_im, val_summary = sess.run(
                    [eval_op, in_im_op, out_im_op, val_summary_op],
                    feed_dict={x: mnist.validation.images, 
                               is_training: False})
                val_writer.add_summary(in_im,       sess.run(global_step))
                val_writer.add_summary(out_im,      sess.run(global_step))
                val_writer.add_summary(val_summary, sess.run(global_step))
                print('Epoch: %04d Cost: %.8f Valid. loss: %.8f' 
                      %(epoch + 1, avg_cost, validation_loss))

                saver.save(
                    sess, 
                    'mnist_autoencoder_h=%d_logs/model-checkpoint-%04d'
                    %(N_CODE, epoch + 1),
                    global_step=global_step)
                
        print('Optimization finished!')
        test_loss = sess.run(eval_op, 
                             feed_dict={x: mnist.test.images, 
                                        is_training: False})
        print('Test Loss:', test_loss)

Epoch: 0001 Cost: 6.54849421 Valid. loss: 4.81077480
Epoch: 0002 Cost: 4.51690872 Valid. loss: 4.12351942
Epoch: 0003 Cost: 3.96013405 Valid. loss: 3.62583232
Epoch: 0004 Cost: 3.55808474 Valid. loss: 3.32946634
Epoch: 0005 Cost: 3.28801332 Valid. loss: 3.05751657
Epoch: 0006 Cost: 3.10688807 Valid. loss: 2.92666411
Epoch: 0007 Cost: 2.97580387 Valid. loss: 2.77790880
Epoch: 0008 Cost: 2.87400215 Valid. loss: 2.70492435
Epoch: 0009 Cost: 2.80418155 Valid. loss: 2.68389130
Epoch: 0010 Cost: 2.73504546 Valid. loss: 2.54963827
Epoch: 0011 Cost: 2.67495095 Valid. loss: 2.53508329
Epoch: 0012 Cost: 2.62286801 Valid. loss: 2.45514178
Epoch: 0013 Cost: 2.56989626 Valid. loss: 2.40769243
Epoch: 0014 Cost: 2.53533129 Valid. loss: 2.35881948
Epoch: 0015 Cost: 2.49438467 Valid. loss: 2.35260797
Epoch: 0016 Cost: 2.46109661 Valid. loss: 2.29149103
Epoch: 0017 Cost: 2.42657231 Valid. loss: 2.26814651
Epoch: 0018 Cost: 2.41041783 Valid. loss: 2.24904466
Epoch: 0019 Cost: 2.37967764 Valid. loss: 2.24

KeyboardInterrupt: 