Skip to content

Commit

Permalink
Allow exposures with fakes in to be used in future processing steps
Browse files Browse the repository at this point in the history
  • Loading branch information
sr525 committed Apr 23, 2019
1 parent 9730928 commit 479bbc8
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 19 deletions.
21 changes: 17 additions & 4 deletions python/lsst/pipe/drivers/coaddDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,17 @@ class CoaddDriverConfig(Config):
doc="Run detection on the coaddition product")
detectCoaddSources = ConfigurableField(
target=DetectCoaddSourcesTask, doc="Detect sources on coadd")
hasFakes = Field(dtype=bool, default=False,
doc="Should be set to True if fake sources were added to the data before processing.")

def setDefaults(self):
self.makeCoaddTempExp.select.retarget(NullSelectImagesTask)
self.assembleCoadd.select.retarget(NullSelectImagesTask)
self.assembleCoadd.doWrite = False
if self.hasFakes:
self.detectCoaddSources.hasFakes = True
self.makeCoaddTempExp.hasFakes = True
self.assembleCoadd.hasFakes = True

def validate(self):
if self.makeCoaddTempExp.coaddName != self.coaddName:
Expand Down Expand Up @@ -94,6 +100,11 @@ def __init__(self, reuse=tuple(), **kwargs):
self.makeSubtask("backgroundReference")
self.makeSubtask("assembleCoadd")
self.makeSubtask("detectCoaddSources")
if self.config.hasFakes:
self.calexpType = "fakes_calexp"
else:
self.calexpType = "calexp"


def __reduce__(self):
"""Pickler"""
Expand Down Expand Up @@ -157,10 +168,9 @@ def runDataRef(self, tractPatchRefList, butler, selectIdList=[]):
self.log.info("Non-empty tracts (%d): %s" % (len(tractPatchRefList),
[patchRefList[0].dataId["tract"] for patchRefList in
tractPatchRefList]))

# Install the dataRef in the selectDataList
for data in selectDataList:
data.dataRef = getDataRef(butler, data.dataId, "calexp")
data.dataRef = getDataRef(butler, data.dataId, self.calexpType)

# Process the non-empty tracts
return [self.run(patchRefList, butler, selectDataList) for patchRefList in tractPatchRefList]
Expand Down Expand Up @@ -206,7 +216,7 @@ def readSelection(self, cache, selectId):
@return a SelectStruct with a dataId instead of dataRef
"""
try:
ref = getDataRef(cache.butler, selectId, "calexp")
ref = getDataRef(cache.butler, selectId, self.calexpType)
self.log.info("Reading Wcs from %s" % (selectId,))
md = ref.get("calexp_md", immediate=True)
wcs = afwGeom.makeSkyWcs(md)
Expand Down Expand Up @@ -322,7 +332,10 @@ def coadd(self, cache, data):
detResults = self.detectCoaddSources.run(coadd, idFactory, expId=expId)
self.detectCoaddSources.write(detResults, patchRef)
else:
patchRef.put(coadd, self.assembleCoadd.config.coaddName+"Coadd")
if self.config.hasFakes:
patchRef.put(coadd, "fakes_" + self.assembleCoadd.config.coaddName + "Coadd")
else:
patchRef.put(coadd, self.assembleCoadd.config.coaddName + "Coadd")

def selectExposures(self, patchRef, selectDataList):
"""!Select exposures to operate upon, via the SelectImagesTask
Expand Down
41 changes: 28 additions & 13 deletions python/lsst/pipe/drivers/multiBandDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,25 @@ class MultiBandDriverConfig(Config):
"if we consider those footprints important to recover."),
)

hasFakes = Field(
dtype=bool,
default=False,
doc="Should be set to True if fakes were inserted into the data being processed."
)

def setDefaults(self):
Config.setDefaults(self)
self.forcedPhotCoadd.references.retarget(MultiBandReferencesTask)

if self.hasFakes:
self.detectCoaddSources.hasFakes = True
self.deblendCoaddSources.hasFakes = True
self.measureCoaddSources.hasFakes = True
self.forcedPhotCoadd.hasFakes = True


