Skip to content

Commit

Permalink
Merge pull request #532 from rainwoodman/faster-slice
Browse files Browse the repository at this point in the history
Faster slice
  • Loading branch information
rainwoodman committed Nov 13, 2018
2 parents 16b2042 + 1969b75 commit e836a3f
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 60 deletions.
105 changes: 50 additions & 55 deletions nbodykit/base/catalog.py
Expand Up @@ -265,11 +265,11 @@ def __iter__(self):
def __contains__(self, col):
return col in self.columns

def __slice__(self, index):
def _get_slice(self, index):
"""
Select a subset of ``self`` according to a boolean index array.
Returns a new object of the same type as ``selff`` holding only the
Returns a new object of the same type as ``self`` holding only the
data that satisfies the slice index.
Parameters
Expand All @@ -278,27 +278,28 @@ def __slice__(self, index):
either a dask or numpy boolean array; this determines which
rows are included in the returned object
"""
# compute the index slice if needed and get the size
if isinstance(index, da.Array):
index = self.compute(index)
elif isinstance(index, list):
index = numpy.array(index)
if index is Ellipsis:
return self
elif isinstance(index, slice):
start, stop, stride = index.indices(self.size)
size = (stop - start) // stride
else:
# compute the index slice if needed and get the size
index = CatalogSourceBase.make_column(index)

if getattr(self, 'size', NotImplemented) is NotImplemented:
raise ValueError("cannot make catalog subset; self catalog doest not have a size")
if index.dtype == numpy.dtype('?'):
# verify the index is a boolean array
if len(index) != self.size:
raise KeyError("slice index has length %d; should be %d" %(len(index), self.size))

# verify the index is a boolean array
if len(index) != self.size:
raise ValueError("slice index has length %d; should be %d" %(len(index), self.size))
if getattr(index, 'dtype', None) != '?':
raise ValueError("index used to slice CatalogSource must be boolean and array-like")
# new size is just number of True entries
size = index.sum().compute()
else:

# new size is just number of True entries
size = index.sum()
if len(index) > 0 and index.dtype != numpy.integer:
raise KeyError("slice index has must be boolean, integer. got %s" %(index.dtype))

# if collective size is unchanged, just return self
if self.comm.allreduce(size) == self.csize:
return self.base if self.base is not None else self
size = len(index)

# initialize subset Source of right size
subset_data = {col:self[col][index] for col in self}
Expand All @@ -324,7 +325,7 @@ def __getitem__(self, sel):
Notes
-----
- Slicing with a boolean array is a **collective** operation
- Slicing is a **collective** operation
- If the :attr:`base` attribute is set, columns will be returned
from :attr:`base` instead of from ``self``.
"""
Expand All @@ -336,10 +337,10 @@ def __getitem__(self, sel):
raise ValueError("cannot perform selection due to NotImplemented size")

# convert slices to boolean arrays
if isinstance(sel, (slice, list)):
if isinstance(sel, (list, da.Array, numpy.ndarray)):

# select a subset of list of string column names
if isinstance(sel, list) and all(isinstance(ss, string_types) for ss in sel):
if len(sel) > 0 and all(isinstance(ss, string_types) for ss in sel):
invalid = set(sel) - set(self.columns)
if len(invalid):
msg = "cannot select subset of columns from "
Expand All @@ -352,40 +353,28 @@ def __getitem__(self, sel):
toret.attrs.update(self.attrs)
return toret

# list must be all integers
if isinstance(sel, list) and not numpy.array(sel).dtype == numpy.integer:
raise KeyError("array like indexing via a list should be a list of integers")

# convert into slice into boolean array
index = numpy.zeros(self.size, dtype='?')
index[sel] = True; sel = index

# do the slicing
if not numpy.isscalar(sel):
return self.__slice__(sel)
else:
raise KeyError("strings and boolean arrays are the only supported indexing methods")

# owner of the memory (either self or base)
if self.base is None:
# get the right column
is_default = False
if sel in self._overrides:
r = self._overrides[sel]
elif sel in self.hardcolumns:
r = self.get_hardcolumn(sel)
elif sel in self._defaults:
r = getattr(self, sel)()
is_default = True
else:
raise KeyError("column `%s` is not defined in this source; " %sel + \
"try adding column via `source[column] = data`")
# return a ColumnAccessor for pretty prints
return ColumnAccessor(self, r, is_default=is_default)
return self._get_slice(sel)
else:
# chain to the memory owner
# this will not work if there are overrides
return self.base.__getitem__(sel)
# owner of the memory (either self or base)
if self.base is None:
# get the right column
is_default = False
if sel in self._overrides:
r = self._overrides[sel]
elif sel in self.hardcolumns:
r = self.get_hardcolumn(sel)
elif sel in self._defaults:
r = getattr(self, sel)()
is_default = True
else:
raise KeyError("column `%s` is not defined in this source; " %sel + \
"try adding column via `source[column] = data`")
# return a ColumnAccessor for pretty prints
return ColumnAccessor(self, r, is_default=is_default)
else:
# chain to the memory owner
# this will not work if there are overrides
return self.base.__getitem__(sel)


def __setitem__(self, col, value):
Expand Down Expand Up @@ -944,9 +933,15 @@ def gslice(self, start, stop, end=1, redistribute=True):
Execute a global slice of a CatalogSource.
.. note::
After the global slice is performed, the data is scattered
evenly across all ranks.
.. note::
The current algorithm generates an index on the root rank
and does not scale well.
Parameters
----------
start : int
Expand Down
16 changes: 11 additions & 5 deletions nbodykit/base/tests/test_catalog.py
Expand Up @@ -153,16 +153,21 @@ def test_bad_column(comm):
with pytest.raises(ValueError):
data = source.get_hardcolumn('BAD_COLUMN')

@MPITest([4])
@MPITest([1, 4])
def test_empty_slice(comm):

source = UniformCatalog(nbar=2e-4, BoxSize=512., seed=42, comm=comm)

# empty slice returns self
source2 = source[source['Selection']]
# Ellipsis slice returns self
source2 = source[...]
assert source is source2

# non-empty selection on root only
# Empty slice dos not crash
subset = source[[]]
assert all(col in subset for col in source.columns)
assert isinstance(subset, source.__class__)

# any selection on root only
sel = source.rng.choice([True, False])
if comm.rank != 0:
sel[...] = True
Expand All @@ -172,11 +177,12 @@ def test_empty_slice(comm):
assert source is not source2


@MPITest([4])
@MPITest([1, 4])
def test_slice(comm):

source = UniformCatalog(nbar=2e-4, BoxSize=512., seed=42, comm=comm)

source['NZ'] = 1
# slice a subset
subset = source[:10]
assert all(col in subset for col in source.columns)
Expand Down

0 comments on commit e836a3f

Please sign in to comment.