Skip to content

Commit

Permalink
Add GetMultiTractCoaddTemplateTask
Browse files Browse the repository at this point in the history
warps coadds from multiple tracts onto detector geometry
  • Loading branch information
yalsayyad committed Sep 1, 2021
1 parent 5559182 commit a46130a
Showing 1 changed file with 237 additions and 2 deletions.
239 changes: 237 additions & 2 deletions python/lsst/ip/diffim/getTemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,18 @@
import lsst.afw.image as afwImage
import lsst.geom as geom
import lsst.afw.geom as afwGeom
import lsst.afw.table as afwTable
import lsst.afw.math as afwMath
import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
from lsst.skymap import BaseSkyMap
from lsst.daf.butler import DeferredDatasetHandle
from lsst.ip.diffim.dcrModel import DcrModel
from lsst.meas.algorithms import CoaddPsf, WarpedPsf, CoaddPsfConfig

__all__ = ["GetCoaddAsTemplateTask", "GetCoaddAsTemplateConfig",
"GetCalexpAsTemplateTask", "GetCalexpAsTemplateConfig"]
"GetCalexpAsTemplateTask", "GetCalexpAsTemplateConfig",
"GetMultiTractCoaddTemplateTask", "GetMultiTractCoaddTemplateConfig"]


class GetCoaddAsTemplateConfig(pexConfig.Config):
Expand Down Expand Up @@ -169,7 +174,8 @@ def runQuantum(self, exposure, butlerQC, skyMapRef, coaddExposureRefs):
if tracts.count(tracts[0]) == len(tracts):
tractInfo = skyMap[tracts[0]]
else:
raise RuntimeError("Templates constructed from multiple Tracts not yet supported")
raise RuntimeError("Templates constructed from multiple Tracts not supported by this task. "
"Use GetMultiTractCoaddTemplateTask instead.")

detectorBBox = exposure.getBBox()
detectorWcs = exposure.getWcs()
Expand Down Expand Up @@ -474,3 +480,232 @@ def runDataRef(self, *args, **kwargs):

def runQuantum(self, **kwargs):
raise NotImplementedError("Calexp template is not supported with gen3 middleware")


class GetMultiTractCoaddTemplateConnections(pipeBase.PipelineTaskConnections,
dimensions=("instrument", "visit", "detector", "skymap"),
defaultTemplates={"coaddName": "goodSeeing",
"warpTypeSuffix": "",
"fakesType": ""}):
bbox = pipeBase.connectionTypes.Input(
doc="BBoxes of calexp used determine geometry of output template",
name="{fakesType}calexp.bbox",
storageClass="Box2I",
dimensions=("instrument", "visit", "detector"),
)
wcs = pipeBase.connectionTypes.Input(
doc="WCSs of calexps that we want to fetch the template for",
name="{fakesType}calexp.wcs",
storageClass="Wcs",
dimensions=("instrument", "visit", "detector"),
)
skyMap = pipeBase.connectionTypes.Input(
doc="Input definition of geometry/bbox and projection/wcs for template exposures",
name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
dimensions=("skymap", ),
storageClass="SkyMap",
)
# TODO DM-31292: Add option to use global external wcs from jointcal
# Needed for DRP HSC
coaddExposures = pipeBase.connectionTypes.Input(
doc="Input template to match and subtract from the exposure",
dimensions=("tract", "patch", "skymap", "band"),
storageClass="ExposureF",
name="{fakesType}{coaddName}Coadd{warpTypeSuffix}",
multiple=True,
deferLoad=True
)
outputExposure = pipeBase.connectionTypes.Output(
doc="Warped template used to create `subtractedExposure`.",
dimensions=("instrument", "visit", "detector"),
storageClass="ExposureF",
name="{fakesType}{coaddName}Diff_templateExp{warpTypeSuffix}",
)


class GetMultiTractCoaddTemplateConfig(pipeBase.PipelineTaskConfig, GetCoaddAsTemplateConfig,
pipelineConnections=GetMultiTractCoaddTemplateConnections):
warp = pexConfig.ConfigField(
dtype=afwMath.Warper.ConfigClass,
doc="warper configuration",
)
coaddPsf = pexConfig.ConfigField(
doc="Configuration for CoaddPsf",
dtype=CoaddPsfConfig,
)

def setDefaults(self):
self.warp.warpingKernelName = 'lanczos5'
self.coaddPsf.warpingKernelName = 'lanczos5'


class GetMultiTractCoaddTemplateTask(pipeBase.PipelineTask):
ConfigClass = GetMultiTractCoaddTemplateConfig
_DefaultName = "getMultiTractCoaddTemplateTask"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.warper = afwMath.Warper.fromConfig(self.config.warp)

def runQuantum(self, butlerQC, inputRefs, outputRefs):
# Read in all inputs.
inputs = butlerQC.get(inputRefs)
inputs['coaddExposures'] = self.getOverlappingExposures(inputs)
# SkyMap only needed for filtering without
inputs.pop('skyMap')
outputs = self.run(**inputs)
butlerQC.put(outputs, outputRefs)

