Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-38091: Use InMemoryDatasetHandle #758

Merged
merged 12 commits into from
Apr 15, 2023
57 changes: 34 additions & 23 deletions python/lsst/pipe/tasks/functors.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
from astropy.coordinates import SkyCoord

from lsst.utils import doImport
from lsst.utils.introspection import get_full_type_name
from lsst.daf.butler import DeferredDatasetHandle
from lsst.pipe.base import InMemoryDatasetHandle
import lsst.geom as geom
import lsst.sphgeom as sphgeom

Expand Down Expand Up @@ -94,7 +96,7 @@ class Functor(object):
"""Define and execute a calculation on a ParquetTable

The `__call__` method accepts either a `ParquetTable` object or a
`DeferredDatasetHandle`, and returns the
`DeferredDatasetHandle` or `InMemoryDatasetHandle`, and returns the
result of the calculation as a single column. Each functor defines what
columns are needed for the calculation, and only these columns are read
from the `ParquetTable`.
Expand Down Expand Up @@ -184,12 +186,14 @@ def _get_data_columnLevels(self, data, columnIndex=None):

Parameters
----------
data : `MultilevelParquetTable` or `DeferredDatasetHandle`

data : various
The data to be read, can be a `MultilevelParquetTable`,
`DeferredDatasetHandle`, or `InMemoryDatasetHandle`.
columnnIndex (optional): pandas `Index` object
if not passed, then it is read from the `DeferredDatasetHandle`
If not passed, then it is read from the `DeferredDatasetHandle`
for `InMemoryDatasetHandle`.
"""
if isinstance(data, DeferredDatasetHandle):
if isinstance(data, (DeferredDatasetHandle, InMemoryDatasetHandle)):
if columnIndex is None:
columnIndex = data.get(component="columns")
if columnIndex is not None:
Expand All @@ -202,11 +206,13 @@ def _get_data_columnLevels(self, data, columnIndex=None):
def _get_data_columnLevelNames(self, data, columnIndex=None):
"""Gets the content of each of the column levels for a multilevel table

Similar to `_get_data_columnLevels`, this enables backward compatibility with gen2.
Similar to `_get_data_columnLevels`, this enables backward
compatibility with gen2.

Mirrors original gen2 implementation within `pipe.tasks.parquetTable.MultilevelParquetTable`
Mirrors original gen2 implementation within
`pipe.tasks.parquetTable.MultilevelParquetTable`
"""
if isinstance(data, DeferredDatasetHandle):
if isinstance(data, (DeferredDatasetHandle, InMemoryDatasetHandle)):
if columnIndex is None:
columnIndex = data.get(component="columns")
if columnIndex is not None:
Expand Down Expand Up @@ -252,18 +258,18 @@ def multilevelColumns(self, data, columnIndex=None, returnTuple=False):

Parameters
----------
data : `MultilevelParquetTable` or `DeferredDatasetHandle`

data : various
The data as either `MultilevelParquetTable`,
`DeferredDatasetHandle`, or `InMemoryDatasetHandle`.
columnIndex (optional): pandas `Index` object
either passed or read in from `DeferredDatasetHandle`.

`returnTuple` : bool
`returnTuple` : `bool`
If true, then return a list of tuples rather than the column dictionary
specification. This is set to `True` by `CompositeFunctor` in order to be able to
combine columns from the various component functors.