def validate(self):

for subtask in ("mergeCoaddDetections", "deblendCoaddSources", "measureCoaddSources",
"mergeCoaddMeasurements", "forcedPhotCoadd"):
coaddName = getattr(self, subtask).coaddName
Expand Down Expand Up @@ -142,6 +156,10 @@ def __init__(self, butler=None, schema=None, refObjLoader=None, reuse=tuple(), *
self.measureCoaddSources.schema))
self.makeSubtask("forcedPhotCoadd", refSchema=afwTable.Schema(
self.mergeCoaddMeasurements.schema))
if self.config.hasFakes:
self.coaddType = "fakes_" + self.config.coaddName
else:
self.coaddType = self.config.coaddName

def __reduce__(self):
"""Pickler"""
Expand Down Expand Up @@ -193,7 +211,6 @@ def runDataRef(self, patchRefList):
pool = Pool("all")
pool.cacheClear()
pool.storeSet(butler=butler)

# MultiBand measurements require that the detection stage be completed
# before measurements can be made.
#
Expand All @@ -211,11 +228,11 @@ def runDataRef(self, patchRefList):
detectionList = []
for patchRef in patchRefList:
if ("detectCoaddSources" in self.reuse and
patchRef.datasetExists(self.config.coaddName + "Coadd_calexp", write=True)):
patchRef.datasetExists(self.coaddType + "Coadd_calexp", write=True)):
self.log.info("Skipping detectCoaddSources for %s; output already exists." %
patchRef.dataId)
continue
if not patchRef.datasetExists(self.config.coaddName + "Coadd"):
if not patchRef.datasetExists(self.coaddType + "Coadd"):
self.log.debug("Not processing %s; required input %sCoadd missing." %
(patchRef.dataId, self.config.coaddName))
continue
Expand All @@ -224,7 +241,7 @@ def runDataRef(self, patchRefList):
pool.map(self.runDetection, detectionList)

patchRefList = [patchRef for patchRef in patchRefList if
patchRef.datasetExists(self.config.coaddName + "Coadd_calexp") and
patchRef.datasetExists(self.coaddType + "Coadd_calexp") and
patchRef.datasetExists(self.config.coaddName + "Coadd_det",
write=self.config.doDetection)]
dataIdList = [patchRef.dataId for patchRef in patchRefList]
Expand Down Expand Up @@ -292,7 +309,7 @@ def runDataRef(self, patchRefList):
# and we're starting over
patchReprocessing[patchId] = True

