# MovingMNIST LSTMConv2D Example
Adapted from [github: TensorFlow Examples](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/recurrent_network.ipynb).

Uses a custom LSTM cell that implements the LSTMConv2D op like in Keras. Also, uses an on-the-fly generated MovingMNIST dataset, adapted by [Unsupervised Learning with LSTMs](https://github.com/emansim/unsupervised-videos).

In [None]:
# Force matplotlib to use inline rendering
%matplotlib inline

from __future__ import print_function

import os
import sys

# add path to libraries for ipython
sys.path.append(os.path.expanduser("~/libs"))

import random

import numpy as np
import tensorflow as tf
import tensortools as tt

In [None]:
BATCH_SIZE = 24

MAX_STEPS = 500
DISPLAY_STEPS = 10
VALID_STEPS = 100


LEARGNING_RATE = 5e-4

TIME_STEPS_IN = 10
PREDICTION_STEPS = 1  # this value is more or less hard coded
KERNEL_FILTERS = 64
KERNEL_SIZE = 5
IMAGE_SIZE = 64
CHANNELS = 1

LSTM_LAYERS = 2

REG_LAMBDA = 5e-4

MEMORY_DEVICE = '/cpu:0'

### Input data

In [None]:
dataset_train = tt.datasets.moving_mnist.MovingMNISTTrainDataset(BATCH_SIZE,
                                                                TIME_STEPS_IN + PREDICTION_STEPS)
dataset_valid = tt.datasets.moving_mnist.MovingMNISTValidDataset(BATCH_SIZE,
                                                                TIME_STEPS_IN + PREDICTION_STEPS)
dataset_test = tt.datasets.moving_mnist.MovingMNISTTestDataset(BATCH_SIZE,
                                                               TIME_STEPS_IN + PREDICTION_STEPS)

In [None]:
b = dataset_train.get_batch()

# visualize the first sequence
tt.visualization.display_batch(b[0,:,:,:,:], nrows=3, ncols=5)

### Graph construction

In [None]:
g = tf.Graph()

In [None]:
def RNN(x):
    # Permuting batch_size and n_steps
    x = tf.transpose(x, [1, 0, 2, 3, 4])
    # Split to get a list of 'n_steps'
    x = tf.split(0, TIME_STEPS_IN, x)
    x = [tf.squeeze(i, (0,)) for i in x]

    # Define a lstm cell with tensorflow
    lstm_cell = tt.recurrent.BasicLSTMConv2DCell(KERNEL_SIZE, KERNEL_SIZE, KERNEL_FILTERS,
                                                 IMAGE_SIZE, IMAGE_SIZE,
                                                 forget_bias=1.0,
                                                 hidden_activation=tt.network.hard_sigmoid,
                                                 device=MEMORY_DEVICE)
    if LSTM_LAYERS > 1:
        lstm_cell = tt.recurrent.MultiRNNConv2DCell([lstm_cell] * LSTM_LAYERS,
                                                    state_is_tuple=True)
    # Get lstm cell output
    outputs, states = tt.recurrent.rnn_conv2d(lstm_cell, x)

    # Linear activation, using rnn inner loop last output
    return outputs[-1]

In [None]:
with g.as_default():
    x = tf.placeholder(tf.float32, [None, TIME_STEPS_IN, IMAGE_SIZE, IMAGE_SIZE, CHANNELS], "X")
    y_ = tf.placeholder(tf.float32, [None, IMAGE_SIZE, IMAGE_SIZE, CHANNELS], "Y_")

    # image to value scale [-1,1] (roughly zero mean)
    x = x * 2 - 1
    
    out = RNN(x)
    
    # 1x1 convolution
    pred = tt.network.conv2d_transpose("Deconv", out, 1,
                             KERNEL_SIZE, KERNEL_SIZE, 1, 1,
                             regularizer=tf.contrib.layers.l2_regularizer(REG_LAMBDA),
                             device=MEMORY_DEVICE)
    
    # convert back to value scale [0,1]
    pred = (pred + 1) / 2

In [None]:
with g.as_default():
    with tf.name_scope("Train"):
        loss_l2 = tf.nn.l2_loss(pred - y_) / BATCH_SIZE
        loss_l1 = tf.reduce_sum(tf.abs(pred - y_)) / BATCH_SIZE
        
        reg_loss = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES), name="reg_loss")
        total_loss = tf.add(loss_l2, reg_loss, name="total_loss")
        
        optimizer = tf.train.AdamOptimizer(learning_rate=LEARGNING_RATE).minimize(total_loss)

