In [None]:
# default_exp metrics

# Metrics

> Definition of the metrics that can be used in training models.

In [None]:
#hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#export 
import tensorflow as tf
from tensorflow.python.ops import math_ops, confusion_matrix, array_ops, init_ops
from tensorflow.python.framework import dtypes

## Intersection-Over-Union
In addition to the official tensorflow metrics (https://www.tensorflow.org/api_docs/python/tf/keras/metrics).

In [None]:
#export
class MeanIoU2(tf.keras.metrics.MeanIoU):
    """
    Computes the Intersection-Over-Union metric.
    Adjusted for probabilistic labels and different semantic classes from
    tf.keras.metrics.MeanIoU
    (https://www.tensorflow.org/api_docs/python/tf/keras/metrics/MeanIoU)

    Mean Intersection-Over-Union is a common evaluation metric for semantic image
    segmentation, which first computes the IOU for each semantic class and then
    computes the average over classes. IOU is defined as follows:
    IOU = true_positive / (true_positive + false_positive + false_negative).
    The predictions are accumulated in a confusion matrix, weighted by
    `sample_weight` and the metric is then calculated from it.

    If `class_id` is specified, we calculate the IoU by considering only the
    entries in the batch for which `class_id` is in the label.
    """
    def __init__(self,
                 num_classes,
                 class_id = None,
                 name=None,
                 dtype=None):
        super().__init__(num_classes=num_classes, name=name, dtype=dtype)
        self.class_id = class_id

    def update_state(self, y_true, y_pred, sample_weight=None):
        """Accumulates the confusion matrix statistics.
        Args:
          y_true: The ground truth values.
          y_pred: The predicted values.
          sample_weight: Optional weighting of each example. Defaults to 1. Can be a
            `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
            be broadcastable to `y_true`.
        Returns:
          Update op.
        """
        y_true = math_ops.cast(y_true, self._dtype)
        y_pred = math_ops.cast(y_pred, self._dtype)
        
        # Assign probabilistic labels to class
        y_pred = tf.math.round(y_pred)
        # y_pred = tf.math.greater(y_pred, 0.5)
        
        if self.class_id is not None:
            y_true = y_true[..., self.class_id]
            y_pred = y_pred[..., self.class_id]
        
        
        # Flatten the input if its rank > 1.
        if y_pred.shape.ndims > 1:
              y_pred = array_ops.reshape(y_pred, [-1])

        if y_true.shape.ndims > 1:
              y_true = array_ops.reshape(y_true, [-1])

        if sample_weight is not None and sample_weight.shape.ndims > 1:
              sample_weight = array_ops.reshape(sample_weight, [-1])

        # Accumulate the prediction to current confusion matrix.
        current_cm = confusion_matrix.confusion_matrix(
            y_true,
            y_pred,
            self.num_classes,
            weights=sample_weight,
            dtype=dtypes.float64)
        return self.total_cm.assign_add(current_cm)
    
    def result(self):
        """Compute the mean intersection-over-union via the confusion matrix."""
        sum_over_row = math_ops.cast(
            math_ops.reduce_sum(self.total_cm, axis=0), dtype=self._dtype)
        sum_over_col = math_ops.cast(
            math_ops.reduce_sum(self.total_cm, axis=1), dtype=self._dtype)
        true_positives = math_ops.cast(
            array_ops.diag_part(self.total_cm), dtype=self._dtype)

        # Select class
        if self.class_id is not None:
            sum_over_row = sum_over_row[self.class_id]
            sum_over_col = sum_over_col[self.class_id]
            true_positives = true_positives[self.class_id]

        # sum_over_row + sum_over_col =
        #     2 * true_positives + false_positives + false_negatives.
        denominator = sum_over_row + sum_over_col - true_positives

        # The mean is only computed over classes that appear in the
        # label or prediction tensor. If the denominator is 0, we need to
        # ignore the class.
        num_valid_entries = math_ops.reduce_sum(
            math_ops.cast(math_ops.not_equal(denominator, 0), dtype=self._dtype))
        iou = math_ops.div_no_nan(true_positives, denominator)
        return math_ops.div_no_nan(
            math_ops.reduce_sum(iou, name='mean_iou'), num_valid_entries)

### Example


In [None]:
m = MeanIoU2(num_classes=2)
m.update_state([1, 0, 1, 1], [0.8, 0.1, 0.99, 0.56])

<tf.Tensor: shape=(2, 2), dtype=float64, numpy=
array([[1., 0.],
       [0., 3.]])>

Show result

In [None]:
m.result().numpy()

1.0

__tf.keras.metrics.MeanIoU shows a different behaviour__

In [None]:
m = tf.keras.metrics.MeanIoU(num_classes=2)
m.update_state([1, 0, 1, 1], [0.8, 0.1, 0.99, 0.56])

<tf.Variable 'UnreadVariable' shape=(2, 2) dtype=float64, numpy=
array([[1., 0.],
       [3., 0.]])>

Show result

In [None]:
m.result().numpy()

0.125