Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tickets/DM-33710: Implement ScarletDataModels #687

Merged
merged 1 commit into from
Jun 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 21 additions & 11 deletions python/lsst/pipe/tasks/deblendCoaddSourcesPipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,24 @@ class DeblendCoaddSourcesMultiConnections(PipelineTaskConnections,
dimensions=("tract", "patch", "band", "skymap"),
multiple=True
TallJimbo marked this conversation as resolved.
Show resolved Hide resolved
)
deblendedCatalog = cT.Output(
doc="Catalogs produced by multiband deblending",
name="{outputCoaddName}Coadd_deblendedCatalog",
storageClass="SourceCatalog",
dimensions=("tract", "patch", "skymap"),
)
scarletModelData = cT.Output(
doc="Multiband scarlet models produced by the deblender",
name="{outputCoaddName}Coadd_scarletModelData",
storageClass="ScarletModelData",
dimensions=("tract", "patch", "skymap"),
)

def __init__(self, *, config=None):
super().__init__(config=config)
# Remove unused connections.
# TODO: deprecate once RFC-860 passes.
self.outputs -= set(("fluxCatalogs", "templateCatalogs"))


class DeblendCoaddSourcesMultiConfig(PipelineTaskConfig,
Expand Down Expand Up @@ -213,19 +231,11 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputs["idFactory"] = exposureIdInfo.makeSourceIdFactory()
inputs["filters"] = [dRef.dataId["band"] for dRef in inputRefs.coadds]
outputs = self.run(**inputs)
for outRef in outputRefs.templateCatalogs:
band = outRef.dataId['band']
if (catalog := outputs.templateCatalogs.get(band)) is not None:
butlerQC.put(catalog, outRef)

for outRef in outputRefs.fluxCatalogs:
band = outRef.dataId['band']
if (catalog := outputs.fluxCatalogs.get(band)) is not None:
butlerQC.put(catalog, outRef)
butlerQC.put(outputs, outputRefs)

def run(self, coadds, filters, mergedDetections, idFactory):
sources = self._makeSourceCatalog(mergedDetections, idFactory)
multiExposure = afwImage.MultibandExposure.fromExposures(filters, coadds)
templateCatalogs, fluxCatalogs = self.multibandDeblend.run(multiExposure, sources)
retStruct = Struct(templateCatalogs=templateCatalogs, fluxCatalogs=fluxCatalogs)
catalog, modelData = self.multibandDeblend.run(multiExposure, sources)
retStruct = Struct(deblendedCatalog=catalog, scarletModelData=modelData)
return retStruct
84 changes: 75 additions & 9 deletions python/lsst/pipe/tasks/multiBand.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
# the GNU General Public License along with this program. If not,
# see <https://www.lsstcorp.org/LegalNotices/>.
#
import warnings
import numpy as np

from lsst.coadd.utils.coaddDataIdContainer import ExistingCoaddDataIdContainer
from lsst.coadd.utils.getGen3CoaddExposureId import getGen3CoaddExposureId
from lsst.pipe.base import (CmdLineTask, Struct, ArgumentParser, ButlerInitializedTaskRunner,
PipelineTask, PipelineTaskConfig, PipelineTaskConnections)
import lsst.pipe.base.connectionTypes as cT
from lsst.pex.config import Config, Field, ConfigurableField
from lsst.pex.config import Config, Field, ConfigurableField, ChoiceField
from lsst.meas.algorithms import DynamicDetectionTask, ReferenceObjectLoader, ScaleVarianceTask
from lsst.meas.base import SingleFrameMeasurementTask, ApplyApCorrTask, CatalogCalculationTask
from lsst.meas.deblender import SourceDeblendTask
Expand Down Expand Up @@ -590,6 +591,8 @@ class MeasureMergedCoaddSourcesConnections(PipelineTaskConnections,
defaultTemplates={"inputCoaddName": "deep",
"outputCoaddName": "deep",
"deblendedCatalog": "deblendedFlux"}):
warnings.warn("MeasureMergedCoaddSourcesConnections.defaultTemplates is deprecated and no longer used. "
"Use MeasureMergedCoaddSourcesConfig.inputCatalog.")
inputSchema = cT.InitInput(
doc="Input schema for measure merged task produced by a deblender or detection task",
name="{inputCoaddName}Coadd_deblendedFlux_schema",
Expand Down Expand Up @@ -659,6 +662,18 @@ class MeasureMergedCoaddSourcesConnections(PipelineTaskConnections,
storageClass="SourceCatalog",
dimensions=("tract", "patch", "band", "skymap"),
)
scarletCatalog = cT.Input(
doc="Catalogs produced by multiband deblending",
name="{inputCoaddName}Coadd_deblendedCatalog",
storageClass="SourceCatalog",
dimensions=("tract", "patch", "skymap"),
)
scarletModels = cT.Input(
doc="Multiband scarlet models produced by the deblender",
name="{inputCoaddName}Coadd_scarletModelData",
storageClass="ScarletModelData",
dimensions=("tract", "patch", "skymap"),
)
outputSources = cT.Output(
doc="Source catalog containing all the measurement information generated in this task",
name="{outputCoaddName}Coadd_meas",
Expand Down Expand Up @@ -698,6 +713,15 @@ def __init__(self, *, config=None):
self.inputs -= set(("sourceTableHandles",))
self.inputs -= set(("finalizedSourceTableHandles",))

if config.inputCatalog == "deblendedCatalog":
self.inputs -= set(("inputCatalog",))

if not config.doAddFootprints:
self.inputs -= set(("scarletModels",))
else:
self.inputs -= set(("deblendedCatalog"))
self.inputs -= set(("scarletModels",))

if config.doMatchSources is False:
self.outputs -= set(("matchResult",))

Expand All @@ -712,11 +736,29 @@ class MeasureMergedCoaddSourcesConfig(PipelineTaskConfig,

@brief Configuration parameters for the MeasureMergedCoaddSourcesTask
"""
inputCatalog = Field(dtype=str, default="deblendedFlux",
doc=("Name of the input catalog to use."
"If the single band deblender was used this should be 'deblendedFlux."
"If the multi-band deblender was used this should be 'deblendedModel."
"If no deblending was performed this should be 'mergeDet'"))
inputCatalog = ChoiceField(
dtype=str,
default="deblendedCatalog",
allowed={
"deblendedCatalog": "Output catalog from ScarletDeblendTask",
"deblendedFlux": "Output catalog from SourceDeblendTask",
"mergeDet": "The merged detections before deblending."
},
doc="The name of the input catalog.",
)
doAddFootprints = Field(dtype=bool,
default=True,
doc="Whether or not to add footprints to the input catalog from scarlet models. "
"This should be true whenever using the multi-band deblender, "
"otherwise this should be False.")
doConserveFlux = Field(dtype=bool, default=True,
doc="Whether to use the deblender models as templates to re-distribute the flux "
"from the 'exposure' (True), or to perform measurements on the deblender "
"model footprints.")
doStripFootprints = Field(dtype=bool, default=True,
doc="Whether to strip footprints from the output catalog before "
"saving to disk. "
"This is usually done when using scarlet models to save disk space.")
measurement = ConfigurableField(target=SingleFrameMeasurementTask, doc="Source measurement")
setPrimaryFlags = ConfigurableField(target=SetPrimaryFlagsTask, doc="Set flags for primary tract/patch")
doPropagateFlags = Field(
Expand Down Expand Up @@ -992,15 +1034,37 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):
# Transform inputCatalog
table = afwTable.SourceTable.make(self.schema, idFactory)
sources = afwTable.SourceCatalog(table)
sources.extend(inputs.pop('inputCatalog'), self.schemaMapper)
# Load the correct input catalog
if "scarletCatalog" in inputs:
inputCatalog = inputs.pop("scarletCatalog")
catalogRef = inputRefs.scarletCatalog
else:
inputCatalog = inputs.pop("inputCatalog")
catalogRef = inputRefs.inputCatalog
sources.extend(inputCatalog, self.schemaMapper)
del inputCatalog
# Add the HeavyFootprints to the deblended sources
if self.config.doAddFootprints:
modelData = inputs.pop('scarletModels')
if self.config.doConserveFlux:
redistributeImage = inputs['exposure'].image
else:
redistributeImage = None
modelData.updateCatalogFootprints(
catalog=sources,
band=inputRefs.exposure.dataId["band"],
psfModel=inputs['exposure'].getPsf(),
redistributeImage=redistributeImage,
removeScarletData=True,
)
table = sources.getTable()
table.setMetadata(self.algMetadata) # Capture algorithm metadata to write out to the source catalog.
inputs['sources'] = sources

skyMap = inputs.pop('skyMap')
tractNumber = inputRefs.inputCatalog.dataId['tract']
tractNumber = catalogRef.dataId['tract']
tractInfo = skyMap[tractNumber]
patchInfo = tractInfo.getPatchInfo(inputRefs.inputCatalog.dataId['patch'])
patchInfo = tractInfo.getPatchInfo(catalogRef.dataId['patch'])
skyInfo = Struct(
skyMap=skyMap,
tractInfo=tractInfo,
Expand Down Expand Up @@ -1052,6 +1116,8 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputs['ccdInputs'] = ccdInputs

outputs = self.run(**inputs)
# Strip HeavyFootprints to save space on disk
sources = outputs.outputSources
butlerQC.put(outputs, outputRefs)

def runDataRef(self, patchRef, psfCache=100):
Expand Down
13 changes: 10 additions & 3 deletions tests/test_isPrimaryFlag.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,17 @@ def testIsScarletPrimaryFlag(self):
src.setFootprint(foot)
src.set("merge_peak_sky", True)
# deblend
result, fluxResult = deblendTask.run(coadds, catalog)
catalog, modelData = deblendTask.run(coadds, catalog)
# Attach footprints to the catalog
modelData.updateCatalogFootprints(
catalog=catalog,
band="test",
psfModel=coadds["test"].getPsf(),
redistributeImage=None,
)
# measure
measureTask.run(result["test"], self.exposure)
outputCat = result["test"]
measureTask.run(catalog, self.exposure)
outputCat = catalog
# Set the primary flags
setPrimaryTask.run(outputCat, skyMap=skyMap, tractInfo=tractInfo, patchInfo=patchInfo)

Expand Down