From 2fcb74e6a5ae76e2191c85556b6a4fe9ec6f22d4 Mon Sep 17 00:00:00 2001 From: Ryosuke Okuta Date: Wed, 16 Aug 2017 01:57:54 +0900 Subject: [PATCH] Small improvement for transpose function --- chainer/functions/array/transpose.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/chainer/functions/array/transpose.py b/chainer/functions/array/transpose.py index b25b84a4ccf2..9c13d0a7f8e2 100644 --- a/chainer/functions/array/transpose.py +++ b/chainer/functions/array/transpose.py @@ -23,13 +23,11 @@ def forward(self, inputs): return y, def backward(self, indexes, grad_outputs): - gy = grad_outputs[0] inv_axes = self.axes - if self.axes: - axes = tuple(ax % len(self.axes) for ax in self.axes) - inv_axes = tuple(numpy.argsort(axes)) - gx = transpose(gy, inv_axes) - return gx, + if inv_axes: + axes_len = len(inv_axes) + inv_axes = tuple(numpy.argsort([ax % axes_len for ax in inv_axes])) + return Transpose(inv_axes).apply(grad_outputs) def transpose(x, axes=None):