From ae75b614830f609c04ae655d244b08bb571caa05 Mon Sep 17 00:00:00 2001 From: vz415 Date: Fri, 10 Nov 2023 14:24:16 -0800 Subject: [PATCH 1/6] Add reduce_on_plateau LR scheduler to contrib directory. --- optax/contrib/__init__.py | 2 + optax/contrib/reduce_on_plateau.py | 131 ++++++++++++++++++++++++ optax/contrib/reduce_on_plateau_test.py | 106 +++++++++++++++++++ 3 files changed, 239 insertions(+) create mode 100644 optax/contrib/reduce_on_plateau.py create mode 100644 optax/contrib/reduce_on_plateau_test.py diff --git a/optax/contrib/__init__.py b/optax/contrib/__init__.py index a9d4b989..d89d18d5 100644 --- a/optax/contrib/__init__.py +++ b/optax/contrib/__init__.py @@ -23,6 +23,8 @@ from optax.contrib.privacy import differentially_private_aggregate from optax.contrib.privacy import DifferentiallyPrivateAggregateState from optax.contrib.privacy import dpsgd +from optax.contrib.reduce_on_plateau import reduce_on_plateau +from optax.contrib.reduce_on_plateau import ReduceLROnPlateauState from optax.contrib.sam import normalize from optax.contrib.sam import NormalizeState from optax.contrib.sam import sam diff --git a/optax/contrib/reduce_on_plateau.py b/optax/contrib/reduce_on_plateau.py new file mode 100644 index 00000000..2b843317 --- /dev/null +++ b/optax/contrib/reduce_on_plateau.py @@ -0,0 +1,131 @@ +# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reduce Learning Rate on Plateau callback. + +This callback monitors a quantity and if no improvement is seen for a 'patience' +number of epochs, the learning rate is reduced by a factor of 'reduce_factor'. +Optionally, a cooldown period can be specified during which the learning rate +will not be reduced. +""" +from typing import NamedTuple + +import jax +import jax.numpy as jnp + +from optax._src import base + + +class ReduceLROnPlateauState(NamedTuple): + """State for the ReduceLROnPlateau callback.""" + reduce_factor: float + patience: int + min_improvement: float + best_loss: float + plateau_count: int + lr: float + cooldown_counter: int + cooldown:int + + +def reduce_on_plateau( + reduce_factor: float, + patience: int, + min_improvement:float, + cooldown:int +) -> base.GradientTransformationExtraArgs: + """ Args: + reduce_factor: Factor by which the learning rate will be reduced. + new_lr = lr * factor. + patience: Number of epochs with no improvement after which learning + rate will be reduced. + min_improvement: Threshold for measuring the new optimum, to only focus on + significant changes. + cooldown: Number of epochs to wait before resuming normal operation + after lr has been reduced. + """ + + + def init_fn(params): + del params + return ReduceLROnPlateauState(patience=patience, + reduce_factor=reduce_factor, + min_improvement=min_improvement, + cooldown=cooldown, + cooldown_counter=0, + plateau_count=0, + best_loss=float("inf"), + lr=1, + ) + + def update_fn( + updates, + state, + params=None, + extra_args={}, + ): + del params + current_loss = extra_args.get("loss") + + # Check if the current loss is the best so far + + best_loss = state.best_loss + # Update plateau count and check if plateaued + has_improved = jnp.where( + (current_loss / best_loss - 1) < -state.min_improvement, 1, 0 + ) + new_best_loss = jnp.where(has_improved, current_loss, best_loss) + + curr_plateau_count = jnp.where(has_improved, 0, state.plateau_count + 1) + + + # We're in cooldown, so reduce the counter and ignore any bad epochs + def in_cooldown(): + new_plateau_count = 0 + new_lr = state.lr + new_cooldown_counter = state.cooldown_counter - 1 + return new_plateau_count, new_lr, new_cooldown_counter + + # We're not in cooldown, so update the plateau count and lr as usual + def not_in_cooldown(): + new_plateau_count = jnp.where( + curr_plateau_count == state.patience, 0, curr_plateau_count + ) + new_lr = jnp.where( + curr_plateau_count == state.patience, + state.lr * state.reduce_factor, + state.lr, + ) + new_cooldown_counter = jnp.where( + curr_plateau_count == state.patience, state.cooldown, 0 + ) + return new_plateau_count, new_lr, new_cooldown_counter + + new_plateau_count, new_lr, new_cooldown_counter = jax.lax.cond(state.cooldown_counter > 0, in_cooldown, not_in_cooldown) + + updates = jax.tree_util.tree_map(lambda g: new_lr * g, updates) + + new_state = ReduceLROnPlateauState( + patience=state.patience, + reduce_factor=state.reduce_factor, + min_improvement=state.min_improvement, + plateau_count=new_plateau_count, + best_loss=new_best_loss, + lr=new_lr, + cooldown_counter=new_cooldown_counter, + cooldown=state.cooldown, + ) + return updates, new_state + + return base.GradientTransformationExtraArgs(init_fn, update_fn) \ No newline at end of file diff --git a/optax/contrib/reduce_on_plateau_test.py b/optax/contrib/reduce_on_plateau_test.py new file mode 100644 index 00000000..5110f180 --- /dev/null +++ b/optax/contrib/reduce_on_plateau_test.py @@ -0,0 +1,106 @@ +# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `reduce_on_plateau.py`.""" + +from absl.testing import absltest +import chex +import jax +import jax.numpy as jnp + +from optax._src import base +from optax._src import transform +from optax.contrib.reduce_on_plateau import ReduceLROnPlateauState, reduce_on_plateau + + +class ReduceLROnPlateauTest(absltest.TestCase): + def test_learning_rate_reduced_after_cooldown_period_is_over(self): + """Test that learning rate is reduced again after cooldown period is over.""" + # Define a state where cooldown_counter is zero + state = ReduceLROnPlateauState( + reduce_factor=0.5, + patience=5, + min_improvement=1e-4, + best_loss=0.1, + ## + plateau_count=4, + ## + lr=0.01, + cooldown=5, + cooldown_counter=0, + ) + # Define a dummy update and extra_args + updates = {'params': 1} + extra_args = {'loss': 0.15} + # Apply the transformation to the updates and state + transform = reduce_on_plateau(reduce_factor=0.5, patience=5, min_improvement=1e-4, cooldown=5) + new_updates, new_state = transform.update(updates=updates, state=state, extra_args=extra_args) + # Check that learning rate is reduced + assert new_state.lr == 0.005 + assert new_state.plateau_count == 0 + assert new_state.cooldown_counter == 5 + new_updates, new_state = transform.update(updates=updates, state=new_state, extra_args=extra_args) + assert new_state.lr == 0.005 + assert new_state.plateau_count == 0 + assert new_state.cooldown_counter == 4 + + + def test_learning_rate_is_not_reduced(self): + """Test that plateau count resets after a new best loss is found.""" + state = ReduceLROnPlateauState( + reduce_factor=0.5, + patience=5, + min_improvement=1e-4, + best_loss=0.1, + plateau_count=3, + lr=0.01, + cooldown_counter=0, + cooldown=5, + ) + # Define a dummy update and extra_args + updates = {'params': 1} + extra_args = {'loss': 0.01} + # Apply the transformation to the updates and state + transform = reduce_on_plateau(reduce_factor=0.5, patience=5, min_improvement=1e-4, cooldown=5) + new_updates, new_state = transform.update(updates=updates, state=state, extra_args=extra_args) + # Check that plateau count resets + assert new_state.plateau_count == 0 + assert new_state.best_loss == 0.01 + + + def test_learning_rate_not_reduced_during_cooldown(self): + """Test that learning rate is not reduced during cooldown.""" + # Define a state where cooldown_counter is positive + state = ReduceLROnPlateauState( + reduce_factor=0.5, + patience=5, + min_improvement=1e-4, + best_loss=0.1, + plateau_count=4, + lr=0.01, + cooldown=5, + cooldown_counter=3, + ) + # Define a dummy update and extra_args + updates = {'params': 1} + extra_args = {'loss': 0.15} + # Apply the transformation to the updates and state + transform = reduce_on_plateau(reduce_factor=0.5, patience=5, min_improvement=1e-4, cooldown=5) + new_updates, new_state = transform.update(updates=updates, state=state, extra_args=extra_args) + # Check that learning rate is not reduced + assert new_state.lr == 0.01 + + +if __name__ == '__main__': + absltest.main() \ No newline at end of file From 8facf45155ae9055b801efe1d7114fbc635861d7 Mon Sep 17 00:00:00 2001 From: vz415 <19888961+vz415@users.noreply.github.com> Date: Mon, 11 Dec 2023 14:17:01 -0600 Subject: [PATCH 2/6] fixed spacing for ReduceLROnPlateau --- optax/contrib/reduce_on_plateau.py | 198 ++++++++++++------------ optax/contrib/reduce_on_plateau_test.py | 158 ++++++++++--------- 2 files changed, 182 insertions(+), 174 deletions(-) diff --git a/optax/contrib/reduce_on_plateau.py b/optax/contrib/reduce_on_plateau.py index 2b843317..8da5d769 100644 --- a/optax/contrib/reduce_on_plateau.py +++ b/optax/contrib/reduce_on_plateau.py @@ -28,104 +28,108 @@ class ReduceLROnPlateauState(NamedTuple): - """State for the ReduceLROnPlateau callback.""" - reduce_factor: float - patience: int - min_improvement: float - best_loss: float - plateau_count: int - lr: float - cooldown_counter: int - cooldown:int + """State for the ReduceLROnPlateau callback.""" + reduce_factor: float + patience: int + min_improvement: float + best_loss: float + plateau_count: int + lr: float + cooldown_counter: int + cooldown:int def reduce_on_plateau( - reduce_factor: float, - patience: int, - min_improvement:float, - cooldown:int + reduce_factor: float, + patience: int, + min_improvement:float, + cooldown:int ) -> base.GradientTransformationExtraArgs: - """ Args: - reduce_factor: Factor by which the learning rate will be reduced. - new_lr = lr * factor. - patience: Number of epochs with no improvement after which learning - rate will be reduced. - min_improvement: Threshold for measuring the new optimum, to only focus on - significant changes. - cooldown: Number of epochs to wait before resuming normal operation - after lr has been reduced. - """ - - - def init_fn(params): - del params - return ReduceLROnPlateauState(patience=patience, - reduce_factor=reduce_factor, - min_improvement=min_improvement, - cooldown=cooldown, - cooldown_counter=0, - plateau_count=0, - best_loss=float("inf"), - lr=1, - ) - - def update_fn( - updates, - state, - params=None, - extra_args={}, - ): - del params - current_loss = extra_args.get("loss") - - # Check if the current loss is the best so far - - best_loss = state.best_loss - # Update plateau count and check if plateaued - has_improved = jnp.where( - (current_loss / best_loss - 1) < -state.min_improvement, 1, 0 - ) - new_best_loss = jnp.where(has_improved, current_loss, best_loss) - - curr_plateau_count = jnp.where(has_improved, 0, state.plateau_count + 1) - - - # We're in cooldown, so reduce the counter and ignore any bad epochs - def in_cooldown(): - new_plateau_count = 0 - new_lr = state.lr - new_cooldown_counter = state.cooldown_counter - 1 - return new_plateau_count, new_lr, new_cooldown_counter - - # We're not in cooldown, so update the plateau count and lr as usual - def not_in_cooldown(): - new_plateau_count = jnp.where( - curr_plateau_count == state.patience, 0, curr_plateau_count - ) - new_lr = jnp.where( - curr_plateau_count == state.patience, - state.lr * state.reduce_factor, - state.lr, - ) - new_cooldown_counter = jnp.where( - curr_plateau_count == state.patience, state.cooldown, 0 - ) - return new_plateau_count, new_lr, new_cooldown_counter - - new_plateau_count, new_lr, new_cooldown_counter = jax.lax.cond(state.cooldown_counter > 0, in_cooldown, not_in_cooldown) - - updates = jax.tree_util.tree_map(lambda g: new_lr * g, updates) - - new_state = ReduceLROnPlateauState( - patience=state.patience, - reduce_factor=state.reduce_factor, - min_improvement=state.min_improvement, - plateau_count=new_plateau_count, - best_loss=new_best_loss, - lr=new_lr, - cooldown_counter=new_cooldown_counter, - cooldown=state.cooldown, - ) - return updates, new_state - - return base.GradientTransformationExtraArgs(init_fn, update_fn) \ No newline at end of file + """ Args: + reduce_factor: Factor by which the learning rate will be reduced. + new_lr = lr * factor. + patience: Number of epochs with no improvement after which learning + rate will be reduced. + min_improvement: Threshold for measuring the new optimum, to only focus on + significant changes. + cooldown: Number of epochs to wait before resuming normal operation + after lr has been reduced. + """ + + + def init_fn(params): + del params + return ReduceLROnPlateauState(patience=patience, + reduce_factor=reduce_factor, + min_improvement=min_improvement, + cooldown=cooldown, + cooldown_counter=0, + plateau_count=0, + best_loss=float("inf"), + lr=1, + ) + + def update_fn( + updates, + state, + params=None, + extra_args=None, + ): + del params + if extra_args is None: + extra_args = {} + current_loss = extra_args.get("loss") + + # Check if the current loss is the best so far + + best_loss = state.best_loss + # Update plateau count and check if plateaued + has_improved = jnp.where( + (current_loss / best_loss - 1) < -state.min_improvement, 1, 0 + ) + new_best_loss = jnp.where(has_improved, current_loss, best_loss) + + curr_plateau_count = jnp.where(has_improved, 0, state.plateau_count + 1) + + + # We're in cooldown, so reduce the counter and ignore any bad epochs + def in_cooldown(): + new_plateau_count = 0 + new_lr = state.lr + new_cooldown_counter = state.cooldown_counter - 1 + return new_plateau_count, new_lr, new_cooldown_counter + + # We're not in cooldown, so update the plateau count and lr as usual + def not_in_cooldown(): + new_plateau_count = jnp.where( + curr_plateau_count == state.patience, 0, curr_plateau_count + ) + new_lr = jnp.where( + curr_plateau_count == state.patience, + state.lr * state.reduce_factor, + state.lr, + ) + new_cooldown_counter = jnp.where( + curr_plateau_count == state.patience, state.cooldown, 0 + ) + return new_plateau_count, new_lr, new_cooldown_counter + + new_plateau_count, new_lr, new_cooldown_counter = jax.lax.cond( + state.cooldown_counter > 0, in_cooldown, not_in_cooldown) + + updates = jax.tree_util.tree_map(lambda g: new_lr * g, updates) + + new_state = ReduceLROnPlateauState( + patience=state.patience, + reduce_factor=state.reduce_factor, + min_improvement=state.min_improvement, + plateau_count=new_plateau_count, + best_loss=new_best_loss, + lr=new_lr, + cooldown_counter=new_cooldown_counter, + cooldown=state.cooldown, + ) + return updates, new_state + + + return base.GradientTransformationExtraArgs(init_fn, update_fn) diff --git a/optax/contrib/reduce_on_plateau_test.py b/optax/contrib/reduce_on_plateau_test.py index 5110f180..97980ceb 100644 --- a/optax/contrib/reduce_on_plateau_test.py +++ b/optax/contrib/reduce_on_plateau_test.py @@ -15,92 +15,96 @@ """Tests for `reduce_on_plateau.py`.""" from absl.testing import absltest -import chex -import jax -import jax.numpy as jnp -from optax._src import base -from optax._src import transform from optax.contrib.reduce_on_plateau import ReduceLROnPlateauState, reduce_on_plateau class ReduceLROnPlateauTest(absltest.TestCase): - def test_learning_rate_reduced_after_cooldown_period_is_over(self): - """Test that learning rate is reduced again after cooldown period is over.""" - # Define a state where cooldown_counter is zero - state = ReduceLROnPlateauState( - reduce_factor=0.5, - patience=5, - min_improvement=1e-4, - best_loss=0.1, - ## - plateau_count=4, - ## - lr=0.01, - cooldown=5, - cooldown_counter=0, - ) - # Define a dummy update and extra_args - updates = {'params': 1} - extra_args = {'loss': 0.15} - # Apply the transformation to the updates and state - transform = reduce_on_plateau(reduce_factor=0.5, patience=5, min_improvement=1e-4, cooldown=5) - new_updates, new_state = transform.update(updates=updates, state=state, extra_args=extra_args) - # Check that learning rate is reduced - assert new_state.lr == 0.005 - assert new_state.plateau_count == 0 - assert new_state.cooldown_counter == 5 - new_updates, new_state = transform.update(updates=updates, state=new_state, extra_args=extra_args) - assert new_state.lr == 0.005 - assert new_state.plateau_count == 0 - assert new_state.cooldown_counter == 4 + def test_learning_rate_reduced_after_cooldown_period_is_over(self): + """Test that learning rate is reduced again after cooldown period is over. + """ + # Define a state where cooldown_counter is zero + state = ReduceLROnPlateauState( + reduce_factor=0.5, + patience=5, + min_improvement=1e-4, + best_loss=0.1, + ## + plateau_count=4, + ## + lr=0.01, + cooldown=5, + cooldown_counter=0, + ) + # Define a dummy update and extra_args + updates = {'params': 1} + extra_args = {'loss': 0.15} + # Apply the transformation to the updates and state + transform = reduce_on_plateau( + reduce_factor=0.5, patience=5, min_improvement=1e-4, cooldown=5) + _, new_state = transform.update( + updates=updates, state=state, extra_args=extra_args) + # Check that learning rate is reduced + assert new_state.lr == 0.005 + assert new_state.plateau_count == 0 + assert new_state.cooldown_counter == 5 + _, new_state = transform.update( + updates=updates, state=new_state, extra_args=extra_args) + assert new_state.lr == 0.005 + assert new_state.plateau_count == 0 + assert new_state.cooldown_counter == 4 - def test_learning_rate_is_not_reduced(self): - """Test that plateau count resets after a new best loss is found.""" - state = ReduceLROnPlateauState( - reduce_factor=0.5, - patience=5, - min_improvement=1e-4, - best_loss=0.1, - plateau_count=3, - lr=0.01, - cooldown_counter=0, - cooldown=5, - ) - # Define a dummy update and extra_args - updates = {'params': 1} - extra_args = {'loss': 0.01} - # Apply the transformation to the updates and state - transform = reduce_on_plateau(reduce_factor=0.5, patience=5, min_improvement=1e-4, cooldown=5) - new_updates, new_state = transform.update(updates=updates, state=state, extra_args=extra_args) - # Check that plateau count resets - assert new_state.plateau_count == 0 - assert new_state.best_loss == 0.01 + def test_learning_rate_is_not_reduced(self): + """Test that plateau count resets after a new best loss is found.""" + state = ReduceLROnPlateauState( + reduce_factor=0.5, + patience=5, + min_improvement=1e-4, + best_loss=0.1, + plateau_count=3, + lr=0.01, + cooldown_counter=0, + cooldown=5, + ) + # Define a dummy update and extra_args + updates = {'params': 1} + extra_args = {'loss': 0.01} + # Apply the transformation to the updates and state + transform = reduce_on_plateau( + reduce_factor=0.5, patience=5, min_improvement=1e-4, cooldown=5) + _, new_state = transform.update( + updates=updates, state=state, extra_args=extra_args) + # Check that plateau count resets + assert new_state.plateau_count == 0 + assert new_state.best_loss == 0.01 - def test_learning_rate_not_reduced_during_cooldown(self): - """Test that learning rate is not reduced during cooldown.""" - # Define a state where cooldown_counter is positive - state = ReduceLROnPlateauState( - reduce_factor=0.5, - patience=5, - min_improvement=1e-4, - best_loss=0.1, - plateau_count=4, - lr=0.01, - cooldown=5, - cooldown_counter=3, - ) - # Define a dummy update and extra_args - updates = {'params': 1} - extra_args = {'loss': 0.15} - # Apply the transformation to the updates and state - transform = reduce_on_plateau(reduce_factor=0.5, patience=5, min_improvement=1e-4, cooldown=5) - new_updates, new_state = transform.update(updates=updates, state=state, extra_args=extra_args) - # Check that learning rate is not reduced - assert new_state.lr == 0.01 + def test_learning_rate_not_reduced_during_cooldown(self): + """Test that learning rate is not reduced during cooldown.""" + # Define a state where cooldown_counter is positive + state = ReduceLROnPlateauState( + reduce_factor=0.5, + patience=5, + min_improvement=1e-4, + best_loss=0.1, + plateau_count=4, + lr=0.01, + cooldown=5, + cooldown_counter=3, + ) + # Define a dummy update and extra_args + updates = {'params': 1} + extra_args = {'loss': 0.15} + # Apply the transformation to the updates and state + transform = reduce_on_plateau( + reduce_factor=0.5, patience=5, min_improvement=1e-4, cooldown=5) + _, new_state = transform.update( + updates=updates, state=state, extra_args=extra_args) + # Check that learning rate is not reduced + assert new_state.lr == 0.01 if __name__ == '__main__': - absltest.main() \ No newline at end of file + absltest.main() + \ No newline at end of file From f02913d2101d12468954c87773975126a454c3bf Mon Sep 17 00:00:00 2001 From: vz415 <19888961+vz415@users.noreply.github.com> Date: Mon, 11 Dec 2023 14:35:29 -0600 Subject: [PATCH 3/6] fix prodigy file type error --- optax/contrib/prodigy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/contrib/prodigy.py b/optax/contrib/prodigy.py index a6c5bab2..392b671d 100644 --- a/optax/contrib/prodigy.py +++ b/optax/contrib/prodigy.py @@ -46,7 +46,7 @@ class ProdigyState(NamedTuple): def prodigy( learning_rate: base.ScalarOrSchedule = 0.1, betas: tuple[float, float] = (0.9, 0.999), - beta3: float | None = None, + beta3: float = None, eps: float = 1e-8, estim_lr0: float = 1e-6, estim_lr_coef: float = 1.0, From 530d4ae5b92eb21da6954f181b70fe29e96cc61c Mon Sep 17 00:00:00 2001 From: vz415 <19888961+vz415@users.noreply.github.com> Date: Tue, 12 Dec 2023 08:52:47 -0600 Subject: [PATCH 4/6] add reduce_on_plateau scheduler to documentation --- docs/api.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/api.rst b/docs/api.rst index 39b129d0..964acd77 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -619,6 +619,7 @@ Schedules .. autofunction:: piecewise_constant_schedule .. autofunction:: piecewise_interpolate_schedule .. autofunction:: polynomial_schedule +.. autofunction:: optax.contrib.reduce_on_plateau.reduce_on_plateau .. autofunction:: sgdr_schedule .. autofunction:: warmup_cosine_decay_schedule .. autofunction:: warmup_exponential_decay_schedule From 96c2b539dd791b19ddb1350a2fd4b267d1d45afa Mon Sep 17 00:00:00 2001 From: vz415 <19888961+vz415@users.noreply.github.com> Date: Thu, 14 Dec 2023 10:32:28 -0600 Subject: [PATCH 5/6] make suggested documentation changes --- docs/api.rst | 2 +- optax/contrib/reduce_on_plateau.py | 13 ++++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 964acd77..8552cd33 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -619,7 +619,7 @@ Schedules .. autofunction:: piecewise_constant_schedule .. autofunction:: piecewise_interpolate_schedule .. autofunction:: polynomial_schedule -.. autofunction:: optax.contrib.reduce_on_plateau.reduce_on_plateau +.. autofunction:: optax.contrib.reduce_on_plateau .. autofunction:: sgdr_schedule .. autofunction:: warmup_cosine_decay_schedule .. autofunction:: warmup_exponential_decay_schedule diff --git a/optax/contrib/reduce_on_plateau.py b/optax/contrib/reduce_on_plateau.py index 8da5d769..8f20b38c 100644 --- a/optax/contrib/reduce_on_plateau.py +++ b/optax/contrib/reduce_on_plateau.py @@ -42,10 +42,17 @@ class ReduceLROnPlateauState(NamedTuple): def reduce_on_plateau( reduce_factor: float, patience: int, - min_improvement:float, - cooldown:int + min_improvement: float, + cooldown: int ) -> base.GradientTransformationExtraArgs: - """ Args: + """ Reduce learning rate when a metric has stopped improving. + + Models often benefit from reducing the learning once learning stagnates. + his scheduler reads a metrics quantity and if no improvement is seen for + a ‘patience’ number of epochs, the learning rate is reduced + + Args: + reduce_factor: Factor by which the learning rate will be reduced. new_lr = lr * factor. patience: Number of epochs with no improvement after which learning From eb176d9e5265ec404509542dd1b5a84107b74e95 Mon Sep 17 00:00:00 2001 From: Fabian Pedregosa Date: Thu, 14 Dec 2023 14:29:52 -0600 Subject: [PATCH 6/6] Apply suggestions from code review --- optax/contrib/reduce_on_plateau.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/optax/contrib/reduce_on_plateau.py b/optax/contrib/reduce_on_plateau.py index 8f20b38c..218f03f0 100644 --- a/optax/contrib/reduce_on_plateau.py +++ b/optax/contrib/reduce_on_plateau.py @@ -36,7 +36,7 @@ class ReduceLROnPlateauState(NamedTuple): plateau_count: int lr: float cooldown_counter: int - cooldown:int + cooldown: int def reduce_on_plateau( @@ -49,17 +49,17 @@ def reduce_on_plateau( Models often benefit from reducing the learning once learning stagnates. his scheduler reads a metrics quantity and if no improvement is seen for - a ‘patience’ number of epochs, the learning rate is reduced + a ‘patience’ number of epochs, the learning rate is reduced. Args: reduce_factor: Factor by which the learning rate will be reduced. new_lr = lr * factor. - patience: Number of epochs with no improvement after which learning + patience: Number of iterations with no improvement after which learning rate will be reduced. min_improvement: Threshold for measuring the new optimum, to only focus on significant changes. - cooldown: Number of epochs to wait before resuming normal operation + cooldown: Number of iterations to wait before resuming normal operation after lr has been reduced. """ @@ -88,7 +88,6 @@ def update_fn( current_loss = extra_args.get("loss") # Check if the current loss is the best so far - best_loss = state.best_loss # Update plateau count and check if plateaued has_improved = jnp.where(