diff --git a/optax/contrib/_scale_by_grad_norm.py b/optax/contrib/_scale_by_grad_norm.py new file mode 100644 index 00000000..8e775f79 --- /dev/null +++ b/optax/contrib/_scale_by_grad_norm.py @@ -0,0 +1,61 @@ +# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""GradientTransformation to scale by the gradient norm""" + +import jax + +from optax._src import base +from optax._src import utils + +from typing import NamedTuple + + +class ScaleByGradientNorm(NamedTuple): + """State of `GradientTransformation` returned by `scale_by_gradient_norm`. + + Attributes: + scale: (float) scaling factor + eps: (float) jitter term to avoid dividing by 0 + """ + + scale: float + eps: float + + +def scale_by_gradient_norm( + scale: float = 1.0, eps: float = 1e-6 +) -> base.GradientTransformation: + """ + Scale by the norm of the gradient. + + Args: + scale: (float) scaling factor + eps: (float) jitter term to avoid dividing by 0 + + Returns: + An (init_fn, update_fn) tuple. + """ + + def init_fn(params): + del params + return ScaleByGradientNorm(scale, eps) + + def update_fn(updates, state, params=None): + del params + g_norm = (utils.global_norm(updates) + eps) / scale + updates = jax.tree_util.tree_map(lambda g: g / g_norm, updates) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) diff --git a/optax/contrib/_scale_by_grad_norm_test.py b/optax/contrib/_scale_by_grad_norm_test.py new file mode 100644 index 00000000..fd40a366 --- /dev/null +++ b/optax/contrib/_scale_by_grad_norm_test.py @@ -0,0 +1,45 @@ +# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `_scale_by_grad_norm.py`.""" + +from absl.testing import absltest +import chex +import jax +import jax.numpy as jnp +from optax.contrib._scale_by_grad_norm import scale_by_gradient_norm + + +class ScaleByGradientNormTest(chex.TestCase): + @chex.all_variants + def test_scale_by_gradient_norm(self): + params = jnp.array([300.0, -400.0]) + updates = jnp.array([300.0, -400.0]) + + optim = scale_by_gradient_norm(scale=1.0) + init_fn = self.variant(optim.init) + transform_fn = self.variant(optim.update) + + state = init_fn(params) + chex.assert_tree_all_finite(state) + + updates, state = transform_fn(updates, state, params) + chex.assert_tree_all_finite((params, updates, state)) + jax.tree_util.tree_map( + lambda *args: chex.assert_equal_shape(args), params, updates + ) + + +if __name__ == "__main__": + absltest.main()