Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-22521: enable partial image reads in cp_pipe combine to avoid memory issues #126

Merged
merged 4 commits into from
Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion pipelines/cpBias.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ tasks:
cpBiasCombine:
class: lsst.cp.pipe.cpCombine.CalibCombineTask
config:
connections.inputExps: 'cpBiasProc'
connections.inputExpHandles: 'cpBiasProc'
connections.outputData: 'bias'
calibrationType: 'bias'
exposureScaling: "Unity"
Expand Down
2 changes: 1 addition & 1 deletion pipelines/cpDark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ tasks:
cpDarkCombine:
class: lsst.cp.pipe.cpCombine.CalibCombineTask
config:
connections.inputExps: 'cpDarkProc'
connections.inputExpHandles: 'cpDarkProc'
connections.outputData: 'dark'
calibrationType: 'dark'
exposureScaling: "DarkTime"
Expand Down
2 changes: 1 addition & 1 deletion pipelines/cpFlat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ tasks:
cpFlatCombine:
class: lsst.cp.pipe.cpCombine.CalibCombineByFilterTask
config:
connections.inputExps: 'cpFlatProc'
connections.inputExpHandles: 'cpFlatProc'
connections.inputScales: 'cpFlatNormScales'
connections.outputData: 'flat'
calibrationType: 'flat'
Expand Down
2 changes: 1 addition & 1 deletion pipelines/cpFlatSingleChip.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ tasks:
cpFlatCombine:
class: lsst.cp.pipe.cpCombine.CalibCombineByFilterTask
config:
connections.inputExps: 'cpFlatProc'
connections.inputExpHandles: 'cpFlatProc'
connections.outputData: 'flat'
calibrationType: 'flat'
exposureScaling: MeanStats
Expand Down
2 changes: 1 addition & 1 deletion pipelines/cpFringe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ tasks:
cpFringeCombine:
class: lsst.cp.pipe.cpCombine.CalibCombineByFilterTask
config:
connections.inputExps: 'cpFringeProc'
connections.inputExpHandles: 'cpFringeProc'
connections.outputData: 'fringe'
calibrationType: 'fringe'
exposureScaling: "Unity"
Expand Down
143 changes: 94 additions & 49 deletions python/lsst/cp/pipe/cpCombine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
import time

import lsst.geom as geom
import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
import lsst.pipe.base.connectionTypes as cT
Expand Down Expand Up @@ -105,12 +106,13 @@ def run(self, exposureOrImage):

