[toc]

# Tensorflow metrics

| | positive | negative |
| -- | -- | -- |
| true | TP | TN |
| false | FP | FN |

$$
pecision = \frac{TP}{TP + FP}
$$

对于一个 batch 来说，可以这样计算

In [59]:
import tensorflow as tf
import numpy as np

labels = np.array([[1,1,1,0],
                   [1,1,1,0],
                   [1,1,1,0],
                   [1,1,1,0]], dtype=np.uint8)

predictions = np.array([[1,0,0,0],
                        [1,1,0,0],
                        [1,1,1,0],
                        [0,1,1,1]], dtype=np.uint8)

n_batches = len(labels)

# predictions positive
pred_p = (predictions > 0).sum()

# labels == 1 && predictions == 1
TP = (labels*predictions > 0).sum() 

precision = true_p / pred_p
print("Precision :%1.4f" %(precision))

Precision :0.8889


- 真实：A A A C B C A B B C
- 预测：A A C B A C A C B C

由于硬件方面的一些限制，导致此方法不能扩展到大型数据集，比如当数据集很大时，就无法一次性适应内存。 因而，为了使其可扩展，我们希望使评估指标能够逐步更新，每批新的预测和标签。 为此，我们需要跟踪两个值。

*   TP 正确预测的正样本数量
*   TP + FP，即预测样本中所有正样本的数量

## 所以我们要这么做

因此，我们修改为下面的版本。

In [61]:
# Initialize running variables
TP = 0
PRED_P = 0

# Specific steps
# Create running variables
TP = 0
PRED_P = 0

def reset_running_variables():
    """ Resets the previous values of running variables to zero """
    global TP, PRED_P
    TP = 0
    PRED_P = 0

def update_running_variables(labels, preds):
    global TP, PRED_P
    TP += ((labels * preds) > 0).sum()
    PRED_P += (preds > 0).sum()

def calculate_precision():
    global TP, PRED_P
    return float (TP) / PRED_P

## 怎么用上面的函数呢？

接下来的两个例子，给出了运用的具体代码，并且可以更好滴帮助我们理解`tf.metrics.precision()`的计算逻辑以及对应输出所代表的含义

### 样本整体准确率

In [64]:
# Overall precision
reset_running_variables()

for i in range(n_batches):
    update_running_variables(labels=labels[i], preds=predictions[i])

precision = calculate_precision()
print("[NP] SCORE: %1.4f" %precision)

[NP] SCORE: 0.8889


### 批次准确率

In [69]:
# Batch precision

for i in range(n_batches):
    reset_running_variables() # 每一个 batch 重置一下
    update_running_variables(labels=labels[i], preds=predictions[i])
    prec = calculate_precision()
    print("[NP] batch %d score: %1.4f" %(i, prec))

[NP] batch 0 score: 1.0000
[NP] batch 1 score: 1.0000
[NP] batch 2 score: 1.0000
[NP] batch 3 score: 0.6667


*  Github代码中precision的解释部分

> The `precision` function creates **two local variables**, `true_positives` and `false_positives`, that are used to compute the precision. This value is ultimately returned as `precision`, an idempotent operation that simply divides `true_positives` by the sum of `true_positives` and `false_positives`. For estimation of the metric over a stream of data, the function creates an `update_op` operation that updates these variables and returns the `precision`.

- 两个变量和 `tf.metrics.precision()`的关系

官方文档提及的**two local variables** ：`true_postives` 和 `false_positives`分别对应上文定义的两个变量。 *true_postives--N_TRUE_P* false_postives--N_PRED_P - N_TRUE_P

### 三个函数和头大的`update_op`

官方文档提及的`update_op`和`precision`分别对应上文定义的两个函数：
- precision : calculate_precision()
- update_op: update_running_variables()

大家不要被这个`update_op`搞晕，其实从字面来理解就是一个变量更新的操作，上文的代码中，就是通过`reset_running_variables()`的位置来决定何时对变量进行更新，其实就是对应于`tf.variables_initializer()`。