"""
if isinstance(data, DeferredDatasetHandle) and columnIndex is None:
if isinstance(data, (DeferredDatasetHandle, InMemoryDatasetHandle)) and columnIndex is None:
columnIndex = data.get(component="columns")

# Confirm that the dataset has the column levels the functor is expecting it to have.
Expand All @@ -287,11 +293,12 @@ def multilevelColumns(self, data, columnIndex=None, returnTuple=False):

if isinstance(data, MultilevelParquetTable):
return data._colsFromDict(columnDict)
elif isinstance(data, DeferredDatasetHandle):
elif isinstance(data, (DeferredDatasetHandle, InMemoryDatasetHandle)):
if returnTuple:
return self._colsFromDict(columnDict, columnIndex=columnIndex)
else:
return columnDict
raise RuntimeError(f"Unexpected data type. Got {get_full_type_name}.")

def _func(self, df, dropna=True):
raise NotImplementedError('Must define calculation on dataframe')
Expand All @@ -300,7 +307,7 @@ def _get_columnIndex(self, data):
"""Return columnIndex
"""

if isinstance(data, DeferredDatasetHandle):
if isinstance(data, (DeferredDatasetHandle, InMemoryDatasetHandle)):
return data.get(component="columns")
else:
return None
Expand Down Expand Up @@ -336,9 +343,11 @@ def _get_data(self, data):
if isinstance(data, MultilevelParquetTable):
# Load in-memory dataframe with appropriate columns the gen2 way
df = data.toDataFrame(columns=columns, droplevels=False)
elif isinstance(data, DeferredDatasetHandle):
elif isinstance(data, (DeferredDatasetHandle, InMemoryDatasetHandle)):
# Load in-memory dataframe with appropriate columns the gen3 way
df = data.get(parameters={"columns": columns})
else:
raise RuntimeError(f"Unexpected type provided for data. Got {get_full_type_name(data)}.")

# Drop unnecessary column levels
if is_multiLevel:
Expand All @@ -355,8 +364,8 @@ def _dropna(self, vals):
return vals.dropna()

def __call__(self, data, dropna=False):
df = self._get_data(data)
try:
df = self._get_data(data)
vals = self._func(df)
except Exception as e:
self.log.error("Exception in %s call: %s: %s", self.name, type(e).__name__, e)
Expand Down Expand Up @@ -475,10 +484,12 @@ def __call__(self, data, **kwargs):

Parameters
----------
data : `lsst.daf.butler.DeferredDatasetHandle`,
`lsst.pipe.tasks.parquetTable.MultilevelParquetTable`,
`lsst.pipe.tasks.parquetTable.ParquetTable`,
or `pandas.DataFrame`.
data : various
The data represented as `lsst.daf.butler.DeferredDatasetHandle`,
`lsst.pipe.tasks.parquetTable.MultilevelParquetTable`,
`lsst.pipe.tasks.parquetTable.ParquetTable`,
`lsst.pipe.base.InMemoryDatasetHandle`,
or `pandas.DataFrame`.
erykoff marked this conversation as resolved.
Show resolved Hide resolved
The table or a pointer to a table on disk from which columns can
be accessed
"""
Expand All @@ -494,7 +505,7 @@ def __call__(self, data, **kwargs):
if isinstance(data, MultilevelParquetTable):
# Read data into memory the gen2 way
df = data.toDataFrame(columns=columns, droplevels=False)
elif isinstance(data, DeferredDatasetHandle):
elif isinstance(data, (DeferredDatasetHandle, InMemoryDatasetHandle)):
# Read data into memory the gen3 way
df = data.get(parameters={"columns": columns})

Expand All @@ -513,7 +524,7 @@ def __call__(self, data, **kwargs):
raise e

else:
if isinstance(data, DeferredDatasetHandle):
if isinstance(data, (DeferredDatasetHandle, InMemoryDatasetHandle)):
# input if Gen3 deferLoad=True
df = data.get(parameters={"columns": self.columns})
elif isinstance(data, pd.DataFrame):
Expand Down
58 changes: 13 additions & 45 deletions tests/assembleCoaddTestUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,38 +46,20 @@
__all__ = ["MockWarpReference", "makeMockSkyInfo", "MockCoaddTestData"]


class MockWarpReference(lsst.daf.butler.DeferredDatasetHandle):
class MockWarpReference(pipeBase.InMemoryDatasetHandle):
"""Very simple object that looks like a Gen 3 data reference to a warped
exposure.

Parameters
----------
exposure : `lsst.afw.image.Exposure`
The exposure to be retrieved by the data reference.
coaddName : `str`
The type of coadd being produced. Typically 'deep'.
patch : `int`
Unique identifier for a subdivision of a tract.
tract : `int`
Unique identifier for a tract of a skyMap
visit : `int`
Unique identifier for an observation,
potentially consisting of multiple ccds.
"""
def __init__(self, exposure, coaddName='deep', patch=42, tract=0, visit=100):
self.coaddName = coaddName
self.exposure = exposure
self.tract = tract
self.patch = patch
self.visit = visit

def get(self, bbox=None, component=None, parameters=None):
def get(self, *, component=None, parameters=None):
"""Retrieve the specified dataset using the API of the Gen 3 Butler.

Parameters
----------
bbox : `lsst.geom.box.Box2I`, optional
If supplied, retrieve only a subregion of the exposure.
component : `str`, optional
If supplied, return the named metadata of the exposure.
parameters : `dict`, optional
Expand All @@ -91,36 +73,18 @@ def get(self, bbox=None, component=None, parameters=None):
Either the exposure or its metadata, depending on the datasetType.
"""
if component == 'psf':
return self.exposure.getPsf()
return self.inMemoryDataset.getPsf()
elif component == 'visitInfo':
return self.exposure.getInfo().getVisitInfo()
return self.inMemoryDataset.getInfo().getVisitInfo()
bbox = None
if parameters is not None:
if "bbox" in parameters:
bbox = parameters["bbox"]
exp = self.exposure.clone()
bbox = parameters.get("bbox")
timj marked this conversation as resolved.
Show resolved Hide resolved
exp = self.inMemoryDataset.clone()
timj marked this conversation as resolved.
Show resolved Hide resolved
timj marked this conversation as resolved.
Show resolved Hide resolved
if bbox is not None:
return exp[bbox]
else:
return exp

