Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cupy.piecewise #3329

Merged
merged 30 commits into from Jun 11, 2020
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d65c418
initial commit
Dahlia-Chehata May 11, 2020
f6214b8
fixes
Dahlia-Chehata May 12, 2020
0c3efeb
restore unrelated fixes
Dahlia-Chehata May 15, 2020
4e4519c
update docs
Dahlia-Chehata May 15, 2020
441a566
fix docs
Dahlia-Chehata May 15, 2020
33ccc0c
add tests with bool,complex types and overlapping conditions
Dahlia-Chehata May 16, 2020
b25d70a
Kernel fixes
Dahlia-Chehata May 16, 2020
b1db29f
minor fix
Dahlia-Chehata May 16, 2020
93e4fc6
fix otherwise test
Dahlia-Chehata May 17, 2020
1229486
fixes
Dahlia-Chehata May 18, 2020
e83f975
remove scalar support
Dahlia-Chehata May 18, 2020
4a1ab72
fix kernel
Dahlia-Chehata May 20, 2020
179271c
fix docs + 0dim/ndarray tests
Dahlia-Chehata May 20, 2020
aa73d23
fixes
Dahlia-Chehata May 22, 2020
41554e6
n-dim input with mismatched condlist shape
Dahlia-Chehata May 26, 2020
61a07c2
tuples support
Dahlia-Chehata May 31, 2020
1afe48c
tests fixes
Dahlia-Chehata Jun 1, 2020
2c12f61
tests fixes
Dahlia-Chehata Jun 2, 2020
e6f96db
remove additional memory
Dahlia-Chehata Jun 4, 2020
a7a5519
fixes
Dahlia-Chehata Jun 5, 2020
2b7e840
fix test
Dahlia-Chehata Jun 7, 2020
899b07e
fixes
Dahlia-Chehata Jun 8, 2020
00f2c9e
flake fix
Dahlia-Chehata Jun 8, 2020
4576eac
eliminate tuples and list of scalars support
Dahlia-Chehata Jun 9, 2020
98f9cbb
edit tests
Dahlia-Chehata Jun 9, 2020
2e2d5de
Cosmetic changes
asi1024 Jun 10, 2020
e74cbab
Drop unsupproted condlist type
asi1024 Jun 10, 2020
c7365d2
Reduce cpu-gpu device synchronozation for cupy.ndarray funclist
asi1024 Jun 10, 2020
72035ef
Enhance test converage
asi1024 Jun 10, 2020
38cd394
Merge remote-tracking branch 'cupy/master' into piecewise
asi1024 Jun 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions cupy/__init__.py
Expand Up @@ -56,6 +56,7 @@ def is_available():
import cupy.core.fusion # NOQA
from cupy import creation # NOQA
from cupy import fft # NOQA
from cupy import functional # NOQA
from cupy import indexing # NOQA
from cupy import io # NOQA
from cupy import linalg # NOQA
Expand Down Expand Up @@ -278,6 +279,11 @@ def is_available():
from cupy.creation.matrix import tril # NOQA
from cupy.creation.matrix import triu # NOQA

# -----------------------------------------------------------------------------
# Functional routines
# -----------------------------------------------------------------------------
from cupy.functional.piecewise import piecewise # NOQA

# -----------------------------------------------------------------------------
# Array manipulation routines
# -----------------------------------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions cupy/functional/__init__.py
@@ -0,0 +1,2 @@
# "NOQA" to suppress flake8 warning
from cupy.functional import piecewise # NOQA
68 changes: 68 additions & 0 deletions cupy/functional/piecewise.py
@@ -0,0 +1,68 @@
import cupy

from cupy import core

_piecewise_krnl = core.ElementwiseKernel(
'U condlist, raw S funclist, int64 xsiz',
'raw I prev, raw T y',
'''
if (condlist){
__syncthreads();
if (prev[i % xsiz] < i)
prev[i % xsiz] = i;
__syncthreads();
if (prev[i % xsiz] == i)
y[i % xsiz] = funclist[i / xsiz];
}
''',
'piecewise_kernel'
)


def piecewise(x, condlist, funclist):
"""Evaluate a piecewise-defined function.

Args:
x (cupy.ndarray): input domain
condlist (cupy.ndarray or list):
Dahlia-Chehata marked this conversation as resolved.
Show resolved Hide resolved
conditions list ( boolean arrays or boolean scalars).
Dahlia-Chehata marked this conversation as resolved.
Show resolved Hide resolved
Each boolean array/ scalar corresponds to a function
in funclist. Length of funclist is equal to that
of condlist. If one extra function is given, it is used
as otherwise condition
asi1024 marked this conversation as resolved.
Show resolved Hide resolved
funclist (cupy.ndarray or list): list of scalar functions.
Dahlia-Chehata marked this conversation as resolved.
Show resolved Hide resolved
Dahlia-Chehata marked this conversation as resolved.
Show resolved Hide resolved

Returns:
cupy.ndarray: the result of calling the functions in funclist
Dahlia-Chehata marked this conversation as resolved.
Show resolved Hide resolved
on portions of x defined by condlist.

.. warning::

This function currently doesn't support callable functions,
args and kw parameters.

.. seealso:: :func:`numpy.piecewise`
"""