### overall 版本

In [77]:
# Overall precision using tensorflow
import tensorflow as tf

# 在 graph 下定义节点
graph = tf.Graph()
with graph.as_default():
    # Placeholders to take in batches onf data
    tf_label = tf.placeholder(dtype=tf.int32, shape=[None])
    tf_prediction = tf.placeholder(dtype=tf.int32, shape=[None])

    # Define the metric and update operations
    tf_metric, tf_metric_update = tf.metrics.precision(tf_label,
                                                      tf_prediction,
                                                      name="my_metric")

    # Isolate the variables stored behind the scenes by the metric operation
    running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="my_metric")

    # Define initializer to initialize/reset running variables
    running_vars_initializer = tf.variables_initializer(var_list=running_vars)


with tf.Session(graph=graph) as session:
    session.run(tf.global_variables_initializer())

    # initialize/reset the running variables
    session.run(running_vars_initializer)

    for i in range(n_batches):
        # Update the running variables on new batch of samples
        feed_dict = {tf_label: labels[i], tf_prediction: predictions[i]}
        session.run(tf_metric_update, feed_dict=feed_dict)

    # Calculate the score
    score = session.run(tf_metric)
    print("[TF] SCORE: %1.4f" %score)

[TF] SCORE: 0.8889


### batch 版本

In [79]:
import tensorflow as tf

# 在 graph 下定义节点
graph = tf.Graph()
with graph.as_default():
    # Placeholders to take in batches onf data
    tf_label = tf.placeholder(dtype=tf.int32, shape=[None])
    tf_prediction = tf.placeholder(dtype=tf.int32, shape=[None])

    # Define the metric and update operations
    tf_metric, tf_metric_update = tf.metrics.precision(tf_label,
                                                      tf_prediction,
                                                      name="my_metric")

    # Isolate the variables stored behind the scenes by the metric operation
    running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="my_metric")

    # Define initializer to initialize/reset running variables
    running_vars_initializer = tf.variables_initializer(var_list=running_vars)


with tf.Session(graph=graph) as session:
    session.run(tf.global_variables_initializer())


    for i in range(n_batches):
        
        # initialize/reset the running variables
        session.run(running_vars_initializer)
        # Update the running variables on new batch of samples
        feed_dict = {tf_label: labels[i], tf_prediction: predictions[i]}
        session.run(tf_metric_update, feed_dict=feed_dict)

        # Calculate the score
        score = session.run(tf_metric)
        print("[TF] batch %d score: %1.4f" %(i, score))

[TF] batch 0 score: 1.0000
[TF] batch 1 score: 1.0000
[TF] batch 2 score: 1.0000
[TF] batch 3 score: 0.6667


In [47]:
import tensorflow as tf
import numpy as np

tf.reset_default_graph()

predictions = np.array([0, 0, 2, 1, 0, 2, 0, 2, 1, 2])
targets = np.array([0, 0, 0, 2, 1, 2, 0, 1, 1, 2])

placeholder_predictions = tf.placeholder(tf.int32, [None])
placeholder_targets = tf.placeholder(tf.int32, [None])

recall = tf.metrics.recall(labels=placeholder_targets, predictions=placeholder_predictions)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    ret = sess.run(recall, feed_dict={placeholder_predictions: predictions, 
                                      placeholder_targets: targets})
    print(ret)

(0.0, 0.8333333)


In [113]:
from sklearn.metrics import recall_score

predictions = np.array([0, 0, 2, 1, 0, 2, 0, 2, 1, 2])
targets = np.array([0, 0, 0, 2, 1, 2, 0, 1, 1, 2])

micro_recall = recall_score(predictions, targets, average='micro')
macro_recall = recall_score(predictions, targets, average='macro')
print(micro_recall)
print(macro_recall)

0.6
0.5833333333333334


## 多分类

In [100]:
tf.reset_default_graph()

placeholder_predictions = tf.placeholder(tf.int32, [None])
placeholder_targets = tf.placeholder(tf.int32, [None])

num_labels = 3

