/
piecewise.py
61 lines (50 loc) · 1.93 KB
/
piecewise.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import cupy
from cupy import _core
_piecewise_krnl = _core.ElementwiseKernel(
'bool cond, T value',
'T y',
'if (cond) y = value',
'cupy_piecewise_kernel'
)
def piecewise(x, condlist, funclist):
"""Evaluate a piecewise-defined function.
Args:
x (cupy.ndarray): input domain
condlist (list of cupy.ndarray):
Each boolean array/ scalar corresponds to a function
in funclist. Length of funclist is equal to that of
condlist. If one extra function is given, it is used
as the default value when the otherwise condition is met
funclist (list of scalars): list of scalar functions.
Returns:
cupy.ndarray: the scalar values in funclist on portions of x
defined by condlist.
.. warning::
This function currently doesn't support callable functions,
args and kw parameters.
.. seealso:: :func:`numpy.piecewise`
"""
if cupy.isscalar(condlist):
condlist = [condlist]
condlen = len(condlist)
funclen = len(funclist)
if condlen == funclen:
out = cupy.zeros(x.shape, x.dtype)
elif condlen + 1 == funclen:
func = funclist[-1]
funclist = funclist[:-1]
if callable(func):
raise NotImplementedError(
'Callable functions are not supported currently')
out = cupy.full(x.shape, func, x.dtype)
else:
raise ValueError('with {} condition(s), either {} or {} functions'
' are expected'.format(condlen, condlen, condlen + 1))
for condition, func in zip(condlist, funclist):
if callable(func):
raise NotImplementedError(
'Callable functions are not supported currently')
if isinstance(func, cupy.ndarray):
func = func.astype(x.dtype)
_piecewise_krnl(condition, func, out)
return out