Skip to content

Commit

Permalink
Add Average Recall Metric (#34)
Browse files Browse the repository at this point in the history
Add average recall metric, and softmax to test model output.
  • Loading branch information
YiqinZhao authored and ybubnov committed Mar 20, 2019
1 parent 94ae5bb commit 41de08b
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 0 deletions.
16 changes: 16 additions & 0 deletions keras_metrics/__init__.py
Expand Up @@ -25,6 +25,18 @@ def fn(label=0, **kwargs):
sparse_categorical_metric = partial(
metric_fn, cast_strategy=casts.sparse_categorical)

binary_average_metric = partial(
metric_fn, cast_strategy=casts.binary_argmax
)

categorical_average_metric = partial(
metric_fn, cast_strategy=casts.argmax
)

sparse_categorical_average_metric = partial(
metric_fn, cast_strategy=casts.sparse_argmax
)


binary_true_positive = binary_metric(m.true_positive)
binary_true_negative = binary_metric(m.true_negative)
Expand All @@ -33,6 +45,7 @@ def fn(label=0, **kwargs):
binary_precision = binary_metric(m.precision)
binary_recall = binary_metric(m.recall)
binary_f1_score = binary_metric(m.f1_score)
binary_average_recall = binary_average_metric(m.average_recall)


categorical_true_positive = categorical_metric(m.true_positive)
Expand All @@ -42,6 +55,7 @@ def fn(label=0, **kwargs):
categorical_precision = categorical_metric(m.precision)
categorical_recall = categorical_metric(m.recall)
categorical_f1_score = categorical_metric(m.f1_score)
categorical_average_recall = categorical_average_metric(m.average_recall)


sparse_categorical_true_positive = sparse_categorical_metric(m.true_positive)
Expand All @@ -51,6 +65,8 @@ def fn(label=0, **kwargs):
sparse_categorical_precision = sparse_categorical_metric(m.precision)
sparse_categorical_recall = sparse_categorical_metric(m.recall)
sparse_categorical_f1_score = sparse_categorical_metric(m.f1_score)
sparse_categorical_average_recall = sparse_categorical_average_metric(
m.average_recall)


# For backward compatibility.
Expand Down
21 changes: 21 additions & 0 deletions keras_metrics/casts.py
Expand Up @@ -24,3 +24,24 @@ def sparse_categorical(y_true, y_pred, dtype="int32", label=0):
y_pred = K.cast(K.round(y_pred), dtype)

return y_true, y_pred


def binary_argmax(y_true, y_pred, dtype="int32", label=0):
y_true, y_pred = K.squeeze(y_true, axis=-1), K.squeeze(y_pred, axis=-1)
y_true, y_pred = K.cast(y_true, dtype=dtype), K.cast(y_pred, dtype=dtype)

return y_true, y_pred


def argmax(y_true, y_pred, dtype="int32", label=0):
y_true, y_pred = K.argmax(y_true, axis=-1), K.argmax(y_pred, axis=-1)
y_true, y_pred = K.cast(y_true, dtype=dtype), K.cast(y_pred, dtype=dtype)

return y_true, y_pred


def sparse_argmax(y_true, y_pred, dtype="int32", label=0):
y_true, y_pred = K.squeeze(y_true, axis=-1), K.argmax(y_pred, axis=-1)
y_true, y_pred = K.cast(y_true, dtype=dtype), K.cast(y_pred, dtype=dtype)

return y_true, y_pred
51 changes: 51 additions & 0 deletions keras_metrics/metrics.py
Expand Up @@ -226,3 +226,54 @@ def __call__(self, y_true, y_pred):
self.add_update(self.recall.updates)

return 2 * truediv(pr * rec, pr + rec + K.epsilon())


class average_recall(layer):
"""Create a metric for the average recall calculation.
"""

def __init__(self, name="average_recall", classes=2, **kwargs):
super(average_recall, self).__init__(name=name, **kwargs)

self.classes = classes

self.true = K.zeros(classes, dtype="int32")
self.pred = K.zeros(classes, dtype="int32")

def reset_states(self):
K.set_value(self.true, [0 for v in range(self.classes)])
K.set_value(self.pred, [0 for v in range(self.classes)])

def __call__(self, y_true, y_pred):
# Cast input
t, p = self.cast(y_true, y_pred, dtype="float64")

# Init a bias matrix
b = K.variable([truediv(1, (v + 1)) for v in range(self.classes)],
dtype="float64")

# Simulate to_categorical operation
t, p = K.expand_dims(t, axis=-1), K.expand_dims(p, axis=-1)
t, p = (t + 1) * b - 1, (p + 1) * b - 1

# Make correct position filled with 1
t, p = K.cast(t, "bool"), K.cast(p, "bool")
t, p = 1 - K.cast(t, "int32"), 1 - K.cast(p, "int32")

t, p = K.transpose(t), K.transpose(p)

# Results for current batch
batch_t = K.sum(t, axis=-1)
batch_p = K.sum(t * p, axis=-1)

# Accumulated results
total_t = self.true * 1 + batch_t
total_p = self.pred * 1 + batch_p

self.add_update(K.update_add(self.true, batch_t))
self.add_update(K.update_add(self.pred, batch_p))

tp = K.cast(total_p, dtype='float64')
tt = K.cast(total_t, dtype='float64')

return K.mean(truediv(tp, (tt + self.epsilon)))
12 changes: 12 additions & 0 deletions tests/test_metrics.py
Expand Up @@ -21,6 +21,7 @@ class TestMetrics(unittest.TestCase):
km.binary_precision,
km.binary_recall,
km.binary_f1_score,
km.binary_average_recall
]