def metric(y_true, y_pred):
    recall_n = [0] * num_labels
    precision_n = [0] * num_labels
    update_op_rec_n = [[]] * num_labels
    update_op_pre_n = [[]] * num_labels
    for k in range(num_labels):
        recall_n[k], update_op_rec_n[k] = tf.metrics.recall(
            labels=tf.equal(y_true, k),
            predictions=tf.equal(y_pred, k)
        )    
        precision_n[k], update_op_pre_n[k] = tf.metrics.precision(
            labels=tf.equal(y_true, k),
            predictions=tf.equal(y_pred, k)
        )
    recall_value = sum(recall_n) * 1.0 / num_labels
    precision_value = sum(precision_n) * 1.0 / num_labels
    update_op_rec = sum(update_op_rec_n) * 1.0 / num_labels
    update_op_pre = sum(update_op_pre_n) * 1.0 / num_labels
    recall = (recall_value, update_op_rec)
    precision = (precision_value, update_op_pre)
    return recall, precision

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    sess.run(metric(placeholder_targets, placeholder_predictions), 
             feed_dict={placeholder_predictions: predictions, 
                        placeholder_targets: targets})

FailedPreconditionError: Attempting to use uninitialized value recall/false_negatives/count
	 [[node recall/false_negatives/count/read (defined at <ipython-input-100-48d88e54ae9d>:16) ]]

Original stack trace for 'recall/false_negatives/count/read':
  File "/opt/anaconda3/envs/tars/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/opt/anaconda3/envs/tars/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/traitlets/config/application.py", line 845, in launch_instance
    app.start()
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/ipykernel/kernelapp.py", line 612, in start
    self.io_loop.start()
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tornado/platform/asyncio.py", line 199, in start
    self.asyncio_loop.run_forever()
  File "/opt/anaconda3/envs/tars/lib/python3.7/asyncio/base_events.py", line 541, in run_forever
    self._run_once()
  File "/opt/anaconda3/envs/tars/lib/python3.7/asyncio/base_events.py", line 1786, in _run_once
    handle._run()
  File "/opt/anaconda3/envs/tars/lib/python3.7/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tornado/ioloop.py", line 688, in <lambda>
    lambda f: self._run_callback(functools.partial(callback, future))
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tornado/ioloop.py", line 741, in _run_callback
    ret = callback()
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tornado/gen.py", line 814, in inner
    self.ctx_run(self.run)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tornado/gen.py", line 775, in run
    yielded = self.gen.send(value)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/ipykernel/kernelbase.py", line 365, in process_one
    yield gen.maybe_future(dispatch(*args))
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tornado/gen.py", line 234, in wrapper
    yielded = ctx_run(next, result)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/ipykernel/kernelbase.py", line 268, in dispatch_shell
    yield gen.maybe_future(handler(stream, idents, msg))
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tornado/gen.py", line 234, in wrapper
    yielded = ctx_run(next, result)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/ipykernel/kernelbase.py", line 545, in execute_request
    user_expressions, allow_stdin,
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tornado/gen.py", line 234, in wrapper
    yielded = ctx_run(next, result)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/ipykernel/ipkernel.py", line 306, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/ipykernel/zmqshell.py", line 536, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 2878, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 2923, in _run_cell
    return runner(coro)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/IPython/core/async_helpers.py", line 68, in _pseudo_sync_runner
    coro.send(None)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3147, in run_cell_async
    interactivity=interactivity, compiler=compiler, result=result)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3338, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3418, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-100-48d88e54ae9d>", line 33, in <module>
    sess.run(metric(placeholder_targets, placeholder_predictions),
  File "<ipython-input-100-48d88e54ae9d>", line 16, in metric
    predictions=tf.equal(y_pred, k)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tensorflow/python/ops/metrics_impl.py", line 2196, in recall
    name=None)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tensorflow/python/ops/metrics_impl.py", line 1562, in false_negatives
    updates_collections)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tensorflow/python/ops/metrics_impl.py", line 1495, in _count_condition
    count = metric_variable([], dtypes.float32, name='count')
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tensorflow/python/ops/metrics_impl.py", line 86, in metric_variable
    name=name)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tensorflow/python/ops/variables.py", line 259, in __call__
    return cls._variable_v1_call(*args, **kwargs)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tensorflow/python/ops/variables.py", line 220, in _variable_v1_call
    shape=shape)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tensorflow/python/ops/variables.py", line 198, in <lambda>
    previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tensorflow/python/ops/variable_scope.py", line 2511, in default_variable_creator
    shape=shape)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tensorflow/python/ops/variables.py", line 263, in __call__
    return super(VariableMetaclass, cls).__call__(*args, **kwargs)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tensorflow/python/ops/variables.py", line 1568, in __init__
    shape=shape)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tensorflow/python/ops/variables.py", line 1755, in _init_from_args
    self._snapshot = array_ops.identity(self._variable, name="read")
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py", line 180, in wrapper
    return target(*args, **kwargs)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py", line 86, in identity
    ret = gen_array_ops.identity(input, name=name)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 4253, in identity
    "Identity", input=input, name=name)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 3616, in create_op
    op_def=op_def)
  File "/opt/anaconda3/envs/tars/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 2005, in __init__
    self._traceback = tf_stack.extract_stack()


