Skip to content

Commit

Permalink
Adding support for Loss instances in model compile. (#12915)
Browse files Browse the repository at this point in the history
* Updating training utils.

* training changes for supporting Loss instances.

* Add weight broadcasting.

* Saving test

* Add correctness test.
  • Loading branch information
pavithrasv authored and fchollet committed Jun 6, 2019
1 parent ab3ef6f commit 910e124
Show file tree
Hide file tree
Showing 10 changed files with 469 additions and 246 deletions.
6 changes: 5 additions & 1 deletion keras/engine/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np

from .. import backend as K
from .. import losses
from .. import optimizers
from ..utils.io_utils import H5Dict
from ..utils.io_utils import ask_to_proceed_with_overwrite
Expand Down Expand Up @@ -347,7 +348,10 @@ def convert_custom_objects(obj):
custom_objects=custom_objects)

# Recover loss functions and metrics.
loss = convert_custom_objects(training_config['loss'])
loss_config = training_config['loss'] # Deserialize loss class.
if isinstance(loss_config, dict) and 'class_name' in loss_config:
loss_config = losses.get(loss_config)
loss = convert_custom_objects(loss_config)
metrics = convert_custom_objects(training_config['metrics'])
sample_weight_mode = training_config['sample_weight_mode']
loss_weights = training_config['loss_weights']
Expand Down
360 changes: 138 additions & 222 deletions keras/engine/training.py

Large diffs are not rendered by default.

219 changes: 208 additions & 11 deletions keras/engine/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import collections
import copy
import numpy as np
import six
import warnings

from .. import backend as K
from .. import losses
from ..utils import Sequence
from ..utils.generic_utils import to_list
from ..utils import generic_utils
from ..utils import losses_utils


def standardize_single_array(x):
Expand Down Expand Up @@ -252,7 +254,8 @@ def set_of_lengths(x):
def check_loss_and_target_compatibility(targets, loss_fns, output_shapes):
"""Does validation on the compatibility of targets and loss functions.
This helps prevent users from using loss functions incorrectly.
This helps prevent users from using loss functions incorrectly. This check
is purely for UX purposes.
# Arguments
targets: list of Numpy arrays of targets.
Expand All @@ -263,13 +266,16 @@ def check_loss_and_target_compatibility(targets, loss_fns, output_shapes):
ValueError: if a loss function or target array
is incompatible with an output.
"""
key_losses = {losses.mean_squared_error,
losses.binary_crossentropy,
losses.categorical_crossentropy}
key_loss_fns = {
losses.mean_squared_error, losses.binary_crossentropy,
losses.categorical_crossentropy
}
key_loss_classes = (losses.MeanSquaredError, losses.BinaryCrossentropy,
losses.CategoricalCrossentropy)
for y, loss, shape in zip(targets, loss_fns, output_shapes):
if y is None or loss is None:
continue
if loss is losses.categorical_crossentropy:
if losses.is_categorical_crossentropy(loss):
if y.shape[-1] == 1:
raise ValueError(
'You are passing a target array of shape ' + str(y.shape) +
Expand All @@ -287,15 +293,20 @@ def check_loss_and_target_compatibility(targets, loss_fns, output_shapes):
'Alternatively, you can use the loss function '
'`sparse_categorical_crossentropy` instead, '
'which does expect integer targets.')
if loss in key_losses:
is_loss_wrapper = isinstance(loss, losses.LossFunctionWrapper)
if (isinstance(loss, key_loss_classes) or (is_loss_wrapper and
(loss.fn in key_loss_fns))):
for target_dim, out_dim in zip(y.shape[1:], shape[1:]):
if out_dim is not None and target_dim != out_dim:
loss_name = loss.name
if loss_name is None:
loss_type = loss.fn if is_loss_wrapper else type(loss)
loss_name = loss_type.__name__
raise ValueError(
'A target array with shape ' + str(y.shape) +
' was passed for an output of shape ' + str(shape) +
' while using as loss `' + loss.__name__ + '`. '
'This loss expects '
'targets to have the same shape '
' while using as loss `' + loss_name + '`. '
'This loss expects targets to have the same shape '
'as the output.')


Expand Down Expand Up @@ -347,7 +358,7 @@ def collect_metrics(metrics, output_names):
.format(unknown_output_names, output_names))
for name in output_names:
output_metrics = metrics.get(name, [])
output_metrics = to_list(output_metrics)
output_metrics = generic_utils.to_list(output_metrics)
nested_metrics.append(output_metrics)
return nested_metrics
else:
Expand Down Expand Up @@ -713,3 +724,189 @@ def _is_graph_model(layer):
if hasattr(layer, '_batch_input_shape'):
return layer._batch_input_shape, layer.dtype
return None, None


def get_loss_function(loss):
"""Returns the loss corresponding to the loss input in `compile` API."""
if loss is None or isinstance(loss, losses.Loss):
return loss

