Skip to content

Commit

Permalink
Merge branch 'tickets/DM-23786'
Browse files Browse the repository at this point in the history
  • Loading branch information
leeskelvin committed Mar 31, 2023
2 parents 59f6dce + fda16d1 commit 3b994c1
Showing 1 changed file with 355 additions and 0 deletions.
355 changes: 355 additions & 0 deletions python/lsst/pipe/tasks/subtractBrightStars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,355 @@
# This file is part of pipe_tasks.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
"""Retrieve extended PSF model and subtract bright stars at calexp (ie
single visit) level.
"""

__all__ = ["SubtractBrightStarsConnections", "SubtractBrightStarsConfig", "SubtractBrightStarsTask"]

from functools import reduce
from operator import ior

import numpy as np
from lsst.afw.image import Exposure, ExposureF, MaskedImageF
from lsst.afw.math import (
StatisticsControl,
WarpingControl,
makeStatistics,
rotateImageBy90,
stringToStatisticsProperty,
warpImage,
)
from lsst.geom import Box2I, Point2D, Point2I
from lsst.pex.config import ChoiceField, Field, ListField
from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct
from lsst.pipe.base import connectionTypes as cT


class SubtractBrightStarsConnections(
PipelineTaskConnections,
dimensions=("instrument", "visit", "detector"),
defaultTemplates={"outputExposureName": "brightStar_subtracted", "outputBackgroundName": "brightStars"},
):
inputExposure = cT.Input(
doc="Input exposure from which to subtract bright star stamps.",
name="calexp",
storageClass="ExposureF",
dimensions=(
"visit",
"detector",
),
)
inputBrightStarStamps = cT.Input(
doc="Set of preprocessed postage stamps, each centered on a single bright star.",
name="brightStarStamps",
storageClass="BrightStarStamps",
dimensions=(
"visit",
"detector",
),
)
inputExtendedPsf = cT.Input(
doc="Extended PSF model.",
name="extended_psf",
storageClass="ExtendedPsf",
dimensions=("band",),
)
skyCorr = cT.Input(
doc="Input Sky Correction to be subtracted from the calexp if ``doApplySkyCorr``=True.",
name="skyCorr",
storageClass="Background",
dimensions=(
"instrument",
"visit",
"detector",
),
)
outputExposure = cT.Output(
doc="Exposure with bright stars subtracted.",
name="{outputExposureName}_calexp",
storageClass="ExposureF",
dimensions=(
"visit",
"detector",
),
)
outputBackgroundExposure = cT.Output(
doc="Exposure containing only the modelled bright stars.",
name="{outputBackgroundName}_calexp_background",
storageClass="ExposureF",
dimensions=(
"visit",
"detector",
),
)

def __init__(self, *, config=None):
super().__init__(config=config)
if not config.doApplySkyCorr:
self.inputs.remove("skyCorr")


class SubtractBrightStarsConfig(PipelineTaskConfig, pipelineConnections=SubtractBrightStarsConnections):
"""Configuration parameters for SubtractBrightStarsTask"""

doWriteSubtractor = Field[bool](
dtype=bool,
doc="Should an exposure containing all bright star models be written to disk?",
default=True,
)
doWriteSubtractedExposure = Field[bool](
dtype=bool,
doc="Should an exposure with bright stars subtracted be written to disk?",
default=True,
)
magLimit = Field[float](
dtype=float,
doc="Magnitude limit, in Gaia G; all stars brighter than this value will be subtracted",
default=18,
)
warpingKernelName = ChoiceField[str](
dtype=str,
doc="Warping kernel",
default="lanczos5",
allowed={
"bilinear": "bilinear interpolation",
"lanczos3": "Lanczos kernel of order 3",
"lanczos4": "Lanczos kernel of order 4",
"lanczos5": "Lanczos kernel of order 5",
"lanczos6": "Lanczos kernel of order 6",
"lanczos7": "Lanczos kernel of order 7",
},
)
scalingType = ChoiceField[str](
dtype=str,
doc="How the model should be scaled to each bright star; implemented options are "
"`annularFlux` to reuse the annular flux of each stamp, or `leastSquare` to perform "
"least square fitting on each pixel with no bad mask plane set.",
default="leastSquare",
allowed={
"annularFlux": "reuse BrightStarStamp annular flux measurement",
"leastSquare": "find least square scaling factor",
},
)
badMaskPlanes = ListField[str](
dtype=str,
doc="Mask planes that, if set, lead to associated pixels not being included in the computation of "
"the scaling factor (`BAD` should always be included). Ignored if scalingType is `annularFlux`, "
"as the stamps are expected to already be normalized.",
# Note that `BAD` should always be included, as secondary detected
# sources (i.e., detected sources other than the primary source of
# interest) also get set to `BAD`.
default=("BAD", "CR", "CROSSTALK", "EDGE", "NO_DATA", "SAT", "SUSPECT", "UNMASKEDNAN"),
)
doApplySkyCorr = Field[bool](
dtype=bool,
doc="Apply full focal plane sky correction before extracting stars?",
default=True,
)


