/
flip.py
56 lines (41 loc) · 1.58 KB
/
flip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import six
from chainer.backends import cuda
from chainer import function_node
from chainer.utils import type_check
def _flip(array, axis):
indices = [slice(None)] * array.ndim
indices[axis] = slice(None, None, -1)
return array[tuple(indices)]
class Flip(function_node.FunctionNode):
"""Flips an input variable in reverse order along the given axis."""
def __init__(self, axis):
if not isinstance(axis, six.integer_types):
raise TypeError('axis must be int')
self.axis = axis
def check_type_forward(self, in_types):
type_check.expect(in_types.size() == 1)
x_type = in_types[0]
type_check.expect(x_type.ndim > 0)
if self.axis >= 0:
type_check.expect(x_type.ndim > self.axis)
else:
type_check.expect(x_type.ndim >= -self.axis)
def forward(self, inputs):
xp = cuda.get_array_module(*inputs)
if hasattr(xp, 'flip'): # numpy.flip is supported from version 1.12.0
return xp.flip(inputs[0], self.axis),
else:
return _flip(inputs[0], self.axis),
def backward(self, indexes, grad_outputs):
return flip(grad_outputs[0], self.axis),
def flip(x, axis):
"""Flips an input variable in reverse order along the given axis.
Args:
x (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`):
Input variable.
axis (int): Axis along which the input variable is reversed.
Returns:
~chainer.Variable: Output variable.
"""
return Flip(axis).apply((x,))[0]