Skip to content

Commit

Permalink
Merge pull request chainer#3118 from unnonouno/average-tuple-axis
Browse files Browse the repository at this point in the history
Suppot tuple for axis argument in average
  • Loading branch information
niboshi committed Aug 29, 2017
1 parent f76bdd7 commit 8553ddc
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 7 deletions.
29 changes: 24 additions & 5 deletions chainer/functions/math/average.py
@@ -1,3 +1,5 @@
import six

from chainer.functions.array import broadcast
from chainer.functions.array import reshape
from chainer.functions.math import sum as sum_mod
Expand All @@ -8,7 +10,7 @@ def average(x, axis=None, weights=None, keepdims=False):
Args:
x (~chainer.Variable): Elements to sum.
axis (None or int): Axis which the method is performed.
axis (None or int or tuple of int): Axis which the method is performed.
With the default (axis = None) it performs a mean over all the
dimensions of the input array.
weights (None or chainer.Variable): An array holding weights to
Expand All @@ -24,12 +26,27 @@ def average(x, axis=None, weights=None, keepdims=False):
~chainer.Variable: Output variable.
"""
if axis is None:
pass
elif isinstance(axis, tuple):
axis = [a + x.ndim if a < 0 else a for a in axis]
axis.sort()
for a, b in six.moves.zip(axis, axis[1:]):
if a == b:
raise ValueError('duplicate value in \'axis\'')
axis = tuple(axis)
else:
if axis < 0:
axis += x.ndim
axis = (axis,)

if weights is not None:
if axis is not None and len(axis) > 1:
raise ValueError(
'tuple axis is not supported when weights is given')
divider = sum_mod.sum(weights)
if axis is not None:
if axis < 0:
axis += x.ndim
w_shape = [d if i == axis else 1 for i, d in enumerate(x.shape)]
w_shape = [d if i in axis else 1 for i, d in enumerate(x.shape)]
weights = broadcast.broadcast_to(
reshape.reshape(weights, w_shape), x.shape)

Expand All @@ -38,7 +55,9 @@ def average(x, axis=None, weights=None, keepdims=False):
if axis is None:
divider = x.size
else:
divider = x.shape[axis]
divider = 1
for a in axis:
divider *= x.shape[a]

x_sum = sum_mod.sum(x, axis, keepdims)
if weights is not None:
Expand Down
36 changes: 34 additions & 2 deletions tests/chainer_tests/functions_tests/math_tests/test_average.py
Expand Up @@ -15,7 +15,7 @@
@testing.parameterize(*(
testing.product({
'shape': [(3, 2, 4)],
'axis': [None, 0, 1, 2, -1],
'axis': [None, 0, 1, 2, -1, (0, 1), (1, -1)],
'dtype': [numpy.float16, numpy.float32, numpy.float64],
'use_weights': [True, False],
'keepdims': [True, False],
Expand All @@ -34,17 +34,23 @@ def setUp(self):
self.x = numpy.random.uniform(-1, 1, self.shape).astype(self.dtype)
if self.axis is None:
w_shape = self.shape
else:
elif isinstance(self.axis, int):
axis = self.axis
if axis < 0:
axis += ndim
w_shape = self.shape[axis],
else:
w_shape = tuple(self.shape[a] for a in self.axis)

g_shape = self.x.sum(axis=self.axis, keepdims=self.keepdims).shape
self.gy = numpy.random.uniform(-1, 1, g_shape).astype(self.dtype)
self.w = numpy.random.uniform(-1, 1, w_shape).astype(self.dtype)

def check_forward(self, x_data, axis, weights):
if self.use_weights and isinstance(self.axis, tuple):
# This condition is not supported
return

x = chainer.Variable(x_data)
if self.use_weights:
w = chainer.Variable(weights)
Expand Down Expand Up @@ -87,6 +93,10 @@ def test_forward_gpu(self):
cuda.to_gpu(self.x), self.axis, cuda.to_gpu(self.w))

def check_backward(self, x_data, y_grad, axis, w_data):
if self.use_weights and isinstance(self.axis, tuple):
# This condition is not supported
return

if self.use_weights:
def f(x, w):
return functions.average(
Expand All @@ -113,4 +123,26 @@ def test_backward_gpu(self):
cuda.to_gpu(self.w))


@testing.parameterize(*testing.product({
'dtype': [numpy.float16, numpy.float32, numpy.float64],
}))
class TestAverageDuplicateValueInAxis(unittest.TestCase):

def test_duplicate_value(self):
x = numpy.random.uniform(-1, 1, 24).reshape(2, 3, 4).astype(self.dtype)
with self.assertRaises(ValueError):
functions.average(x, axis=(0, 0))

def test_duplicate_value_negative(self):
x = numpy.random.uniform(-1, 1, 24).reshape(2, 3, 4).astype(self.dtype)
with self.assertRaises(ValueError):
functions.average(x, axis=(1, -2))

def test_weights_and_axis(self):
x = numpy.random.uniform(-1, 1, 24).reshape(2, 3, 4).astype(self.dtype)
w = numpy.random.uniform(-1, 1, 6).reshape(2, 3).astype(self.dtype)
with self.assertRaises(ValueError):
functions.average(x, axis=(0, 1), weights=w)


testing.run_module(__name__, __file__)

0 comments on commit 8553ddc

Please sign in to comment.