Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-28394: Add Tasks to write, transform, and consolidate ForcedSources #571

Merged
merged 2 commits into from
Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions pipelines/DRP.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ tasks:
connections.measCat: forced_diff
connections.outputSchema: forced_diff_schema
connections.exposure: goodSeeingDiff_differenceExp
writeForcedSourceTable: lsst.pipe.tasks.postprocess.WriteForcedSourceTableTask
transformForcedSourceTable: lsst.pipe.tasks.postprocess.TransformForcedSourceTableTask
consolidateForcedSourceTable: lsst.pipe.tasks.postprocess.ConsolidateForcedSourceTableTask
forcedPhotCcdOnDiaObjects:
class: lsst.meas.base.ForcedPhotCcdFromDataFrameTask
forcedPhotDiffOnDiaObjects:
Expand Down Expand Up @@ -208,3 +211,4 @@ contracts:
- transformDiaSourceCat.connections.diaSourceTable == drpAssociation.connections.diaSourceTables
- drpAssociation.connections.assocDiaSourceTable == drpDiaCalculation.connections.assocDiaSourceTable
- drpAssociation.connections.diaObjectTable == drpDiaCalculation.connections.diaObjectTable
- forcedPhotDiffim.connections.refCat == forcedPhotCcd.connections.refCat
27 changes: 12 additions & 15 deletions python/lsst/pipe/tasks/functors.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,18 @@ class Functor(object):
is anything other than `'ref'`, then an error will be raised when trying to
perform the calculation.

As currently implemented, `Functor` is only set up to expect a
dataset of the format of the `deepCoadd_obj` dataset; that is, a
dataframe with a multi-level column index,
with the levels of the column index being `band`,
`dataset`, and `column`. This is defined in the `_columnLevels` attribute,
as well as being implicit in the role of the `filt` and `dataset` attributes
defined at initialization. In addition, the `_get_data` method that reads
Originally, `Functor` was set up to expect
datasets formatted like the `deepCoadd_obj` dataset; that is, a
dataframe with a multi-level column index, with the levels of the
column index being `band`, `dataset`, and `column`.
It has since been generalized to apply to dataframes without mutli-level
indices and multi-level indices with just `dataset` and `column` levels.
In addition, the `_get_data` method that reads
the dataframe from the `ParquetTable` will return a dataframe with column
index levels defined by the `_dfLevels` attribute; by default, this is
`column`.

The `_columnLevels` and `_dfLevels` attributes should generally not need to
The `_dfLevels` attributes should generally not need to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_columnLevels is still mentioned in the documentation above here at line 110

be changed, unless `_func` needs columns from multiple filters or datasets
to do the calculation.
An example of this is the `lsst.pipe.tasks.functors.Color` functor, for
Expand All @@ -133,7 +133,6 @@ class Functor(object):
"""

_defaultDataset = 'ref'
_columnLevels = ('band', 'dataset', 'column')
_dfLevels = ('column',)
_defaultNoDup = False

Expand Down Expand Up @@ -250,12 +249,6 @@ def multilevelColumns(self, data, columnIndex=None, returnTuple=False):
# Confirm that the dataset has the column levels the functor is expecting it to have.
columnLevels = self._get_data_columnLevels(data, columnIndex)

if not set(columnLevels) == set(self._columnLevels):
raise ValueError(
"ParquetTable does not have the expected column levels. "
f"Got {columnLevels}; expected {self._columnLevels}."
)

columnDict = {'column': self.columns,
'dataset': self.dataset}
if self.filt is None:
Expand Down Expand Up @@ -547,6 +540,10 @@ def from_yaml(cls, translationDefinition, **kwargs):
else:
renameRules = None

if 'calexpFlags' in translationDefinition:
for flag in translationDefinition['calexpFlags']:
funcs[cls.renameCol(flag, renameRules)] = Column(flag, dataset='calexp')

if 'refFlags' in translationDefinition:
for flag in translationDefinition['refFlags']:
funcs[cls.renameCol(flag, renameRules)] = Column(flag, dataset='ref')
Expand Down
185 changes: 182 additions & 3 deletions python/lsst/pipe/tasks/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

from .parquetTable import ParquetTable
from .multiBandUtils import makeMergeArgumentParser, MergeSourcesRunner
from .functors import CompositeFunctor, RAColumn, DecColumn, Column
from .functors import CompositeFunctor, Column


def flattenFilters(df, noDupCols=['coord_ra', 'coord_dec'], camelCase=False, inputBands=None):
Expand Down Expand Up @@ -458,8 +458,7 @@ class PostprocessAnalysis(object):
only run during multi-band forced-photometry.
"""
_defaultRefFlags = []
_defaultFuncs = (('coord_ra', RAColumn()),
('coord_dec', DecColumn()))
_defaultFuncs = ()

