/
squared_difference.py
44 lines (35 loc) · 1.27 KB
/
squared_difference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from chainer.backends import cuda
from chainer import function_node
from chainer import utils
from chainer.utils import type_check
class SquaredDifference(function_node.FunctionNode):
"""Squared difference of input variables."""
def check_type_forward(self, in_types):
type_check.expect(in_types.size() == 2)
type_check.expect(
in_types[0].dtype.kind == 'f',
in_types[0].dtype == in_types[1].dtype,
in_types[0].shape == in_types[1].shape
)
def forward(self, inputs):
self.retain_inputs((0, 1))
xp = cuda.get_array_module(*inputs)
x1, x2 = inputs
difference = x1 - x2
y = xp.square(difference)
return utils.force_array(y, dtype=x1.dtype),
def backward(self, indexes, grads):
gy, = grads
x1, x2 = self.get_retained_inputs()
difference = x1 - x2
gx = gy * 2 * difference
return gx, -gx
def squared_difference(x1, x2):
"""Squared difference of input variables.
Args:
x1 (~chainer.Variable): Input variables to be compared.
x2 (~chainer.Variable): Input variables to be compared.
Returns:
~chainer.Variable: ``(x1 - x2) ** 2`` element-wise.
"""
return SquaredDifference().apply((x1, x2))[0]