/
transform.py
1422 lines (1101 loc) · 46.2 KB
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gradient transformations."""
import functools
from typing import Any, Callable, NamedTuple, Optional, Union
import chex
import jax
import jax.numpy as jnp
from optax._src import base
from optax._src import numerics
from optax._src import utils
from optax._src import wrappers
# pylint:disable=no-value-for-parameter
_abs_sq = numerics.abs_sq
class TraceState(NamedTuple):
"""Holds an aggregation of past updates."""
trace: base.Params
def trace(
decay: float,
nesterov: bool = False,
accumulator_dtype: Optional[Any] = None,
) -> base.GradientTransformation:
"""Compute a trace of past updates.
Note: `trace` and `ema` have very similar but distinct updates;
`trace = decay * trace + t`, while `ema = decay * ema + (1-decay) * t`.
Both are frequently found in the optimization literature.
Args:
decay: Decay rate for the trace of past updates.
nesterov: Whether to use Nesterov momentum.
accumulator_dtype: Optional `dtype` to be used for the accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
Returns:
A `GradientTransformation` object.
"""
accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype)
def init_fn(params):
return TraceState(
trace=jax.tree_util.tree_map(
lambda t: jnp.zeros_like(t, dtype=accumulator_dtype), params))
def update_fn(updates, state, params=None):
del params
f = lambda g, t: g + decay * t
new_trace = jax.tree_util.tree_map(f, updates, state.trace)
updates = (
jax.tree_util.tree_map(f, updates, new_trace) if nesterov
else new_trace)
new_trace = utils.cast_tree(new_trace, accumulator_dtype)
return updates, TraceState(trace=new_trace)
return base.GradientTransformation(init_fn, update_fn)
def update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order`-th moment."""
return jax.tree_util.tree_map(
lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments)
def update_infinity_moment(updates, moments, decay, eps):
"""Compute the exponential moving average of the infinity norm."""
return jax.tree_util.tree_map(
lambda g, t: jnp.maximum(jnp.abs(g) + eps, decay * t), updates, moments)
def update_moment_per_elem_norm(updates, moments, decay, order):
"""Compute the EMA of the `order`-th moment of the element-wise norm."""
def orderth_norm(g):
if jnp.isrealobj(g):
return g ** order
else:
half_order = order / 2
# JAX generates different HLO for int and float `order`
if half_order.is_integer():
half_order = int(half_order)
return _abs_sq(g) ** half_order
return jax.tree_util.tree_map(
lambda g, t: (1 - decay) * orderth_norm(g) + decay * t, updates, moments)
@functools.partial(jax.jit, inline=True)
def bias_correction(moment, decay, count):
"""Performs bias correction. It becomes a no-op as count goes to infinity."""
# The conversion to the data type of the moment ensures that bfloat16 remains
# bfloat16 in the optimizer state. This conversion has to be done after
# `bias_correction_` is calculated as calculating `decay**count` in low
# precision can result in it being rounded to 1 and subsequently a
# "division by zero" error.
bias_correction_ = 1 - decay**count
# Perform division in the original precision.
return jax.tree_util.tree_map(
lambda t: t / bias_correction_.astype(t.dtype), moment)
def _reject_complex(params):
if any(jnp.iscomplexobj(x) for x in jax.tree_util.tree_leaves(params)):
raise ValueError('This transformation does not support complex parameters.')
class EmaState(NamedTuple):
"""Holds an exponential moving average of past updates."""
count: chex.Array # shape=(), dtype=jnp.int32.
ema: base.Params
def ema(
decay: float,
debias: bool = True,
accumulator_dtype: Optional[Any] = None
) -> base.GradientTransformation:
"""Compute an exponential moving average of past updates.
Note: `trace` and `ema` have very similar but distinct updates;
`ema = decay * ema + (1-decay) * t`, while `trace = decay * trace + t`.
Both are frequently found in the optimization literature.
Args:
decay: Decay rate for the exponential moving average.
debias: Whether to debias the transformed gradient.
accumulator_dtype: Optional `dtype` to used for the accumulator; if `None`
then the `dtype` is inferred from `params` and `updates`.
Returns:
A `GradientTransformation` object.
"""
accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype)
def init_fn(params):
return EmaState(
count=jnp.zeros([], jnp.int32),
ema=jax.tree_util.tree_map(
lambda t: jnp.zeros_like(t, dtype=accumulator_dtype), params))
def update_fn(updates, state, params=None):
del params
updates = new_ema = update_moment(updates, state.ema, decay, order=1)
count_inc = utils.safe_int32_increment(state.count)
if debias:
updates = bias_correction(new_ema, decay, count_inc)
state_ema = utils.cast_tree(new_ema, accumulator_dtype)
return updates, EmaState(count=count_inc, ema=state_ema)
return base.GradientTransformation(init_fn, update_fn)
class ScaleByRssState(NamedTuple):
"""State holding the sum of gradient squares to date."""
sum_of_squares: base.Updates
def scale_by_rss(
initial_accumulator_value: float = 0.1,
eps: float = 1e-7
) -> base.GradientTransformation:
"""Rescale updates by the root of the sum of all squared gradients to date.
References:
[Duchi et al, 2011](https://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
[McMahan et al., 2010](https://arxiv.org/abs/1002.4908)
Args:
initial_accumulator_value: Starting value for accumulators, must be >= 0.
eps: A small floating point value to avoid zero denominator.
Returns:
A `GradientTransformation` object.
"""
def init_fn(params):
sum_of_squares = jax.tree_util.tree_map(
lambda t: jnp.full_like(t, initial_accumulator_value), params)
return ScaleByRssState(sum_of_squares=sum_of_squares)
def update_fn(updates, state, params=None):
del params
sum_of_squares = jax.tree_util.tree_map(
lambda g, t: _abs_sq(g) + t, updates, state.sum_of_squares)
inv_sqrt_g_square = jax.tree_util.tree_map(
lambda t: jnp.where(t > 0, jax.lax.rsqrt(t + eps), 0.0), sum_of_squares)
updates = jax.tree_util.tree_map(
lambda scale, g: scale * g, inv_sqrt_g_square, updates)
return updates, ScaleByRssState(sum_of_squares=sum_of_squares)
return base.GradientTransformation(init_fn, update_fn)
class ScaleByRmsState(NamedTuple):
"""State for exponential root mean-squared (RMS)-normalized updates."""
nu: base.Updates
def scale_by_rms(
decay: float = 0.9,
eps: float = 1e-8,
initial_scale: float = 0.
) -> base.GradientTransformation:
r"""Rescale updates by the root of the exp. moving avg of the square.
WARNING: PyTorch and optax's RMSprop implementations differ and could impact
performance. In the denominator, optax uses $\sqrt{v + \epsilon}$ whereas
PyTorch uses $\sqrt{v} + \epsilon$. See
https://github.com/google-deepmind/optax/issues/532 for more detail.
References:
[Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
Args:
decay: Decay rate for the exponentially weighted average of squared grads.
eps: Term added to the denominator to improve numerical stability.
initial_scale: Initial value for second moment.
Returns:
A `GradientTransformation` object.
"""
def init_fn(params):
nu = jax.tree_util.tree_map(
lambda n: jnp.full_like(n, initial_scale), params) # second moment
return ScaleByRmsState(nu=nu)
def update_fn(updates, state, params=None):
del params
nu = update_moment_per_elem_norm(updates, state.nu, decay, 2)
updates = jax.tree_util.tree_map(
lambda g, n: g * jax.lax.rsqrt(n + eps), updates, nu)
return updates, ScaleByRmsState(nu=nu)
return base.GradientTransformation(init_fn, update_fn)
class ScaleByRStdDevState(NamedTuple):
"""State for centered exponential moving average of squares of updates."""
mu: base.Updates
nu: base.Updates
def scale_by_stddev(
decay: float = 0.9,
eps: float = 1e-8,
initial_scale: float = 0.
) -> base.GradientTransformation:
"""Rescale updates by the root of the centered exp. moving average of squares.
References:
[Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
Args:
decay: Decay rate for the exponentially weighted average of squared grads.
eps: Term added to the denominator to improve numerical stability.
initial_scale: Initial value for second moment.
Returns:
A `GradientTransformation` object.
"""
def init_fn(params):
mu = jax.tree_util.tree_map(jnp.zeros_like, params) # First moment
nu = jax.tree_util.tree_map(
lambda n: jnp.full_like(n, initial_scale), params) # second moment
return ScaleByRStdDevState(mu=mu, nu=nu)
def update_fn(updates, state, params=None):
del params
mu = update_moment(updates, state.mu, decay, 1)
nu = update_moment_per_elem_norm(updates, state.nu, decay, 2)
updates = jax.tree_util.tree_map(
lambda g, m, n: g * jax.lax.rsqrt(n - _abs_sq(m) + eps),
updates, mu, nu)
return updates, ScaleByRStdDevState(mu=mu, nu=nu)
return base.GradientTransformation(init_fn, update_fn)
class ScaleByAdamState(NamedTuple):
"""State for the Adam algorithm."""
count: chex.Array # shape=(), dtype=jnp.int32.
mu: base.Updates
nu: base.Updates
def scale_by_adam(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
mu_dtype: Optional[chex.ArrayDType] = None,
*,
nesterov: bool = False
) -> base.GradientTransformation:
"""Rescale updates according to the Adam algorithm.
References:
Kingma et al, `Adam: A Method for Stochastic Optimization
<https://arxiv.org/abs/1412.6980>`_, 2014
Dozat, `Incorporating Nesterov Momentum into Adam
<https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ>`_ 2016
.. warning::
PyTorch and optax's adam follow Algorithm 1 of the Kingma
and Ba's Adam paper, if reproducing old results note that TensorFlow
used instead the formulation just before Section 2.1 of the paper.
See https://github.com/deepmind/optax/issues/571 for more detail.
Args:
b1: Decay rate for the exponentially weighted average of grads.
b2: Decay rate for the exponentially weighted average of squared grads.
eps: Term added to the denominator to improve numerical stability.
eps_root: Term added to the denominator inside the square-root to improve
numerical stability when backpropagating gradients through the rescaling.
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
nesterov: Whether to use Nesterov momentum. The variant of Adam with
Nesterov momentum is described in [Dozat 2016]
Returns:
A `GradientTransformation` object.
"""
mu_dtype = utils.canonicalize_dtype(mu_dtype)
def init_fn(params):
mu = jax.tree_util.tree_map( # First moment
lambda t: jnp.zeros_like(t, dtype=mu_dtype), params)
nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment
return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)
def update_fn(updates, state, params=None):
del params
mu = update_moment(updates, state.mu, b1, 1)
nu = update_moment_per_elem_norm(updates, state.nu, b2, 2)
count_inc = numerics.safe_int32_increment(state.count)
if nesterov:
mu_hat = jax.tree_util.tree_map(
lambda m, g: b1 * m + (1 - b1) * g,
bias_correction(mu, b1, numerics.safe_int32_increment(count_inc)),
bias_correction(updates, b1, count_inc))
else:
mu_hat = bias_correction(mu, b1, count_inc)
# Dozat 2016 https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ
# Algorithm 2 further multiplies Adam's standard nu_hat by b2. It is
# unclear why. Other Nadam implementations also omit the extra b2 factor.
nu_hat = bias_correction(nu, b2, count_inc)
updates = jax.tree_util.tree_map(
lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat)
mu = utils.cast_tree(mu, mu_dtype)
return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)
return base.GradientTransformation(init_fn, update_fn)
class ScaleByAmsgradState(NamedTuple):
"""State for the AMSGrad algorithm."""
count: chex.Array # shape=(), dtype=jnp.int32.
mu: base.Updates
nu: base.Updates
nu_max: base.Updates
def scale_by_amsgrad(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
mu_dtype: Optional[chex.ArrayDType] = None,
) -> base.GradientTransformation:
"""Rescale updates according to the AMSGrad algorithm.
References:
[Reddi et al, 2018](https://openreview.net/forum?id=ryQu7f-RZ)
Args:
b1: Decay rate for the exponentially weighted average of grads.
b2: Decay rate for the exponentially weighted average of squared grads.
eps: Term added to the denominator to improve numerical stability.
eps_root: Term added to the denominator inside the square-root to improve
numerical stability when backpropagating gradients through the rescaling.
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
Returns:
A `GradientTransformation` object.
"""
mu_dtype = utils.canonicalize_dtype(mu_dtype)
def init_fn(params):
mu = jax.tree_util.tree_map( # First moment
lambda t: jnp.zeros_like(t, dtype=mu_dtype), params)
nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment
nu_max = jax.tree_util.tree_map(jnp.zeros_like, params)
return ScaleByAmsgradState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu,
nu_max=nu_max)
def update_fn(updates, state, params=None):
del params
mu = update_moment(updates, state.mu, b1, 1)
nu = update_moment_per_elem_norm(updates, state.nu, b2, 2)
count_inc = numerics.safe_int32_increment(state.count)
mu_hat = bias_correction(mu, b1, count_inc)
nu_hat = bias_correction(nu, b2, count_inc)
nu_max = jax.tree_util.tree_map(jnp.maximum, state.nu_max, nu_hat)
updates = jax.tree_util.tree_map(
lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_max)
mu = utils.cast_tree(mu, mu_dtype)
return updates, ScaleByAmsgradState(count=count_inc, mu=mu, nu=nu,
nu_max=nu_max)
return base.GradientTransformation(init_fn, update_fn)
def scale_by_adamax(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8
) -> base.GradientTransformation:
"""Rescale updates according to the Adamax algorithm.
References:
[Kingma et al, 2014](https://arxiv.org/abs/1412.6980)
Args:
b1: Decay rate for the exponentially weighted average of grads.
b2: Decay rate for the exponentially weighted maximum of grads.
eps: Term added to the denominator to improve numerical stability.
Returns:
A `GradientTransformation` object.
"""
def init_fn(params):
mu = jax.tree_util.tree_map(jnp.zeros_like, params) # First moment
nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Infinite moment
return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)
def update_fn(updates, state, params=None):
del params
count_inc = numerics.safe_int32_increment(state.count)
mu = update_moment(updates, state.mu, b1, 1)
nu = update_infinity_moment(updates, state.nu, b2, eps)
# Bias correction for mean. No bias correction needed for infinity moment.
mu_hat = bias_correction(mu, b1, count_inc)
updates = jax.tree_util.tree_map(lambda m, v: m / v, mu_hat, nu)
return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)
return base.GradientTransformation(init_fn, update_fn)
class ScaleByLionState(NamedTuple):
"""State for the Lion algorithm."""
count: chex.Array # shape=(), dtype=jnp.int32.
mu: base.Updates
def scale_by_lion(
b1: float = 0.9,
b2: float = 0.99,
mu_dtype: Optional[chex.ArrayDType] = None,
) -> base.GradientTransformation:
"""Rescale updates according to the Lion algorithm.
References:
[Chen et al, 2023](https://arxiv.org/abs/2302.06675)
Args:
b1: Rate for combining the momentum and the current grad.
b2: Decay rate for the exponentially weighted average of grads.
mu_dtype: Optional `dtype` to be used for the momentum; if
`None` then the `dtype is inferred from `params` and `updates`.
Returns:
A `GradientTransformation` object.
"""
mu_dtype = utils.canonicalize_dtype(mu_dtype)
def init_fn(params):
mu = jax.tree_util.tree_map( # moment
lambda t: jnp.zeros_like(t, dtype=mu_dtype), params)
return ScaleByLionState(count=jnp.zeros([], jnp.int32), mu=mu)
def update_fn(updates, state, params=None):
del params
updates_new = jax.tree_util.tree_map(
lambda g, m: jnp.sign((1. - b1) * g + b1 * m), updates, state.mu)
mu = update_moment(updates, state.mu, b2, 1)
mu = utils.cast_tree(mu, mu_dtype)
count_inc = numerics.safe_int32_increment(state.count)
return updates_new, ScaleByLionState(count=count_inc, mu=mu)
return base.GradientTransformation(init_fn, update_fn)
ScaleState = base.EmptyState
def scale(
step_size: float
) -> base.GradientTransformation:
"""Scale updates by some fixed scalar `step_size`.
Args:
step_size: A scalar corresponding to a fixed scaling factor for updates.
Returns:
A `GradientTransformation` object.
"""
def init_fn(params):
del params
return ScaleState()
def update_fn(updates, state, params=None):
del params
updates = jax.tree_util.tree_map(lambda g: step_size * g, updates)
return updates, state
return base.GradientTransformation(init_fn, update_fn)
def scale_by_param_block_norm(
min_scale: float = 1e-3
) -> base.GradientTransformation:
"""Scale updates for each param block by the norm of that block's parameters.
A `block` is here a weight vector (e.g. in a Linear layer) or a weight matrix
(e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.
Args:
min_scale: Minimum scaling factor.
Returns:
A `GradientTransformation` object.
"""
def init_fn(params):
del params
return base.EmptyState()
def update_fn(updates, state, params):
if params is None:
raise ValueError(base.NO_PARAMS_MSG)
updates = jax.tree_util.tree_map(
lambda u, p: u * numerics.safe_norm(p, min_scale),
updates, params)
return updates, state
return base.GradientTransformation(init_fn, update_fn)
def scale_by_param_block_rms(
min_scale: float = 1e-3
) -> base.GradientTransformation:
"""Scale updates by rms of the gradient for each param vector or matrix.
A `block` is here a weight vector (e.g. in a Linear layer) or a weight matrix
(e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.
Args:
min_scale: Minimum scaling factor.
Returns:
A `GradientTransformation` object.
"""
def init_fn(params):
del params
return base.EmptyState()
def update_fn(updates, state, params):
if params is None:
raise ValueError(base.NO_PARAMS_MSG)
updates = jax.tree_util.tree_map(
lambda u, p: u * numerics.safe_root_mean_squares(p, min_scale),
updates, params)
return updates, state
return base.GradientTransformation(init_fn, update_fn)
class ScaleByAdaDeltaState(NamedTuple):
"""State for the rescaling by Adadelta algoritm."""
e_g: base.Updates
e_x: base.Updates
def scale_by_adadelta(
rho: float = 0.9, eps: float = 1e-6
) -> base.GradientTransformation:
"""Rescale updates according to the Adadelta algorithm.
References:
[Matthew D. Zeiler, 2012](https://arxiv.org/pdf/1212.5701.pdf)
Args:
rho: A coefficient used for computing a running average of squared
gradients.
eps: Term added to the denominator to improve numerical stability.
Returns:
A `GradientTransformation` object.
"""
def init_fn(params):
e_g = jax.tree_util.tree_map(jnp.zeros_like, params) # E[squared gradient]
e_x = jax.tree_util.tree_map(jnp.zeros_like, params) # E[squared update]
return ScaleByAdaDeltaState(e_g=e_g, e_x=e_x)
def update_fn(updates, state, params=None):
del params
e_g = update_moment(updates, state.e_g, rho, 2)
updates = jax.tree_util.tree_map(
lambda g, cur_e_g, prev_e_x: (
jnp.sqrt(prev_e_x + eps) / jnp.sqrt(cur_e_g + eps)
)
* g,
updates,
e_g,
state.e_x,
)
e_x = update_moment(updates, state.e_x, rho, 2)
return updates, ScaleByAdaDeltaState(e_g=e_g, e_x=e_x)
return base.GradientTransformation(init_fn, update_fn)
class ScaleByBeliefState(NamedTuple):
"""State for the rescaling by AdaBelief algorithm."""
count: chex.Array # shape=(), dtype=jnp.int32.
mu: base.Updates
nu: base.Updates
def scale_by_belief(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-16,
eps_root: float = 1e-16
) -> base.GradientTransformation:
"""Rescale updates according to the AdaBelief algorithm.
References:
[Zhuang et al, 2020](https://arxiv.org/abs/2010.07468)
Args:
b1: Decay rate for the exponentially weighted average of grads.
b2: Decay rate for the exponentially weighted average of variance of grads.
eps: Term added to the denominator to improve numerical stability.
eps_root: Term added to the second moment of the prediction error to
improve numerical stability. If backpropagating gradients through the
gradient transformation (e.g. for meta-learning), this must be non-zero.
Returns:
A `GradientTransformation` object.
"""
def init_fn(params):
mu = jax.tree_util.tree_map(jnp.zeros_like, params) # First moment
s = jax.tree_util.tree_map(jnp.zeros_like, params) # Second Central moment
return ScaleByBeliefState(count=jnp.zeros([], jnp.int32), mu=mu, nu=s)
def update_fn(updates, state, params=None):
del params
mu = update_moment(updates, state.mu, b1, 1)
prediction_error = jax.tree_util.tree_map(
lambda g, m: g-m, updates, state.mu)
nu = update_moment_per_elem_norm(prediction_error, state.nu, b2, 2)
nu = jax.tree_util.tree_map(lambda v: v + eps_root, nu)
count_inc = numerics.safe_int32_increment(state.count)
mu_hat = bias_correction(mu, b1, count_inc)
nu_hat = bias_correction(nu, b2, count_inc)
updates = jax.tree_util.tree_map(
lambda m, v: m / (jnp.sqrt(v) + eps), mu_hat, nu_hat)
return updates, ScaleByBeliefState(count=count_inc, mu=mu, nu=nu)
return base.GradientTransformation(init_fn, update_fn)
def scale_by_yogi(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-3,
eps_root: float = 0.0,
initial_accumulator_value: float = 1e-6
) -> base.GradientTransformation:
"""Rescale updates according to the Yogi algorithm.
Supports complex numbers, see
https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29
References:
[Zaheer et al, 2018](https://papers.nips.cc/paper/2018/hash/90365351ccc7437a1309dc64e4db32a3-Abstract.html) #pylint:disable=line-too-long
Args:
b1: Decay rate for the exponentially weighted average of grads.
b2: Decay rate for the exponentially weighted average of variance of grads.
eps: Term added to the denominator to improve numerical stability.
eps_root: Term added to the denominator inside the square-root to improve
numerical stability when backpropagating gradients through the rescaling.
initial_accumulator_value: The starting value for accumulators.
Only positive values are allowed.
Returns:
A `GradientTransformation` object.
"""
def init_fn(params):
value_like = lambda p: jnp.full_like(p, initial_accumulator_value)
mu = jax.tree_util.tree_map(value_like, params) # First moment
nu = jax.tree_util.tree_map(value_like, params) # Second Central moment
return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)
def update_fn(updates, state, params=None):
del params
mu = update_moment(updates, state.mu, b1, 1)
nu = jax.tree_util.tree_map(
lambda g, v: v - (1 - b2) * jnp.sign(v - _abs_sq(g)) * _abs_sq(g),
updates, state.nu)
count_inc = numerics.safe_int32_increment(state.count)
mu_hat = bias_correction(mu, b1, count_inc)
nu_hat = bias_correction(nu, b2, count_inc)
updates = jax.tree_util.tree_map(
lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat)
return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)
return base.GradientTransformation(init_fn, update_fn)
def scale_by_radam(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
threshold: float = 5.0
) -> base.GradientTransformation:
"""Rescale updates according to the Rectified Adam algorithm.
References:
[Liu et al, 2020](https://arxiv.org/abs/1908.03265)
Args:
b1: Decay rate for the exponentially weighted average of grads.
b2: Decay rate for the exponentially weighted average of squared grads.
eps: Term added to the denominator to improve numerical stability.
eps_root: Term added to the denominator inside the square-root to improve
numerical stability when backpropagating gradients through the rescaling.
threshold: Threshold for variance tractability.
Returns:
A `GradientTransformation` object.
"""
ro_inf = 2./(1 - b2) - 1
def _radam_update(params):
ro = params[0]
mu_hat = params[1]
nu_hat = params[2]
r = jnp.sqrt((ro - 4)*(ro - 2)*ro_inf/((ro_inf - 4)*(ro_inf - 2)*ro))
updates = jax.tree_util.tree_map(
lambda m, v: r*m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat)
return updates
def init_fn(params):
mu = jax.tree_util.tree_map(jnp.zeros_like, params) # First moment
nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment
return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)
def update_fn(updates, state, params=None):
del params
mu = update_moment(updates, state.mu, b1, 1)
nu = update_moment_per_elem_norm(updates, state.nu, b2, 2)
count_inc = numerics.safe_int32_increment(state.count)
b2t = b2**count_inc
ro = ro_inf - 2 * count_inc * b2t / (1 - b2t)
mu_hat = bias_correction(mu, b1, count_inc)
nu_hat = bias_correction(nu, b2, count_inc)
updates = jax.lax.cond(
ro >= threshold, _radam_update, lambda _: mu_hat,
(ro, mu_hat, nu_hat))
return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)
return base.GradientTransformation(init_fn, update_fn)
class ScaleByRpropState(NamedTuple):
step_sizes: base.Updates
prev_updates: base.Updates
def scale_by_rprop(
learning_rate: float,
eta_minus: float = 0.5,
eta_plus: float = 1.2,
min_step_size: float = 1e-6,
max_step_size: float = 50.0,
) -> base.GradientTransformation:
"""Scale with the Rprop optimizer.
Rprop, short for resillient backpropogation, is a first order variant of
gradient descent. It responds only to the sign of the gradient by increasing
or decreasing the step size selected per parameter exponentially to speed up
convergence and avoid oscillations.
References:
PyTorch implementation:
https://pytorch.org/docs/stable/generated/torch.optim.Rprop.html
Riedmiller and Braun, 1993: https://ieeexplore.ieee.org/document/298623
Igel and Hüsken, 2003:
https://www.sciencedirect.com/science/article/abs/pii/S0925231201007007
Args:
learning_rate: The initial step size.
eta_minus: Multiplicative factor for decreasing step size. This is applied
when the gradient changes sign from one step to the next.
eta_plus: Multiplicative factor for increasing step size. This is applied
when the gradient has the same sign from one step to the next.
min_step_size: Minimum allowed step size. Smaller steps will be clipped to
this value.
max_step_size: Maximum allowed step size. Larger steps will be clipped to
this value.
Returns:
The corresponding `GradientTransformation`.
"""
def init_fn(params):
step_sizes = jax.tree_util.tree_map(
lambda p: learning_rate * jnp.ones_like(p), params)
prev_updates = jax.tree_util.tree_map(jnp.zeros_like, params)
return ScaleByRpropState(step_sizes, prev_updates)
def update_fn(updates, state, params=None):
del params
sign = jax.tree_util.tree_map(
lambda g, prev_g: g * prev_g, updates, state.prev_updates)
step_sizes = jax.tree_util.tree_map(
lambda s, step_size: jnp.where(
s == 0,
step_size,
jnp.clip(
step_size * jnp.where(s > 0, eta_plus, eta_minus),
a_min=min_step_size, a_max=max_step_size
)
),
sign, state.step_sizes
)
prev_updates = jax.tree_util.tree_map(
lambda s, g, step_size: jnp.where(
s < 0, jnp.zeros_like(g), step_size * jnp.sign(g)),
sign, updates, step_sizes)
updates = jax.tree_util.tree_map(
lambda s, g, prev_g: jnp.where(s < 0, jnp.zeros_like(prev_g), prev_g),
sign, prev_updates, state.prev_updates)
return updates, ScaleByRpropState(step_sizes, prev_updates)
return base.GradientTransformation(init_fn, update_fn)
AddDecayedWeightsState = base.EmptyState
def add_decayed_weights(
weight_decay: Union[float, jax.Array] = 0.0,
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None
) -> base.GradientTransformation:
"""Add parameter scaled by `weight_decay`.
Args:
weight_decay: A scalar weight decay rate.
mask: A tree with same structure as (or a prefix of) the params PyTree,
or a Callable that returns such a pytree given the params/updates.
The leaves should be booleans, `True` for leaves/subtrees you want to
apply the transformation to, and `False` for those you want to skip.
Returns:
A `GradientTransformation` object.
"""
def init_fn(params):
del params
return AddDecayedWeightsState()
def update_fn(updates, state, params):
if params is None:
raise ValueError(base.NO_PARAMS_MSG)
updates = jax.tree_util.tree_map(
lambda g, p: g + weight_decay * p, updates, params)
return updates, state
# If mask is not `None`, apply mask to the gradient transformation.
# E.g. it is common to skip weight decay on bias units and batch stats.
if mask is not None:
return wrappers.masked(
base.GradientTransformation(init_fn, update_fn), mask)
return base.GradientTransformation(init_fn, update_fn)
class ScaleByScheduleState(NamedTuple):
"""Maintains count for scale scheduling."""
count: chex.Array # shape=(), dtype=jnp.int32
def scale_by_learning_rate(
learning_rate: base.ScalarOrSchedule,
*,
flip_sign: bool = True,
) -> base.GradientTransformation:
"""Scale by the (negative) learning rate (either as scalar or as schedule).
Args:
learning_rate: Can either be a scalar or a schedule (i.e. a callable that
maps an (int) step to a float).
flip_sign: When set to True (the default) this corresponds to scaling by the
negative learning rate.
Returns:
An optax.GradientTransformation that corresponds to multiplying the gradient
with `-learning_rate` (if flip_sign is True) or with `learning_rate` (if
flip_sign is False).
"""
m = -1 if flip_sign else 1
if callable(learning_rate):
return scale_by_schedule(lambda count: m * learning_rate(count))
return scale(m * learning_rate)
def scale_by_schedule(
step_size_fn: base.Schedule
) -> base.GradientTransformation:
"""Scale updates using a custom schedule for the `step_size`.
Args:
step_size_fn: A function that takes an update count as input and proposes
the step_size to multiply the updates by.
Returns:
A `GradientTransformation` object.
"""
def init_fn(params):
del params
return ScaleByScheduleState(count=jnp.zeros([], jnp.int32))
def update_fn(updates, state, params=None):
del params
step_size = step_size_fn(state.count)
updates = jax.tree_util.tree_map(
lambda g: jnp.array(step_size, dtype=g.dtype) * g, updates)
return updates, ScaleByScheduleState(
count=numerics.safe_int32_increment(state.count))
return base.GradientTransformation(init_fn, update_fn)
class ScaleByTrustRatioState(NamedTuple):
"""The scale and decay trust ratio transformation is stateless."""
def scale_by_trust_ratio(
min_norm: float = 0.0,
trust_coefficient: float = 1.,
eps: float = 0.,
) -> base.GradientTransformation:
"""Scale updates by `trust ratio`.
References:
[You et. al 2020](https://arxiv.org/abs/1904.00962)
Args: