Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cupy.piecewise #3329

Merged
merged 30 commits into from Jun 11, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d65c418
initial commit
Dahlia-Chehata May 11, 2020
f6214b8
fixes
Dahlia-Chehata May 12, 2020
0c3efeb
restore unrelated fixes
Dahlia-Chehata May 15, 2020
4e4519c
update docs
Dahlia-Chehata May 15, 2020
441a566
fix docs
Dahlia-Chehata May 15, 2020
33ccc0c
add tests with bool,complex types and overlapping conditions
Dahlia-Chehata May 16, 2020
b25d70a
Kernel fixes
Dahlia-Chehata May 16, 2020
b1db29f
minor fix
Dahlia-Chehata May 16, 2020
93e4fc6
fix otherwise test
Dahlia-Chehata May 17, 2020
1229486
fixes
Dahlia-Chehata May 18, 2020
e83f975
remove scalar support
Dahlia-Chehata May 18, 2020
4a1ab72
fix kernel
Dahlia-Chehata May 20, 2020
179271c
fix docs + 0dim/ndarray tests
Dahlia-Chehata May 20, 2020
aa73d23
fixes
Dahlia-Chehata May 22, 2020
41554e6
n-dim input with mismatched condlist shape
Dahlia-Chehata May 26, 2020
61a07c2
tuples support
Dahlia-Chehata May 31, 2020
1afe48c
tests fixes
Dahlia-Chehata Jun 1, 2020
2c12f61
tests fixes
Dahlia-Chehata Jun 2, 2020
e6f96db
remove additional memory
Dahlia-Chehata Jun 4, 2020
a7a5519
fixes
Dahlia-Chehata Jun 5, 2020
2b7e840
fix test
Dahlia-Chehata Jun 7, 2020
899b07e
fixes
Dahlia-Chehata Jun 8, 2020
00f2c9e
flake fix
Dahlia-Chehata Jun 8, 2020
4576eac
eliminate tuples and list of scalars support
Dahlia-Chehata Jun 9, 2020
98f9cbb
edit tests
Dahlia-Chehata Jun 9, 2020
2e2d5de
Cosmetic changes
asi1024 Jun 10, 2020
e74cbab
Drop unsupproted condlist type
asi1024 Jun 10, 2020
c7365d2
Reduce cpu-gpu device synchronozation for cupy.ndarray funclist
asi1024 Jun 10, 2020
72035ef
Enhance test converage
asi1024 Jun 10, 2020
38cd394
Merge remote-tracking branch 'cupy/master' into piecewise
asi1024 Jun 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 5 additions & 9 deletions cupy/__init__.py
Expand Up @@ -7,11 +7,9 @@
from cupy import _environment
from cupy import _version


Dahlia-Chehata marked this conversation as resolved.
Show resolved Hide resolved
if sys.platform.startswith('win32') and (3, 8) <= sys.version_info: # NOQA
_environment._setup_win32_dll_directory() # NOQA


try:
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=ImportWarning,
Expand Down Expand Up @@ -39,7 +37,6 @@

raise ImportError(msg) from e


from cupy import cuda
# Do not make `cupy.cupyx` available because it is confusing.
import cupyx as _cupyx
Expand All @@ -51,11 +48,11 @@ def is_available():

__version__ = _version.__version__


from cupy import binary # NOQA
import cupy.core.fusion # NOQA
from cupy import creation # NOQA
from cupy import fft # NOQA
from cupy import functional # NOQA
from cupy import indexing # NOQA
from cupy import io # NOQA
from cupy import linalg # NOQA
Expand All @@ -69,12 +66,10 @@ def is_available():
from cupy import util # NOQA
from cupy import lib # NOQA


# import class and function
from cupy.core import ndarray # NOQA
from cupy.core import ufunc # NOQA


# =============================================================================
# Constants (borrowed from NumPy)
# =============================================================================
Expand All @@ -94,7 +89,6 @@ def is_available():
from numpy import PINF # NOQA
from numpy import PZERO # NOQA


# =============================================================================
# Data types (borrowed from NumPy)
#
Expand Down Expand Up @@ -643,13 +637,16 @@ def isscalar(element):
from cupy.misc import may_share_memory # NOQA
from cupy.misc import shares_memory # NOQA

# -----------------------------------------------------------------------------
# Functional routines
# -----------------------------------------------------------------------------
from cupy.functional.piecewise import piecewise # NOQA

# -----------------------------------------------------------------------------
# Padding
# -----------------------------------------------------------------------------
pad = padding.pad.pad


# -----------------------------------------------------------------------------
# Sorting, searching, and counting
# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -784,7 +781,6 @@ def get_array_module(*args):

disable_experimental_feature_warning = False


# set default allocator
_default_memory_pool = cuda.MemoryPool()
_default_pinned_memory_pool = cuda.PinnedMemoryPool()
Expand Down
2 changes: 2 additions & 0 deletions cupy/functional/__init__.py
@@ -0,0 +1,2 @@
# "NOQA" to suppress flake8 warning
from cupy.functional import piecewise # NOQA
76 changes: 76 additions & 0 deletions cupy/functional/piecewise.py
@@ -0,0 +1,76 @@
import cupy

from cupy import core
from cupy.core.core import ndarray
from cupy.core._reduction import ReductionKernel

_piecewise_krnl = ReductionKernel(
'S x1, T x2',
'U y',
'x1 ? x2 : NULL',
'b == NULL? a : b',
'y = a',
'NULL',
'piecewise'
)
asi1024 marked this conversation as resolved.
Show resolved Hide resolved


