Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add atol option to contrib.reduce_on_plateau() #698

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 57 additions & 62 deletions examples/contrib/reduce_on_plateau.ipynb

Large diffs are not rendered by default.

24 changes: 20 additions & 4 deletions optax/contrib/reduce_on_plateau.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class ReduceLROnPlateauState(NamedTuple):
def reduce_on_plateau(
factor: float = 0.1,
patience: int = 10,
threshold: float = 1e-4,
rtol: float = 1e-4,
atol: float = 0.0,
cooldown: int = 0,
) -> base.GradientTransformationExtraArgs:
"""Reduce learning rate when a metric has stopped improving.
Expand All @@ -53,14 +54,29 @@ def reduce_on_plateau(
factor: Factor by which to reduce the learning rate. new_lr = lr * factor.
patience: Number of iterations with no improvement after which learning rate
will be reduced.
threshold: Threshold for measuring the new optimum, to only focus on
significant changes.
rtol: Relative tolerance for measuring new optimum.
atol: Absolute tolerance for measuring new optimum.
cooldown: Number of iterations to wait before resuming normal operation
after lr has been reduced.

Returns:
A GradientTransformationExtraArgs object.
"""
if rtol < 0.0 or atol < 0.0:
raise ValueError(
"Both rtol and atol must be non-negative, got "
f"rtol = {rtol} and atol = {atol}."
)
elif rtol == 0.0 and atol == 0.0:
raise ValueError(
"At least one of rtol or atol must be positive, got "
f"rtol = {rtol} and atol = {atol}."
)
elif rtol > 1.0:
raise ValueError(
f"rtol must be less than or equal to 1.0, got rtol = {rtol}."
)


def init_fn(params) -> ReduceLROnPlateauState:
del params
Expand All @@ -82,7 +98,7 @@ def update_fn(
del params, extra_args

# Update plateau count and check if plateaued
has_improved = jnp.where((loss / state.best_loss - 1) < -threshold, 1, 0)
has_improved = jnp.where(loss < (1 - rtol) * state.best_loss - atol, 1, 0)
new_best_loss = jnp.where(has_improved, loss, state.best_loss)

curr_plateau_count = jnp.where(
Expand Down
3 changes: 2 additions & 1 deletion optax/contrib/reduce_on_plateau_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def setUp(self):
self.transform = contrib.reduce_on_plateau(
factor=0.1,
patience=self.patience,
threshold=1e-4,
rtol=1e-4,
atol=0.0,
cooldown=self.cooldown,
)
self.updates = {'params': jnp.array(1.0)} # dummy updates
Expand Down
Loading