From 1528c68e1e60fe137ea8dc07e42ab970aae9ea32 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Tue, 25 Jun 2024 16:42:51 +0530 Subject: [PATCH 1/5] feat: add normalize_by_update_norm --- optax/__init__.py | 2 ++ optax/_src/transform.py | 37 ++++++++++++++++++++++++++++++++++++ optax/_src/transform_test.py | 1 + 3 files changed, 40 insertions(+) diff --git a/optax/__init__.py b/optax/__init__.py index 7aaa5d68..360ee3a0 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -105,6 +105,8 @@ from optax._src.transform import centralize from optax._src.transform import ema from optax._src.transform import EmaState +from optax._src.transform import normalize_by_update_norm +from optax._src.transform import NormalizeByUpdateNormState from optax._src.transform import scale from optax._src.transform import scale_by_adadelta from optax._src.transform import scale_by_adam diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 9dda6786..46c61dbd 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -1419,6 +1419,43 @@ def update_fn( return base.GradientTransformationExtraArgs(_init_empty_state, update_fn) +class NormalizeByUpdateNormState(NamedTuple): + """State for normalize_by_update_norm.""" + scale_factor: float + eps: float + + +def normalize_by_update_norm( + scale_factor: float = 1.0, eps: float = 1e-6 +) -> base.GradientTransformation: + """ + Scale by the inverse of the gradient norm. + + Args: + scale_factor: (float) scaling factor + eps: (float) jitter term to avoid dividing by 0 + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return NormalizeByUpdateNormState(scale_factor, eps) + + def update_fn( + updates: base.Updates, + state: base.EmptyState, + params: Optional[base.Params] = None, + ) -> tuple[base.Updates, base.EmptyState]: + del params + g_norm = (utils.global_norm(updates) + eps) / scale_factor + updates = jtu.tree_map(lambda g: g / g_norm, updates) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) + + ### Legacy symbols to be removed. ### diff --git a/optax/_src/transform_test.py b/optax/_src/transform_test.py index 39a91449..0aefe9b0 100644 --- a/optax/_src/transform_test.py +++ b/optax/_src/transform_test.py @@ -53,6 +53,7 @@ def setUp(self): ('param_block_norm', transform.scale_by_param_block_norm), ('param_block_rms', transform.scale_by_param_block_rms), ('distance_over_gradients', transform.scale_by_distance_over_gradients), + ('normalize_by_update_norm', transform.normalize_by_update_norm), ]) def test_scalers(self, scaler_constr): params = self.init_params From 6b5e78f8c69af929f013c144da16b4d3730206be Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Fri, 28 Jun 2024 14:02:12 +0530 Subject: [PATCH 2/5] docs: update docstrings Co-authored-by: Fabian Pedregosa --- optax/_src/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 46c61dbd..2abc1faa 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -1432,7 +1432,7 @@ def normalize_by_update_norm( Scale by the inverse of the gradient norm. Args: - scale_factor: (float) scaling factor + scale_factor: factor by which the update will be multiplied (defaults to 1). eps: (float) jitter term to avoid dividing by 0 Returns: From c2a7714c89f2d08a9abe80b2221f6f9aec393b15 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Fri, 28 Jun 2024 14:17:58 +0530 Subject: [PATCH 3/5] feat: add doctests --- optax/_src/transform.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 2abc1faa..c21b056b 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -1429,7 +1429,28 @@ def normalize_by_update_norm( scale_factor: float = 1.0, eps: float = 1e-6 ) -> base.GradientTransformation: """ - Scale by the inverse of the gradient norm. + Scale by the inverse of the update norm. + + Examples: + >>> import optax + >>> import jax + >>> import jax.numpy as jnp + >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function + >>> solver = optax.normalize_by_update_norm(scale_factor=1.0) + >>> params = jnp.array([1., 2., 3.]) + >>> print('Objective function: ', f(params)) + Objective function: 14.0 + >>> opt_state = solver.init(params) + >>> for _ in range(5): + ... grad = jax.grad(f)(params) + ... updates, opt_state = solver.update(grad, opt_state, params) + ... params = optax.apply_updates(params, updates) + ... print('Objective function: {:.2E}'.format(f(params))) + Objective function: 2.25E+01 + Objective function: 3.30E+01 + Objective function: 4.54E+01 + Objective function: 5.99E+01 + Objective function: 7.64E+01 Args: scale_factor: factor by which the update will be multiplied (defaults to 1). @@ -1449,7 +1470,7 @@ def update_fn( params: Optional[base.Params] = None, ) -> tuple[base.Updates, base.EmptyState]: del params - g_norm = (utils.global_norm(updates) + eps) / scale_factor + g_norm = (otu.tree_l2_norm(updates) + eps) / scale_factor updates = jtu.tree_map(lambda g: g / g_norm, updates) return updates, state From 5c01f6f861e518f5d829c640061d26cfb1318659 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Fri, 28 Jun 2024 14:47:43 +0530 Subject: [PATCH 4/5] docs: update transformations --- docs/api/transformations.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/api/transformations.rst b/docs/api/transformations.rst index c1fcb9cc..33395c54 100644 --- a/docs/api/transformations.rst +++ b/docs/api/transformations.rst @@ -32,6 +32,8 @@ Transformations identity keep_params_nonnegative NonNegativeParamsState + normalize_by_update_norm + NormalizeByUpdateNormState OptState Params per_example_global_norm_clip From b849fc63fd13fcf5a5d47a175e2e5041ebf9b7b7 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Fri, 28 Jun 2024 17:29:35 +0530 Subject: [PATCH 5/5] docs: update transformations API docs --- docs/api/transformations.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/api/transformations.rst b/docs/api/transformations.rst index 33395c54..3987c96d 100644 --- a/docs/api/transformations.rst +++ b/docs/api/transformations.rst @@ -174,6 +174,10 @@ Transformations and states .. autoclass:: NonNegativeParamsState :members: +.. autofunction:: normalize_by_update_norm +.. autoclass:: NormalizeByUpdateNormState + :members: + .. autofunction:: per_example_global_norm_clip .. autofunction:: per_example_layer_norm_clip