Skip to content

Commit

Permalink
Refactor so that N masks are never stored in memory
Browse files Browse the repository at this point in the history
  • Loading branch information
yalsayyad committed Jan 26, 2018
1 parent 349b62b commit d3ef8f0
Showing 1 changed file with 77 additions and 77 deletions.
154 changes: 77 additions & 77 deletions python/lsst/pipe/tasks/assembleCoadd.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ def assembleSubregion(self, coaddExposure, bbox, tempExpRefList, imageScalerList
\param[in] imageScalerList: List of image scalers
\param[in] weightList: List of weights
\param[in] altMaskList: List of alternate masks to use rather than those stored with tempExp, or None
Each element is dict with keys = mask plane name to which to add the spans
\param[in] statsFlags: afwMath.Property object for statistic for coadd
\param[in] statsCtrl: Statistics control object for coadd
\param[in] nImage: optional ImageU keeps track of exposure count for each pixel
Expand All @@ -601,9 +602,9 @@ def assembleSubregion(self, coaddExposure, bbox, tempExpRefList, imageScalerList
for tempExpRef, imageScaler, altMask in zip(tempExpRefList, imageScalerList, altMaskList):
exposure = tempExpRef.get(tempExpName + "_sub", bbox=bbox)
maskedImage = exposure.getMaskedImage()
if altMask:
altMaskSub = altMask.Factory(altMask, bbox, afwImage.PARENT)
maskedImage.getMask().swap(altMaskSub)
mask = maskedImage.getMask()
if altMask is not None:
self.applyAltMaskPlanes(mask, altMask)
imageScaler.scaleMaskedImage(maskedImage)

# Add 1 for each pixel which is not excluded by the exclude mask.
Expand All @@ -628,6 +629,22 @@ def assembleSubregion(self, coaddExposure, bbox, tempExpRefList, imageScalerList
if nImage is not None:
nImage.assign(subNImage, bbox)

def applyAltMaskPlanes(self, mask, altMaskSpans):
"""!
\brief Apply in place alt mask formatted as SpanSets to a mask
@param mask: original mask
@param altMaskSpans: Dictionary containing spanSet lists to apply.
Each element contains the new mask plane name
(e.g. "CLIPPED and/or "NO_DATA") as the key,
and list of SpanSets to apply to the mask
"""
for plane, spanSetList in altMaskSpans.items():
maskClipValue = mask.addMaskPlane(plane)
for spanSet in spanSetList:
spanSet.clippedTo(mask.getBBox()).setMask(mask, 2**maskClipValue)
return mask

def readBrightObjectMasks(self, dataRef):
"""Returns None on failure"""
try:
Expand Down Expand Up @@ -1055,12 +1072,12 @@ def assemble(self, skyInfo, tempExpRefList, imageScalerList, weightList, *args,

self.log.info('Found %d clipped objects', len(result.clipFootprints))

# Go to individual visits for big footprints
maskClipValue = mask.getPlaneBitMask("CLIPPED")
maskDetValue = mask.getPlaneBitMask("DETECTED") | mask.getPlaneBitMask("DETECTED_NEGATIVE")
bigFootprints = self.detectClipBig(result.tempExpClipList, result.clipFootprints, result.clipIndices,
maskClipValue, maskDetValue)

# Append big footprints from individual Warps to result.clipSpans
bigFootprints = self.detectClipBig(result.clipSpans, result.clipFootprints, result.clipIndices,
result.detectionFootprints, maskClipValue, maskDetValue,
exp.getBBox())
# Create mask of the current clipped footprints
maskClip = mask.Factory(mask.getBBox(afwImage.PARENT))
afwDet.setMaskFromFootprintList(maskClip, result.clipFootprints, maskClipValue)
Expand All @@ -1074,7 +1091,7 @@ def assemble(self, skyInfo, tempExpRefList, imageScalerList, weightList, *args,
badMaskPlanes.append("CLIPPED")
badPixelMask = afwImage.Mask.getPlaneBitMask(badMaskPlanes)
return AssembleCoaddTask.assemble(self, skyInfo, tempExpRefList, imageScalerList, weightList,
result.tempExpClipList, mask=badPixelMask)
result.clipSpans, mask=badPixelMask)

def buildDifferenceImage(self, skyInfo, tempExpRefList, imageScalerList, weightList):
"""!
Expand Down Expand Up @@ -1129,12 +1146,13 @@ def detectClip(self, exp, tempExpRefList):
\param[in] exp: Exposure to run detection on
\param[in] tempExpRefList: List of data reference to tempExp
\return struct containing:
- clippedFootprints: list of clipped footprints
- clippedIndices: indices for each clippedFootprint in tempExpRefList
- tempExpClipList: list of new masks for tempExp
- clipFootprints: list of clipped footprints
- clipIndices: indices for each clippedFootprint in tempExpRefList
- clipSpans: List of dictionaries containing spanSet lists to clip. Each element contains the new
maskplane name ("CLIPPED")" as the key and list of SpanSets as value
- detectionFootprints: List of DETECTED/DETECTED_NEGATIVE plane compressed into footprints
"""
mask = exp.getMaskedImage().getMask()
maskClipValue = mask.getPlaneBitMask("CLIPPED")
maskDetValue = mask.getPlaneBitMask("DETECTED") | mask.getPlaneBitMask("DETECTED_NEGATIVE")
fpSet = self.clipDetection.detectFootprints(exp, doSmooth=True, clearMask=True)
# Merge positive and negative together footprints together
Expand All @@ -1145,28 +1163,43 @@ def detectClip(self, exp, tempExpRefList):

clipFootprints = []
clipIndices = []
artifactSpanSets = [{'CLIPPED': list()} for _ in tempExpRefList]

# for use by detectClipBig
visitDetectionFootprints = []

dims = [len(tempExpRefList), len(footprints.getFootprints())]
overlapDetArr = numpy.zeros(dims, dtype=numpy.uint16)
ignoreArr = numpy.zeros(dims, dtype=numpy.uint16)

# Loop over masks once and extract/store only relevant overlap metrics and detection footprints
for i, warpRef in enumerate(tempExpRefList):
tmpExpMask = warpRef.get(self.getTempExpDatasetName(self.warpType),
immediate=True).getMaskedImage().getMask()
maskVisitDet = tmpExpMask.Factory(tmpExpMask, tmpExpMask.getBBox(afwImage.PARENT),
afwImage.PARENT, True)
maskVisitDet &= maskDetValue
visitFootprints = afwDet.FootprintSet(maskVisitDet, afwDet.Threshold(1))
visitDetectionFootprints.append(visitFootprints)

# build a list with a mask for each visit which can be modified with clipping information
tempExpClipList = [tmpExpRef.get(self.getTempExpDatasetName(self.warpType),
immediate=True).getMaskedImage().getMask() for
tmpExpRef in tempExpRefList]
for j, footprint in enumerate(footprints.getFootprints()):
ignoreArr[i, j] = countMaskFromFootprint(tmpExpMask, footprint, ignoreMask, 0x0)
overlapDetArr[i, j] = countMaskFromFootprint(tmpExpMask, footprint, maskDetValue, ignoreMask)

for footprint in footprints.getFootprints():
# build a list of clipped spans for each visit
for j, footprint in enumerate(footprints.getFootprints()):
nPixel = footprint.getArea()
overlap = [] # hold the overlap with each visit
maskList = [] # which visit mask match
indexList = [] # index of visit in global list
for i, tmpExpMask in enumerate(tempExpClipList):
# Determine the overlap with the footprint
ignore = countMaskFromFootprint(tmpExpMask, footprint, ignoreMask, 0x0)
overlapDet = countMaskFromFootprint(tmpExpMask, footprint, maskDetValue, ignoreMask)
for i in range(len(tempExpRefList)):
ignore = ignoreArr[i, j]
overlapDet = overlapDetArr[i, j]
totPixel = nPixel - ignore

# If we have more bad pixels than detection skip
if ignore > overlapDet or totPixel <= 0.5*nPixel or overlapDet == 0:
continue
overlap.append(overlapDet/float(totPixel))
maskList.append(tmpExpMask)
indexList.append(i)

overlap = numpy.array(overlap)
Expand Down Expand Up @@ -1200,38 +1233,37 @@ def detectClip(self, exp, tempExpRefList):
continue

for index in keepIndex:
footprint.spans.setMask(maskList[index], maskClipValue)
globalIndex = indexList[index]
artifactSpanSets[globalIndex]['CLIPPED'].append(footprint.spans)

clipIndices.append(numpy.array(indexList)[keepIndex])
clipFootprints.append(footprint)

return pipeBase.Struct(clipFootprints=clipFootprints, clipIndices=clipIndices,
tempExpClipList=tempExpClipList)
clipSpans=artifactSpanSets, detectionFootprints=visitDetectionFootprints)

def detectClipBig(self, tempExpClipList, clipFootprints, clipIndices, maskClipValue, maskDetValue):
def detectClipBig(self, clipList, clipFootprints, clipIndices, detectionFootprints,
maskClipValue, maskDetValue, coaddBBox):
"""!
\brief Find footprints from individual tempExp footprints for large footprints.
\brief Return individual warp footprints for large artifacts and append them to clipList in place
Identify big footprints composed of many sources in the coadd difference that may have originated in a
large diffuse source in the coadd. We do this by indentifying all clipped footprints that overlap
significantly with each source in all the coaddTempExps.
\param[in] tempExpClipList: List of tempExp masks with clipping information
\param[in] clipList: List of alt mask SpanSets with clipping information. Modified.
\param[in] clipFootprints: List of clipped footprints
\param[in] clipIndices: List of which entries in tempExpClipList each footprint belongs to
\param[in] maskClipValue: Mask value of clipped pixels
\param[in] maskClipValue: Mask value of detected pixels
\param[in] maskDetValue: Mask value of detected pixels
\param[in] coaddBBox: BBox of the coadd and warps
\return list of big footprints
"""
bigFootprintsCoadd = []
ignoreMask = self.getBadPixelMask()
for index, tmpExpMask in enumerate(tempExpClipList):

# Create list of footprints from the DETECTED pixels
maskVisitDet = tmpExpMask.Factory(tmpExpMask, tmpExpMask.getBBox(afwImage.PARENT),
afwImage.PARENT, True)
maskVisitDet &= maskDetValue
visitFootprints = afwDet.FootprintSet(maskVisitDet, afwDet.Threshold(1))
for index, (clippedSpans, visitFootprints) in enumerate(zip(clipList, detectionFootprints)):
maskVisitDet = afwImage.MaskX(coaddBBox, 0x0)
for footprint in visitFootprints.getFootprints():
footprint.spans.setMask(maskVisitDet, maskDetValue)

# build a mask of clipped footprints that are in this visit
clippedFootprintsVisit = []
Expand All @@ -1251,10 +1283,8 @@ def detectClipBig(self, tempExpClipList, clipFootprints, clipIndices, maskClipVa
bigFootprintsVisit.append(foot)
bigFootprintsCoadd.append(foot)

# Update single visit masks
maskVisitClip.clearAllMaskPlanes()
afwDet.setMaskFromFootprintList(maskVisitClip, bigFootprintsVisit, maskClipValue)
tmpExpMask |= maskVisitClip
for footprint in bigFootprintsVisit:
clippedSpans["CLIPPED"].append(footprint.spans)

return bigFootprintsCoadd

Expand Down Expand Up @@ -1384,8 +1414,6 @@ class CompareWarpAssembleCoaddTask(AssembleCoaddTask):
<dl>
<dt>`saveCountIm`
<dd> If True then save the Epoch Count Image as a fits file in the `figPath`
<dt> `saveAltMask`
<dd> If True then save the new masks with CLIPPED planes as fits files to the `figPath`
<dt> `figPath`
<dd> Path to save the debug fits images and figures
</dl>
Expand All @@ -1397,7 +1425,6 @@ def DebugInfo(name):
di = lsstDebug.getInfo(name)
if name == "lsst.pipe.tasks.assembleCoadd":
di.saveCountIm = True
di.saveAltMask = True
di.figPath = "/desired/path/to/debugging/output/images"
return di
lsstDebug.Info = DebugInfo
Expand Down Expand Up @@ -1515,13 +1542,12 @@ def assemble(self, skyInfo, tempExpRefList, imageScalerList, weightList,
"""
templateCoadd = supplementaryData.templateCoadd
spanSetMaskList = self.findArtifacts(templateCoadd, tempExpRefList, imageScalerList)
maskList = self.computeAltMaskList(tempExpRefList, spanSetMaskList)
badMaskPlanes = self.config.badMaskPlanes[:]
badMaskPlanes.append("CLIPPED")
badPixelMask = afwImage.Mask.getPlaneBitMask(badMaskPlanes)

return AssembleCoaddTask.assemble(self, skyInfo, tempExpRefList, imageScalerList, weightList,
maskList, mask=badPixelMask)
spanSetMaskList, mask=badPixelMask)

def findArtifacts(self, templateCoadd, tempExpRefList, imageScalerList):
"""!
Expand Down Expand Up @@ -1587,37 +1613,11 @@ def findArtifacts(self, templateCoadd, tempExpRefList, imageScalerList):
filteredSpanSetList = self._filterArtifacts(spanSetList, epochCountImage, nImage)
spanSetArtifactList[i] = filteredSpanSetList

return pipeBase.Struct(artifacts=spanSetArtifactList,
noData=spanSetNoDataMaskList)

def computeAltMaskList(self, tempExpRefList, maskSpanSets):
"""!
\brief Apply artifact span set lists to masks
@param tempExpRefList: List of data references to warps
@param maskSpanSets: Struct containing artifact and noData spanSet lists to apply
return List of alternative masks
Add artifact span set list as "CLIPPED" plane and NaNs to existing "NO_DATA" plane
"""
spanSetMaskList = maskSpanSets.artifacts
spanSetNoDataList = maskSpanSets.noData
altMaskList = []
for warpRef, artifacts, noData in zip(tempExpRefList, spanSetMaskList, spanSetNoDataList):
warp = warpRef.get(self.getTempExpDatasetName(self.config.warpType), immediate=True)
mask = warp.maskedImage.mask
maskClipValue = mask.addMaskPlane("CLIPPED")
noDataValue = mask.addMaskPlane("NO_DATA")
for artifact in artifacts:
artifact.clippedTo(mask.getBBox()).setMask(mask, 2**maskClipValue)
for noDataRegion in noData:
noDataRegion.clippedTo(mask.getBBox()).setMask(mask, 2**noDataValue)
altMaskList.append(mask)
if lsstDebug.Info(__name__).saveAltMask:
mask.writeFits(self._dataRef2DebugPath("altMask", warpRef))

return altMaskList
altMasks = []
for artifacts, noData in zip(spanSetArtifactList, spanSetNoDataMaskList):
altMasks.append({'CLIPPED': artifacts,
'NO_DATA': noData})
return altMasks

def _filterArtifacts(self, spanSetList, epochCountImage, nImage):
"""!
Expand Down

0 comments on commit d3ef8f0

Please sign in to comment.