/
customizing_saving_and_serialization.py
337 lines (254 loc) · 10.5 KB
/
customizing_saving_and_serialization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
"""
Title: Customizing Saving and Serialization
Author: Neel Kovelamudi
Date created: 2023/03/15
Last modified: 2023/03/15
Description: A more advanced guide on customizing saving for your layers and models.
Accelerator: None
"""
"""
## Introduction
This guide covers advanced methods that can be customized in Keras saving. For most
users, the methods outlined in the primary
[Serialize, save, and export guide](https://keras.io/guides/serialization_and_saving)
are sufficient.
"""
"""
### APIs
We will cover the following APIs:
- `save_assets()` and `load_assets()`
- `save_own_variables()` and `load_own_variables()`
- `get_build_config()` and `build_from_config()`
- `get_compile_config()` and `compile_from_config()`
When restoring a model, these get executed in the following order:
- `build_from_config()`
- `compile_from_config()`
- `load_own_variables()`
- `load_assets()`
"""
"""
## Setup
"""
import os
import numpy as np
import keras
"""
## State saving customization
These methods determine how the state of your model's layers is saved when calling
`model.save()`. You can override them to take full control of the state saving process.
"""
"""
### `save_own_variables()` and `load_own_variables()`
These methods save and load the state variables of the layer when `model.save()` and
`keras.models.load_model()` are called, respectively. By default, the state variables
saved and loaded are the weights of the layer (both trainable and non-trainable). Here is
the default implementation of `save_own_variables()`:
```python
def save_own_variables(self, store):
all_vars = self._trainable_weights + self._non_trainable_weights
for i, v in enumerate(all_vars):
store[f"{i}"] = v.numpy()
```
The store used by these methods is a dictionary that can be populated with the layer
variables. Let's take a look at an example customizing this.
**Example:**
"""
@keras.utils.register_keras_serializable(package="my_custom_package")
class LayerWithCustomVariable(keras.layers.Dense):
def __init__(self, units, **kwargs):
super().__init__(units, **kwargs)
self.my_variable = keras.Variable(
np.random.random((units,)), name="my_variable", dtype="float32"
)
def save_own_variables(self, store):
super().save_own_variables(store)
# Stores the value of the variable upon saving
store["variables"] = self.my_variable.numpy()
def load_own_variables(self, store):
# Assigns the value of the variable upon loading
self.my_variable.assign(store["variables"])
# Load the remaining weights
for i, v in enumerate(self.weights):
v.assign(store[f"{i}"])
# Note: You must specify how all variables (including layer weights)
# are loaded in `load_own_variables.`
def call(self, inputs):
dense_out = super().call(inputs)
return dense_out + self.my_variable
model = keras.Sequential([LayerWithCustomVariable(1)])
ref_input = np.random.random((8, 10))
ref_output = np.random.random((8, 10))
model.compile(optimizer="adam", loss="mean_squared_error")
model.fit(ref_input, ref_output)
model.save("custom_vars_model.keras")
restored_model = keras.models.load_model("custom_vars_model.keras")
np.testing.assert_allclose(
model.layers[0].my_variable.numpy(),
restored_model.layers[0].my_variable.numpy(),
)
"""
### `save_assets()` and `load_assets()`
These methods can be added to your model class definition to store and load any
additional information that your model needs.
For example, NLP domain layers such as TextVectorization layers and IndexLookup layers
may need to store their associated vocabulary (or lookup table) in a text file upon
saving.
Let's take at the basics of this workflow with a simple file `assets.txt`.
**Example:**
"""
@keras.saving.register_keras_serializable(package="my_custom_package")
class LayerWithCustomAssets(keras.layers.Dense):
def __init__(self, vocab=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.vocab = vocab
def save_assets(self, inner_path):
# Writes the vocab (sentence) to text file at save time.
with open(os.path.join(inner_path, "vocabulary.txt"), "w") as f:
f.write(self.vocab)
def load_assets(self, inner_path):
# Reads the vocab (sentence) from text file at load time.
with open(os.path.join(inner_path, "vocabulary.txt"), "r") as f:
text = f.read()
self.vocab = text.replace("<unk>", "little")
model = keras.Sequential(
[LayerWithCustomAssets(vocab="Mary had a <unk> lamb.", units=5)]
)
x = np.random.random((10, 10))
y = model(x)
model.save("custom_assets_model.keras")
restored_model = keras.models.load_model("custom_assets_model.keras")
np.testing.assert_string_equal(
restored_model.layers[0].vocab, "Mary had a little lamb."
)
"""
## `build` and `compile` saving customization
### `get_build_config()` and `build_from_config()`
These methods work together to save the layer's built states and restore them upon
loading.
By default, this only includes a build config dictionary with the layer's input shape,
but overriding these methods can be used to include further Variables and Lookup Tables
that can be useful to restore for your built model.
**Example:**
"""
@keras.saving.register_keras_serializable(package="my_custom_package")
class LayerWithCustomBuild(keras.layers.Layer):
def __init__(self, units=32, **kwargs):
super().__init__(**kwargs)
self.units = units
def call(self, inputs):
return keras.ops.matmul(inputs, self.w) + self.b
def get_config(self):
return dict(units=self.units, **super().get_config())
def build(self, input_shape, layer_init):
# Note the overriding of `build()` to add an extra argument.
# Therefore, we will need to manually call build with `layer_init` argument
# before the first execution of `call()`.
super().build(input_shape)
self._input_shape = input_shape
self.w = self.add_weight(
shape=(input_shape[-1], self.units),
initializer=layer_init,
trainable=True,
)
self.b = self.add_weight(
shape=(self.units,),
initializer=layer_init,
trainable=True,
)
self.layer_init = layer_init
def get_build_config(self):
build_config = {
"layer_init": self.layer_init,
"input_shape": self._input_shape,
} # Stores our initializer for `build()`
return build_config
def build_from_config(self, config):
# Calls `build()` with the parameters at loading time
self.build(config["input_shape"], config["layer_init"])
custom_layer = LayerWithCustomBuild(units=16)
custom_layer.build(input_shape=(8,), layer_init="random_normal")
model = keras.Sequential(
[
custom_layer,
keras.layers.Dense(1, activation="sigmoid"),
]
)
x = np.random.random((16, 8))
y = model(x)
model.save("custom_build_model.keras")
restored_model = keras.models.load_model("custom_build_model.keras")
np.testing.assert_equal(restored_model.layers[0].layer_init, "random_normal")
np.testing.assert_equal(restored_model.built, True)
"""
### `get_compile_config()` and `compile_from_config()`
These methods work together to save the information with which the model was compiled
(optimizers, losses, etc.) and restore and re-compile the model with this information.
Overriding these methods can be useful for compiling the restored model with custom
optimizers, custom losses, etc., as these will need to be deserialized prior to calling
`model.compile` in `compile_from_config()`.
Let's take a look at an example of this.
**Example:**
"""
@keras.saving.register_keras_serializable(package="my_custom_package")
def small_square_sum_loss(y_true, y_pred):
loss = keras.ops.square(y_pred - y_true)
loss = loss / 10.0
loss = keras.ops.sum(loss, axis=1)
return loss
@keras.saving.register_keras_serializable(package="my_custom_package")
def mean_pred(y_true, y_pred):
return keras.ops.mean(y_pred)
@keras.saving.register_keras_serializable(package="my_custom_package")
class ModelWithCustomCompile(keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dense1 = keras.layers.Dense(8, activation="relu")
self.dense2 = keras.layers.Dense(4, activation="softmax")
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
def compile(self, optimizer, loss_fn, metrics):
super().compile(optimizer=optimizer, loss=loss_fn, metrics=metrics)
self.model_optimizer = optimizer
self.loss_fn = loss_fn
self.loss_metrics = metrics
def get_compile_config(self):
# These parameters will be serialized at saving time.
return {
"model_optimizer": self.model_optimizer,
"loss_fn": self.loss_fn,
"metric": self.loss_metrics,
}
def compile_from_config(self, config):
# Deserializes the compile parameters (important, since many are custom)
optimizer = keras.utils.deserialize_keras_object(config["model_optimizer"])
loss_fn = keras.utils.deserialize_keras_object(config["loss_fn"])
metrics = keras.utils.deserialize_keras_object(config["metric"])
# Calls compile with the deserialized parameters
self.compile(optimizer=optimizer, loss_fn=loss_fn, metrics=metrics)
model = ModelWithCustomCompile()
model.compile(
optimizer="SGD", loss_fn=small_square_sum_loss, metrics=["accuracy", mean_pred]
)
x = np.random.random((4, 8))
y = np.random.random((4,))
model.fit(x, y)
model.save("custom_compile_model.keras")
restored_model = keras.models.load_model("custom_compile_model.keras")
np.testing.assert_equal(model.model_optimizer, restored_model.model_optimizer)
np.testing.assert_equal(model.loss_fn, restored_model.loss_fn)
np.testing.assert_equal(model.loss_metrics, restored_model.loss_metrics)
"""
## Conclusion
Using the methods learned in this tutorial allows for a wide variety of use cases,
allowing the saving and loading of complex models with exotic assets and state
elements. To recap:
- `save_own_variables` and `load_own_variables` determine how your states are saved
and loaded.
- `save_assets` and `load_assets` can be added to store and load any additional
information your model needs.
- `get_build_config` and `build_from_config` save and restore the model's built
states.
- `get_compile_config` and `compile_from_config` save and restore the model's
compiled states.
"""