-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
broadcast.py
137 lines (104 loc) · 4.13 KB
/
broadcast.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import six
import chainer
from chainer import backend
from chainer import function_node
from chainer.utils import type_check
import chainerx
class Broadcast(function_node.FunctionNode):
"""Function that broadcasts given arrays."""
def check_type_forward(self, in_types):
type_check.expect(in_types.size() > 0)
shapes = [t.shape for t in in_types]
type_check.expect_broadcast_shapes(*shapes)
def forward(self, inputs):
self._xp = backend.get_array_module(*inputs)
self._in_shapes = [x.shape for x in inputs]
self._in_dtypes = [x.dtype for x in inputs]
return tuple(self._xp.broadcast_arrays(*inputs))
def backward(self, indexes, grad_outputs):
return tuple([None if grad_outputs[i] is None else
chainer.functions.sum_to(
grad_outputs[i], self.inputs[i].shape)
for i in indexes])
def broadcast(*args):
"""Broadcast given variables.
Args:
args (:class:`~chainer.Variable` or :ref:`ndarray`):
Input variables to be broadcasted. Each dimension of the shapes \
of the input variables must have the same size.
Returns:
~chainer.Variable: :class:`~chainer.Variable` or tuple of \
:class:`~chainer.Variable` objects which are broadcasted \
from the given arguments.
.. admonition:: Example
>>> x = np.random.uniform(0, 1, (3, 2)).astype(np.float32)
>>> y = F.broadcast(x)
>>> np.all(x == y.array)
True
>>> z = np.random.uniform(0, 1, (3, 2)).astype(np.float32)
>>> y, w = F.broadcast(x, z)
>>> np.all(x == y.array) & np.all(z == w.array)
True
"""
if len(args) == 1:
return chainer.as_variable(args[0])
return Broadcast().apply(args)
class BroadcastTo(function_node.FunctionNode):
"""Function that broadcasts an array to a new shape."""
def __init__(self, shape):
self._shape = tuple(shape)
def check_type_forward(self, in_types):
type_check._argname(in_types, ('x',))
ndim = type_check.make_variable(len(self._shape), 'len(shape)')
type_check.expect(in_types[0].ndim <= ndim)
shape = type_check.eval(in_types[0].shape)
# check the shape in inverse order
for i in six.moves.range(-1, -len(shape) - 1, -1):
if shape[i] == self._shape[i] or shape[i] == 1:
continue
expect = 'in_type[0].shape[%d] == %d' % (i, self._shape[i])
if self._shape[i] != 1:
expect += ' or in_type[0].shape[%d] == 1' % i
actual = 'in_type[0].shape: %s' % str(shape)
raise type_check.InvalidType(expect, actual)
def broadcast_to(self, inputs):
x, = inputs
return chainerx.broadcast_to(x, self.shape),
def forward(self, inputs):
x, = inputs
xp = backend.get_array_module(x)
if hasattr(xp, 'broadcast_to'):
return xp.broadcast_to(x, self._shape),
else:
# numpy 1.9 doesn't support broadcast_to method
dummy = xp.empty(self._shape)
bx, _ = xp.broadcast_arrays(x, dummy)
return bx,
def backward(self, indexes, grad_outputs):
gx, = grad_outputs
x_node, = self.inputs
return chainer.functions.sum_to(gx, x_node.shape),
def broadcast_to(x, shape):
"""Broadcast a given variable to a given shape.
Args:
x (:class:`~chainer.Variable` or :ref:`ndarray`):
Input variable to be broadcasted. A \
:math:`(s_1, s_2, ..., s_N)`-shaped float array.
shape (tuple): Tuple of :class:`int` of the shape of the \
output variable.
Returns:
~chainer.Variable: Output variable broadcasted to the given shape.
.. admonition:: Example
>>> x = np.arange(0, 3)
>>> x
array([0, 1, 2])
>>> y = F.broadcast_to(x, (3, 3))
>>> y.array
array([[0, 1, 2],
[0, 1, 2],
[0, 1, 2]])
"""
if x.shape == shape:
return chainer.as_variable(x)
y, = BroadcastTo(shape).apply((x,))
return y