Skip to content

Commit

Permalink
Refine the doc of sigmoid_binary_cross_entropy to not assume the mean…
Browse files Browse the repository at this point in the history
…ing of last dimension.

This loss is just an elementwise loss and the last dimension can be anything, not necessarily `num_classes`.
For example, `logits` can be a vector whose dimension means `batch`.

PiperOrigin-RevId: 474589208
  • Loading branch information
ppwwyyxx authored and OptaxDev committed Sep 16, 2022
1 parent 44df918 commit 649b1ec
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions optax/_src/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,23 +117,24 @@ def smooth_labels(


def sigmoid_binary_cross_entropy(logits, labels):
"""Computes sigmoid cross entropy given logits and multiple class labels.
"""Computes element-wise sigmoid cross entropy given logits and labels.
Measures the probability error in discrete classification tasks in which
each class is an independent binary prediction and different classes are
not mutually exclusive. This may be used for multilabel image classification
for instance a model may predict that an image contains both a cat and a dog.
This can be used to measure the error in discrete classification tasks in
which each class is an independent binary prediction and different classes
are not mutually exclusive. This may be used for multilabel image
classification for instance a model may predict that an image contains both a
cat and a dog.
References:
[Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html)
Args:
logits: Unnormalized log probabilities, with shape `[..., num_classes]`.
labels: The target probabilities for each class, must have a shape
broadcastable to that of `logits`;
logits: Each element is unnormalized log probability of a binary prediction.
labels: The target probabilities, must have a shape broadcastable to that of
`logits`;
Returns:
cross entropy for each binary class prediction, shape `[..., num_classes]`.
cross entropy for each binary prediction, same shape as `logits`.
"""
chex.assert_type([logits], float)
log_p = jax.nn.log_sigmoid(logits)
Expand Down

0 comments on commit 649b1ec

Please sign in to comment.