diff --git a/nbodykit/base/catalog.py b/nbodykit/base/catalog.py index 85532a924..4ca6a49d4 100644 --- a/nbodykit/base/catalog.py +++ b/nbodykit/base/catalog.py @@ -265,11 +265,11 @@ def __iter__(self): def __contains__(self, col): return col in self.columns - def __slice__(self, index): + def _get_slice(self, index): """ Select a subset of ``self`` according to a boolean index array. - Returns a new object of the same type as ``selff`` holding only the + Returns a new object of the same type as ``self`` holding only the data that satisfies the slice index. Parameters @@ -278,27 +278,28 @@ def __slice__(self, index): either a dask or numpy boolean array; this determines which rows are included in the returned object """ - # compute the index slice if needed and get the size - if isinstance(index, da.Array): - index = self.compute(index) - elif isinstance(index, list): - index = numpy.array(index) + if index is Ellipsis: + return self + elif isinstance(index, slice): + start, stop, stride = index.indices(self.size) + size = (stop - start) // stride + else: + # compute the index slice if needed and get the size + index = CatalogSourceBase.make_column(index) - if getattr(self, 'size', NotImplemented) is NotImplemented: - raise ValueError("cannot make catalog subset; self catalog doest not have a size") + if index.dtype == numpy.dtype('?'): + # verify the index is a boolean array + if len(index) != self.size: + raise KeyError("slice index has length %d; should be %d" %(len(index), self.size)) - # verify the index is a boolean array - if len(index) != self.size: - raise ValueError("slice index has length %d; should be %d" %(len(index), self.size)) - if getattr(index, 'dtype', None) != '?': - raise ValueError("index used to slice CatalogSource must be boolean and array-like") + # new size is just number of True entries + size = index.sum().compute() + else: - # new size is just number of True entries - size = index.sum() + if len(index) > 0 and index.dtype != numpy.integer: + raise KeyError("slice index has must be boolean, integer. got %s" %(index.dtype)) - # if collective size is unchanged, just return self - if self.comm.allreduce(size) == self.csize: - return self.base if self.base is not None else self + size = len(index) # initialize subset Source of right size subset_data = {col:self[col][index] for col in self} @@ -324,7 +325,7 @@ def __getitem__(self, sel): Notes ----- - - Slicing with a boolean array is a **collective** operation + - Slicing is a **collective** operation - If the :attr:`base` attribute is set, columns will be returned from :attr:`base` instead of from ``self``. """ @@ -336,10 +337,10 @@ def __getitem__(self, sel): raise ValueError("cannot perform selection due to NotImplemented size") # convert slices to boolean arrays - if isinstance(sel, (slice, list)): + if isinstance(sel, (list, da.Array, numpy.ndarray)): # select a subset of list of string column names - if isinstance(sel, list) and all(isinstance(ss, string_types) for ss in sel): + if len(sel) > 0 and all(isinstance(ss, string_types) for ss in sel): invalid = set(sel) - set(self.columns) if len(invalid): msg = "cannot select subset of columns from " @@ -352,40 +353,28 @@ def __getitem__(self, sel): toret.attrs.update(self.attrs) return toret - # list must be all integers - if isinstance(sel, list) and not numpy.array(sel).dtype == numpy.integer: - raise KeyError("array like indexing via a list should be a list of integers") - - # convert into slice into boolean array - index = numpy.zeros(self.size, dtype='?') - index[sel] = True; sel = index - - # do the slicing - if not numpy.isscalar(sel): - return self.__slice__(sel) - else: - raise KeyError("strings and boolean arrays are the only supported indexing methods") - - # owner of the memory (either self or base) - if self.base is None: - # get the right column - is_default = False - if sel in self._overrides: - r = self._overrides[sel] - elif sel in self.hardcolumns: - r = self.get_hardcolumn(sel) - elif sel in self._defaults: - r = getattr(self, sel)() - is_default = True - else: - raise KeyError("column `%s` is not defined in this source; " %sel + \ - "try adding column via `source[column] = data`") - # return a ColumnAccessor for pretty prints - return ColumnAccessor(self, r, is_default=is_default) + return self._get_slice(sel) else: - # chain to the memory owner - # this will not work if there are overrides - return self.base.__getitem__(sel) + # owner of the memory (either self or base) + if self.base is None: + # get the right column + is_default = False + if sel in self._overrides: + r = self._overrides[sel] + elif sel in self.hardcolumns: + r = self.get_hardcolumn(sel) + elif sel in self._defaults: + r = getattr(self, sel)() + is_default = True + else: + raise KeyError("column `%s` is not defined in this source; " %sel + \ + "try adding column via `source[column] = data`") + # return a ColumnAccessor for pretty prints + return ColumnAccessor(self, r, is_default=is_default) + else: + # chain to the memory owner + # this will not work if there are overrides + return self.base.__getitem__(sel) def __setitem__(self, col, value): @@ -944,9 +933,15 @@ def gslice(self, start, stop, end=1, redistribute=True): Execute a global slice of a CatalogSource. .. note:: + After the global slice is performed, the data is scattered evenly across all ranks. + .. note:: + + The current algorithm generates an index on the root rank + and does not scale well. + Parameters ---------- start : int diff --git a/nbodykit/base/tests/test_catalog.py b/nbodykit/base/tests/test_catalog.py index d1d3ff888..c8f4c872a 100644 --- a/nbodykit/base/tests/test_catalog.py +++ b/nbodykit/base/tests/test_catalog.py @@ -153,16 +153,21 @@ def test_bad_column(comm): with pytest.raises(ValueError): data = source.get_hardcolumn('BAD_COLUMN') -@MPITest([4]) +@MPITest([1, 4]) def test_empty_slice(comm): source = UniformCatalog(nbar=2e-4, BoxSize=512., seed=42, comm=comm) - # empty slice returns self - source2 = source[source['Selection']] + # Ellipsis slice returns self + source2 = source[...] assert source is source2 - # non-empty selection on root only + # Empty slice dos not crash + subset = source[[]] + assert all(col in subset for col in source.columns) + assert isinstance(subset, source.__class__) + + # any selection on root only sel = source.rng.choice([True, False]) if comm.rank != 0: sel[...] = True @@ -172,11 +177,12 @@ def test_empty_slice(comm): assert source is not source2 -@MPITest([4]) +@MPITest([1, 4]) def test_slice(comm): source = UniformCatalog(nbar=2e-4, BoxSize=512., seed=42, comm=comm) + source['NZ'] = 1 # slice a subset subset = source[:10] assert all(col in subset for col in source.columns)