Skip to content

Commit

Permalink
Add ReduceLROnPlateau callback (keras-team#117)
Browse files Browse the repository at this point in the history
* Implement the LearningRateScheduler callback

* Add missing files

* Cleanup

* Formatting

* Remove test util for building model

* Initial implementation of ReduceLROnPlateau

* Remove unused test variables

* Tests for ReduceLROnPlateau

* Improve docstrings

* Review comments
  • Loading branch information
ianstenbit authored and fchollet committed May 9, 2023
1 parent 0322038 commit 19e1eab
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 4 deletions.
1 change: 1 addition & 0 deletions keras_core/callbacks/__init__.py
Expand Up @@ -6,5 +6,6 @@
from keras_core.callbacks.lambda_callback import LambdaCallback
from keras_core.callbacks.learning_rate_scheduler import LearningRateScheduler
from keras_core.callbacks.progbar_logger import ProgbarLogger
from keras_core.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau
from keras_core.callbacks.remote_monitor import RemoteMonitor
from keras_core.callbacks.terminate_on_nan import TerminateOnNaN
5 changes: 1 addition & 4 deletions keras_core/callbacks/learning_rate_scheduler.py
@@ -1,6 +1,5 @@
import numpy as np

from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.callbacks.callback import Callback
from keras_core.utils import io_utils
Expand Down Expand Up @@ -54,9 +53,7 @@ def on_epoch_begin(self, epoch, logs=None):
raise ValueError('Optimizer must have a "learning_rate" attribute.')

try: # new API
learning_rate = backend.Variable(
self.model.optimizer.learning_rate
).numpy()
learning_rate = float(np.array(self.model.optimizer.learning_rate))
learning_rate = self.schedule(epoch, learning_rate)
except TypeError: # Support for old API for backward compatibility
learning_rate = self.schedule(epoch)
Expand Down
143 changes: 143 additions & 0 deletions keras_core/callbacks/reduce_lr_on_plateau.py
@@ -0,0 +1,143 @@
import warnings

import numpy as np

from keras_core.api_export import keras_core_export
from keras_core.callbacks.callback import Callback
from keras_core.utils import io_utils


@keras_core_export("keras_core.callbacks.ReduceLROnPlateau")
class ReduceLROnPlateau(Callback):
"""Reduce learning rate when a metric has stopped improving.
Models often benefit from reducing the learning rate by a factor
of 2-10 once learning stagnates. This callback monitors a
quantity and if no improvement is seen for a 'patience' number
of epochs, the learning rate is reduced.
Example:
```python
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
patience=5, min_lr=0.001)
model.fit(x_train, y_train, callbacks=[reduce_lr])
```
Args:
monitor: String. Quantity to be monitored.
factor: Float. Factor by which the learning rate will be reduced.
`new_lr = lr * factor`.
patience: Integer. Number of epochs with no improvement after which
learning rate will be reduced.
verbose: Integer. 0: quiet, 1: update messages.
mode: String. One of `{'auto', 'min', 'max'}`. In `'min'` mode,
the learning rate will be reduced when the
quantity monitored has stopped decreasing; in `'max'` mode it will
be reduced when the quantity monitored has stopped increasing; in
`'auto'` mode, the direction is automatically inferred from the name
of the monitored quantity.
min_delta: Float. Threshold for measuring the new optimum, to only focus
on significant changes.
cooldown: Integer. Number of epochs to wait before resuming normal
operation after the learning rate has been reduced.
min_lr: Float. Lower bound on the learning rate.
"""

def __init__(
self,
monitor="val_loss",
factor=0.1,
patience=10,
verbose=0,
mode="auto",
min_delta=1e-4,
cooldown=0,
min_lr=0,
**kwargs,
):
super().__init__()

self.monitor = monitor
if factor >= 1.0:
raise ValueError(
"ReduceLROnPlateau does not support a factor >= 1.0. "
f"Received factor={factor}"
)

self.factor = factor
self.min_lr = min_lr
self.min_delta = min_delta
self.patience = patience
self.verbose = verbose
self.cooldown = cooldown
self.cooldown_counter = 0 # Cooldown counter.
self.wait = 0
self.best = 0
self.mode = mode
self.monitor_op = None
self._reset()

def _reset(self):
"""Resets wait counter and cooldown counter."""
if self.mode not in {"auto", "min", "max"}:
warnings.warn(
f"Learning rate reduction mode {self.mode} is unknown, "
"fallback to auto mode.",
stacklevel=2,
)
self.mode = "auto"
if self.mode == "min" or (
self.mode == "auto" and "acc" not in self.monitor
):
self.monitor_op = lambda a, b: np.less(a, b - self.min_delta)
self.best = np.Inf
else:
self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta)
self.best = -np.Inf
self.cooldown_counter = 0
self.wait = 0

def on_train_begin(self, logs=None):
self._reset()

