Skip to content

Commit

Permalink
Convert MeasureMergedCoaddSourcesTask to PipelineTask
Browse files Browse the repository at this point in the history
  • Loading branch information
natelust committed Jan 19, 2019
1 parent 37237b9 commit 9fc727f
Showing 1 changed file with 239 additions and 28 deletions.
267 changes: 239 additions & 28 deletions python/lsst/pipe/tasks/multiBand.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
#
from lsst.coadd.utils.coaddDataIdContainer import ExistingCoaddDataIdContainer
from lsst.pipe.base import (CmdLineTask, Struct, ArgumentParser, ButlerInitializedTaskRunner,
PipelineTask, InitOutputDatasetField, InputDatasetField, OutputDatasetField,
PipelineTask, PipelineTaskConfig, InitInputDatasetField,
InitOutputDatasetField, InputDatasetField, OutputDatasetField,
QuantumConfig)
from lsst.pex.config import Config, Field, ConfigurableField
from lsst.meas.algorithms import DynamicDetectionTask
from lsst.meas.algorithms import DynamicDetectionTask, ReferenceObjectLoader
from lsst.meas.base import SingleFrameMeasurementTask, ApplyApCorrTask, CatalogCalculationTask
from lsst.meas.deblender import SourceDeblendTask, MultibandDeblendTask
from lsst.pipe.tasks.coaddBase import getSkyInfo
Expand Down Expand Up @@ -580,7 +581,7 @@ def getExposureId(self, dataRef):
return int(dataRef.get(self.config.coaddName + "CoaddId"))


class MeasureMergedCoaddSourcesConfig(Config):
class MeasureMergedCoaddSourcesConfig(PipelineTaskConfig):
"""!
@anchor MeasureMergedCoaddSourcesConfig_
Expand All @@ -607,6 +608,7 @@ class MeasureMergedCoaddSourcesConfig(Config):
"This format uses more disk space, but is more convenient to read."),
)
coaddName = Field(dtype=str, default="deep", doc="Name of coadd")
psfCache = Field(dtype=int, default=100, doc="Size of psfCache")
checkUnitsParseStrict = Field(
doc="Strictness of Astropy unit compatibility check, can be 'raise', 'warn' or 'silent'",
dtype=str,
Expand All @@ -630,9 +632,88 @@ class MeasureMergedCoaddSourcesConfig(Config):
target=CatalogCalculationTask,
doc="Subtask to run catalogCalculation plugins on catalog"
)
inputSchema = InitInputDatasetField(
doc="Input schema for measure merged task produced by a deblender or detection task",
nameTemplate="{inputCoaddName}Coadd_deblendedFlux_schema",
storageClass="SourceCatalog"
)
outputSchema = InitOutputDatasetField(
doc="Output schema after all new fields are added by task",
nameTemplate="{inputCoaddName}Coadd_meas_schema",
storageClass="SourceCatalog"
)
refCat = InputDatasetField(
doc="Reference catalog used to match measured sources against known sources",
name="ref_cat",
storageClass="SimpleCatalog",
dimensions=("SkyPix",),
manualLoad=True
)
exposure = InputDatasetField(
doc="Input coadd image",
nameTemplate="{inputCoaddName}Coadd_calexp",
scalar=True,
storageClass="ExposureF",
dimensions=("Tract", "Patch", "AbstractFilter", "SkyMap")
)
skyMap = InputDatasetField(
doc="SkyMap to use in processing",
nameTemplate="{inputCoaddName}Coadd_skyMap",
storageClass="SkyMap",
dimensions=("SkyMap",),
scalar=True
)
visitCatalogs = InputDatasetField(
doc="Source catalogs for visits which overlap input tract, patch, abstract_filter. Will be "
"further filtered in the task for the purpose of propagating flags from image calibration "
"and characterization to codd objects",
name="src",
dimensions=("Instrument", "Visit", "Detector"),
storageClass="SourceCatalog"
)
intakeCatalog = InputDatasetField(
doc=("Name of the input catalog to use."
"If the single band deblender was used this should be 'deblendedFlux."
"If the multi-band deblender was used this should be 'deblendedModel, "
"or deblendedFlux if the multiband deblender was configured to output "
"deblended flux catalogs. If no deblending was performed this should "
"be 'mergeDet'"),
nameTemplate="{inputCoaddName}Coadd_deblendedFlux",
storageClass="SourceCatalog",
dimensions=("Tract", "Patch", "AbstractFilter", "SkyMap"),
scalar=True
)
outputSources = OutputDatasetField(
doc="Source catalog containing all the measurement information generated in this task",
nameTemplate="{outputCoaddName}Coadd_meas",
dimensions=("Tract", "Patch", "AbstractFilter", "SkyMap"),
storageClass="SourceCatalog",
scalar=True
)
matchResult = OutputDatasetField(
doc="Match catalog produced by configured matcher, optional on doMatchSources",
nameTemplate="{outputCoaddName}Coadd_measMatch",
dimensions=("Tract", "Patch", "AbstractFilter", "SkyMap"),
storageClass="Catalog",
scalar=True
)
denormMatches = OutputDatasetField(
doc="Denormalized Match catalog produced by configured matcher, optional on "
"doWriteMatchesDenormalized",
nameTemplate="{outputCoaddName}Coadd_measMatchFull",
dimensions=("Tract", "Patch", "AbstractFilter", "SkyMap"),
storageClass="Catalog",
scalar=True
)

@property
def refObjLoader(self):
return self.match.refObjLoader

def setDefaults(self):
Config.setDefaults(self)
super().setDefaults()
self.formatTemplateNames({"inputCoaddName": "deep", "outputCoaddName": "deep"})
self.quantum.dimensions = ("Tract", "Patch", "AbstractFilter", "SkyMap")
self.measurement.plugins.names |= ['base_InputCount', 'base_Variance']
self.measurement.plugins['base_PixelFlags'].masksFpAnywhere = ['CLIPPED', 'SENSOR_EDGE',
'INEXACT_PSF']
Expand All @@ -654,7 +735,7 @@ def getTargetList(parsedCmd, **kwargs):
return ButlerInitializedTaskRunner.getTargetList(parsedCmd, psfCache=parsedCmd.psfCache)


class MeasureMergedCoaddSourcesTask(CmdLineTask):
class MeasureMergedCoaddSourcesTask(PipelineTask, CmdLineTask):
r"""!
@anchor MeasureMergedCoaddSourcesTask_
Expand Down Expand Up @@ -778,7 +859,8 @@ def _makeArgumentParser(cls):
parser.add_argument("--psfCache", type=int, default=100, help="Size of CoaddPsf cache")
return parser

