Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 140 additions & 17 deletions python/lsst/meas/algorithms/maskStreaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,11 @@ class LineProfile:
Guess for position of line. Data far from line guess is masked out.
Defaults to None, in which case only data with `weights` = 0 is masked
out.
detectionMask : `np.ndarray`, optional
2-d boolean array where detected pixels are True.
"""

def __init__(self, data, weights, line=None):
def __init__(self, data, weights, line=None, detectionMask=None):
self.data = data
self.weights = weights
self._ymax, self._xmax = data.shape
Expand All @@ -145,7 +147,7 @@ def __init__(self, data, weights, line=None):
self.mask = (weights != 0)

self._initLine = line
self.setLineMask(line, maxStreakWidth=0, nSigmaMask=5)
self.setLineMask(line, maxStreakWidth=0, nSigmaMask=10, detectionMask=detectionMask)

def getLineXY(self, line):
"""Return the pixel coordinates of the ends of the line.
Expand Down Expand Up @@ -191,13 +193,21 @@ def getLineXY(self, line):

return boxIntersections

def setLineMask(self, line, maxStreakWidth, nSigmaMask, logger=None):
def setLineMask(self, line, maxStreakWidth, nSigmaMask, logger=None, detectionMask=None):
"""Set mask around the image region near the line.

Parameters
----------
line : `Line`
Parameters of line in the image.
maxStreakWidth : `float`
Maximum width in pixels of streak mask.
nSigmaMask : `float`
Factor by which to multiply the line's width to get the mask width.
logger : `lsst.utils.logging.LsstLogAdapter`, optional
Logger to use for reporting when maxStreakWidth is reached.
detectionMask : `np.ndarray`, optional
2-d boolean array where detected pixels are True.
"""
if line:
# Only fit pixels within nSigmaMask * sigma of the estimated line
Expand All @@ -212,10 +222,38 @@ def setLineMask(self, line, maxStreakWidth, nSigmaMask, logger=None):
width = maxStreakWidth
m = (abs(distance) < width/2)
self.lineMask = self.mask & m
if detectionMask is not None:
# Mask out areas where there are no detected pixels. This
# happens when, for example, the streak ends in the middle of
# the image.
lineEnds = self.getLineXY(line)
xA = lineEnds[0, 0] - self._xmax / 2.
yA = lineEnds[0, 1] - self._ymax / 2.

radtheta = np.deg2rad(line.theta)
costheta = np.cos(radtheta)
sintheta = np.sin(radtheta)

maskDetections = detectionMask[self.lineMask] != 0
distanceFromLineEnd = (- sintheta * self._xmesh[self.lineMask]
+ costheta * self._ymesh[self.lineMask]
+ sintheta * xA
- costheta * yA)
lineBins = np.arange(distanceFromLineEnd.min(), distanceFromLineEnd.max() + 5.1, 5)
# Get the chi2 of the pixels perpendicular to the streak:
detectionsAlongStreak, _, binnumber = scipy.stats.binned_statistic(distanceFromLineEnd,
maskDetections,
statistic='sum',
bins=lineBins)
countAlongStreak, *_ = scipy.stats.binned_statistic(distanceFromLineEnd, maskDetections,
statistic='count', bins=lineBins)
detectionFraction = detectionsAlongStreak / countAlongStreak
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having just squashed a divide by zero runtime error yesterday I thought was impossible... think if truly impossible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this can give a runtime error, because these are arrays, so dividing by zero gives a RuntimeWarning, not an error.

emptyRows = detectionFraction < (0.5 * np.median(detectionFraction[detectionFraction != 0]))
emptyDetections = emptyRows[binnumber - 1]
self.lineMask[self.lineMask] = ~emptyDetections
else:
self.lineMask = np.copy(self.mask)

self.lineMaskSize = self.lineMask.sum()
self._maskData = self.data[self.lineMask]
self._maskWeights = self.weights[self.lineMask]
self._mxmesh = self._xmesh[self.lineMask]
Expand Down Expand Up @@ -321,18 +359,77 @@ def _lineChi2(self, line, grad=True):
# Calculate chi2
model, dModel = self._makeMaskedProfile(line)
chi2 = (self._maskWeights * (self._maskData - model)**2).sum()
maskSize = (self._maskWeights != 0).sum()
if not grad:
return chi2.sum() / self.lineMaskSize
return chi2.sum() / maskSize

