Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion guides/making_new_layers_and_models_via_subclassing.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,14 @@ def __init__(self, rate=1e-2):
self.rate = rate

def call(self, inputs):
self.add_loss(self.rate * tf.reduce_sum(inputs))
self.add_loss(self.rate * tf.reduce_mean(inputs))
return inputs


"""
Notice that `add_loss()` can take the result of plain TensorFlow operations.
There is no need to call a `Loss` object here.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I assume, Loss objects can be used as well, is this correct? The wording would sound like it's discouraged.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My motivation for this change stems from design discussions with colleagues who had the impression that subclasses of Loss would be the natural way to encapsulate some custom loss functions that we planned to use with add_loss().

So let's consider what class Loss adds on top of a simple tensor-to-tensor function from (y_true, y_pred) to a total loss:

  1. A split between init-time hparams and call-time inputs, providing a parametrized callback for the training loop.
  2. Serialization and deserialization as part of a Keras Model config.
  3. Wrapping of a user-level call() function with automated application of per-example weights...
  4. ...and reduction across examples.

For use in Model.compile(loss=...), all of this makes sense.

For use with Layer.add_loss(...), not so much:

  1. add_loss() only takes the final result, no need to pass a parametrized callback.
  2. add_loss() does not track a Loss object for serialization as such
  3. Example weights are typically not available within the model definition.
  4. The one type of reduction that makes sense for add_loss() is SUM_OVER_BATCH_SIZE, and that will unexpectedly blow up in your face the moment you try to move the Model under a distribution strategy.

So, not only did the similarity in name point us towards an abstraction that was overly complex (items 1-3), it actually hurt us as soon as we wanted to leverage TensorFlow's key feature: scaling up with distributed training (item 4).

Bottom line: Yes, I think it is appropriate to gently discourage an unreflected use of Loss objects with add_loss(). I would have liked to leave a stronger hint at item 4 (see first commit of this PR), but I couldn't find a good link target to provide the necessary context.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Friendly ping... I think the proposed wording is completely correct and useful guidance for the user: While calling a Loss object is possible, with some care, a plain Tensor is often more appropriate (more often than the similar names suggest).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, especially given that majority of the Loss classes have the corresponding loss functions to use, there really isn't a need to complicate the workflow by using a class wrapper. It doesn't seem appropriate for a model builder to easily make mistake by using a reduction method that doesn't work with distribute strategies either, so discouraging using that is an option too. In any case, the way it's stated here lgtm.

Thanks again for the change and patience!


These losses (including those created by any inner layer) can be retrieved via
`layer.losses`. This property is reset at the start of every `__call__()` to
the top-level layer, so that `layer.losses` always contains the loss values
Expand Down