Skip to content

Commit

Permalink
Implemented new version of transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
unnonouno committed Aug 15, 2017
1 parent a1320b9 commit 9c7c692
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions chainer/functions/array/transpose.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import numpy

from chainer import function
from chainer import function_node
from chainer.utils import type_check


class Transpose(function.Function):
class Transpose(function_node.FunctionNode):
"""Permute the dimensions of an array."""

def __init__(self, axes=None):
Expand All @@ -23,13 +23,13 @@ def forward(self, inputs):
y = x.transpose(self.axes)
return y,

def backward(self, inputs, grad_outputs):
def backward(self, indexes, grad_outputs):
gy = grad_outputs[0]
inv_axes = self.axes
if self.axes:
axes = tuple(ax % len(self.axes) for ax in self.axes)
inv_axes = tuple(numpy.argsort(axes))
gx = gy.transpose(inv_axes)
gx = transpose(gy, inv_axes)
return gx,


Expand All @@ -45,4 +45,4 @@ def transpose(x, axes=None):
~chainer.Variable: Variable whose axes are permuted.
"""
return Transpose(axes)(x)
return Transpose(axes).apply((x,))[0]

0 comments on commit 9c7c692

Please sign in to comment.