### Training

In [None]:
with g.as_default():
    # Launch the graph
    gpu_options = tf.GPUOptions(allow_growth=True)
    sess = tf.InteractiveSession(config=tf.ConfigProto(gpu_options=gpu_options))
    
    sess.run(tf.initialize_all_variables())

    tt.visualization.show_graph(sess.graph_def)

    dataset_train.reset()
    
    step = 1
    # Keep training until reach max iterations
    while step <= MAX_STEPS:
        batch = dataset_train.get_batch()      
        batch_x = batch[:,0:TIME_STEPS_IN,:,:,:]
        batch_y = batch[:,-1,:,:,:]
        
        # Run optimization op (backprop)
        sess.run(optimizer, feed_dict={x: batch_x, y_: batch_y})
        if step % DISPLAY_STEPS == 0:
            # Calculate batch loss
            l1, l2 = sess.run([loss_l1, loss_l2], feed_dict={x: batch_x, y_: batch_y})
            print("@{}/{}: Minibatch Train Loss-L2= {:.6f} (Train Loss-L1= {:.6f})".format(
                    step, MAX_STEPS,
                    l2, l1))
        if step % VALID_STEPS == 0:
            dataset_valid.reset()
            num_batches = dataset_valid.dataset_size // dataset_valid.batch_size
            loss_sum = 0
            for b in xrange(num_batches):
                batch = dataset_valid.get_batch()
                batch_x = batch[:,0:TIME_STEPS_IN,:,:,:]
                batch_y = batch[:,-1,:,:,:]
                l2 = sess.run(loss_l2, feed_dict={x: batch_x, y_: batch_y})
                loss_sum += l2 / PREDICTION_STEPS
            
            avg_loss_per_frame = loss_sum / num_batches
            print("@{}: Minibatch Avg. Valid Loss-L2 per Frame= {:.6f}".format(
                    step, avg_loss_per_frame))

        step += 1
    print("Optimization Finished!")

### Testing

In [None]:
with g.as_default():
    batch = dataset_test.get_batch()
        
    batch_x = batch[:,0:TIME_STEPS_IN,:,:,:]
    batch_y = batch[:,TIME_STEPS_IN,:,:,:]
    
    # remove batch_dim
    batch_x = batch_x[0,:,:,:,:]
    batch_y = batch_y[0,:,:,:] 
    loss_factor = BATCH_SIZE
    
    print('IN:')
    for i in range(TIME_STEPS_IN):
        tt.visualization.display_array(batch_x[i] * 255)
        
    print('TARGET:')
    tt.visualization.display_array(batch_y * 255)
    
    print('PREDICTION:')
    batch_x = np.expand_dims(batch_x, axis=0)
    batch_y = np.expand_dims(batch_y, axis=0)
    prediction, l1, l2 = sess.run([pred, loss_l1, loss_l2], feed_dict={x: batch_x, y_: batch_y})
    prediction_squeezed = np.squeeze(prediction, axis=0)
    prediction_scaled = prediction_squeezed * 255
    tt.visualization.display_array(prediction_scaled)
    
    print('PREDICTION (fixed):')
    np.place(prediction_squeezed, prediction_squeezed > 1, [1])
    np.place(prediction_squeezed, prediction_squeezed < 0, [0])
    prediction_scaled = prediction_squeezed * 255
    tt.visualization.display_array(prediction_scaled)
    
    print('min-value: ', np.min(prediction_scaled))
    print('max-value: ', np.max(prediction_scaled))
    print('Test l1-loss:', l1 * loss_factor)
    print('Test l2-loss:', l2 * loss_factor)

In [None]:
with g.as_default():
    dataset_test.reset()
    num_batches = dataset_test.dataset_size // dataset_valid.batch_size
    print(num_batches)
    loss_sum = 0
    for b in xrange(num_batches):
        batch = dataset_valid.get_batch()
        batch_x = batch[:,0:TIME_STEPS_IN,:,:,:]
        batch_y = batch[:,TIME_STEPS_IN,:,:,:]
        l2 = sess.run(loss_l2, feed_dict={x: batch_x, y_: batch_y})
        loss_sum += l2 / PREDICTION_STEPS

    avg_loss_per_frame = loss_sum / num_batches
    print("Avg. Test Loss-L2 per Frame= {:.6f}".format(avg_loss_per_frame))