/
minmax.py
193 lines (145 loc) · 5.59 KB
/
minmax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import numpy
from chainer.backends import cuda
from chainer import function_node
import chainer.functions
import chainer.utils
from chainer.utils import type_check
class SelectorBase(function_node.FunctionNode):
"""Select an array element from a given axis or set of axes."""
def __init__(self, axis=None, keepdims=False):
self.keepdims = keepdims
if axis is None:
self.axis = None
elif isinstance(axis, int):
self.axis = (axis,)
elif isinstance(axis, tuple) and all(isinstance(a, int) for a in axis):
if len(set(axis)) != len(axis):
raise ValueError('duplicate value in axis: ({})'.format(
', '.join(map(str, axis))))
self.axis = axis
else:
raise TypeError('None, int or tuple of int are required')
def _fwd(self, x, xp):
raise NotImplementedError('_fwd should be implemented in sub-class.')
def check_type_forward(self, in_types):
type_check.expect(
in_types.size() == 1,
in_types[0].dtype.kind == 'f'
)
if self.axis is not None:
for axis in self.axis:
if axis >= 0:
type_check.expect(
axis < in_types[0].ndim,
)
else:
type_check.expect(
-axis - 1 < in_types[0].ndim,
)
def forward(self, x):
self.retain_inputs((0,))
self.retain_outputs((0,))
xp = cuda.get_array_module(*x)
return xp.asarray(self._fwd(x[0], xp)),
def backward(self, indexes, gy):
x = self.get_retained_inputs()[0]
y = self.get_retained_outputs()[0]
if self.axis is None:
axis = range(x.ndim)
else:
axis = [ax % x.ndim for ax in self.axis]
# Add broadcastable dimensions to y and gy
# for each one that was reduced in the forward operation
shape = [s if ax not in axis else 1 for ax, s in enumerate(x.shape)]
gy = gy[0].reshape(shape)
y = y.reshape(shape)
# Compute the gradient
cond = (x.data == y.data)
gy = chainer.functions.broadcast_to(gy, cond.shape)
return gy * cond,
class Max(SelectorBase):
def _fwd(self, x, xp):
return xp.amax(x, axis=self.axis, keepdims=self.keepdims)
class Min(SelectorBase):
def _fwd(self, x, xp):
return xp.amin(x, axis=self.axis, keepdims=self.keepdims)
class IndexSelectorBase(function_node.FunctionNode):
"""Select index of an array element from a given axis."""
def __init__(self, axis=None):
if axis is None:
self.axis = None
elif isinstance(axis, int):
self.axis = axis
else:
raise TypeError('None or int are required')
def _fwd(self, x, xp):
raise NotImplementedError('_fwd should be implemented in sub-class.')
def check_type_forward(self, in_types):
type_check.expect(
in_types.size() == 1,
in_types[0].dtype.kind == 'f'
)
if self.axis is not None:
if self.axis >= 0:
type_check.expect(
self.axis < in_types[0].ndim,
)
else:
type_check.expect(
-self.axis - 1 < in_types[0].ndim,
)
def forward(self, x):
xp = cuda.get_array_module(*x)
return xp.asarray(self._fwd(x[0], xp)),
def backward(self, indexes, grad_outputs):
return None,
class ArgMin(IndexSelectorBase):
def _fwd(self, x, xp):
return xp.argmin(x, axis=self.axis).astype(numpy.int32)
class ArgMax(IndexSelectorBase):
def _fwd(self, x, xp):
return xp.argmax(x, axis=self.axis).astype(numpy.int32)
def max(x, axis=None, keepdims=False):
"""Maximum of array elements over a given axis.
Args:
x (~chainer.Variable): Array to be maximized.
axis (None, int, or tuple of int): Axis over which a max is performed.
The default (axis = None) is perform a max over all the dimensions
of the input array.
Returns:
~chainer.Variable: Output variable.
"""
return Max(axis, keepdims).apply((x,))[0]
def min(x, axis=None, keepdims=False):
"""Minimum of array elements over a given axis.
Args:
x (~chainer.Variable): Array to be minimized.
axis (None, int, or tuple of int): Axis over which a min is performed.
The default (axis = None) is perform a min over all the dimensions
of the input array.
Returns:
~chainer.Variable: Output variable.
"""
return Min(axis, keepdims).apply((x,))[0]
def argmax(x, axis=None):
"""Returns index which holds maximum of array elements over a given axis.
Args:
x (~chainer.Variable): Array to find maximum elements.
axis (None or int): Axis over which a max is performed.
The default (axis = None) is perform a max over all the dimensions
of the input array.
Returns:
~chainer.Variable: Output variable.
"""
return ArgMax(axis).apply((x,))[0]
def argmin(x, axis=None):
"""Returns index which holds minimum of array elements over a given axis.
Args:
x (~chainer.Variable): Array to find minimum elements.
axis (None or int): Axis over which a min is performed.
The default (axis = None) is perform a min over all the dimensions
of the input array.
Returns:
~chainer.Variable: Output variable.
"""
return ArgMin(axis).apply((x,))[0]