/
trainer_lib.py
992 lines (853 loc) · 38.3 KB
/
trainer_lib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
# coding=utf-8
# Copyright 2020 The Trax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trax main training functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
import gzip as gzip_lib
import itertools
import os
import pickle
import random
import sys
import time
from absl import logging
import gin
import jax
import numpy
import six
import tensorflow.compat.v2 as tf
from trax import history as trax_history
from trax import jaxboard
from trax import layers as tl
from trax import lr_schedules as lr
from trax import math
from trax import optimizers as trax_opt
from trax.math import numpy as np
from trax.math import random as jax_random
from trax.shapes import ShapeDtype
from trax.supervised import inputs as trax_inputs
# TODO(afrozm): Maybe flatten everything from OptState into TrainerState.
TrainerState = collections.namedtuple('_TrainerState', [
'step', # Current training step number.
'opt_state', # OptState.
'history', # trax.history.History.
'model_state', # Auxilliary state of the model.
])
OptState = collections.namedtuple('_OptState', [
'weights', # Model weights.
'slots', # Per-parameter optimizer state, e.g. gradient moments.
'opt_params', # Optimizer (hyper)parameters, e.g. learning rate, momentum.
])
_DEFAULT_METRICS = {
'loss': tl.CrossEntropyLoss(),
'accuracy': tl.AccuracyScalar(),
'sequence_accuracy': tl.SequenceAccuracyScalar(),
'neg_log_perplexity': tl.Serial(tl.CrossEntropyLoss(), tl.Negate()),
'weights_per_batch_per_core': tl.SumOfWeights(),
}
class Trainer(object):
"""Trax trainer.
A trainer allows to make training steps, train for full epochs,
save the training state and access evaluation data.
"""
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs,
output_dir=None, random_seed=None, n_devices=None,
checkpoints_at=None, should_save_checkpoints=True,
should_write_summaries=True, nontrainable_param_map=None,
id_to_mask=None,
metrics=None, checkpoint_highest=None, checkpoint_lowest=None):
self._is_chief, self._n_devices, rng = (
self._init_host_and_devices(n_devices, random_seed))
self._should_save_checkpoints = should_save_checkpoints and self._is_chief
self._checkpoints_at = checkpoints_at or []
self._should_write_summaries = should_write_summaries
if not output_dir:
self._should_save_checkpoints = False
self._should_write_summaries = False
self._checkpoint_highest = checkpoint_highest
self._checkpoint_lowest = checkpoint_lowest
self._id_to_mask = id_to_mask
self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
# Inputs is either an Inputs instance or a function that returns it.
self._inputs = inputs
if callable(inputs): # If we pass a function, e.g., through gin, call it.
self._inputs = inputs()
# Mask id_to_mask and add weights if needed.
# TODO(lukaszkaiser, jonni): move this out of Trainer to input processing.
self._inputs = _add_weights_and_mask(self._inputs, id_to_mask)
# Initialize the learning rate to a dummy value. It will be set in reset().
opt = optimizer(learning_rate=0.0)
# Setup the model.
model_train = model(mode='train')
model_predict_eval = model(mode='eval')
# Setup state.
rng, init_rng = jax_random.split(rng)
self._rngs = np.stack(jax_random.split(rng, self._n_devices))
def new_opt_state_and_model_state(shape_dtype, rng):
"""Returns optimizer and model states suitable for training a model."""
# Combine inputs and targets on the stack.
shapes, dtypes = shape_dtype
input_signature = tuple(ShapeDtype(s, d)
for (s, d) in zip(shapes, dtypes))
# We need to create a new model instance and not reuse `model_train` here,
# because `m.initialize` puts cached parameter values in `m` and hence the
# next call of `m.initialize` will give wrong results.
m = tl.Serial(model(mode='train'), loss_fn)
m._set_rng_recursive(rng) # pylint: disable=protected-access
weights, state = m.init(input_signature)
(slots, opt_params) = opt.tree_init(weights)
return (OptState(weights, slots, opt_params), state)
if _is_jit_init():
# JIT parameter initialization to avoid memory fragmentation
new_opt_state_and_model_state = math.jit(new_opt_state_and_model_state,
static_argnums=(0,))
self._new_opt_state_and_model_state = (
lambda: new_opt_state_and_model_state( # pylint: disable=g-long-lambda
self._inputs.example_shape_dtype, init_rng))
# Arrange and initialize metrics layers.
self._metrics = list(sorted(self._metrics_dict.keys()))
metrics_layers = [self._metrics_dict[m] for m in self._metrics]
metrics_in_parallel = tl.Branch(*metrics_layers)
metrics_in_parallel._set_rng_recursive(init_rng) # pylint: disable=protected-access
example_signature = tuple(
ShapeDtype(s, d) for (s, d) in zip(*self._inputs.example_shape_dtype)
)
model_predict_eval.init(example_signature)
output_signature = model_predict_eval.output_signature(example_signature)
m_weights, m_state = metrics_in_parallel.init(output_signature)
self._metrics_weights = self._for_n_devices(m_weights)
self._metrics_state = self._for_n_devices(m_state)
# Jit model_predict and update so they're fast.
self._jit_eval = _jit_predict_fn(
model_predict_eval, metrics_in_parallel, self._n_devices)
self._jit_update_fn = _jit_update_fn(
model_train, loss_fn, opt, self._n_devices)
self._model_train = model_train
self._model_predict_eval = model_predict_eval
self._loss_fn = loss_fn
# TODO(pkozakowski): "Learning rate schedules" are currently able to control
# control all optimizer parameters and model state, so let's rename them
# accordingly.
self._lr_schedule = lr_schedule
if nontrainable_param_map is None:
nontrainable_param_map = {}
self._nontrainable_param_map = nontrainable_param_map
# Those fields will be set in reset().
self._output_dir = None
self._train_sw = None
self._eval_sw = None
self._history = None
self._lr_fn = None
self._opt_state = None
self._step = None
self._model_state = None
self.reset(output_dir)
@property
def n_devices(self):
return self._n_devices
@property
def step(self):
return self._step
@property
def model_weights(self):
# Currently we need to pick [0] as we ignore loss weights (empty).
weights = self._opt_state.weights[0]
if self.n_devices > 1:
unreplicate = lambda x: x[0]
weights = math.nested_map(unreplicate, weights)
return weights
@model_weights.setter
def model_weights(self, weights):
new_model_weights = self._for_n_devices(weights)
if isinstance(self._opt_state.weights, list):
self._opt_state.weights[0] = new_model_weights
else: # weights are a tuple, need to re-create
new_weights = [new_model_weights] + list(self._opt_state.weights[1:])
self._opt_state = self._opt_state._replace(weights=new_weights)
@property
def state(self):
return TrainerState(
opt_state=self._opt_state, step=self._step, history=self._history,
model_state=self._model_state)
@property
def nontrainable_params(self):
# TODO(afrozm): Give further thought to this name.
# TODO(lukaszkaiser): it makes no sense to use an accelerator (e.g. TPU)
# in op-by-op mode just to compute the learning rate. However, there
# should be a cleaner approach that forceably swapping out the backend.
with math.use_backend('numpy'):
return self._lr_fn(self._step)
def reset(self, output_dir, init_checkpoint=None):
"""Reset the model parameters.
Restores the parameters from the given output_dir if a checkpoint exists,
otherwise randomly initializes them.
Does not re-jit the model.
Args:
output_dir: Output directory.
init_checkpoint: Initial checkpoint to use (default $output_dir/model.pkl)
"""
self.close()
self._output_dir = output_dir
if output_dir is not None:
tf.io.gfile.makedirs(output_dir)
else:
assert not self._should_save_checkpoints
assert not self._should_write_summaries
# Create summary writers and history.
if self._should_write_summaries:
self._train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'train'),
enable=self._is_chief)
self._eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'eval'),
enable=self._is_chief)
# Reset the train and eval streams.
self._train_stream = _repeat_stream(self._inputs.train_stream,
self._n_devices)
# TODO(lukaszkaiser): add an option to evaluate exactly on the full eval
# set by adding a padding and stopping the stream when too large.
self._eval_stream = _repeat_stream(
self._inputs.eval_stream, self._n_devices)
self._train_eval_stream = _repeat_stream(
self._inputs.train_eval_stream, self._n_devices)
# Restore the training state.
if output_dir is not None:
state = load_trainer_state(output_dir, init_checkpoint)
else:
state = TrainerState(step=None, opt_state=None,
history=trax_history.History(), model_state=None)
self._step = state.step or 0
history = state.history
self._lr_fn = self._lr_schedule(history)
self._history = history
if state.opt_state:
opt_state = state.opt_state
model_state = state.model_state
else:
opt_state, model_state = self._new_opt_state_and_model_state()
model_state = self._for_n_devices(model_state)
self._opt_state = OptState(*self._for_n_devices(opt_state))
self._model_state = model_state
if not state.opt_state and self._should_save_checkpoints:
self.save_state(keep=False)
self.update_nontrainable_params()
def train_epoch(self, n_steps, n_eval_steps):
"""Runs `n_steps` of training, with periodic logging, saving, and evals."""
# TODO(jonni): Clarify how this method relates to the stricter notion of
# epoch (training for as many steps as needed for a full pass through the
# training data).
print() # Add visual separator in logs for start of training epoch.
start_time = time.time()
for _ in range(n_steps):
batch = next(self._train_stream)
if self.n_devices > 1: # TODO(lukaszkaiser): use everywhere if possible.
batch = _reshape_by_device(batch, self.n_devices)
self.train_step(batch)
if self._should_save_now():
self.save_state(keep=True)
if self._should_log_now():
for (name, value) in self.nontrainable_params.items():
self._train_sw.scalar('training/{}'.format(name), value)
# At end of n_steps, do bookkeeping, run evals, and save state.
elapsed_time = time.time() - start_time
self.log_step('Ran %d train steps in %0.2f secs' % (n_steps, elapsed_time))
if self._train_sw and n_steps > 1:
self._train_sw.scalar('training/steps per second',
n_steps / elapsed_time, step=self._step)
self._train_sw.flush()
self.evaluate(n_eval_steps)
if self._eval_sw:
self._eval_sw.flush()
if self._should_save_checkpoints:
self.save_state(keep=False)
if self._should_save_checkpoints and self._current_step_is_best(high=True):
self.save_state(keep=False, prefix='highest_' + self._checkpoint_highest)
if self._should_save_checkpoints and self._current_step_is_best(high=False):
self.save_state(keep=False, prefix='lowest_' + self._checkpoint_lowest)
def train_step(self, batch):
"""Run one training step and update self._opt_state."""
# Calculate the current optimizer parameters.
# TODO(pkozakowski): Optimizer parameters get polluted with model state,
# which doesn't break anything but is weird. Filter it out.
opt_param_updates = self._for_n_devices(
math.nested_map(np.array, self.nontrainable_params))
opt_state = self._opt_state
opt_state.opt_params.update(opt_param_updates)
# Run the update.
(weights, slots), self._model_state, self._rngs = self._jit_update_fn(
self._step, opt_state, batch, self._model_state, self._rngs)
self._model_state = self._map_to_state_dicts(self._state_dicts_update)
self._opt_state = opt_state._replace(weights=weights, slots=slots)
self._step += 1
def evaluate(self, n_eval_steps):
"""Evaluate the model and log metrics."""
_, rng = jax_random.split(self._rngs[0])
# TODO(lukaszkaiser): both model state and parameters by default include
# the loss layer. Currently, we access the pure-model parameters by just
# indexing, [0] here. But we should make it more explicit in a better API.
weights = (self._opt_state[0][0], self._metrics_weights)
state = (self._model_state[0], self._metrics_state)
self.log_step('Evaluation')
train_eval_slice = itertools.islice(self._train_eval_stream, n_eval_steps)
train_metrics, _ = self.evaluation_round(train_eval_slice, weights, state,
rng)
self.log_metrics(train_metrics, self._train_sw, 'train')
eval_slice = itertools.islice(self._eval_stream, n_eval_steps)
eval_metrics, _ = self.evaluation_round(eval_slice, weights, state, rng)
self.log_metrics(eval_metrics, self._eval_sw, 'eval')
self.log_step('Finished evaluation')
# Save the optimizer weights in the history
for (name, value) in self.nontrainable_params.items():
self._history.append('train', 'training/{}'.format(name), self._step,
value)
def evaluation_round(self, inputs_stream, weights, state, rng):
"""Evaluate.
Args:
inputs_stream: iterable of inputs to evaluate on.
weights: weights for each f in eval_fns.
state: state for each f in eval_fns.
rng: random number generator.
Returns:
metrics: dict from metric name to metric value averaged over the number of
inputs.
state: end state for `predict_fn`.
"""
metrics = collections.defaultdict(float)
count = 0
for inp in inputs_stream:
count += 1
rng, subrng = jax_random.split(rng)
metric_values, _ = self._jit_eval(inp, weights, state, subrng)
try:
metric_values = list(metric_values)
except TypeError:
metric_values = [float(metric_values)]
for m, v in zip(self._metrics, metric_values):
metrics[m] += v
return {m: v / count for (m, v) in six.iteritems(metrics)}, state
def update_model_state(self, key, value):
"""Updates model state based on nontrainable_params."""
# Translate model state keys to nontrainable param names.
if key in self._nontrainable_param_map:
p_name = self._nontrainable_param_map[key]
else:
# If a key is not in mapping, it stays the same.
p_name = key
if p_name in self.nontrainable_params:
if self._step == 0:
log('Mapping model state key {} to nontrainable param {}.'
''.format(key, p_name))
return self._for_n_devices(np.array(self.nontrainable_params[p_name]))
return value
def update_nontrainable_params(self):
self._lr_fn = self._lr_schedule(self._history)
def save_gin(self):
assert self._output_dir is not None
config_path = os.path.join(self._output_dir, 'config.gin')
config_str = gin.operative_config_str()
with tf.io.gfile.GFile(config_path, 'w') as f:
f.write(config_str)
sw = self._train_sw
if sw:
sw.text('gin_config',
jaxboard.markdownify_operative_config_str(config_str))
def _save_state_dict(self, trainer_state_dict, weights_file):
pickle_to_file(trainer_state_dict, weights_file)
log('Model saved to %s' % weights_file, stdout=False)
def save_state(self, keep, prefix='model'):
"""Save trainer state given a possibly replicated opt_state."""
opt_state = self._opt_state
if self.n_devices > 1:
first_replica = lambda x: x[0]
opt_state = OptState(*math.nested_map(first_replica, opt_state))
# This line, while optional, allows JAX to transfer arrays from the device
# to the host in parallel, which is particularly important for cloud TPU.
if math.backend_name() == 'jax':
opt_state = jax.device_get(opt_state)
step, history, model_state = self._step, self._history, self._model_state
output_dir = self._output_dir
weights_file = os.path.join(output_dir, prefix + '.pkl')
# This dict will be stored as the model.
trainer_state_dict = make_trainer_state_dict(step,
opt_state,
history,
model_state)
self._save_state_dict(trainer_state_dict, weights_file)
if keep:
weights_file = os.path.join(output_dir, '{}_{}.pkl'.format(prefix, step))
self._save_state_dict(trainer_state_dict, weights_file)
def save_computation_graphs(self, save_backward_graph):
"""Dump computation graphs to files."""
if self.n_devices != 1:
return # TODO(lukaszkaiser): make this work with more devices.
batch = next(self._train_stream)
output_dir = self._output_dir
if self.n_devices > 1:
batch = _reshape_by_device(batch, self.n_devices)
weights = self._opt_state[0][0]
forward_computation = jax.xla_computation(self._model_predict_eval)(
batch, weights=weights, state=self._model_state[0],
rng=self._rngs[0])
with tf.io.gfile.GFile(os.path.join(output_dir, 'forward.txt'), 'w') as f:
f.write(forward_computation.GetHloText())
with tf.io.gfile.GFile(os.path.join(output_dir, 'forward.dot'), 'w') as f:
f.write(forward_computation.GetHloDotGraph())
backward_computation = jax.xla_computation(self._jit_update_fn)(
self._step, self._opt_state, batch, self._model_state,
self._rngs)
with tf.io.gfile.GFile(os.path.join(output_dir, 'backward.txt'), 'w') as f:
f.write(backward_computation.GetHloText())
if save_backward_graph: # Backward graphs can be large so we guard it.
with tf.io.gfile.GFile(
os.path.join(output_dir, 'backward.dot'), 'w') as f:
f.write(backward_computation.GetHloDotGraph())
def log_step(self, step_message):
log('Step % 6d: %s' % (self.step, step_message))
def log_metrics(self, metrics, summ_writer, log_prefix):
"""Log metrics to summary writer and history."""
history = self._history
rjust_len = max([0] + [len(name) for name in metrics])
for name, value in six.iteritems(metrics):
self.log_step('%s %s | % .8f' % (
log_prefix.ljust(5), name.rjust(rjust_len), value))
full_name = 'metrics/' + name
if history:
history.append(log_prefix, full_name, self.step, value)
if summ_writer:
summ_writer.scalar(full_name, value, self.step)
def print_n_weights(self):
"""Prints the total count of trainable weights."""
opt_state = self._opt_state
sizes = _sizes(opt_state.weights)
if self.n_devices > 1:
unreplicate = lambda x: x[0]
single_weights = math.nested_map(unreplicate, opt_state.weights)
sizes = _sizes(single_weights)
total_size = _nested_reduce(sum, sizes)
self.log_step('Total number of trainable weights: %d' % total_size)
def _init_host_and_devices(self, n_devices=None, random_seed=None):
"""Initializes host and device attributes for this trainer.
Args:
n_devices: Number of devices this trainer will use. If `None`, get the
number from the backend.
random_seed: Random seed as the starting point for all random numbers used
by the trainer. If `None`, calculate one from system time and host id.
Returns:
is_chief: True if this trainer has special chief responsibilities.
n_devices: The passed in value of n_devices or a computed default.
random_seed: The passed in value of random_seed or a computed default.
"""
if math.backend_name() == 'jax':
host_id = jax.host_id()
host_count = jax.host_count()
else:
host_id = 0
host_count = 1
is_chief = (host_id == 0)
device_count = math.device_count()
n_devices = n_devices or device_count
# TODO(lukaszkaiser): remove this restriction when possible.
if n_devices != device_count and math.backend_name() == 'jax':
raise ValueError('JAX cannot work yet with n_devices != all devices: '
'%d != %d' % (n_devices, device_count))
if random_seed is None and host_count > 1:
random_seed = int(1e6 * (host_id + time.time())) % 2**32
return is_chief, n_devices, init_random_number_generators(random_seed)
def _map_to_state_dicts(self, f):
"""Map the function f to all dicts in model state."""
# TODO(jonni): Can we replace _nested_map with math.nested_map?
def _nested_map(f, x):
if isinstance(x, list):
return [_nested_map(f, y) for y in x]
if isinstance(x, tuple):
return tuple([_nested_map(f, y) for y in x])
if isinstance(x, dict) and len(x) == 1:
return f(x)
return x
return _nested_map(f, self._model_state)
def _state_dicts_update(self, state_dict):
assert len(state_dict.keys()) == 1
key = list(state_dict.keys())[0]
value = state_dict[key]
return {key: self.update_model_state(key, value)}
def _should_save_now(self):
return self._should_save_checkpoints and self._step in self._checkpoints_at
def _current_step_is_best(self, high):
"""Is the current step the best (highest if high, else lowest)."""
metric = self._checkpoint_highest if high else self._checkpoint_lowest
if metric is None:
return False
# History is a list of pairs (step, value).
history = self._history.get('eval', 'metrics/' + metric)
sequence = [float(i[1]) for i in history] # Just the values.
best = max(sequence) if high else min(sequence) # Best value.
last_is_best = float(history[-1][1]) == best # Is last the best?
cur_step = history[-1][0] == self._step # Is last the current step?
return cur_step and last_is_best
def _should_log_now(self):
return (self._train_sw is not None
and (self._step == 1 or self._step % 10 == 0))
def _for_n_devices(self, x):
"""Replicates/broadcasts `x` for n devices if `self.n_devicess > 1`."""
return tl.for_n_devices(x, self.n_devices) # pylint: disable=protected-access
def close(self):
if self._train_sw is not None:
self._train_sw.close()
self._train_sw = None
if self._eval_sw is not None:
self._eval_sw.close()
self._eval_sw = None
@gin.configurable(blacklist=['output_dir'])
def train(output_dir,
model=gin.REQUIRED,
loss_fn=tl.CrossEntropyLoss(),
inputs=trax_inputs.inputs,
optimizer=trax_opt.Adafactor,
lr_schedule=lr.MultifactorSchedule,
trainer_class=Trainer,
steps=1000,
checkpoints_at=None,
eval_steps=10,
eval_frequency=100,
random_seed=None,
save_graphs=True,
save_backward_graph=False,
nontrainable_param_map=None,
id_to_mask=None,
metrics=None,
checkpoint_highest=None,
checkpoint_lowest=None,
custom_train_fn=None):
"""Train the model on the inputs.
Args:
output_dir: Directory where to put the logs and checkpoints.
model: The model to train as a callable returning 2 callables, an init_fn
and apply_fn.
loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state,
rng -> loss.
inputs: callable returning trax.inputs.Inputs.
optimizer: The optimizer (see optimizers/base.py for signature).
lr_schedule: A learning rate schedule as a function that takes history and
returns a function from step to learning rate (a float).
trainer_class: The trainer class to use.
steps: int, total number of training steps.
checkpoints_at: list of integers. Save a checkpoint for each training step
in the list.
eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
eval_frequency: int, how often to run evaluation (every eval_frequency
steps). If None or 0, eval disabled.
random_seed: the random seed to use; time/os dependent if None (default).
save_graphs: bool, if True, save computation graph to file.
save_backward_graph: bool, if True, save backward graph to file too.
nontrainable_param_map: dict, mapping from model nontrainable parameter
names to control names in PolicySchedule.
id_to_mask: id to mask out (None by default).
metrics: optionally override the default metrics dictionary.
checkpoint_highest: save the checkpoint highest at this metric.
checkpoint_lowest: save the checkpoint lowest at this metric.
custom_train_fn: custom train function to call, entirely bypassing this one
Returns:
trax.TrainerState
"""
if custom_train_fn is not None:
return custom_train_fn(output_dir, model=model)
n_devices = num_devices()
# TODO(lukaszkaiser): remove has_weights and id_to_mask (configure loss).
trainer = trainer_class(model, loss_fn, optimizer, lr_schedule, inputs,
output_dir,
random_seed=random_seed, n_devices=n_devices,
checkpoints_at=checkpoints_at,
nontrainable_param_map=nontrainable_param_map,
metrics=metrics, id_to_mask=id_to_mask,
checkpoint_lowest=checkpoint_lowest,
checkpoint_highest=checkpoint_highest)
epoch_steps = [steps] # Only training if eval_frequency is 0 or None
if eval_frequency and eval_steps > 0:
epoch_steps = itertools.chain([1, # first epoch only 1 step
eval_frequency - 1],
itertools.repeat(eval_frequency))
trainer.log_step('Starting training using %d devices' % trainer.n_devices)
trainer.print_n_weights()
try:
for epoch_steps in epochs(steps, trainer.step, epoch_steps):
trainer.train_epoch(epoch_steps, eval_steps)
# Update nontrainable parameters with new history
trainer.update_nontrainable_params()
# Bookkeeping we do at the first step
if trainer.step == 1:
# Save computation graph (single-device only for now)
if (save_graphs and math.backend_name() == 'jax'):
trainer.save_computation_graphs(save_backward_graph)
# Save Gin config
trainer.save_gin()
trainer.log_step('Training done')
except Exception as e:
raise e
finally:
trainer.close()
return trainer.state
@gin.configurable
def num_devices(value=None):
"""Returns how many devices to use (if None, default, use all available)."""
return value
@gin.configurable
def _is_jit_init(value=True):
return value
@gin.configurable
def _jit_update_fn(predict_fn, loss_fn, optimizer, n_devices, jit=True):
"""Returns a (JIT-compiled) function that computes updates for one step."""
model_and_loss = tl.Serial(predict_fn, loss_fn)
# Gradients are always wrt. the first argument, so putting weights first.
def model_and_loss_call(weights, batch, state, rng):
res = model_and_loss(batch, weights=weights, state=state, rng=rng)
return res, model_and_loss.state
if n_devices == 1: # TODO(lukaszkaiser): remove branch when not needed.
def single_update(i, opt_state, batch, state, rng):
weights, slots, opt_params = opt_state
rng, subrng = jax_random.split(rng[0])
grad_fn = math.grad(model_and_loss_call, has_aux=True)
grads, state = grad_fn(weights, batch, state, rng)
return optimizer.tree_update(
i, grads, weights, slots, opt_params), state, [subrng]
return math.jit(single_update) if jit else single_update
# Else, for n_devices > 1:
@functools.partial(math.pmap, axis_name='batch')
def mapped_update(i, opt_state, batch, state, rng):
"""This is a multi-device version of the update function above."""
# We assume all tensors have the first dimension = n_devices.
weights, slots, opt_params = opt_state
rng, subrng = jax_random.split(rng)
grad_fn = math.grad(model_and_loss_call, has_aux=True)
grads, state = grad_fn(weights, batch, state, rng)
# We do a psum(1.0) here instead of `n_devices` since `n_devices` is just
# the number of devices on this host machine, however psum goes over all
# devices of all hosts (ex: a TPU pod) and we need to be averaging over all
# of them.
grads = jax.tree_util.tree_map(
lambda g: math.psum(g, 'batch') / math.psum(np.array(1.0), 'batch'),
grads)
return optimizer.tree_update(
i, grads, weights, slots, opt_params), state, subrng
def update(i, opt_state, batch, state, rng):
return mapped_update(np.repeat(i, n_devices), opt_state, batch, state, rng)
return update
@gin.configurable
def _jit_predict_fn(model_predict, metric_fn, n_devices, jit=True):
"""Returns a JIT-compiled predict function (unless jit=False)."""
model = tl.Serial(model_predict, metric_fn)
if not jit:
return model.pure_fn
return tl.jit_forward(model.pure_fn, n_devices)
@gin.configurable
def _jit_compute_loss_fn(predict_fn, loss_fn, n_devices, jit=True):
"""Returns a (JIT-compiled) function that computes the loss for one step."""
if n_devices == 1: # TODO(lukaszkaiser): remove branch when not needed.
def single_compute_loss(opt_state, batch, state, rng):
rng, subrng = jax_random.split(rng[0])
loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng)
return loss_val, state, [subrng]
return math.jit(single_compute_loss) if jit else single_compute_loss
# Else, for n_devices > 1:
@functools.partial(math.pmap, axis_name='batch')
def mapped_compute_loss(opt_state, batch, state, rng):
"""This is a multi-device version of the update function above."""
# We assume all tensors have the first dimension = n_devices.
rng, subrng = jax_random.split(rng)
loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng)
return loss_val, state, subrng
def compute_loss(opt_state, batch, state, rng):
return mapped_compute_loss(
opt_state, _reshape_by_device(batch, n_devices), state, rng)
return compute_loss
def log(s, stdout=True):
logging.info(s)
if stdout:
print(s)
sys.stdout.flush()
def epochs(total_steps, steps_to_skip, epoch_steps):
"""Generates the number of steps in each epoch before reaching total_steps.
Args:
total_steps: int, total number of steps.
steps_to_skip: int, number of steps to skip because of a restart.
epoch_steps: iterable of int, numbers of steps in each epoch.
Yields:
epoch_steps: int, number of steps in this epoch
"""
steps_to_go = total_steps - steps_to_skip
epoch_steps = iter(epoch_steps)
# Remove the desired number of steps from the stream.
for steps_this_epoch in epoch_steps:
if steps_this_epoch > steps_to_skip:
# Put back the number of steps left in the unfinished epoch.
epoch_steps = itertools.chain(
[steps_this_epoch - steps_to_skip], epoch_steps)
if steps_this_epoch >= steps_to_skip:
break
steps_to_skip -= steps_this_epoch
# Yield the remaining steps per epoch up to total_steps.
for steps_this_epoch in epoch_steps:
steps_this_epoch = min(steps_this_epoch, steps_to_go)
yield steps_this_epoch
steps_to_go -= steps_this_epoch
if steps_to_go == 0:
break
def make_trainer_state_dict(step,
opt_state,
history,
model_state):
"""Creates a trainer state dictionary to save to disk.
Args:
step: int, a step number
opt_state: OptState namedtuple
history: `trax.history.History`, the history object.
model_state: A nested structure of the model state.
Returns:
A dictionary with the fields of TrainerState and OptState flattened.
"""
return {
'step': step,
'weights': opt_state.weights[0],
'loss_weights': opt_state.weights[1],
'slots': opt_state.slots,
'opt_params': opt_state.opt_params,
'history': history,
'state': model_state[0],
'loss_state': model_state[1],
'version_timestamp': 'Jan-13-2020' # To update in the future if needed.
}
def trainer_state_from_dict(trainer_state_dict):
"""Given the trainer state dictionary, returns `TrainerState`."""
# TODO(afrozm): This becomes simpler if OptState is flattened into
# TrainerState.
step = trainer_state_dict['step']
history = trainer_state_dict['history']
# TODO(lukaszkaiser): remove the first branch after everyone ports to 'state'.
if 'model_state' in trainer_state_dict:
model_state = trainer_state_dict['model_state']
else:
model_state = (trainer_state_dict['state'],
trainer_state_dict['loss_state'])
weights = trainer_state_dict['weights']
# TODO(lukaszkaiser): remove the next 2 lines after 'loss_weights' is in use.
if 'loss_weights' in trainer_state_dict:
weights = (weights, trainer_state_dict['loss_weights'])
opt_state = OptState(
weights=weights,
slots=trainer_state_dict['slots'],
opt_params=trainer_state_dict['opt_params'])
return TrainerState(step=step, opt_state=OptState(*opt_state),
history=history, model_state=model_state)
def load_trainer_state(output_dir, weights_file=None):
"""Returns a TrainerState instance loaded from the given `output_dir`."""
if weights_file is None:
weights_file = os.path.join(output_dir, 'model.pkl')
if not tf.io.gfile.exists(weights_file):
return TrainerState(step=None, opt_state=None,
history=trax_history.History(), model_state=None)
elif not tf.io.gfile.exists(weights_file):
raise ValueError('File not found: %s' % weights_file)
with tf.io.gfile.GFile(weights_file, 'rb') as f:
trainer_state_dict = pickle.load(f)
trainer_state = trainer_state_from_dict(trainer_state_dict)
log('Model loaded from %s at step %d' % (weights_file, trainer_state.step))
logging.debug('From loaded model : history = %s', trainer_state.history)
return trainer_state
def init_random_number_generators(seed=None):
"""Initializes random generators for Python, NumPy, TensorFlow, and JAX."""
# Seed Python random (None as seed is okay), then use it to seed the others.
random.seed(seed)
if seed is None:
seed = random.randint(0, 2**31 - 1)
numpy.random.seed(seed)
tf.random.set_seed(seed)
return jax_random.get_prng(seed)
def _reshape_by_device(x, n_devices):
"""Reshapes possibly nested x into a shape (n_devices, ...)."""
return tl.reshape_by_device(x, n_devices) # pylint: disable=protected-access
def _nested_reduce(f, x):
"""Fold the function f to the nested structure x (dicts, tuples, lists)."""
if isinstance(x, list):
return f([_nested_reduce(f, y) for y in x])
if isinstance(x, tuple):
return f([_nested_reduce(f, y) for y in x])
return x
def _sizes(x):
"""Get a structure of sizes for a structure of nested arrays."""
def size(x):
try:
return x.size
except Exception: # pylint: disable=broad-except
return 0
return math.nested_map(size, x)
def _repeat_stream(stream, n_devices):
"""Repeat a stream indefinitely."""
while True:
for example in stream(n_devices):
yield example
def pickle_to_file(obj, file_path, gzip=False):
"""Pickle obj to file_path with gzipping and failure protection."""
# Pickle to tmp file and overwrite to prevent writing partial files.
tmp_file_path = file_path + '._tmp_'
with tf.io.gfile.GFile(tmp_file_path, 'wb') as f:
if not gzip:
pickle.dump(obj, f)
else:
with gzip_lib.GzipFile(fileobj=f, compresslevel=2) as gzipf:
pickle.dump(obj, gzipf)
# Moving a file is much less error-prone than pickling large files.
tf.io.gfile.rename(tmp_file_path, file_path, overwrite=True)
def unpickle_from_file(file_path, gzip=False):
"""Unpickle obj from file_path with gzipping."""
with tf.io.gfile.GFile(file_path, 'rb') as f:
if not gzip:
obj = pickle.load(f)
else:
with gzip_lib.GzipFile(fileobj=f, compresslevel=2) as gzipf:
obj = pickle.load(gzipf)
return obj
def _add_weights_and_mask(inputs, id_to_mask):
"""Add weights to inputs without weights and masks by id if requested.
Each of the (train, eval, train_eval) streams of inputs is augmented in
the following way:
* if the stream consists of pairs (inputs, targets), a loss mask is added
that is creates as a tensor of ones of the same shape as targets
* if id_to_mask is not None, and the stream (after the previous point) has
triples (inputs, targets, weights), the weights are multipled by a 0/1 mask
that is 0 iff targets is equal to id_to_mask (1 otherwise).
Args:
inputs: a trax_inputs.Inputs object to operate on
id_to_mask: int or None, id to pad in targets if not None
Returns:
a trax_inputs.Inputs object with augmented streams
"""
def _with_masks(input_stream):
"""Create masks for the given stream."""
for example in input_stream:
if len(example) > 3 or len(example) < 2:
assert id_to_mask is None, 'Cannot automatically mask this stream.'
yield example
else:
if len(example) == 2:
weights = numpy.ones_like(example[1]).astype(numpy.float32)
else:
weights = example[2].astype(numpy.float32)
mask = 1.0 - numpy.equal(example[1], id_to_mask).astype(np.float32)
weights *= mask
yield (example[0], example[1], weights)
return trax_inputs.Inputs(
train_stream=lambda n: _with_masks(inputs.train_stream(n)),
eval_stream=lambda n: _with_masks(inputs.eval_stream(n)),
train_eval_stream=lambda n: _with_masks(inputs.train_eval_stream(n)))