Skip to content

Commit

Permalink
Improve ndindex execution speed.
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanv authored and certik committed Sep 12, 2012
1 parent 2f28db6 commit 5a471b5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 42 deletions.
49 changes: 8 additions & 41 deletions numpy/lib/index_tricks.py
Expand Up @@ -18,6 +18,7 @@
import numpy.matrixlib as matrix
from function_base import diff
from numpy.lib._compiled_base import ravel_multi_index, unravel_index
from numpy.lib.stride_tricks import as_strided
makemat = matrix.matrix

def ix_(*args):
Expand Down Expand Up @@ -531,37 +532,12 @@ class ndindex(object):
(2, 1, 0)
"""
def __init__(self, *shape):
x = as_strided(_nx.zeros(1), shape=shape, strides=_nx.zeros_like(shape))
self._it = _nx.nditer(x, flags=['multi_index'])

def __init__(self, *args):
if len(args) == 1 and isinstance(args[0], tuple):
args = args[0]
self.nd = len(args)
self.ind = [0]*self.nd
self.index = 0
self.maxvals = args
tot = 1
for k in range(self.nd):
tot *= args[k]
self.total = tot

def _incrementone(self, axis):
if (axis < 0): # base case
return
if (self.ind[axis] < self.maxvals[axis]-1):
self.ind[axis] += 1
else:
self.ind[axis] = 0
self._incrementone(axis-1)

def ndincr(self):
"""
Increment the multi-dimensional index by one.
`ndincr` takes care of the "wrapping around" of the axes.
It is called by `ndindex.next` and not normally used directly.
"""
self._incrementone(self.nd-1)
def __iter__(self):
return self

def next(self):
"""
Expand All @@ -573,17 +549,8 @@ def next(self):
Returns a tuple containing the indices of the current iteration.
"""
if (self.index >= self.total):
raise StopIteration
val = tuple(self.ind)
self.index += 1
self.ndincr()
return val

def __iter__(self):
return self


self._it.next()
return self._it.multi_index


# You can do all this with slice() plus a few special objects,
Expand Down
8 changes: 7 additions & 1 deletion numpy/lib/tests/test_index_tricks.py
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from numpy import ( array, ones, r_, mgrid, unravel_index, zeros, where,
ndenumerate, fill_diagonal, diag_indices,
diag_indices_from, s_, index_exp )
diag_indices_from, s_, index_exp, ndindex )

class TestRavelUnravelIndex(TestCase):
def test_basic(self):
Expand Down Expand Up @@ -237,5 +237,11 @@ def test_diag_indices_from():
assert_array_equal(c, np.arange(4))


def test_ndindex():
x = list(np.ndindex(1, 2, 3))
expected = [ix for ix, e in np.ndenumerate(np.zeros((1, 2, 3)))]
assert_array_equal(x, expected)


if __name__ == "__main__":
run_module_suite()

0 comments on commit 5a471b5

Please sign in to comment.