/
stack.py
90 lines (83 loc) · 2.7 KB
/
stack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import chainer
def stack(xs, axis=0):
"""Concatenate variables along a new axis.
Args:
xs (list of :class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`):
Input variables to be concatenated. The variables must have the
same shape.
axis (int): The axis along which the arrays will be stacked. The
``axis`` parameter is acceptable when
:math:`-ndim - 1 \\leq axis \\leq ndim`. (``ndim`` is the
dimension of input variables). When :math:`axis < 0`, the result
is the same with :math:`ndim + 1 - |axis|`.
Returns:
~chainer.Variable:
Output variable. Let ``x_1, x_2, ..., x_n`` and ``y`` be the input
variables and the output variable,
``y[:, ..., 0, ..., :]`` is ``x_1``,
``y[:, ..., 1, ..., :]`` is ``x_2``
and ``y[:, ..., n-1, ..., :]`` is ``x_n`` (The indexed axis
indicates the ``axis``).
.. admonition:: Example
>>> x1 = np.arange(0, 12).reshape(3, 4)
>>> x1.shape
(3, 4)
>>> x1
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
>>> x2 = np.arange(12, 24).reshape(3, 4)
>>> x2.shape
(3, 4)
>>> x2
array([[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]])
>>> y = F.stack([x1, x2], axis=0)
>>> y.shape
(2, 3, 4)
>>> y.data
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
<BLANKLINE>
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
>>> y = F.stack([x1, x2], axis=1)
>>> y.shape
(3, 2, 4)
>>> y.data
array([[[ 0, 1, 2, 3],
[12, 13, 14, 15]],
<BLANKLINE>
[[ 4, 5, 6, 7],
[16, 17, 18, 19]],
<BLANKLINE>
[[ 8, 9, 10, 11],
[20, 21, 22, 23]]])
>>> y = F.stack([x1, x2], axis=2)
>>> y.shape
(3, 4, 2)
>>> y.data
array([[[ 0, 12],
[ 1, 13],
[ 2, 14],
[ 3, 15]],
<BLANKLINE>
[[ 4, 16],
[ 5, 17],
[ 6, 18],
[ 7, 19]],
<BLANKLINE>
[[ 8, 20],
[ 9, 21],
[10, 22],
[11, 23]]])
>>> y = F.stack([x1, x2], axis=-1)
>>> y.shape
(3, 4, 2)
"""
xs = [chainer.functions.expand_dims(x, axis) for x in xs]
return chainer.functions.concat(xs, axis)