Skip to content

Commit

Permalink
Add weights for tensorflow and theano (fix issue #46)
Browse files Browse the repository at this point in the history
  • Loading branch information
bill-lotter committed Oct 22, 2018
1 parent 6cda934 commit 09c2301
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
1 change: 0 additions & 1 deletion README.md
Expand Up @@ -8,7 +8,6 @@ The PredNet is a deep recurrent convolutional neural network that is inspired by
The architecture is implemented as a custom layer<sup>1</sup> in [Keras](http://keras.io/).
Code and model data is now compatible with Keras 2.0.
Specifically, it has been tested on Keras 2.0.6 with Theano 0.9.0, Tensorflow 1.2.1, and Python 2.7 (for your convenience, we have added an environment.yml file for setting up your python environment).
The provided weights were trained with the Theano backend.
For previous versions of the code compatible with Keras 1.2.1, use fbcdc18.
To convert old PredNet model files and weights for Keras 2.0 compatibility, see ```convert_model_to_keras2``` in `keras_utils.py`.
<br>
Expand Down
6 changes: 4 additions & 2 deletions download_models.sh
@@ -1,5 +1,7 @@
savedir="model_data_keras2"
mkdir -p -- "$savedir"
wget https://www.dropbox.com/s/z7ittwfxa5css7a/model_data_keras2.zip?dl=0 -O $savedir/model_data_keras2.zip
unzip -j $savedir/model_data_keras2.zip -d $savedir
wget https://www.dropbox.com/s/iutxm0anhxqca0z/model_data_keras2.zip?dl=0 -O $savedir/model_data_keras2.zip
unzip $savedir/model_data_keras2.zip -d $savedir
rm $savedir/model_data_keras2.zip
mv $savedir/model_data_keras2/* $savedir
rm -r $savedir/model_data_keras2
2 changes: 1 addition & 1 deletion kitti_evaluate.py
Expand Up @@ -24,7 +24,7 @@
batch_size = 10
nt = 10

weights_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_weights.hdf5')
weights_file = os.path.join(WEIGHTS_DIR, 'tensorflow_weights/prednet_kitti_weights.hdf5')
json_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_model.json')
test_file = os.path.join(DATA_DIR, 'X_test.hkl')
test_sources = os.path.join(DATA_DIR, 'sources_test.hkl')
Expand Down

0 comments on commit 09c2301

Please sign in to comment.