# Calculate derivative and Hessian of chi2
derivChi2 = ((-2 * self._maskWeights * (self._maskData - model))[None, :] * dModel).sum(axis=1)
hessianChi2 = (2 * self._maskWeights * dModel[:, None, :] * dModel[None, :, :]).sum(axis=2)

reducedChi = chi2 / self.lineMaskSize
reducedDChi = derivChi2 / self.lineMaskSize
reducedHessianChi = hessianChi2 / self.lineMaskSize
reducedChi = chi2 / maskSize
reducedDChi = derivChi2 / maskSize
reducedHessianChi = hessianChi2 / maskSize
return reducedChi, reducedDChi, reducedHessianChi

def _rejectOutliers(self, line):
"""Reject outlier pixels.

This calculates the chi2/dof in bins of pixels perpendicular to the
streak direction and removes outliers. This is done so that the profile
fitter ignores regions like the area around bright stars.

Parameters
----------
line : `Line`
`Line` parameters for which to build model and calculate chi2.

Returns
-------
nOutliers : `int`
Number of outlier pixels.
"""
model, _ = self._makeMaskedProfile(line)
pixelChi2 = (self._maskWeights * (self._maskData - model)**2)

lineEnds = self.getLineXY(line)

if lineEnds.shape == (2, 2):
xA = lineEnds[0, 0] - self._xmax / 2.
yA = lineEnds[0, 1] - self._ymax / 2.
else:
# Line profile is outside the detector bounding box. Exit outlier rejection.
return 0

radtheta = np.deg2rad(line.theta)
costheta = np.cos(radtheta)
sintheta = np.sin(radtheta)

distanceFromLineEnd = (- sintheta * self._mxmesh + costheta * self._mymesh + sintheta * xA
- costheta * yA)

distanceFromLineEnd = distanceFromLineEnd[self._maskWeights != 0]
nonZeroPixelChi2 = pixelChi2[self._maskWeights != 0]
lineBins = np.arange(distanceFromLineEnd.min(), distanceFromLineEnd.max() + 5.1, 5)
# Get the chi2 of the pixels perpendicular to the streak:
chi2AlongStreak, _, binnumber = scipy.stats.binned_statistic(distanceFromLineEnd, nonZeroPixelChi2,
statistic='sum', bins=lineBins)
countAlongStreak, *_ = scipy.stats.binned_statistic(distanceFromLineEnd, nonZeroPixelChi2,
statistic='count', bins=lineBins)

rChi2AlongStreak = chi2AlongStreak / countAlongStreak
outliers = rChi2AlongStreak > (np.nanmean(rChi2AlongStreak[rChi2AlongStreak != 0])
+ 3 * np.nanstd(rChi2AlongStreak[rChi2AlongStreak != 0]))

outlierPix = outliers[binnumber - 1]
tmpWeights = self._maskWeights[self._maskWeights != 0]
tmpWeights[outlierPix] = 0
self._maskWeights[self._maskWeights != 0] = tmpWeights

nOutliers = outlierPix.sum()

return nOutliers