def __init__(self, parq, functors, filt=None, flags=None, refFlags=None, forcedFlags=None):
self.parq = parq
Expand Down Expand Up @@ -1486,3 +1485,183 @@ def run(self, visitSummaries):

outputCatalog = pd.DataFrame(data=visitEntries)
return pipeBase.Struct(outputCatalog=outputCatalog)


class WriteForcedSourceTableConnections(pipeBase.PipelineTaskConnections,
dimensions=("instrument", "visit", "detector", "skymap", "tract")):

inputCatalog = connectionTypes.Input(
doc="Primary per-detector, single-epoch forced-photometry catalog. "
"By default, it is the output of ForcedPhotCcdTask on calexps",
name="forced_src",
storageClass="SourceCatalog",
dimensions=("instrument", "visit", "detector", "skymap", "tract")
)
inputCatalogDiff = connectionTypes.Input(
doc="Secondary multi-epoch, per-detector, forced photometry catalog. "
"By default, it is the output of ForcedPhotCcdTask run on image differences.",
name="forced_diff",
storageClass="SourceCatalog",
dimensions=("instrument", "visit", "detector", "skymap", "tract")
)
outputCatalog = connectionTypes.Output(
doc="InputCatalogs horizonatally joined on `objectId` in Parquet format",
name="forcedSource",
storageClass="DataFrame",
dimensions=("instrument", "visit", "detector")
)


class WriteForcedSourceTableConfig(WriteSourceTableConfig,
pipelineConnections=WriteForcedSourceTableConnections):
pass


class WriteForcedSourceTableTask(pipeBase.PipelineTask):
"""Merge and convert per-detector forced source catalogs to parquet
"""
_DefaultName = "writeForcedSourceTable"
ConfigClass = WriteForcedSourceTableConfig

def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputs = butlerQC.get(inputRefs)
# Add ccdVisitId to allow joining with CcdVisitTable
inputs['ccdVisitId'] = butlerQC.quantum.dataId.pack("visit_detector")
inputs['band'] = butlerQC.quantum.dataId.full['band']

outputs = self.run(**inputs)
butlerQC.put(outputs, outputRefs)

def run(self, inputCatalog, inputCatalogDiff, ccdVisitId=None, band=None):
dfs = []
for table, dataset, in zip((inputCatalog, inputCatalogDiff), ('calexp', 'diff')):
df = table.asAstropy().to_pandas().set_index('objectId', drop=False)
df = df.reindex(sorted(df.columns), axis=1)
df['ccdVisitId'] = ccdVisitId if ccdVisitId else pd.NA
df['band'] = band if band else pd.NA
df.columns = pd.MultiIndex.from_tuples([(dataset, c) for c in df.columns],
names=('dataset', 'column'))

dfs.append(df)

outputCatalog = functools.reduce(lambda d1, d2: d1.join(d2), dfs)
return pipeBase.Struct(outputCatalog=outputCatalog)


class TransformForcedSourceTableConnections(pipeBase.PipelineTaskConnections,
dimensions=("instrument", "skymap", "patch", "tract")):

inputCatalogs = connectionTypes.Input(
doc="Parquet table of merged ForcedSources produced by WriteForcedSourceTableTask",
name="forcedSource",
storageClass="DataFrame",
dimensions=("instrument", "visit", "detector"),
multiple=True,
deferLoad=True
)
referenceCatalog = connectionTypes.Input(
doc="Reference catalog which was used to seed the forcedPhot. Columns "
"objectId, detect_isPrimary, detect_isTractInner, detect_isPatchInner "
"are expected.",
name="objectTable",
storageClass="DataFrame",
dimensions=("tract", "patch", "skymap"),
deferLoad=True
)
outputCatalog = connectionTypes.Output(
doc="Narrower, temporally-aggregated, per-patch ForcedSource Table transformed and converted per a "
"specified set of functors",
name="ForcedSourceTable",
storageClass="DataFrame",
dimensions=("tract", "patch", "skymap")
)


