In [1]:
import os
import sys

# add path to libraries for ipython
sys.path.append(os.path.expanduser("~/libs"))

import time
import math

import cv2
from scipy.misc import imresize
import numpy as np
import tensorflow as tf
import tensortools as tt

# import model.conv_deconv_model as model
# import model.conv_lstm_model as model
# import model.conv_lstmconv2d_model as model
import model.lstmconv2d_encoder_decoder as model

In [2]:
TRAIN_DIR = 'train_moving_mnist'
CHECKPOINT_FILE = 'model.ckpt-30000'
VIDEO_OUTPUT_NAME = 'out/predicted_video.avi'

INPUT_SEQ_LENGTH = 10

PREDICTION_LENGTH = 10
#GROUND_TRUTH_LENGTH = 30

FRAME_WIDTH = 64
FRAME_HEIGHT = 64
FRAME_CHANNELS = 1

LAMBDA = 5e-4 # ???

BATCH_SIZE = 1

GPU_MEMORY_FRACTION = 1.0
GPU_ALLOW_GROWTH = True

# INPUT

In [3]:
dataset_valid = tt.datasets.moving_mnist.MovingMNISTValidDataset(BATCH_SIZE,
                                                                INPUT_SEQ_LENGTH + PREDICTION_LENGTH)
#dataset_test = tt.datasets.moving_mnist.MovingMNISTTestDataset(BATCH_SIZE,
#                                                               INPUT_SEQ_LENGTH + PREDICTION_LENGTH)

File mnist.h5 has already been downloaded.


# MAIN

In [4]:
with tf.Graph().as_default():
    seq_batch = tf.placeholder(tf.float32, shape=[None, INPUT_SEQ_LENGTH, FRAME_HEIGHT, FRAME_WIDTH, FRAME_CHANNELS])
    pred_batch = tf.placeholder(tf.float32, shape=[None, INPUT_SEQ_LENGTH, FRAME_HEIGHT, FRAME_WIDTH, FRAME_CHANNELS])
    
    # build graph and compute predictions from the inference model
    model_output = model.inference(seq_batch,
                                   pred_batch,  # 
                                   FRAME_CHANNELS,
                                   PREDICTION_LENGTH,  # FIXME not used anymore ? Or PREDICTON_SEQ_LENGTH?
                                   LAMBDA)

    # Create a saver and merge all summaries
    saver = tf.train.Saver(tf.all_variables())

    # Create a session for running operations in the Graph
    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=GPU_MEMORY_FRACTION,
        allow_growth=GPU_ALLOW_GROWTH)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        checkpoint_path = os.path.join(TRAIN_DIR, CHECKPOINT_FILE)
        saver.restore(sess, checkpoint_path)
        
        with tt.utils.video.VideoWriter(
            VIDEO_OUTPUT_NAME,
            fps=10,
            frame_size=(FRAME_HEIGHT, FRAME_WIDTH),
            is_color=False if FRAME_CHANNELS == 1 else True) as vw:

            input_frames = dataset_valid.get_batch()

            for i in xrange(PREDICTION_LENGTH):
                frame = input_frames[0,i,:,:,:]

                # if FRAME_CHANNELS == 1:
                    # frame = tt.utils.image.to_grayscale(frame)

                if frame is not None:
                    frame = frame * 255
                    frame = np.concatenate((frame, frame), axis=1)
                    frame = frame.astype(np.uint8)
                    vw.write_frame(frame)
                
                else:
                    print('Warning: Error while reading frame.')

            # insert an empty frame in between:
            black_frame = np.zeros((FRAME_HEIGHT, FRAME_WIDTH * 2, FRAME_CHANNELS), dtype=np.uint8) 
            vw.write_frame(black_frame)
            vw.write_frame(black_frame)
            
            predicted_frames = sess.run(model_output, feed_dict={seq_batch: input_frames[:,0:INPUT_SEQ_LENGTH,:,:,:],
                                                                 pred_batch: input_frames[:,INPUT_SEQ_LENGTH-1:INPUT_SEQ_LENGTH+PREDICTION_LENGTH-1,:,:,:]})
            print("min", np.min(predicted_frames))
            print("max", np.max(predicted_frames))

            for j in xrange(PREDICTION_LENGTH):
                predicted_frame = predicted_frames[0,j,:,:,:]
                gt_frame = input_frames[0,INPUT_SEQ_LENGTH + j,:,:,:]
                np.place(predicted_frame, predicted_frame > 1, [1])
                np.place(predicted_frame, predicted_frame < 0, [0])
                frame = np.concatenate((predicted_frame, gt_frame), axis=1)
                frame = frame * 255
                frame = frame.astype(np.uint8)
                vw.write_frame(frame)

print('DONE')

LSTMStateTuple(c=array([[[[ -2.43138969e-02,  -6.06607139e-01,   1.82922501e-02, ...,
            1.18917570e-01,   2.38289297e-01,   1.58113867e-01],
         [ -1.68069616e-01,  -4.04568195e-01,  -2.47341003e-02, ...,
            6.13150187e-02,   2.19614804e-01,   1.33122504e-01],
         [ -6.05173230e-01,  -3.03347856e-01,  -2.17048451e-02, ...,
            4.06243168e-02,   1.43029422e-01,   1.95424363e-01],
         ..., 
         [ -2.60363841e+00,   7.86047205e-02,  -9.48971137e-03, ...,
           -2.41402946e-02,  -7.84939468e-01,  -1.80609643e-01],
         [ -2.44910955e+00,   1.50817588e-01,   8.32132697e-02, ...,
           -6.40981570e-02,  -1.43055642e+00,  -8.33836645e-02],
         [ -2.65334988e+00,   2.78674304e-01,   2.65048534e-01, ...,
           -1.67480543e-01,  -1.95025766e+00,  -1.24449179e-01]],

        [[ -5.03148198e-01,  -3.48773569e-01,  -2.40946449e-02, ...,
            1.13054104e-01,   1.67718142e-01,   2.48381257e-01],
         [ -8.33336711e-01, 