Skip to content

Commit

Permalink
Merge pull request #3630 from rjenc29/where_broadcast
Browse files Browse the repository at this point in the history
np.where with broadcasting
  • Loading branch information
sklam committed Jan 30, 2019
2 parents 8cbe10b + 300ea7d commit 92a40f4
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 8 deletions.
65 changes: 62 additions & 3 deletions numba/targets/arraymath.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import math
from collections import namedtuple
from enum import IntEnum
from functools import partial

import numpy as np

Expand Down Expand Up @@ -1644,7 +1645,7 @@ def determine_dtype(array_like):
array_like_dt = np.float64
if isinstance(array_like, types.Array):
array_like_dt = as_dtype(array_like.dtype)
elif isinstance(array_like, types.Number):
elif isinstance(array_like, (types.Number, types.Boolean)):
array_like_dt = as_dtype(array_like)
elif isinstance(array_like, (types.UniTuple, types.Tuple)):
coltypes = set()
Expand Down Expand Up @@ -2156,11 +2157,69 @@ def where_impl(cond, x, y):
return impl_ret_untracked(context, builder, sig.return_type, res)


@register_jitable
def _where_x_y_scalar(cond, x, y, res):
for idx, c in np.ndenumerate(cond):
res[idx] = x if c else y
return res


@register_jitable
def _where_x_scalar(cond, x, y, res):
for idx, c in np.ndenumerate(cond):
res[idx] = x if c else y[idx]
return res


@register_jitable
def _where_y_scalar(cond, x, y, res):
for idx, c in np.ndenumerate(cond):
res[idx] = x[idx] if c else y
return res


def _where_inner(context, builder, sig, args, impl):
cond, x, y = sig.args

x_dt = determine_dtype(x)
y_dt = determine_dtype(y)
npty = np.promote_types(x_dt, y_dt)

if cond.layout == 'F':
def where_impl(cond, x, y):
res = np.asfortranarray(np.empty(cond.shape, dtype=npty))
return impl(cond, x, y, res)
else:
def where_impl(cond, x, y):
res = np.empty(cond.shape, dtype=npty)
return impl(cond, x, y, res)

res = context.compile_internal(builder, where_impl, sig, args)
return impl_ret_untracked(context, builder, sig.return_type, res)


array_scalar_scalar_where = partial(_where_inner, impl=_where_x_y_scalar)
array_array_scalar_where = partial(_where_inner, impl=_where_y_scalar)
array_scalar_array_where = partial(_where_inner, impl=_where_x_scalar)


@lower_builtin(np.where, types.Any, types.Any, types.Any)
def any_where(context, builder, sig, args):
cond = sig.args[0]
cond, x, y = sig.args

if isinstance(cond, types.Array):
return array_where(context, builder, sig, args)
if isinstance(x, types.Array):
if isinstance(y, types.Array):
impl = array_where
elif isinstance(y, (types.Number, types.Boolean)):
impl = array_array_scalar_where
elif isinstance(x, (types.Number, types.Boolean)):
if isinstance(y, types.Array):
impl = array_scalar_array_where
elif isinstance(y, (types.Number, types.Boolean)):
impl = array_scalar_scalar_where

return impl(context, builder, sig, args)

def scalar_where_impl(cond, x, y):
"""
Expand Down
65 changes: 65 additions & 0 deletions numba/tests/test_array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,71 @@ def check_scal(scal):
for x in (0, 1, True, False, 2.5, 0j):
check_scal(x)

def test_np_where_3_broadcast_x_y_scalar(self):
pyfunc = np_where_3
cfunc = jit(nopython=True)(pyfunc)

def check_ok(args):
expected = pyfunc(*args)
got = cfunc(*args)
self.assertPreciseEqual(got, expected)

def a_variations():
a = np.linspace(-2, 4, 20)
self.random.shuffle(a)
yield a
yield a.reshape(2, 5, 2)
yield a.reshape(4, 5, order='F')
yield a.reshape(2, 5, 2)[::-1]

for a in a_variations():
params = (a > 0, 0, 1)
check_ok(params)

params = (a < 0, np.nan, 1 + 4j)
check_ok(params)

params = (a > 1, True, False)
check_ok(params)

def test_np_where_3_broadcast_x_or_y_scalar(self):
pyfunc = np_where_3
cfunc = jit(nopython=True)(pyfunc)

def check_ok(args):
condition, x, y = args

expected = pyfunc(condition, x, y)
got = cfunc(condition, x, y)
self.assertPreciseEqual(got, expected)

# swap x and y
expected = pyfunc(condition, y, x)
got = cfunc(condition, y, x)
self.assertPreciseEqual(got, expected)

def array_permutations():
x = np.arange(9).reshape(3, 3)
yield x
yield x * 1.1
yield np.asfortranarray(x)
yield x[::-1]
yield np.linspace(-10, 10, 60).reshape(3, 4, 5) * 1j

def scalar_permutations():
yield 0
yield 4.3
yield np.nan
yield True
yield 8 + 4j

for x in array_permutations():
for y in scalar_permutations():
x_mean = np.mean(x)
condition = x > x_mean
params = (condition, x, y)
check_ok(params)

def test_item(self):
pyfunc = array_item
cfunc = jit(nopython=True)(pyfunc)
Expand Down
15 changes: 10 additions & 5 deletions numba/typing/npydecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,11 +1147,16 @@ def generic(self, args, kws):
as_dtype(getattr(args[2], 'dtype', args[2]))))
if isinstance(cond, types.Array):
# array where()
if (cond.ndim == x.ndim == y.ndim):
if x.layout == y.layout == cond.layout:
retty = types.Array(retdty, x.ndim, x.layout)
else:
retty = types.Array(retdty, x.ndim, 'C')
if isinstance(x, types.Array) and isinstance(y, types.Array):
if (cond.ndim == x.ndim == y.ndim):
if x.layout == y.layout == cond.layout:
retty = types.Array(retdty, x.ndim, x.layout)
else:
retty = types.Array(retdty, x.ndim, 'C')
return signature(retty, *args)
else:
# x and y both scalar
retty = types.Array(retdty, cond.ndim, cond.layout)
return signature(retty, *args)
else:
# scalar where()
Expand Down

0 comments on commit 92a40f4

Please sign in to comment.