def piecewise(x, condlist, funclist):
"""
Evaluate a piecewise-defined function.

Args:
:param x: input domain
:param condlist: conditions list ( boolean arrays or boolean scalars)
Each boolean array/ scalar corresponds to a function
in funclist. Length of functionlist is equal to that
of condlist. If one extra function is given, it is used
as otherwise condition
:param funclist: list of scalar functions.
asi1024 marked this conversation as resolved.
Show resolved Hide resolved

Returns:
cupy.ndarray: the result of calling the functions in funclist
on portions of x defined by condlist.

.. warning::

This function currently doesn't support callable functions
args and kw parameters are not supported

.. seealso:: :func:`numpy.piecewise`
"""

if any(callable(item) for item in funclist):
raise NotImplementedError(
'Callable functions are not supported currently')
Dahlia-Chehata marked this conversation as resolved.
Show resolved Hide resolved
if cupy.isscalar(x):
x = cupy.asarray(x)
Dahlia-Chehata marked this conversation as resolved.
Show resolved Hide resolved
scalar = 0
if cupy.isscalar(condlist):
scalar = 1
if cupy.isscalar(condlist) or (
(not isinstance(condlist[0], (list, ndarray))) and x.ndim != 0):
condlist = [condlist]
condlist = cupy.array(condlist, dtype=bool)
condlen = len(condlist)
funclen = len(funclist)
if condlen + 1 == funclen: # o.w
Dahlia-Chehata marked this conversation as resolved.
Show resolved Hide resolved
condelse = ~ cupy.any(condlist, axis=0, keepdims=True)
condlist = core.concatenate_method([condlist, condelse], 0)
condlen += 1
elif condlen != funclen:
raise ValueError('with {} condition(s), either {} or {} functions'
' are expected'.format(condlen, condlen, condlen + 1))

funclist = cupy.asarray(funclist)
Dahlia-Chehata marked this conversation as resolved.
Show resolved Hide resolved
if scalar:
funclist = funclist[condlist]
condlist = cupy.ones(shape=(1, x.size), dtype=bool)

y = cupy.zeros(x.shape, x.dtype)
if not x.ndim:
_piecewise_krnl(condlist, funclist, y)
else:
for i in range(x.size):
Dahlia-Chehata marked this conversation as resolved.
Show resolved Hide resolved
_piecewise_krnl(condlist[:, i], funclist, y[i])
return y
73 changes: 73 additions & 0 deletions tests/cupy_tests/functional_tests/test_piecewise.py
@@ -0,0 +1,73 @@
import unittest

import pytest

import cupy
from cupy import testing


class TestPiecewise(unittest.TestCase):
Dahlia-Chehata marked this conversation as resolved.
Show resolved Hide resolved

@testing.for_all_dtypes(no_bool=True, no_complex=True)
Dahlia-Chehata marked this conversation as resolved.
Show resolved Hide resolved
@testing.numpy_cupy_array_equal()
def test_linespace(self, xp, dtype):
Dahlia-Chehata marked this conversation as resolved.
Show resolved Hide resolved
x = xp.linspace(-2.5, 2.5, 6, dtype=dtype)
condlist = [x < 0, x >= 0]
Dahlia-Chehata marked this conversation as resolved.
Show resolved Hide resolved
funclist = [-1, 1]
return xp.piecewise(x, condlist, funclist)

@testing.for_all_dtypes(no_bool=True, no_complex=True)
@testing.numpy_cupy_array_equal()
def test_scalar_value(self, xp, dtype):
Dahlia-Chehata marked this conversation as resolved.
Show resolved Hide resolved
x = dtype(2)
condlist = [x < 0, x >= 0]
funclist = [-10, 10]
return xp.piecewise(x, condlist, funclist)

@testing.for_all_dtypes(no_bool=True, no_complex=True)
@testing.numpy_cupy_array_equal()
def test_scalar_condition(self, xp, dtype):
x = xp.linspace(-2.5, 2.5, 4, dtype=dtype)
condlist = True
funclist = [-10, 10]
return xp.piecewise(x, condlist, funclist)

@testing.for_all_dtypes(no_bool=True, no_complex=True)
@testing.numpy_cupy_array_equal()
def test_otherwise_condition1(self, xp, dtype):
x = xp.linspace(-2, 4, 4, dtype=dtype)
condlist = [x < 0, x >= 0]
funclist = [-1, 0, 2]
return xp.piecewise(x, condlist, funclist)

@testing.for_all_dtypes(no_bool=True, no_complex=True)
@testing.numpy_cupy_array_equal()
def test_otherwise_condition2(self, xp, dtype):
x = cupy.array([-10, 20, 30, 40], dtype=dtype)
condlist = [[True, False, False, True], [True, False, False, True]]
funclist = [-1, 1, 2]
return xp.piecewise(x, condlist, funclist)

@testing.for_all_dtypes(no_bool=True, no_complex=True)
def test_mismatched_lengths(self, dtype):
x = cupy.linspace(-2, 4, 6, dtype=dtype)
condlist = [x < 0, x >= 0]
funclist = [-1, 0, 2, 4, 5]
with pytest.raises(ValueError):
cupy.piecewise(x, condlist, funclist)

@testing.for_all_dtypes(no_bool=True, no_complex=True)
def test_callable_funclist(self, dtype):
x = cupy.linspace(-2, 4, 6, dtype=dtype)
condlist = [x < 0, x > 0]
funclist = [lambda x: -x, lambda x: x]
with pytest.raises(NotImplementedError):
cupy.piecewise(x, condlist, funclist)

@testing.for_all_dtypes(no_bool=True, no_complex=True)
def test_mixed_funclist(self, dtype):
x = cupy.linspace(-2, 2, 6, dtype=dtype)
condlist = [x < 0, x == 0, x > 0]
funclist = [-10, lambda x: -x, 10, lambda x: x]
with pytest.raises(NotImplementedError):
cupy.piecewise(x, condlist, funclist)