This repository has been archived by the owner on Nov 3, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 653
/
save_load_utils.py
107 lines (98 loc) · 4.79 KB
/
save_load_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import warnings
import h5py
import keras.backend as K
from keras import optimizers
from keras.engine import saving
def save_all_weights(model, filepath, include_optimizer=True):
"""
Save model weights and optimizer weights but not configuration to a HDF5 file.
Functionally between `save` and `save_weights`.
The HDF5 file contains:
- the model's weights
- the model's optimizer's state (if any)
If you have a complicated model or set of models that do not serialize
to JSON correctly, use this method.
# Arguments
model: Keras model instance to be saved.
filepath: String, path where to save the model.
include_optimizer: If True, save optimizer's state together.
# Raises
ImportError: if h5py is not available.
"""
if h5py is None:
raise ImportError('`save_all_weights` requires h5py.')
with h5py.File(filepath, 'w') as f:
model_weights_group = f.create_group('model_weights')
model_layers = model.layers
saving.save_weights_to_hdf5_group(model_weights_group, model_layers)
if include_optimizer and hasattr(model, 'optimizer') and model.optimizer:
if isinstance(model.optimizer, optimizers.TFOptimizer):
warnings.warn(
'TensorFlow optimizers do not '
'make it possible to access '
'optimizer attributes or optimizer state '
'after instantiation. '
'As a result, we cannot save the optimizer '
'as part of the model save file.'
'You will have to compile your model again after loading it. '
'Prefer using a Keras optimizer instead '
'(see keras.io/optimizers).')
else:
# Save optimizer weights.
symbolic_weights = getattr(model.optimizer, 'weights')
if symbolic_weights:
optimizer_weights_group = f.create_group('optimizer_weights')
weight_values = K.batch_get_value(symbolic_weights)
weight_names = []
for i, (w, val) in enumerate(zip(symbolic_weights, weight_values)):
# Default values of symbolic_weights is /variable for theano
if K.backend() == 'theano':
if hasattr(w, 'name') and w.name != "/variable":
name = str(w.name)
else:
name = 'param_' + str(i)
else:
if hasattr(w, 'name') and w.name:
name = str(w.name)
else:
name = 'param_' + str(i)
weight_names.append(name.encode('utf8'))
optimizer_weights_group.attrs['weight_names'] = weight_names
for name, val in zip(weight_names, weight_values):
param_dset = optimizer_weights_group.create_dataset(
name,
val.shape,
dtype=val.dtype)
if not val.shape:
# scalar
param_dset[()] = val
else:
param_dset[:] = val
def load_all_weights(model, filepath, include_optimizer=True):
"""Loads the weights of a model saved via `save_all_weights`.
If model has been compiled, optionally load its optimizer's weights.
# Arguments
model: instantiated model with architecture matching the saved model.
Compile the model beforehand if you want to load optimizer weights.
filepath: String, path to the saved model.
# Returns
None. The model will have its weights updated.
# Raises
ImportError: if h5py is not available.
ValueError: In case of an invalid savefile.
"""
if h5py is None:
raise ImportError('`load_all_weights` requires h5py.')
with h5py.File(filepath, mode='r') as f:
# set weights
saving.load_weights_from_hdf5_group(f['model_weights'], model.layers)
# Set optimizer weights.
if (include_optimizer
and 'optimizer_weights' in f and hasattr(model, 'optimizer')
and model.optimizer):
optimizer_weights_group = f['optimizer_weights']
optimizer_weight_names = [n.decode('utf8') for n in
optimizer_weights_group.attrs['weight_names']]
optimizer_weight_values = [optimizer_weights_group[n] for n in
optimizer_weight_names]
model.optimizer.set_weights(optimizer_weight_values)