Skip to content
176 changes: 168 additions & 8 deletions numpy/lib/index_tricks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import division, absolute_import, print_function

import itertools
import sys
import math

Expand Down Expand Up @@ -517,7 +518,7 @@ class CClass(AxisConcatenator):
useful because of its common occurrence. In particular, arrays will be
stacked along their last axis after being upgraded to at least 2-D with
1's post-pended to the shape (column vectors made out of 1-D arrays).

See Also
--------
column_stack : Stack 1-D arrays as columns into a 2-D array.
Expand Down Expand Up @@ -602,6 +603,10 @@ class ndindex(object):
`*args` : ints
The size of each dimension of the array.

slices : tuple of slices (or single slice)
The slice which would be taken in the desired array of the provided
shape.

See Also
--------
ndenumerate, flatiter
Expand All @@ -617,15 +622,88 @@ class ndindex(object):
(2, 0, 0)
(2, 1, 0)

The provided shape can be in the form of a tuple as well

>>> for index in np.ndindex((4, 2)):
... print(index)
(0, 0)
(0, 1)
(1, 0)
(1, 1)
(2, 0)
(2, 1)
(3, 0)
(3, 1)

Slicing into the shape is also possible

>>> for index in np.ndindex((10, 4), np.s_[::6, ::2]):
... print(index)
(0, 0)
(0, 2)
(6, 0)
(6, 2)

Iterating in reversed order is also possible

>>> for i in reversed(np.ndindex(3, 2)):
...: print(i)
(2, 1)
(2, 0)
(1, 1)
(1, 0)
(0, 1)
(0, 0)

"""

def __init__(self, *shape):
def __init__(self, *shape, slices=(), order='C', reverse=False):
# UGLY UGLY Hack to ensure that the following works
# np.ndindex((3, 2), np.s_[::2, ::2])
if len(shape) != 0 and isinstance(shape[-1], bool):
reverse = shape[-1]
shape = shape[:-1]

if len(shape) != 0 and isinstance(shape[-1], str):
order = shape[-1]
shape = shape[:-1]

if slices == () and len(shape) != 0:
if (isinstance(shape[-1], slice) or
(isinstance(shape[-1], tuple) and
len(shape[-1]) != 0 and
isinstance(shape[-1][0], slice))):
slices = shape[-1]
shape = shape[:-1]

if len(shape) == 1 and isinstance(shape[0], tuple):
shape = shape[0]
x = as_strided(_nx.zeros(1), shape=shape,
strides=_nx.zeros_like(shape))
self._it = _nx.nditer(x, flags=['multi_index', 'zerosize_ok'],
order='C')

if isinstance(slices, slice):
slices = (slices,)

if len(slices) > len(shape):
raise ValueError('too many slices for shape')
# append some None slices to ensure slices has at least
# the same dimensions as shape match up
# with python 3, one could use itertools.zip_longest
slices = slices + (slice(None),) * (len(shape) - len(slices))

self._slices = slices
self._shape = shape
self._order = order
self._reverse = reverse

self._range_indices = tuple(sl.indices(s)
for s, sl in zip(shape, slices))
if self._order == 'F':
self._it = itertools.product(
*(range(*i) if not reverse else reversed(range(*i))
for i in reversed(self._range_indices)))
else:
self._it = itertools.product(
*(range(*i) if not reverse else reversed(range(*i))
for i in self._range_indices))

def __iter__(self):
return self
Expand All @@ -650,8 +728,90 @@ def __next__(self):
iteration.

"""
next(self._it)
return self._it.multi_index
index = next(self._it)
if self._order == 'C':
return index
else:
return index[::-1]

def __contains__(self, index):
"""
Standard membership method. Checks if a given tuple is in the iterator.

Returns
-------
membership: bool
True if the iterator contains the member.

"""
# Make it a tuple
if not isinstance(index, tuple):
index = (index,)
if len(index) != len(self._shape):
return False

# We can't check containement in self._it as it will traverse it.
# Do we need to check if we have already passed the value?
# If so, how? do we need to cache that state too?
# How would we even check?
# What happens if the slice is np.s_[::-1, ::1, ::-1, ::1]
# i.e. not strictly C or Fortran
#
# The python built-in function range doesn't behave like this.
# Range itself is just an object.
# __iter__ returns a seperate iterator.
# for loops are free to traverse that iterator.
# containement in that iterator is checked by traversal
# containment in range is not.
#
# I suggest I refactor all these addition into a new
# class `ndrange`. That classe's goals would be to behave more like
# python3 range
# At the same time, I would refactoro the first parameter and force
# the shape to be a tuple
"""
In [11]: five = range(5)

In [12]: iter_five = five.__iter__()

In [13]: 3 in five
Out[13]: True

In [14]: 3 in five
Out[14]: True

In [15]: 3 in five
Out[15]: True

In [16]: 3 in iter_five
Out[16]: True

In [17]: 3 in iter_five
Out[17]: False

In [18]: iter_five = iter(five)

In [19]: 3 in iter_five
Out[19]: True

In [20]: 3 in iter_five
Out[20]: False
"""

for i, range_indices in zip(index, self._range_indices):
if i not in range(*range_indices):
return False
else:
return True

def __reversed__(self):
"""
Standard iteration reversal method. This is useful to generate an
Fortran ordered iterator.

"""
return ndindex(self._shape, slices=self._slices,
order=self._order, reverse=not self._reverse)

next = __next__

Expand Down
52 changes: 52 additions & 0 deletions numpy/lib/tests/test_index_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
index_exp, ndindex, r_, s_, ix_
)

import pytest
from random import shuffle


class TestRavelUnravelIndex(object):
def test_basic(self):
Expand Down Expand Up @@ -376,3 +379,52 @@ def test_ndindex():
# Make sure 0-sized ndindex works correctly
x = list(ndindex(*[0]))
assert_equal(x, [])


def ndindex_tester_helper(expected, shape, slices=(), order=None,
reversals=0, use_keyword=False):

if not use_keyword and order is None:
i = ndindex(shape, slices)
elif not use_keyword and order is not None:
i = ndindex(shape, slices, order)
elif use_keyword and order is None:
i = ndindex(shape, slices=slices)
else: # use_keywords and order is not None
i = ndindex(shape, slices=slices, order=order)

for _ in range(reversals):
i = reversed(i)
expected.reverse()

shuffled = expected.copy()
shuffle(shuffled)
# Make sure we can assert identity in a random order.
# This is important since if `__contains__` traverses the iterator,
# Then it won't find an element when probed repeatedly.
for e in shuffled:
assert e in i

x = list(i)
assert_array_equal(x, expected)

@pytest.mark.parametrize('reversals', [0, 1, 2])
@pytest.mark.parametrize('use_keyword', [False, True])
def test_ndindex_strided(reversals, use_keyword):
# 1D
expected = [(0,), (2,)]
ndindex_tester_helper(expected, shape=5, slices=slice(0, 4, 2),
use_keyword=use_keyword)

# 2D
expected = [(1, 0), (1, 3), (1, 6),
(3, 0), (3, 3), (3, 6)]
ndindex_tester_helper(expected, shape=(4, 9), slices=np.s_[1::2, ::3],
use_keyword=use_keyword)
ndindex_tester_helper(expected, shape=(4, 9), slices=np.s_[1::2, ::3],
order='C', use_keyword=use_keyword)
expected = [(1, 0), (3, 0),
(1, 3), (3, 3),
(1, 6), (3, 6)]
ndindex_tester_helper(expected, shape=(4, 9), slices=np.s_[1::2, ::3],
order='F', use_keyword=use_keyword)