-
Notifications
You must be signed in to change notification settings - Fork 165
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #629 from vz415:master
PiperOrigin-RevId: 592225070
- Loading branch information
Showing
4 changed files
with
228 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Reduce Learning Rate on Plateau callback. | ||
This callback monitors a quantity and if no improvement is seen for a 'patience' | ||
number of epochs, the learning rate is reduced by a factor of 'reduce_factor'. | ||
Optionally, a cooldown period can be specified during which the learning rate | ||
will not be reduced. | ||
""" | ||
from typing import NamedTuple, Tuple | ||
|
||
import chex | ||
import jax | ||
import jax.numpy as jnp | ||
from optax._src import base | ||
from optax._src import numerics | ||
|
||
|
||
class ReduceLROnPlateauState(NamedTuple): | ||
"""State for the ReduceLROnPlateau callback.""" | ||
|
||
lr: chex.Array # shape=(), dtype=jnp.float32 | ||
best_loss: chex.Array # shape=(), dtype=jnp.float32 | ||
plateau_count: chex.Array # shape=(), dtype=jnp.int32 | ||
cooldown_counter: chex.Array # shape=(), dtype=jnp.int32 | ||
|
||
|
||
def reduce_on_plateau( | ||
factor: float = 0.1, | ||
patience: int = 10, | ||
threshold: float = 1e-4, | ||
cooldown: int = 0, | ||
) -> base.GradientTransformationExtraArgs: | ||
"""Reduce learning rate when a metric has stopped improving. | ||
Models often benefit from reducing the learning once learning stagnates. | ||
his scheduler reads a metrics quantity and if no improvement is seen for | ||
a ``patience`` number of epochs, the learning rate is reduced. | ||
Args: | ||
factor: Factor by which to reduce the learning rate. new_lr = lr * factor. | ||
patience: Number of iterations with no improvement after which learning rate | ||
will be reduced. | ||
threshold: Threshold for measuring the new optimum, to only focus on | ||
significant changes. | ||
cooldown: Number of iterations to wait before resuming normal operation | ||
after lr has been reduced. | ||
Returns: | ||
A GradientTransformationExtraArgs object. | ||
""" | ||
|
||
def init_fn(params) -> ReduceLROnPlateauState: | ||
del params | ||
return ReduceLROnPlateauState( | ||
best_loss=jnp.asarray(float("inf"), dtype=jnp.float32), | ||
plateau_count=jnp.asarray(0, jnp.int32), | ||
lr=jnp.asarray(1.0, dtype=jnp.float32), | ||
cooldown_counter=jnp.asarray(0, jnp.int32), | ||
) | ||
|
||
def update_fn( | ||
updates: base.Updates, | ||
state: ReduceLROnPlateauState, | ||
params=None, | ||
*, | ||
loss, | ||
**extra_args, | ||
) -> Tuple[base.Params, ReduceLROnPlateauState]: | ||
del params, extra_args | ||
|
||
# Update plateau count and check if plateaued | ||
has_improved = jnp.where((loss / state.best_loss - 1) < -threshold, 1, 0) | ||
new_best_loss = jnp.where(has_improved, loss, state.best_loss) | ||
|
||
curr_plateau_count = jnp.where( | ||
has_improved, 0, numerics.safe_int32_increment(state.plateau_count) | ||
) | ||
|
||
# We're in cooldown, so reduce the counter and ignore any bad epochs | ||
def in_cooldown(): | ||
new_plateau_count = 0 | ||
new_lr = state.lr | ||
new_cooldown_counter = state.cooldown_counter - 1 | ||
return new_plateau_count, new_lr, new_cooldown_counter | ||
|
||
# We're not in cooldown, so update the plateau count and lr as usual | ||
def not_in_cooldown(): | ||
new_plateau_count = jnp.where( | ||
curr_plateau_count == patience, 0, curr_plateau_count | ||
) | ||
new_lr = jnp.where( | ||
curr_plateau_count == patience, | ||
state.lr * factor, | ||
state.lr, | ||
) | ||
new_cooldown_counter = jnp.where( | ||
curr_plateau_count == patience, cooldown, 0 | ||
) | ||
return new_plateau_count, new_lr, new_cooldown_counter | ||
|
||
new_plateau_count, new_lr, new_cooldown_counter = jax.lax.cond( | ||
state.cooldown_counter > 0, in_cooldown, not_in_cooldown | ||
) | ||
|
||
updates = jax.tree_util.tree_map(lambda g: new_lr * g, updates) | ||
|
||
new_state = ReduceLROnPlateauState( | ||
plateau_count=new_plateau_count, | ||
best_loss=new_best_loss, | ||
lr=new_lr, | ||
cooldown_counter=new_cooldown_counter, | ||
) | ||
return updates, new_state | ||
|
||
return base.GradientTransformationExtraArgs(init_fn, update_fn) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for `reduce_on_plateau.py`.""" | ||
|
||
from absl.testing import absltest | ||
import chex | ||
import jax.numpy as jnp | ||
from optax import contrib | ||
|
||
|
||
class ReduceLROnPlateauTest(absltest.TestCase): | ||
|
||
def test_learning_rate_reduced_after_cooldown_period_is_over(self): | ||
"""Test that learning rate is reduced again after cooldown period is over.""" | ||
|
||
# Define a dummy update and extra_args | ||
updates = {'params': jnp.array(1.0)} | ||
patience = 5 | ||
cooldown = 5 | ||
# Apply the transformation to the updates and state | ||
transform = contrib.reduce_on_plateau(patience=patience, cooldown=cooldown) | ||
state = transform.init(updates['params']) | ||
for _ in range(patience + 1): | ||
updates, state = transform.update(updates=updates, state=state, loss=1.0) | ||
# Check that learning rate is reduced | ||
# we access the fields inside new_state using indices instead of attributes | ||
# because otherwise pytype throws an error | ||
lr, best_loss, plateau_count, cooldown_counter = state | ||
chex.assert_trees_all_close(lr, 0.1) | ||
chex.assert_trees_all_close(best_loss, 1.0) | ||
chex.assert_trees_all_close(plateau_count, 0) | ||
chex.assert_trees_all_close(cooldown_counter, cooldown) | ||
chex.assert_trees_all_close(updates, {'params': jnp.array(0.1)}) | ||
|
||
_, state = transform.update(updates=updates, state=state, loss=1.0) | ||
lr, best_loss, plateau_count, cooldown_counter = state | ||
chex.assert_trees_all_close(lr, 0.1) | ||
chex.assert_trees_all_close(best_loss, 1.0) | ||
chex.assert_trees_all_close(plateau_count, 0) | ||
chex.assert_trees_all_close(cooldown_counter, cooldown - 1) | ||
|
||
def test_learning_rate_is_not_reduced(self): | ||
"""Test that plateau count resets after a new best loss is found.""" | ||
state = contrib.ReduceLROnPlateauState( | ||
best_loss=jnp.array(0.1, dtype=jnp.float32), | ||
plateau_count=jnp.array(3, dtype=jnp.int32), | ||
lr=jnp.array(0.01, dtype=jnp.float32), | ||
cooldown_counter=jnp.array(0, dtype=jnp.int32), | ||
) | ||
# Define a dummy update and extra_args | ||
updates = {'params': 1} | ||
# Apply the transformation to the updates and state | ||
transform = contrib.reduce_on_plateau( | ||
factor=0.5, patience=5, threshold=1e-4, cooldown=5 | ||
) | ||
_, new_state = transform.update(updates=updates, state=state, loss=0.01) | ||
lr, best_loss, plateau_count, _ = new_state | ||
# Check that plateau count resets | ||
chex.assert_trees_all_close(plateau_count, 0) | ||
chex.assert_trees_all_close(lr, 0.01) | ||
chex.assert_trees_all_close(best_loss, 0.01) | ||
|
||
def test_learning_rate_not_reduced_during_cooldown(self): | ||
"""Test that learning rate is not reduced during cooldown.""" | ||
# Define a state where cooldown_counter is positive | ||
state = contrib.ReduceLROnPlateauState( | ||
best_loss=jnp.array(0.1, dtype=jnp.float32), | ||
plateau_count=jnp.array(4, dtype=jnp.int32), | ||
lr=jnp.array(0.01, dtype=jnp.float32), | ||
cooldown_counter=jnp.array(3, dtype=jnp.int32), | ||
) | ||
# Define a dummy update and extra_args | ||
updates = {'params': 1} | ||
# Apply the transformation to the updates and state | ||
transform = contrib.reduce_on_plateau( | ||
factor=0.5, patience=5, threshold=1e-4, cooldown=5 | ||
) | ||
_, new_state = transform.update(updates=updates, state=state, loss=0.15) | ||
# Check that learning rate is not reduced | ||
lr, _, _, _ = new_state | ||
chex.assert_trees_all_close(lr, 0.01) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |