/
rollaxis.py
73 lines (55 loc) · 2.04 KB
/
rollaxis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import six
from chainer import backend
from chainer import function_node
from chainer.utils import type_check
class Rollaxis(function_node.FunctionNode):
"""Roll axis of an array."""
def __init__(self, axis, start):
if not isinstance(axis, six.integer_types):
raise TypeError('axis must be int')
if not isinstance(start, six.integer_types):
raise TypeError('start must be int')
self.axis = axis
self.start = start
def check_type_forward(self, in_types):
type_check._argname(in_types, ('x',))
x_type = in_types[0]
if self.axis >= 0:
type_check.expect(x_type.ndim > self.axis)
else:
type_check.expect(x_type.ndim > -self.axis - 1)
if self.start >= 0:
type_check.expect(x_type.ndim >= self.start)
else:
type_check.expect(x_type.ndim > -self.start - 1)
def forward(self, inputs):
self.retain_inputs(())
self._in_ndim = inputs[0].ndim
xp = backend.get_array_module(*inputs)
return xp.rollaxis(inputs[0], self.axis, self.start),
def backward(self, indexes, gy):
axis = self.axis
if axis < 0:
axis += self._in_ndim
start = self.start
if start < 0:
start += self._in_ndim
if axis > start:
axis += 1
elif axis < start:
start -= 1
return Rollaxis(start, axis).apply(gy)
def rollaxis(x, axis, start=0):
"""Roll the axis backwards to the given position.
This function continues to be supported for backward compatibility,
but you should prefer
``chainer.functions.moveaxis(x, source, destination)``.
See :func:`chainer.functions.moveaxis`.
Args:
x (:class:`~chainer.Variable` or :ref:`ndarray`): Input variable.
axis (int): The axis to roll backwards.
start (int): The place to which the axis is moved.
Returns:
~chainer.Variable: Variable whose axis is rolled.
"""
return Rollaxis(axis, start).apply((x,))[0]