class SubtractBrightStarsTask(PipelineTask):
"""Use an extended PSF model to subtract bright stars from a calibrated
exposure (i.e. at single-visit level).
This task uses both a set of bright star stamps produced by
`~lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask`
and an extended PSF model produced by
`~lsst.pipe.tasks.extended_psf.MeasureExtendedPsfTask`.
"""

ConfigClass = SubtractBrightStarsConfig
_DefaultName = "subtractBrightStars"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Placeholders to set up Statistics if scalingType is leastSquare.
self.statsControl, self.statsFlag = None, None

def _setUpStatistics(self, exampleMask):
"""Configure statistics control and flag, for use if ``scalingType`` is
`leastSquare`.
"""
if self.config.scalingType == "leastSquare":
self.statsControl = StatisticsControl()
# Set the mask planes which will be ignored.
andMask = reduce(ior, (exampleMask.getPlaneBitMask(bm) for bm in self.config.badMaskPlanes))
self.statsControl.setAndMask(andMask)
self.statsFlag = stringToStatisticsProperty("SUM")

def applySkyCorr(self, calexp, skyCorr):
"""Apply correction to the sky background level.
Sky corrections can be generated via the SkyCorrectionTask within the
pipe_tools module. Because the sky model used by that code extends over
the entire focal plane, this can produce better sky subtraction.
The calexp is updated in-place.
Parameters
----------
calexp : `~lsst.afw.image.Exposure` or `~lsst.afw.image.MaskedImage`
Calibrated exposure.
skyCorr : `~lsst.afw.math.backgroundList.BackgroundList`
Full focal plane sky correction, obtained by running
`~lsst.pipe.drivers.skyCorrection.SkyCorrectionTask`.
"""
if isinstance(calexp, Exposure):
calexp = calexp.getMaskedImage()
calexp -= skyCorr.getImage()

def scaleModel(self, model, star, inPlace=True, nb90Rots=0):
"""Compute scaling factor to be applied to the extended PSF so that its
amplitude matches that of an individual star.
Parameters
----------
model : `~lsst.afw.image.MaskedImageF`
The extended PSF model, shifted (and potentially warped) to match
the bright star's positioning.
star : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp`
A stamp centered on the bright star to be subtracted.
inPlace : `bool`
Whether the model should be scaled in place. Default is `True`.
nb90Rots : `int`
The number of 90-degrees rotations to apply to the star stamp.
Returns
-------
scalingFactor : `float`
The factor by which the model image should be multiplied for it
to be scaled to the input bright star.
"""
if self.config.scalingType == "annularFlux":
scalingFactor = star.annularFlux
elif self.config.scalingType == "leastSquare":
if self.statsControl is None:
self._setUpStatistics(star.stamp_im.mask)
starIm = star.stamp_im.clone()
# Rotate the star postage stamp.
starIm = rotateImageBy90(starIm, nb90Rots)
# Reverse the prior star flux normalization ("unnormalize").
starIm *= star.annularFlux
# The estimator of the scalingFactor (f) that minimizes (Y-fX)^2
# is E[XY]/E[XX].
xy = starIm.clone()
xy.image.array *= model.image.array
xx = starIm.clone()
xx.image.array = model.image.array**2
# Compute the least squares scaling factor.
xySum = makeStatistics(xy, self.statsFlag, self.statsControl).getValue()
xxSum = makeStatistics(xx, self.statsFlag, self.statsControl).getValue()
scalingFactor = xySum / xxSum if xxSum else 1
if inPlace:
model.image *= scalingFactor
return scalingFactor

