Skip to content

Commit

Permalink
Merge pull request #6262 from stuartarchibald/fix/2201
Browse files Browse the repository at this point in the history
Support dtype from str literal.
  • Loading branch information
stuartarchibald committed Sep 22, 2020
2 parents 8179143 + 3be2086 commit db74d8f
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 4 deletions.
5 changes: 4 additions & 1 deletion numba/core/typing/npydecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,10 @@ def parse_dtype(dtype):
return dtype.dtype
elif isinstance(dtype, types.TypeRef):
return dtype.instance_type
elif isinstance(dtype, types.StringLiteral):
dt = getattr(np, dtype.literal_value, None)
if dt is not None:
return from_dtype(dt)

def _parse_nested_sequence(context, typ):
"""
Expand Down Expand Up @@ -472,7 +476,6 @@ def _parse_nested_sequence(context, typ):
return 0, typ



@infer_global(np.array)
class NpArray(CallableTemplate):
"""
Expand Down
9 changes: 9 additions & 0 deletions numba/np/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -3507,6 +3507,7 @@ def numpy_empty_nd(context, builder, sig, args):

@lower_builtin(np.empty_like, types.Any)
@lower_builtin(np.empty_like, types.Any, types.DTypeSpec)
@lower_builtin(np.empty_like, types.Any, types.StringLiteral)
def numpy_empty_like_nd(context, builder, sig, args):
arrtype, shapes = _parse_empty_like_args(context, builder, sig, args)
ary = _empty_nd_impl(context, builder, arrtype, shapes)
Expand All @@ -3524,6 +3525,7 @@ def numpy_zeros_nd(context, builder, sig, args):

@lower_builtin(np.zeros_like, types.Any)
@lower_builtin(np.zeros_like, types.Any, types.DTypeSpec)
@lower_builtin(np.zeros_like, types.Any, types.StringLiteral)
def numpy_zeros_like_nd(context, builder, sig, args):
arrtype, shapes = _parse_empty_like_args(context, builder, sig, args)
ary = _empty_nd_impl(context, builder, arrtype, shapes)
Expand All @@ -3545,6 +3547,7 @@ def full(shape, value):


@lower_builtin(np.full, types.Any, types.Any, types.DTypeSpec)
@lower_builtin(np.full, types.Any, types.Any, types.StringLiteral)
def numpy_full_dtype_nd(context, builder, sig, args):

def full(shape, value, dtype):
Expand All @@ -3571,6 +3574,7 @@ def full_like(arr, value):


@lower_builtin(np.full_like, types.Any, types.Any, types.DTypeSpec)
@lower_builtin(np.full_like, types.Any, types.Any, types.StringLiteral)
def numpy_full_like_nd_type_spec(context, builder, sig, args):

def full_like(arr, value, dtype):
Expand Down Expand Up @@ -3599,6 +3603,7 @@ def ones(shape):


@lower_builtin(np.ones, types.Any, types.DTypeSpec)
@lower_builtin(np.ones, types.Any, types.StringLiteral)
def numpy_ones_dtype_nd(context, builder, sig, args):

def ones(shape, dtype):
Expand All @@ -3625,6 +3630,7 @@ def ones_like(arr):


@lower_builtin(np.ones_like, types.Any, types.DTypeSpec)
@lower_builtin(np.ones_like, types.Any, types.StringLiteral)
def numpy_ones_like_dtype_nd(context, builder, sig, args):

def ones_like(arr, dtype):
Expand All @@ -3651,6 +3657,7 @@ def identity(n):


@lower_builtin(np.identity, types.Integer, types.DTypeSpec)
@lower_builtin(np.identity, types.Integer, types.StringLiteral)
def numpy_identity_type_spec(context, builder, sig, args):

def identity(n, dtype):
Expand Down Expand Up @@ -4092,6 +4099,7 @@ def array_astype(context, builder, sig, args):