In [40]:
def metric_fn(per_example_loss, label_ids, logits, num_labels):
    predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
    accuracy = tf.metrics.accuracy(label_ids, predictions)
    y_true = label_ids
    y_pred = tf.argmax(logits, 1)
    recall_n = [0] * num_labels
    precision_n = [0] * num_labels
    update_op_rec_n = [[]] * num_labels
    update_op_pre_n = [[]] * num_labels
    for k in range(num_labels):
        recall_n[k], update_op_rec_n[k] = tf.metrics.recall(
            labels=tf.equal(y_true, k),
            predictions=tf.equal(y_pred, k)
        )    
        precision_n[k], update_op_pre_n[k] = tf.metrics.precision(
            labels=tf.equal(y_true, k),
            predictions=tf.equal(y_pred, k)
        )    
    recall_value = sum(recall_n) * 1.0 / num_labels
    precision_value = sum(precision_n) * 1.0 / num_labels
    update_op_rec = sum(update_op_rec_n) * 1.0 / num_labels
    update_op_pre = sum(update_op_pre_n) * 1.0 / num_labels
    recall = (recall_value, update_op_rec)
    precision = (precision_value, update_op_pre)
    loss = tf.metrics.mean(per_example_loss)
    return {
        "eval_accuracy": accuracy,
        "eval_loss": loss,
        "recall": recall,
        "precision": precision,
    }

# References

