In [3]:
import tensorflow as tf
from tensorflow.python.keras.utils import metrics_utils
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.keras.utils.generic_utils import to_list
from tensorflow.python.keras import backend as K

In [64]:
class BlanacedBinaryAccuracy(tf.keras.metrics.Metric):
    def __init__(
        self,
        thresholds=None,
        name=None,
        dtype=None
    ):
        super().__init__(name=name, dtype=dtype)
        self.init_thresholds = thresholds

        if thresholds is None:
            thresholds = [0.5]
        elif isinstance(thresholds, float):
            thresholds = [thresholds]
        self.thresholds = thresholds
        self.true_positives = self.add_weight('true_positives', shape=(len(self.thresholds),), initializer=init_ops.zeros_initializer)
        self.false_negatives = self.add_weight('false_negatives',shape=(len(self.thresholds),), initializer=init_ops.zeros_initializer)
        self.true_negatives = self.add_weight('true_negatives', shape=(len(self.thresholds),), initializer=init_ops.zeros_initializer)
        self.false_positives = self.add_weight('false_positives',shape=(len(self.thresholds),), initializer=init_ops.zeros_initializer)

    def update_state(self, y_true, y_pred, sample_weight=None):
        tf.print(y_true.shape, y_pred.shape)
        return metrics_utils.update_confusion_matrix_variables(
            {
                metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
                metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives,
                metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives,
                metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives,
            },
            y_true, y_pred,
            thresholds=self.thresholds,
        )

    def result(self):
        pos_acc = math_ops.div_no_nan(
            self.true_positives, self.true_positives + self.false_negatives)
        neg_acc = math_ops.div_no_nan(
            self.true_negatives, self.true_negatives + self.false_positives)
        acc = (pos_acc + neg_acc) / 2
        return acc[0] if len(self.thresholds) == 1 else tf.reduce_max(acc)

    def reset_states(self):
        num_thresholds = len(to_list(self.thresholds))
        K.batch_set_value(
            [(v, np.zeros((num_thresholds,))) for v in self.variables]
        )

    def get_config(self):
        config = {
            'thresholds': self.init_thresholds,
        }
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))


In [70]:
from sklearn.metrics import balanced_accuracy_score, confusion_matrix
import numpy as np

ths = [0.1, 0.5, 0.9]
acc = BlanacedBinaryAccuracy(ths)
xs = np.array([0, 0, 1, 0, 1, 1])
ys = np.array([0.3, 0.4, 0.5, 0.6, 0.9, 0.9])
print(acc(tf.convert_to_tensor(xs), tf.convert_to_tensor(ys)).numpy())
print(np.amax([balanced_accuracy_score(xs, ys > th) for th in ths]))


TensorShape([6]) TensorShape([6])
0.6666667
0.6666666666666666
