/
reshape.py
97 lines (78 loc) · 2.76 KB
/
reshape.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import chainer
from chainer import function_node
from chainer.utils import type_check
def _count_unknown_dims(shape):
cnt = 0
for dim in shape:
cnt += dim < 0
return cnt
class Reshape(function_node.FunctionNode):
"""Reshapes an input array without copy."""
def __init__(self, shape):
self.shape = shape
self._cnt = _count_unknown_dims(shape)
assert self._cnt <= 1
def check_type_forward(self, in_types):
type_check.expect(
in_types.size() == 1,
)
x_type, = in_types
if self._cnt == 0:
type_check.expect(
type_check.prod(x_type.shape) == type_check.prod(self.shape))
else:
known_size = 1
for s in self.shape:
if s > 0:
known_size *= s
size_var = type_check.make_variable(
known_size, 'known_size(=%d)' % known_size)
type_check.expect(
type_check.prod(x_type.shape) % size_var == 0)
def forward(self, inputs):
x, = inputs
return x.reshape(self.shape),
def backward(self, indexes, grad_outputs):
gx, = grad_outputs
return reshape(gx, self.inputs[0].shape),
def reshape(x, shape):
"""Reshapes an input variable without copy.
Args:
x (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`): Input variable.
shape (:class:`tuple` of :class:`int` s):
Expected shape of the output array. The number of elements which
the array of ``shape`` contains must be equal to that of input
array. One shape dimension can be -1. In this case, the value is
inferred from the length of the array and remaining dimensions.
Returns:
~chainer.Variable:
Variable that holds a reshaped version of the input variable.
.. seealso:: :func:`numpy.reshape`, :func:`cupy.reshape`
.. admonition:: Example
>>> x = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
>>> y = F.reshape(x, (8,))
>>> y.shape
(8,)
>>> y.data
array([1, 2, 3, 4, 5, 6, 7, 8])
>>> y = F.reshape(x, (4, -1)) # the shape of output is inferred
>>> y.shape
(4, 2)
>>> y.data
array([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
>>> y = F.reshape(x, (4, 3)) \
# the shape of input and output are not consistent
Traceback (most recent call last):
...
Invalid operation is performed in: Reshape (Forward)
Expect: prod(in_types[0].shape) == prod((4, 3))
Actual: 8 != 12
"""
if x.shape == shape:
return chainer.as_variable(x)
y, = Reshape(shape).apply((x,))
return y