Skip to content

Commit

Permalink
Merge pull request #629 from vz415:master
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 592225070
  • Loading branch information
OptaxDev committed Dec 19, 2023
2 parents e084f56 + eb176d9 commit e355cd5
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ Schedules
.. autofunction:: piecewise_constant_schedule
.. autofunction:: piecewise_interpolate_schedule
.. autofunction:: polynomial_schedule
.. autofunction:: optax.contrib.reduce_on_plateau
.. autofunction:: sgdr_schedule
.. autofunction:: warmup_cosine_decay_schedule
.. autofunction:: warmup_exponential_decay_schedule
Expand Down
2 changes: 2 additions & 0 deletions optax/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from optax.contrib.privacy import dpsgd
from optax.contrib.prodigy import prodigy
from optax.contrib.prodigy import ProdigyState
from optax.contrib.reduce_on_plateau import reduce_on_plateau
from optax.contrib.reduce_on_plateau import ReduceLROnPlateauState
from optax.contrib.sam import normalize
from optax.contrib.sam import NormalizeState
from optax.contrib.sam import sam
Expand Down
128 changes: 128 additions & 0 deletions optax/contrib/reduce_on_plateau.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Reduce Learning Rate on Plateau callback.
This callback monitors a quantity and if no improvement is seen for a 'patience'
number of epochs, the learning rate is reduced by a factor of 'reduce_factor'.
Optionally, a cooldown period can be specified during which the learning rate
will not be reduced.
"""
from typing import NamedTuple, Tuple

import chex
import jax
import jax.numpy as jnp
from optax._src import base
from optax._src import numerics


class ReduceLROnPlateauState(NamedTuple):
"""State for the ReduceLROnPlateau callback."""

lr: chex.Array # shape=(), dtype=jnp.float32
best_loss: chex.Array # shape=(), dtype=jnp.float32
plateau_count: chex.Array # shape=(), dtype=jnp.int32
cooldown_counter: chex.Array # shape=(), dtype=jnp.int32


def reduce_on_plateau(
factor: float = 0.1,
patience: int = 10,
threshold: float = 1e-4,
cooldown: int = 0,
) -> base.GradientTransformationExtraArgs:
"""Reduce learning rate when a metric has stopped improving.
Models often benefit from reducing the learning once learning stagnates.
his scheduler reads a metrics quantity and if no improvement is seen for
a ``patience`` number of epochs, the learning rate is reduced.
Args:
factor: Factor by which to reduce the learning rate. new_lr = lr * factor.
patience: Number of iterations with no improvement after which learning rate
will be reduced.
threshold: Threshold for measuring the new optimum, to only focus on
significant changes.
cooldown: Number of iterations to wait before resuming normal operation
after lr has been reduced.
Returns:
A GradientTransformationExtraArgs object.
"""

def init_fn(params) -> ReduceLROnPlateauState:
del params
return ReduceLROnPlateauState(
best_loss=jnp.asarray(float("inf"), dtype=jnp.float32),
plateau_count=jnp.asarray(0, jnp.int32),
lr=jnp.asarray(1.0, dtype=jnp.float32),
cooldown_counter=jnp.asarray(0, jnp.int32),
)

def update_fn(
updates: base.Updates,
state: ReduceLROnPlateauState,
params=None,
*,
loss,
**extra_args,
) -> Tuple[base.Params, ReduceLROnPlateauState]:
del params, extra_args

# Update plateau count and check if plateaued
has_improved = jnp.where((loss / state.best_loss - 1) < -threshold, 1, 0)
new_best_loss = jnp.where(has_improved, loss, state.best_loss)

curr_plateau_count = jnp.where(
has_improved, 0, numerics.safe_int32_increment(state.plateau_count)
)

# We're in cooldown, so reduce the counter and ignore any bad epochs
def in_cooldown():
new_plateau_count = 0
new_lr = state.lr
new_cooldown_counter = state.cooldown_counter - 1
return new_plateau_count, new_lr, new_cooldown_counter

# We're not in cooldown, so update the plateau count and lr as usual
def not_in_cooldown():
new_plateau_count = jnp.where(
curr_plateau_count == patience, 0, curr_plateau_count
)
new_lr = jnp.where(
curr_plateau_count == patience,
state.lr * factor,
state.lr,
)
new_cooldown_counter = jnp.where(
curr_plateau_count == patience, cooldown, 0
)
return new_plateau_count, new_lr, new_cooldown_counter

new_plateau_count, new_lr, new_cooldown_counter = jax.lax.cond(
state.cooldown_counter > 0, in_cooldown, not_in_cooldown
)

updates = jax.tree_util.tree_map(lambda g: new_lr * g, updates)

new_state = ReduceLROnPlateauState(
plateau_count=new_plateau_count,
best_loss=new_best_loss,
lr=new_lr,
cooldown_counter=new_cooldown_counter,
)
return updates, new_state

return base.GradientTransformationExtraArgs(init_fn, update_fn)
97 changes: 97 additions & 0 deletions optax/contrib/reduce_on_plateau_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `reduce_on_plateau.py`."""