# Deserialize loss configuration, if needed.
if isinstance(loss, collections.Mapping):
loss = losses.get(loss)

# Custom callable class.
if callable(loss) and not hasattr(loss, '__name__'):
return loss

# Wrap loss function with signature `(y_true, y_pred, **kwargs)`
# in `LossFunctionWrapper` class.
loss_fn = losses.get(loss)

# For losses which are given as strings/functions in the compile API,
# we always set the loss reduction type to be `SUM_OVER_BATCH_SIZE`..
return losses.LossFunctionWrapper(
loss_fn,
name=loss_fn.__name__,
reduction=losses_utils.Reduction.SUM_OVER_BATCH_SIZE)


def get_output_sample_weight_and_mode(skip_target_weighing_indices,
sample_weight_mode, output_name,
output_index):
"""Returns the sample weight and weight mode for a single output."""
if output_index in skip_target_weighing_indices:
return None, None

if sample_weight_mode == 'temporal':
shape = [None, None]
mode = 'temporal'
else:
shape = [None]
mode = None
weight = K.placeholder(
shape=shape,
name=output_name + '_sample_weights')
return weight, mode


def prepare_sample_weights(output_names, sample_weight_mode,
skip_target_weighing_indices):
"""Prepares sample weights for the model.
# Arguments
output_names: List of model output names.
sample_weight_mode: sample weight mode user input passed from compile API.
skip_target_weighing_indices: Indices of output for which sample weights
should be skipped.
# Returns
A pair of list of sample weights and sample weight modes
(one for each output).
# Raises
ValueError: In case of invalid `sample_weight_mode` input.
"""
sample_weights = []
sample_weight_modes = []
if isinstance(sample_weight_mode, dict):
unknown_output = set(sample_weight_mode.keys()) - set(output_names)
if unknown_output:
raise ValueError(
'Unknown entry in '
'sample_weight_mode dictionary: "' + str(unknown_output) +
'". Only expected the following keys: ' + str(output_names))
for i, name in enumerate(output_names):
if (i not in skip_target_weighing_indices and
name not in sample_weight_mode):
raise ValueError(
'Output missing from sample_weight_modes dictionary')
weight, mode = get_output_sample_weight_and_mode(
skip_target_weighing_indices,
sample_weight_mode.get(name),
name,
i)
sample_weights.append(weight)
sample_weight_modes.append(mode)
elif isinstance(sample_weight_mode, list):
if len(sample_weight_mode) != len(output_names):
raise ValueError('When passing a list as sample_weight_mode, '
'it should have one entry per model output. '
'The model has ' + str(len(output_names)) +
' outputs, but you passed ' +
str(len(sample_weight_mode)) + 'sample_weight_modes')
for i, name in enumerate(output_names):
weight, mode = get_output_sample_weight_and_mode(
skip_target_weighing_indices, sample_weight_mode[i], name, i)
sample_weights.append(weight)
sample_weight_modes.append(mode)
else:
for i, name in enumerate(output_names):
weight, mode = get_output_sample_weight_and_mode(
skip_target_weighing_indices, sample_weight_mode, name, i)
sample_weights.append(weight)
sample_weight_modes.append(mode)
return sample_weights, sample_weight_modes


def prepare_loss_functions(loss, output_names):
"""Converts loss to a list of loss functions.
# Arguments
loss: String (name of objective function), objective function or
`Loss` instance. If the model has multiple outputs, you can use
a different loss on each output by passing a dictionary or a
list of losses. The loss value that will be minimized by the model
will then be the sum of all individual losses.
output_names: List of model output names.
# Returns
A list of loss objective functions.
# Raises:
ValueError: If loss is a dict with keys not in model output names,
or if loss is a list with len not equal to model outputs.
"""
if isinstance(loss, collections.Mapping):
generic_utils.check_for_unexpected_keys('loss', loss, output_names)
loss_functions = []
for name in output_names:
if name not in loss:
warnings.warn(
'Output {0} missing from loss dictionary. We assume '
'this was done on purpose. The fit and evaluate APIs will not '
'be expecting any data to be passed to {0}.'.format(name))
loss_functions.append(get_loss_function(loss.get(name, None)))
elif isinstance(loss, six.string_types):
loss_functions = [get_loss_function(loss) for _ in output_names]
elif isinstance(loss, collections.Sequence):
if len(loss) != len(output_names):
raise ValueError('When passing a list as loss, it should have one entry '
'per model outputs. The model has {} outputs, but you '
'passed loss={}'.format(len(output_names), loss))
loss_functions = [get_loss_function(l) for l in loss]
else:
loss_functions = [get_loss_function(loss) for _ in range(len(output_names))]

return loss_functions