def fit(self, dChi2Tol=0.1, maxIter=100, log=None):
"""Perform Newton-Raphson minimization to find line parameters.

Expand Down Expand Up @@ -368,14 +465,15 @@ def fit(self, dChi2Tol=0.1, maxIter=100, log=None):
dChi2 = 1
iter = 0
oldChi2 = 0
nOutliers = 1
fitFailure = False

def line_search(c, dx):
testx = x - c * dx
testLine = Line(testx[0], testx[1], testx[2]**-1)
return self._lineChi2(testLine, grad=False)

while abs(dChi2) > dChi2Tol:
while (abs(dChi2) > dChi2Tol) or (nOutliers != 0):
line = Line(x[0], x[1], x[2]**-1)
chi2, b, A = self._lineChi2(line)
if chi2 == 0:
Expand All @@ -401,10 +499,11 @@ def line_search(c, dx):
fitFailure = True
break
oldChi2 = chi2

nOutliers = self._rejectOutliers(line)
iter += 1

outline = Line(x[0], x[1], abs(x[2])**-1, chi2)

return outline, fitFailure


Expand Down Expand Up @@ -453,7 +552,7 @@ class MaskStreaksConfig(pexConfig.Config):
doc="Binsize in pixels for position parameter rho when finding "
"clusters of detected lines",
dtype=float,
default=30,
default=40,
)
thetaBinSize = pexConfig.Field(
doc="Binsize in degrees for angle parameter theta when finding "
Expand Down Expand Up @@ -516,6 +615,12 @@ class MaskStreaksConfig(pexConfig.Config):
dtype=float,
default=0.,
)
saturatedDetectionsDilation = pexConfig.Field(
doc="Mask out the region around saturated detections by dilating the "
"existing mask by this number of pixels.",
dtype=int,
default=250,
)


class MaskStreaksTask(pipeBase.Task):
Expand Down Expand Up @@ -568,12 +673,30 @@ def find(self, maskedImage):
initEdges = self._cannyFilter(detectionMask)
# Ignore regions with known bad masks, adding a one-pixel buffer around
# each to ensure that the edges of bad regions are also ignored.
ignoreMask = mask.clone()

badPixelMask = mask.getPlaneBitMask(self.config.badMaskPlanes)
badMaskSpanSet = SpanSet.fromMask(mask, badPixelMask).split()
for sset in badMaskSpanSet:
sset_dilated = sset.dilated(1)
sset_dilated.clippedTo(mask.getBBox()).setMask(mask, mask.getPlaneBitMask("BAD"))
dilatedBadMask = (mask.array & badPixelMask) > 0
sset_dilated.clippedTo(
ignoreMask.getBBox()).setMask(ignoreMask, ignoreMask.getPlaneBitMask("BAD"))

# TODO: DM-52769, replace this with a model for the diffraction spikes
# around bright stars once DM-52541 is done.
if self.config.saturatedDetectionsDilation:
# Dilate spansets that are both detected and saturated mask by a lot more:
satMask = mask.getPlaneBitMask("SAT")
satMask = (mask.array & mask.getPlaneBitMask("SAT"))
satDetMask = (satMask != 0) & (detectionMask != 0)
satDetIm = lsst.afw.image.Mask(satDetMask.astype(np.int32))
satSpanSet = SpanSet.fromMask(satDetIm, 1).split()
for sset in satSpanSet:
sset_dilated = sset.dilated(self.config.saturatedDetectionsDilation)
sset_dilated.clippedTo(
ignoreMask.getBBox()).setMask(ignoreMask, ignoreMask.getPlaneBitMask("BAD"))

dilatedBadMask = (ignoreMask.array & badPixelMask) > 0
self.edges = initEdges & ~dilatedBadMask
self.lines = self._runKHT(self.edges)

Expand All @@ -583,7 +706,7 @@ def find(self, maskedImage):
clusters = LineCollection([], [])
else:
clusters = self._findClusters(self.lines)
fitLines, lineMask = self._fitProfile(clusters, maskedImage)
fitLines, lineMask = self._fitProfile(clusters, maskedImage, detectionMask=detectionMask)

if self.config.onlyMaskDetected:
# The output mask is the intersection of the fit streaks and the image detections
Expand Down Expand Up @@ -728,7 +851,7 @@ def _findClusters(self, lines):

return result

def _fitProfile(self, lines, maskedImage):
def _fitProfile(self, lines, maskedImage, detectionMask=None):
"""Fit the profile of the streak.

Given the initial parameters of detected lines, fit a model for the
Expand Down Expand Up @@ -764,9 +887,9 @@ def _fitProfile(self, lines, maskedImage):
nFitFailures = 0
for line in lines:
line.sigma = self.config.invSigma**-1
lineModel = LineProfile(data, weights, line=line)
lineModel = LineProfile(data, weights, line=line, detectionMask=detectionMask)
# Skip any lines that do not cover any data (sometimes happens because of chip gaps)
if lineModel.lineMaskSize == 0:
if lineModel.lineMask.sum() == 0:
continue

fit, fitFailure = lineModel.fit(dChi2Tol=self.config.dChi2Tolerance, log=self.log,
Expand Down