Skip to content

Commit

Permalink
PR #17746: Minor improvements and code refactoring in backend.py
Browse files Browse the repository at this point in the history
Imported from GitHub PR #17746

Small changes in backend.py, some of were discussed in the PR #17651
Copybara import of the project:

--
0f89165 by Kaan Bıçakcı <kaan.dvlpr@gmail.com>:

Small fixes on focal losses and cat.crossentropy

--
3c193de by Kaan Bıçakcı <kaan.dvlpr@gmail.com>:

Fix linting and sigmoid func

--
b87b656 by Kaan Bıçakcı <kaan.dvlpr@gmail.com>:

Revert the redirection of the internal function

Merging this change closes #17746

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17746 from Frightera:frightera_small_loss_fixes b87b656
PiperOrigin-RevId: 522179031
  • Loading branch information
tensorflower-gardener committed Apr 5, 2023
1 parent ecb4f98 commit a188434
Showing 1 changed file with 23 additions and 18 deletions.
41 changes: 23 additions & 18 deletions keras/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5566,8 +5566,12 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
labels=target, logits=output, axis=axis
)

# scale preds so that the class probas of each sample sum to 1
# Adjust the predictions so that the probability of
# each class for every sample adds up to 1
# This is needed to ensure that the cross entropy is
# computed correctly.
output = output / tf.reduce_sum(output, axis, True)

# Compute cross entropy from probabilities.
epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_)
Expand Down Expand Up @@ -5844,28 +5848,29 @@ def binary_focal_crossentropy(
where `alpha` is a float in the range of `[0, 1]`.
Args:
target: A tensor with the same shape as `output`.
output: A tensor.
apply_class_balancing: A bool, whether to apply weight balancing on the
binary classes 0 and 1.
alpha: A weight balancing factor for class 1, default is `0.25` as
mentioned in the reference. The weight for class 0 is `1.0 - alpha`.
gamma: A focusing parameter, default is `2.0` as mentioned in the
reference.
from_logits: Whether `output` is expected to be a logits tensor. By
default, we consider that `output` encodes a probability distribution.
target: A tensor with the same shape as `output`.
output: A tensor.
apply_class_balancing: A bool, whether to apply weight balancing on the
binary classes 0 and 1.
alpha: A weight balancing factor for class 1, default is `0.25` as
mentioned in the reference. The weight for class 0 is `1.0 - alpha`.
gamma: A focusing parameter, default is `2.0` as mentioned in the
reference.
from_logits: Whether `output` is expected to be a logits tensor. By
default, we consider that `output` encodes a probability
distribution.
Returns:
A tensor.
A tensor.
"""
sigmoidal = tf.__internal__.smart_cond.smart_cond(
from_logits,
lambda: sigmoid(output),
lambda: output,
)

sigmoidal = sigmoid(output) if from_logits else output

p_t = target * sigmoidal + (1 - target) * (1 - sigmoidal)

# Calculate focal factor
focal_factor = tf.pow(1.0 - p_t, gamma)

# Binary crossentropy
bce = binary_crossentropy(
target=target,
Expand Down Expand Up @@ -5893,7 +5898,7 @@ def sigmoid(x):
Returns:
A tensor.
"""
return tf.sigmoid(x)
return tf.math.sigmoid(x)


@keras_export("keras.backend.hard_sigmoid")
Expand Down

0 comments on commit a188434

Please sign in to comment.