1. [TensorFlow: “Attempting to use uninitialized value” in variable initialization - Stack Overflow](https://stackoverflow.com/questions/44624648/tensorflow-attempting-to-use-uninitialized-value-in-variable-initialization/44630421)

2. [Tensorflow使用tf.metrics计算多分类效果指标 - 知乎](https://zhuanlan.zhihu.com/p/82183796)

3. [【0.2】Tensorflow踩坑记之tf.metrics - 知乎](https://zhuanlan.zhihu.com/p/43359894)

In [120]:
"""Multiclass"""

__author__ = "Guillaume Genthial"

import numpy as np
import tensorflow as tf
from tensorflow.python.ops.metrics_impl import _streaming_confusion_matrix


def precision(labels, predictions, num_classes, pos_indices=None,
              weights=None, average='micro'):
    """Multi-class precision metric for Tensorflow
    Parameters
    ----------
    labels : Tensor of tf.int32 or tf.int64
        The true labels
    predictions : Tensor of tf.int32 or tf.int64
        The predictions, same shape as labels
    num_classes : int
        The number of classes
    pos_indices : list of int, optional
        The indices of the positive classes, default is all
    weights : Tensor of tf.int32, optional
        Mask, must be of compatible shape with labels
    average : str, optional
        'micro': counts the total number of true positives, false
            positives, and false negatives for the classes in
            `pos_indices` and infer the metric from it.
        'macro': will compute the metric separately for each class in
            `pos_indices` and average. Will not account for class
            imbalance.
        'weighted': will compute the metric separately for each class in
            `pos_indices` and perform a weighted average by the total
            number of true labels for each class.
    Returns
    -------
    tuple of (scalar float Tensor, update_op)
    """
    cm, op = _streaming_confusion_matrix(
        labels, predictions, num_classes, weights)
    pr, _, _ = metrics_from_confusion_matrix(
        cm, pos_indices, average=average)
    op, _, _ = metrics_from_confusion_matrix(
        op, pos_indices, average=average)
    return (pr, op)


def recall(labels, predictions, num_classes, pos_indices=None, weights=None,
           average='micro'):
    """Multi-class recall metric for Tensorflow
    Parameters
    ----------
    labels : Tensor of tf.int32 or tf.int64
        The true labels
    predictions : Tensor of tf.int32 or tf.int64
        The predictions, same shape as labels
    num_classes : int
        The number of classes
    pos_indices : list of int, optional
        The indices of the positive classes, default is all
    weights : Tensor of tf.int32, optional
        Mask, must be of compatible shape with labels
    average : str, optional
        'micro': counts the total number of true positives, false
            positives, and false negatives for the classes in
            `pos_indices` and infer the metric from it.
        'macro': will compute the metric separately for each class in
            `pos_indices` and average. Will not account for class
            imbalance.
        'weighted': will compute the metric separately for each class in
            `pos_indices` and perform a weighted average by the total
            number of true labels for each class.
    Returns
    -------
    tuple of (scalar float Tensor, update_op)
    """
    cm, op = _streaming_confusion_matrix(
        labels, predictions, num_classes, weights)
    _, re, _ = metrics_from_confusion_matrix(
        cm, pos_indices, average=average)
    _, op, _ = metrics_from_confusion_matrix(
        op, pos_indices, average=average)
    return (re, op)


def f1(labels, predictions, num_classes, pos_indices=None, weights=None,
       average='micro'):
    return fbeta(labels, predictions, num_classes, pos_indices, weights,
                 average)


def fbeta(labels, predictions, num_classes, pos_indices=None, weights=None,
          average='micro', beta=1):
    """Multi-class fbeta metric for Tensorflow
    Parameters
    ----------
    labels : Tensor of tf.int32 or tf.int64
        The true labels
    predictions : Tensor of tf.int32 or tf.int64
        The predictions, same shape as labels
    num_classes : int
        The number of classes
    pos_indices : list of int, optional
        The indices of the positive classes, default is all
    weights : Tensor of tf.int32, optional
        Mask, must be of compatible shape with labels
    average : str, optional
        'micro': counts the total number of true positives, false
            positives, and false negatives for the classes in
            `pos_indices` and infer the metric from it.
        'macro': will compute the metric separately for each class in
            `pos_indices` and average. Will not account for class
            imbalance.
        'weighted': will compute the metric separately for each class in
            `pos_indices` and perform a weighted average by the total
            number of true labels for each class.
    beta : int, optional
        Weight of precision in harmonic mean
    Returns
    -------
    tuple of (scalar float Tensor, update_op)
    """
    cm, op = _streaming_confusion_matrix(
        labels, predictions, num_classes, weights)
    _, _, fbeta = metrics_from_confusion_matrix(
        cm, pos_indices, average=average, beta=beta)
    _, _, op = metrics_from_confusion_matrix(
        op, pos_indices, average=average, beta=beta)
    return (fbeta, op)


def safe_div(numerator, denominator):
    """Safe division, return 0 if denominator is 0"""
    numerator, denominator = tf.to_float(numerator), tf.to_float(denominator)
    zeros = tf.zeros_like(numerator, dtype=numerator.dtype)
    denominator_is_zero = tf.equal(denominator, zeros)
    return tf.where(denominator_is_zero, zeros, numerator / denominator)


def pr_re_fbeta(cm, pos_indices, beta=1):
    """Uses a confusion matrix to compute precision, recall and fbeta"""
    num_classes = cm.shape[0]
    neg_indices = [i for i in range(num_classes) if i not in pos_indices]
    cm_mask = np.ones([num_classes, num_classes])
    cm_mask[neg_indices, neg_indices] = 0
    diag_sum = tf.reduce_sum(tf.diag_part(cm * cm_mask))

    cm_mask = np.ones([num_classes, num_classes])
    cm_mask[:, neg_indices] = 0
    tot_pred = tf.reduce_sum(cm * cm_mask)

    cm_mask = np.ones([num_classes, num_classes])
    cm_mask[neg_indices, :] = 0
    tot_gold = tf.reduce_sum(cm * cm_mask)

    pr = safe_div(diag_sum, tot_pred)
    re = safe_div(diag_sum, tot_gold)
    fbeta = safe_div((1. + beta**2) * pr * re, beta**2 * pr + re)

    return pr, re, fbeta


def metrics_from_confusion_matrix(cm, pos_indices=None, average='micro',
                                  beta=1):
    """Precision, Recall and F1 from the confusion matrix
    Parameters
    ----------
    cm : tf.Tensor of type tf.int32, of shape (num_classes, num_classes)
        The streaming confusion matrix.
    pos_indices : list of int, optional
        The indices of the positive classes
    beta : int, optional
        Weight of precision in harmonic mean
    average : str, optional
        'micro', 'macro' or 'weighted'
    """
    num_classes = cm.shape[0]
    if pos_indices is None:
        pos_indices = [i for i in range(num_classes)]

    if average == 'micro':
        return pr_re_fbeta(cm, pos_indices, beta)
    elif average in {'macro', 'weighted'}:
        precisions, recalls, fbetas, n_golds = [], [], [], []
        for idx in pos_indices:
            pr, re, fbeta = pr_re_fbeta(cm, [idx], beta)
            precisions.append(pr)
            recalls.append(re)
            fbetas.append(fbeta)
            cm_mask = np.zeros([num_classes, num_classes])
            cm_mask[idx, :] = 1
            n_golds.append(tf.to_float(tf.reduce_sum(cm * cm_mask)))

        if average == 'macro':
            pr = tf.reduce_mean(precisions)
            re = tf.reduce_mean(recalls)
            fbeta = tf.reduce_mean(fbetas)
            return pr, re, fbeta
        if average == 'weighted':
            n_gold = tf.reduce_sum(n_golds)
            pr_sum = sum(p * n for p, n in zip(precisions, n_golds))
            pr = safe_div(pr_sum, n_gold)
            re_sum = sum(r * n for r, n in zip(recalls, n_golds))
            re = safe_div(re_sum, n_gold)
            fbeta_sum = sum(f * n for f, n in zip(fbetas, n_golds))
            fbeta = safe_div(fbeta_sum, n_gold)
            return pr, re, fbeta

    else:
        raise NotImplementedError()

In [128]:
import tensorflow as tf

y_true = [0, 0, 2, 1, 0, 2, 0, 2, 1, 2]
y_pred = [0, 0, 0, 2, 1, 2, 0, 1, 1, 2]

pos_indices = None  # Class 0 is the 'negative' class
num_classes = 3
average = 'macro'

# Tuple of (value, update_op)
macro_precision = precision(y_true, y_pred, num_classes, pos_indices, average='macro')
micro_precision = precision(y_true, y_pred, num_classes, pos_indices, average='micro')
# recall = recall(y_true, y_pred, num_classes, pos_indices, average=average)
# f2 = fbeta(y_true, y_pred, num_classes, pos_indices, average=average, beta=2)
# f1 =f1(y_true, y_pred, num_classes, pos_indices, average=average)

# Run the update op and get the updated value
with tf.Session() as sess:
    sess.run(tf.local_variables_initializer())
    print(sess.run(macro_precision[1]))
    print(sess.run(micro_precision[1]))

0.5833333
0.6
