/
sign.py
36 lines (26 loc) · 893 Bytes
/
sign.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import chainer
from chainer import backend
from chainer import utils
def sign(x):
"""Elementwise sign function.
For a given input :math:`x`, this function returns :math:`sgn(x)`
defined as
.. math::
sgn(x) = \\left \\{ \\begin{array}{cc}
-1 & {\\rm if~x < 0} \\\\
0 & {\\rm if~x = 0} \\\\
1 & {\\rm if~x > 0} \\\\
\\end{array} \\right.
.. note::
The gradient of this function is ``None`` everywhere and therefore
unchains the computational graph.
Args:
x (:class:`~chainer.Variable` or :ref:`ndarray`):
Input variable for which the sign is computed.
Returns:
~chainer.Variable: Output variable.
"""
if isinstance(x, chainer.variable.Variable):
x = x.array
xp = backend.get_array_module(x)
return chainer.as_variable(utils.force_array(xp.sign(x)))