def getOverlappingExposures(self, inputs):
"""Return list of coaddExposure DeferredDatasetHandles that overlap detector
The spatial index in the registry has generous padding and often supplies
patches near, but not directly overlapping the detector.
Filters inputs so that we don't have to read in all input coadds.
Parameters
----------
inputs : `dict` of task Inputs
Returns
-------
coaddExposures : list of elements of type
`lsst.daf.butler.DeferredDatasetHandle` of
`lsst.afw.image.Exposure`
Raises
------
NoWorkFound
Raised if no patches overlap the input detector bbox
"""
# Check that the patches actually overlap the detector
# Exposure's validPolygon would be more accurate
detectorPolygon = geom.Box2D(inputs['bbox'])
overlappingArea = 0
coaddExposureList = []
for coaddRef in inputs['coaddExposures']:
dataId = coaddRef.dataId
patchWcs = inputs['skyMap'][dataId['tract']].getWcs()
patchBBox = inputs['skyMap'][dataId['tract']][dataId['patch']].getOuterBBox()
patchCorners = patchWcs.pixelToSky(geom.Box2D(patchBBox).getCorners())
patchPolygon = afwGeom.Polygon(inputs['wcs'].skyToPixel(patchCorners))
if patchPolygon.intersection(detectorPolygon):
overlappingArea += patchPolygon.intersectionSingle(detectorPolygon).calculateArea()
self.log.info("Using template input tract=%s, patch=%s" %
(dataId['tract'], dataId['patch']))
coaddExposureList.append(coaddRef)

if not overlappingArea:
raise pipeBase.NoWorkFound('No patches overlap detector')

return coaddExposureList

def run(self, coaddExposures, bbox, wcs):
"""Warp coadds from multiple tracts to form a template for image diff.
Where the tracts overlap, the resulting template image is averaged.
The PSF on the template is created by combining the CoaddPsf on each
template image into a meta-CoaddPsf.
Parameters
----------
coaddExposures: list of DeferredDatasetHandle to `lsst.afw.image.Exposure`
Coadds to be mosaicked
bbox : `lsst.geom.Box2I`
Template Bounding box of the detector geometry onto which to
resample the coaddExposures
wcs : `lsst.afw.geom.SkyWcs`
Template WCS onto which to resample the coaddExposures
Returns
-------
result : `struct`
return a pipeBase.Struct:
- ``outputExposure`` : a template coadd exposure assembled out of patches
Raises
------
NoWorkFound
Raised if no patches overlatp the input detector bbox
"""
# Table for CoaddPSF
tractsSchema = afwTable.ExposureTable.makeMinimalSchema()
tractKey = tractsSchema.addField('tract', type=np.int32, doc='Which tract')
patchKey = tractsSchema.addField('patch', type=np.int32, doc='Which patch')
weightKey = tractsSchema.addField('weight', type=float, doc='Weight for each tract, should be 1')
tractsCatalog = afwTable.ExposureCatalog(tractsSchema)

finalWcs = wcs
bbox.grow(self.config.templateBorderSize)
finalBBox = bbox

nPatchesFound = 0
maskedImageList = []
weightList = []

for coaddExposure in coaddExposures:
coaddPatch = coaddExposure.get()

# warp to detector WCS
xyTransform = afwGeom.makeWcsPairTransform(coaddPatch.getWcs(), finalWcs)
psfWarped = WarpedPsf(coaddPatch.getPsf(), xyTransform)
warped = self.warper.warpExposure(finalWcs, coaddPatch, maxBBox=finalBBox)

# Check if warped image is viable
if not np.any(np.isfinite(warped.image.array)):
self.log.info("No overlap for warped %s. Skipping" % coaddExposure.ref.dataId)
continue

warped.setPsf(psfWarped)

exp = afwImage.ExposureF(finalBBox, finalWcs)
exp.maskedImage.set(np.nan, afwImage.Mask.getPlaneBitMask("NO_DATA"), np.nan)
exp.maskedImage.assign(warped.maskedImage, warped.getBBox())

maskedImageList.append(exp.maskedImage)
weightList.append(1)
record = tractsCatalog.addNew()
record.setPsf(psfWarped)
record.setWcs(finalWcs)
record.setPhotoCalib(coaddPatch.getPhotoCalib())
record.setBBox(warped.getBBox())
record.set(tractKey, coaddExposure.ref.dataId['tract'])
record.set(patchKey, coaddExposure.ref.dataId['patch'])
record.set(weightKey, 1.)
nPatchesFound += 1

if nPatchesFound == 0:
raise pipeBase.NoWorkFound("No patches found to overlap detector")

# Combine images from individual patches together
statsFlags = afwMath.stringToStatisticsProperty('MEAN')
statsCtrl = afwMath.StatisticsControl()
statsCtrl.setNanSafe(True)
statsCtrl.setWeighted(True)
statsCtrl.setCalcErrorFromInputVariance(True)

templateExposure = afwImage.ExposureF(finalBBox, finalWcs)
templateExposure.maskedImage.set(np.nan, afwImage.Mask.getPlaneBitMask("NO_DATA"), np.nan)
xy0 = templateExposure.getXY0()
# Do not mask any values
templateExposure.maskedImage = afwMath.statisticsStack(maskedImageList, statsFlags, statsCtrl,
weightList, clipped=0, maskMap=[])
templateExposure.maskedImage.setXY0(xy0)

# CoaddPsf centroid not only must overlap image, but must overlap the part of
# image with data. Use centroid of region with data
boolmask = templateExposure.mask.array & templateExposure.mask.getPlaneBitMask('NO_DATA') == 0
maskx = afwImage.makeMaskFromArray(boolmask.astype(afwImage.MaskPixel))
centerCoord = afwGeom.SpanSet.fromMask(maskx, 1).computeCentroid()

ctrl = self.config.coaddPsf.makeControl()
coaddPsf = CoaddPsf(tractsCatalog, finalWcs, centerCoord, ctrl.warpingKernelName, ctrl.cacheSize)
if coaddPsf is None:
raise RuntimeError("CoaddPsf could not be constructed")

templateExposure.setPsf(coaddPsf)
templateExposure.setFilterLabel(coaddPatch.getFilterLabel())
templateExposure.setPhotoCalib(coaddPatch.getPhotoCalib())
return pipeBase.Struct(outputExposure=templateExposure)

0 comments on commit a46130a

Please sign in to comment.