class CalibCombineConnections(pipeBase.PipelineTaskConnections,
dimensions=("instrument", "detector")):
inputExps = cT.Input(
inputExpHandles = cT.Input(
name="cpInputs",
doc="Input pre-processed exposures to combine.",
storageClass="Exposure",
dimensions=("instrument", "detector", "exposure"),
multiple=True,
deferLoad=True,
)
inputScales = cT.Input(
name="cpScales",
Expand Down Expand Up @@ -173,6 +175,13 @@ class CalibCombineConfig(pipeBase.PipelineTaskConfig,
default=5,
doc="Maximum number of visits to estimate variance from input variance, not per-pixel spread",
)
subregionSize = pexConfig.ListField(
dtype=int,
doc="Width, height of subregion size.",
length=2,
# This is 200 rows for all detectors smaller than 10k in width.
default=(10000, 200),
)

doVignette = pexConfig.Field(
dtype=bool,
Expand Down Expand Up @@ -220,19 +229,19 @@ def __init__(self, **kwargs):
def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputs = butlerQC.get(inputRefs)

dimensions = [exp.dataId.byName() for exp in inputRefs.inputExps]
dimensions = [expHandle.dataId.byName() for expHandle in inputRefs.inputExpHandles]
inputs['inputDims'] = dimensions

outputs = self.run(**inputs)
butlerQC.put(outputs, outputRefs)

def run(self, inputExps, inputScales=None, inputDims=None):
def run(self, inputExpHandles, inputScales=None, inputDims=None):
"""Combine calib exposures for a single detector.

Parameters
----------
inputExps : `list` [`lsst.afw.image.Exposure`]
Input list of exposures to combine.
inputExpHandles : `list` [`lsst.daf.butler.DeferredDatasetHandle`]
Input list of exposure handles to combine.
inputScales : `dict` [`dict` [`dict` [`float`]]], optional
Dictionary of scales, indexed by detector (`int`),
amplifier (`int`), and exposure (`int`). Used for
Expand Down Expand Up @@ -262,29 +271,16 @@ def run(self, inputExps, inputScales=None, inputDims=None):
config.exposureScaling == InputList, and a necessary scale
was not found.
"""
width, height = self.getDimensions(inputExps)
width, height = self.getDimensions(inputExpHandles)
stats = afwMath.StatisticsControl(self.config.clip, self.config.nIter,
afwImage.Mask.getPlaneBitMask(self.config.mask))
numExps = len(inputExps)
numExps = len(inputExpHandles)
if numExps < 1:
raise RuntimeError("No valid input data")
if numExps < self.config.maxVisitsToCalcErrorFromInputVariance:
stats.setCalcErrorFromInputVariance(True)

# Check that all inputs either share the same detector (based
# on detId), or that no inputs have any detector.
detectorList = [exp.getDetector() for exp in inputExps]
if None in detectorList:
self.log.warning("Not all input detectors defined.")
detectorIds = [det.getId() if det is not None else None for det in detectorList]
detectorSerials = [det.getId() if det is not None else None for det in detectorList]
numDetectorIds = len(set(detectorIds))
numDetectorSerials = len(set(detectorSerials))
numDetectors = len(set([numDetectorIds, numDetectorSerials]))
if numDetectors != 1:
raise RuntimeError("Input data contains multiple detectors.")

inputDetector = inputExps[0].getDetector()
inputDetector = inputExpHandles[0].get(component='detector')

# Create output exposure for combined data.
combined = afwImage.MaskedImageF(width, height)
Expand All @@ -293,20 +289,20 @@ def run(self, inputExps, inputScales=None, inputDims=None):
# Apply scaling:
expScales = []
if inputDims is None:
inputDims = [dict() for i in inputExps]
inputDims = [dict() for i in inputExpHandles]

for index, (exp, dims) in enumerate(zip(inputExps, inputDims)):
for index, (expHandle, dims) in enumerate(zip(inputExpHandles, inputDims)):
scale = 1.0
if exp is None:
self.log.warning("Input %d is None (%s); unable to scale exp.", index, dims)
continue

visitInfo = expHandle.get(component='visitInfo')
czwa marked this conversation as resolved.
Show resolved Hide resolved
if self.config.exposureScaling == "ExposureTime":
scale = exp.getInfo().getVisitInfo().getExposureTime()
scale = visitInfo.getExposureTime()
elif self.config.exposureScaling == "DarkTime":
scale = exp.getInfo().getVisitInfo().getDarkTime()
scale = visitInfo.getDarkTime()
elif self.config.exposureScaling == "MeanStats":
# Note: there may a bug freeing memory here. TBD.
exp = expHandle.get()
scale = self.stats.run(exp)
del exp
elif self.config.exposureScaling == "InputList":
visitId = dims.get('exposure', None)
detectorId = dims.get('detector', None)
Expand All @@ -323,7 +319,7 @@ def run(self, inputExps, inputScales=None, inputDims=None):
scale = inputScales['expScale'][detectorId][visitId]
elif self.config.scalingLevel == 'AMP':
scale = [inputScales['expScale'][detectorId][amp.getName()][visitId]
for amp in exp.getDetector()]
for amp in inputDetector]
else:
raise RuntimeError(f"Unknown scaling level: {self.config.scalingLevel}")
elif self.config.exposureScaling == 'Unity':
Expand All @@ -333,18 +329,17 @@ def run(self, inputExps, inputScales=None, inputDims=None):

expScales.append(scale)
self.log.info("Scaling input %d by %s", index, scale)
self.applyScale(exp, scale)

self.combine(combined, inputExps, stats)
self.combine(combinedExp, inputExpHandles, expScales, stats)

self.interpolateNans(combined)

if self.config.doVignette:
polygon = inputExps[0].getInfo().getValidPolygon()
polygon = inputExpHandles[0].get(component='validPolygon')
maskVignettedRegion(combined, polygon=polygon, vignetteValue=0.0)

# Combine headers
self.combineHeaders(inputExps, combinedExp,
self.combineHeaders(inputExpHandles, combinedExp,
calibType=self.config.calibrationType, scales=expScales)

# Set the detector
Expand All @@ -355,20 +350,21 @@ def run(self, inputExps, inputScales=None, inputDims=None):
outputData=combinedExp,
)

def getDimensions(self, expList):
def getDimensions(self, expHandleList):
"""Get dimensions of the inputs.

Parameters
----------
expList : `list` [`lsst.afw.image.Exposure`]
Exps to check the sizes of.
expHandleList : `list` [`lsst.daf.butler.DeferredDatasetHandle`]
Exposure handles to check the sizes of.

Returns
-------
width, height : `int`
Unique set of input dimensions.
"""
dimList = [exp.getDimensions() for exp in expList if exp is not None]
dimList = [expHandle.get(component='bbox').getDimensions() for expHandle in expHandleList]

return self.getSize(dimList)

def getSize(self, dimList):
Expand Down Expand Up @@ -416,30 +412,76 @@ def applyScale(self, exposure, scale=None):
else:
mi /= scale

def combine(self, target, expList, stats):
@staticmethod
def _subBBoxIter(bbox, subregionSize):
"""Iterate over subregions of a bbox.

Parameters
----------
bbox : `lsst.geom.Box2I`
Bounding box over which to iterate.
subregionSize: `lsst.geom.Extent2I`
Size of sub-bboxes.

Yields
------
subBBox : `lsst.geom.Box2I`
Next sub-bounding box of size ``subregionSize`` or
smaller; each ``subBBox`` is contained within ``bbox``, so
it may be smaller than ``subregionSize`` at the edges of
``bbox``, but it will never be empty.
"""
if bbox.isEmpty():
raise RuntimeError("bbox %s is empty" % (bbox,))
if subregionSize[0] < 1 or subregionSize[1] < 1:
raise RuntimeError("subregionSize %s must be nonzero" % (subregionSize,))

for rowShift in range(0, bbox.getHeight(), subregionSize[1]):
for colShift in range(0, bbox.getWidth(), subregionSize[0]):
subBBox = geom.Box2I(bbox.getMin() + geom.Extent2I(colShift, rowShift), subregionSize)
subBBox.clip(bbox)
if subBBox.isEmpty():
raise RuntimeError("Bug: empty bbox! bbox=%s, subregionSize=%s, "
"colShift=%s, rowShift=%s" %
(bbox, subregionSize, colShift, rowShift))
yield subBBox

def combine(self, target, expHandleList, expScaleList, stats):
"""Combine multiple images.

Parameters
----------
target : `lsst.afw.image.Exposure`
Output exposure to construct.
expList : `list` [`lsst.afw.image.Exposure`]
Input exposures to combine.
expHandleList : `list` [`lsst.daf.butler.DeferredDatasetHandle`]
Input exposure handles to combine.
expScaleList : `list` [`float`]
List of scales to apply to each input image.
stats : `lsst.afw.math.StatisticsControl`
Control explaining how to combine the input images.
"""
images = [img.getMaskedImage() for img in expList if img is not None]
combineType = afwMath.stringToStatisticsProperty(self.config.combine)
afwMath.statisticsStack(target, images, combineType, stats)

def combineHeaders(self, expList, calib, calibType="CALIB", scales=None):
subregionSizeArr = self.config.subregionSize
subregionSize = geom.Extent2I(subregionSizeArr[0], subregionSizeArr[1])
for subBbox in self._subBBoxIter(target.getBBox(), subregionSize):
images = []
for expHandle, expScale in zip(expHandleList, expScaleList):
inputExp = expHandle.get(parameters={'bbox': subBbox})
self.applyScale(inputExp, expScale)
images.append(inputExp.getMaskedImage())

combinedSubregion = afwMath.statisticsStack(images, combineType, stats)
target.maskedImage.assign(combinedSubregion, subBbox)

def combineHeaders(self, expHandleList, calib, calibType="CALIB", scales=None):
"""Combine input headers to determine the set of common headers,
supplemented by calibration inputs.

Parameters
----------
expList : `list` [`lsst.afw.image.Exposure`]
Input list of exposures to combine.
expHandleList : `list` [`lsst.daf.butler.DeferredDatasetHandle`]
Input list of exposure handles to combine.
calib : `lsst.afw.image.Exposure`
Output calibration to construct headers for.
calibType : `str`, optional
Expand Down Expand Up @@ -473,16 +515,19 @@ def combineHeaders(self, expList, calib, calibType="CALIB", scales=None):
header.set("CALIB_CREATE_TIME", calibTime)

# Merge input headers
inputHeaders = [exp.getMetadata() for exp in expList if exp is not None]
inputHeaders = [expHandle.get(component='metadata') for expHandle in expHandleList]
merged = merge_headers(inputHeaders, mode='drop')

# Scan the first header for items that were dropped due to
# conflict, and replace them.
for k, v in merged.items():
if k not in header:
md = expList[0].getMetadata()
md = inputHeaders[0]
comment = md.getComment(k) if k in md else None
header.set(k, v, comment=comment)

# Construct list of visits
visitInfoList = [exp.getInfo().getVisitInfo() for exp in expList if exp is not None]
visitInfoList = [expHandle.get(component='visitInfo') for expHandle in expHandleList]
for i, visit in enumerate(visitInfoList):
if visit is None:
continue
Expand Down
13 changes: 7 additions & 6 deletions python/lsst/cp/pipe/cpSkyTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,12 +349,13 @@ class CpSkyCombineConnections(pipeBase.PipelineTaskConnections,
dimensions=("instrument", "exposure", "detector"),
multiple=True,
)
inputExps = cT.Input(
inputExpHandles = cT.Input(
name="cpSkyMaskedIsr",
doc="Masked post-ISR image.",
storageClass="Exposure",
dimensions=("instrument", "exposure", "detector"),
multiple=True,
deferLoad=True,
)

outputCalib = cT.Output(
Expand Down Expand Up @@ -391,15 +392,15 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)
self.makeSubtask("sky")

def run(self, inputBkgs, inputExps):
def run(self, inputBkgs, inputExpHandles):
"""Merge per-exposure measurements into a detector level calibration.

Parameters
----------
inputBkgs : `list` [`lsst.afw.math.BackgroundList`]
Remnant backgrounds from each exposure.
inputExps : `list` [`lsst.afw.image.Exposure`]
The ISR processed, detection masked images.
inputHandles : `list` [`lsst.daf.butler.DeferredDatasetHandles`]
The Butler handles to the ISR processed, detection masked images.

Returns
-------
Expand All @@ -410,9 +411,9 @@ def run(self, inputBkgs, inputExps):
The final sky calibration product.
"""
skyCalib = self.sky.averageBackgrounds(inputBkgs)
skyCalib.setDetector(inputExps[0].getDetector())
skyCalib.setDetector(inputExpHandles[0].get(component='detector'))

CalibCombineTask().combineHeaders(inputExps, skyCalib, calibType='SKY')
CalibCombineTask().combineHeaders(inputExpHandles, skyCalib, calibType='SKY')

return pipeBase.Struct(
outputCalib=skyCalib,
Expand Down