Skip to content

Commit

Permalink
Merge pull request #1 from stuartarchibald/pr_6822_continued
Browse files Browse the repository at this point in the history
More constant inference for 6822.
  • Loading branch information
braniii authored Mar 22, 2021
2 parents 6e59554 + 85ab2e2 commit 700c0b3
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 26 deletions.
26 changes: 12 additions & 14 deletions numba/np/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,33 +1566,31 @@ def array_T(context, builder, typ, value):
@overload(np.rot90)
def numpy_rot90(arr, k=1):
# supporting axes argument it needs to be included in np.flip
if not isinstance(k, types.Integer):
if not isinstance(k, (int, types.Integer)):
raise errors.TypingError('The second argument "k" must be an integer')
if not isinstance(arr, types.Array):
raise errors.TypingError('The first argument "arr" must be an array')

const_arr_ndim = arr.ndim
if arr.ndim < 2:
raise ValueError('Input must be >= 1-d.')

axes_list = tuple([1, 0] + [*range(2, arr.ndim)])

def impl(arr, k=1):
arr = np.asarray(arr)

if arr.ndim < 2:
raise ValueError('Input must be >= 1-d.')

k = k % 4
if k == 0:
return arr[:]
if k == 2:
return np.flipud(np.fliplr(arr))

axes_list = np.arange(arr.ndim)
axes_list[:2] = [1, 0]
axes_list = to_fixed_tuple(axes_list, const_arr_ndim)

if k == 1:
elif k == 1:
return np.transpose(np.fliplr(arr), axes_list)
# k == 3
return np.fliplr(np.transpose(arr, axes_list))
elif k == 2:
return np.flipud(np.fliplr(arr))
elif k == 3:
return np.fliplr(np.transpose(arr, axes_list))
else:
assert 0 # unreachable
return impl


Expand Down
40 changes: 28 additions & 12 deletions numba/tests/test_np_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,11 @@ def flip(a):
return np.flip(a)


def rot90(a, k=1):
return np.rot90(a, k)
def rot90(a):
return np.rot90(a)

def rot90_k(a, k=1):
return np.rot90(a, k)

def array_split(a, indices, axis=0):
return np.array_split(a, indices, axis=axis)
Expand Down Expand Up @@ -2328,13 +2330,35 @@ def test_rot90_basic(self):
def a_variations():
yield np.arange(10).reshape(5, 2)
yield np.arange(20).reshape(5, 2, 2)
yield np.arange(64).reshape(2, 2, 2, 2, 2, 2)

for a in a_variations():
expected = pyfunc(a)
got = cfunc(a)
self.assertPreciseEqual(expected, got)

def test_rot90_with_k_basic(self):
pyfunc = rot90_k
cfunc = jit(nopython=True)(pyfunc)

def a_variations():
yield np.arange(10).reshape(5, 2)
yield np.arange(20).reshape(5, 2, 2)
yield np.arange(64).reshape(2, 2, 2, 2, 2, 2)

for a in a_variations():
for k in range(-3, 13):
for k in range(-5, 6):
expected = pyfunc(a, k)
got = cfunc(a, k)
self.assertPreciseEqual(expected, got)

def test_rot90_exception(self):
pyfunc = rot90_k
cfunc = jit(nopython=True)(pyfunc)

# Exceptions leak references
self.disable_leak_check()

with self.assertRaises(TypingError) as raises:
cfunc("abc")

Expand All @@ -2347,18 +2371,10 @@ def a_variations():
self.assertIn('The second argument "k" must be an integer',
str(raises.exception))

def test_rot90_exception(self):
pyfunc = rot90
cfunc = jit(nopython=True)(pyfunc)

# Exceptions leak references
self.disable_leak_check()

with self.assertRaises(TypingError) as raises:
cfunc(np.arange(3))

self.assertIn("cannot index array", str(raises.exception))
self.assertIn("with 2 indices", str(raises.exception))
self.assertIn("Input must be >= 1-d.", str(raises.exception))

def _check_split(self, func):
# Since np.split and np.array_split are very similar
Expand Down

0 comments on commit 700c0b3

Please sign in to comment.