from absl.testing import absltest
import chex
import jax.numpy as jnp
from optax import contrib


class ReduceLROnPlateauTest(absltest.TestCase):

def test_learning_rate_reduced_after_cooldown_period_is_over(self):
"""Test that learning rate is reduced again after cooldown period is over."""

# Define a dummy update and extra_args
updates = {'params': jnp.array(1.0)}
patience = 5
cooldown = 5
# Apply the transformation to the updates and state
transform = contrib.reduce_on_plateau(patience=patience, cooldown=cooldown)
state = transform.init(updates['params'])
for _ in range(patience + 1):
updates, state = transform.update(updates=updates, state=state, loss=1.0)
# Check that learning rate is reduced
# we access the fields inside new_state using indices instead of attributes
# because otherwise pytype throws an error
lr, best_loss, plateau_count, cooldown_counter = state
chex.assert_trees_all_close(lr, 0.1)
chex.assert_trees_all_close(best_loss, 1.0)
chex.assert_trees_all_close(plateau_count, 0)
chex.assert_trees_all_close(cooldown_counter, cooldown)
chex.assert_trees_all_close(updates, {'params': jnp.array(0.1)})

_, state = transform.update(updates=updates, state=state, loss=1.0)
lr, best_loss, plateau_count, cooldown_counter = state
chex.assert_trees_all_close(lr, 0.1)
chex.assert_trees_all_close(best_loss, 1.0)
chex.assert_trees_all_close(plateau_count, 0)
chex.assert_trees_all_close(cooldown_counter, cooldown - 1)

def test_learning_rate_is_not_reduced(self):
"""Test that plateau count resets after a new best loss is found."""
state = contrib.ReduceLROnPlateauState(
best_loss=jnp.array(0.1, dtype=jnp.float32),
plateau_count=jnp.array(3, dtype=jnp.int32),
lr=jnp.array(0.01, dtype=jnp.float32),
cooldown_counter=jnp.array(0, dtype=jnp.int32),
)
# Define a dummy update and extra_args
updates = {'params': 1}
# Apply the transformation to the updates and state
transform = contrib.reduce_on_plateau(
factor=0.5, patience=5, threshold=1e-4, cooldown=5
)
_, new_state = transform.update(updates=updates, state=state, loss=0.01)
lr, best_loss, plateau_count, _ = new_state
# Check that plateau count resets
chex.assert_trees_all_close(plateau_count, 0)
chex.assert_trees_all_close(lr, 0.01)
chex.assert_trees_all_close(best_loss, 0.01)

def test_learning_rate_not_reduced_during_cooldown(self):
"""Test that learning rate is not reduced during cooldown."""
# Define a state where cooldown_counter is positive
state = contrib.ReduceLROnPlateauState(
best_loss=jnp.array(0.1, dtype=jnp.float32),
plateau_count=jnp.array(4, dtype=jnp.int32),
lr=jnp.array(0.01, dtype=jnp.float32),
cooldown_counter=jnp.array(3, dtype=jnp.int32),
)
# Define a dummy update and extra_args
updates = {'params': 1}
# Apply the transformation to the updates and state
transform = contrib.reduce_on_plateau(
factor=0.5, patience=5, threshold=1e-4, cooldown=5
)
_, new_state = transform.update(updates=updates, state=state, loss=0.15)
# Check that learning rate is not reduced
lr, _, _, _ = new_state
chex.assert_trees_all_close(lr, 0.01)


if __name__ == '__main__':
absltest.main()

0 comments on commit e355cd5

Please sign in to comment.