-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
cast.py
56 lines (41 loc) · 1.41 KB
/
cast.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import chainer
from chainer import function_node
from chainer.utils import type_check
class Cast(function_node.FunctionNode):
"""Cast function."""
def __init__(self, typ):
self.type = typ
def check_type_forward(self, in_types):
type_check._argname(in_types, ('x',))
x_type = in_types[0]
type_check.expect(x_type.dtype.kind == 'f')
def forward(self, x):
self._in_type = x[0].dtype.type
return x[0].astype(self.type, copy=False),
def backward(self, indexes, g):
return cast(g[0], self._in_type),
def cast(x, typ):
"""Cast an input variable to a given type.
Args:
x (:class:`~chainer.Variable` or :ref:`ndarray`):
Input variable to be casted. A \
:math:`(s_1, s_2, ..., s_N)`-shaped float array.
typ (:class:`str` of dtype or :class:`numpy.dtype`):
Typecode or data type to cast.
Returns:
~chainer.Variable: Variable holding a casted array.
.. admonition:: Example
>>> x = np.arange(0, 3, dtype=np.float64)
>>> x.dtype
dtype('float64')
>>> y = F.cast(x, np.float32)
>>> y.dtype
dtype('float32')
>>> y = F.cast(x, 'float16')
>>> y.dtype
dtype('float16')
"""
if x.dtype == typ:
if not chainer.config.enable_backprop:
return chainer.as_variable(x)
return Cast(typ).apply((x,))[0]