/
select_item.py
116 lines (93 loc) · 3.32 KB
/
select_item.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import numpy
import six
import chainer
from chainer import cuda
from chainer import function_node
from chainer.utils import type_check
class SelectItem(function_node.FunctionNode):
"""Select elements stored in given indices."""
def check_type_forward(self, in_types):
type_check.expect(in_types.size() == 2)
x_type, t_type = in_types
type_check.expect(
t_type.dtype.kind == 'i',
x_type.ndim == 2,
t_type.ndim == 1,
x_type.shape[0] == t_type.shape[0],
)
def forward(self, inputs):
self.retain_inputs((1,))
x, t = inputs
self._in_shape = x.shape
self._in_dtype = x.dtype
if chainer.is_debug():
if not ((0 <= t).all() and
(t < x.shape[1]).all()):
msg = 'Each label `t` need to satisfty `0 <= t < x.shape[1]`'
raise ValueError(msg)
xp = cuda.get_array_module(x)
if xp is numpy:
# This code is equivalent to `t.choose(x.T)`, but `numpy.choose`
# does not work when `x.shape[1] > 32`.
return x[six.moves.range(t.size), t],
else:
y = cuda.elementwise(
'S t, raw T x',
'T y',
'int ind[] = {i, t}; y = x[ind];',
'getitem_fwd'
)(t, x)
return y,
def backward(self, indexes, gy):
t = self.get_retained_inputs()[0]
ret = []
if 0 in indexes:
ggx = Assign(self._in_shape, self._in_dtype, t).apply(gy)[0]
ret.append(ggx)
if 1 in indexes:
ret.append(None)
return ret
class Assign(function_node.FunctionNode):
def __init__(self, shape, dtype, t):
self.shape = shape
self.dtype = dtype
self.t = t.data
def forward_cpu(self, inputs):
gx = numpy.zeros(self.shape, self.dtype)
gx[six.moves.range(self.t.size), self.t] = inputs[0]
return gx,
def forward_gpu(self, inputs):
gx = cuda.cupy.zeros(self.shape, self.dtype)
gx = cuda.elementwise(
'S t, T gloss',
'raw T gx',
'int ind[] = {i, t}; gx[ind] = gloss;',
'getitem_bwd'
)(self.t, inputs[0], gx)
return gx,
def backward(self, indexes, gy):
return SelectItem().apply((gy[0], self.t))
def select_item(x, t):
"""Select elements stored in given indices.
This function returns ``t.choose(x.T)``, that means
``y[i] == x[i, t[i]]`` for all ``i``.
Args:
x (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`):
Variable storing arrays. A two-dimensional float array.
t (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`):
Variable storing index numbers. A one-dimensional int array.
Length of the ``t`` should be equal to ``x.shape[0]``.
Returns:
~chainer.Variable: Variable that holds ``t``-th element of ``x``.
.. admonition:: Example
>>> x = np.array([[0, 1, 2], [3, 4, 5]], 'f')
>>> t = np.array([0, 2], 'i')
>>> y = F.select_item(x, t)
>>> y.shape
(2,)
>>> y.data
array([ 0., 5.], dtype=float32)
"""
return SelectItem().apply((x, t))[0]