diff --git a/keras/backend.py b/keras/backend.py index 63e7bcd20bfe..6b6dab677c99 100644 --- a/keras/backend.py +++ b/keras/backend.py @@ -5566,8 +5566,12 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1): labels=target, logits=output, axis=axis ) - # scale preds so that the class probas of each sample sum to 1 + # Adjust the predictions so that the probability of + # each class for every sample adds up to 1 + # This is needed to ensure that the cross entropy is + # computed correctly. output = output / tf.reduce_sum(output, axis, True) + # Compute cross entropy from probabilities. epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype) output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_) @@ -5844,28 +5848,29 @@ def binary_focal_crossentropy( where `alpha` is a float in the range of `[0, 1]`. Args: - target: A tensor with the same shape as `output`. - output: A tensor. - apply_class_balancing: A bool, whether to apply weight balancing on the - binary classes 0 and 1. - alpha: A weight balancing factor for class 1, default is `0.25` as - mentioned in the reference. The weight for class 0 is `1.0 - alpha`. - gamma: A focusing parameter, default is `2.0` as mentioned in the - reference. - from_logits: Whether `output` is expected to be a logits tensor. By - default, we consider that `output` encodes a probability distribution. + target: A tensor with the same shape as `output`. + output: A tensor. + apply_class_balancing: A bool, whether to apply weight balancing on the + binary classes 0 and 1. + alpha: A weight balancing factor for class 1, default is `0.25` as + mentioned in the reference. The weight for class 0 is `1.0 - alpha`. + gamma: A focusing parameter, default is `2.0` as mentioned in the + reference. + from_logits: Whether `output` is expected to be a logits tensor. By + default, we consider that `output` encodes a probability + distribution. Returns: - A tensor. + A tensor. """ - sigmoidal = tf.__internal__.smart_cond.smart_cond( - from_logits, - lambda: sigmoid(output), - lambda: output, - ) + + sigmoidal = sigmoid(output) if from_logits else output + p_t = target * sigmoidal + (1 - target) * (1 - sigmoidal) + # Calculate focal factor focal_factor = tf.pow(1.0 - p_t, gamma) + # Binary crossentropy bce = binary_crossentropy( target=target, @@ -5893,7 +5898,7 @@ def sigmoid(x): Returns: A tensor. """ - return tf.sigmoid(x) + return tf.math.sigmoid(x) @keras_export("keras.backend.hard_sigmoid")