def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
logs["lr"] = float(np.array(self.model.optimizer.learning_rate))
current = logs.get(self.monitor)

if current is None:
print("tacos")
warnings.warn(
"Learning rate reduction is conditioned on metric "
f"`{self.monitor}` which is not available. Available metrics "
f"are: {','.join(list(logs.keys()))}.",
stacklevel=2,
)
else:
if self.in_cooldown():
self.cooldown_counter -= 1
self.wait = 0

if self.monitor_op(current, self.best):
self.best = current
self.wait = 0
elif not self.in_cooldown():
self.wait += 1
if self.wait >= self.patience:
old_lr = float(np.array(self.model.optimizer.learning_rate))
if old_lr > np.float32(self.min_lr):
new_lr = old_lr * self.factor
new_lr = max(new_lr, self.min_lr)
self.model.optimizer.learning_rate = new_lr
if self.verbose > 0:
io_utils.print_msg(
f"\nEpoch {epoch +1}: "
"ReduceLROnPlateau reducing "
f"learning rate to {new_lr}."
)
self.cooldown_counter = self.cooldown
self.wait = 0

def in_cooldown(self):
return self.cooldown_counter > 0
127 changes: 127 additions & 0 deletions keras_core/callbacks/reduce_lr_on_plateau_test.py
@@ -0,0 +1,127 @@
from keras_core import callbacks
from keras_core import layers
from keras_core import optimizers
from keras_core import testing
from keras_core.models import Sequential
from keras_core.testing import test_utils
from keras_core.utils import io_utils
from keras_core.utils import numerical_utils


class ReduceLROnPlateauTest(testing.TestCase):
def setUp(self):
(x_train, y_train), (x_test, y_test) = test_utils.get_test_data(
train_samples=10,
test_samples=10,
input_shape=(3,),
num_classes=2,
)
y_test = numerical_utils.to_categorical(y_test)
y_train = numerical_utils.to_categorical(y_train)

model = Sequential([layers.Dense(5), layers.Dense(2)])

model.compile(
loss="mse",
optimizer=optimizers.Adam(0.1),
)

self.model = model
self.x_train = x_train
self.x_test = x_test
self.y_train = y_train
self.y_test = y_test

def test_reduces_lr_with_model_fit(self):
reduce_lr = callbacks.ReduceLROnPlateau(
patience=1, factor=0.1, monitor="val_loss", min_delta=10
)

self.model.fit(
self.x_train,
self.y_train,
validation_data=(self.x_test, self.y_test),
callbacks=[reduce_lr],
epochs=2,
)

self.assertEqual(self.model.optimizer.learning_rate.value, 0.01)

def test_throws_when_optimizer_has_schedule(self):
reduce_lr = callbacks.ReduceLROnPlateau(
patience=1, factor=0.1, monitor="val_loss", min_delta=10
)

self.model.compile(
loss="mse",
optimizer=optimizers.Adam(
optimizers.schedules.PolynomialDecay(
initial_learning_rate=0.1, decay_steps=10
)
),
)

with self.assertRaisesRegex(
TypeError,
"This optimizer was created with a `LearningRateSchedule`",
):
self.model.fit(
self.x_train,
self.y_train,
validation_data=(self.x_test, self.y_test),
callbacks=[reduce_lr],
epochs=2,
)

def test_verbose_logging(self):
reduce_lr = callbacks.ReduceLROnPlateau(
patience=1, factor=0.1, monitor="val_loss", min_delta=10, verbose=1
)
io_utils.disable_interactive_logging()

with self.assertLogs(level="INFO") as logs:
self.model.fit(
self.x_train,
self.y_train,
validation_data=(self.x_test, self.y_test),
callbacks=[reduce_lr],
epochs=2,
)
expected_log = "ReduceLROnPlateau reducing learning rate to 0.01"
self.assertTrue(any(expected_log in log for log in logs.output))

def test_honors_min_lr(self):
reduce_lr = callbacks.ReduceLROnPlateau(
patience=1,
factor=0.1,
monitor="val_loss",
min_delta=10,
min_lr=0.005,
)

self.model.fit(
self.x_train,
self.y_train,
validation_data=(self.x_test, self.y_test),
callbacks=[reduce_lr],
epochs=4,
)

self.assertEqual(self.model.optimizer.learning_rate.value, 0.005)

def test_cooldown(self):
reduce_lr = callbacks.ReduceLROnPlateau(
patience=1, factor=0.1, monitor="val_loss", min_delta=10, cooldown=2
)

self.model.fit(
self.x_train,
self.y_train,
validation_data=(self.x_test, self.y_test),
callbacks=[reduce_lr],
epochs=4,
)

# With a cooldown of 2 epochs, we should only reduce the LR every other
# epoch, so after 4 epochs we will have reduced 2 times.
self.assertAllClose(self.model.optimizer.learning_rate.value, 0.001)

0 comments on commit 19e1eab

Please sign in to comment.