/
cast.py
59 lines (44 loc) · 1.44 KB
/
cast.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy
import chainer
from chainer import function_node
from chainer.utils import type_check
class Cast(function_node.FunctionNode):
"""Cast function."""
def __init__(self, typ):
self.type = typ
def check_type_forward(self, in_types):
type_check._argname(in_types, ('x',))
def forward(self, x):
self._in_type = x[0].dtype.type
return x[0].astype(self.type, copy=False),
def backward(self, indexes, g):
if numpy.dtype(self._in_type).kind != 'f':
gx = None
else:
gx = cast(g[0], self._in_type)
return gx,
def cast(x, typ):
"""Cast an input variable to a given type.
Args:
x (:class:`~chainer.Variable` or :ref:`ndarray`):
Input variable to be casted. A \
:math:`(s_1, s_2, ..., s_N)`-shaped array.
typ (:class:`str` of dtype or :class:`numpy.dtype`):
Typecode or data type to cast.
Returns:
~chainer.Variable: Variable holding a casted array.
.. admonition:: Example
>>> x = np.arange(0, 3, dtype=np.float64)
>>> x.dtype
dtype('float64')
>>> y = F.cast(x, np.float32)
>>> y.dtype
dtype('float32')
>>> y = F.cast(x, 'float16')
>>> y.dtype
dtype('float16')
"""
if x.dtype == typ:
if not chainer.config.enable_backprop:
return chainer.as_variable(x)
return Cast(typ).apply((x,))[0]