-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Align subclassing guide with docstring of Layer.add_loss() #1144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
which uses `reduce_mean()`, not `reduce_sum()`, to achieve independence of batch sizes, consistent with Keras' default loss reduction by SUM_OVER_BATCH_SIZE. Along the way, clarify the difference between `add_loss` and `Loss`.
Addresses review comment.
|
The CI failure looks unrelated: |
|
|
||
| """ | ||
| Notice that `add_loss()` can take the result of plain TensorFlow operations. | ||
| There is no need to call a `Loss` object here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I assume, Loss objects can be used as well, is this correct? The wording would sound like it's discouraged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My motivation for this change stems from design discussions with colleagues who had the impression that subclasses of Loss would be the natural way to encapsulate some custom loss functions that we planned to use with add_loss().
So let's consider what class Loss adds on top of a simple tensor-to-tensor function from (y_true, y_pred) to a total loss:
- A split between init-time hparams and call-time inputs, providing a parametrized callback for the training loop.
- Serialization and deserialization as part of a Keras Model config.
- Wrapping of a user-level
call()function with automated application of per-example weights... - ...and reduction across examples.
For use in Model.compile(loss=...), all of this makes sense.
For use with Layer.add_loss(...), not so much:
add_loss()only takes the final result, no need to pass a parametrized callback.add_loss()does not track aLossobject for serialization as such- Example weights are typically not available within the model definition.
- The one type of reduction that makes sense for
add_loss()isSUM_OVER_BATCH_SIZE, and that will unexpectedly blow up in your face the moment you try to move the Model under a distribution strategy.
So, not only did the similarity in name point us towards an abstraction that was overly complex (items 1-3), it actually hurt us as soon as we wanted to leverage TensorFlow's key feature: scaling up with distributed training (item 4).
Bottom line: Yes, I think it is appropriate to gently discourage an unreflected use of Loss objects with add_loss(). I would have liked to leave a stronger hint at item 4 (see first commit of this PR), but I couldn't find a good link target to provide the necessary context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Friendly ping... I think the proposed wording is completely correct and useful guidance for the user: While calling a Loss object is possible, with some care, a plain Tensor is often more appropriate (more often than the similar names suggest).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, especially given that majority of the Loss classes have the corresponding loss functions to use, there really isn't a need to complicate the workflow by using a class wrapper. It doesn't seem appropriate for a model builder to easily make mistake by using a reduction method that doesn't work with distribute strategies either, so discouraging using that is an option too. In any case, the way it's stated here lgtm.
Thanks again for the change and patience!
|
Thank you, @rchao! The GitHub checks are telling me I need review from a code owner, and @MarkDaoust was requested to do it. |
fchollet
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
which uses
reduce_mean(), notreduce_sum(), to achieve independence of batch sizes, consistent with Keras' default loss reduction by SUM_OVER_BATCH_SIZE. Along the way, clarify the difference betweenadd_lossandLoss.