diff --git a/docs/api/transformations.rst b/docs/api/transformations.rst index c1fcb9cc..3987c96d 100644 --- a/docs/api/transformations.rst +++ b/docs/api/transformations.rst @@ -32,6 +32,8 @@ Transformations identity keep_params_nonnegative NonNegativeParamsState + normalize_by_update_norm + NormalizeByUpdateNormState OptState Params per_example_global_norm_clip @@ -172,6 +174,10 @@ Transformations and states .. autoclass:: NonNegativeParamsState :members: +.. autofunction:: normalize_by_update_norm +.. autoclass:: NormalizeByUpdateNormState + :members: + .. autofunction:: per_example_global_norm_clip .. autofunction:: per_example_layer_norm_clip diff --git a/optax/__init__.py b/optax/__init__.py index 7aaa5d68..360ee3a0 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -105,6 +105,8 @@ from optax._src.transform import centralize from optax._src.transform import ema from optax._src.transform import EmaState +from optax._src.transform import normalize_by_update_norm +from optax._src.transform import NormalizeByUpdateNormState from optax._src.transform import scale from optax._src.transform import scale_by_adadelta from optax._src.transform import scale_by_adam diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 9dda6786..c21b056b 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -1419,6 +1419,64 @@ def update_fn( return base.GradientTransformationExtraArgs(_init_empty_state, update_fn) +class NormalizeByUpdateNormState(NamedTuple): + """State for normalize_by_update_norm.""" + scale_factor: float + eps: float + + +def normalize_by_update_norm( + scale_factor: float = 1.0, eps: float = 1e-6 +) -> base.GradientTransformation: + """ + Scale by the inverse of the update norm. + + Examples: + >>> import optax + >>> import jax + >>> import jax.numpy as jnp + >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function + >>> solver = optax.normalize_by_update_norm(scale_factor=1.0) + >>> params = jnp.array([1., 2., 3.]) + >>> print('Objective function: ', f(params)) + Objective function: 14.0 + >>> opt_state = solver.init(params) + >>> for _ in range(5): + ... grad = jax.grad(f)(params) + ... updates, opt_state = solver.update(grad, opt_state, params) + ... params = optax.apply_updates(params, updates) + ... print('Objective function: {:.2E}'.format(f(params))) + Objective function: 2.25E+01 + Objective function: 3.30E+01 + Objective function: 4.54E+01 + Objective function: 5.99E+01 + Objective function: 7.64E+01 + + Args: + scale_factor: factor by which the update will be multiplied (defaults to 1). + eps: (float) jitter term to avoid dividing by 0 + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return NormalizeByUpdateNormState(scale_factor, eps) + + def update_fn( + updates: base.Updates, + state: base.EmptyState, + params: Optional[base.Params] = None, + ) -> tuple[base.Updates, base.EmptyState]: + del params + g_norm = (otu.tree_l2_norm(updates) + eps) / scale_factor + updates = jtu.tree_map(lambda g: g / g_norm, updates) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) + + ### Legacy symbols to be removed. ### diff --git a/optax/_src/transform_test.py b/optax/_src/transform_test.py index 39a91449..0aefe9b0 100644 --- a/optax/_src/transform_test.py +++ b/optax/_src/transform_test.py @@ -53,6 +53,7 @@ def setUp(self): ('param_block_norm', transform.scale_by_param_block_norm), ('param_block_rms', transform.scale_by_param_block_rms), ('distance_over_gradients', transform.scale_by_distance_over_gradients), + ('normalize_by_update_norm', transform.normalize_by_update_norm), ]) def test_scalers(self, scaler_constr): params = self.init_params