# Runtime MNIST CNN Autoencoder Example
Uses Conv2D and Deconv2D operations to create a simple auto-encoder. In this example, it is possible that it learns the trivial function. It's intention is more to see how the *tt.network.conv2d_transpose()* function works.

An image scale of [0, 1] is used here.

In [None]:
# Force matplotlib to use inline rendering
%matplotlib inline

import os
import sys

# add path to libraries for ipython
sys.path.append(os.path.expanduser("~/libs"))

import numpy as np
import tensorflow as tf
import tensortools as tt

In [None]:
BATCH_SIZE = 16
REG_LAMBDA = 5e-4
INITIAL_LR = 0.005
LR_DECAY_STEP_INTERVAL = 10000
LR_DECAY_FACTOR = 0.5

In [None]:
dataset_train = tt.datasets.mnist.MNISTTrainDataset()
dataset_valid = tt.datasets.mnist.MNISTValidDataset()
dataset_test = tt.datasets.mnist.MNISTTestDataset()

### Model

In [None]:
class SimpleCNNAutoencoderModel(tt.model.AbstractModel):    
    def __init__(self, reg_lambda=0.0):
        super(SimpleCNNAutoencoderModel, self).__init__(reg_lambda)
        
    @tt.utils.attr.override
    def inference(self, inputs, targets, is_training=True,
                  device_scope=None, memory_device=None):
        with tf.variable_scope("Encoder"):
            # 1: Conv
            conv1 = tt.network.conv2d("Conv1", inputs,
                                      4, (3, 3), (2, 2),
                                      weight_init=tf.contrib.layers.xavier_initializer_conv2d(),
                                      bias_init=0.1,
                                      regularizer=tf.contrib.layers.l2_regularizer(self.reg_lambda),
                                      activation=tf.nn.relu)

            # 2: Conv
            conv2 = tt.network.conv2d("Conv2", conv1,
                                      4, (3, 3), (2, 2),
                                      weight_init=tf.contrib.layers.xavier_initializer_conv2d(),
                                      bias_init=0.1,
                                      regularizer=tf.contrib.layers.l2_regularizer(self.reg_lambda),
                                      activation=tf.nn.relu)
            encoder_out = conv2

        with tf.variable_scope("Decoder"):
            # 3: Deconv
            conv3t = tt.network.conv2d_transpose("Deconv1", encoder_out,
                                                 4, (3, 3), (2, 2),
                                                 weight_init=tt.init.bilinear_initializer(),
                                                 bias_init=0.1,
                                                 regularizer=tf.contrib.layers.l2_regularizer(self.reg_lambda),
                                                 activation=tf.nn.relu)

            # 4: Deconv
            conv4t = tt.network.conv2d_transpose("Deconv2", conv3t,
                                                 1, (3, 3), (2, 2),
                                                 weight_init=tt.init.bilinear_initializer(), 
                                                 bias_init=0.1,
                                                 regularizer=tf.contrib.layers.l2_regularizer(self.reg_lambda),
                                                 activation=tf.nn.sigmoid)
            decoder_out = conv4t
        return decoder_out
    
    @tt.utils.attr.override
    def loss(self, predictions, targets):
        return tt.loss.bce(predictions, targets)

### Training

In [None]:
runtime = tt.core.DefaultRuntime()
runtime.register_datasets(dataset_train, dataset_valid, dataset_test)
runtime.register_model(SimpleCNNAutoencoderModel(reg_lambda=REG_LAMBDA))
runtime.build(INITIAL_LR,
              LR_DECAY_STEP_INTERVAL,
              LR_DECAY_FACTOR,
              is_autoencoder=True)

In [None]:
runtime.train(batch_size=BATCH_SIZE, steps=3000, do_checkpoints=False, do_summary=False)

### Evaluation

In [None]:
x, _ = dataset_valid.get_batch(4)

tt.visualization.display_batch(x * 255, nrows=2, ncols=2, title="Input")

pred = runtime.predict(x)

tt.visualization.display_batch(pred * 255, nrows=2, ncols=2, title="Reconstruction")

### Terminate

In [None]:
runtime.close()