categorical_metrics = [
Expand All @@ -31,6 +32,7 @@ class TestMetrics(unittest.TestCase):
km.categorical_precision,
km.categorical_recall,
km.categorical_f1_score,
km.categorical_average_recall,
]

sparse_categorical_metrics = [
Expand All @@ -41,6 +43,7 @@ class TestMetrics(unittest.TestCase):
km.sparse_categorical_precision,
km.sparse_categorical_recall,
km.sparse_categorical_f1_score,
km.sparse_categorical_average_recall,
]

def create_binary_samples(self, n):
Expand All @@ -60,6 +63,9 @@ def create_model(self, outputs, loss, metrics_fns):
model.add(keras.layers.Activation(keras.backend.sin))
model.add(keras.layers.Activation(keras.backend.abs))
model.add(keras.layers.Lambda(lambda x: K.concatenate([x]*outputs)))
scale = [v + 1 for v in range(outputs)]
model.add(keras.layers.Lambda(lambda x: (0.5 - x) * scale + 1))
model.add(keras.layers.Softmax())
model.compile(optimizer="sgd",
loss=loss,
metrics=self.create_metrics(metrics_fns))
Expand Down Expand Up @@ -126,10 +132,14 @@ def assert_metrics(self, model, samples_fn):
precision = metrics[4]
recall = metrics[5]
f1 = metrics[6]
average_recall = metrics[7]

expected_precision = tp_val / (tp_val + fp_val)
expected_recall = tp_val / (tp_val + fn_val)

expected_average_recall = (
expected_recall + (tn_val / (fp_val + tn_val))) / 2

f1_divident = (expected_precision*expected_recall)
f1_divisor = (expected_precision+expected_recall)
expected_f1 = (2 * f1_divident / f1_divisor)
Expand All @@ -145,6 +155,8 @@ def assert_metrics(self, model, samples_fn):
self.assertAlmostEqual(expected_precision, precision, places=places)
self.assertAlmostEqual(expected_recall, recall, places=places)
self.assertAlmostEqual(expected_f1, f1, places=places)
self.assertAlmostEqual(expected_average_recall,
average_recall, places=places)

def test_binary_metrics(self):
model = self.create_model(1, "binary_crossentropy",
Expand Down

0 comments on commit 41de08b

Please sign in to comment.