@lower_builtin(np.frombuffer, types.Buffer)
@lower_builtin(np.frombuffer, types.Buffer, types.DTypeSpec)
@lower_builtin(np.frombuffer, types.Buffer, types.StringLiteral)
def np_frombuffer(context, builder, sig, args):
bufty = sig.args[0]
aryty = sig.return_type
Expand Down Expand Up @@ -4327,6 +4335,7 @@ def assign(seqty, seq, shapes, indices):

@lower_builtin(np.array, types.Any)
@lower_builtin(np.array, types.Any, types.DTypeSpec)
@lower_builtin(np.array, types.Any, types.StringLiteral)
def np_array(context, builder, sig, args):
arrty = sig.return_type
ndim = arrty.ndim
Expand Down
4 changes: 2 additions & 2 deletions numba/tests/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ def compile(self, func, args, return_type=None, flags=DEFAULT_FLAGS):
from numba.core.registry import cpu_target

cache_key = (func, args, return_type, flags)
try:
if cache_key in self.cr_cache:
cr = self.cr_cache[cache_key]
except KeyError:
else:
# Register the contexts in case for nested @jit or @overload calls
# (same as compile_isolated())
with cpu_target.nested_context(self.typingctx, self.targetctx):
Expand Down
19 changes: 19 additions & 0 deletions numba/tests/test_array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ def np_frombuffer(b):
def np_frombuffer_dtype(b):
return np.frombuffer(b, dtype=np.complex64)

def np_frombuffer_dtype_str(b):
return np.frombuffer(b, dtype='complex64')

