In [1]:
import os
import sys

# add path to libraries for ipython
sys.path.append(os.path.expanduser("~/libs"))

import time
import math

import cv2
from scipy.misc import imresize
import numpy as np
import tensorflow as tf
import tensortools as tt

from model.conv_deconv_model import ConvDeconvModel
from model.conv_lstm_model import ConvLSTMModel
from model.conv_lstmconv2d_model import ConvLSTMConv2DModel
from model.lstm_encoder_decoder import LSTMDecoderEncoderModel
from model.lstmconv2d_encoder_decoder import LSTMConv2DDecoderEncoderModel
from model.conv_lstmconv2d_encoder_decoder import ConvLSTMConv2DDecoderEncoderModel
from model.conv_lstmconv2d_encoder_decoder_v2 import ConvLSTMConv2DDecoderEncoderModelV2

In [5]:
TRAIN_DIR = 'train_moving_mnist_lstmconv2d_en_decoder'
CHECKPOINT_FILE = 'model.ckpt-47000'
VIDEO_OUTPUT_NAME = 'out/predicted_video_train_on_prev_lstmconv2d_ssim_mse.avi'

INPUT_SEQ_LENGTH = 10

PREDICTION_LENGTH = 10

FRAME_WIDTH = 64
FRAME_HEIGHT = 64
FRAME_CHANNELS = 1

BATCH_SIZE = 1

GPU_MEMORY_FRACTION = 1.0
GPU_ALLOW_GROWTH = True

# INPUT

In [3]:
# dataset_train = tt.datasets.moving_mnist.MovingMNISTTrainDataset(BATCH_SIZE,
#                                                                  INPUT_SEQ_LENGTH + PREDICTION_LENGTH)
dataset_valid = tt.datasets.moving_mnist.MovingMNISTValidDataset(BATCH_SIZE,
                                                                 INPUT_SEQ_LENGTH + PREDICTION_LENGTH,
                                                                 num_digits=2)
#dataset_test = tt.datasets.moving_mnist.MovingMNISTTestDataset(BATCH_SIZE,
#                                                               INPUT_SEQ_LENGTH + PREDICTION_LENGTH)

File mnist.h5 has already been downloaded.


# MAIN

In [6]:
with tf.Graph().as_default():
    seq_batch = tf.placeholder(tf.float32, shape=[None, INPUT_SEQ_LENGTH, FRAME_HEIGHT, FRAME_WIDTH, FRAME_CHANNELS])
    pred_batch = tf.placeholder(tf.float32, shape=[None, PREDICTION_LENGTH, FRAME_HEIGHT, FRAME_WIDTH, FRAME_CHANNELS])
    
    # build graph and compute predictions from the inference model
    model = LSTMConv2DDecoderEncoderModel(seq_batch, pred_batch,
                                          lstm_layers=1,
                                          lstm_filters=32)
    
    total_loss = model.total_loss
    loss = model.loss

    # Create a saver and merge all summaries
    saver = tf.train.Saver(tf.all_variables())

    # Create a session for running operations in the Graph
    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=GPU_MEMORY_FRACTION,
        allow_growth=GPU_ALLOW_GROWTH)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        checkpoint_path = os.path.join(TRAIN_DIR, CHECKPOINT_FILE)
        saver.restore(sess, checkpoint_path)
        
        with tt.utils.video.VideoWriter(
            VIDEO_OUTPUT_NAME,
            fps=10,
            frame_size=(FRAME_HEIGHT, FRAME_WIDTH),
            is_color=False if FRAME_CHANNELS == 1 else True) as vw:

            input_frames = dataset_valid.get_batch()

            for i in xrange(INPUT_SEQ_LENGTH):
                frame = input_frames[0,i,:,:,:]

                # if FRAME_CHANNELS == 1:
                    # frame = tt.utils.image.to_grayscale(frame)

                if frame is not None:
                    frame = frame * 255
                    frame = np.concatenate((frame, frame), axis=1)
                    frame = frame.astype(np.uint8)
                    vw.write_frame(frame)
                
                else:
                    print('Warning: Error while reading frame.')

            # insert an empty frame in between:
            #black_frame = np.zeros((FRAME_HEIGHT, FRAME_WIDTH * 2, FRAME_CHANNELS), dtype=np.uint8) 
            #vw.write_frame(black_frame)
            #vw.write_frame(black_frame)
            
            predicted_frames, total_error, error = sess.run([model.predictions, total_loss, loss],
                                                            feed_dict={seq_batch: input_frames[:,0:INPUT_SEQ_LENGTH,:,:,:],
                                                                       pred_batch: input_frames[:,INPUT_SEQ_LENGTH:INPUT_SEQ_LENGTH+PREDICTION_LENGTH,:,:,:]})
            print('total error', total_error)
            print('error', error)
            
            print("min", np.min(predicted_frames))
            print("max", np.max(predicted_frames))

            for j in xrange(PREDICTION_LENGTH):
                predicted_frame = predicted_frames[0,j,:,:,:]
                gt_frame = input_frames[0,INPUT_SEQ_LENGTH + j,:,:,:]
                np.place(predicted_frame, predicted_frame > 1, [1])
                np.place(predicted_frame, predicted_frame < 0, [0])
                frame = np.concatenate((predicted_frame, gt_frame), axis=1)
                frame = frame * 255
                frame = frame.astype(np.uint8)
                vw.write_frame(frame)

print('DONE')

('total error', 11.68919)
('error', 11.687532)
('min', -0.21953911)
('max', 0.97124642)
DONE
