-
Notifications
You must be signed in to change notification settings - Fork 426
/
savers.py
468 lines (371 loc) · 15 KB
/
savers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
# Copyright 2018 DeepMind Technologies Limited. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility classes for saving model checkpoints and snapshots."""
import abc
import datetime
import os
import pickle
import time
from typing import Mapping, Optional, Union
from absl import logging
from acme import core
from acme.utils import signals
from acme.utils import paths
import sonnet as snt
import tensorflow as tf
import tree
from tensorflow.python.saved_model import revived_types
PythonState = tf.train.experimental.PythonState
Checkpointable = Union[tf.Module, tf.Variable, PythonState]
_DEFAULT_CHECKPOINT_TTL = int(datetime.timedelta(days=5).total_seconds())
_DEFAULT_SNAPSHOT_TTL = int(datetime.timedelta(days=90).total_seconds())
class TFSaveable(abc.ABC):
"""An interface for objects that expose their checkpointable TF state."""
@property
@abc.abstractmethod
def state(self) -> Mapping[str, Checkpointable]:
"""Returns TensorFlow checkpointable state."""
class Checkpointer:
"""Convenience class for periodically checkpointing.
This can be used to checkpoint any object with trackable state (e.g.
tensorflow variables or modules); see tf.train.Checkpoint for
details. Objects inheriting from tf.train.experimental.PythonState can also
be checkpointed.
Typically people use Checkpointer to make sure that they can correctly recover
from a machine going down during learning. For more permanent storage of self-
contained "networks" see the Snapshotter object.
Usage example:
```python
model = snt.Linear(10)
checkpointer = Checkpointer(objects_to_save={'model': model})
for _ in range(100):
# ...
checkpointer.save()
```
"""
def __init__(
self,
objects_to_save: Mapping[str, Union[Checkpointable, core.Saveable]],
*,
directory: str = '~/acme/',
subdirectory: str = 'default',
time_delta_minutes: float = 10.0,
enable_checkpointing: bool = True,
add_uid: bool = True,
max_to_keep: int = 1,
checkpoint_ttl_seconds: Optional[int] = _DEFAULT_CHECKPOINT_TTL,
keep_checkpoint_every_n_hours: Optional[int] = None,
):
"""Builds the saver object.
Args:
objects_to_save: Mapping specifying what to checkpoint.
directory: Which directory to put the checkpoint in.
subdirectory: Sub-directory to use (e.g. if multiple checkpoints are being
saved).
time_delta_minutes: How often to save the checkpoint, in minutes.
enable_checkpointing: whether to checkpoint or not.
add_uid: If True adds a UID to the checkpoint path, see
`paths.get_unique_id()` for how this UID is generated.
max_to_keep: The maximum number of checkpoints to keep.
checkpoint_ttl_seconds: TTL (time to leave) in seconds for checkpoints.
keep_checkpoint_every_n_hours: keep_checkpoint_every_n_hours passed to
tf.train.CheckpointManager.
"""
# Convert `Saveable` objects to TF `Checkpointable` first, if necessary.
def to_ckptable(x: Union[Checkpointable, core.Saveable]) -> Checkpointable:
if isinstance(x, core.Saveable):
return SaveableAdapter(x)
return x
objects_to_save = {k: to_ckptable(v) for k, v in objects_to_save.items()}
self._time_delta_minutes = time_delta_minutes
self._last_saved = 0.
self._enable_checkpointing = enable_checkpointing
self._checkpoint_manager = None
if enable_checkpointing:
# Checkpoint object that handles saving/restoring.
self._checkpoint = tf.train.Checkpoint(**objects_to_save)
self._checkpoint_dir = paths.process_path(
directory,
'checkpoints',
subdirectory,
ttl_seconds=checkpoint_ttl_seconds,
backups=False,
add_uid=add_uid)
# Create a manager to maintain different checkpoints.
self._checkpoint_manager = tf.train.CheckpointManager(
self._checkpoint,
directory=self._checkpoint_dir,
max_to_keep=max_to_keep,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)
self.restore()
def save(self, force: bool = False) -> bool:
"""Save the checkpoint if it's the appropriate time, otherwise no-ops.
Args:
force: Whether to force a save regardless of time elapsed since last save.
Returns:
A boolean indicating if a save event happened.
"""
if not self._enable_checkpointing:
return False
if (not force and
time.time() - self._last_saved < 60 * self._time_delta_minutes):
return False
checkpoint_manager: tf.train.CheckpointManager = self.checkpoint_manager
# Save any checkpoints.
logging.info('Saving checkpoint: %s', checkpoint_manager.directory)
checkpoint_manager.save()
self._last_saved = time.time()
return True
def restore(self):
"""Restore from most recent checkpoint."""
# Restore from the most recent checkpoint (if it exists).
checkpoint_to_restore = self.checkpoint_manager.latest_checkpoint
logging.info('Attempting to restore checkpoint: %s',
checkpoint_to_restore)
self._checkpoint.restore(checkpoint_to_restore)
@property
def directory(self):
return self.checkpoint_manager.directory
@property
def checkpoint_manager(self) -> tf.train.CheckpointManager:
if not self._enable_checkpointing:
raise ValueError(
'Check-point not enabled. No checkpoint manager available.'
)
# At this point, _enable_checkpointing is true, so _checkpoint_manager
# should not be None.
assert self._checkpoint_manager is not None
return self._checkpoint_manager
class CheckpointingRunner(core.Worker):
"""Wrap an object and expose a run method which checkpoints periodically.
This internally creates a Checkpointer around `wrapped` object and exposes
all of the methods of `wrapped`. Additionally, any `**kwargs` passed to the
runner are forwarded to the internal Checkpointer.
"""
def __init__(
self,
wrapped: Union[Checkpointable, core.Saveable, TFSaveable],
key: str = 'wrapped',
*,
time_delta_minutes: int = 30,
**kwargs,
):
if isinstance(wrapped, TFSaveable):
# If the object to be wrapped exposes its TF State, checkpoint that.
objects_to_save = wrapped.state
else:
# Otherwise checkpoint the wrapped object itself.
objects_to_save = wrapped
self._wrapped = wrapped
self._time_delta_minutes = time_delta_minutes
self._checkpointer = Checkpointer(
objects_to_save={key: objects_to_save},
time_delta_minutes=time_delta_minutes,
**kwargs)
# Handle preemption signal. Note that this must happen in the main thread.
def _signal_handler(self):
logging.info('Caught SIGTERM: forcing a checkpoint save.')
self._checkpointer.save(force=True)
def step(self):
if isinstance(self._wrapped, core.Learner):
# Learners have a step() method, so alternate between that and ckpt call.
self._wrapped.step()
self._checkpointer.save()
else:
# Wrapped object doesn't have a run method; set our run method to ckpt.
self.checkpoint()
def run(self):
"""Runs the checkpointer."""
with signals.runtime_terminator(self._signal_handler):
while True:
self.step()
def __dir__(self):
return dir(self._wrapped) + ['get_directory']
# TODO(b/195915583) : Throw when wrapped object has get_directory() method.
def __getattr__(self, name):
if name == 'get_directory':
return self.get_directory
return getattr(self._wrapped, name)
def checkpoint(self):
self._checkpointer.save()
# Do not sleep for a long period of time to avoid LaunchPad program
# termination hangs (time.sleep is not interruptible).
for _ in range(self._time_delta_minutes * 60):
time.sleep(1)
def get_directory(self):
return self._checkpointer.directory
class Snapshotter:
"""Convenience class for periodically snapshotting.
Objects which can be snapshotted are limited to Sonnet or tensorflow Modules
which implement a __call__ method. This will save the module's graph and
variables such that they can be loaded later using `tf.saved_model.load`. See
https://www.tensorflow.org/guide/saved_model for more details.
The Snapshotter is typically used to save infrequent permanent self-contained
snapshots which can be loaded later for inspection. For frequent saving of
model parameters in order to guard against pre-emption of the learning process
see the Checkpointer class.
Usage example:
```python
model = snt.Linear(10)
snapshotter = Snapshotter(objects_to_save={'model': model})
for _ in range(100):
# ...
snapshotter.save()
```
"""
def __init__(
self,
objects_to_save: Mapping[str, snt.Module],
*,
directory: str = '~/acme/',
time_delta_minutes: float = 30.0,
snapshot_ttl_seconds: int | None = _DEFAULT_SNAPSHOT_TTL,
):
"""Builds the saver object.
Args:
objects_to_save: Mapping specifying what to snapshot.
directory: Which directory to put the snapshot in.
time_delta_minutes: How often to save the snapshot, in minutes.
snapshot_ttl_seconds: TTL (time to live) in seconds for snapshots. If
`None`, then snapshots will be created in `directory` without a TTL.
"""
objects_to_save = objects_to_save or {}
self._time_delta_minutes = time_delta_minutes
self._last_saved = 0.
self._snapshots = {}
# Save the base directory path so we can refer to it if needed.
self.directory = paths.process_path(
directory, 'snapshots', ttl_seconds=snapshot_ttl_seconds)
# Save a dictionary mapping paths to snapshot capable models.
for name, module in objects_to_save.items():
path = os.path.join(self.directory, name)
self._snapshots[path] = make_snapshot(module)
def save(self, force: bool = False) -> bool:
"""Snapshots if it's the appropriate time, otherwise no-ops.
Args:
force: If True, save new snapshot no matter how long it's been since the
last one.
Returns:
A boolean indicating if a save event happened.
"""
seconds_since_last = time.time() - self._last_saved
if (self._snapshots and
(force or seconds_since_last >= 60 * self._time_delta_minutes)):
# Save any snapshots.
for path, snapshot in self._snapshots.items():
tf.saved_model.save(snapshot, path)
# Record the time we finished saving.
self._last_saved = time.time()
return True
return False
class Snapshot(tf.Module):
"""Thin wrapper which allows the module to be saved."""
def __init__(self):
super().__init__()
self._module = None
self._variables = None
self._trainable_variables = None
@tf.function
def __call__(self, *args, **kwargs):
if self._module is None:
raise ValueError('_module not set')
return self._module(*args, **kwargs)
@property
def submodules(self):
return [self._module]
@property
def variables(self):
return self._variables
@property
def trainable_variables(self):
return self._trainable_variables
# Registers the Snapshot object above such that when it is restored by
# tf.saved_model.load it will be restored as a Snapshot. This is important
# because it allows us to expose the __call__, and *_variables properties.
revived_types.register_revived_type(
'acme_snapshot',
lambda obj: isinstance(obj, Snapshot),
versions=[
revived_types.VersionedTypeRegistration(
object_factory=lambda proto: Snapshot(),
version=1,
min_producer_version=1,
min_consumer_version=1,
setter=setattr,
)
])
def make_snapshot(module: snt.Module):
"""Create a thin wrapper around a module to make it snapshottable."""
# Get the input signature as long as it has been created.
input_signature = _get_input_signature(module)
if input_signature is None:
raise ValueError(
('module instance "{}" has no input_signature attribute, '
'which is required for snapshotting; run '
'create_variables to add this annotation.').format(module.name))
# Wrap the module up in tf.function so we can process it properly.
@tf.function
def wrapped_module(*args, **kwargs):
return module(*args, **kwargs)
# pylint: disable=protected-access
snapshot = Snapshot()
snapshot._module = wrapped_module
snapshot._variables = module.variables
snapshot._trainable_variables = module.trainable_variables
# pylint: disable=protected-access
# Make sure the snapshot has the proper input signature.
snapshot.__call__.get_concrete_function(*input_signature)
# If we are an RNN also save the initial-state generating function.
if isinstance(module, snt.RNNCore):
snapshot.initial_state = tf.function(module.initial_state)
snapshot.initial_state.get_concrete_function(
tf.TensorSpec(shape=(), dtype=tf.int32))
return snapshot
def _get_input_signature(module: snt.Module) -> Optional[tf.TensorSpec]:
"""Get module input signature.
Works even if the module with signature is wrapper into snt.Sequentual or
snt.DeepRNN.
Args:
module: the module which input signature we need to get. The module has to
either have input_signature itself (i.e. you have to run create_variables
on the module), or it has to be a module (with input_signature) wrapped in
(one or multiple) snt.Sequential or snt.DeepRNNs.
Returns:
Input signature of the module or None if it's not available.
"""
if hasattr(module, '_input_signature'):
return module._input_signature # pylint: disable=protected-access
if isinstance(module, snt.Sequential):
first_layer = module._layers[0] # pylint: disable=protected-access
return _get_input_signature(first_layer)
if isinstance(module, snt.DeepRNN):
first_layer = module._layers[0] # pylint: disable=protected-access
input_signature = _get_input_signature(first_layer)
# Wrapping a module in DeepRNN changes its state shape, so we need to bring
# it up to date.
state = module.initial_state(1)
input_signature[-1] = tree.map_structure(
lambda t: tf.TensorSpec((None,) + t.shape[1:], t.dtype), state)
return input_signature
return None
class SaveableAdapter(tf.train.experimental.PythonState):
"""Adapter which allows `Saveable` object to be checkpointed by TensorFlow."""
def __init__(self, object_to_save: core.Saveable):
self._object_to_save = object_to_save
def serialize(self):
state = self._object_to_save.save()
return pickle.dumps(state)
def deserialize(self, pickled: bytes):
state = pickle.loads(pickled)
self._object_to_save.restore(state)