From 77f9640ea12e95b160e8ca9464dcaac0b946a637 Mon Sep 17 00:00:00 2001 From: Clare Saunders Date: Thu, 22 May 2025 09:40:09 -0700 Subject: [PATCH] Improve fits with non-flat backgrounds --- python/lsst/meas/algorithms/maskStreaks.py | 157 ++++++++++++++++++--- 1 file changed, 140 insertions(+), 17 deletions(-) diff --git a/python/lsst/meas/algorithms/maskStreaks.py b/python/lsst/meas/algorithms/maskStreaks.py index 329ca431..5f0afc8a 100644 --- a/python/lsst/meas/algorithms/maskStreaks.py +++ b/python/lsst/meas/algorithms/maskStreaks.py @@ -131,9 +131,11 @@ class LineProfile: Guess for position of line. Data far from line guess is masked out. Defaults to None, in which case only data with `weights` = 0 is masked out. + detectionMask : `np.ndarray`, optional + 2-d boolean array where detected pixels are True. """ - def __init__(self, data, weights, line=None): + def __init__(self, data, weights, line=None, detectionMask=None): self.data = data self.weights = weights self._ymax, self._xmax = data.shape @@ -145,7 +147,7 @@ def __init__(self, data, weights, line=None): self.mask = (weights != 0) self._initLine = line - self.setLineMask(line, maxStreakWidth=0, nSigmaMask=5) + self.setLineMask(line, maxStreakWidth=0, nSigmaMask=10, detectionMask=detectionMask) def getLineXY(self, line): """Return the pixel coordinates of the ends of the line. @@ -191,13 +193,21 @@ def getLineXY(self, line): return boxIntersections - def setLineMask(self, line, maxStreakWidth, nSigmaMask, logger=None): + def setLineMask(self, line, maxStreakWidth, nSigmaMask, logger=None, detectionMask=None): """Set mask around the image region near the line. Parameters ---------- line : `Line` Parameters of line in the image. + maxStreakWidth : `float` + Maximum width in pixels of streak mask. + nSigmaMask : `float` + Factor by which to multiply the line's width to get the mask width. + logger : `lsst.utils.logging.LsstLogAdapter`, optional + Logger to use for reporting when maxStreakWidth is reached. + detectionMask : `np.ndarray`, optional + 2-d boolean array where detected pixels are True. """ if line: # Only fit pixels within nSigmaMask * sigma of the estimated line @@ -212,10 +222,38 @@ def setLineMask(self, line, maxStreakWidth, nSigmaMask, logger=None): width = maxStreakWidth m = (abs(distance) < width/2) self.lineMask = self.mask & m + if detectionMask is not None: + # Mask out areas where there are no detected pixels. This + # happens when, for example, the streak ends in the middle of + # the image. + lineEnds = self.getLineXY(line) + xA = lineEnds[0, 0] - self._xmax / 2. + yA = lineEnds[0, 1] - self._ymax / 2. + + radtheta = np.deg2rad(line.theta) + costheta = np.cos(radtheta) + sintheta = np.sin(radtheta) + + maskDetections = detectionMask[self.lineMask] != 0 + distanceFromLineEnd = (- sintheta * self._xmesh[self.lineMask] + + costheta * self._ymesh[self.lineMask] + + sintheta * xA + - costheta * yA) + lineBins = np.arange(distanceFromLineEnd.min(), distanceFromLineEnd.max() + 5.1, 5) + # Get the chi2 of the pixels perpendicular to the streak: + detectionsAlongStreak, _, binnumber = scipy.stats.binned_statistic(distanceFromLineEnd, + maskDetections, + statistic='sum', + bins=lineBins) + countAlongStreak, *_ = scipy.stats.binned_statistic(distanceFromLineEnd, maskDetections, + statistic='count', bins=lineBins) + detectionFraction = detectionsAlongStreak / countAlongStreak + emptyRows = detectionFraction < (0.5 * np.median(detectionFraction[detectionFraction != 0])) + emptyDetections = emptyRows[binnumber - 1] + self.lineMask[self.lineMask] = ~emptyDetections else: self.lineMask = np.copy(self.mask) - self.lineMaskSize = self.lineMask.sum() self._maskData = self.data[self.lineMask] self._maskWeights = self.weights[self.lineMask] self._mxmesh = self._xmesh[self.lineMask] @@ -321,18 +359,77 @@ def _lineChi2(self, line, grad=True): # Calculate chi2 model, dModel = self._makeMaskedProfile(line) chi2 = (self._maskWeights * (self._maskData - model)**2).sum() + maskSize = (self._maskWeights != 0).sum() if not grad: - return chi2.sum() / self.lineMaskSize + return chi2.sum() / maskSize # Calculate derivative and Hessian of chi2 derivChi2 = ((-2 * self._maskWeights * (self._maskData - model))[None, :] * dModel).sum(axis=1) hessianChi2 = (2 * self._maskWeights * dModel[:, None, :] * dModel[None, :, :]).sum(axis=2) - reducedChi = chi2 / self.lineMaskSize - reducedDChi = derivChi2 / self.lineMaskSize - reducedHessianChi = hessianChi2 / self.lineMaskSize + reducedChi = chi2 / maskSize + reducedDChi = derivChi2 / maskSize + reducedHessianChi = hessianChi2 / maskSize return reducedChi, reducedDChi, reducedHessianChi + def _rejectOutliers(self, line): + """Reject outlier pixels. + + This calculates the chi2/dof in bins of pixels perpendicular to the + streak direction and removes outliers. This is done so that the profile + fitter ignores regions like the area around bright stars. + + Parameters + ---------- + line : `Line` + `Line` parameters for which to build model and calculate chi2. + + Returns + ------- + nOutliers : `int` + Number of outlier pixels. + """ + model, _ = self._makeMaskedProfile(line) + pixelChi2 = (self._maskWeights * (self._maskData - model)**2) + + lineEnds = self.getLineXY(line) + + if lineEnds.shape == (2, 2): + xA = lineEnds[0, 0] - self._xmax / 2. + yA = lineEnds[0, 1] - self._ymax / 2. + else: + # Line profile is outside the detector bounding box. Exit outlier rejection. + return 0 + + radtheta = np.deg2rad(line.theta) + costheta = np.cos(radtheta) + sintheta = np.sin(radtheta) + + distanceFromLineEnd = (- sintheta * self._mxmesh + costheta * self._mymesh + sintheta * xA + - costheta * yA) + + distanceFromLineEnd = distanceFromLineEnd[self._maskWeights != 0] + nonZeroPixelChi2 = pixelChi2[self._maskWeights != 0] + lineBins = np.arange(distanceFromLineEnd.min(), distanceFromLineEnd.max() + 5.1, 5) + # Get the chi2 of the pixels perpendicular to the streak: + chi2AlongStreak, _, binnumber = scipy.stats.binned_statistic(distanceFromLineEnd, nonZeroPixelChi2, + statistic='sum', bins=lineBins) + countAlongStreak, *_ = scipy.stats.binned_statistic(distanceFromLineEnd, nonZeroPixelChi2, + statistic='count', bins=lineBins) + + rChi2AlongStreak = chi2AlongStreak / countAlongStreak + outliers = rChi2AlongStreak > (np.nanmean(rChi2AlongStreak[rChi2AlongStreak != 0]) + + 3 * np.nanstd(rChi2AlongStreak[rChi2AlongStreak != 0])) + + outlierPix = outliers[binnumber - 1] + tmpWeights = self._maskWeights[self._maskWeights != 0] + tmpWeights[outlierPix] = 0 + self._maskWeights[self._maskWeights != 0] = tmpWeights + + nOutliers = outlierPix.sum() + + return nOutliers + def fit(self, dChi2Tol=0.1, maxIter=100, log=None): """Perform Newton-Raphson minimization to find line parameters. @@ -368,6 +465,7 @@ def fit(self, dChi2Tol=0.1, maxIter=100, log=None): dChi2 = 1 iter = 0 oldChi2 = 0 + nOutliers = 1 fitFailure = False def line_search(c, dx): @@ -375,7 +473,7 @@ def line_search(c, dx): testLine = Line(testx[0], testx[1], testx[2]**-1) return self._lineChi2(testLine, grad=False) - while abs(dChi2) > dChi2Tol: + while (abs(dChi2) > dChi2Tol) or (nOutliers != 0): line = Line(x[0], x[1], x[2]**-1) chi2, b, A = self._lineChi2(line) if chi2 == 0: @@ -401,10 +499,11 @@ def line_search(c, dx): fitFailure = True break oldChi2 = chi2 + + nOutliers = self._rejectOutliers(line) iter += 1 outline = Line(x[0], x[1], abs(x[2])**-1, chi2) - return outline, fitFailure @@ -453,7 +552,7 @@ class MaskStreaksConfig(pexConfig.Config): doc="Binsize in pixels for position parameter rho when finding " "clusters of detected lines", dtype=float, - default=30, + default=40, ) thetaBinSize = pexConfig.Field( doc="Binsize in degrees for angle parameter theta when finding " @@ -516,6 +615,12 @@ class MaskStreaksConfig(pexConfig.Config): dtype=float, default=0., ) + saturatedDetectionsDilation = pexConfig.Field( + doc="Mask out the region around saturated detections by dilating the " + "existing mask by this number of pixels.", + dtype=int, + default=250, + ) class MaskStreaksTask(pipeBase.Task): @@ -568,12 +673,30 @@ def find(self, maskedImage): initEdges = self._cannyFilter(detectionMask) # Ignore regions with known bad masks, adding a one-pixel buffer around # each to ensure that the edges of bad regions are also ignored. + ignoreMask = mask.clone() + badPixelMask = mask.getPlaneBitMask(self.config.badMaskPlanes) badMaskSpanSet = SpanSet.fromMask(mask, badPixelMask).split() for sset in badMaskSpanSet: sset_dilated = sset.dilated(1) - sset_dilated.clippedTo(mask.getBBox()).setMask(mask, mask.getPlaneBitMask("BAD")) - dilatedBadMask = (mask.array & badPixelMask) > 0 + sset_dilated.clippedTo( + ignoreMask.getBBox()).setMask(ignoreMask, ignoreMask.getPlaneBitMask("BAD")) + + # TODO: DM-52769, replace this with a model for the diffraction spikes + # around bright stars once DM-52541 is done. + if self.config.saturatedDetectionsDilation: + # Dilate spansets that are both detected and saturated mask by a lot more: + satMask = mask.getPlaneBitMask("SAT") + satMask = (mask.array & mask.getPlaneBitMask("SAT")) + satDetMask = (satMask != 0) & (detectionMask != 0) + satDetIm = lsst.afw.image.Mask(satDetMask.astype(np.int32)) + satSpanSet = SpanSet.fromMask(satDetIm, 1).split() + for sset in satSpanSet: + sset_dilated = sset.dilated(self.config.saturatedDetectionsDilation) + sset_dilated.clippedTo( + ignoreMask.getBBox()).setMask(ignoreMask, ignoreMask.getPlaneBitMask("BAD")) + + dilatedBadMask = (ignoreMask.array & badPixelMask) > 0 self.edges = initEdges & ~dilatedBadMask self.lines = self._runKHT(self.edges) @@ -583,7 +706,7 @@ def find(self, maskedImage): clusters = LineCollection([], []) else: clusters = self._findClusters(self.lines) - fitLines, lineMask = self._fitProfile(clusters, maskedImage) + fitLines, lineMask = self._fitProfile(clusters, maskedImage, detectionMask=detectionMask) if self.config.onlyMaskDetected: # The output mask is the intersection of the fit streaks and the image detections @@ -728,7 +851,7 @@ def _findClusters(self, lines): return result - def _fitProfile(self, lines, maskedImage): + def _fitProfile(self, lines, maskedImage, detectionMask=None): """Fit the profile of the streak. Given the initial parameters of detected lines, fit a model for the @@ -764,9 +887,9 @@ def _fitProfile(self, lines, maskedImage): nFitFailures = 0 for line in lines: line.sigma = self.config.invSigma**-1 - lineModel = LineProfile(data, weights, line=line) + lineModel = LineProfile(data, weights, line=line, detectionMask=detectionMask) # Skip any lines that do not cover any data (sometimes happens because of chip gaps) - if lineModel.lineMaskSize == 0: + if lineModel.lineMask.sum() == 0: continue fit, fitFailure = lineModel.fit(dChi2Tol=self.config.dChi2Tolerance, log=self.log,