diff --git a/docs/api.rst b/docs/api.rst index 39b129d0..8552cd33 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -619,6 +619,7 @@ Schedules .. autofunction:: piecewise_constant_schedule .. autofunction:: piecewise_interpolate_schedule .. autofunction:: polynomial_schedule +.. autofunction:: optax.contrib.reduce_on_plateau .. autofunction:: sgdr_schedule .. autofunction:: warmup_cosine_decay_schedule .. autofunction:: warmup_exponential_decay_schedule diff --git a/optax/contrib/__init__.py b/optax/contrib/__init__.py index ce85ff06..8c8b85ee 100644 --- a/optax/contrib/__init__.py +++ b/optax/contrib/__init__.py @@ -27,6 +27,8 @@ from optax.contrib.privacy import dpsgd from optax.contrib.prodigy import prodigy from optax.contrib.prodigy import ProdigyState +from optax.contrib.reduce_on_plateau import reduce_on_plateau +from optax.contrib.reduce_on_plateau import ReduceLROnPlateauState from optax.contrib.sam import normalize from optax.contrib.sam import NormalizeState from optax.contrib.sam import sam diff --git a/optax/contrib/reduce_on_plateau.py b/optax/contrib/reduce_on_plateau.py new file mode 100644 index 00000000..b3be6ad0 --- /dev/null +++ b/optax/contrib/reduce_on_plateau.py @@ -0,0 +1,128 @@ +# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reduce Learning Rate on Plateau callback. + +This callback monitors a quantity and if no improvement is seen for a 'patience' +number of epochs, the learning rate is reduced by a factor of 'reduce_factor'. +Optionally, a cooldown period can be specified during which the learning rate +will not be reduced. +""" +from typing import NamedTuple, Tuple + +import chex +import jax +import jax.numpy as jnp +from optax._src import base +from optax._src import numerics + + +class ReduceLROnPlateauState(NamedTuple): + """State for the ReduceLROnPlateau callback.""" + + lr: chex.Array # shape=(), dtype=jnp.float32 + best_loss: chex.Array # shape=(), dtype=jnp.float32 + plateau_count: chex.Array # shape=(), dtype=jnp.int32 + cooldown_counter: chex.Array # shape=(), dtype=jnp.int32 + + +def reduce_on_plateau( + factor: float = 0.1, + patience: int = 10, + threshold: float = 1e-4, + cooldown: int = 0, +) -> base.GradientTransformationExtraArgs: + """Reduce learning rate when a metric has stopped improving. + + Models often benefit from reducing the learning once learning stagnates. + his scheduler reads a metrics quantity and if no improvement is seen for + a ``patience`` number of epochs, the learning rate is reduced. + + Args: + factor: Factor by which to reduce the learning rate. new_lr = lr * factor. + patience: Number of iterations with no improvement after which learning rate + will be reduced. + threshold: Threshold for measuring the new optimum, to only focus on + significant changes. + cooldown: Number of iterations to wait before resuming normal operation + after lr has been reduced. + + Returns: + A GradientTransformationExtraArgs object. + """ + + def init_fn(params) -> ReduceLROnPlateauState: + del params + return ReduceLROnPlateauState( + best_loss=jnp.asarray(float("inf"), dtype=jnp.float32), + plateau_count=jnp.asarray(0, jnp.int32), + lr=jnp.asarray(1.0, dtype=jnp.float32), + cooldown_counter=jnp.asarray(0, jnp.int32), + ) + + def update_fn( + updates: base.Updates, + state: ReduceLROnPlateauState, + params=None, + *, + loss, + **extra_args, + ) -> Tuple[base.Params, ReduceLROnPlateauState]: + del params, extra_args + + # Update plateau count and check if plateaued + has_improved = jnp.where((loss / state.best_loss - 1) < -threshold, 1, 0) + new_best_loss = jnp.where(has_improved, loss, state.best_loss) + + curr_plateau_count = jnp.where( + has_improved, 0, numerics.safe_int32_increment(state.plateau_count) + ) + + # We're in cooldown, so reduce the counter and ignore any bad epochs + def in_cooldown(): + new_plateau_count = 0 + new_lr = state.lr + new_cooldown_counter = state.cooldown_counter - 1 + return new_plateau_count, new_lr, new_cooldown_counter + + # We're not in cooldown, so update the plateau count and lr as usual + def not_in_cooldown(): + new_plateau_count = jnp.where( + curr_plateau_count == patience, 0, curr_plateau_count + ) + new_lr = jnp.where( + curr_plateau_count == patience, + state.lr * factor, + state.lr, + ) + new_cooldown_counter = jnp.where( + curr_plateau_count == patience, cooldown, 0 + ) + return new_plateau_count, new_lr, new_cooldown_counter + + new_plateau_count, new_lr, new_cooldown_counter = jax.lax.cond( + state.cooldown_counter > 0, in_cooldown, not_in_cooldown + ) + + updates = jax.tree_util.tree_map(lambda g: new_lr * g, updates) + + new_state = ReduceLROnPlateauState( + plateau_count=new_plateau_count, + best_loss=new_best_loss, + lr=new_lr, + cooldown_counter=new_cooldown_counter, + ) + return updates, new_state + + return base.GradientTransformationExtraArgs(init_fn, update_fn) diff --git a/optax/contrib/reduce_on_plateau_test.py b/optax/contrib/reduce_on_plateau_test.py new file mode 100644 index 00000000..7f77fe29 --- /dev/null +++ b/optax/contrib/reduce_on_plateau_test.py @@ -0,0 +1,97 @@ +# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `reduce_on_plateau.py`.""" + +from absl.testing import absltest +import chex +import jax.numpy as jnp +from optax import contrib + + +class ReduceLROnPlateauTest(absltest.TestCase): + + def test_learning_rate_reduced_after_cooldown_period_is_over(self): + """Test that learning rate is reduced again after cooldown period is over.""" + + # Define a dummy update and extra_args + updates = {'params': jnp.array(1.0)} + patience = 5 + cooldown = 5 + # Apply the transformation to the updates and state + transform = contrib.reduce_on_plateau(patience=patience, cooldown=cooldown) + state = transform.init(updates['params']) + for _ in range(patience + 1): + updates, state = transform.update(updates=updates, state=state, loss=1.0) + # Check that learning rate is reduced + # we access the fields inside new_state using indices instead of attributes + # because otherwise pytype throws an error + lr, best_loss, plateau_count, cooldown_counter = state + chex.assert_trees_all_close(lr, 0.1) + chex.assert_trees_all_close(best_loss, 1.0) + chex.assert_trees_all_close(plateau_count, 0) + chex.assert_trees_all_close(cooldown_counter, cooldown) + chex.assert_trees_all_close(updates, {'params': jnp.array(0.1)}) + + _, state = transform.update(updates=updates, state=state, loss=1.0) + lr, best_loss, plateau_count, cooldown_counter = state + chex.assert_trees_all_close(lr, 0.1) + chex.assert_trees_all_close(best_loss, 1.0) + chex.assert_trees_all_close(plateau_count, 0) + chex.assert_trees_all_close(cooldown_counter, cooldown - 1) + + def test_learning_rate_is_not_reduced(self): + """Test that plateau count resets after a new best loss is found.""" + state = contrib.ReduceLROnPlateauState( + best_loss=jnp.array(0.1, dtype=jnp.float32), + plateau_count=jnp.array(3, dtype=jnp.int32), + lr=jnp.array(0.01, dtype=jnp.float32), + cooldown_counter=jnp.array(0, dtype=jnp.int32), + ) + # Define a dummy update and extra_args + updates = {'params': 1} + # Apply the transformation to the updates and state + transform = contrib.reduce_on_plateau( + factor=0.5, patience=5, threshold=1e-4, cooldown=5 + ) + _, new_state = transform.update(updates=updates, state=state, loss=0.01) + lr, best_loss, plateau_count, _ = new_state + # Check that plateau count resets + chex.assert_trees_all_close(plateau_count, 0) + chex.assert_trees_all_close(lr, 0.01) + chex.assert_trees_all_close(best_loss, 0.01) + + def test_learning_rate_not_reduced_during_cooldown(self): + """Test that learning rate is not reduced during cooldown.""" + # Define a state where cooldown_counter is positive + state = contrib.ReduceLROnPlateauState( + best_loss=jnp.array(0.1, dtype=jnp.float32), + plateau_count=jnp.array(4, dtype=jnp.int32), + lr=jnp.array(0.01, dtype=jnp.float32), + cooldown_counter=jnp.array(3, dtype=jnp.int32), + ) + # Define a dummy update and extra_args + updates = {'params': 1} + # Apply the transformation to the updates and state + transform = contrib.reduce_on_plateau( + factor=0.5, patience=5, threshold=1e-4, cooldown=5 + ) + _, new_state = transform.update(updates=updates, state=state, loss=0.15) + # Check that learning rate is not reduced + lr, _, _, _ = new_state + chex.assert_trees_all_close(lr, 0.01) + + +if __name__ == '__main__': + absltest.main()