diff --git a/nbodykit/base/catalog.py b/nbodykit/base/catalog.py index 5a6318a75..f1598c4e3 100644 --- a/nbodykit/base/catalog.py +++ b/nbodykit/base/catalog.py @@ -143,7 +143,7 @@ def __getitem__(self, sel): 1. strings specifying a column in the CatalogSource; returns a dask array holding the column data 2. boolean arrays specifying a slice of the CatalogSource; - returns a SlicedCatalogSource holding only the revelant slice + returns a CatalogSubset holding only the revelant slice 3. slice object specifying which particles to select """ # handle boolean array slices @@ -161,7 +161,7 @@ def __getitem__(self, sel): # do the slicing if not numpy.isscalar(sel): - return SliceCatalogSource(self, sel) + return get_catalog_subset(self, sel) else: raise KeyError("strings and boolean arrays are the only supported indexing methods") @@ -498,11 +498,43 @@ def to_mesh(self, Nmesh=None, BoxSize=None, dtype='f4', r.window = window return r + +class CatalogSubset(CatalogSource): + """ + A subset of a CatalogSource holding only a portion + of the original source + """ + def __init__(self, size, comm, use_cache=False, **columns): + """ + Parameters + ---------- + size : int + the size of the new source; this was likely determined by + the number of particles passing the selection criterion + comm : MPI communicator + the MPI communicator; this should be the same as the + comm of the object that we are selecting from + use_cache : bool; optional + whether to cache results + **columns : + the data arrays that will be added to this source; keys + represent the column names + """ + self._size = size + CatalogSource.__init__(self, comm=comm, use_cache=use_cache) + + # store the column arrays + for name in columns: + self[name] = columns[name] + + @property + def size(self): + return self._size -def SliceCatalogSource(parent, index): +def get_catalog_subset(parent, index): """ - Slice a `CatalogSource` according to a boolean index array, - returning a new CatalogSource holding only the data that satisfies + Select a subset of a CatalogSource according to a boolean index array, + returning a CatalogSubset holding only the data that satisfies the slice criterion Parameters @@ -515,7 +547,7 @@ def SliceCatalogSource(parent, index): Returns ------- - sliced : SlicedCatalogSource + subset : CatalogSubset the particle source with the same meta-data as `parent`, and with the sliced data arrays """ @@ -534,17 +566,9 @@ def SliceCatalogSource(parent, index): # new size is just number of True entries size = index.sum() - class SlicedCatalogSource(CatalogSource): - @property - def size(self): - return size - - # initialize empty Source of right size - toret = SlicedCatalogSource(parent.comm, use_cache=parent.use_cache) - - # copy over columns - for column in parent: - toret[column] = parent[column][index] + # initialize subset Source of right size + subset_data = {col:parent[col][index] for col in parent} + toret = CatalogSubset(size, parent.comm, use_cache=parent.use_cache, **subset_data) # and the meta-data toret.attrs.update(parent.attrs)