From 2f42ee2a37d28f40d53fe97836de94782b25894f Mon Sep 17 00:00:00 2001 From: Jason Hartford Date: Fri, 28 Jul 2017 15:45:57 -0700 Subject: [PATCH] added weight loading --- deepiv/models.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/deepiv/models.py b/deepiv/models.py index bbd505c..157e813 100644 --- a/deepiv/models.py +++ b/deepiv/models.py @@ -11,6 +11,12 @@ from keras.models import Model from keras import backend as K from keras.layers import Lambda, InputLayer +from keras.engine import topology +try: + import h5py +except ImportError: + h5py = None + import keras.utils @@ -404,3 +410,14 @@ def __getitem__(self, idx): if idx == (len(self) - 1): self.shuffle() return batch_features, batch_y + +def load_weights(filepath, model): + if h5py is None: + raise ImportError('`load_weights` requires h5py.') + + with h5py.File(filepath, mode='r') as f: + # set weights + topology.load_weights_from_hdf5_group(f['model_weights'], model.layers) + + return model +