if any(callable(item) for item in funclist):
raise NotImplementedError(
'Callable functions are not supported currently')
Dahlia-Chehata marked this conversation as resolved.
Show resolved Hide resolved
if cupy.isscalar(condlist):
condlist = cupy.full(shape=x.shape, fill_value=condlist)
if not isinstance(condlist[0], (list, cupy.ndarray)) and x.ndim != 0:
asi1024 marked this conversation as resolved.
Show resolved Hide resolved
condlist = [condlist]
condlist = cupy.array(condlist, dtype=bool)
condlen = len(condlist)
funclen = len(funclist)
if condlen == funclen:
y = cupy.zeros(shape=x.shape, dtype=x.dtype)
elif condlen + 1 == funclen: # o.w
Dahlia-Chehata marked this conversation as resolved.
Show resolved Hide resolved
y = cupy.full(shape=x.shape, fill_value=funclist[-1], dtype=x.dtype)
funclist = funclist[:-1]
else:
raise ValueError('with {} condition(s), either {} or {} functions'
' are expected'.format(condlen, condlen, condlen + 1))
funclist = cupy.asarray(funclist)
Dahlia-Chehata marked this conversation as resolved.
Show resolved Hide resolved
prev = cupy.full(shape=x.shape, fill_value=-1, dtype=int)
_piecewise_krnl(condlist, funclist, x.size, prev, y)
return y
10 changes: 10 additions & 0 deletions docs/source/reference/functional.rst
@@ -0,0 +1,10 @@
Functional programming
======================

.. https://docs.scipy.org/doc/numpy/reference/routines.functional.html

.. autosummary::
:toctree: generated/
:nosignatures:

cupy.piecewise
1 change: 1 addition & 0 deletions docs/source/reference/routines.rst
Expand Up @@ -16,6 +16,7 @@ These functions cover a subset of
binary
dtype
fft
functional
indexing
io
linalg
Expand Down
81 changes: 81 additions & 0 deletions tests/cupy_tests/functional_tests/test_piecewise.py
@@ -0,0 +1,81 @@
import unittest

import pytest

import cupy
from cupy import testing


class TestPiecewise(unittest.TestCase):
Dahlia-Chehata marked this conversation as resolved.
Show resolved Hide resolved

@testing.for_all_dtypes()
@testing.numpy_cupy_array_equal()
def test_piecewise(self, xp, dtype):
x = xp.linspace(2.5, 12.5, 6, dtype=dtype)
condlist = [x < 0, x >= 0, x < 5, x >= 1.5]
funclist = [-1, 1, 2, 5]
return xp.piecewise(x, condlist, funclist)

@testing.for_all_dtypes()
@testing.numpy_cupy_array_equal()
def test_scalar_input(self, xp, dtype):
x = dtype(2)
condlist = [x < 0, x >= 0]
funclist = [-10, 10]
return xp.piecewise(x, condlist, funclist)

@testing.for_all_dtypes()
@testing.numpy_cupy_array_equal()
def test_scalar_condition(self, xp, dtype):
x = testing.shaped_random(shape=(2, 3, 5), xp=xp, dtype=dtype)
condlist = True
funclist = [-10, 10]
return xp.piecewise(x, condlist, funclist)

@testing.for_signed_dtypes()
@testing.numpy_cupy_array_equal()
def test_otherwise_condition1(self, xp, dtype):
x = xp.linspace(-2, 20, 12, dtype=dtype)
condlist = [x > 15, x <= 5, x == 0, x == 10]
funclist = [-1, 0, 2, 3, -5]
return xp.piecewise(x, condlist, funclist)

@testing.for_all_dtypes()
@testing.numpy_cupy_array_equal()
def test_otherwise_condition2(self, xp, dtype):
x = cupy.array([-10, 20, 30, 40], dtype=dtype)
condlist = [[True, False, False, True], [True, False, False, True]]
funclist = [-1, 1, 2]
return xp.piecewise(x, condlist, funclist)

@testing.for_all_dtypes()
@testing.numpy_cupy_array_equal()
def test_piecewise_ndim(self, xp, dtype):
x = testing.shaped_random(shape=(2, 3, 5), xp=xp, dtype=dtype)
condlist = [x < 0, x > 0]
funclist = [-1, 1, 2]
return xp.piecewise(x, condlist, funclist)

@testing.for_all_dtypes()
def test_mismatched_lengths(self, dtype):
x = cupy.linspace(-2, 4, 6, dtype=dtype)
condlist = [x < 0, x >= 0]
funclist = [-1, 0, 2, 4, 5]
with pytest.raises(ValueError):
cupy.piecewise(x, condlist, funclist)

@testing.for_all_dtypes()
def test_callable_funclist(self, dtype):
x = cupy.linspace(-2, 4, 6, dtype=dtype)
condlist = [x < 0, x > 0]
funclist = [lambda x: -x, lambda x: x]
with pytest.raises(NotImplementedError):
cupy.piecewise(x, condlist, funclist)

@testing.for_all_dtypes()
def test_mixed_funclist(self, dtype):
x = cupy.linspace(-2, 2, 6, dtype=dtype)
condlist = [x < 0, x == 0, x > 0]
funclist = [-10, lambda x: -x, 10, lambda x: x]
with pytest.raises(NotImplementedError):
cupy.piecewise(x, condlist, funclist)