class TransformForcedSourceTableConfig(TransformCatalogBaseConfig,
pipelineConnections=TransformForcedSourceTableConnections):
pass


class TransformForcedSourceTableTask(TransformCatalogBaseTask):
"""Transform/standardize a ForcedSource catalog

Transforms each wide, per-detector forcedSource parquet table per the
specification file (per-camera defaults found in ForcedSource.yaml).
All epochs that overlap the patch are aggregated into one per-patch
narrow-parquet file.

No de-duplication of rows is performed. Duplicate resolutions flags are
pulled in from the referenceCatalog: `detect_isPrimary`,
`detect_isTractInner`,`detect_isPatchInner`, so that user may de-duplicate
for analysis or compare duplicates for QA.

The resulting table includes multiple bands. Epochs (MJDs) and other useful
per-visit rows can be retreived by joining with the CcdVisitTable on
ccdVisitId.
"""
_DefaultName = "transformForcedSourceTable"
ConfigClass = TransformForcedSourceTableConfig

def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputs = butlerQC.get(inputRefs)
if self.funcs is None:
raise ValueError("config.functorFile is None. "
"Must be a valid path to yaml in order to run Task as a PipelineTask.")
outputs = self.run(inputs['inputCatalogs'], inputs['referenceCatalog'], funcs=self.funcs,
dataId=outputRefs.outputCatalog.dataId.full)

butlerQC.put(outputs, outputRefs)

def run(self, inputCatalogs, referenceCatalog, funcs=None, dataId=None, band=None):
dfs = []
ref = referenceCatalog.get(parameters={"columns": ['detect_isPrimary', 'detect_isTractInner',
'detect_isPatchInner']})
self.log.info("Aggregating %s input catalogs" % (len(inputCatalogs)))
for handle in inputCatalogs:
result = self.transform(None, handle, funcs, dataId)
# Filter for only rows that were detected on (overlap) the patch
dfs.append(ref.join(result.df, how='inner'))

outputCatalog = pd.concat(dfs)
self.log.info("Made a table of %d columns and %d rows",
len(outputCatalog.columns), len(outputCatalog))
return pipeBase.Struct(outputCatalog=outputCatalog)


class ConsolidateForcedSourceTableConnections(pipeBase.PipelineTaskConnections,
defaultTemplates={"catalogType": ""},
dimensions=("instrument", "tract")):
inputCatalogs = connectionTypes.Input(
doc="Input per-patch ForcedSource Tables",
name="{catalogType}ForcedSourceTable",
storageClass="DataFrame",
dimensions=("tract", "patch", "skymap"),
multiple=True,
)

outputCatalog = connectionTypes.Output(
doc="Output per-tract concatenation of ForcedSource Tables",
name="{catalogType}ForcedSourceTable_tract",
storageClass="DataFrame",
dimensions=("tract", "skymap"),
)


class ConsolidateForcedSourceTableConfig(pipeBase.PipelineTaskConfig,
pipelineConnections=ConsolidateForcedSourceTableConnections):
pass


class ConsolidateForcedSourceTableTask(CmdLineTask, pipeBase.PipelineTask):
"""Concatenate a per-patch `ForcedSourceTable` list into a single
per-tract `forcedSourceTable_tract`
"""
_DefaultName = 'consolidateForcedSourceTable'
ConfigClass = ConsolidateForcedSourceTableConfig

def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputs = butlerQC.get(inputRefs)
self.log.info("Concatenating %s per-patch ForcedSource Tables",
len(inputs['inputCatalogs']))
df = pd.concat(inputs['inputCatalogs'])
butlerQC.put(pipeBase.Struct(outputCatalog=df), outputRefs)
2 changes: 0 additions & 2 deletions tests/test_transformObject.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ def testNullFilter(self):
funcs = {'Fwhm': HsmFwhm(dataset='meas')}
df = task.run(self.parq, funcs=funcs, dataId=self.dataId)
self.assertIsInstance(df, pd.DataFrame)
for column in ('coord_ra', 'coord_dec'):
self.assertIn(column, df.columns)

for filt in config.outputBands:
self.assertIn(filt + 'Fwhm', df.columns)
Expand Down