def __init__(self, butler=None, schema=None, peakSchema=None, refObjLoader=None, **kwargs):
def __init__(self, butler=None, schema=None, peakSchema=None, refObjLoader=None, initInputs=None,
**kwargs):
"""!
@brief Initialize the task.
Expand All @@ -795,9 +877,11 @@ def __init__(self, butler=None, schema=None, peakSchema=None, refObjLoader=None,
This will include all fields from the input schema, as well as additional fields for all the
measurements.
"""
CmdLineTask.__init__(self, **kwargs)
super().__init__(**kwargs)
self.deblended = self.config.inputCatalog.startswith("deblended")
self.inputCatalog = "Coadd_" + self.config.inputCatalog
if initInputs is not None:
schema = initInputs['inputSchema'].schema
if schema is None:
assert butler is not None, "Neither butler nor schema is defined"
schema = butler.get(self.config.coaddName + self.inputCatalog + "_schema", immediate=True).schema
Expand All @@ -808,8 +892,6 @@ def __init__(self, butler=None, schema=None, peakSchema=None, refObjLoader=None,
self.makeSubtask("measurement", schema=self.schema, algMetadata=self.algMetadata)
self.makeSubtask("setPrimaryFlags", schema=self.schema)
if self.config.doMatchSources:
if refObjLoader is None:
assert butler is not None, "Neither butler nor refObjLoader is defined"
self.makeSubtask("match", butler=butler, refObjLoader=refObjLoader)
if self.config.doPropagateFlags:
self.makeSubtask("propagateFlags", schema=self.schema)
Expand All @@ -819,6 +901,84 @@ def __init__(self, butler=None, schema=None, peakSchema=None, refObjLoader=None,
if self.config.doRunCatalogCalculation:
self.makeSubtask("catalogCalculation", schema=self.schema)

@classmethod
def getInputDatasetTypes(cls, config):
inputDatasetTypes = super().getInputDatasetTypes(config)
if not config.doPropagateFlags:
inputDatasetTypes.pop("visitCatalogs")
return inputDatasetTypes

@classmethod
def getOutputDatasetTypes(cls, config):
outputDatasetTypes = super().getOutputDatasetTypes(config)
if config.doMatchSources is False:
outputDatasetTypes.pop("matchResult")
if config.doWriteMatchesDenormalized is False:
outputDatasetTypes.pop("denormMatches")
return outputDatasetTypes

def getInitOutputDatasets(self):
return {"outputSchema": afwTable.SourceCatalog(self.schema)}

def adaptArgsAndRun(self, inputData, inputDataIds, outputDataIds, butler):
refObjLoader = ReferenceObjectLoader(inputDataIds['refCat'], butler,
config=self.config.refObjLoader, log=self.log)
self.match.setRefObjLoader(refObjLoader)

# Set psfcache
# move this to run after gen2 deprecation
inputData['exposure'].getPsf().setCacheCapacity(self.config.psfCache)

# Transform inputCatalog
idFactory = afwTable.IdFactory.makeSimple()
table = afwTable.SourceTable.make(self.schema, idFactory)
sources = afwTable.SourceCatalog(table)
sources.extend(inputData.pop('intakeCatalog'), self.schemaMapper)
table = sources.getTable()
table.setMetadata(self.algMetadata) # Capture algorithm metadata to write out to the source catalog.
inputData['sources'] = sources

inputData['exposureId'] = 0

skyMap = inputData.pop('skyMap')
tractNumber = inputDataIds['intakeCatalog']['tract']
tractInfo = skyMap[tractNumber]
patchInfo = tractInfo.getPatchInfo(inputDataIds['intakeCatalog']['patch'])
skyInfo = Struct(
skyMap=skyMap,
tractInfo=tractInfo,
patchInfo=patchInfo,
wcs=tractInfo.getWcs(),
bbox=patchInfo.getOuterBBox()
)
inputData['skyInfo'] = skyInfo

if self.config.doPropagateFlags:
# Filter out any visit catalog that is not coadd inputs
ccdInputs = inputData['exposure'].getInfo().getCoaddInputs().ccds
visitKey = ccdInputs.schema.find("visit").key
ccdKey = ccdInputs.schema.find("ccd").key
inputVisitIds = set()
ccdRecordsWcs = {}
for ccdRecord in ccdInputs:
visit = ccdRecord.get(visitKey)
ccd = ccdRecord.get(ccdKey)
inputVisitIds.add((visit, ccd))
ccdRecordsWcs[(visit, ccd)] = ccdRecord.getWcs()

inputCatalogsToKeep = []
inputCatalogWcsUpdate = []
for i, dataId in enumerate(inputDataIds['visitCatalogs']):
key = (dataId['visit'], dataId['detector'])
if key in inputVisitIds:
inputCatalogsToKeep.append(inputData['visitCatalogs'][i])
inputCatalogWcsUpdate.append(ccdRecordsWcs[key])
inputData['visitCatalogs'] = inputCatalogsToKeep
inputData['wcsUpdates'] = inputCatalogWcsUpdate
inputData['ccdInputs'] = ccdInputs

return self.run(**inputData)

def runDataRef(self, patchRef, psfCache=100):
"""!
@brief Deblend and measure.
Expand All @@ -834,8 +994,57 @@ def runDataRef(self, patchRef, psfCache=100):
sources = self.readSources(patchRef)
table = sources.getTable()
table.setMetadata(self.algMetadata) # Capture algorithm metadata to write out to the source catalog.
skyInfo = getSkyInfo(coaddName=self.config.coaddName, patchRef=patchRef)

results = self.run(exposure=exposure, sources=sources,
ccdInputs=self.propagateFlags.getCcdInputs(exposure),
skyInfo=skyInfo, butler=patchRef.getButler(),
exposureId=self.getExposureId(patchRef))

if self.config.doMatchSources:
self.writeMatches(patchRef, results)
self.write(patchRef, results.outputSources)

def run(self, exposure, sources, skyInfo, exposureId, ccdInputs=None, visitCatalogs=None, wcsUpdates=None,
butler=None):
"""Run measurement algorithms on the input exposure, and optionally populate the
resulting catalog with extra information.
self.measurement.run(sources, exposure, exposureId=self.getExposureId(patchRef))
Parameters
----------
exposure : `lsst.afw.exposure.Exposure`
The input exposure on which measurements are to be performed
sources : `lsst.afw.table.SourceCatalog`
A catalog built from the results of merged detections, or
deblender outputs.
skyInfo : `lsst.pipe.base.Struct`
A struct containing information about the position of the input exposure within
a `SkyMap`, the `SkyMap`, its `Wcs`, and its bounding box
exposureId : `int` or `bytes`
packed unique number or bytes unique to the input exposure
ccdInputs : `lsst.afw.table.ExposureCatalog`
Catalog containing information on the individual visits which went into making
the exposure
visitCatalogs : list of `lsst.afw.table.SourceCatalogs` or `None`
A list of source catalogs corresponding to measurements made on the individual
visits which went into the input exposure. If None and butler is `None` then
the task cannot propagate visit flags to the output catalog.
wcsUpdates : list of `lsst.afw.geom.SkyWcs` or `None`
If visitCatalogs is not `None` this should be a list of wcs objects which correspond
to the input visits. Used to put all coordinates to common system. If `None` and
butler is `None` then the task cannot propagate visit flags to the output catalog.
butler : `lsst.daf.butler.Butler` or `lsst.daf.persistence.Butler`
Either a gen2 or gen3 butler used to load visit catalogs
Returns
-------
results : `lsst.pipe.base.Struct`
Results of running measurement task. Will contain the catalog in the
sources attribute. Optionally will have results of matching to a
reference catalog in the matchResults attribute, and denormalized
matches in the denormMatches attribute.
"""
self.measurement.run(sources, exposure, exposureId=exposureId)

if self.config.doApCorr:
self.applyApCorr.run(
Expand All @@ -853,15 +1062,24 @@ def runDataRef(self, patchRef, psfCache=100):
if self.config.doRunCatalogCalculation:
self.catalogCalculation.run(sources)

skyInfo = getSkyInfo(coaddName=self.config.coaddName, patchRef=patchRef)
self.setPrimaryFlags.run(sources, skyInfo.skyMap, skyInfo.tractInfo, skyInfo.patchInfo,
includeDeblend=self.deblended)
if self.config.doPropagateFlags:
self.propagateFlags.run(patchRef.getButler(), sources, self.propagateFlags.getCcdInputs(exposure),
exposure.getWcs())
self.propagateFlags.run(butler, sources, ccdInputs, exposure.getWcs(), visitCatalogs, wcsUpdates)

results = Struct()

if self.config.doMatchSources:
self.writeMatches(patchRef, exposure, sources)
self.write(patchRef, sources)
matchResult = self.match.run(sources, exposure.getInfo().getFilter().getName())
matches = afwTable.packMatches(matchResult.matches)
matches.table.setMetadata(matchResult.matchMeta)
results.matchResult = matches
if self.config.doWriteMatchesDenormalized:
results.denormMatches = denormalizeMatches(matchResult.matches,
matchResult.matchMeta)

results.outputSources = sources
return results

def readSources(self, dataRef):
"""!
Expand All @@ -883,24 +1101,17 @@ def readSources(self, dataRef):
sources.extend(merged, self.schemaMapper)
return sources

def writeMatches(self, dataRef, exposure, sources):
def writeMatches(self, dataRef, results):
"""!
@brief Write matches of the sources to the astrometric reference catalog.
We use the Wcs in the exposure to match sources.
@param[in] dataRef: data reference
@param[in] exposure: exposure with Wcs
@param[in] sources: source catalog
@param[in] results: results struct from run method
"""
result = self.match.run(sources, exposure.getInfo().getFilter().getName())
if result.matches:
matches = afwTable.packMatches(result.matches)
matches.table.setMetadata(result.matchMeta)
dataRef.put(matches, self.config.coaddName + "Coadd_measMatch")
if self.config.doWriteMatchesDenormalized:
denormMatches = denormalizeMatches(result.matches, result.matchMeta)
dataRef.put(denormMatches, self.config.coaddName + "Coadd_measMatchFull")
if hasattr(results, "matchResult"):
dataRef.put(results.matchResult, self.config.coaddName + "Coadd_measMatch")
if hasattr(results, "denormMatches"):
dataRef.put(results.denormMatches, self.config.coaddName + "Coadd_measMatchFull")

def write(self, dataRef, sources):
"""!
Expand Down

0 comments on commit 9fc727f

Please sign in to comment.