In [1]:
"""Classification with abstention metric classes and functions."""

import numpy as np
import tensorflow as tf

__author__ = "Elizabeth A. Barnes and Randal J. Barnes"
__date__ = "January 11, 2021"

# np.warnings.filterwarnings('ignore', category=np.VisibleDeprectionWarning)

# ------------------------------------------------------------------------
# FUNCTIONS
#
#   The following metric functions are used for comparison purposes and
#   plotting.  These are not necessarily tensorflow compliant.
# ------------------------------------------------------------------------


def compute_dnn_accuracy(y_true, y_pred, perc, tranquil=np.nan):
    """Compute the categorical accuracy for the predictions above the
    percentile threshold."""
    max_logits = np.max(y_pred, axis=-1)
    i = np.where(max_logits >= np.percentile(max_logits, 100 - perc))[0]
    met = tf.keras.metrics.CategoricalAccuracy()
    met.update_state(y_true[i, :], y_pred[i, :])
    return met.result().numpy()


def compute_dac_accuracy(y_true, y_pred, abstain):
    """Compute the categorical accuracy the predictions excluding abstentions."""
    cat_pred = tf.math.argmax(y_pred, axis=-1)
    mask = tf.math.not_equal(cat_pred, abstain)
    met = tf.keras.metrics.CategoricalAccuracy()
    met.update_state(tf.boolean_mask(y_true, mask), tf.boolean_mask(y_pred, mask))
    return met.result().numpy()

# ------------------------------------------------------------------------
# CLASSES
#
#   The following metrics classes are tensorflow compliant.
#
#   See page 390 of Geron, 2019, for a prototype of a metric class. See also,
#   https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Metric.
# ------------------------------------------------------------------------

class PredictionAccuracy(tf.keras.metrics.Metric):
    """Compute the prediction accuracy for an epoch.

    The prediction accuracy does not include abstentions. The prediction
    accuracy is the total number of correct predictions divided by the
    total number of predictions, across the entire epoch. This is not the
    same as the average of batch prediction accuracies.

    The computation is done by maintaining running sums of total predictions
    and correct predictions made across all batches in an epoch. The running
    sums are reset at the end of each epoch.

    """
    def __init__(self, abstain, **kwargs):
        super().__init__(**kwargs)
        self.abstain = abstain
        self.correct = self.add_weight("correct", initializer="zeros")
        self.total = self.add_weight("total", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        cat_pred = tf.math.argmax(y_pred, axis=-1)
        cat_true = tf.math.argmax(y_true, axis=-1)

        mask = tf.math.not_equal(cat_pred, self.abstain)
        cat_pred = tf.boolean_mask(cat_pred, mask)
        cat_true = tf.boolean_mask(cat_true, mask)

        batch_correct = tf.math.count_nonzero(tf.math.equal(cat_pred, cat_true))
        batch_total = tf.math.count_nonzero(mask)

        self.correct.assign_add(tf.cast(batch_correct, tf.float32))
        self.total.assign_add(tf.cast(batch_total, tf.float32))

    def result(self):
        return self.correct / self.total

    def get_config(self):
        base_config = super().get_config()
        return{**base_config}


class PredictionLoss(tf.keras.metrics.Metric):
    """Compute the prediction loss for epoch.

    The prediction loss does not include abstentions. Thus, the loss is the
    sample-by-sample cross entropy.

    The prediction loss is the sum predictions losses divided by the total
    number of predictions, across the entire epoch. This is not the same as
    the average of batch prediction losses.

    The computation is done by maintaining running sums of prediction losses
    prediction counts, across the entire epoch. The running sums are reset at
    the end of each epoch.

    """
    def __init__(self, abstain, **kwargs):
        super().__init__(**kwargs)
        self.abstain = abstain
        self.count = self.add_weight("count", initializer="zeros")
        self.total = self.add_weight("total", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        predicted = tf.math.argmax(y_pred, axis=-1)

        q = 1 - y_pred[:, -1]
        logq = tf.math.log(q)

        r = tf.boolean_mask(y_pred, y_true)
        logr = tf.math.log(r)

        mask = tf.math.not_equal(predicted, self.abstain)
        loss = tf.boolean_mask(logq - logr, mask)

        batch_count = tf.math.count_nonzero(mask)
        batch_total = tf.math.reduce_sum(loss)

        self.count.assign_add(tf.cast(batch_count, tf.float32))
        self.total.assign_add(tf.cast(batch_total, tf.float32))

    def result(self):
        return self.total / float(self.count)

    def get_config(self):
        base_config = super().get_config()
        return{**base_config}

2023-10-10 14:54:11.408131: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
def threat_score(true_pos, false_pos, false_neg):
    """
    Compute the threat score (Critical Success Index or POD).

    Args:
    true_pos (int): Number of true positives.
    false_pos (int): Number of false positives.
    false_neg (int): Number of false negatives.

    Returns:
    float: Threat score value between 0 and 1. Higher values indicate better prediction accuracy.
    """
    denominator = true_pos + false_pos + false_neg
    if denominator == 0:
        return 0.0  # Handle the case where there are no events in the ground truth
    
    return true_pos / denominator

In [4]:
def gilbert_skill_score(true_pos, false_pos, false_neg, chance_hit):
    """
    Compute the Gilbert Skill score. Incorporates number of hits due to random chance. A skill corrected verification measure of categorical forecast performance similar 
    to the critical success index (CSI) but which takes into account the number of hits due to chance.

    Args:
    true_pos (int): Number of true positives.
    false_pos (int): Number of false positives.
    false_neg (int): Number of false negatives.
    chance_hit : Number of correct hits due to purely random chance; CH= (A+B)(A+C)/n

    Returns:
    float: Threat score value between 0 and 1. Higher values indicate better prediction accuracy.
    
    Formula: GS= (A-CH)/(A+B+C-CH)
    """
    numerator = true_pos - chance_hit
    denominator = true_pos + false_pos + false_neg - chance_hit
    if denominator == 0:
        return 0.0  # Handle the case where there are no events in the ground truth
    
    return numerator / denominator

In [3]:
def confusion_matrix(predclasses, targclasses):

    class_names = np.unique(targclasses)

    table = []
    for pred_class in class_names:
        row = []
        for true_class in class_names:
            row.append(100 * np.mean(predclasses[targclasses == true_class] == pred_class))
        table.append(row)
    class_titles_t = ["T(Light)", "T(Heavy)"]
    class_titles_p = ["P(Light)", "P(Heavy)"]
    conf_matrix = pd.DataFrame(table, index=class_titles_p, columns=class_titles_t)
    display(conf_matrix.style.background_gradient(cmap='Blues').format("{:.1f}"))