Skip to content

Commit

Permalink
Merge pull request #758 from lsst/tickets/DM-38091
Browse files Browse the repository at this point in the history
DM-38091: Use InMemoryDatasetHandle
  • Loading branch information
timj committed Apr 15, 2023
2 parents a751136 + 34e466a commit bbe75d3
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 307 deletions.
8 changes: 8 additions & 0 deletions python/lsst/pipe/tasks/assembleCoadd.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,9 +653,17 @@ def run(self, skyInfo, tempExpRefList, imageScalerList, weightList,
Input list of image scalers (`list`) (unmodified).
``weightList``
Input list of weights (`list`) (unmodified).
Raises
------
lsst.pipe.base.NoWorkFound
Raised if no data references are provided.
"""
tempExpName = self.getTempExpDatasetName(self.warpType)
self.log.info("Assembling %s %s", len(tempExpRefList), tempExpName)
if not tempExpRefList:
raise pipeBase.NoWorkFound("No exposures provided for co-addition.")

stats = self.prepareStats(mask=mask)

if altMaskList is None:
Expand Down
57 changes: 34 additions & 23 deletions python/lsst/pipe/tasks/functors.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
from astropy.coordinates import SkyCoord

from lsst.utils import doImport
from lsst.utils.introspection import get_full_type_name
from lsst.daf.butler import DeferredDatasetHandle
from lsst.pipe.base import InMemoryDatasetHandle
import lsst.geom as geom
import lsst.sphgeom as sphgeom

Expand Down Expand Up @@ -94,7 +96,7 @@ class Functor(object):
"""Define and execute a calculation on a ParquetTable
The `__call__` method accepts either a `ParquetTable` object or a
`DeferredDatasetHandle`, and returns the
`DeferredDatasetHandle` or `InMemoryDatasetHandle`, and returns the
result of the calculation as a single column. Each functor defines what
columns are needed for the calculation, and only these columns are read
from the `ParquetTable`.
Expand Down Expand Up @@ -184,12 +186,14 @@ def _get_data_columnLevels(self, data, columnIndex=None):
Parameters
----------
data : `MultilevelParquetTable` or `DeferredDatasetHandle`
data : various
The data to be read, can be a `MultilevelParquetTable`,
`DeferredDatasetHandle`, or `InMemoryDatasetHandle`.
columnnIndex (optional): pandas `Index` object
if not passed, then it is read from the `DeferredDatasetHandle`
If not passed, then it is read from the `DeferredDatasetHandle`
for `InMemoryDatasetHandle`.
"""
if isinstance(data, DeferredDatasetHandle):
if isinstance(data, (DeferredDatasetHandle, InMemoryDatasetHandle)):
if columnIndex is None:
columnIndex = data.get(component="columns")
if columnIndex is not None:
Expand All @@ -202,11 +206,13 @@ def _get_data_columnLevels(self, data, columnIndex=None):
def _get_data_columnLevelNames(self, data, columnIndex=None):
"""Gets the content of each of the column levels for a multilevel table
Similar to `_get_data_columnLevels`, this enables backward compatibility with gen2.
Similar to `_get_data_columnLevels`, this enables backward
compatibility with gen2.
Mirrors original gen2 implementation within `pipe.tasks.parquetTable.MultilevelParquetTable`
Mirrors original gen2 implementation within
`pipe.tasks.parquetTable.MultilevelParquetTable`
"""
if isinstance(data, DeferredDatasetHandle):
if isinstance(data, (DeferredDatasetHandle, InMemoryDatasetHandle)):
if columnIndex is None:
columnIndex = data.get(component="columns")
if columnIndex is not None:
Expand Down Expand Up @@ -252,18 +258,18 @@ def multilevelColumns(self, data, columnIndex=None, returnTuple=False):
Parameters
----------
data : `MultilevelParquetTable` or `DeferredDatasetHandle`
data : various
The data as either `MultilevelParquetTable`,
`DeferredDatasetHandle`, or `InMemoryDatasetHandle`.
columnIndex (optional): pandas `Index` object
either passed or read in from `DeferredDatasetHandle`.
`returnTuple` : bool
`returnTuple` : `bool`
If true, then return a list of tuples rather than the column dictionary
specification. This is set to `True` by `CompositeFunctor` in order to be able to
combine columns from the various component functors.
"""
if isinstance(data, DeferredDatasetHandle) and columnIndex is None:
if isinstance(data, (DeferredDatasetHandle, InMemoryDatasetHandle)) and columnIndex is None:
columnIndex = data.get(component="columns")

# Confirm that the dataset has the column levels the functor is expecting it to have.
Expand All @@ -287,11 +293,12 @@ def multilevelColumns(self, data, columnIndex=None, returnTuple=False):

if isinstance(data, MultilevelParquetTable):
return data._colsFromDict(columnDict)
elif isinstance(data, DeferredDatasetHandle):
elif isinstance(data, (DeferredDatasetHandle, InMemoryDatasetHandle)):
if returnTuple:
return self._colsFromDict(columnDict, columnIndex=columnIndex)
else:
return columnDict
raise RuntimeError(f"Unexpected data type. Got {get_full_type_name}.")

def _func(self, df, dropna=True):
raise NotImplementedError('Must define calculation on dataframe')
Expand All @@ -300,7 +307,7 @@ def _get_columnIndex(self, data):
"""Return columnIndex
"""

if isinstance(data, DeferredDatasetHandle):
if isinstance(data, (DeferredDatasetHandle, InMemoryDatasetHandle)):
return data.get(component="columns")
else:
return None
Expand Down Expand Up @@ -336,9 +343,11 @@ def _get_data(self, data):
if isinstance(data, MultilevelParquetTable):
# Load in-memory dataframe with appropriate columns the gen2 way
df = data.toDataFrame(columns=columns, droplevels=False)
elif isinstance(data, DeferredDatasetHandle):
elif isinstance(data, (DeferredDatasetHandle, InMemoryDatasetHandle)):
# Load in-memory dataframe with appropriate columns the gen3 way
df = data.get(parameters={"columns": columns})
else:
raise RuntimeError(f"Unexpected type provided for data. Got {get_full_type_name(data)}.")

# Drop unnecessary column levels
if is_multiLevel:
Expand All @@ -355,8 +364,8 @@ def _dropna(self, vals):
return vals.dropna()

def __call__(self, data, dropna=False):
df = self._get_data(data)
try:
df = self._get_data(data)
vals = self._func(df)
except Exception as e:
self.log.error("Exception in %s call: %s: %s", self.name, type(e).__name__, e)
Expand Down Expand Up @@ -475,10 +484,12 @@ def __call__(self, data, **kwargs):
Parameters
----------
data : `lsst.daf.butler.DeferredDatasetHandle`,
`lsst.pipe.tasks.parquetTable.MultilevelParquetTable`,
`lsst.pipe.tasks.parquetTable.ParquetTable`,
or `pandas.DataFrame`.
data : various
The data represented as `lsst.daf.butler.DeferredDatasetHandle`,
`lsst.pipe.tasks.parquetTable.MultilevelParquetTable`,
`lsst.pipe.tasks.parquetTable.ParquetTable`,
`lsst.pipe.base.InMemoryDatasetHandle`,
or `pandas.DataFrame`.
The table or a pointer to a table on disk from which columns can
be accessed
"""
Expand All @@ -494,7 +505,7 @@ def __call__(self, data, **kwargs):
if isinstance(data, MultilevelParquetTable):
# Read data into memory the gen2 way
df = data.toDataFrame(columns=columns, droplevels=False)
elif isinstance(data, DeferredDatasetHandle):
elif isinstance(data, (DeferredDatasetHandle, InMemoryDatasetHandle)):
# Read data into memory the gen3 way
df = data.get(parameters={"columns": columns})

Expand All @@ -513,7 +524,7 @@ def __call__(self, data, **kwargs):
raise e

else:
if isinstance(data, DeferredDatasetHandle):
if isinstance(data, (DeferredDatasetHandle, InMemoryDatasetHandle)):
# input if Gen3 deferLoad=True
df = data.get(parameters={"columns": self.columns})
elif isinstance(data, pd.DataFrame):
Expand Down
64 changes: 10 additions & 54 deletions tests/assembleCoaddTestUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from lsst.afw.cameraGeom.testUtils import DetectorWrapper
import lsst.afw.geom as afwGeom
import lsst.afw.image as afwImage
import lsst.daf.butler
import lsst.geom as geom
from lsst.geom import arcseconds, degrees
from lsst.meas.algorithms.testUtils import plantSources
Expand All @@ -46,38 +45,20 @@
__all__ = ["MockWarpReference", "makeMockSkyInfo", "MockCoaddTestData"]


class MockWarpReference(lsst.daf.butler.DeferredDatasetHandle):
class MockWarpReference(pipeBase.InMemoryDatasetHandle):
"""Very simple object that looks like a Gen 3 data reference to a warped
exposure.
Parameters
----------
exposure : `lsst.afw.image.Exposure`
The exposure to be retrieved by the data reference.
coaddName : `str`
The type of coadd being produced. Typically 'deep'.
patch : `int`
Unique identifier for a subdivision of a tract.
tract : `int`
Unique identifier for a tract of a skyMap
visit : `int`
Unique identifier for an observation,
potentially consisting of multiple ccds.
"""
def __init__(self, exposure, coaddName='deep', patch=42, tract=0, visit=100):
self.coaddName = coaddName
self.exposure = exposure
self.tract = tract
self.patch = patch
self.visit = visit

def get(self, bbox=None, component=None, parameters=None):
def get(self, *, component=None, parameters=None):
"""Retrieve the specified dataset using the API of the Gen 3 Butler.
Parameters
----------
bbox : `lsst.geom.box.Box2I`, optional
If supplied, retrieve only a subregion of the exposure.
component : `str`, optional
If supplied, return the named metadata of the exposure.
parameters : `dict`, optional
Expand All @@ -88,38 +69,13 @@ def get(self, bbox=None, component=None, parameters=None):
-------
`lsst.afw.image.Exposure` or `lsst.afw.image.VisitInfo`
or `lsst.meas.algorithms.SingleGaussianPsf`
Either the exposure or its metadata, depending on the datasetType.
"""
if component == 'psf':
return self.exposure.getPsf()
elif component == 'visitInfo':
return self.exposure.getInfo().getVisitInfo()
if parameters is not None:
if "bbox" in parameters:
bbox = parameters["bbox"]
exp = self.exposure.clone()
if bbox is not None:
return exp[bbox]
else:
return exp

@property
def dataId(self):
"""Generate a valid data identifier.
Returns
-------
dataId : `lsst.daf.butler.DataCoordinate`
Data identifier dict for the patch.
Either the exposure or its metadata, depending on the component
requested.
"""
return lsst.daf.butler.DataCoordinate.standardize(
tract=self.tract,
patch=self.patch,
visit=self.visit,
instrument="DummyCam",
skymap="Skymap",
universe=lsst.daf.butler.DimensionUniverse(),
)
exp = super().get(component=component, parameters=parameters)
if isinstance(exp, afwImage.ExposureF):
exp = exp.clone()
return exp


def makeMockSkyInfo(bbox, wcs, patch):
Expand Down Expand Up @@ -479,7 +435,7 @@ def makeDataRefList(exposures, matchedExposures, warpType, tract=0, patch=42, co
exposure = matchedExposures[expId]
else:
raise ValueError("warpType must be one of 'direct' or 'psfMatched'")
dataRef = MockWarpReference(exposure, coaddName=coaddName,
tract=tract, patch=patch, visit=expId)
dataRef = MockWarpReference(exposure, storageClass="ExposureF",
tract=tract, patch=patch, visit=expId, coaddName=coaddName)
dataRefList.append(dataRef)
return dataRefList

0 comments on commit bbe75d3

Please sign in to comment.