# Only process patches that have been identified as needing it
# Only process patches that have been identifiedz as needing it
pool.map(self.runMeasurements, [dataId1 for dataId1 in dataIdList if not self.config.reprocessing or
patchReprocessing[dataId1["patch"]]])
pool.map(self.runMergeMeasurements, [idList for patchId, idList in patches.items() if
Expand Down Expand Up @@ -320,8 +337,7 @@ def runDetection(self, cache, patchRef):
"""
with self.logOperation("do detections on {}".format(patchRef.dataId)):
idFactory = self.detectCoaddSources.makeIdFactory(patchRef)
coadd = patchRef.get(self.config.coaddName + "Coadd",
immediate=True)
coadd = patchRef.get(self.coaddType + "Coadd", immediate=True)
expId = int(patchRef.get(self.config.coaddName + "CoaddId"))
self.detectCoaddSources.emptyMetadata()
detResults = self.detectCoaddSources.run(coadd, idFactory, expId=expId)
Expand All @@ -337,7 +353,7 @@ def runMergeDetections(self, cache, dataIdList):
@param dataIdList: List of data identifiers for the patch in different filters
"""
with self.logOperation("merge detections from %s" % (dataIdList,)):
dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
dataRefList = [getDataRef(cache.butler, dataId, self.coaddType + "Coadd_calexp") for
dataId in dataIdList]
if ("mergeCoaddDetections" in self.reuse and
dataRefList[0].datasetExists(self.config.coaddName + "Coadd_mergeDet", write=True)):
Expand All @@ -364,7 +380,7 @@ def runDeblendMerged(self, cache, dataIdList):
whether the patch requires reprocessing.
"""
with self.logOperation("deblending %s" % (dataIdList,)):
dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
dataRefList = [getDataRef(cache.butler, dataId, self.coaddType + "Coadd_calexp") for
dataId in dataIdList]
reprocessing = False # Does this patch require reprocessing?
if ("deblendCoaddSources" in self.reuse and
Expand Down Expand Up @@ -418,8 +434,7 @@ def runMeasurements(self, cache, dataId):
Data identifier for patch
"""
with self.logOperation("measurements on %s" % (dataId,)):
dataRef = getDataRef(cache.butler, dataId,
self.config.coaddName + "Coadd_calexp")
dataRef = getDataRef(cache.butler, dataId, self.coaddType + "Coadd_calexp")
if ("measureCoaddSources" in self.reuse and
not self.config.reprocessing and
dataRef.datasetExists(self.config.coaddName + "Coadd_meas", write=True)):
Expand All @@ -436,7 +451,7 @@ def runMergeMeasurements(self, cache, dataIdList):
@param dataIdList: List of data identifiers for the patch in different filters
"""
with self.logOperation("merge measurements from %s" % (dataIdList,)):
dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
dataRefList = [getDataRef(cache.butler, dataId, self.coaddType + "Coadd_calexp") for
dataId in dataIdList]
if ("mergeCoaddMeasurements" in self.reuse and
not self.config.reprocessing and
Expand All @@ -456,7 +471,7 @@ def runForcedPhot(self, cache, dataId):
"""
with self.logOperation("forced photometry on %s" % (dataId,)):
dataRef = getDataRef(cache.butler, dataId,
self.config.coaddName + "Coadd_calexp")
self.coaddType + "Coadd_calexp")
if ("forcedPhotCoadd" in self.reuse and
not self.config.reprocessing and
dataRef.datasetExists(self.config.coaddName + "Coadd_forced_src", write=True)):
Expand Down
11 changes: 9 additions & 2 deletions python/lsst/pipe/drivers/skyCorrection.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class SkyCorrectionConfig(Config):
doBgModel2 = Field(dtype=bool, default=True, doc="Do cleanup background model subtraction?")
doSky = Field(dtype=bool, default=True, doc="Do sky frame subtraction?")
binning = Field(dtype=int, default=8, doc="Binning factor for constructing focal-plane images")
hasFakes = Field(dtype=bool, default=False,
doc="Should be set to True if fake sources were added to the data before processing.")

def setDefaults(self):
Config.setDefaults(self)
Expand All @@ -65,6 +67,11 @@ def __init__(self, *args, **kwargs):
self.makeSubtask("maskObjects")
self.makeSubtask("sky")

if self.config.hasFakes:
self.calexpType = "fakes_calexp"
else:
self.calexpType = "calexp"

@classmethod
def _makeArgumentParser(cls, *args, **kwargs):
kwargs.pop("doBatch", False)
Expand Down Expand Up @@ -122,7 +129,7 @@ def runDataRef(self, expRef):
camera = expRef.get("camera")

dataIdList = [ccdRef.dataId for ccdRef in expRef.subItems("ccd") if
ccdRef.datasetExists("calexp")]
ccdRef.datasetExists(self.calexpType)]

exposures = pool.map(self.loadImage, dataIdList)
if DEBUG:
Expand Down Expand Up @@ -201,7 +208,7 @@ def loadImage(self, cache, dataId):
Resultant exposure.
"""
cache.dataId = dataId
cache.exposure = cache.butler.get("calexp", dataId, immediate=True).clone()
cache.exposure = cache.butler.get(self.calexpType, dataId, immediate=True).clone()
bgOld = cache.butler.get("calexpBackground", dataId, immediate=True)
image = cache.exposure.getMaskedImage()

Expand Down

0 comments on commit 479bbc8

Please sign in to comment.