def np_frombuffer_allocated(shape):
"""
np.frombuffer() on a Numba-allocated buffer.
Expand Down Expand Up @@ -513,6 +516,22 @@ def test_np_frombuffer(self):
def test_np_frombuffer_dtype(self):
self.check_np_frombuffer(np_frombuffer_dtype)

def test_np_frombuffer_dtype_str(self):
self.check_np_frombuffer(np_frombuffer_dtype_str)

def test_np_frombuffer_dtype_non_const_str(self):
@jit(nopython=True)
def func(buf, dt):
np.frombuffer(buf, dtype=dt)

with self.assertRaises(TypingError) as raises:
func(bytearray(range(16)), 'int32')

excstr = str(raises.exception)
self.assertIn('No match', excstr)
self.assertIn('frombuffer(bytearray(uint8, 1d, C), dtype=unicode_type)',
excstr)

def check_layout_dependent_func(self, pyfunc, fac=np.arange):
def is_same(a, b):
return a.ctypes.data == b.ctypes.data
Expand Down
141 changes: 140 additions & 1 deletion numba/tests/test_dyn_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,32 @@ def func(n):
return pyfunc(n, _dtype)
self.check_1d(func)

def test_1d_dtype_str(self):
pyfunc = self.pyfunc
_dtype = 'int32'
def func(n):
return pyfunc(n, _dtype)
self.check_1d(func)

def func(n):
return pyfunc(n, 'complex128')
self.check_1d(func)

def test_1d_dtype_non_const_str(self):
pyfunc = self.pyfunc

@njit
def func(n, dt):
return pyfunc(n, dt)

with self.assertRaises(TypingError) as raises:
func(5, 'int32')

excstr = str(raises.exception)
self.assertIn('No match', excstr)
self.assertIn('{}({}, unicode_type)'.format(pyfunc.__name__,
np.intp.__name__), excstr)

def test_2d(self):
pyfunc = self.pyfunc
def func(m, n):
Expand Down Expand Up @@ -641,6 +667,12 @@ def func(m, n):
return pyfunc((m, n), dtype=np.complex64)
self.check_2d(func)

def test_2d_dtype_str_kwarg(self):
pyfunc = self.pyfunc
def func(m, n):
return pyfunc((m, n), dtype='complex64')
self.check_2d(func)

def test_alloc_size(self):
pyfunc = self.pyfunc
width = types.intp.bitwidth
Expand Down Expand Up @@ -684,6 +716,25 @@ def func(n):
return np.full(n, 4.5, dtype)
self.check_1d(func)

def test_1d_dtype_str(self):
def func(n):
return np.full(n, 4.5, 'bool_')
self.check_1d(func)

def test_1d_dtype_non_const_str(self):

@njit
def func(n, fv, dt):
return np.full(n, fv, dt)

with self.assertRaises(TypingError) as raises:
func((5,), 4.5, 'int32')

excstr = str(raises.exception)
self.assertIn('No match', excstr)
self.assertIn('full(UniTuple({} x 1), float64, unicode_type)'.format(
np.intp.__name__), excstr)

def test_2d(self):
def func(m, n):
return np.full((m, n), 4.5)
Expand Down Expand Up @@ -828,6 +879,35 @@ def func(arr):
return pyfunc(arr, dtype=np.int32)
self.check_like(func, np.float64)

def test_like_dtype_str_kwarg(self):
pyfunc = self.pyfunc
def func(arr):
return pyfunc(arr, dtype='int32')
self.check_like(func, np.float64)

def test_like_dtype_str_kwarg(self):
pyfunc = self.pyfunc
def func(arr):
return pyfunc(arr, dtype='int32')
self.check_like(func, np.float64)

def test_like_dtype_non_const_str(self):
pyfunc = self.pyfunc

@njit
def func(n, dt):
return pyfunc(n, dt)

with self.assertRaises(TypingError) as raises:
func(np.ones(4), 'int32')

excstr = str(raises.exception)

self.assertIn('No match', excstr)
self.assertIn(
'{}(array(float64, 1d, C), unicode_type)'.format(pyfunc.__name__),
excstr)


class TestNdZerosLike(TestNdEmptyLike):

Expand Down Expand Up @@ -897,6 +977,25 @@ def func(arr):
return np.full_like(arr, 4.5, dtype=np.bool_)
self.check_like(func, np.float64)

def test_like_dtype_str_kwarg(self):
def func(arr):
return np.full_like(arr, 4.5, 'bool_')
self.check_like(func, np.float64)

def test_like_dtype_non_const_str_kwarg(self):

@njit
def func(arr, fv, dt):
return np.full_like(arr, fv, dt)

with self.assertRaises(TypingError) as raises:
func(np.ones(3,), 4.5, 'int32')

excstr = str(raises.exception)
self.assertIn('No match', excstr)
self.assertIn('full_like(array(float64, 1d, C), float64, unicode_type)',
excstr)


class TestNdIdentity(BaseTest):

Expand All @@ -909,11 +1008,26 @@ def func(n):
self.check_identity(func)

def test_identity_dtype(self):
for dtype in (np.complex64, np.int16, np.bool_, np.dtype('bool')):
for dtype in (np.complex64, np.int16, np.bool_, np.dtype('bool'),
'bool_'):
def func(n):
return np.identity(n, dtype)
self.check_identity(func)

def test_like_dtype_non_const_str_kwarg(self):

@njit
def func(n, dt):
return np.identity(n, dt)

with self.assertRaises(TypingError) as raises:
func(4, 'int32')

excstr = str(raises.exception)
self.assertIn('No match', excstr)
self.assertIn('identity({}, unicode_type)'.format(np.intp.__name__),
excstr)


class TestNdEye(BaseTest):

Expand Down Expand Up @@ -1136,6 +1250,31 @@ def pyfunc(arg):
((),),
])

def test_1d_with_str_dtype(self):
def pyfunc(arg):
return np.array(arg, dtype='float32')

self.check_outputs(pyfunc,
[([2, 42],),
([3.5, 1.0],),
((1, 3.5, 42),),
((),),
])

def test_1d_with_non_const_str_dtype(self):

@njit
def func(arg, dt):
return np.array(arg, dtype=dt)

with self.assertRaises(TypingError) as raises:
func((5, 3), 'int32')

excstr = str(raises.exception)
self.assertIn('No match', excstr)
self.assertIn('array(UniTuple({} x 2), dtype=unicode_type)'.format(
np.intp.__name__), excstr)

def test_2d(self):
def pyfunc(arg):
return np.array(arg)
Expand Down

0 comments on commit db74d8f

Please sign in to comment.