# MNIST CNN Autoencoder Example
Uses Conv2D and Deconv2D operations to create a simple auto-encoder. In this example, it is possible that it learns the trivial function. It's intention is more to see how the *tt.network.conv2d_transpose()* function works.

An image scale of [0, 1] is used here.

In [None]:
# Force matplotlib to use inline rendering
%matplotlib inline

import os
import sys

# add path to libraries for ipython
sys.path.append(os.path.expanduser("~/libs"))

import numpy as np
import tensorflow as tf
import tensortools as tt

In [None]:
FRAME_W_H = 28
BATCH_SIZE = 64
MAX_STEPS = 500
LEARNING_RATE = 1e-3
REG = 5e-4

TRAIN_ON_BINARY_DATA = True

In [None]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

In [None]:
tt.hardware.set_cuda_devices([1])
g = tf.Graph()

### Inference

In [None]:
with g.as_default():
    x  = tf.placeholder(tf.float32, [BATCH_SIZE, FRAME_W_H ** 2], "X")
    y_ = tf.placeholder(tf.float32, [BATCH_SIZE, FRAME_W_H ** 2], "Y_")

    x_image = tf.reshape(x, [-1, FRAME_W_H, FRAME_W_H, 1])

    with tf.variable_scope("Encoder"):
        # 1: Conv
        conv1 = tt.network.conv2d("Conv1", x_image,
                                  4, (3, 3), (2, 2),
                                  weight_init=tf.contrib.layers.xavier_initializer_conv2d(),
                                  bias_init=0.1,
                                  regularizer=tf.contrib.layers.l2_regularizer(REG),
                                  activation=tf.nn.relu)

        # 2: Conv
        conv2 = tt.network.conv2d("Conv2", conv1,
                                  8, (3, 3), (2, 2),
                                  weight_init=tf.contrib.layers.xavier_initializer_conv2d(),
                                  bias_init=0.1,
                                  regularizer=tf.contrib.layers.l2_regularizer(REG),
                                  activation=tf.nn.relu)
        encoder_out = conv2

    with tf.variable_scope("Decoder"):
        # 3: Deconv
        conv3t = tt.network.conv2d_transpose("Deconv1", encoder_out,
                                             8, (3, 3), (2, 2),
                                             weight_init=tt.init.bilinear_initializer(),
                                             bias_init=0.1,
                                             regularizer=tf.contrib.layers.l2_regularizer(REG),
                                             activation=tf.nn.relu)

        # 4: Deconv
        conv4t = tt.network.conv2d_transpose("Deconv2", conv3t,
                                             1, (3, 3), (2, 2),
                                             weight_init=tt.init.bilinear_initializer(), 
                                             bias_init=0.1,
                                             regularizer=tf.contrib.layers.l2_regularizer(REG),
                                             activation=tf.nn.relu) # no activation and relu work better here
        decoder_out = conv4t

    output = tf.reshape(decoder_out, [-1,FRAME_W_H ** 2])

In [None]:
with g.as_default():
    with tf.name_scope("Train"), tf.device('/gpu:0'):
        x_image = tf.reshape(output, [-1, FRAME_W_H, FRAME_W_H, 1])
        y_image = tf.reshape(y_, [-1, FRAME_W_H, FRAME_W_H, 1])
        
        # BCE
        bce = tt.loss.bce(x_image, y_image)
        ms_ssim = tt.loss.ms_ssim(x_image, y_image, patch_size=7, level_weights=[0.25, 0.75])
        alpha = 0.84
        #loss = (1-alpha) * bce + alpha * ms_ssim + tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        loss = bce
        train_step = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)

    sess = tf.InteractiveSession(config=tf.ConfigProto(log_device_placement=True))
    sess.run(tf.initialize_all_variables())

### Training

In [None]:
def as_binary(data):
    d = data.copy()
    d[d > 0.5] = 1.0
    d[d <= 0.5] = 0.0
    return d

In [None]:
with g.as_default():
    for i in range(MAX_STEPS + 1):
        batch_images, _ = mnist.train.next_batch(BATCH_SIZE)
        
        if TRAIN_ON_BINARY_DATA:
            batch_images = as_binary(batch_images)
        
        if i % 100 == 0:
            train_loss = loss.eval(feed_dict={
                                   x: batch_images,
                                   y_: batch_images})
            print("step %d / %d, loss %g" % (i, MAX_STEPS, train_loss))

        train_step.run(feed_dict={x: batch_images, y_: batch_images})

### Evaluation

In [None]:
def show(images, title):
    print("### {} ###".format(title))
    print ("Value range: [{}, {}] with avg {}".format(images.min(), images.max(), images.mean()))
    print("test loss: %g" % loss.eval(feed_dict={
                                     x: images,
                                     y_: images}))
    print("test BCE (w/o reg): %g" % bce.eval(feed_dict={
                                                        x: images,
                                                        y_: images}))
    print("test MS-SSIM (w/o reg): %g" % ms_ssim.eval(feed_dict={
                                                          x: images,
                                                          y_: images}))

test_images = mnist.test.images[:BATCH_SIZE]
print ("Value range: [{}, {}]".format(test_images.min(), test_images.max()))
test_images_binary = as_binary(test_images)

show(test_images, "Float")
show(test_images_binary, "Binary")

In [None]:
def show_visualization(images, title):
    out_images = output.eval(feed_dict={x: images})
    reshaped_images = np.reshape(images, [-1, FRAME_W_H, FRAME_W_H, 1])
    tt.visualization.display_batch(reshaped_images * 255, 4, 4, title+"-GT")
    
    reshaped_out_images = np.reshape(out_images, [-1, FRAME_W_H, FRAME_W_H, 1])
    tt.visualization.display_batch(reshaped_out_images * 255, 4, 4, title + "-Reconstr.")

In [None]:
show_visualization(test_images, "Float")

In [None]:
show_visualization(test_images_binary, "Binary")