diff --git a/keras/engine/saving.py b/keras/engine/saving.py index 1b0685b2cd2..49f249210db 100644 --- a/keras/engine/saving.py +++ b/keras/engine/saving.py @@ -17,6 +17,7 @@ import numpy as np from .. import backend as K +from .. import losses from .. import optimizers from ..utils.io_utils import H5Dict from ..utils.io_utils import ask_to_proceed_with_overwrite @@ -347,7 +348,10 @@ def convert_custom_objects(obj): custom_objects=custom_objects) # Recover loss functions and metrics. - loss = convert_custom_objects(training_config['loss']) + loss_config = training_config['loss'] # Deserialize loss class. + if isinstance(loss_config, dict) and 'class_name' in loss_config: + loss_config = losses.get(loss_config) + loss = convert_custom_objects(loss_config) metrics = convert_custom_objects(training_config['metrics']) sample_weight_mode = training_config['sample_weight_mode'] loss_weights = training_config['loss_weights'] diff --git a/keras/engine/training.py b/keras/engine/training.py index 868f8c656a5..3b9000ec9a6 100644 --- a/keras/engine/training.py +++ b/keras/engine/training.py @@ -10,17 +10,7 @@ from .network import Network from .base_layer import Layer -from .training_utils import collect_metrics -from .training_utils import check_array_length_consistency -from .training_utils import check_loss_and_target_compatibility -from .training_utils import check_generator_arguments -from .training_utils import standardize_class_weights -from .training_utils import standardize_input_data -from .training_utils import standardize_sample_weights -from .training_utils import standardize_weights -from .training_utils import weighted_masked_objective -from .training_utils import get_static_batch_size -from .training_utils import is_generator_or_sequence +from . import training_utils from . import training_arrays from . import training_generator from .. import backend as K @@ -51,8 +41,8 @@ def compile(self, optimizer, # Arguments optimizer: String (name of optimizer) or optimizer instance. See [optimizers](/optimizers). - loss: String (name of objective function) or objective function. - See [losses](/losses). + loss: String (name of objective function) or objective function or + `Loss` instance. See [losses](/losses). If the model has multiple outputs, you can use a different loss on each output by passing a dictionary or a list of losses. The loss value that will be minimized by the model @@ -98,7 +88,7 @@ def compile(self, optimizer, `optimizer`, `loss`, `metrics` or `sample_weight_mode`. """ self.optimizer = optimizers.get(optimizer) - self.loss = loss or [] + self.loss = loss or {} self.metrics = metrics or [] self.loss_weights = loss_weights self.sample_weight_mode = sample_weight_mode @@ -113,80 +103,30 @@ def compile(self, optimizer, return self._is_compiled = True - # Prepare loss functions. - if isinstance(loss, dict): - for name in loss: - if name not in self.output_names: - raise ValueError('Unknown entry in loss ' - 'dictionary: "' + name + '". ' - 'Only expected the following keys: ' + - str(self.output_names)) - loss_functions = [] - for name in self.output_names: - if name not in loss: - warnings.warn('Output "' + name + - '" missing from loss dictionary. ' - 'We assume this was done on purpose, ' - 'and we will not be expecting ' - 'any data to be passed to "' + name + - '" during training.', stacklevel=2) - loss_functions.append(losses.get(loss.get(name))) - elif isinstance(loss, list): - if len(loss) != len(self.outputs): - raise ValueError('When passing a list as loss, ' - 'it should have one entry per model outputs. ' - 'The model has ' + str(len(self.outputs)) + - ' outputs, but you passed loss=' + - str(loss)) - loss_functions = [losses.get(l) for l in loss] - else: - loss_function = losses.get(loss) - loss_functions = [loss_function for _ in range(len(self.outputs))] - self.loss_functions = loss_functions - weighted_losses = [ - weighted_masked_objective(fn) for fn in loss_functions] - skip_target_indices = [] - skip_target_weighing_indices = [] + # Prepare list of loss functions, same size as model outputs. + self.loss_functions = training_utils.prepare_loss_functions( + self.loss, self.output_names) + self._feed_outputs = [] self._feed_output_names = [] self._feed_output_shapes = [] self._feed_loss_fns = [] - for i in range(len(weighted_losses)): - if weighted_losses[i] is None: + + # if loss function is None, then this output will be skipped during total + # loss calculation and feed targets preparation. + skip_target_indices = [] + skip_target_weighing_indices = [] + for i, loss_function in enumerate(self.loss_functions): + if loss_function is None: skip_target_indices.append(i) skip_target_weighing_indices.append(i) # Prepare output masks. - masks = self.compute_mask(self.inputs, mask=None) - if masks is None: - masks = [None for _ in self.outputs] - masks = to_list(masks) - - # Prepare loss weights. - if loss_weights is None: - loss_weights_list = [1. for _ in range(len(self.outputs))] - elif isinstance(loss_weights, dict): - for name in loss_weights: - if name not in self.output_names: - raise ValueError('Unknown entry in loss_weights ' - 'dictionary: "' + name + '". ' - 'Only expected the following keys: ' + - str(self.output_names)) - loss_weights_list = [] - for name in self.output_names: - loss_weights_list.append(loss_weights.get(name, 1.)) - elif isinstance(loss_weights, list): - if len(loss_weights) != len(self.outputs): - raise ValueError('When passing a list as loss_weights, ' - 'it should have one entry per model output. ' - 'The model has ' + str(len(self.outputs)) + - ' outputs, but you passed loss_weights=' + - str(loss_weights)) - loss_weights_list = loss_weights - else: - raise TypeError('Could not interpret loss_weights argument: ' + - str(loss_weights) + - ' - expected a list of dicts.') + masks = [getattr(x, '_keras_mask', None) for x in self.outputs] + + # Prepare list loss weights, same size of model outputs. + self.loss_weights_list = training_utils.prepare_loss_weights( + self.output_names, loss_weights) # Prepare targets of model. self.targets = [] @@ -250,124 +190,18 @@ def compile(self, optimizer, self.targets.append(target) # Prepare sample weights. - sample_weights = [] - sample_weight_modes = [] - if isinstance(sample_weight_mode, dict): - for name in sample_weight_mode: - if name not in self.output_names: - raise ValueError('Unknown entry in ' - 'sample_weight_mode dictionary: "' + - name + '". ' - 'Only expected the following keys: ' + - str(self.output_names)) - for i, name in enumerate(self.output_names): - if i in skip_target_weighing_indices: - weight = None - sample_weight_modes.append(None) - else: - if name not in sample_weight_mode: - raise ValueError('Output "' + name + - '" missing from sample_weight_modes ' - 'dictionary') - if sample_weight_mode.get(name) == 'temporal': - weight = K.placeholder(ndim=2, - name=name + '_sample_weights') - sample_weight_modes.append('temporal') - else: - weight = K.placeholder(ndim=1, - name=name + '_sample_weights') - sample_weight_modes.append(None) - sample_weights.append(weight) - elif isinstance(sample_weight_mode, list): - if len(sample_weight_mode) != len(self.outputs): - raise ValueError('When passing a list as sample_weight_mode, ' - 'it should have one entry per model output. ' - 'The model has ' + str(len(self.outputs)) + - ' outputs, but you passed ' - 'sample_weight_mode=' + - str(sample_weight_mode)) - for i in range(len(self.output_names)): - if i in skip_target_weighing_indices: - weight = None - sample_weight_modes.append(None) - else: - mode = sample_weight_mode[i] - name = self.output_names[i] - if mode == 'temporal': - weight = K.placeholder(ndim=2, - name=name + '_sample_weights') - sample_weight_modes.append('temporal') - else: - weight = K.placeholder(ndim=1, - name=name + '_sample_weights') - sample_weight_modes.append(None) - sample_weights.append(weight) - else: - for i, name in enumerate(self.output_names): - if i in skip_target_weighing_indices: - sample_weight_modes.append(None) - sample_weights.append(None) - else: - if sample_weight_mode == 'temporal': - sample_weights.append( - K.placeholder(ndim=2, - name=name + '_sample_weights')) - sample_weight_modes.append('temporal') - else: - sample_weights.append( - K.placeholder(ndim=1, - name=name + '_sample_weights')) - sample_weight_modes.append(None) - self.sample_weight_modes = sample_weight_modes - self._feed_sample_weight_modes = [] - for i in range(len(self.outputs)): - if i not in skip_target_weighing_indices: - self._feed_sample_weight_modes.append( - self.sample_weight_modes[i]) + self._set_sample_weight_attributes( + sample_weight_mode, skip_target_weighing_indices) # Prepare metrics. self.metrics_names = ['loss'] self.metrics_tensors = [] - # Compute total loss. - total_loss = None - with K.name_scope('loss'): - for i in range(len(self.outputs)): - if i in skip_target_indices: - continue - y_true = self.targets[i] - y_pred = self.outputs[i] - weighted_loss = weighted_losses[i] - sample_weight = sample_weights[i] - mask = masks[i] - loss_weight = loss_weights_list[i] - with K.name_scope(self.output_names[i] + '_loss'): - output_loss = weighted_loss(y_true, y_pred, - sample_weight, mask) - if len(self.outputs) > 1: - self.metrics_tensors.append(output_loss) - self.metrics_names.append(self.output_names[i] + '_loss') - if total_loss is None: - total_loss = loss_weight * output_loss - else: - total_loss += loss_weight * output_loss - if total_loss is None: - if not self.losses: - raise ValueError('The model cannot be compiled ' - 'because it has no loss to optimize.') - else: - total_loss = 0. - - # Add regularization penalties - # and other layer-specific losses. - for loss_tensor in self.losses: - total_loss += loss_tensor - # List of same size as output_names. # contains tuples (metrics for output, names of metrics). - nested_metrics = collect_metrics(metrics, self.output_names) - nested_weighted_metrics = collect_metrics(weighted_metrics, - self.output_names) + nested_metrics = training_utils.collect_metrics(metrics, self.output_names) + nested_weighted_metrics = training_utils.collect_metrics( + weighted_metrics, self.output_names) self.metrics_updates = [] self.stateful_metric_names = [] self.stateful_metric_functions = [] @@ -406,11 +240,13 @@ def handle_metrics(metrics, weights=None): suffix = 'acc' elif metric in ('crossentropy', 'ce'): suffix = 'ce' - weighted_metric_fn = weighted_masked_objective(metric_fn) + weighted_metric_fn = training_utils.weighted_masked_objective( + metric_fn) metric_name = metric_name_prefix + suffix else: metric_fn = metrics_module.get(metric) - weighted_metric_fn = weighted_masked_objective(metric_fn) + weighted_metric_fn = training_utils.weighted_masked_objective( + metric_fn) # Get metric name as string if hasattr(metric_fn, 'name'): metric_name = metric_fn.name @@ -449,19 +285,18 @@ def handle_metrics(metrics, weights=None): y_true = self.targets[i] y_pred = self.outputs[i] - weights = sample_weights[i] + weights = self.sample_weights[i] output_metrics = nested_metrics[i] output_weighted_metrics = nested_weighted_metrics[i] handle_metrics(output_metrics) handle_metrics(output_weighted_metrics, weights=weights) - # Prepare gradient updates and state updates. - self.total_loss = total_loss - self.sample_weights = sample_weights - self._feed_sample_weights = [] - for i in range(len(self.sample_weights)): - if i not in skip_target_weighing_indices: - self._feed_sample_weights.append(sample_weights[i]) + # Compute total loss. + # Used to keep track of the total loss value (stateless). + # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) + + # loss_weight_2 * output_2_loss_fn(...) + + # layer losses. + self.total_loss = self._prepare_total_loss(skip_target_indices, masks) # Functions for train, test and predict will # be compiled lazily when required. @@ -750,7 +585,7 @@ def _standardize_user_data(self, x, feed_input_shapes = self._feed_input_shapes # Standardize the inputs. - x = standardize_input_data( + x = training_utils.standardize_input_data( x, feed_input_names, feed_input_shapes, @@ -770,25 +605,29 @@ def _standardize_user_data(self, x, feed_output_shapes = [] for output_shape, loss_fn in zip(self._feed_output_shapes, self._feed_loss_fns): - if loss_fn is losses.sparse_categorical_crossentropy: + if ((isinstance(loss_fn, losses.LossFunctionWrapper) and + loss_fn.fn == losses.sparse_categorical_crossentropy)) or ( + isinstance( + loss_fn, losses.SparseCategoricalCrossentropy)): if K.image_data_format() == 'channels_first' and len( output_shape) in [4, 5]: feed_output_shapes.append( (output_shape[0], 1) + output_shape[2:]) else: feed_output_shapes.append(output_shape[:-1] + (1,)) - elif (not hasattr(loss_fn, '__name__') or - getattr(losses, loss_fn.__name__, None) is None): - # If `loss_fn` is not a function (e.g. callable class) - # or if it not in the `losses` module, then - # it is a user-defined loss and we make no assumptions - # about it. + elif (not isinstance(loss_fn, losses.Loss) or + (isinstance(loss_fn, losses.LossFunctionWrapper) and + (getattr(losses, loss_fn.fn.__name__, None) is None))): + # If the given loss is not an instance of the `Loss` class + # (custom class) or if the loss function that is wrapped is + # not in the `losses` module, then it is a user-defined loss + # and we make no assumptions about it. feed_output_shapes.append(None) else: feed_output_shapes.append(output_shape) # Standardize the outputs. - y = standardize_input_data( + y = training_utils.standardize_input_data( y, feed_output_names, feed_output_shapes, @@ -797,23 +636,23 @@ def _standardize_user_data(self, x, # Generate sample-wise weight values given the `sample_weight` and # `class_weight` arguments. - sample_weights = standardize_sample_weights( + sample_weights = training_utils.standardize_sample_weights( sample_weight, feed_output_names) - class_weights = standardize_class_weights( + class_weights = training_utils.standardize_class_weights( class_weight, feed_output_names) sample_weights = [ - standardize_weights(ref, sw, cw, mode) + training_utils.standardize_weights(ref, sw, cw, mode) for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights, feed_sample_weight_modes) ] # Check that all arrays have the same length. if check_array_lengths: - check_array_length_consistency(x, y, sample_weights) + training_utils.check_array_length_consistency(x, y, sample_weights) if self._is_graph_network: # Additional checks to avoid users mistakenly # using improper loss fns. - check_loss_and_target_compatibility( + training_utils.check_loss_and_target_compatibility( y, self._feed_loss_fns, feed_output_shapes) else: y = [] @@ -830,6 +669,65 @@ def _standardize_user_data(self, x, str(x[0].shape[0]) + ' samples') return x, y, sample_weights + def _prepare_total_loss(self, skip_target_indices=None, masks=None): + """Computes total loss from loss functions. + + # Arguments + skip_target_indices: A list of indices of model outputs where loss + function is None. + masks: List of mask values corresponding to each model output. + + # Returns + A list of loss weights of python floats. + """ + skip_target_indices = skip_target_indices or [] + total_loss = None + with K.name_scope('loss'): + zipped_inputs = zip(self.targets, self.outputs, self.loss_functions, + self.sample_weights, masks, self.loss_weights_list) + for i, (y_true, y_pred, loss_fn, sample_weight, mask, + loss_weight) in enumerate(zipped_inputs): + if i in skip_target_indices: + continue + loss_name = self.output_names[i] + '_loss' + with K.name_scope(loss_name): + if mask is not None: + mask = math_ops.cast(mask, y_pred.dtype) + # Update weights with mask. + if sample_weight is None: + sample_weight = mask + else: + # Update dimensions of weights to match with mask. + mask, _, sample_weight = ( + losses_utils.squeeze_or_expand_dimensions( + mask, None, sample_weight)) + sample_weight *= mask + + output_loss = loss_fn( + y_true, y_pred, sample_weight=sample_weight) + + if len(self.outputs) > 1: + self.metrics_tensors.append(output_loss) + self.metrics_names.append(self.output_names[i] + '_loss') + + if total_loss is None: + total_loss = loss_weight * output_loss + else: + total_loss += loss_weight * output_loss + + if total_loss is None: + if not self.losses: + raise ValueError('The model cannot be compiled ' + 'because it has no loss to optimize.') + else: + total_loss = 0. + + # Add regularization penalties and other layer-specific losses. + for loss_tensor in self.losses: + total_loss += loss_tensor + + return total_loss + def _get_callback_model(self): """Returns the Callback Model for this Model.""" if hasattr(self, 'callback_model') and self.callback_model: @@ -862,14 +760,14 @@ def _validate_or_infer_batch_size(self, batch_size, steps, x): is passed, or if the specified batch size does not match the exepected size defined in the Input Layer. """ - if batch_size is not None and is_generator_or_sequence(x): + if batch_size is not None and training_utils.is_generator_or_sequence(x): raise ValueError('The `batch_size` argument must not be specified when' ' using a generator or Sequence as an input.') layers = super(Model, self).layers # Avoids the override in Sequential. if layers: first_layer = layers[0] - static_batch_size = get_static_batch_size(first_layer) + static_batch_size = training_utils.get_static_batch_size(first_layer) if static_batch_size is not None: # Check `batch_size` argument is consistent with InputLayer. @@ -888,6 +786,24 @@ def _validate_or_infer_batch_size(self, batch_size, steps, x): batch_size = 32 return batch_size + def _set_sample_weight_attributes(self, sample_weight_mode, + skip_target_weighing_indices): + """Sets sample weight related attributes on the model.""" + sample_weights, sample_weight_modes = training_utils.prepare_sample_weights( + self.output_names, sample_weight_mode, skip_target_weighing_indices) + self.sample_weights = sample_weights + self.sample_weight_modes = sample_weight_modes + self._feed_sample_weight_modes = [ + sample_weight_modes[i] + for i in range(len(self.outputs)) + if i not in skip_target_weighing_indices + ] + self._feed_sample_weights = [ + sample_weights[i] + for i in range(len(sample_weights)) + if i not in skip_target_weighing_indices + ] + def fit(self, x=None, y=None, @@ -1063,8 +979,8 @@ def fit(self, # Case 1: generator-like. Input is Python generator, # or Sequence object, or iterator. - if is_generator_or_sequence(x): - check_generator_arguments( + if training_utils.is_generator_or_sequence(x): + training_utils.check_generator_arguments( y, sample_weight, validation_split=validation_split) return self.fit_generator( x, @@ -1266,8 +1182,8 @@ def evaluate(self, batch_size = self._validate_or_infer_batch_size(batch_size, steps, x) # Case 1: generator-like. Input is Python generator, or Sequence object. - if is_generator_or_sequence(x): - check_generator_arguments(y, sample_weight) + if training_utils.is_generator_or_sequence(x): + training_utils.check_generator_arguments(y, sample_weight) return self.evaluate_generator( x, steps=steps, @@ -1362,7 +1278,7 @@ def predict(self, x, batch_size = self._validate_or_infer_batch_size(batch_size, steps, x) # Case 1: generator-like. Input is Python generator, or Sequence object. - if is_generator_or_sequence(x): + if training_utils.is_generator_or_sequence(x): return self.predict_generator( x, steps=steps, diff --git a/keras/engine/training_utils.py b/keras/engine/training_utils.py index a57217c2628..5135784bedb 100644 --- a/keras/engine/training_utils.py +++ b/keras/engine/training_utils.py @@ -8,12 +8,14 @@ import collections import copy import numpy as np +import six import warnings from .. import backend as K from .. import losses from ..utils import Sequence -from ..utils.generic_utils import to_list +from ..utils import generic_utils +from ..utils import losses_utils def standardize_single_array(x): @@ -252,7 +254,8 @@ def set_of_lengths(x): def check_loss_and_target_compatibility(targets, loss_fns, output_shapes): """Does validation on the compatibility of targets and loss functions. - This helps prevent users from using loss functions incorrectly. + This helps prevent users from using loss functions incorrectly. This check + is purely for UX purposes. # Arguments targets: list of Numpy arrays of targets. @@ -263,13 +266,16 @@ def check_loss_and_target_compatibility(targets, loss_fns, output_shapes): ValueError: if a loss function or target array is incompatible with an output. """ - key_losses = {losses.mean_squared_error, - losses.binary_crossentropy, - losses.categorical_crossentropy} + key_loss_fns = { + losses.mean_squared_error, losses.binary_crossentropy, + losses.categorical_crossentropy + } + key_loss_classes = (losses.MeanSquaredError, losses.BinaryCrossentropy, + losses.CategoricalCrossentropy) for y, loss, shape in zip(targets, loss_fns, output_shapes): if y is None or loss is None: continue - if loss is losses.categorical_crossentropy: + if losses.is_categorical_crossentropy(loss): if y.shape[-1] == 1: raise ValueError( 'You are passing a target array of shape ' + str(y.shape) + @@ -287,15 +293,20 @@ def check_loss_and_target_compatibility(targets, loss_fns, output_shapes): 'Alternatively, you can use the loss function ' '`sparse_categorical_crossentropy` instead, ' 'which does expect integer targets.') - if loss in key_losses: + is_loss_wrapper = isinstance(loss, losses.LossFunctionWrapper) + if (isinstance(loss, key_loss_classes) or (is_loss_wrapper and + (loss.fn in key_loss_fns))): for target_dim, out_dim in zip(y.shape[1:], shape[1:]): if out_dim is not None and target_dim != out_dim: + loss_name = loss.name + if loss_name is None: + loss_type = loss.fn if is_loss_wrapper else type(loss) + loss_name = loss_type.__name__ raise ValueError( 'A target array with shape ' + str(y.shape) + ' was passed for an output of shape ' + str(shape) + - ' while using as loss `' + loss.__name__ + '`. ' - 'This loss expects ' - 'targets to have the same shape ' + ' while using as loss `' + loss_name + '`. ' + 'This loss expects targets to have the same shape ' 'as the output.') @@ -347,7 +358,7 @@ def collect_metrics(metrics, output_names): .format(unknown_output_names, output_names)) for name in output_names: output_metrics = metrics.get(name, []) - output_metrics = to_list(output_metrics) + output_metrics = generic_utils.to_list(output_metrics) nested_metrics.append(output_metrics) return nested_metrics else: @@ -713,3 +724,189 @@ def _is_graph_model(layer): if hasattr(layer, '_batch_input_shape'): return layer._batch_input_shape, layer.dtype return None, None + + +def get_loss_function(loss): + """Returns the loss corresponding to the loss input in `compile` API.""" + if loss is None or isinstance(loss, losses.Loss): + return loss + + # Deserialize loss configuration, if needed. + if isinstance(loss, collections.Mapping): + loss = losses.get(loss) + + # Custom callable class. + if callable(loss) and not hasattr(loss, '__name__'): + return loss + + # Wrap loss function with signature `(y_true, y_pred, **kwargs)` + # in `LossFunctionWrapper` class. + loss_fn = losses.get(loss) + + # For losses which are given as strings/functions in the compile API, + # we always set the loss reduction type to be `SUM_OVER_BATCH_SIZE`.. + return losses.LossFunctionWrapper( + loss_fn, + name=loss_fn.__name__, + reduction=losses_utils.Reduction.SUM_OVER_BATCH_SIZE) + + +def get_output_sample_weight_and_mode(skip_target_weighing_indices, + sample_weight_mode, output_name, + output_index): + """Returns the sample weight and weight mode for a single output.""" + if output_index in skip_target_weighing_indices: + return None, None + + if sample_weight_mode == 'temporal': + shape = [None, None] + mode = 'temporal' + else: + shape = [None] + mode = None + weight = K.placeholder( + shape=shape, + name=output_name + '_sample_weights') + return weight, mode + + +def prepare_sample_weights(output_names, sample_weight_mode, + skip_target_weighing_indices): + """Prepares sample weights for the model. + + # Arguments + output_names: List of model output names. + sample_weight_mode: sample weight mode user input passed from compile API. + skip_target_weighing_indices: Indices of output for which sample weights + should be skipped. + + # Returns + A pair of list of sample weights and sample weight modes + (one for each output). + + # Raises + ValueError: In case of invalid `sample_weight_mode` input. + """ + sample_weights = [] + sample_weight_modes = [] + if isinstance(sample_weight_mode, dict): + unknown_output = set(sample_weight_mode.keys()) - set(output_names) + if unknown_output: + raise ValueError( + 'Unknown entry in ' + 'sample_weight_mode dictionary: "' + str(unknown_output) + + '". Only expected the following keys: ' + str(output_names)) + for i, name in enumerate(output_names): + if (i not in skip_target_weighing_indices and + name not in sample_weight_mode): + raise ValueError( + 'Output missing from sample_weight_modes dictionary') + weight, mode = get_output_sample_weight_and_mode( + skip_target_weighing_indices, + sample_weight_mode.get(name), + name, + i) + sample_weights.append(weight) + sample_weight_modes.append(mode) + elif isinstance(sample_weight_mode, list): + if len(sample_weight_mode) != len(output_names): + raise ValueError('When passing a list as sample_weight_mode, ' + 'it should have one entry per model output. ' + 'The model has ' + str(len(output_names)) + + ' outputs, but you passed ' + + str(len(sample_weight_mode)) + 'sample_weight_modes') + for i, name in enumerate(output_names): + weight, mode = get_output_sample_weight_and_mode( + skip_target_weighing_indices, sample_weight_mode[i], name, i) + sample_weights.append(weight) + sample_weight_modes.append(mode) + else: + for i, name in enumerate(output_names): + weight, mode = get_output_sample_weight_and_mode( + skip_target_weighing_indices, sample_weight_mode, name, i) + sample_weights.append(weight) + sample_weight_modes.append(mode) + return sample_weights, sample_weight_modes + + +def prepare_loss_functions(loss, output_names): + """Converts loss to a list of loss functions. + + # Arguments + loss: String (name of objective function), objective function or + `Loss` instance. If the model has multiple outputs, you can use + a different loss on each output by passing a dictionary or a + list of losses. The loss value that will be minimized by the model + will then be the sum of all individual losses. + output_names: List of model output names. + + # Returns + A list of loss objective functions. + + # Raises: + ValueError: If loss is a dict with keys not in model output names, + or if loss is a list with len not equal to model outputs. + """ + if isinstance(loss, collections.Mapping): + generic_utils.check_for_unexpected_keys('loss', loss, output_names) + loss_functions = [] + for name in output_names: + if name not in loss: + warnings.warn( + 'Output {0} missing from loss dictionary. We assume ' + 'this was done on purpose. The fit and evaluate APIs will not ' + 'be expecting any data to be passed to {0}.'.format(name)) + loss_functions.append(get_loss_function(loss.get(name, None))) + elif isinstance(loss, six.string_types): + loss_functions = [get_loss_function(loss) for _ in output_names] + elif isinstance(loss, collections.Sequence): + if len(loss) != len(output_names): + raise ValueError('When passing a list as loss, it should have one entry ' + 'per model outputs. The model has {} outputs, but you ' + 'passed loss={}'.format(len(output_names), loss)) + loss_functions = [get_loss_function(l) for l in loss] + else: + loss_functions = [get_loss_function(loss) for _ in range(len(output_names))] + + return loss_functions + + +def prepare_loss_weights(output_names, loss_weights=None): + """Converts loss weights to a list of loss weights. + + # Arguments + output_names: List of model output names. + loss_weights: Optional list or dictionary specifying scalar coefficients + (Python floats) to weight the loss contributions of different model + outputs. The loss value that will be minimized by the model will then be + the *weighted sum* of all individual losses, weighted by the + `loss_weights` coefficients. If a list, it is expected to have a 1:1 + mapping to the model's outputs. If a dict, it is expected to map + output names (strings) to scalar coefficients. + + # Returns + A list of loss weights of python floats. + + # Raises + ValueError: If loss weight is a dict with key not in model output names, + or if loss is a list with len not equal to model outputs. + """ + if loss_weights is None: + weights_list = [1.] * len(output_names) + elif isinstance(loss_weights, collections.Mapping): + generic_utils.check_for_unexpected_keys('loss_weights', loss_weights, + output_names) + weights_list = [loss_weights.get(name, 1.) for name in output_names] + elif isinstance(loss_weights, list): + if len(loss_weights) != len(output_names): + raise ValueError('When passing a list as loss_weights, ' + 'it should have one entry per model output. ' + 'The model has ' + str(len(output_names)) + + ' outputs, but you passed loss_weights=' + + str(loss_weights)) + weights_list = loss_weights + else: + raise TypeError('Could not interpret loss_weights argument: ' + + str(loss_weights) + ' - expected a list of dicts.') + + return weights_list diff --git a/keras/losses.py b/keras/losses.py index 633c2d77057..0b47a3735e3 100644 --- a/keras/losses.py +++ b/keras/losses.py @@ -557,12 +557,12 @@ def cosine_proximity(y_true, y_pred): def is_categorical_crossentropy(loss): - return (isinstance(loss, CategoricalCrossentropy or + return (isinstance(loss, CategoricalCrossentropy) or (isinstance(loss, LossFunctionWrapper) and loss.fn == categorical_crossentropy) or (hasattr(loss, '__name__') and loss.__name__ == 'categorical_crossentropy') or - loss == 'categorical_crossentropy')) + loss == 'categorical_crossentropy') def serialize(loss): diff --git a/keras/utils/generic_utils.py b/keras/utils/generic_utils.py index 0c7ae06b892..9b5118d5a91 100644 --- a/keras/utils/generic_utils.py +++ b/keras/utils/generic_utils.py @@ -613,3 +613,11 @@ def transpose_shape(shape, target_format, spatial_axes): raise ValueError('The `data_format` argument must be one of ' '"channels_first", "channels_last". Received: ' + str(target_format)) + + +def check_for_unexpected_keys(name, input_dict, expected_values): + unknown = set(input_dict.keys()).difference(expected_values) + if unknown: + raise ValueError('Unknown entries in {} dictionary: {}. Only expected ' + 'following keys: {}'.format(name, list(unknown), + expected_values)) diff --git a/keras/utils/losses_utils.py b/keras/utils/losses_utils.py index 9d037d3deb8..102f6158710 100644 --- a/keras/utils/losses_utils.py +++ b/keras/utils/losses_utils.py @@ -128,7 +128,38 @@ def compute_weighted_loss(losses, losses, _, sample_weight = squeeze_or_expand_dimensions( losses, None, sample_weight) - weighted_losses = losses * sample_weight + # Broadcast weights if possible. + weights_shape = K.int_shape(sample_weight) + losses_shape = K.int_shape(losses) + if losses_shape != weights_shape: + weights_rank = K.ndim(sample_weight) + losses_rank = K.ndim(losses) + + # Raise error if ndim of weights is > losses. + if weights_rank > losses_rank: + raise ValueError( + 'Incompatible shapes: `losses` {} vs `sample_weight` {}'.format( + losses_shape, weights_shape)) + + # Expand dim of weights to match ndim of losses, if required. + for i in range(weights_rank, losses_rank): + sample_weight = K.expand_dims(sample_weight, axis=i) + + for i in range(weights_rank): + if (weights_shape[i] is not None and losses_shape[i] is not None and + weights_shape[i] != losses_shape[i]): + # Cannot be broadcasted. + if weights_shape[i] != 1: + raise ValueError( + 'Incompatible shapes: `losses` {} vs ' + '`sample_weight` {}'.format( + losses_shape, weights_shape)) + sample_weight = K.repeat_elements( + sample_weight, losses_shape[i], axis=i) + + # Apply weights to losses. + weighted_losses = sample_weight * losses + # Apply reduction function to the individual weighted losses. loss = reduce_weighted_loss(weighted_losses, reduction) # Convert the result back to the input type. diff --git a/keras/wrappers/scikit_learn.py b/keras/wrappers/scikit_learn.py index 83c3e3c44b3..ee510c4e396 100644 --- a/keras/wrappers/scikit_learn.py +++ b/keras/wrappers/scikit_learn.py @@ -9,6 +9,7 @@ import numpy as np +from .. import losses from ..utils.np_utils import to_categorical from ..utils.generic_utils import has_arg from ..utils.generic_utils import to_list @@ -140,10 +141,8 @@ def fit(self, x, y, **kwargs): else: self.model = self.build_fn(**self.filter_sk_params(self.build_fn)) - loss_name = self.model.loss - if hasattr(loss_name, '__name__'): - loss_name = loss_name.__name__ - if loss_name == 'categorical_crossentropy' and len(y.shape) != 2: + if (losses.is_categorical_crossentropy(self.model.loss) and + len(y.shape) != 2): y = to_categorical(y) fit_args = copy.deepcopy(self.filter_sk_params(Sequential.fit)) diff --git a/tests/keras/engine/test_training.py b/tests/keras/engine/test_training.py index b4bc2e46942..2b224c3e070 100644 --- a/tests/keras/engine/test_training.py +++ b/tests/keras/engine/test_training.py @@ -10,7 +10,7 @@ import keras from keras import losses -from keras.layers import Activation, Dense, Dropout, Conv2D, Concatenate +from keras.layers import Layer, Activation, Dense, Dropout, Conv2D, Concatenate from keras.engine import Input from keras.engine.training import Model from keras.engine import training_utils @@ -665,6 +665,34 @@ def expected_shape(batch_size, n_batches): assert np.shape(out) == shape_0 +def test_training_with_loss_instance(): + a = Input(shape=(3,), name='input_a') + b = Input(shape=(3,), name='input_b') + + dense = Dense(4, name='dense') + c = dense(a) + d = dense(b) + e = Dropout(0.5, name='dropout')(c) + + model = Model([a, b], [d, e]) + loss_weights = [1., 0.5] + model.compile( + 'sgd', + loss=losses.MeanSquaredError(), + metrics=['mae'], + loss_weights=loss_weights) + + input_a_np = np.random.random((10, 3)) + input_b_np = np.random.random((10, 3)) + + output_d_np = np.random.random((10, 4)) + output_e_np = np.random.random((10, 4)) + + model.fit([input_a_np, input_b_np], [output_d_np, output_e_np], + epochs=1, + batch_size=5) + + @pytest.mark.skipif(sys.version_info < (3,), reason='Cannot catch warnings in python 2') def test_warnings(): @@ -788,7 +816,7 @@ def test_check_last_is_one(): a = np.random.random((2, 3, 1)) with pytest.raises(ValueError) as exc: training_utils.check_loss_and_target_compatibility( - [a], [losses.categorical_crossentropy], [a.shape]) + [a], [losses.CategoricalCrossentropy()], [a.shape]) assert 'You are passing a target array' in str(exc) @@ -797,7 +825,7 @@ def test_check_bad_shape(): a = np.random.random((2, 3, 5)) with pytest.raises(ValueError) as exc: training_utils.check_loss_and_target_compatibility( - [a], [losses.categorical_crossentropy], [(2, 3, 6)]) + [a], [losses.CategoricalCrossentropy()], [(2, 3, 6)]) assert 'targets to have the same shape' in str(exc) @@ -1723,5 +1751,27 @@ def on_test_begin(self, logs=None): assert val_counter.val_runs == 3 +def test_loss_correctness(): + class Bias(Layer): + + def build(self, input_shape): + self.bias = self.add_weight('bias', (1,), initializer='zeros') + + def call(self, inputs): + return inputs + self.bias + + inp = Input(shape=(1,)) + out = Bias()(inp) + model = Model(inp, out) + model.compile( + keras.optimizers.SGD(lr=0.1), + loss=keras.losses.MeanAbsoluteError()) + + x = np.array([[0.], [1.], [2.]]) + y = np.array([[0.5], [2.], [3.5]]) + history = model.fit(x, y, batch_size=3, epochs=5) + np.allclose(history.history['loss'], [1., 0.9, 0.8, 0.7, 0.6]) + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/tests/keras/losses_test.py b/tests/keras/losses_test.py index a4202e74ec9..86f20fdb7ba 100644 --- a/tests/keras/losses_test.py +++ b/tests/keras/losses_test.py @@ -28,7 +28,7 @@ class MSE_MAE_loss: def __init__(self, mse_fraction): self.mse_fraction = mse_fraction - def __call__(self, y_true, y_pred): + def __call__(self, y_true, y_pred, sample_weight=None): return (self.mse_fraction * losses.mse(y_true, y_pred) + (1 - self.mse_fraction) * losses.mae(y_true, y_pred)) @@ -124,6 +124,24 @@ def test_serializing_model_with_loss_class(self, tmpdir): loaded_model = keras.models.load_model(model_filename) loaded_model.predict(np.random.rand(128, 2)) + def test_loss_wrapper(self): + loss_fn = losses.get('mse') + mse_obj = losses.LossFunctionWrapper(loss_fn, name=loss_fn.__name__) + + assert mse_obj.name == 'mean_squared_error' + assert (mse_obj.reduction == losses_utils.Reduction.SUM_OVER_BATCH_SIZE) + + y_true = K.constant([[1., 9.], [2., 5.]]) + y_pred = K.constant([[4., 8.], [12., 3.]]) + sample_weight = K.constant([1.2, 0.5]) + loss = mse_obj(y_true, y_pred, sample_weight=sample_weight) + + # mse = [((4 - 1)^2 + (8 - 9)^2) / 2, ((12 - 2)^2 + (3 - 5)^2) / 2] + # mse = [5, 52] + # weighted_mse = [5 * 1.2, 52 * 0.5] = [6, 26] + # reduced_weighted_mse = (6 + 26) / 2 = + np.allclose(K.eval(loss), 16, atol=1e-2) + class TestMeanSquaredError: diff --git a/tests/test_model_saving.py b/tests/test_model_saving.py index a404112056b..84215c691c7 100644 --- a/tests/test_model_saving.py +++ b/tests/test_model_saving.py @@ -38,7 +38,7 @@ def test_sequential_model_saving(): model.add(Dense(2, input_shape=(3,))) model.add(RepeatVector(3)) model.add(TimeDistributed(Dense(3))) - model.compile(loss=losses.MSE, + model.compile(loss=losses.MeanSquaredError(), optimizer=optimizers.RMSprop(lr=0.0001), metrics=[metrics.categorical_accuracy], sample_weight_mode='temporal')