diff --git a/chainer/functions/array/transpose.py b/chainer/functions/array/transpose.py index b25b84a4ccf2..9c13d0a7f8e2 100644 --- a/chainer/functions/array/transpose.py +++ b/chainer/functions/array/transpose.py @@ -23,13 +23,11 @@ def forward(self, inputs): return y, def backward(self, indexes, grad_outputs): - gy = grad_outputs[0] inv_axes = self.axes - if self.axes: - axes = tuple(ax % len(self.axes) for ax in self.axes) - inv_axes = tuple(numpy.argsort(axes)) - gx = transpose(gy, inv_axes) - return gx, + if inv_axes: + axes_len = len(inv_axes) + inv_axes = tuple(numpy.argsort([ax % axes_len for ax in inv_axes])) + return Transpose(inv_axes).apply(grad_outputs) def transpose(x, axes=None):