@property
def dataId(self):
"""Generate a valid data identifier.

Returns
-------
dataId : `lsst.daf.butler.DataCoordinate`
Data identifier dict for the patch.
"""
return lsst.daf.butler.DataCoordinate.standardize(
tract=self.tract,
patch=self.patch,
visit=self.visit,
instrument="DummyCam",
skymap="Skymap",
universe=lsst.daf.butler.DimensionUniverse(),
)


def makeMockSkyInfo(bbox, wcs, patch):
"""Construct a `Struct` containing the geometry of the patch to be coadded.
Expand Down Expand Up @@ -479,7 +443,11 @@ def makeDataRefList(exposures, matchedExposures, warpType, tract=0, patch=42, co
exposure = matchedExposures[expId]
else:
raise ValueError("warpType must be one of 'direct' or 'psfMatched'")
dataRef = MockWarpReference(exposure, coaddName=coaddName,
tract=tract, patch=patch, visit=expId)
dataId = lsst.daf.butler.DataCoordinate.standardize(
tract=tract, patch=patch, visit=expId,
instrument="DummyCam", skymap="Skymap",
universe=lsst.daf.butler.DimensionUniverse(),
)
dataRef = MockWarpReference(exposure, dataId=dataId)
timj marked this conversation as resolved.
Show resolved Hide resolved
dataRefList.append(dataRef)
return dataRefList
102 changes: 1 addition & 101 deletions tests/surveyPropertyMapsTestUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
"""Utilities for HealSparsePropertyMapTask and others."""
import numpy as np

import lsst.daf.butler
import lsst.geom as geom
from lsst.daf.base import DateTime
from lsst.afw.coord import Observatory
Expand All @@ -33,8 +32,7 @@
from lsst.afw.detection import GaussianPsf


__all__ = ['makeMockVisitSummary', 'MockVisitSummaryReference', 'MockCoaddReference',
'MockInputMapReference']
__all__ = ['makeMockVisitSummary']


def makeMockVisitSummary(visit,
Expand Down Expand Up @@ -178,101 +176,3 @@ def makeMockVisitSummary(visit,
row['psfArea'] = shape.getArea()

return visit_summary


class MockVisitSummaryReference(lsst.daf.butler.DeferredDatasetHandle):
"""Very simple object that looks like a Gen3 data reference to
a visit summary.

Parameters
----------
visit_summary : `lsst.afw.table.ExposureCatalog`
Visit summary catalog.
visit : `int`
Visit number.
"""
def __init__(self, visit_summary, visit):
self.visit_summary = visit_summary
self.visit = visit

def get(self):
"""Retrieve the specified dataset using the API of the Gen3 Butler.

Returns
-------
visit_summary : `lsst.afw.table.ExposureCatalog`
"""
return self.visit_summary


class MockCoaddReference(lsst.daf.butler.DeferredDatasetHandle):
"""Very simple object that looks like a Gen3 data reference to
a coadd.

Parameters
----------
exposure : `lsst.afw.image.Exposure`
The exposure to be retrieved by the data reference.
coaddName : `str`
The type of coadd produced. Typically "deep".
patch : `int`
Unique identifier for a subdivision of a tract.
tract : `int`
Unique identifier for a tract of a skyMap
"""
def __init__(self, exposure, coaddName="deep", patch=0, tract=0):
self.coaddName = coaddName
self.exposure = exposure
self.tract = tract
self.patch = patch

def get(self, component=None):
"""Retrieve the specified dataset using the API of the Gen 3 Butler.

Parameters
----------
component : `str`, optional
If supplied, return the named metadata of the exposure. Allowed
components are "photoCalib" or "coaddInputs".

Returns
-------
`lsst.afw.image.Exposure` ('component=None') or
`lsst.afw.image.PhotoCalib` ('component="photoCalib") or
`lsst.afw.image.CoaddInputs` ('component="coaddInputs")
"""
if component == "photoCalib":
return self.exposure.getPhotoCalib()
elif component == "coaddInputs":
return self.exposure.getInfo().getCoaddInputs()

return self.exposure.clone()


class MockInputMapReference(lsst.daf.butler.DeferredDatasetHandle):
"""Very simple object that looks like a Gen3 data reference to
an input map.

Parameters
----------
input_map : `healsparse.HealSparseMap`
Bitwise input map.
patch : `int`
Patch number.
tract : `int`
Tract number.
"""
def __init__(self, input_map, patch=0, tract=0):
self.input_map = input_map
self.tract = tract
self.patch = patch

def get(self):
"""
Retrieve the specified dataset using the API of the Gen 3 Butler.

Returns
-------
input_map : `healsparse.HealSparseMap`
"""
return self.input_map