def runQuantum(self, butlerQC, inputRefs, outputRefs):
# Docstring inherited.
inputs = butlerQC.get(inputRefs)
dataId = butlerQC.quantum.dataId
subtractor, _ = self.run(**inputs, dataId=dataId)
if self.config.doWriteSubtractedExposure:
outputExposure = inputs["inputExposure"].clone()
outputExposure.image -= subtractor.image
else:
outputExposure = None
outputBackgroundExposure = subtractor if self.config.doWriteSubtractor else None
output = Struct(outputExposure=outputExposure, outputBackgroundExposure=outputBackgroundExposure)
butlerQC.put(output, outputRefs)

def run(self, inputExposure, inputBrightStarStamps, inputExtendedPsf, dataId, skyCorr=None):
"""Iterate over all bright stars in an exposure to scale the extended
PSF model before subtracting bright stars.
Parameters
----------
inputExposure : `~lsst.afw.image.exposure.exposure.ExposureF`
The image from which bright stars should be subtracted.
inputBrightStarStamps :
`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`
Set of stamps centered on each bright star to be subtracted,
produced by running
`~lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask`.
inputExtendedPsf : `~lsst.pipe.tasks.extended_psf.ExtendedPsf`
Extended PSF model, produced by
`~lsst.pipe.tasks.extended_psf.MeasureExtendedPsfTask`.
dataId : `dict` or `~lsst.daf.butler.DataCoordinate`
The dataId of the exposure (and detector) bright stars should be
subtracted from.
skyCorr : `~lsst.afw.math.backgroundList.BackgroundList`, optional
Full focal plane sky correction, obtained by running
`~lsst.pipe.drivers.skyCorrection.SkyCorrectionTask`. If
`doApplySkyCorr` is set to `True`, `skyCorr` cannot be `None`.
Returns
-------
subtractorExp : `~lsst.afw.image.ExposureF`
An Exposure containing a scaled bright star model fit to every
bright star profile; its image can then be subtracted from the
input exposure.
invImages : `list` [`~lsst.afw.image.MaskedImageF`]
A list of small images ("stamps") containing the model, each scaled
to its corresponding input bright star.
"""
inputExpBBox = inputExposure.getBBox()
if self.config.doApplySkyCorr and (skyCorr is not None):
self.log.info(
"Applying sky correction to exposure %s (exposure will be modified in-place).", dataId
)
self.applySkyCorr(inputExposure, skyCorr)
# Create an empty image the size of the exposure.
# TODO: DM-31085 (set mask planes).
subtractorExp = ExposureF(bbox=inputExposure.getBBox())
subtractor = subtractorExp.maskedImage
# Make a copy of the input model.
model = inputExtendedPsf(dataId["detector"]).clone()
modelStampSize = model.getDimensions()
inv90Rots = 4 - inputBrightStarStamps.nb90Rots % 4
model = rotateImageBy90(model, inv90Rots)
warpCont = WarpingControl(self.config.warpingKernelName)
invImages = []
# Loop over bright stars, computing the inverse transformed and scaled
# postage stamp for each.
for star in inputBrightStarStamps:
if star.gaiaGMag < self.config.magLimit:
# Set the origin.
model.setXY0(star.position)
# Create an empty destination image.
invTransform = star.archive_element.inverted()
invOrigin = Point2I(invTransform.applyForward(Point2D(star.position)))
bbox = Box2I(corner=invOrigin, dimensions=modelStampSize)
invImage = MaskedImageF(bbox)
# Apply inverse transform.
goodPix = warpImage(invImage, model, invTransform, warpCont)
if not goodPix:
self.log.debug(
f"Warping of a model failed for star {star.gaiaId}: " "no good pixel in output"
)
# Scale the model.
self.scaleModel(invImage, star, inPlace=True, nb90Rots=inv90Rots)
# Replace NaNs before subtraction (note all NaN pixels have
# the NO_DATA flag).
invImage.image.array[np.isnan(invImage.image.array)] = 0
bbox.clip(inputExpBBox)
if bbox.getArea() > 0:
subtractor[bbox] += invImage[bbox]
invImages.append(invImage)
return subtractorExp, invImages

0 comments on commit 3b994c1

Please sign in to comment.