def prepare_loss_weights(output_names, loss_weights=None):
"""Converts loss weights to a list of loss weights.
# Arguments
output_names: List of model output names.
loss_weights: Optional list or dictionary specifying scalar coefficients
(Python floats) to weight the loss contributions of different model
outputs. The loss value that will be minimized by the model will then be
the *weighted sum* of all individual losses, weighted by the
`loss_weights` coefficients. If a list, it is expected to have a 1:1
mapping to the model's outputs. If a dict, it is expected to map
output names (strings) to scalar coefficients.
# Returns
A list of loss weights of python floats.
# Raises
ValueError: If loss weight is a dict with key not in model output names,
or if loss is a list with len not equal to model outputs.
"""
if loss_weights is None:
weights_list = [1.] * len(output_names)
elif isinstance(loss_weights, collections.Mapping):
generic_utils.check_for_unexpected_keys('loss_weights', loss_weights,
output_names)
weights_list = [loss_weights.get(name, 1.) for name in output_names]
elif isinstance(loss_weights, list):
if len(loss_weights) != len(output_names):
raise ValueError('When passing a list as loss_weights, '
'it should have one entry per model output. '
'The model has ' + str(len(output_names)) +
' outputs, but you passed loss_weights=' +
str(loss_weights))
weights_list = loss_weights
else:
raise TypeError('Could not interpret loss_weights argument: ' +
str(loss_weights) + ' - expected a list of dicts.')

return weights_list
4 changes: 2 additions & 2 deletions keras/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,12 +557,12 @@ def cosine_proximity(y_true, y_pred):


def is_categorical_crossentropy(loss):
return (isinstance(loss, CategoricalCrossentropy or
return (isinstance(loss, CategoricalCrossentropy) or
(isinstance(loss, LossFunctionWrapper) and
loss.fn == categorical_crossentropy) or
(hasattr(loss, '__name__') and
loss.__name__ == 'categorical_crossentropy') or
loss == 'categorical_crossentropy'))
loss == 'categorical_crossentropy')


def serialize(loss):
Expand Down
8 changes: 8 additions & 0 deletions keras/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,3 +613,11 @@ def transpose_shape(shape, target_format, spatial_axes):
raise ValueError('The `data_format` argument must be one of '
'"channels_first", "channels_last". Received: ' +
str(target_format))


def check_for_unexpected_keys(name, input_dict, expected_values):
unknown = set(input_dict.keys()).difference(expected_values)
if unknown:
raise ValueError('Unknown entries in {} dictionary: {}. Only expected '
'following keys: {}'.format(name, list(unknown),
expected_values))
33 changes: 32 additions & 1 deletion keras/utils/losses_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,38 @@ def compute_weighted_loss(losses,
losses, _, sample_weight = squeeze_or_expand_dimensions(
losses, None, sample_weight)

weighted_losses = losses * sample_weight
# Broadcast weights if possible.
weights_shape = K.int_shape(sample_weight)
losses_shape = K.int_shape(losses)
if losses_shape != weights_shape:
weights_rank = K.ndim(sample_weight)
losses_rank = K.ndim(losses)

# Raise error if ndim of weights is > losses.
if weights_rank > losses_rank:
raise ValueError(
'Incompatible shapes: `losses` {} vs `sample_weight` {}'.format(
losses_shape, weights_shape))

# Expand dim of weights to match ndim of losses, if required.
for i in range(weights_rank, losses_rank):
sample_weight = K.expand_dims(sample_weight, axis=i)

for i in range(weights_rank):
if (weights_shape[i] is not None and losses_shape[i] is not None and
weights_shape[i] != losses_shape[i]):
# Cannot be broadcasted.
if weights_shape[i] != 1:
raise ValueError(
'Incompatible shapes: `losses` {} vs '
'`sample_weight` {}'.format(
losses_shape, weights_shape))
sample_weight = K.repeat_elements(
sample_weight, losses_shape[i], axis=i)

# Apply weights to losses.
weighted_losses = sample_weight * losses

# Apply reduction function to the individual weighted losses.
loss = reduce_weighted_loss(weighted_losses, reduction)
# Convert the result back to the input type.
Expand Down
7 changes: 3 additions & 4 deletions keras/wrappers/scikit_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np

from .. import losses
from ..utils.np_utils import to_categorical
from ..utils.generic_utils import has_arg
from ..utils.generic_utils import to_list
Expand Down Expand Up @@ -140,10 +141,8 @@ def fit(self, x, y, **kwargs):
else:
self.model = self.build_fn(**self.filter_sk_params(self.build_fn))

loss_name = self.model.loss
if hasattr(loss_name, '__name__'):
loss_name = loss_name.__name__
if loss_name == 'categorical_crossentropy' and len(y.shape) != 2:
if (losses.is_categorical_crossentropy(self.model.loss) and
len(y.shape) != 2):
y = to_categorical(y)

fit_args = copy.deepcopy(self.filter_sk_params(Sequential.fit))
Expand Down
Loading

0 comments on commit 910e124

Please sign in to comment.