Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-37357: Update masking in parallel overscan #248

Merged
merged 10 commits into from
Jan 31, 2023
6 changes: 5 additions & 1 deletion python/lsst/ip/isr/isrTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,6 +1298,11 @@ def run(self, ccdExposure, *, camera=None, bias=None, linearizer=None,

# Amplifier level processing.
overscans = []

if self.config.doOverscan and self.config.overscan.doParallelOverscan:
# This will attempt to mask bleed pixels across all amplifiers.
self.overscan.maskParallelOverscan(ccdExposure, ccd)

for amp in ccd:
# if ccdExposure is one amp,
# check for coverage to prevent performing ops multiple times
Expand Down Expand Up @@ -1884,7 +1889,6 @@ def overscanCorrection(self, ccdExposure, amp):
See Also
--------
lsst.ip.isr.overscan.OverscanTask

"""
if amp.getRawHorizontalOverscanBBox().isEmpty():
self.log.info("ISR_OSCAN: No overscan region. Not performing overscan correction.")
Expand Down
244 changes: 189 additions & 55 deletions python/lsst/ip/isr/overscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import lsst.pex.config as pexConfig

from .isr import fitOverscanImage
from .isrFunctions import makeThresholdMask, countMaskedPixels
from .isrFunctions import makeThresholdMask


class OverscanCorrectionTaskConfig(pexConfig.Config):
Expand Down Expand Up @@ -80,15 +80,17 @@ class OverscanCorrectionTaskConfig(pexConfig.Config):
doc="Correct using parallel overscan after serial overscan correction?",
default=False,
)
parallelOverscanMaskThreshold = pexConfig.RangeField(
erykoff marked this conversation as resolved.
Show resolved Hide resolved
dtype=float,
doc="Minimum fraction of pixels in parallel overscan region necessary "
"for parallel overcan correction.",
default=0.1,
min=0.0,
max=1.0,
inclusiveMin=True,
inclusiveMax=True,
parallelOverscanMaskThreshold = pexConfig.Field(
dtype=int,
doc="Threshold above which pixels in the parallel overscan are masked as bleeds.",
default=100000,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a config that will be per camera, per detector, or per amp? I know we don't currently support running with different isr configs, but I'm curious about this one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It probably needs to be something that the per-amp overscan config will need to include, but I think it's the same per device (with possible exceptions https://lsstc.slack.com/archives/CBV7K0DK6/p1674082048054619 suggests some amplifiers may allow bleeds at lower flux levels).

)
parallelOverscanMaskGrowSize = pexConfig.Field(
dtype=int,
doc="Masks created from saturated bleeds should be grown by this many "
"pixels during construction of the parallel overscan mask. "
"This value determined from the ITL chip in the LATISS camera",
default=7,
erykoff marked this conversation as resolved.
Show resolved Hide resolved
)

leadingColumnsToSkip = pexConfig.Field(
Expand Down Expand Up @@ -221,32 +223,26 @@ def run(self, exposure, amp, isTransposed=False):
maskIm = exposure.getMaskedImage()
maskIm = maskIm.Factory(maskIm, parallelOverscanBBox)

# The serial overscan correction has removed the majority
# of the signal in the parallel overscan region, so the
# mean should be close to zero. The noise in both should
# be similar, so we can use the noise from the serial
# overscan region to set the threshold for bleed
# detection.
thresholdLevel = self.config.numSigmaClip * serialResults.overscanSigmaResidual
makeThresholdMask(maskIm, threshold=thresholdLevel, growFootprints=0)
maskPix = countMaskedPixels(maskIm, self.config.maskPlanes)
xSize, ySize = parallelOverscanBBox.getDimensions()
if maskPix > xSize*ySize*self.config.parallelOverscanMaskThreshold:
self.log.warning('Fraction of masked pixels for parallel overscan calculation larger'
' than %f of total pixels (i.e. %f masked pixels) on amp %s.',
self.config.parallelOverscanMaskThreshold, maskPix, amp.getName())
self.log.warning('Not doing parallel overscan correction.')
else:
parallelResults = self.correctOverscan(exposure, amp,
imageBBox, parallelOverscanBBox,
isTransposed=not isTransposed)

overscanMean = (overscanMean, parallelResults.overscanMean)
overscanMedian = (overscanMedian, parallelResults.overscanMedian)
overscanSigma = (overscanSigma, parallelResults.overscanSigma)
residualMean = (residualMean, parallelResults.overscanMeanResidual)
residualMedian = (residualMedian, parallelResults.overscanMedianResidual)
residualSigma = (residualSigma, parallelResults.overscanSigmaResidual)
# The serial overscan correction has removed some signal
# from the parallel overscan region, but that is largely a
# constant offset. The collapseArray method now attempts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You say "constant offset" but that's constant per what unit of area?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm very confused about the reference to collapseArray here because I don't see how it's called in the parallel overscan code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's calculated from the double-overscan corner, subtracting one value per row. All rows should be the same within the read noise, as there's no real signal. For the images I've been testing, the serial overscan median value is ~13000-15000, and subtracting it pulls the parallel overscan region down so that the parallel overscan median value is ~10-20.

The overscan code is far more confusing than I'd like (definitely my fault). The parallel overscan uses the same code path as the serial, just with some extra transposes. correctOverscan calls fitOverscan, which calls either measureConstantOverscan or measureVectorOverscan. Continuing with the vector case, that method calls collapseArray for all except MEDIAN_PER_ROW, which instead calls fitOverscanImage (the C++ code), then the fillMaskedPixels method that is also called by collapseArray. I'm not happy with the code-spaghetti I've made.

# to fill fully masked columns with the median of
# neighboring values, with a fallback to the median of the
# correction in all other columns. Filling with neighbor
# values ensures that large variations in the parallel
# overscan do not create new outlier points. The
# MEDIAN_PER_ROW method does this filling as a separate
# operation, using the same method.
parallelResults = self.correctOverscan(exposure, amp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't comment there, but correctOverscan doesn't have a docstring.

imageBBox, parallelOverscanBBox,
isTransposed=not isTransposed)
overscanMean = (overscanMean, parallelResults.overscanMean)
overscanMedian = (overscanMedian, parallelResults.overscanMedian)
overscanSigma = (overscanSigma, parallelResults.overscanSigma)
residualMean = (residualMean, parallelResults.overscanMeanResidual)
residualMedian = (residualMedian, parallelResults.overscanMedianResidual)
residualSigma = (residualSigma, parallelResults.overscanSigmaResidual)

parallelOverscanFit = parallelResults.overscanOverscanModel if parallelResults else None
parallelOverscanImage = parallelResults.overscanImage if parallelResults else None

Expand All @@ -264,7 +260,55 @@ def run(self, exposure, amp, isTransposed=False):
residualSigma=residualSigma)

def correctOverscan(self, exposure, amp, imageBBox, overscanBBox, isTransposed=True):
"""
"""Trim the exposure, fit the overscan, subtract the fit, and
calculate statistics.

Parameters
----------
exposure : `lsst.afw.image.Exposure`
Exposure containing the data.
amp : `lsst.afw.cameraGeom.Amplifier`
The amplifier that is to be corrected.
imageBBox: `lsst.geom.Box2I`
Bounding box of the image data that will have the overscan
subtracted. If parallel overscan will be performed, that
area is added to the image bounding box during serial
overscan correction.
overscanBBox: `lsst.geom.Box2I`
Bounding box for the overscan data.
isTransposed: `bool`
If true, then the data will be transposed before fitting
the overscan.

Returns
-------
results : `lsst.pipe.base.Struct`
``ampOverscanModel``
Overscan model broadcast to the full image size.
(`lsst.afw.image.Exposure`)
``overscanOverscanModel``
Overscan model broadcast to the full overscan image
size. (`lsst.afw.image.Exposure`)
``overscanImage``
Overscan image with the overscan fit subtracted.
(`lsst.afw.image.Exposure`)
``overscanValue``
Overscan model. (`float` or `np.array`)
``overscanMean``
Mean value of the overscan fit. (`float`)
``overscanMedian``
Median value of the overscan fit. (`float`)
``overscanSigma``
Standard deviation of the overscan fit. (`float`)
``overscanMeanResidual``
Mean value of the overscan region after overscan
subtraction. (`float`)
``overscanMedianResidual``
Median value of the overscan region after overscan
subtraction. (`float`)
``overscanSigmaResidual``
Standard deviation of the overscan region after
overscan subtraction. (`float`)
"""
overscanBox = self.trimOverscan(exposure, amp, overscanBBox,
self.config.leadingColumnsToSkip,
Expand All @@ -279,7 +323,7 @@ def correctOverscan(self, exposure, amp, imageBBox, overscanBBox, isTransposed=T

median = np.ma.median(np.ma.masked_where(overscanMask, overscanArray))
bad = np.where(np.abs(overscanArray - median) > self.config.maxDeviation)
overscanMask[bad] = overscanImage.mask.getPlaneBitMask("SAT")
overscanImage.mask.array[bad] = overscanImage.mask.getPlaneBitMask("SAT")

# Do overscan fit.
# CZW: Handle transposed correctly.
Expand All @@ -298,10 +342,9 @@ def correctOverscan(self, exposure, amp, imageBBox, overscanBBox, isTransposed=T
# CZW: Transposed?
overscanOverscanModel = self.broadcastFitToImage(overscanResults.overscanValue,
overscanImage.image.array)
self.debugView(overscanImage, overscanResults.overscanValue, amp, isTransposed=isTransposed)
overscanImage.image.array -= overscanOverscanModel

self.debugView(overscanImage, overscanResults.overscanValue, amp)

# Find residual fit statistics.
stats = afwMath.makeStatistics(overscanImage.getMaskedImage(),
afwMath.MEAN | afwMath.MEDIAN | afwMath.STDEVCLIP, self.statControl)
Expand Down Expand Up @@ -483,6 +526,41 @@ def integerConvert(image):
image, type(image), dir(image))
return outI

def maskParallelOverscan(self, exposure, detector):
"""Mask the union of high values on all amplifiers in the parallel
overscan.

This operates on the image in-place.

Parameters
----------
exposure : `lsst.afw.image.Exposure`
An untrimmed raw exposure.
detector : `lsst.afw.cameraGeom.Detector`
The detetor to use for amplifier geometry.
"""
parallelMask = None

for amp in detector:
dataView = afwImage.MaskedImageF(exposure.getMaskedImage(),
amp.getRawParallelOverscanBBox(),
afwImage.PARENT)
makeThresholdMask(
maskedImage=dataView,
threshold=self.config.parallelOverscanMaskThreshold,
growFootprints=self.config.parallelOverscanMaskGrowSize,
maskName="BAD"
)
if parallelMask is None:
parallelMask = dataView.mask.array
else:
parallelMask |= dataView.mask.array
for amp in detector:
dataView = afwImage.MaskedImageF(exposure.getMaskedImage(),
amp.getRawParallelOverscanBBox(),
afwImage.PARENT)
dataView.mask.array |= parallelMask

# Constant methods
def measureConstantOverscan(self, image):
"""Measure a constant overscan value.
Expand All @@ -499,12 +577,8 @@ def measureConstantOverscan(self, image):
- ``overscanValue``: Overscan value to subtract (`float`)
- ``isTransposed``: Orientation of the overscan (`bool`)
"""
if self.config.fitType == 'MEDIAN':
calcImage = self.integerConvert(image)
else:
calcImage = image
fitType = afwMath.stringToStatisticsProperty(self.config.fitType)
overscanValue = afwMath.makeStatistics(calcImage, fitType, self.statControl).getValue()
overscanValue = afwMath.makeStatistics(image, fitType, self.statControl).getValue()

return pipeBase.Struct(overscanValue=overscanValue,
isTransposed=False)
Expand Down Expand Up @@ -547,27 +621,81 @@ def maskOutliers(self, imageArray):
axisMedians = median
axisStdev = 0.74*(uq - lq) # robust stdev

# Replace pixels that have excessively large stdev values
# with the median of stdev values. A large stdev likely
# indicates a bleed is spilling into the overscan.
axisStdev = np.where(axisStdev > 2.0 * np.median(axisStdev),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does this 2.0 come from and should it be configurable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can take almost any value and do a similar thing. It was added for the case below:

overscan = [[0.0 0.0 0.0 0.0 0.0] # Furthest row from imaging section.
            [0.0 0.0 100000.0 0.0 0.0]
            [0.0 0.0 100000.0 0.0 0.0]] # Closest row to the imaging section.
axisMedian = [0.0 0.0 100000.0 0.0 0.0]
axisStdev = [0.0 0.0 37000.0 0.0 0.0]

With these values, the "bleed" column rejects nothing, as all pixels are within 3 sigma, so some sigma clipping of the STDEV values is required. I chose 2.0 because in all the plots I looked at, the real STDEV values were tightly clustered about the median, so the penalty for over clipping makes little difference to the masking step.

np.median(axisStdev), axisStdev)

# Mask pixels that are N-sigma away from their array medians.
diff = np.abs(imageArray - axisMedians[:, np.newaxis])
return np.ma.masked_where(diff > self.statControl.getNumSigmaClip()
* axisStdev[:, np.newaxis], imageArray)
masked = np.ma.masked_where(diff > self.statControl.getNumSigmaClip()
* axisStdev[:, np.newaxis], imageArray)

@staticmethod
def collapseArray(maskedArray):
return masked

def fillMaskedPixels(self, overscanVector):
"""Fill masked/NaN pixels in the overscan.

Parameters
----------
overscanVector : `np.array` or `np.ma.masked_array`
Overscan vector to fill.

Returns
-------
overscanVector : `np.ma.masked_array`
Filled vector.

Notes
-----
Each maskSlice is a section of overscan with contiguous masks.
Ideally this adds 5 pixels from the left and right of that
mask slice, and takes the median of those values to fill the
slice. If this isn't possible, the median of all non-masked
values is used. The mask is removed for the pixels filled.
"""
workingCopy = overscanVector
if not isinstance(overscanVector, np.ma.MaskedArray):
workingCopy = np.ma.masked_array(overscanVector,
mask=~np.isfinite(overscanVector))

defaultValue = np.median(workingCopy.data[~workingCopy.mask])
for maskSlice in np.ma.clump_masked(workingCopy):
neighborhood = []
if maskSlice.start > 5:
neighborhood.extend(workingCopy[maskSlice.start - 5:maskSlice.start].data)
if maskSlice.stop < workingCopy.size - 5:
neighborhood.extend(workingCopy[maskSlice.stop:maskSlice.stop+5].data)
if len(neighborhood) > 0:
workingCopy.data[maskSlice] = np.nanmedian(neighborhood)
workingCopy.mask[maskSlice] = False
else:
workingCopy.data[maskSlice] = defaultValue
workingCopy.mask[maskSlice] = False
return workingCopy

def collapseArray(self, maskedArray, fillMasked=True):
"""Collapse overscan array (and mask) to a 1-D vector of values.

Parameters
----------
maskedArray : `numpy.ma.masked_array`
Masked array of input overscan data.
fillMasked : `bool`, optional
If true, fill any pixels that are masked with a median of
neighbors.

Returns
-------
collapsed : `numpy.ma.masked_array`
Single dimensional overscan data, combined with the mean.

"""
collapsed = np.mean(maskedArray, axis=1)
if collapsed.mask.sum() > 0:
collapsed.data[collapsed.mask] = np.mean(maskedArray.data[collapsed.mask], axis=1)
if collapsed.mask.sum() > 0 and fillMasked:
collapsed = self.fillMaskedPixels(collapsed)

return collapsed

def collapseArrayMedian(self, maskedArray):
Expand Down Expand Up @@ -735,6 +863,7 @@ def measureVectorOverscan(self, image, isTransposed=False):
mi.mask.array[:, :] = masked.mask[:, :]

overscanVector = fitOverscanImage(mi, self.config.maskPlanes, isTransposed)
overscanVector = self.fillMaskedPixels(overscanVector)
maskArray = self.maskExtrapolated(overscanVector)
else:
collapsed = self.collapseArray(masked)
Expand Down Expand Up @@ -770,7 +899,7 @@ def measureVectorOverscan(self, image, isTransposed=False):
maskArray=maskArray,
isTransposed=isTransposed)

def debugView(self, image, model, amp=None):
def debugView(self, image, model, amp=None, isTransposed=True):
"""Debug display for the final overscan solution.

Parameters
Expand All @@ -781,6 +910,8 @@ def debugView(self, image, model, amp=None):
Overscan model determined for the image.
amp : `lsst.afw.cameraGeom.Amplifier`, optional
Amplifier to extract diagnostic information.
isTransposed : `bool`, optional
Does the data need to be transposed before display?
"""
import lsstDebug
if not lsstDebug.Info(__name__).display:
Expand All @@ -790,12 +921,14 @@ def debugView(self, image, model, amp=None):

calcImage = self.getImageArray(image)
# CZW: Check that this is ok
calcImage = np.transpose(calcImage)
if isTransposed:
calcImage = np.transpose(calcImage)
masked = self.maskOutliers(calcImage)
collapsed = self.collapseArray(masked)
collapsed = self.collapseArray(masked, fillMasked=False)

num = len(collapsed)
indices = 2.0 * np.arange(num)/float(num) - 1.0
indices = np.arange(num)

if np.ma.is_masked(collapsed):
collapsedMask = collapsed.mask
Expand All @@ -814,8 +947,9 @@ def debugView(self, image, model, amp=None):
else:
plotModel = np.zeros_like(indices)
plotModel += model

axes.plot(indices, plotModel, 'r-')
plot.xlabel("centered/scaled position along overscan region")
plot.xlabel("position along overscan region")
plot.ylabel("pixel value/fit value")
if amp:
plot.title(f"{amp.getName()} DataX: "
Expand Down