/
checkpoints.py
1595 lines (1350 loc) · 62.8 KB
/
checkpoints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for reading and writing sharded checkpoints.
The checkpointing utilities here can be used in two ways. The first is to use
the `Checkpointer` class. This requires having an optimizer and various
partitioning utilities setup, but allows for reading and writing of partitioned
parameters. It also allows different hosts to read different parameter
partitions in a multi-host setup, which results in much faster reads. This is
normally used during training where you have already created an optimizer based
on a config.
The second way is to use the `load_t5x_checkpoint` function. This doesn't
require an optimizer to get given up front so it is useful for things like
debugging and analysis of learned weights. However, this means that we cannot do
partitioned reads so loading will be slower than that `Checkpointer` class.
"""
import asyncio
import dataclasses
import functools
import os
import re
import subprocess
import time
from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple
from absl import logging
from flax import serialization
from flax import traverse_util
import jax
import jax.config
from jax.experimental import global_device_array as gda_lib
from jax.experimental import multihost_utils
from jax.experimental.gda_serialization import serialization as gda_serialization
import jax.numpy as jnp
import numpy as np
from t5x import checkpoint_importer
from t5x import checkpoint_utils
from t5x import optimizers
from t5x import partitioning
from t5x import state_utils
from t5x import train_state as train_state_lib
import tensorflow as tf
from tensorflow.io import gfile
import tensorstore as ts
import typing_extensions
from tensorboard.backend.event_processing import directory_watcher
from tensorboard.backend.event_processing import event_file_loader
from tensorboard.backend.event_processing import io_wrapper
PartitionSpec = partitioning.PartitionSpec
PyTreeDef = type(jax.tree_structure(None))
LazyArray = checkpoint_importer.LazyArray
LazyAwaitableArray = checkpoint_importer.LazyAwaitableArray
LazyThreadPoolArray = checkpoint_importer.LazyThreadPoolArray
# Version 3 is used since 2021-06-10, compared to version 2 the only change is
# that `bfloat16` arrays are written in Tensorstore using its native `bfloat16`
# support instead of casting them to `uint16`.
VERSION = 3
# Desired chunk size is 64MiB.
# This is large enough to keep CNS happy but small enough to support a wide
# range of partitionings.
_DESIRED_CHUNK_SIZE_BYTES = 64 * 1024 * 1024
# TODO(levskaya, adarob): how should we handle stacked/fused variables??
def _choose_chunk_shape(write_shape: Sequence[int],
target_elements: int) -> List[int]:
"""Chooses a chunk shape that evenly divides write_shape.
The chunk shape is chosen such that the total number of elements is less than
or equal to `target_elements`, but is otherwise as large as possible.
This uses a greedy algorithm that attempts to split the largest dimensions
first.
Args:
write_shape: Write shape for which to choose a chunk shape.
target_elements: Desired number of elements in chosen chunk shape. Must be
>= 1.
Returns:
List of length `len(write_shape)` specifying the chosen chunk shape.
"""
assert target_elements >= 1
rank = len(write_shape)
# `dim_factors[i]` is the list of divisors of `write_shape[i]`
dim_factors = [
[i for i in range(1, size + 1) if size % i == 0] for size in write_shape
]
# The current chunk shape is:
# [dim_factors[i][-1] for i in range(rank)]
def get_total_elements():
"""Returns the number of elements in the current chunk shape."""
total_elements = 1
for i in range(rank):
total_elements *= dim_factors[i][-1]
return total_elements
# Reduce the current chunk shape until the desired number of elements is
# reached.
while get_total_elements() > target_elements:
# Greedily reduce the largest dimension. This is not guaranteed to bring us
# the closest to `target_elements`, but is simple to implement and should
# work well enough.
dim_to_reduce = -1
dim_to_reduce_size = 1
for i in range(rank):
size = dim_factors[i][-1]
if size > dim_to_reduce_size:
dim_to_reduce_size = size
dim_to_reduce = i
# Can only fail to choose `dim_to_reduce` if all dimensions have size of 1.
# But that cannot happen since `target_elements >= 1`.
assert dim_to_reduce_size > 1
dim_factors[dim_to_reduce].pop()
return [dim_factors[i][-1] for i in range(rank)]
@dataclasses.dataclass
class _ParameterInfo:
"""Information needed to read/write and slice a partitioned parameter."""
# The unique parameter name.
name: str
# The shape of the parameter.
shape: Tuple[int]
# The TensoreStore Spec containing the minimal information for read/write.
ts_spec: Optional[ts.Spec]
# The LocalChunkInfo for the part of the parameter local to this host.
local_chunk_info: Optional[partitioning.LocalChunkInfo]
# PartitionSpec mesh axes
axes: Optional[partitioning.PartitionSpec] = None
# Register functions with flax.serialization to handle `ts.Spec`.
serialization.register_serialization_state(
ts.Spec,
ty_to_state_dict=lambda t: t.to_json(),
# The parameter may have been written to tensorstore or msgpack.
# If the former, a dict of the spec will be stored. If the latter it will be
# the value itself.
ty_from_state_dict=lambda t, s: ts.Spec(s) if isinstance(s, dict) else s)
def _run_future_tree(future_tree):
"""Block until all futures are resolved on this host."""
future_leaves, treedef = jax.tree_flatten(future_tree)
# TODO(adarob): Use asyncio.run in py3.7+.
loop = asyncio.get_event_loop()
leaves = loop.run_until_complete(asyncio.gather(*future_leaves))
return jax.tree_unflatten(treedef, leaves)
def all_steps(checkpoints_dir: str) -> Sequence[int]:
"""Returns list of available step numbers in ascending order."""
glob_pattern = os.path.join(checkpoints_dir, 'checkpoint_*', 'checkpoint')
checkpoint_paths = gfile.glob(glob_pattern)
re_pattern = re.compile(r'.*/checkpoint_(\d+)/checkpoint$')
matches = [re_pattern.match(ckpt) for ckpt in checkpoint_paths]
return sorted(int(match.group(1)) for match in matches if match)
def latest_step(checkpoints_dir: str) -> Optional[int]:
"""Returns latest step number or None if no checkpoints exist."""
steps = all_steps(checkpoints_dir)
if not steps:
return None
return steps[-1]
def _get_local_data(x):
if isinstance(x, gda_lib.GlobalDeviceArray):
return x.local_data(0)
else:
return x
def get_checkpoint_dir(checkpoints_dir: str, step: int) -> str:
"""Returns path to a checkpoint dir given a parent directory and step."""
return os.path.join(checkpoints_dir, f'checkpoint_{step}')
def _cast(target: PyTreeDef, dtype: jnp.dtype):
"""Cast arrays in target to dtype."""
def maybe_cast(x):
if isinstance(x, (int, str)):
# Ignore common non-array types that shouldn't be cast.
return x
elif x.dtype == dtype:
return x
elif isinstance(x, jax.ShapeDtypeStruct):
return jax.ShapeDtypeStruct(x.shape, dtype)
elif isinstance(x, gda_lib.GlobalDeviceArray):
raise ValueError('GDA cast not supported.')
else:
return x.astype(dtype)
return jax.tree_map(maybe_cast, target)
def _update_ts_path_from_relative_to_absolute(
ckpt_dir: str, ts_spec_dict: MutableMapping[str, Any]):
"""Update (in-place) the path and gcs bucket (if applicable) in a TS Spec."""
# Handle `gs://` paths.
m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_dir, re.DOTALL)
if m is not None:
if ts_spec_dict['kvstore']['driver'] != 'gcs':
raise ValueError(f'Incorrect TensorStore Spec. '
f'Expects kvstore driver to be "gcs" for {ckpt_dir}. '
f'Got {ts_spec_dict}')
bucket = m.group(1)
ckpt_dir = m.group(2)
ts_spec_dict['kvstore']['bucket'] = bucket
# Update the path with `ckpt_dir`
if 'path' in ts_spec_dict['kvstore']:
# tensorstore>=0.1.14 format
ts_spec_dict['kvstore']['path'] = os.path.join(
ckpt_dir, ts_spec_dict['kvstore']['path'])
elif 'path' in ts_spec_dict:
# tensorstore<0.1.14 format
ts_spec_dict['path'] = os.path.join(ckpt_dir, ts_spec_dict['path'])
else:
raise ValueError(
'Incorrect TensorStore Spec. Expects "path" to be a key of spec or '
f'`spec["kvstore"]`. Got {ts_spec_dict}')
def _maybe_update_ts_from_file_to_gcs(ckpt_contents):
"""Updates the TensorStore driver from gfile to gcs."""
def _gfile_to_gcs_driver(arr_or_ts_spec_dict):
"""Converts the ts.Spec dict using gfile driver to gcs driver."""
if not isinstance(arr_or_ts_spec_dict, dict):
return arr_or_ts_spec_dict
if arr_or_ts_spec_dict['kvstore']['driver'] in ('file', 'gfile'):
ts_spec_dict = arr_or_ts_spec_dict
path = ts_spec_dict['kvstore'].pop('path')
ts_spec_dict['path'] = path
# This will be updated to the actual bucket in `_read_ts`.
ts_spec_dict['kvstore'] = {'bucket': 't5x-dummy-bucket', 'driver': 'gcs'}
else:
if arr_or_ts_spec_dict['kvstore']['driver'] != 'gcs':
raise ValueError('Unsupported TensoreStore driver. Got '
f'{arr_or_ts_spec_dict["kvstore"]["driver"]}.')
ts_spec_dict = arr_or_ts_spec_dict
return ts_spec_dict
def _is_leaf(value):
return not isinstance(
value, dict) or set(value.keys()) >= {'driver', 'kvstore', 'metadata'}
return jax.tree_map(_gfile_to_gcs_driver, ckpt_contents, is_leaf=_is_leaf)
def _maybe_update_ts_from_gcs_to_file(ckpt_contents):
"""Updates the TensorStore driver to gfile or file if different."""
# if saved in gcs, change to file
def _gcs_to_file_driver(arr_or_ts_spec_dict):
if not isinstance(arr_or_ts_spec_dict, dict):
return arr_or_ts_spec_dict
if arr_or_ts_spec_dict['kvstore']['driver'] == 'gcs':
ts_spec_dict = arr_or_ts_spec_dict
path = ts_spec_dict.pop('path')
driver = 'file'
ts_spec_dict['kvstore'] = {'path': path, 'driver': driver}
elif arr_or_ts_spec_dict['kvstore']['driver'] == 'gfile':
ts_spec_dict = arr_or_ts_spec_dict
driver = 'file'
ts_spec_dict['kvstore']['driver'] = driver
elif arr_or_ts_spec_dict['kvstore']['driver'] == 'file':
ts_spec_dict = arr_or_ts_spec_dict
else:
raise ValueError('Unsupported TensoreStore driver. Got '
f'{arr_or_ts_spec_dict["kvstore"]["driver"]}.')
return ts_spec_dict
def _is_leaf(value):
return not isinstance(
value, dict) or set(value.keys()) >= {'driver', 'kvstore', 'metadata'}
return jax.tree_map(_gcs_to_file_driver, ckpt_contents, is_leaf=_is_leaf)
class _BytesConditionVariable(object):
"""Wraps a condition variable to control concurrency based on bytes."""
def __init__(self, num_bytes):
self._max_bytes = num_bytes
self._num_bytes = num_bytes
self._cv = asyncio.Condition(lock=asyncio.Lock())
async def wait_for_bytes(self, n_bytes):
async with self._cv:
await self._cv.wait_for(lambda: self._num_bytes > n_bytes)
self._num_bytes -= n_bytes
assert self._num_bytes >= 0
async def return_bytes(self, n_bytes):
async with self._cv:
self._num_bytes += n_bytes
assert self._num_bytes <= self._max_bytes
self._cv.notify_all()
class SaveStateTransformationFn(typing_extensions.Protocol):
def __call__(self, state_dict: PyTreeDef,
parameter_infos: PyTreeDef) -> Tuple[PyTreeDef, PyTreeDef]:
"""Transforms the state and param info, e.g., by remapping parameters.
Args:
state_dict: State in the current model.
parameter_infos: PyTree containing `_ParameterInfo` objects.
Returns:
A tuple whose first element is the result of transforming `state_dict` and
whose second element is the result of transforming `parameter_infos`.
"""
class RestoreStateTransformationFn(typing_extensions.Protocol):
def __call__(self,
state_dict: PyTreeDef,
target_state_dict: PyTreeDef,
*,
is_resuming: bool = False) -> PyTreeDef:
"""Transforms the given checkpoint state, e.g., by remapping parameters.
Args:
state_dict: State to transform, which could be from a previous version of
the model.
target_state_dict: State in the current model.
is_resuming: `True` iff this restore call is due to a job resuming after
being temporarily stopped due to, for example, a preemption. This is
useful when there is restore logic that should run when restoring from
some pre-existing checkpoint, but that should not run again when
resuming from a newly-written checkpoint.
Returns:
The result of transforming the `state_dict`.
"""
class Checkpointer(object):
"""Handles saving and restoring potentially-sharded T5X checkpoints.
Checkpoints are stored using a combination of msgpack (via flax.serialization)
and TensorStore.
Parameters (and other objects) that are not partitioned are written to the
msgpack binary directly (by host 0). Partitioned parameters are each written
to their own TensorStore, with each host writing their portion to the same
TensorStore in parallel. If a partition is written on multiple hosts, the
partition is further sharded across these replicas to avoid additional
overhead. In place of the paramater, a `tensorstore.Spec` is written to the
msgpack (by host 0) as a reference to be used during restore. Note that the
path of the array being written is relative. This makes the checkpoints
portable. In other words, even if the checkpoint files are moved to a new
directory, they can still be loaded. Because the path is relative, the
checkpoint directory information has to be dynamically provided. This is done
by `_update_ts_path_from_relative_to_absolute`.
For TensorStore driver using Google Cloud Storage (GCS) Key-Value Storage
Layer, the GCS bucket information is necessary. When a checkpoint is written
using the gcs driver, we don't want to hardcode the bucket information in the
resulting file in order to maintain the portability. Therefore, we use a dummy
bucket name of "t5x-dummy-bucket". When reading or writing the checkpoint, the
bucket information is parsed from the checkpoint directory and the bucket
information is dynamically updated.
Attributes:
checkpoints_dir: a path to a directory to save checkpoints in and restore
them from.
keep: an optional maximum number of checkpoints to keep. If more than this
number of checkpoints exist after a save, the oldest ones will be
automatically deleted to save space.
restore_dtype: optional dtype to cast targets to after restoring.
save_dtype: dtype to cast targets to before saving.
"""
def __init__(self,
train_state: train_state_lib.TrainState,
partitioner: partitioning.BasePartitioner,
checkpoints_dir: str,
dataset_iterator: Optional[tf.data.Iterator] = None,
*,
keep: Optional[int] = None,
save_dtype: jnp.dtype = np.float32,
restore_dtype: Optional[jnp.dtype] = None,
use_gda: Optional[bool] = False):
"""Checkpointer constructor.
Args:
train_state: A train state to be used to determine the structure of the
parameter tree, and the *full* (non-partitioned) parameter shapes and
dtypes. Saved and restored train states must match this structure.
partitioner: the partitioner to use for determining the local chunks
mapping or to perform params partitioning on restore.
checkpoints_dir: a path to a directory to save checkpoints in and restore
them from.
dataset_iterator: an optional iterator to save/restore.
keep: an optional maximum number of checkpoints to keep. If more than this
number of checkpoints exist after a save, the oldest ones will be
automatically deleted to save space.
save_dtype: dtype to cast targets to before saving.
restore_dtype: optional dtype to cast targets to after restoring. If None,
no parameter casting is performed.
use_gda: if True, enabled gda_lib.GlobalDeviceArray. Note: this is
currently an experimental feature under development.
"""
self._train_state = train_state
self._partitioner = partitioner
self.checkpoints_dir = checkpoints_dir
self.keep = keep
# Immutable due to use in `_get_parameter_infos`
self._save_dtype = save_dtype
self.restore_dtype = restore_dtype
self._dataset_ckpt = (
tf.train.Checkpoint(ds=dataset_iterator) if dataset_iterator else None)
self._use_gda = use_gda
if self._use_gda:
logging.info('Checkpointing using GDA format is enabled.')
data_layout = partitioner.get_data_layout()
self._dataset_ckpt_name = (
f'train_ds-'
f'{data_layout.shard_id:03}-of-{data_layout.num_shards:03}')
self._should_write_dataset_ckpt = (
dataset_iterator and data_layout.is_first_host_in_replica_set)
self._parameter_infos = self._get_parameter_infos()
asyncio.set_event_loop(asyncio.new_event_loop())
def _get_state_dict_for_save(self,
state_dict: Dict[str, Any],
lazy_load: bool = True) -> Mapping[str, Any]:
"""Gets the optimizer state dict."""
def _lazy_load_device_array(arr):
if isinstance(arr, jax.xla.DeviceArray):
return LazyThreadPoolArray(arr.shape, arr.dtype, lambda: np.array(arr))
return arr
if lazy_load:
state_dict = jax.tree_map(_lazy_load_device_array, state_dict)
return state_dict
def _get_parameter_infos(self):
"""Generates the state dict of _ParameterInfos for the Optimizer.
We generate a state dict (matching the shape of the optimizer state dict)
that stores a _ParameterInfo for each parameter array.
The _ParameterInfo contains the TensorStore spec for the parameter array and
the LocalChunkInfo describing the slice of the array local to this host.
Returns:
The state dict of _ParameterInfo objects.
"""
def _get_param_info(name: str, arr: Any, axes: partitioning.PartitionSpec):
# If a node in your model is None it is probably a param_state that is not
# used because of a MultiOptimizer. We don't want to have any parameter
# info for it because it shouldn't be saved or restored.
if arr is None:
return None
# Pass-through empty dict leaves, which occur with optax EmptyState().
if isinstance(arr, dict) and not arr:
return {}
if axes is None:
return _ParameterInfo(
name=name,
shape=arr.shape,
ts_spec=None,
local_chunk_info=None,
axes=None)
if self._use_gda and isinstance(arr, gda_lib.GlobalDeviceArray):
local_chunk_info = None
metadata = gda_serialization._get_metadata(arr) # pylint: disable=protected-access
del metadata['dtype']
else:
local_chunk_info = self._partitioner.get_local_chunk_info(
arr.shape, axes)
write_shape = [
si if sl == slice(None) else sl.stop - sl.start
for si, sl in zip(arr.shape, local_chunk_info.slice)
]
# TODO(levskaya, adarob): how should we handle stacked/fused variables??
chunk_shape = _choose_chunk_shape(
write_shape,
target_elements=_DESIRED_CHUNK_SIZE_BYTES / arr.dtype.itemsize)
metadata = {
'compressor': {
'id': 'gzip'
},
'shape': arr.shape,
'chunks': np.array(chunk_shape),
}
if self.checkpoints_dir.startswith('gs://'):
spec = {
'driver': 'zarr',
'dtype': jnp.dtype(arr.dtype).name,
'kvstore': {
'driver': 'gcs',
# We always write with a dummy bucket and dynamically update the
# bucket information. This makes the checkpoint files portable
# and not bind to the bucket that it was originally written to.
'bucket': 't5x-dummy-bucket',
},
'path': name.replace('/', '.'),
'metadata': metadata,
}
else:
spec = {
'driver': 'zarr',
'dtype': jnp.dtype(arr.dtype).name,
'kvstore': {
'driver': 'file',
'path': name.replace('/', '.')
},
'metadata': metadata,
}
return _ParameterInfo(
name,
shape=arr.shape,
ts_spec=ts.Spec(spec),
local_chunk_info=local_chunk_info,
axes=axes)
# Create a tree of param names as the keys on the path to each leaf
# separated by "/".
param_names = traverse_util.unflatten_dict({
k: '/'.join(k) for k in traverse_util.flatten_dict(
self._train_state.state_dict(), keep_empty_nodes=True)
})
return jax.tree_map(
_get_param_info, param_names,
self._get_state_dict_for_save(self._train_state.state_dict()),
self._partitioner.get_mesh_axes(self._train_state).state_dict())
def _get_checkpoint_dir(self, step: int) -> str:
return get_checkpoint_dir(self.checkpoints_dir, step)
def all_steps(self) -> Sequence[int]:
"""Returns list of available step numbers in ascending order."""
return all_steps(self.checkpoints_dir)
def latest_step(self) -> Optional[int]:
"""Returns latest step number or None if no checkpoints exist."""
return latest_step(self.checkpoints_dir)
def _remove_old_checkpoints(self):
"""Deletes oldest checkpoints if there are more than keep_checkpoints."""
if not self.keep:
return
existing_steps = self.all_steps()
to_remove = len(existing_steps) - self.keep
if to_remove <= 0:
return
for step in existing_steps[:to_remove]:
checkpoint_utils.remove_checkpoint_dir(self._get_checkpoint_dir(step))
def save(self,
train_state: train_state_lib.TrainState,
state_transformation_fns: Sequence[SaveStateTransformationFn] = (),
*,
concurrent_gb: int = 128):
"""Saves a checkpoint for the given train state.
Args:
train_state: the train state to save. May contain a combination of
LazyArray objects and arrays (e.g., np.ndarray, jax.DeviceArray)
state_transformation_fns: Transformations to apply, in order, to the state
before writing.
concurrent_gb: the approximate number of gigabytes of partitionable
parameters to process in parallel. Useful to preserve RAM.
"""
step = train_state.step
step = step.get() if isinstance(step, LazyArray) else step
step = _get_local_data(step)
# Integer, to avoid side effects in the checkpoint path.
step = int(step)
# Share a timestamp across devices.
timestamp = multihost_utils.broadcast_one_to_all(np.int32(time.time()))
final_dir = os.path.join(self.checkpoints_dir, f'checkpoint_{step}')
tmp_dir = final_dir + f'.tmp-{timestamp}'
if gfile.exists(final_dir):
logging.info(
'Skipping save checkpoint for step %d (directory %s already exists)',
step, final_dir)
return
logging.info('Saving checkpoint for step %d to %s', step, tmp_dir)
if jax.process_index() == 0:
gfile.makedirs(tmp_dir)
# Block all hosts until directory is ready.
multihost_utils.sync_global_devices(f'checkpointer:make_dir:{tmp_dir}')
written_state_dict = self._write_state_to_tensorstore(
tmp_dir, train_state, concurrent_gb, state_transformation_fns)
if self._should_write_dataset_ckpt:
logging.info("Writing dataset iterator state to '%s'.",
self._dataset_ckpt_name)
try:
self._dataset_ckpt.write(os.path.join(tmp_dir, self._dataset_ckpt_name))
except tf.errors.FailedPreconditionError as e:
logging.error(
'Input pipeline must be stateless in order to checkpoint. Cache '
'stateful steps offline or disable iterator checkpointing.')
raise e
# Block until complete on all hosts.
multihost_utils.sync_global_devices(
f'checkpointer:tensorstore_write_complete:{tmp_dir}')
if jax.process_index() == 0:
written_state_dict = jax.tree_map(_get_local_data, written_state_dict)
# Write msgpack file in host 0 only
msgpack_bytes = serialization.to_bytes({
'version': VERSION,
'optimizer': written_state_dict
})
with gfile.GFile(os.path.join(tmp_dir, 'checkpoint'), 'wb') as fp:
fp.write(msgpack_bytes)
# Finalize checkpoint directory.
if final_dir.startswith('gs://'):
subprocess.run(['gsutil', '-m', 'mv', tmp_dir, final_dir],
stdout=subprocess.DEVNULL,
check=True)
else:
gfile.rename(tmp_dir, final_dir)
logging.info('Saved checkpoint for step %d to %s', step, final_dir)
# Remove old checkpoints, if necessary.
self._remove_old_checkpoints()
# Block until complete on all hosts.
multihost_utils.sync_global_devices(
f'checkpointer:write_complete:{final_dir}')
def _write_state_to_tensorstore(
self,
ckpt_dir: str,
train_state: train_state_lib.TrainState,
concurrent_gb: int,
state_transformation_fns: Sequence[SaveStateTransformationFn],
) -> Mapping[str, Any]:
"""Writes extracted state from train state to Tensorstore."""
concurrent_bytes = concurrent_gb * 10**9
bytes_cv = _BytesConditionVariable(concurrent_bytes)
async def _write_array(maybe_arr: Any,
param_info: Optional[_ParameterInfo],
cast: bool = False):
"""Maybe write to TensorStore, returning object to write to msgpack.
Args:
maybe_arr: array or LazyArray to be written
param_info: ParameterInfo object. If None (or if param_info.ts_spec is
None), the array will be immediately returned without writing to
tensorstore. This is because array is None or is not partitioned, and
should be written separately.
cast: if True, performs cast operation using self._save_dtype.
Returns:
Tensorstore spec corresponding to the written array.
"""
if param_info is None or param_info.ts_spec is None:
# Write to the msgpack file on host 0.
if isinstance(maybe_arr, LazyArray):
return await maybe_arr.get_async()
return maybe_arr
# Only write each chunk of a parameter from one host
if self._use_gda or param_info.local_chunk_info.replica_id == 0:
arr = maybe_arr
# Wait until memory is available.
if isinstance(arr, gda_lib.GlobalDeviceArray):
n_bytes = sum([
shard.data.nbytes
for shard in arr.local_shards
if shard.replica_id == 0
])
else:
n_bytes = arr.nbytes
if n_bytes > concurrent_bytes:
logging.warning(
'Temporarily increasing the concurrency limits from %d bytes to '
'%d bytes to fit %s.', concurrent_bytes, n_bytes, param_info.name)
n_bytes = concurrent_bytes
await bytes_cv.wait_for_bytes(n_bytes)
if isinstance(maybe_arr, LazyArray):
arr = await arr.get_async()
elif not isinstance(arr, np.ndarray) and not isinstance(
arr, gda_lib.GlobalDeviceArray):
# Cast jax.DeviceArray to np.ndarray.
arr = np.array(maybe_arr, dtype=maybe_arr.dtype)
tmp_ts_spec_dict = param_info.ts_spec.to_json()
if cast:
# Set desired destination dtype.
tmp_ts_spec_dict['dtype'] = jnp.dtype(self._save_dtype).name
param_info.ts_spec = ts.Spec(tmp_ts_spec_dict)
# Path and gcs bucket (if applicable) information is updated in-place.
_update_ts_path_from_relative_to_absolute(ckpt_dir, tmp_ts_spec_dict)
if cast:
# Set up casting spec.
tmp_ts_spec_dict = {
'base': tmp_ts_spec_dict,
'driver': 'cast',
'dtype': jnp.dtype(arr.dtype).name, # dtype before cast
}
if self._use_gda:
await gda_serialization.async_serialize(arr, tmp_ts_spec_dict)
else:
t = await ts.open(
tmp_ts_spec_dict,
create=True,
open=True,
context=ts.Context({'file_io_concurrency': {
'limit': 128
}}))
await t[param_info.local_chunk_info.slice].write(arr)
await bytes_cv.return_bytes(n_bytes)
# N.B. we return the original ts_spec (before
# `_update_ts_path_from_relative_to_absolute` was called). This is because
# we'd like to keep the path as relative, i.e., it doesn't hardcode the
# directory that the checkpoint was originally written. This makes the
# checkpoints portable.
return param_info.ts_spec
transformed_state_dict, transformed_parameter_infos = (
self._transform_state_and_infos(train_state.state_dict(),
self._parameter_infos,
state_transformation_fns))
state_dict_for_save = self._get_state_dict_for_save(transformed_state_dict)
def _cast_arr_if_not_partitioned(maybe_arr, param_info):
if param_info is None or param_info.ts_spec is None:
return _cast(maybe_arr, self._save_dtype)
return maybe_arr
state_dict_for_save['target'] = jax.tree_multimap(
_cast_arr_if_not_partitioned, state_dict_for_save['target'],
transformed_parameter_infos['target'])
future_written_state = {}
for k in state_dict_for_save.keys():
# ensure that only 'target' is cast
future_written_state[k] = jax.tree_multimap(
functools.partial(_write_array, cast=(k == 'target')),
state_dict_for_save[k], transformed_parameter_infos[k])
# Block until complete on this host.
written_state_dict = _run_future_tree(future_written_state)
# Block until complete on all hosts.
multihost_utils.sync_global_devices(
f'checkpointer:ts_write_complete:{ckpt_dir}')
return written_state_dict
def _transform_state_and_infos(
self,
state_dict: PyTreeDef,
parameter_infos: PyTreeDef,
state_transformation_fns: Sequence[SaveStateTransformationFn],
) -> Tuple[PyTreeDef, PyTreeDef]:
"""Applies transformations to the state dict and parameter infos PyTrees."""
for fn in state_transformation_fns:
state_dict, parameter_infos = fn(state_dict, parameter_infos)
return state_dict, parameter_infos
def restore(
self,
step: Optional[int] = None,
path: Optional[str] = None,
state_transformation_fns: Sequence[RestoreStateTransformationFn] = (),
fallback_state: Optional[Mapping[str, Any]] = None,
lazy_parameters: bool = False) -> train_state_lib.TrainState:
"""Restores the host-specific parameters in an Optimizer.
Either `step` or `path` can be specified, but not both. If neither are
specified, restores from the latest checkpoint in the checkpoints directory.
Args:
step: the optional step number to restore from.
path: an optional absolute path to a checkpoint file to restore from.
state_transformation_fns: Transformations to apply, in order, to the state
after reading.
fallback_state: a state dict of an optimizer to fall back to for loading
params that do not exist in the checkpoint (after applying all
`state_transformation_fns`), but do exist in `Checkpointer.optimizer`.
The union of `fallback_state` and state loaded from the checkpoint must
match `Checkpointer.optimizer`.
lazy_parameters: whether to load the parameters as LazyArrays to preserve
memory.
Returns:
The restored train state.
Raises:
ValueError if both `step` and `path` are specified.
ValueError if checkpoint at `path` or `step` does not exist.
ValueError if `step` and `path` are not specified and no checkpoint is
found in the checkpoints directory.
"""
if lazy_parameters and self._partitioner.params_on_devices:
raise ValueError('Lazy Parameters cannot be copied to devices, please '
'set partitioner.params_on_devices=False.')
if step is not None and path is not None:
raise ValueError('At most one of `step` or `path` may be provided.')
if path:
ckpt_path = path
else:
if step is None:
step = self.latest_step()
if not step:
raise ValueError(f'No checkpoints found in {self.checkpoints_dir}.')
ckpt_path = self._get_checkpoint_dir(step)
if gfile.isdir(ckpt_path):
ckpt_dir = ckpt_path
ckpt_path = os.path.join(ckpt_path, 'checkpoint')
else:
ckpt_dir = os.path.dirname(ckpt_path)
if not gfile.exists(ckpt_path) or gfile.isdir(ckpt_path):
raise ValueError(f'Path is not a valid T5X checkpoint: {ckpt_path}')
logging.info('Restoring from checkpoint: %s', ckpt_path)
with gfile.GFile(ckpt_path, 'rb') as fp:
# TODO(adarob): Use threaded reading as in flax.checkpoints.
raw_contents = fp.read()
if raw_contents.startswith(b'model_checkpoint_path'):
raise ValueError(
'Attempting to restore a TensorFlow checkpoint as a native T5X '
'checkpoint. Use `restore_from_tf_checkpoint` instead. Path: ' +
ckpt_path)
# `ckpt_contents['optimizer']` is a pytree with a realized np.array for
# leaves (params or states) written as msgpack and a ts.Spec (in a dict)
# for leaves written by TensorStore.
ckpt_contents = serialization.msgpack_restore(raw_contents)
# If reading a ckpt that was written with gfile driver but the current
# session uses the gcs driver, convert the ckpt's driver to gcs.
if ckpt_dir.startswith('gs://'):
ckpt_contents = _maybe_update_ts_from_file_to_gcs(ckpt_contents)
# If a ckpt was saved in gcs and is being loaded locally, then convert the
# driver to file or gfile. If the ckpt was not saved in gcs, do not change.
else:
ckpt_contents = _maybe_update_ts_from_gcs_to_file(ckpt_contents)
ckpt_state_dict = self._get_optimizer_state_dict(ckpt_contents,
state_transformation_fns)
# The state dict may contain TensorStore specs that need to be read.
dummy_spec = ts.Spec({'driver': 'zarr', 'kvstore': {'driver': 'memory'}})
# `dummy_written_state_dict` is a pytree with a `dummy_spec` for leaves
# (params or states) written as msgpack and a ts.Spec (in a dict) for leaves
# written by TensorStore.
dummy_written_state_dict = jax.tree_map(
lambda x: x.ts_spec or dummy_spec,
self._parameter_infos,
)
if fallback_state is None:
restore_parameter_infos = self._parameter_infos
else:
# If `fallback_state` was specified, restore only the subset
# of parameters matched by `self._get_optimizer_state_dict`. The
# rest will be provided by `fallback_state`.
dummy_written_state_dict = state_utils.intersect_state(
dummy_written_state_dict, ckpt_state_dict)
restore_parameter_infos = state_utils.intersect_state(
self._parameter_infos, ckpt_state_dict)
restore_parameter_infos_flat = state_utils.flatten_state_dict(
restore_parameter_infos)
for key in restore_parameter_infos_flat.keys():
logging.info('Restoring key from ckpt: %s', key)
# NB: `serialization.from_state_dict` doesn't check whether the shapes match
# at the leaf level. Non-partitioned leaves (e.g., optimizer states) can
# load arrays with inconsistent shapes.
# `written_state_dict` is a pytree with a realized np.array for leaves
# (params or states) written as msgpack and a `ts.Spec` for leaves written
# by TensorStore.
written_state_dict = serialization.from_state_dict(dummy_written_state_dict,
ckpt_state_dict)
state_dict = self._read_state_from_tensorstore(
ckpt_path,
written_state_dict,
restore_parameter_infos=restore_parameter_infos,
lazy_parameters=lazy_parameters)
# If `fallback_state` was specified, then fill the missing parameters.
if fallback_state is not None:
state_dict = state_utils.merge_state(state_dict, fallback_state)
for key in state_utils.flatten_state_dict(state_dict).keys():
if key not in restore_parameter_infos_flat:
logging.info('Not restoring key from ckpt: %s', key)
if self._dataset_ckpt:
logging.info("Restoring dataset iterator from '%s'.",
self._dataset_ckpt_name)
self._dataset_ckpt.read(os.path.join(
ckpt_dir, self._dataset_ckpt_name)).assert_consumed()
return self._restore_train_state(state_dict)
def _restore_train_state(
self,
state_dict: optimizers.OptimizerStateType) -> train_state_lib.TrainState:
"""Restores a TrainState from an Optimizer state_dict."""
train_state = self._train_state.restore_state(state_dict)
if not self._use_gda and self._partitioner.params_on_devices:
logging.info('Moving params to devices.')
train_state_axes = self._partitioner.get_mesh_axes(train_state)
train_state = self._partitioner.move_params_to_devices(
train_state, train_state_axes)
return train_state
def _create_lazy_awaitable_array(
self, param_info: _ParameterInfo, maybe_ts_spec: Any, ckpt_path: str,
restore_dtype: Optional[jnp.dtype]) -> LazyAwaitableArray:
"""Creates LazyArray from tensorstore.
Does not materialize the array immediately.
Args:
param_info: Information about how to read the parameter, host based sliced
reads and the like.
maybe_ts_spec: The tensorstore spec to read the parameter or some other
object. If this is an array then we will do a host based sliced read on
it (provided the param_info says to). Anything else we just return.
ckpt_path: A base location to use when resolving the relative paths in the
tensorstore spec.
restore_dtype: type to restore as. None indicates that no cast is
requested.
Returns:
LazyArray object.