Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
3 contributors

Users who have contributed to this file

@fchollet @hendriks73 @GregoryMorse
1249 lines (1035 sloc) 44.6 KB
"""Callbacks: utilities called at certain points during model training.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import csv
import six
import numpy as np
import time
import json
import warnings
import io
from collections import deque
from collections import OrderedDict
from collections import Iterable
from collections import defaultdict
from ..utils.generic_utils import Progbar
from .. import backend as K
from ..engine.training_utils import standardize_input_data
try:
import requests
except ImportError:
requests = None
_TRAIN = 'train'
_TEST = 'test'
_PREDICT = 'predict'
class CallbackList(object):
"""Container abstracting a list of callbacks.
# Arguments
callbacks: List of `Callback` instances.
queue_length: Queue length for keeping
running statistics over callback execution time.
"""
def __init__(self, callbacks=None, queue_length=10):
callbacks = callbacks or []
self.callbacks = [c for c in callbacks]
self.queue_length = queue_length
self.params = {}
self.model = None
self._reset_batch_timing()
def _reset_batch_timing(self):
self._delta_t_batch = 0.
self._delta_ts = defaultdict(lambda: deque([], maxlen=self.queue_length))
def append(self, callback):
self.callbacks.append(callback)
def set_params(self, params):
self.params = params
for callback in self.callbacks:
callback.set_params(params)
def set_model(self, model):
self.model = model
for callback in self.callbacks:
callback.set_model(model)
def _call_batch_hook(self, mode, hook, batch, logs=None):
"""Helper function for all batch_{begin | end} methods."""
if not self.callbacks:
return
hook_name = 'on_{mode}_batch_{hook}'.format(mode=mode, hook=hook)
if hook == 'end':
if not hasattr(self, '_t_enter_batch'):
self._t_enter_batch = time.time()
# Batch is ending, calculate batch time
self._delta_t_batch = time.time() - self._t_enter_batch
logs = logs or {}
t_before_callbacks = time.time()
for callback in self.callbacks:
batch_hook = getattr(callback, hook_name)
batch_hook(batch, logs)
self._delta_ts[hook_name].append(time.time() - t_before_callbacks)
delta_t_median = np.median(self._delta_ts[hook_name])
if (self._delta_t_batch > 0. and
delta_t_median > 0.95 * self._delta_t_batch and
delta_t_median > 0.1):
warnings.warn(
'Method (%s) is slow compared '
'to the batch update (%f). Check your callbacks.'
% (hook_name, delta_t_median), RuntimeWarning)
if hook == 'begin':
self._t_enter_batch = time.time()
def _call_begin_hook(self, mode):
"""Helper function for on_{train|test|predict}_begin methods."""
if mode == _TRAIN:
self.on_train_begin()
elif mode == _TEST:
self.on_test_begin()
else:
self.on_predict_begin()
def _call_end_hook(self, mode):
"""Helper function for on_{train|test|predict}_end methods."""
if mode == _TRAIN:
self.on_train_end()
elif mode == _TEST:
self.on_test_end()
else:
self.on_predict_end()
def on_batch_begin(self, batch, logs=None):
self._call_batch_hook(_TRAIN, 'begin', batch, logs=logs)
def on_batch_end(self, batch, logs=None):
self._call_batch_hook(_TRAIN, 'end', batch, logs=logs)
def on_epoch_begin(self, epoch, logs=None):
"""Calls the `on_epoch_begin` methods of its callbacks.
This function should only be called during train mode.
# Arguments
epoch: integer, index of epoch.
logs: dict, Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = logs or {}
for callback in self.callbacks:
callback.on_epoch_begin(epoch, logs)
self._reset_batch_timing()
def on_epoch_end(self, epoch, logs=None):
"""Calls the `on_epoch_end` methods of its callbacks.
This function should only be called during train mode.
# Arguments
epoch: integer, index of epoch.
logs: dict, metric results for this training epoch, and for the
validation epoch if validation is performed. Validation result keys
are prefixed with `val_`.
"""
logs = logs or {}
for callback in self.callbacks:
callback.on_epoch_end(epoch, logs)
def on_train_batch_begin(self, batch, logs=None):
"""Calls the `on_train_batch_begin` methods of its callbacks.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dict, has keys `batch` and `size` representing the current
batch number and the size of the batch.
"""
self._call_batch_hook(_TRAIN, 'begin', batch, logs=logs)
def on_train_batch_end(self, batch, logs=None):
"""Calls the `on_train_batch_end` methods of its callbacks.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dict, metric results for this batch.
"""
self._call_batch_hook(_TRAIN, 'end', batch, logs=logs)
def on_test_batch_begin(self, batch, logs=None):
"""Calls the `on_test_batch_begin` methods of its callbacks.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dict, has keys `batch` and `size` representing the current
batch number and the size of the batch.
"""
self._call_batch_hook(_TEST, 'begin', batch, logs=logs)
def on_test_batch_end(self, batch, logs=None):
"""Calls the `on_test_batch_end` methods of its callbacks.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dict, metric results for this batch.
"""
self._call_batch_hook(_TEST, 'end', batch, logs=logs)
def on_predict_batch_begin(self, batch, logs=None):
"""Calls the `on_predict_batch_begin` methods of its callbacks.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dict, has keys `batch` and `size` representing the current
batch number and the size of the batch.
"""
self._call_batch_hook(_PREDICT, 'begin', batch, logs=logs)
def on_predict_batch_end(self, batch, logs=None):
"""Calls the `on_predict_batch_end` methods of its callbacks.
# Argument
batch: integer, index of batch within the current epoch.
logs: dict, metric results for this batch.
"""
self._call_batch_hook(_PREDICT, 'end', batch, logs=logs)
def on_train_begin(self, logs=None):
"""Calls the `on_train_begin` methods of its callbacks.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
for callback in self.callbacks:
callback.on_train_begin(logs)
def on_train_end(self, logs=None):
"""Calls the `on_train_end` methods of its callbacks.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
for callback in self.callbacks:
callback.on_train_end(logs)
def on_test_begin(self, logs=None):
"""Calls the `on_test_begin` methods of its callbacks.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
for callback in self.callbacks:
callback.on_test_begin(logs)
def on_test_end(self, logs=None):
"""Calls the `on_test_end` methods of its callbacks.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
for callback in self.callbacks:
callback.on_test_end(logs)
def on_predict_begin(self, logs=None):
"""Calls the `on_predict_begin` methods of its callbacks.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
for callback in self.callbacks:
callback.on_predict_begin(logs)
def on_predict_end(self, logs=None):
"""Calls the `on_predict_end` methods of its callbacks.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
for callback in self.callbacks:
callback.on_predict_end(logs)
def __iter__(self):
return iter(self.callbacks)
class Callback(object):
"""Abstract base class used to build new callbacks.
# Properties
params: dict. Training parameters
(eg. verbosity, batch size, number of epochs...).
model: instance of `keras.models.Model`.
Reference of the model being trained.
The `logs` dictionary that callback methods
take as argument will contain keys for quantities relevant to
the current batch or epoch.
Currently, the `.fit()` method of the `Sequential` model class
will include the following quantities in the `logs` that
it passes to its callbacks:
on_epoch_end: logs include `acc` and `loss`, and
optionally include `val_loss`
(if validation is enabled in `fit`), and `val_acc`
(if validation and accuracy monitoring are enabled).
on_batch_begin: logs include `size`,
the number of samples in the current batch.
on_batch_end: logs include `loss`, and optionally `acc`
(if accuracy monitoring is enabled).
"""
def __init__(self):
self.validation_data = None
self.model = None
def set_params(self, params):
self.params = params
def set_model(self, model):
self.model = model
def on_batch_begin(self, batch, logs=None):
"""A backwards compatibility alias for `on_train_batch_begin`."""
def on_batch_end(self, batch, logs=None):
"""A backwards compatibility alias for `on_train_batch_end`."""
def on_epoch_begin(self, epoch, logs=None):
"""Called at the start of an epoch.
Subclasses should override for any actions to run. This function should only
be called during train mode.
# Arguments
epoch: integer, index of epoch.
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
def on_epoch_end(self, epoch, logs=None):
"""Called at the end of an epoch.
Subclasses should override for any actions to run. This function should only
be called during train mode.
# Arguments
epoch: integer, index of epoch.
logs: dict, metric results for this training epoch, and for the
validation epoch if validation is performed. Validation result keys
are prefixed with `val_`.
"""
def on_train_batch_begin(self, batch, logs=None):
"""Called at the beginning of a training batch in `fit` methods.
Subclasses should override for any actions to run.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dict, has keys `batch` and `size` representing the current
batch number and the size of the batch.
"""
# For backwards compatibility
self.on_batch_begin(batch, logs=logs)
def on_train_batch_end(self, batch, logs=None):
"""Called at the end of a training batch in `fit` methods.
Subclasses should override for any actions to run.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dict, metric results for this batch.
"""
# For backwards compatibility
self.on_batch_end(batch, logs=logs)
def on_test_batch_begin(self, batch, logs=None):
"""Called at the beginning of a batch in `evaluate` methods.
Also called at the beginning of a validation batch in the `fit` methods,
if validation data is provided.
Subclasses should override for any actions to run.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dict, has keys `batch` and `size` representing the current
batch number and the size of the batch.
"""
def on_test_batch_end(self, batch, logs=None):
"""Called at the end of a batch in `evaluate` methods.
Also called at the end of a validation batch in the `fit` methods,
if validation data is provided.
Subclasses should override for any actions to run.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dict, metric results for this batch.
"""
def on_predict_batch_begin(self, batch, logs=None):
"""Called at the beginning of a batch in `predict` methods.
Subclasses should override for any actions to run.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dict, has keys `batch` and `size` representing the current
batch number and the size of the batch.
"""
def on_predict_batch_end(self, batch, logs=None):
"""Called at the end of a batch in `predict` methods.
Subclasses should override for any actions to run.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dict, metric results for this batch.
"""
def on_train_begin(self, logs=None):
"""Called at the beginning of training.
Subclasses should override for any actions to run.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
def on_train_end(self, logs=None):
"""Called at the end of training.
Subclasses should override for any actions to run.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
def on_test_begin(self, logs=None):
"""Called at the beginning of evaluation or validation.
Subclasses should override for any actions to run.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
def on_test_end(self, logs=None):
"""Called at the end of evaluation or validation.
Subclasses should override for any actions to run.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
def on_predict_begin(self, logs=None):
"""Called at the beginning of prediction.
Subclasses should override for any actions to run.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
def on_predict_end(self, logs=None):
"""Called at the end of prediction.
Subclasses should override for any actions to run.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
class BaseLogger(Callback):
"""Callback that accumulates epoch averages of metrics.
This callback is automatically applied to every Keras model.
# Arguments
stateful_metrics: Iterable of string names of metrics that
should *not* be averaged over an epoch.
Metrics in this list will be logged as-is in `on_epoch_end`.
All others will be averaged in `on_epoch_end`.
"""
def __init__(self, stateful_metrics=None):
if stateful_metrics:
self.stateful_metrics = set(stateful_metrics)
else:
self.stateful_metrics = set()
def on_epoch_begin(self, epoch, logs=None):
self.seen = 0
self.totals = {}
def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', 0)
self.seen += batch_size
for k, v in logs.items():
if k in self.stateful_metrics:
self.totals[k] = v
else:
if k in self.totals:
self.totals[k] += v * batch_size
else:
self.totals[k] = v * batch_size
def on_epoch_end(self, epoch, logs=None):
if logs is not None:
for k in self.params['metrics']:
if k in self.totals:
# Make value available to next callbacks.
if k in self.stateful_metrics:
logs[k] = self.totals[k]
else:
logs[k] = self.totals[k] / self.seen
class TerminateOnNaN(Callback):
"""Callback that terminates training when a NaN loss is encountered.
"""
def on_batch_end(self, batch, logs=None):
logs = logs or {}
loss = logs.get('loss')
if loss is not None:
if np.isnan(loss) or np.isinf(loss):
print('Batch %d: Invalid loss, terminating training' % (batch))
self.model.stop_training = True
class ProgbarLogger(Callback):
"""Callback that prints metrics to stdout.
# Arguments
count_mode: One of "steps" or "samples".
Whether the progress bar should
count samples seen or steps (batches) seen.
stateful_metrics: Iterable of string names of metrics that
should *not* be averaged over an epoch.
Metrics in this list will be logged as-is.
All others will be averaged over time (e.g. loss, etc).
# Raises
ValueError: In case of invalid `count_mode`.
"""
def __init__(self, count_mode='samples',
stateful_metrics=None):
super(ProgbarLogger, self).__init__()
if count_mode == 'samples':
self.use_steps = False
elif count_mode == 'steps':
self.use_steps = True
else:
raise ValueError('Unknown `count_mode`: ' + str(count_mode))
if stateful_metrics:
self.stateful_metrics = set(stateful_metrics)
else:
self.stateful_metrics = set()
def on_train_begin(self, logs=None):
self.verbose = self.params['verbose']
self.epochs = self.params['epochs']
def on_epoch_begin(self, epoch, logs=None):
if self.verbose:
print('Epoch %d/%d' % (epoch + 1, self.epochs))
if self.use_steps:
target = self.params['steps']
else:
target = self.params['samples']
self.target = target
self.progbar = Progbar(target=self.target,
verbose=self.verbose,
stateful_metrics=self.stateful_metrics)
self.seen = 0
def on_batch_begin(self, batch, logs=None):
if self.seen < self.target:
self.log_values = []
def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', 0)
if self.use_steps:
self.seen += 1
else:
self.seen += batch_size
for k in self.params['metrics']:
if k in logs:
self.log_values.append((k, logs[k]))
# Skip progbar update for the last batch;
# will be handled by on_epoch_end.
if self.verbose and self.seen < self.target:
self.progbar.update(self.seen, self.log_values)
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
for k in self.params['metrics']:
if k in logs:
self.log_values.append((k, logs[k]))
if self.verbose:
self.progbar.update(self.seen, self.log_values)
class History(Callback):
"""Callback that records events into a `History` object.
This callback is automatically applied to
every Keras model. The `History` object
gets returned by the `fit` method of models.
"""
def on_train_begin(self, logs=None):
self.epoch = []
self.history = {}
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epoch.append(epoch)
for k, v in logs.items():
self.history.setdefault(k, []).append(v)
class ModelCheckpoint(Callback):
"""Save the model after every epoch.
`filepath` can contain named formatting options,
which will be filled with the values of `epoch` and
keys in `logs` (passed in `on_epoch_end`).
For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`,
then the model checkpoints will be saved with the epoch number and
the validation loss in the filename.
# Arguments
filepath: string, path to save the model file.
monitor: quantity to monitor.
verbose: verbosity mode, 0 or 1.
save_best_only: if `save_best_only=True`,
the latest best model according to
the quantity monitored will not be overwritten.
save_weights_only: if True, then only the model's weights will be
saved (`model.save_weights(filepath)`), else the full model
is saved (`model.save(filepath)`).
mode: one of {auto, min, max}.
If `save_best_only=True`, the decision
to overwrite the current save file is made
based on either the maximization or the
minimization of the monitored quantity. For `val_acc`,
this should be `max`, for `val_loss` this should
be `min`, etc. In `auto` mode, the direction is
automatically inferred from the name of the monitored quantity.
period: Interval (number of epochs) between checkpoints.
"""
def __init__(self, filepath, monitor='val_loss', verbose=0,
save_best_only=False, save_weights_only=False,
mode='auto', period=1):
super(ModelCheckpoint, self).__init__()
self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
self.save_best_only = save_best_only
self.save_weights_only = save_weights_only
self.period = period
self.epochs_since_last_save = 0
if mode not in ['auto', 'min', 'max']:
warnings.warn('ModelCheckpoint mode %s is unknown, '
'fallback to auto mode.' % (mode),
RuntimeWarning)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
self.best = np.Inf
elif mode == 'max':
self.monitor_op = np.greater
self.best = -np.Inf
else:
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
self.monitor_op = np.greater
self.best = -np.Inf
else:
self.monitor_op = np.less
self.best = np.Inf
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epochs_since_last_save += 1
if self.epochs_since_last_save >= self.period:
self.epochs_since_last_save = 0
filepath = self.filepath.format(epoch=epoch + 1, **logs)
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
warnings.warn('Can save best model only with %s available, '
'skipping.' % (self.monitor), RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if self.verbose > 0:
print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
' saving model to %s'
% (epoch + 1, self.monitor, self.best,
current, filepath))
self.best = current
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
else:
if self.verbose > 0:
print('\nEpoch %05d: %s did not improve from %0.5f' %
(epoch + 1, self.monitor, self.best))
else:
if self.verbose > 0:
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
class EarlyStopping(Callback):
"""Stop training when a monitored quantity has stopped improving.
# Arguments
monitor: quantity to be monitored.
min_delta: minimum change in the monitored quantity
to qualify as an improvement, i.e. an absolute
change of less than min_delta, will count as no
improvement.
patience: number of epochs that produced the monitored
quantity with no improvement after which training will
be stopped.
Validation quantities may not be produced for every
epoch, if the validation frequency
(`model.fit(validation_freq=5)`) is greater than one.
verbose: verbosity mode.
mode: one of {auto, min, max}. In `min` mode,
training will stop when the quantity
monitored has stopped decreasing; in `max`
mode it will stop when the quantity
monitored has stopped increasing; in `auto`
mode, the direction is automatically inferred
from the name of the monitored quantity.
baseline: Baseline value for the monitored quantity to reach.
Training will stop if the model doesn't show improvement
over the baseline.
restore_best_weights: whether to restore model weights from
the epoch with the best value of the monitored quantity.
If False, the model weights obtained at the last step of
training are used.
"""
def __init__(self,
monitor='val_loss',
min_delta=0,
patience=0,
verbose=0,
mode='auto',
baseline=None,
restore_best_weights=False):
super(EarlyStopping, self).__init__()
self.monitor = monitor
self.baseline = baseline
self.patience = patience
self.verbose = verbose
self.min_delta = min_delta
self.wait = 0
self.stopped_epoch = 0
self.restore_best_weights = restore_best_weights
self.best_weights = None
if mode not in ['auto', 'min', 'max']:
warnings.warn('EarlyStopping mode %s is unknown, '
'fallback to auto mode.' % mode,
RuntimeWarning)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less
if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
def on_train_begin(self, logs=None):
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
if self.baseline is not None:
self.best = self.baseline
else:
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
def on_epoch_end(self, epoch, logs=None):
current = self.get_monitor_value(logs)
if current is None:
return
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
if self.restore_best_weights:
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
if self.restore_best_weights:
if self.verbose > 0:
print('Restoring model weights from the end of '
'the best epoch')
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0 and self.verbose > 0:
print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
def get_monitor_value(self, logs):
monitor_value = logs.get(self.monitor)
if monitor_value is None:
warnings.warn(
'Early stopping conditioned on metric `%s` '
'which is not available. Available metrics are: %s' %
(self.monitor, ','.join(list(logs.keys()))), RuntimeWarning
)
return monitor_value
class RemoteMonitor(Callback):
"""Callback used to stream events to a server.
Requires the `requests` library.
Events are sent to `root + '/publish/epoch/end/'` by default. Calls are
HTTP POST, with a `data` argument which is a
JSON-encoded dictionary of event data.
If send_as_json is set to True, the content type of the request will be
application/json. Otherwise the serialized JSON will be send within a form
# Arguments
root: String; root url of the target server.
path: String; path relative to `root` to which the events will be sent.
field: String; JSON field under which the data will be stored.
The field is used only if the payload is sent within a form
(i.e. send_as_json is set to False).
headers: Dictionary; optional custom HTTP headers.
send_as_json: Boolean; whether the request should be send as
application/json.
"""
def __init__(self,
root='http://localhost:9000',
path='/publish/epoch/end/',
field='data',
headers=None,
send_as_json=False):
super(RemoteMonitor, self).__init__()
self.root = root
self.path = path
self.field = field
self.headers = headers
self.send_as_json = send_as_json
def on_epoch_end(self, epoch, logs=None):
if requests is None:
raise ImportError('RemoteMonitor requires '
'the `requests` library.')
logs = logs or {}
send = {}
send['epoch'] = epoch
for k, v in logs.items():
if isinstance(v, (np.ndarray, np.generic)):
send[k] = v.item()
else:
send[k] = v
try:
if self.send_as_json:
requests.post(self.root + self.path, json=send, headers=self.headers)
else:
requests.post(self.root + self.path,
{self.field: json.dumps(send)},
headers=self.headers)
except requests.exceptions.RequestException:
warnings.warn('Warning: could not reach RemoteMonitor '
'root server at ' + str(self.root))
class LearningRateScheduler(Callback):
"""Learning rate scheduler.
# Arguments
schedule: a function that takes an epoch index as input
(integer, indexed from 0) and current learning rate
and returns a new learning rate as output (float).
verbose: int. 0: quiet, 1: update messages.
"""
def __init__(self, schedule, verbose=0):
super(LearningRateScheduler, self).__init__()
self.schedule = schedule
self.verbose = verbose
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'lr'):
raise ValueError('Optimizer must have a "lr" attribute.')
lr = float(K.get_value(self.model.optimizer.lr))
try: # new API
lr = self.schedule(epoch, lr)
except TypeError: # old API for backward compatibility
lr = self.schedule(epoch)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function '
'should be float.')
K.set_value(self.model.optimizer.lr, lr)
if self.verbose > 0:
print('\nEpoch %05d: LearningRateScheduler setting learning '
'rate to %s.' % (epoch + 1, lr))
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
logs['lr'] = K.get_value(self.model.optimizer.lr)
class ReduceLROnPlateau(Callback):
"""Reduce learning rate when a metric has stopped improving.
Models often benefit from reducing the learning rate by a factor
of 2-10 once learning stagnates. This callback monitors a
quantity and if no improvement is seen for a 'patience' number
of epochs, the learning rate is reduced.
# Example
```python
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
patience=5, min_lr=0.001)
model.fit(X_train, Y_train, callbacks=[reduce_lr])
```
# Arguments
monitor: quantity to be monitored.
factor: factor by which the learning rate will
be reduced. new_lr = lr * factor
patience: number of epochs that produced the monitored
quantity with no improvement after which training will
be stopped.
Validation quantities may not be produced for every
epoch, if the validation frequency
(`model.fit(validation_freq=5)`) is greater than one.
verbose: int. 0: quiet, 1: update messages.
mode: one of {auto, min, max}. In `min` mode,
lr will be reduced when the quantity
monitored has stopped decreasing; in `max`
mode it will be reduced when the quantity
monitored has stopped increasing; in `auto`
mode, the direction is automatically inferred
from the name of the monitored quantity.
min_delta: threshold for measuring the new optimum,
to only focus on significant changes.
cooldown: number of epochs to wait before resuming
normal operation after lr has been reduced.
min_lr: lower bound on the learning rate.
"""
def __init__(self, monitor='val_loss', factor=0.1, patience=10,
verbose=0, mode='auto', min_delta=1e-4, cooldown=0, min_lr=0,
**kwargs):
super(ReduceLROnPlateau, self).__init__()
self.monitor = monitor
if factor >= 1.0:
raise ValueError('ReduceLROnPlateau '
'does not support a factor >= 1.0.')
if 'epsilon' in kwargs:
min_delta = kwargs.pop('epsilon')
warnings.warn('`epsilon` argument is deprecated and '
'will be removed, use `min_delta` instead.')
self.factor = factor
self.min_lr = min_lr
self.min_delta = min_delta
self.patience = patience
self.verbose = verbose
self.cooldown = cooldown
self.cooldown_counter = 0 # Cooldown counter.
self.wait = 0
self.best = 0
self.mode = mode
self.monitor_op = None
self._reset()
def _reset(self):
"""Resets wait counter and cooldown counter.
"""
if self.mode not in ['auto', 'min', 'max']:
warnings.warn('Learning Rate Plateau Reducing mode %s is unknown, '
'fallback to auto mode.' % (self.mode),
RuntimeWarning)
self.mode = 'auto'
if (self.mode == 'min' or
(self.mode == 'auto' and 'acc' not in self.monitor)):
self.monitor_op = lambda a, b: np.less(a, b - self.min_delta)
self.best = np.Inf
else:
self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta)
self.best = -np.Inf
self.cooldown_counter = 0
self.wait = 0
def on_train_begin(self, logs=None):
self._reset()
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
logs['lr'] = K.get_value(self.model.optimizer.lr)
current = logs.get(self.monitor)
if current is None:
warnings.warn(
'Reduce LR on plateau conditioned on metric `%s` '
'which is not available. Available metrics are: %s' %
(self.monitor, ','.join(list(logs.keys()))), RuntimeWarning
)
else:
if self.in_cooldown():
self.cooldown_counter -= 1
self.wait = 0
if self.monitor_op(current, self.best):
self.best = current
self.wait = 0
elif not self.in_cooldown():
self.wait += 1
if self.wait >= self.patience:
old_lr = float(K.get_value(self.model.optimizer.lr))
if old_lr > self.min_lr:
new_lr = old_lr * self.factor
new_lr = max(new_lr, self.min_lr)
K.set_value(self.model.optimizer.lr, new_lr)
if self.verbose > 0:
print('\nEpoch %05d: ReduceLROnPlateau reducing '
'learning rate to %s.' % (epoch + 1, new_lr))
self.cooldown_counter = self.cooldown
self.wait = 0
def in_cooldown(self):
return self.cooldown_counter > 0
class CSVLogger(Callback):
"""Callback that streams epoch results to a csv file.
Supports all values that can be represented as a string,
including 1D iterables such as np.ndarray.
# Example
```python
csv_logger = CSVLogger('training.log')
model.fit(X_train, Y_train, callbacks=[csv_logger])
```
# Arguments
filename: filename of the csv file, e.g. 'run/log.csv'.
separator: string used to separate elements in the csv file.
append: True: append if file exists (useful for continuing
training). False: overwrite existing file,
"""
def __init__(self, filename, separator=',', append=False):
self.sep = separator
self.filename = filename
self.append = append
self.writer = None
self.keys = None
self.append_header = True
if six.PY2:
self.file_flags = 'b'
self._open_args = {}
else:
self.file_flags = ''
self._open_args = {'newline': '\n'}
super(CSVLogger, self).__init__()
def on_train_begin(self, logs=None):
if self.append:
if os.path.exists(self.filename):
with open(self.filename, 'r' + self.file_flags) as f:
self.append_header = not bool(len(f.readline()))
mode = 'a'
else:
mode = 'w'
self.csv_file = io.open(self.filename,
mode + self.file_flags,
**self._open_args)
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
def handle_value(k):
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
if isinstance(k, six.string_types):
return k
elif isinstance(k, Iterable) and not is_zero_dim_ndarray:
return '"[%s]"' % (', '.join(map(str, k)))
else:
return k
if self.keys is None:
self.keys = sorted(logs.keys())
if self.model.stop_training:
# We set NA so that csv parsers do not fail for this last epoch.
logs = dict([(k, logs[k] if k in logs else 'NA') for k in self.keys])
if not self.writer:
class CustomDialect(csv.excel):
delimiter = self.sep
fieldnames = ['epoch'] + self.keys
if six.PY2:
fieldnames = [unicode(x) for x in fieldnames]
self.writer = csv.DictWriter(self.csv_file,
fieldnames=fieldnames,
dialect=CustomDialect)
if self.append_header:
self.writer.writeheader()
row_dict = OrderedDict({'epoch': epoch})
row_dict.update((key, handle_value(logs[key])) for key in self.keys)
self.writer.writerow(row_dict)
self.csv_file.flush()
def on_train_end(self, logs=None):
self.csv_file.close()
self.writer = None
def __del__(self):
if hasattr(self, 'csv_file') and not self.csv_file.closed:
self.csv_file.close()
class LambdaCallback(Callback):
r"""Callback for creating simple, custom callbacks on-the-fly.
This callback is constructed with anonymous functions that will be called
at the appropriate time. Note that the callbacks expects positional
arguments, as:
- `on_epoch_begin` and `on_epoch_end` expect two positional arguments:
`epoch`, `logs`
- `on_batch_begin` and `on_batch_end` expect two positional arguments:
`batch`, `logs`
- `on_train_begin` and `on_train_end` expect one positional argument:
`logs`
# Arguments
on_epoch_begin: called at the beginning of every epoch.
on_epoch_end: called at the end of every epoch.
on_batch_begin: called at the beginning of every batch.
on_batch_end: called at the end of every batch.
on_train_begin: called at the beginning of model training.
on_train_end: called at the end of model training.
# Example
```python
# Print the batch number at the beginning of every batch.
batch_print_callback = LambdaCallback(
on_batch_begin=lambda batch,logs: print(batch))
# Stream the epoch loss to a file in JSON format. The file content
# is not well-formed JSON but rather has a JSON object per line.
import json
json_log = open('loss_log.json', mode='wt', buffering=1)
json_logging_callback = LambdaCallback(
on_epoch_end=lambda epoch, logs: json_log.write(
json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'),
on_train_end=lambda logs: json_log.close()
)
# Terminate some processes after having finished model training.
processes = ...
cleanup_callback = LambdaCallback(
on_train_end=lambda logs: [
p.terminate() for p in processes if p.is_alive()])
model.fit(...,
callbacks=[batch_print_callback,
json_logging_callback,
cleanup_callback])
```
"""
def __init__(self,
on_epoch_begin=None,
on_epoch_end=None,
on_batch_begin=None,
on_batch_end=None,
on_train_begin=None,
on_train_end=None,
**kwargs):
super(LambdaCallback, self).__init__()
self.__dict__.update(kwargs)
if on_epoch_begin is not None:
self.on_epoch_begin = on_epoch_begin
else:
self.on_epoch_begin = lambda epoch, logs: None
if on_epoch_end is not None:
self.on_epoch_end = on_epoch_end
else:
self.on_epoch_end = lambda epoch, logs: None
if on_batch_begin is not None:
self.on_batch_begin = on_batch_begin
else:
self.on_batch_begin = lambda batch, logs: None
if on_batch_end is not None:
self.on_batch_end = on_batch_end
else:
self.on_batch_end = lambda batch, logs: None
if on_train_begin is not None:
self.on_train_begin = on_train_begin
else:
self.on_train_begin = lambda logs: None
if on_train_end is not None:
self.on_train_end = on_train_end
else:
self.on_train_end = lambda logs: None
You can’t perform that action at this time.