Add numerically stable cross entropy loss #856
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
I was using the parametric UMAP cross entropy loss function for another project and ran into some odd and intermittent issues with numerical stability, which I could not pinpoint. My solution was to modify the loss to calculate the log probabilities directly and use a reparameterized repellent term for the cross entropy from Section 8.1 of Shi et al. 2022 (https://arxiv.org/abs/2111.08851)
log(1 - sigmoid(logits)) = log(sigmoid(logits)) - logits
. This seemed to solve any issues I was having. I have not tested this directly with your code base, but thought it might be useful nonetheless.As a side note, I switched the (0,1] threshold implemented with
clip_by_value
to usesigmoid
for (0, 1). You could replacelog_sigmoid(x) = -softplus(-x)
with the equivalent rectifierlog_hard_sigmoid(x) = -relu(-x)
if you prefer to keep the hard threshold.