Skip to content

Commit

Permalink
Update to update footprints with scarlet lite models
Browse files Browse the repository at this point in the history
  • Loading branch information
fred3m committed Jun 8, 2022
1 parent 226e1bc commit 9ce9a48
Showing 1 changed file with 107 additions and 20 deletions.
127 changes: 107 additions & 20 deletions python/lsst/meas/base/forcedPhotCoadd.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ class ForcedPhotCoaddConnections(pipeBase.PipelineTaskConnections,
storageClass="SourceCatalog",
dimensions=("band", "skymap", "tract", "patch")
)
footprintCatInBand = pipeBase.connectionTypes.Input(
doc="Catalog of footprints to attach to sources",
name="{inputCoaddName}Coadd_deblendedFlux",
storageClass="SourceCatalog",
dimensions=("band", "skymap", "tract", "patch")
)
scarletModels = pipeBase.connectionTypes.Input(
doc="Multiband scarlet models produced by the deblender",
name="{inputCoaddName}Coadd_scarletModelData",
storageClass="ScarletModelData",
dimensions=("tract", "patch", "skymap"),
)
refWcs = pipeBase.connectionTypes.Input(
doc="Reference world coordinate system.",
name="{inputCoaddName}Coadd.wcs",
Expand All @@ -87,6 +99,14 @@ class ForcedPhotCoaddConnections(pipeBase.PipelineTaskConnections,
dimensions=["band", "skymap", "tract", "patch"],
)

def __init__(self, *, config=None):
super().__init__(config=config)
if config.footprintDatasetName != "ScarletModelData":
self.inputs.remove("scarletModels")
if config.footprintDatasetName != "DeblendedFlux":
self.inputs.remove("footprintCatInBand")
print("forced_src inputs\n", self.inputs)


class ForcedPhotCoaddConfig(pipeBase.PipelineTaskConfig,
pipelineConnections=ForcedPhotCoaddConnections):
Expand Down Expand Up @@ -121,9 +141,21 @@ class ForcedPhotCoaddConfig(pipeBase.PipelineTaskConfig,
"Must have IDs that match those of the reference catalog."
"If None, Footprints will be generated by transforming the reference Footprints.",
dtype=str,
default="meas",
default="ScarletModelData",
optional=True
)
doConserveFlux = lsst.pex.config.Field(
dtype=bool,
default=True,
doc="Whether to use the deblender models as templates to re-distribute the flux "
"from the 'exposure' (True), or to perform measurements on the deblender model footprints. "
"If footprintDatasetName != 'ScarletModelData' then this field is ignored.")
doStripFootprints = lsst.pex.config.Field(
dtype=bool,
default=True,
doc="Whether to strip footprints from the output catalog before "
"saving to disk. "
"This is usually done when using scarlet models to save disk space.")
hasFakes = lsst.pex.config.Field(
dtype=bool,
default=False,
Expand Down Expand Up @@ -198,16 +230,29 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputs = butlerQC.get(inputRefs)

refCatInBand = inputs.pop('refCatInBand')
if self.config.footprintDatasetName == "ScarletModelData":
footprintData = inputs.pop("scarletModels")
elif self.config.footprintDatasetName == "DeblendedFlux":
footprintData = inputs.pop("footprintCatIndBand")
else:
footprintData = None
inputs['measCat'], inputs['exposureId'] = self.generateMeasCat(inputRefs.exposure.dataId,
inputs['exposure'],
inputs['refCat'],
refCatInBand,
inputs['refWcs'],
"tract_patch")
"tract_patch",
footprintData)
outputs = self.run(**inputs)
# Strip HeavyFootprints to save space on disk
if self.config.footprintDatasetName == "ScarletModelData" and self.config.doStripFootprints:
sources = outputs.measCat
for source in sources[sources["parent"] != 0]:
source.setFootprint(None)
butlerQC.put(outputs, outputRefs)

def generateMeasCat(self, exposureDataId, exposure, refCat, refCatInBand, refWcs, idPackerName):
def generateMeasCat(self, exposureDataId, exposure, refCat, refCatInBand, refWcs, idPackerName,
footprintData):
"""Generate a measurement catalog for Gen3.
Parameters
Expand All @@ -225,6 +270,11 @@ def generateMeasCat(self, exposureDataId, exposure, refCat, refCatInBand, refWcs
Reference world coordinate system.
idPackerName : `str`
Type of ID packer to construct from the registry.
footprintData : `ScarletDataModel` or `lsst.afw.table.SourceCatalog`
Either the scarlet data models or the deblended catalog
containing footprints.
If `footprintData` is `None` then the footprints contained
in `refCatInBand` are used.
Returns
-------
Expand All @@ -247,13 +297,26 @@ def generateMeasCat(self, exposureDataId, exposure, refCat, refCatInBand, refWcs
idFactory=idFactory)
# attach footprints here, as the attachFootprints method is geared for gen2
# and is not worth modifying, as this can naturally live inside this method
for srcRecord in measCat:
fpRecord = refCatInBand.find(srcRecord.getId())
if fpRecord is None:
raise LookupError("Cannot find Footprint for source {}; please check that {} "
"IDs are compatible with reference source IDs"
.format(srcRecord.getId(), self.config.connections.refCatInBand))
srcRecord.setFootprint(fpRecord.getFootprint())
if self.config.footprintDatasetName == "ScarletModelData":
# Load the scarlet models
self._attachScarletFootprints(
catalog=measCat,
modelData=footprintData,
exposure=exposure,
band=exposureDataId["band"]
)
else:
if self.config.footprintDatasetName is None:
footprintCat = refCatInBand
else:
footprintCat = footprintData
for srcRecord in measCat:
fpRecord = footprintCat.find(srcRecord.getId())
if fpRecord is None:
raise LookupError("Cannot find Footprint for source {}; please check that {} "
"IDs are compatible with reference source IDs"
.format(srcRecord.getId(), footprintCat))
srcRecord.setFootprint(fpRecord.getFootprint())
return measCat, exposureIdInfo.expId

def runDataRef(self, dataRef, psfCache=None):
Expand Down Expand Up @@ -399,16 +462,40 @@ def attachFootprints(self, sources, refCat, exposure, refWcs, dataRef):

self.log.info("Loading deblended footprints for sources from %s, %s",
self.config.footprintDatasetName, dataRef.dataId)
fpCat = dataRef.get("%sCoadd_%s" % (self.config.coaddName, self.config.footprintDatasetName),
immediate=True)
for refRecord, srcRecord in zip(refCat, sources):
fpRecord = fpCat.find(refRecord.getId())
if fpRecord is None:
raise LookupError("Cannot find Footprint for source %s; please check that %sCoadd_%s "
"IDs are compatible with reference source IDs" %
(srcRecord.getId(), self.config.coaddName,
self.config.footprintDatasetName))
srcRecord.setFootprint(fpRecord.getFootprint())

if self.config.footprintDatasetName == "ScarletModelData":
# Load the scarlet models
dataModel = dataRef.get("%sCoadd_%s" % (self.config.coaddName, self.config.footprintDatasetName),
immediate=True)
self._attachScarletFootprints(refCat, dataModel, exposure, dataRef.dataId["band"])
else:
fpCat = dataRef.get("%sCoadd_%s" % (self.config.coaddName, self.config.footprintDatasetName),
immediate=True)
for refRecord, srcRecord in zip(refCat, sources):
fpRecord = fpCat.find(refRecord.getId())
if fpRecord is None:
raise LookupError("Cannot find Footprint for source %s; please check that %sCoadd_%s "
"IDs are compatible with reference source IDs" %
(srcRecord.getId(), self.config.coaddName,
self.config.footprintDatasetName))
srcRecord.setFootprint(fpRecord.getFootprint())

def _attachScarletFootprints(self, catalog, modelData, exposure, band):
"""Attach scarlet models as HeavyFootprints
"""
if self.config.doConserveFlux:
redistributeImage = exposure.image
else:
redistributeImage = None
# Attach the footprints
modelData.updateCatalogFootprints(
catalog=catalog,
band=band,
psfModel=exposure.getPsf(),
redistributeImage=redistributeImage,
removeScarletData=True,
updateFluxColumns=False,
)

def getExposure(self, dataRef):
"""Read input exposure on which measurement will be performed.
Expand Down

0 comments on commit 9ce9a48

Please sign in to comment.