/
get_item.py
148 lines (114 loc) · 4.67 KB
/
get_item.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import numpy
import chainer
from chainer.backends import cuda
from chainer import function_node
from chainer import utils
from chainer.utils import type_check
from chainer import variable
_numpy_supports_0d_bool_index = \
numpy.lib.NumpyVersion(numpy.__version__) >= '1.13.0'
class GetItem(function_node.FunctionNode):
"""Function that slices array and extract elements."""
def __init__(self, slices):
if isinstance(slices, list):
if all([isinstance(s, int) for s in slices]):
slices = slices,
slices = tuple(slices)
elif not isinstance(slices, tuple):
slices = slices,
if chainer.is_debug():
n_ellipses = 0
for s in slices:
if s is Ellipsis:
n_ellipses += 1
if n_ellipses > 1:
raise ValueError('Only one Ellipsis is allowed')
self.slices = slices
def check_type_forward(self, in_types):
type_check.expect(in_types.size() == 1)
def forward(self, xs):
return utils.force_array(xs[0][self.slices]),
def backward(self, indexes, gy):
return GetItemGrad(
self.slices, self.inputs[0].shape, self.inputs[0].dtype).apply(gy)
class GetItemGrad(function_node.FunctionNode):
def __init__(self, slices, in_shape, in_dtype):
self.slices = slices
self._in_shape = in_shape
self._in_dtype = in_dtype
def forward(self, inputs):
gy, = inputs
xp = cuda.get_array_module(*inputs)
gx = xp.zeros(self._in_shape, self._in_dtype)
if xp is numpy:
try:
numpy.add.at(gx, self.slices, gy)
except IndexError:
done = False
# In numpy<1.13, 0-dim boolean index is not supported in
# numpy.add.at and it's supported for 0-dim arr in
# arr.__getitem__.
if not _numpy_supports_0d_bool_index and len(self.slices) == 1:
idx = numpy.asanyarray(self.slices[0])
if idx.dtype == numpy.dtype(bool):
# Convert the array and the mask to 1-dim.
# numpy.add.at with them is supported in older numpy.
numpy.add.at(gx[None], idx[None], gy)
done = True
if not done:
msg = '''
GetItem does not support backward for this slices. The slices argument is not
supported by numpy.add.at, while it is supported by numpy.ndarray.__getitem__.
Please report this error to the issue tracker with the stack trace,
the information of your environment, and your script:
https://github.com/chainer/chainer/issues/new.
'''
raise IndexError(msg)
else:
gx.scatter_add(self.slices, inputs[0])
return gx,
def backward(self, indexes, ggx):
return GetItem(self.slices).apply(ggx)
def get_item(x, slices):
"""Extract elements from array with specified shape, axes and offsets.
Args:
x (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`): A variable to be sliced.
slices (int, slice, Ellipsis, None, integer array-like, boolean\
array-like or tuple of them):
An object to specify the selection of elements.
Returns:
A :class:`~chainer.Variable` object which contains sliced array of
``x``.
.. note::
It only supports types that are supported by CUDA's atomicAdd when
an integer array is included in ``slices``.
The supported types are ``numpy.float32``, ``numpy.int32``,
``numpy.uint32``, ``numpy.uint64`` and ``numpy.ulonglong``.
.. note::
It does not support ``slices`` that contains multiple boolean arrays.
.. note::
See NumPy document for details of `indexing
<https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html>`_.
.. admonition:: Example
>>> x = np.arange(12).reshape((2, 2, 3))
>>> x
array([[[ 0, 1, 2],
[ 3, 4, 5]],
<BLANKLINE>
[[ 6, 7, 8],
[ 9, 10, 11]]])
>>> F.get_item(x, 0)
variable([[0, 1, 2],
[3, 4, 5]])
>>> F.get_item(x, (0, 0, slice(0, 2, 1))) # equals x[0, 0, 0:2:1]
variable([0, 1])
>>> F.get_item(x, (Ellipsis, 2)) # equals x[..., 2]
variable([[ 2, 5],
[ 8, 11]])
>>> F.get_item(x, (1, np.newaxis, 1, 0)) # equals x[1, None, 1, 0]
variable([9])
"""
return GetItem(slices).apply((x,))[0]
def install_variable_get_item():
variable.Variable.__getitem__ = get_item