forked from keras-team/keras
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ReduceLROnPlateau callback (keras-team#117)
* Implement the LearningRateScheduler callback * Add missing files * Cleanup * Formatting * Remove test util for building model * Initial implementation of ReduceLROnPlateau * Remove unused test variables * Tests for ReduceLROnPlateau * Improve docstrings * Review comments
- Loading branch information
1 parent
0322038
commit 19e1eab
Showing
4 changed files
with
272 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
import warnings | ||
|
||
import numpy as np | ||
|
||
from keras_core.api_export import keras_core_export | ||
from keras_core.callbacks.callback import Callback | ||
from keras_core.utils import io_utils | ||
|
||
|
||
@keras_core_export("keras_core.callbacks.ReduceLROnPlateau") | ||
class ReduceLROnPlateau(Callback): | ||
"""Reduce learning rate when a metric has stopped improving. | ||
Models often benefit from reducing the learning rate by a factor | ||
of 2-10 once learning stagnates. This callback monitors a | ||
quantity and if no improvement is seen for a 'patience' number | ||
of epochs, the learning rate is reduced. | ||
Example: | ||
```python | ||
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, | ||
patience=5, min_lr=0.001) | ||
model.fit(x_train, y_train, callbacks=[reduce_lr]) | ||
``` | ||
Args: | ||
monitor: String. Quantity to be monitored. | ||
factor: Float. Factor by which the learning rate will be reduced. | ||
`new_lr = lr * factor`. | ||
patience: Integer. Number of epochs with no improvement after which | ||
learning rate will be reduced. | ||
verbose: Integer. 0: quiet, 1: update messages. | ||
mode: String. One of `{'auto', 'min', 'max'}`. In `'min'` mode, | ||
the learning rate will be reduced when the | ||
quantity monitored has stopped decreasing; in `'max'` mode it will | ||
be reduced when the quantity monitored has stopped increasing; in | ||
`'auto'` mode, the direction is automatically inferred from the name | ||
of the monitored quantity. | ||
min_delta: Float. Threshold for measuring the new optimum, to only focus | ||
on significant changes. | ||
cooldown: Integer. Number of epochs to wait before resuming normal | ||
operation after the learning rate has been reduced. | ||
min_lr: Float. Lower bound on the learning rate. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
monitor="val_loss", | ||
factor=0.1, | ||
patience=10, | ||
verbose=0, | ||
mode="auto", | ||
min_delta=1e-4, | ||
cooldown=0, | ||
min_lr=0, | ||
**kwargs, | ||
): | ||
super().__init__() | ||
|
||
self.monitor = monitor | ||
if factor >= 1.0: | ||
raise ValueError( | ||
"ReduceLROnPlateau does not support a factor >= 1.0. " | ||
f"Received factor={factor}" | ||
) | ||
|
||
self.factor = factor | ||
self.min_lr = min_lr | ||
self.min_delta = min_delta | ||
self.patience = patience | ||
self.verbose = verbose | ||
self.cooldown = cooldown | ||
self.cooldown_counter = 0 # Cooldown counter. | ||
self.wait = 0 | ||
self.best = 0 | ||
self.mode = mode | ||
self.monitor_op = None | ||
self._reset() | ||
|
||
def _reset(self): | ||
"""Resets wait counter and cooldown counter.""" | ||
if self.mode not in {"auto", "min", "max"}: | ||
warnings.warn( | ||
f"Learning rate reduction mode {self.mode} is unknown, " | ||
"fallback to auto mode.", | ||
stacklevel=2, | ||
) | ||
self.mode = "auto" | ||
if self.mode == "min" or ( | ||
self.mode == "auto" and "acc" not in self.monitor | ||
): | ||
self.monitor_op = lambda a, b: np.less(a, b - self.min_delta) | ||
self.best = np.Inf | ||
else: | ||
self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta) | ||
self.best = -np.Inf | ||
self.cooldown_counter = 0 | ||
self.wait = 0 | ||
|
||
def on_train_begin(self, logs=None): | ||
self._reset() | ||
|
||
def on_epoch_end(self, epoch, logs=None): | ||
logs = logs or {} | ||
logs["lr"] = float(np.array(self.model.optimizer.learning_rate)) | ||
current = logs.get(self.monitor) | ||
|
||
if current is None: | ||
print("tacos") | ||
warnings.warn( | ||
"Learning rate reduction is conditioned on metric " | ||
f"`{self.monitor}` which is not available. Available metrics " | ||
f"are: {','.join(list(logs.keys()))}.", | ||
stacklevel=2, | ||
) | ||
else: | ||
if self.in_cooldown(): | ||
self.cooldown_counter -= 1 | ||
self.wait = 0 | ||
|
||
if self.monitor_op(current, self.best): | ||
self.best = current | ||
self.wait = 0 | ||
elif not self.in_cooldown(): | ||
self.wait += 1 | ||
if self.wait >= self.patience: | ||
old_lr = float(np.array(self.model.optimizer.learning_rate)) | ||
if old_lr > np.float32(self.min_lr): | ||
new_lr = old_lr * self.factor | ||
new_lr = max(new_lr, self.min_lr) | ||
self.model.optimizer.learning_rate = new_lr | ||
if self.verbose > 0: | ||
io_utils.print_msg( | ||
f"\nEpoch {epoch +1}: " | ||
"ReduceLROnPlateau reducing " | ||
f"learning rate to {new_lr}." | ||
) | ||
self.cooldown_counter = self.cooldown | ||
self.wait = 0 | ||
|
||
def in_cooldown(self): | ||
return self.cooldown_counter > 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
from keras_core import callbacks | ||
from keras_core import layers | ||
from keras_core import optimizers | ||
from keras_core import testing | ||
from keras_core.models import Sequential | ||
from keras_core.testing import test_utils | ||
from keras_core.utils import io_utils | ||
from keras_core.utils import numerical_utils | ||
|
||
|
||
class ReduceLROnPlateauTest(testing.TestCase): | ||
def setUp(self): | ||
(x_train, y_train), (x_test, y_test) = test_utils.get_test_data( | ||
train_samples=10, | ||
test_samples=10, | ||
input_shape=(3,), | ||
num_classes=2, | ||
) | ||
y_test = numerical_utils.to_categorical(y_test) | ||
y_train = numerical_utils.to_categorical(y_train) | ||
|
||
model = Sequential([layers.Dense(5), layers.Dense(2)]) | ||
|
||
model.compile( | ||
loss="mse", | ||
optimizer=optimizers.Adam(0.1), | ||
) | ||
|
||
self.model = model | ||
self.x_train = x_train | ||
self.x_test = x_test | ||
self.y_train = y_train | ||
self.y_test = y_test | ||
|
||
def test_reduces_lr_with_model_fit(self): | ||
reduce_lr = callbacks.ReduceLROnPlateau( | ||
patience=1, factor=0.1, monitor="val_loss", min_delta=10 | ||
) | ||
|
||
self.model.fit( | ||
self.x_train, | ||
self.y_train, | ||
validation_data=(self.x_test, self.y_test), | ||
callbacks=[reduce_lr], | ||
epochs=2, | ||
) | ||
|
||
self.assertEqual(self.model.optimizer.learning_rate.value, 0.01) | ||
|
||
def test_throws_when_optimizer_has_schedule(self): | ||
reduce_lr = callbacks.ReduceLROnPlateau( | ||
patience=1, factor=0.1, monitor="val_loss", min_delta=10 | ||
) | ||
|
||
self.model.compile( | ||
loss="mse", | ||
optimizer=optimizers.Adam( | ||
optimizers.schedules.PolynomialDecay( | ||
initial_learning_rate=0.1, decay_steps=10 | ||
) | ||
), | ||
) | ||
|
||
with self.assertRaisesRegex( | ||
TypeError, | ||
"This optimizer was created with a `LearningRateSchedule`", | ||
): | ||
self.model.fit( | ||
self.x_train, | ||
self.y_train, | ||
validation_data=(self.x_test, self.y_test), | ||
callbacks=[reduce_lr], | ||
epochs=2, | ||
) | ||
|
||
def test_verbose_logging(self): | ||
reduce_lr = callbacks.ReduceLROnPlateau( | ||
patience=1, factor=0.1, monitor="val_loss", min_delta=10, verbose=1 | ||
) | ||
io_utils.disable_interactive_logging() | ||
|
||
with self.assertLogs(level="INFO") as logs: | ||
self.model.fit( | ||
self.x_train, | ||
self.y_train, | ||
validation_data=(self.x_test, self.y_test), | ||
callbacks=[reduce_lr], | ||
epochs=2, | ||
) | ||
expected_log = "ReduceLROnPlateau reducing learning rate to 0.01" | ||
self.assertTrue(any(expected_log in log for log in logs.output)) | ||
|
||
def test_honors_min_lr(self): | ||
reduce_lr = callbacks.ReduceLROnPlateau( | ||
patience=1, | ||
factor=0.1, | ||
monitor="val_loss", | ||
min_delta=10, | ||
min_lr=0.005, | ||
) | ||
|
||
self.model.fit( | ||
self.x_train, | ||
self.y_train, | ||
validation_data=(self.x_test, self.y_test), | ||
callbacks=[reduce_lr], | ||
epochs=4, | ||
) | ||
|
||
self.assertEqual(self.model.optimizer.learning_rate.value, 0.005) | ||
|
||
def test_cooldown(self): | ||
reduce_lr = callbacks.ReduceLROnPlateau( | ||
patience=1, factor=0.1, monitor="val_loss", min_delta=10, cooldown=2 | ||
) | ||
|
||
self.model.fit( | ||
self.x_train, | ||
self.y_train, | ||
validation_data=(self.x_test, self.y_test), | ||
callbacks=[reduce_lr], | ||
epochs=4, | ||
) | ||
|
||
# With a cooldown of 2 epochs, we should only reduce the LR every other | ||
# epoch, so after 4 epochs we will have reduced 2 times. | ||
self.assertAllClose(self.model.optimizer.learning_rate.value, 0.001) |