diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index 6f2aa1d02ed2..112a79bd1f74 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -18,6 +18,7 @@ import numpy.matrixlib as matrix from function_base import diff from numpy.lib._compiled_base import ravel_multi_index, unravel_index +from numpy.lib.stride_tricks import as_strided makemat = matrix.matrix def ix_(*args): @@ -531,37 +532,12 @@ class ndindex(object): (2, 1, 0) """ + def __init__(self, *shape): + x = as_strided(_nx.zeros(1), shape=shape, strides=_nx.zeros_like(shape)) + self._it = _nx.nditer(x, flags=['multi_index']) - def __init__(self, *args): - if len(args) == 1 and isinstance(args[0], tuple): - args = args[0] - self.nd = len(args) - self.ind = [0]*self.nd - self.index = 0 - self.maxvals = args - tot = 1 - for k in range(self.nd): - tot *= args[k] - self.total = tot - - def _incrementone(self, axis): - if (axis < 0): # base case - return - if (self.ind[axis] < self.maxvals[axis]-1): - self.ind[axis] += 1 - else: - self.ind[axis] = 0 - self._incrementone(axis-1) - - def ndincr(self): - """ - Increment the multi-dimensional index by one. - - `ndincr` takes care of the "wrapping around" of the axes. - It is called by `ndindex.next` and not normally used directly. - - """ - self._incrementone(self.nd-1) + def __iter__(self): + return self def next(self): """ @@ -573,17 +549,8 @@ def next(self): Returns a tuple containing the indices of the current iteration. """ - if (self.index >= self.total): - raise StopIteration - val = tuple(self.ind) - self.index += 1 - self.ndincr() - return val - - def __iter__(self): - return self - - + self._it.next() + return self._it.multi_index # You can do all this with slice() plus a few special objects, diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py index beda2d1462b1..0ede40d5a337 100644 --- a/numpy/lib/tests/test_index_tricks.py +++ b/numpy/lib/tests/test_index_tricks.py @@ -2,7 +2,7 @@ import numpy as np from numpy import ( array, ones, r_, mgrid, unravel_index, zeros, where, ndenumerate, fill_diagonal, diag_indices, - diag_indices_from, s_, index_exp ) + diag_indices_from, s_, index_exp, ndindex ) class TestRavelUnravelIndex(TestCase): def test_basic(self): @@ -237,5 +237,11 @@ def test_diag_indices_from(): assert_array_equal(c, np.arange(4)) +def test_ndindex(): + x = list(np.ndindex(1, 2, 3)) + expected = [ix for ix, e in np.ndenumerate(np.zeros((1, 2, 3)))] + assert_array_equal(x, expected) + + if __name__ == "__main__": run_module_suite()