diff --git a/optax/_src/loss.py b/optax/_src/loss.py index b19e40187..bc6b1b9fc 100644 --- a/optax/_src/loss.py +++ b/optax/_src/loss.py @@ -117,23 +117,24 @@ def smooth_labels( def sigmoid_binary_cross_entropy(logits, labels): - """Computes sigmoid cross entropy given logits and multiple class labels. + """Computes element-wise sigmoid cross entropy given logits and labels. - Measures the probability error in discrete classification tasks in which - each class is an independent binary prediction and different classes are - not mutually exclusive. This may be used for multilabel image classification - for instance a model may predict that an image contains both a cat and a dog. + This can be used to measure the error in discrete classification tasks in + which each class is an independent binary prediction and different classes + are not mutually exclusive. This may be used for multilabel image + classification for instance a model may predict that an image contains both a + cat and a dog. References: [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) Args: - logits: Unnormalized log probabilities, with shape `[..., num_classes]`. - labels: The target probabilities for each class, must have a shape - broadcastable to that of `logits`; + logits: Each element is unnormalized log probability of a binary prediction. + labels: The target probabilities, must have a shape broadcastable to that of + `logits`; Returns: - cross entropy for each binary class prediction, shape `[..., num_classes]`. + cross entropy for each binary prediction, same shape as `logits`. """ chex.assert_type([logits], float) log_p = jax.nn.log_sigmoid(logits)