Skip to content

Commit

Permalink
Merge pull request chainer#5497 from toslunar/softmax-axis
Browse files Browse the repository at this point in the history
Support negative `axis` for `F.softmax`
  • Loading branch information
okuta committed Oct 21, 2018
1 parent c4ffdc8 commit 6b222b5
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
3 changes: 1 addition & 2 deletions chainer/functions/activation/softmax.py
Expand Up @@ -44,8 +44,7 @@ def check_type_forward(self, in_types):

type_check.expect(
x_type.dtype.kind == 'f',
x_type.ndim > 1,
self.axis < x_type.ndim
-x_type.ndim <= self.axis < x_type.ndim,
)

def forward(self, x):
Expand Down
Expand Up @@ -10,20 +10,25 @@
from chainer.testing import attr


@testing.parameterize(*testing.product({
'shape_axis':
[{'shape': None, 'axis': 1}, ] +
testing.product({'shape': ((2, 3),), 'axis': (0, 1)}) +
testing.product({'shape': ((2, 3, 4),), 'axis': (0, -1)}) +
testing.product({'shape': ((2, 3, 2, 3),), 'axis': (-3, 3)}),
'dtype': [numpy.float16, numpy.float32, numpy.float64],
}))
@testing.parameterize(*testing.product_dict(
[
{'shape': None, 'axis': 1},
{'shape': (5,), 'axis': 0},
{'shape': (2, 3), 'axis': 0},
{'shape': (2, 3), 'axis': 1},
{'shape': (2, 3, 4), 'axis': 0},
{'shape': (2, 3, 4), 'axis': -1},
{'shape': (2, 3, 2, 3), 'axis': -3},
{'shape': (2, 3, 2, 3), 'axis': 3},
],
testing.product({
'dtype': [numpy.float16, numpy.float32, numpy.float64],
}),
))
@testing.fix_random()
class TestSoftmax(unittest.TestCase):

def setUp(self):
self.shape = self.shape_axis['shape']
self.axis = self.shape_axis['axis']
if self.shape is None:
# For checking numerical stability
value = -5 if self.dtype == numpy.float16 else -1000
Expand Down

0 comments on commit 6b222b5

Please sign in to comment.