Skip to content

Commit

Permalink
Merge pull request apache#23 from vchuravy/vc/multiout_acc
Browse files Browse the repository at this point in the history
[RFC] reformultate accuracy with multi_output in mind
  • Loading branch information
pluskid committed Nov 11, 2015
2 parents ea90b55 + d51d2af commit 7947313
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 7 deletions.
4 changes: 2 additions & 2 deletions docs/api/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ libmxnet data providers


:param prefetch_buffer: Backend Param: Number of prefetched parameters
:type prefetch_buffer: , optional, default=4
:type prefetch_buffer: long (non-negative), optional, default=4


:param rand_crop: Augmentation Param: Whether to random crop on the image
Expand Down Expand Up @@ -460,7 +460,7 @@ libmxnet data providers


:param prefetch_buffer: Backend Param: Number of prefetched parameters
:type prefetch_buffer: , optional, default=4
:type prefetch_buffer: long (non-negative), optional, default=4

:return: the constructed :class:`MXDataProvider`.

Expand Down
3 changes: 3 additions & 0 deletions docs/api/metric.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,8 @@ set.

Multiclass classification accuracy.

Calculates the mean accuracy per sample for softmax in one dimension.
For a multi-dimensional softmax the mean accuracy over all dimensions is calculated.



23 changes: 23 additions & 0 deletions docs/api/symbolic-node.rst
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,29 @@ Public APIs



.. function:: SwapAxis(...)

Apply swapaxis to input.

:param data: Input data to the SwapAxisOp.
:type data: SymbolicNode


:param dim1: the first axis to be swapped.
:type dim1: int (non-negative), optional, default=0


:param dim2: the second axis to be swapped.
:type dim2: int (non-negative), optional, default=0

:param Base.Symbol name: The name of the :class:`SymbolicNode`. (e.g. `:my_symbol`), optional.

:return: the constructed :class:`SymbolicNode`.





.. function:: exp(...)

Take exp of the src
Expand Down
47 changes: 42 additions & 5 deletions src/metric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ abstract AbstractEvalMetric
.. class:: Accuracy
Multiclass classification accuracy.
Calculates the mean accuracy per sample for softmax in one dimension.
For a multi-dimensional softmax the mean accuracy over all dimensions is calculated.
=#
type Accuracy <: AbstractEvalMetric
acc_sum :: Float64
Expand All @@ -48,13 +51,47 @@ type Accuracy <: AbstractEvalMetric
Accuracy() = new(0.0, 0)
end

"""
Implementation taken from findmax in Julia base.
Searches for the maximum value in p_dim of a.
I and n are values for the other dimensions.
"""
function _indmax(a, I, p_dim, n)
m = a[I..., 1, n]
mi = 1
for i in 2:size(a, p_dim)
ai = a[I..., i, n]
if ai > m || m!=m
m = ai
mi = i
end
end
return mi
end

function _update_single_output(metric :: Accuracy, label :: NDArray, pred :: NDArray)
@nd_as_jl ro=(label,pred) begin
n_sample = size(pred)[end]
metric.n_sample += n_sample
for i = 1:n_sample
klass = indmax(pred[:,i])
metric.acc_sum += (klass-1) == label[i]
if ndims(pred) > 2 # Multidimensional case
# Construct cartesian index
p_dim = ndims(pred)-1
initial = tuple(fill(1,p_dim-1)...)
dims = size(pred, (1:p_dim-1)...)
crange = CartesianRange(CartesianIndex(initial), CartesianIndex(dims))

for sample in 1:size(label, ndims(label))
for i in crange
l_i = sub2ind(dims, i.I...)
klass = _indmax(pred, i.I, p_dim, sample)
metric.acc_sum += (klass-1) == label[l_i, sample]
metric.n_sample += 1
end
end
else # 1-dimensional case
for sample in 1:size(label, 1)
klass = indmax(pred[:, sample])
metric.acc_sum += (klass-1) == label[sample]
metric.n_sample += 1
end
end
end
end
Expand Down

0 comments on commit 7947313

Please sign in to comment.