Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Improve ndindex execution speed.

  • Loading branch information...
commit 5a471b5505e1e751e964cd3edade7db53a9596a9 1 parent 2f28db6
@stefanv stefanv authored certik committed
Showing with 15 additions and 42 deletions.
  1. +8 −41 numpy/lib/index_tricks.py
  2. +7 −1 numpy/lib/tests/test_index_tricks.py
View
49 numpy/lib/index_tricks.py
@@ -18,6 +18,7 @@
import numpy.matrixlib as matrix
from function_base import diff
from numpy.lib._compiled_base import ravel_multi_index, unravel_index
+from numpy.lib.stride_tricks import as_strided
makemat = matrix.matrix
def ix_(*args):
@@ -531,37 +532,12 @@ class ndindex(object):
(2, 1, 0)
"""
+ def __init__(self, *shape):
+ x = as_strided(_nx.zeros(1), shape=shape, strides=_nx.zeros_like(shape))
+ self._it = _nx.nditer(x, flags=['multi_index'])
- def __init__(self, *args):
- if len(args) == 1 and isinstance(args[0], tuple):
- args = args[0]
- self.nd = len(args)
- self.ind = [0]*self.nd
- self.index = 0
- self.maxvals = args
- tot = 1
- for k in range(self.nd):
- tot *= args[k]
- self.total = tot
-
- def _incrementone(self, axis):
- if (axis < 0): # base case
- return
- if (self.ind[axis] < self.maxvals[axis]-1):
- self.ind[axis] += 1
- else:
- self.ind[axis] = 0
- self._incrementone(axis-1)
-
- def ndincr(self):
- """
- Increment the multi-dimensional index by one.
-
- `ndincr` takes care of the "wrapping around" of the axes.
- It is called by `ndindex.next` and not normally used directly.
-
- """
- self._incrementone(self.nd-1)
+ def __iter__(self):
+ return self
def next(self):
"""
@@ -573,17 +549,8 @@ def next(self):
Returns a tuple containing the indices of the current iteration.
"""
- if (self.index >= self.total):
- raise StopIteration
- val = tuple(self.ind)
- self.index += 1
- self.ndincr()
- return val
-
- def __iter__(self):
- return self
-
-
+ self._it.next()
+ return self._it.multi_index
# You can do all this with slice() plus a few special objects,
View
8 numpy/lib/tests/test_index_tricks.py
@@ -2,7 +2,7 @@
import numpy as np
from numpy import ( array, ones, r_, mgrid, unravel_index, zeros, where,
ndenumerate, fill_diagonal, diag_indices,
- diag_indices_from, s_, index_exp )
+ diag_indices_from, s_, index_exp, ndindex )
class TestRavelUnravelIndex(TestCase):
def test_basic(self):
@@ -237,5 +237,11 @@ def test_diag_indices_from():
assert_array_equal(c, np.arange(4))
+def test_ndindex():
+ x = list(np.ndindex(1, 2, 3))
+ expected = [ix for ix, e in np.ndenumerate(np.zeros((1, 2, 3)))]
+ assert_array_equal(x, expected)
+
+
if __name__ == "__main__":
run_module_suite()
Please sign in to comment.
Something went wrong with that request. Please try again.