Skip to content

Commit

Permalink
Merge b480407 into 3b0fe15
Browse files Browse the repository at this point in the history
  • Loading branch information
naoto0804 committed Aug 25, 2017
2 parents 3b0fe15 + b480407 commit 4bb7451
Showing 1 changed file with 34 additions and 4 deletions.
38 changes: 34 additions & 4 deletions chainer/functions/evaluation/binary_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,46 @@ def binary_accuracy(y, t):
"""Computes binary classification accuracy of the minibatch.
Args:
y (Variable): Variable holding a matrix whose i-th element
indicates the score of positive at the i-th example.
t (Variable): Variable holding an int32 vector of ground truth labels.
If ``t[i] == -1``, corresponding ``x[i]`` is ignored.
y (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`):
Array whose i-th element indicates the score of
positive at the i-th sample.
The prediction label :math:`\\hat t[i]` is ``1`` if
``y[i] >= 0``, otherwise ``0``.
t (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`):
Array holding an int32 vector of ground truth labels.
If ``t[i] == 1``, it indicates that i-th sample is positive.
If ``t[i] == 0``, it indicates that i-th sample is negative.
If ``t[i] == -1``, corresponding ``y[i]`` is ignored.
Accuracy is zero if all ground truth labels are ``-1``.
Returns:
Variable: A variable holding a scalar array of the accuracy.
.. note:: This function is non-differentiable.
.. admonition:: Example
We show the most common case, when ``y`` is the two dimensional array.
>>> y = np.array([[-2.0, 0.0], # prediction labels are [0, 1]
... [3.0, -5.0]]) # prediction labels are [1, 0]
>>> t = np.array([[0, 1],
... [1, 0]], 'i')
>>> F.binary_accuracy(y, t).data \
# 100% accuracy because all samples are correct.
array(1.0)
>>> t = np.array([[0, 0],
... [1, 1]], 'i')
>>> F.binary_accuracy(y, t).data \
# 50% accuracy because y[0][0] and y[1][0] are correct.
array(0.5)
>>> t = np.array([[0, -1],
... [1, -1]], 'i')
>>> F.binary_accuracy(y, t).data \
# 100% accuracy because of ignoring y[0][1] and y[1][1].
array(1.0)
"""
return BinaryAccuracy()(y, t)

0 comments on commit 4bb7451

Please sign in to comment.