In [1]:
'''
Evaluate trained PredNet on KITTI sequences.
Calculates mean-squared error

SS 12/5/2016:
modified to read in batched .hkl files
50 Hz Shanghai avi's (108 avi's, 20 clips per avi), in 20 separate files
'''

import os
import numpy as np
import pandas as pd
from six.moves import cPickle
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from keras import backend as K
from keras.models import Model, model_from_json
from keras.layers import Input, Dense, Flatten

from prednet import PredNet
from data_utils import SequenceGenerator
from kitti_settings_ss import *

#parts = 1
parts = 20
suffix = '_P20_'

n_plot = 40
batch_size = 10
nt = 10

weights_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_weights.hdf5')
json_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_model.json')

# Load trained model
f = open(json_file, 'r')
json_string = f.read()
f.close()
train_model = model_from_json(json_string, custom_objects = {'PredNet': PredNet})
train_model.load_weights(weights_file)

# Create testing model (to output predictions)
layer_config = train_model.layers[1].get_config()
layer_config['output_mode'] = 'prediction'
dim_ordering = layer_config['dim_ordering']
test_prednet = PredNet(weights=train_model.layers[1].get_weights(), **layer_config)
input_shape = list(train_model.layers[0].batch_input_shape[1:])
input_shape[0] = nt
inputs = Input(shape=tuple(input_shape))
predictions = test_prednet(inputs)
test_model = Model(input=inputs, output=predictions)
if not os.path.exists(RESULTS_SAVE_DIR): os.mkdir(RESULTS_SAVE_DIR)
f = open(RESULTS_SAVE_DIR + 'prediction_scores.txt', 'w')

for part in range(1,parts+1):

    curr_test = 'X_test'+suffix+str(part)+'.hkl'
    curr_sources = 'sources_test'+suffix+str(part)+'.hkl'
    test_file = os.path.join(DATA_DIR, curr_test)
    test_sources = os.path.join(DATA_DIR, curr_sources)

    test_generator = SequenceGenerator(test_file, test_sources, nt, sequence_start_mode='unique', dim_ordering=dim_ordering)
    X_test = test_generator.create_all()
    X_hat = test_model.predict(X_test, batch_size)
    if dim_ordering == 'th':
        X_test = np.transpose(X_test, (0, 1, 3, 4, 2))
        X_hat = np.transpose(X_hat, (0, 1, 3, 4, 2))

    curr_mse_frame2 = 'mse_frame2'+suffix+str(part)+'.csv'
    mse_frame2_out = os.path.join(RESULTS_SAVE_DIR, curr_mse_frame2)
    
    mse_point = (X_test[:, 1:] - X_hat[:, 1:])**2
    mse_frame2 = np.apply_over_axes(np.mean, mse_point, [2,3,4])
    np.savetxt(mse_frame2_out, mse_frame2, delimiter=",")

    print 'mse_frame2'
    print type(mse_frame2)
    print mse_frame2.shape
    
    # Compare overall MSE's write results to prediction_scores.txt
    mse_model = np.mean( (X_test[:, 1:] - X_hat[:, 1:])**2 )  # look at all timesteps except the first
    mse_prev = np.mean( (X_test[:, :-1] - X_test[:, 1:])**2 )
    mse_last = np.mean( (X_test[:, -1] - X_hat[:, -1])**2 )  # look only the last frame
    f.write("part number: %d\n" % part)
    f.write("Model MSE: %f\n" % mse_model)
    f.write("Last Frame MSE: %f" % mse_last)
    f.write("Previous Frame MSE: %f" % mse_prev)
    
f.close()


Using Theano backend.


X shape
<type 'numpy.ndarray'>
(1080, 128, 160, 3)
source shape
<type 'list'>
1080
mse_frame
<type 'numpy.ndarray'>
(108, 9)
mse_frame2
<type 'numpy.ndarray'>
(108, 9, 1, 1, 1)
X shape
<type 'numpy.ndarray'>
(1080, 128, 160, 3)
source shape
<type 'list'>
1080
mse_frame
<type 'numpy.ndarray'>
(108, 9)
mse_frame2
<type 'numpy.ndarray'>
(108, 9, 1, 1, 1)
X shape
<type 'numpy.ndarray'>
(1080, 128, 160, 3)
source shape
<type 'list'>
1080
mse_frame
<type 'numpy.ndarray'>
(108, 9)
mse_frame2
<type 'numpy.ndarray'>
(108, 9, 1, 1, 1)
X shape
<type 'numpy.ndarray'>
(1080, 128, 160, 3)
source shape
<type 'list'>
1080
mse_frame
<type 'numpy.ndarray'>
(108, 9)
mse_frame2
<type 'numpy.ndarray'>
(108, 9, 1, 1, 1)
X shape
<type 'numpy.ndarray'>
(1080, 128, 160, 3)
source shape
<type 'list'>
1080
mse_frame
<type 'numpy.ndarray'>
(108, 9)
mse_frame2
<type 'numpy.ndarray'>
(108, 9, 1, 1, 1)
X shape
<type 'numpy.ndarray'>
(1080, 128, 160, 3)
source shape
<type 'list'>
1080
mse_frame
<type 'numpy.ndarray'>