In [1]:
import tensorflow as tf
import numpy as np
import time

from tensorflow.keras.losses import Loss
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops.losses import util as tf_losses_util
from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras import backend as K
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
import six

In [22]:
class LossFunctionWrapper(Loss):
  """Wraps a loss function in the `Loss` class."""

  def __init__(self,
               fn,
               name=None,
               **kwargs):
    
    super(LossFunctionWrapper, self).__init__(name=name)
    self.fn = fn
    self._fn_kwargs = kwargs

  def call(self, y_true, y_pred):
    """Invokes the `LossFunctionWrapper` instance.
    Args:
      y_true: Ground truth values.
      y_pred: The predicted values.
    Returns:
      Loss values per sample.
    """
    if tensor_util.is_tensor(y_pred) and tensor_util.is_tensor(y_true):
      y_pred, y_true = tf_losses_util.squeeze_or_expand_dimensions(
          y_pred, y_true)
    ag_fn = autograph.tf_convert(self.fn, ag_ctx.control_status_ctx())
    return ag_fn(y_true, y_pred, **self._fn_kwargs)

  def get_config(self):
    config = {}
    for k, v in six.iteritems(self._fn_kwargs):
      config[k] = K.eval(v) if tf_utils.is_tensor_or_variable(v) else v
    base_config = super(LossFunctionWrapper, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))


class BinaryFocalCrossentropy(LossFunctionWrapper):
    def __init__(self, alpha=0.25, gamma=2.0, from_logits=False, axis=-1):
        super().__init__(fn=binary_focal_crossentropy)
        self.alpha = alpha
        self.gamma = gamma
        self.axis = axis
        self.from_logits = from_logits

def binary_focal_crossentropy(y_true, y_pred, alpha=0.25, gamma=2, from_logits=True, axis=-1):
    y_pred = ops.convert_to_tensor_v2(y_pred)
    if from_logits:
        # Transform logits to probabilities
        def sigmoid(x):
            return 1 / (1 + np.exp(-x))
        y_pred = sigmoid(y_pred)
    else:
        # Clip probabilities for numerical stability
        y_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon())

    y_true = math_ops.cast(y_true, y_pred.dtype)
    
    term_1 = y_true * alpha * tf.math.pow(1 - y_pred, gamma) * tf.math.log(y_pred)
    term_0 = (1 - y_true) * (1 - alpha) * tf.math.pow(y_pred, gamma) * tf.math.log(1 - y_pred)
    focal_ce = -(term_1 + term_0)

    return K.mean(focal_ce, axis=axis)

In [23]:
y_true = np.array([0, 1, 0, 0])
y_pred = np.array([-18.6, 0.51, 2.94, -12.8])
bce = BinaryFocalCrossentropy()
round(bce(y_true, y_pred).numpy(), 5)

0.51013

: 

In our data, y_pred and y_true have the format: batch_size X nr_classes (rows X columns).

In [16]:
# Test the implementation of binary_focal_crossentropy() doing it by hand below
y_true = np.array([1., 0., 1., 1., 1., 0.])
y_pred = np.array([0.8, 0.2, 0.7, 0.9, 0.8, 0.1])
start = time.time()
bce = BinaryFocalCrossentropy()
end = time.time()
print(f'{round(bce(y_true, y_pred).numpy(), 5)}, time: {end-start}')


# By 'hand':
y_true = np.array([1., 0., 1., 1., 1., 0.])
y_pred = np.array([0.8, 0.2, 0.7, 0.9, 0.8, 0.1])
alpha = 0.25
gamma = 2
term_1 = 0
term_0 = 0
start = time.time()
for i in range(len(y_true)):
    if y_true[i] == 1:
        term_1 += alpha * (1 - y_pred[i])**gamma * np.log(y_pred[i])
    else:
        term_0 += (1 - alpha) * y_pred[i]**gamma * np.log(1 - y_pred[i])
focal_loss = -(term_1 + term_0) / len(y_pred)
end = time.time()
print(f'{round(focal_loss, 5)}, time: {end-start}')

# Or vectorized:
start = time.time()
term_1 = y_true * alpha * (1 - y_pred)**gamma * np.log(y_pred)
term_0 = (1 - y_true) * (1 - alpha) * y_pred**gamma * np.log(1 - y_pred)
focal_loss = np.mean(-(term_1 + term_0))
end = time.time()
print(f'{round(focal_loss, 5)}, time: {end-start}')


0.00337, time: 0.00016617774963378906
0.00337, time: 0.0002703666687011719
0.00337, time: 0.0003037452697753906


BCE

In [62]:
y_pred = ops.convert_to_tensor_v2(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)

term_0 = (1 - y_true) * K.log(1 - y_pred + K.epsilon())
term_1 = y_true * K.log(y_pred + K.epsilon())
bce = -K.mean(term_0 + term_1)
bce


<tf.Tensor: shape=(), dtype=float64, numpy=0.22314342631421757>

BFCE

In [63]:
y_pred = ops.convert_to_tensor_v2(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
alpha = 0.25
gamma = 2

term_1 = -alpha * K.pow(1 - y_pred, gamma) * K.log(y_pred + K.epsilon())

term_0 = -(1 - alpha) * K.pow(y_pred, gamma) * K.log(1 - y_pred + K.epsilon())

bfce = -K.mean(term_0 + term_1)
bfce


<tf.Tensor: shape=(), dtype=float64, numpy=-0.7747613922315705>

In a scenario where we have 1M samples with label 1 and with p_t 0.99, and 10 samples with label 0 predicted with p_t 0.01. Then we have the following two cases happening in cross entropy and focal loss cross entropy:

In [48]:
# Cross entropy:
ce_label_1 = 10 * np.log(0.01) # 1_000_000 images
ce_label_0 = 1_000_000 * np.log(0.99) # 10 images
total_ce = -(ce_label_1 + ce_label_0)
print(f'The fraction of the minority class in the CE loss is: {ce_label_1/total_ce*100:.4f}%')

# Focal cross entropy:
alpha = 0.25
gamma = 2
fce_label_1 = 10 * (-alpha) * ((1 - 0.01)**gamma) * np.log(0.01) # 1_000_000 images
fce_label_0 = 1_000_000 * (-(1 - alpha)) * (0.01**gamma) * np.log(1- 0.01) # 10 images
total_fce = -(fce_label_1 + fce_label_0)
print(f'The fraction of the minority class in the CE loss is: {fce_label_1/total_fce*100:.4f}%')


The fraction of the minority class in the CE loss is: -0.4561%
The fraction of the minority class in the CE loss is: -93.7382%


Notice thus that the minority class, in the Focal Cross Estropy Loss function, receives much less importance. Thus, the gradients will be updated to repair the loss on the less known classes.