In [8]:
import os
import sys

# add path to libraries for ipython
sys.path.append(os.path.expanduser("~/libs"))

import time
import math

import cv2
from scipy.misc import imresize
import numpy as np
import tensorflow as tf
import tensortools as tt

import model.conv_deconv_model as model
# import model.conv_lstm_model as model

In [11]:
# VIDEO_FILE = 'tmp/UCF11_updated_mpg/golf_swing/v_golf_10/v_golf_10_02.mpg'
VIDEO_FILE = 'tmp/UCF11_updated_mpg/tennis_swing/v_tennis_01/v_tennis_01_01.mpg'
# VIDEO_FILE = 'tmp/UCF11_updated_mpg/golf_swing/v_golf_14/v_golf_14_03.mpg' # Overfitting file
VIDEO_START_FRAME = 0 

TRAIN_DIR = 'train'
CHECKPOINT_FILE = 'model.ckpt-30000'
VIDEO_OUTPUT_NAME = 'out/predicted_video.avi'

INPUT_SEQ_LENGTH = 5

PREDICTION_LENGTH = 270
GROUND_TRUTH_LENGTH = 5

FRAME_SCALE_FACTOR = 1.0
FRAME_WIDTH = int(320 * FRAME_SCALE_FACTOR)
FRAME_HEIGHT = int(240 * FRAME_SCALE_FACTOR)
FRAME_CHANNELS = 1

LAMBDA = 5e-4

BATCH_SIZE = 1

GPU_MEMORY_FRACTION = 1.0

# MAIN

In [12]:
with tf.Graph().as_default():
    seq_batch = tf.placeholder(tf.float32, shape=[None, FRAME_HEIGHT, FRAME_WIDTH, FRAME_CHANNELS * INPUT_SEQ_LENGTH])
    
    # build graph and compute predictions from the inference model
    model_output = model.inference(seq_batch,
                                   FRAME_CHANNELS,
                                   INPUT_SEQ_LENGTH,
                                   LAMBDA)

    # Create a saver and merge all summaries
    saver = tf.train.Saver(tf.all_variables())

    # Create a session for running operations in the Graph
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=GPU_MEMORY_FRACTION)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        checkpoint_path = os.path.join(TRAIN_DIR, CHECKPOINT_FILE)
        saver.restore(sess, checkpoint_path)
        
        with tt.utils.video.VideoReader(VIDEO_FILE, VIDEO_START_FRAME) as vr:
            with tt.utils.video.VideoWriter(
                VIDEO_OUTPUT_NAME,
                frame_size=(FRAME_HEIGHT, FRAME_WIDTH),
                is_color=False if FRAME_CHANNELS == 1 else True) as vw:

                input_frames = []

                for i in xrange(GROUND_TRUTH_LENGTH):
                    frame = vr.next_frame(FRAME_SCALE_FACTOR)
                    
                    if FRAME_CHANNELS == 1:
                        frame = tt.utils.image.to_grayscale(frame)

                    if frame is not None:
                        vw.write_frame(frame)

                        if i >= GROUND_TRUTH_LENGTH - INPUT_SEQ_LENGTH:
                            frame = (frame - 127.5) / 127.5
                            input_frames.append(frame)
                    else:
                        print('Warning: Error while reading frame.')
                        ensure_minimum_framesize

                # insert an empty frame in between:
                black_frame = np.zeros((FRAME_HEIGHT, FRAME_WIDTH, FRAME_CHANNELS), dtype=np.uint8) 
                vw.write_frame(black_frame)
                vw.write_frame(black_frame)

                for j in xrange(PREDICTION_LENGTH):
                    seq_input = input_frames[0]
                    for f in xrange(1, INPUT_SEQ_LENGTH):
                        seq_input = np.concatenate([seq_input, input_frames[f]], axis=2)
                    seq_input = np.expand_dims(seq_input, axis=0)

                    predicted_frame = sess.run(model_output, feed_dict={seq_batch: seq_input})
                    predicted_frame = np.squeeze(predicted_frame, axis=(0,))
                    np.place(predicted_frame, predicted_frame > 1, [1])
                    np.place(predicted_frame, predicted_frame < -1, [-1])
                    video_frame = predicted_frame * 127.5 + 127.5
                    np.place(video_frame, video_frame > 255, [255]) # values are 0-256?!
                    video_frame = video_frame.astype(np.uint8)
                    vw.write_frame(video_frame)

                    del input_frames[0]
                    input_frames.append(predicted_frame)

print('DONE')

DONE
