/
rounding.py
84 lines (53 loc) · 1.96 KB
/
rounding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import warnings
from cupy import _core
from cupy._core import fusion
from cupy._math import ufunc
def around(a, decimals=0, out=None):
"""Rounds to the given number of decimals.
Args:
a (cupy.ndarray): The source array.
decimals (int): Number of decimal places to round to (default: 0).
If decimals is negative, it specifies the number of positions to
the left of the decimal point.
out (cupy.ndarray): Output array.
Returns:
cupy.ndarray: Rounded array.
.. seealso:: :func:`numpy.around`
"""
if fusion._is_fusing():
return fusion._call_ufunc(
_core.core._round_ufunc, a, decimals, out=out)
a = _core.array(a, copy=False)
return a.round(decimals, out=out)
def round(a, decimals=0, out=None):
return around(a, decimals, out=out)
def round_(a, decimals=0, out=None):
warnings.warn('Please use `round` instead.', DeprecationWarning)
return around(a, decimals, out=out)
rint = ufunc.create_math_ufunc(
'rint', 1, 'cupy_rint',
'''Rounds each element of an array to the nearest integer.
.. seealso:: :data:`numpy.rint`
''')
floor = ufunc.create_math_ufunc(
'floor', 1, 'cupy_floor',
'''Rounds each element of an array to its floor integer.
.. seealso:: :data:`numpy.floor`
''', support_complex=False)
ceil = ufunc.create_math_ufunc(
'ceil', 1, 'cupy_ceil',
'''Rounds each element of an array to its ceiling integer.
.. seealso:: :data:`numpy.ceil`
''', support_complex=False)
trunc = ufunc.create_math_ufunc(
'trunc', 1, 'cupy_trunc',
'''Rounds each element of an array towards zero.
.. seealso:: :data:`numpy.trunc`
''', support_complex=False)
fix = _core.create_ufunc(
'cupy_fix', ('e->e', 'f->f', 'd->d'),
'out0 = (in0 >= 0.0) ? floor(in0): ceil(in0)',
doc='''If given value x is positive, it return floor(x).
Else, it return ceil(x).
.. seealso:: :func:`numpy.fix`
''')