# MovingBox Intermediate Frame Prediction by LSTM

In [4]:
from util import *
from util.parser import *
from util.img_kit import *
from util.notebook_display import *
from util.numeric_ops import *
from IPython import display
import numpy as np
from scipy import ndimage
from scipy import misc
from os import walk
import os
import tensorflow as tf
from PIL import Image
import tensorflow.contrib.rnn as rnn

import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['figure.figsize'] = (5.0, 5.0) # set default size of plots
plt.rcParams['image.cmap'] = 'gray'

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
def sample(collection, batch_size = 8, gap = 1):
    """
    Input:
        collection: [img_data] - list of ndarray
    Output:
        (train_input, train_td)
        
        train_input: [batch size, seq_size,   32, 32]
        train_td:    [batch size, seq_size, 32, 32]
    """
    assert gap%2==1, "Gap must be odd !"      
    np.random.shuffle(collection)
    # get average number of training for each class
    n_collection = len(collection)
    num_per_collection = [x.shape[0] for x in collection]
    avg_num_per_class = int(np.ceil(batch_size/n_collection))
    # start-index for each class
    start_ind = []
    for i, imgs in enumerate(collection):
        try:
            s = np.random.choice(range(num_per_collection[i] - (gap + 1) * seq_size), avg_num_per_class, replace=False)
            start_ind.append(s)
        except: # if not enough in this class
            start_ind.append(np.array([]))
    # end-index for each class
    end_ind = [x+gap+1 for x in before_ind]
    # mid-index for each class
    mid_ind = [x+(gap+1)//2 for x in before_ind]
    
    selected_classes = [i for i in range(n_collection) if before_ind[i].shape[0]>0]
    before_imgs = np.concatenate([collection[i][before_ind[i]] for i in selected_classes], axis = 0)
    after_imgs = np.concatenate([collection[i][after_ind[i]] for i in selected_classes], axis = 0)
    mid_imgs = np.concatenate([collection[i][mid_ind[i]] for i in selected_classes], axis = 0)
    
    before_imgs = before_imgs[:batch_size]
    mid_imgs = mid_imgs[:batch_size]
    after_imgs = after_imgs[:batch_size]
    return before_imgs, after_imgs, mid_imgs


def sample_train(batch_size = 8, gap = 1): return sample(train_collection, batch_size, gap = gap)

def sample_test(batch_size = 8, gap = 1):  return sample(test_collection, batch_size, gap)

# Param

In [2]:
seq_size        = 3
feature_size    = 1024    # size of feature vector for LSTM
lstm_state_size = 512   # size of hidden state: [lstm_state_size, lstm_state_size]

# Encoder

In [None]:
def encode_img(img, is_training=True):
    """
    Input:
        batch size of img
    Output:
        batch size of feature [batch_size, 20, 32, feature_channel]
    """
    x = img
    
    x = tf.layers.conv2d(x, filters = 48, kernel_size=8, strides=2, padding='same', activation=tf.nn.relu)
    
    x = tf.layers.conv2d(x, filters = 48, kernel_size=5, strides=2, padding='same', activation=tf.nn.relu)
    
    x = tf.layers.conv2d(x, filters = 48, kernel_size=4, strides=2, padding='same', activation=tf.nn.relu)

    x = tf.layers.conv2d(x, filters = 48, kernel_size=3, strides=1, padding='same', activation=tf.nn.relu)

    x = tf.layers.conv2d(x, filters = 48, kernel_size=2, padding='same', activation=tf.nn.relu)
    
    x = tf.layers.conv2d(x, filters = feature_channel, kernel_size=2, padding='same', activation=tf.nn.relu)
    return x

In [None]:
def encode_seq(img_seq):
    """
    Input:
        img_seq: sequence of images      [ [batch_size, 32, 32, 1] x size seq_size ]
    Output:
        encoded feature of the sequence  [batch_size, feature_size]
    """
    
    return [encode_img(img) for img in img_seq]

# Decoder

In [None]:
def decode(feature, is_training=True):
    """
    Input:
        batch size of feature [batch_size, 32, 32, feature_channel]
    Output:
        batch size of img [batch_size, 160, 256, 3]
    """
    x = feature

    x = tf.layers.conv2d(x, filters = 84, kernel_size=3, padding='same', activation=tf.nn.relu)
    x = tf.layers.conv2d_transpose(x, filters=48, kernel_size=5, strides=2, activation=tf.nn.relu, padding='same')
    x = tf.layers.conv2d_transpose(x, filters=48, kernel_size=7, strides=2, activation=tf.nn.relu, padding='same')
    x = tf.layers.conv2d_transpose(x, filters=48, kernel_size=4, strides=2, activation=tf.nn.relu, padding='same')
    x = tf.layers.conv2d_transpose(x, filters=48, kernel_size=3,  strides=1, activation=tf.nn.tanh, padding='same')
    x = tf.layers.conv2d_transpose(x, filters=24, kernel_size=2,  strides=1, activation=tf.nn.tanh, padding='same')
    img = tf.layers.conv2d_transpose(x, filters=3, kernel_size=2,  strides=1, activation=tf.nn.tanh, padding='same')
    return img

In [None]:
def predict_in_between(seq_feature, frame1, frame2):
    """
    Input:
        seq_feature:     encoded feature of image sequence  [batch_size, 8,  8,  feature_dim]
        frame1, frame2:  two consecutive frames             [batch_size,32, 32,  1]  
    Output:
        frame in between
    """

## Loss

In [6]:
def get_loss(gd_imgs, output_imgs):
    return tf.norm(x-gd)

## Computation Graph

In [None]:
tf.reset_default_graph()
batch_seq        = tf.placeholder(tf.float32, [None, seq_size, 32, 32, 1])
is_training      = tf.placeholder(tf.bool, ())


feature_seq      =  encode_seq(batch_seq)
lstm_cell = rnn.BasicLSTMCell(lstm_state_size)
splitted_fea_seq = tf.split(feature_seq, seq_size, axis=1) # [None, 32, 32, 1]
output_feature, states = rnn.static_rnn(lstm_cell, splitted_fea_seq)
output_imgs = np.array([decode(f) for f in output_feature])  # [batch_size, 32, 32, 1]

loss = get_loss()
# alpha   = tf.Variable(0., "alpha") # parameter for leaky relu
G_batch = generate(batch_before, 1, batch_after)


G_loss = content_loss(G_batch, batch_mid)
G_solver = get_solver(learning_rate, beta)

G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator') 

G_train_step = G_solver.minimize(G_loss, var_list=G_vars)