Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-23786: Add bright star subtraction task #555

Merged
merged 3 commits into from
Mar 31, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
355 changes: 355 additions & 0 deletions python/lsst/pipe/tasks/subtractBrightStars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,355 @@
# This file is part of pipe_tasks.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
"""Retrieve extended PSF model and subtract bright stars at calexp (ie
single visit) level.
"""

__all__ = ["SubtractBrightStarsConnections", "SubtractBrightStarsConfig", "SubtractBrightStarsTask"]

from functools import reduce
from operator import ior

import numpy as np
from lsst.afw.image import Exposure, ExposureF, MaskedImageF
from lsst.afw.math import (
StatisticsControl,
WarpingControl,
makeStatistics,
rotateImageBy90,
stringToStatisticsProperty,
warpImage,
)
from lsst.geom import Box2I, Point2D, Point2I
from lsst.pex.config import ChoiceField, Field, ListField
from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct
from lsst.pipe.base import connectionTypes as cT


class SubtractBrightStarsConnections(
PipelineTaskConnections,
dimensions=("instrument", "visit", "detector"),
defaultTemplates={"outputExposureName": "brightStar_subtracted", "outputBackgroundName": "brightStars"},
):
inputExposure = cT.Input(
doc="Input exposure from which to subtract bright star stamps.",
name="calexp",
storageClass="ExposureF",
dimensions=(
"visit",
"detector",
),
)
inputBrightStarStamps = cT.Input(
doc="Set of preprocessed postage stamps, each centered on a single bright star.",
name="brightStarStamps",
storageClass="BrightStarStamps",
dimensions=(
"visit",
"detector",
),
)
inputExtendedPsf = cT.Input(
doc="Extended PSF model.",
name="extended_psf",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has it been already defined elsewhere that the name of this input will be extended_psf? While it is within the rules, having some variables in camelCase and some in snake_case is going to get confusing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this nomenclature has already been established in prior tickets. Future tickets should aim to standardize all the names associated with bright star subtraction.

storageClass="ExtendedPsf",
dimensions=("band",),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(another) (very very) minor comment - also, some tuples and/or connection blocks have terminating commas, and some don't.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added trailing commas.

)
skyCorr = cT.Input(
doc="Input Sky Correction to be subtracted from the calexp if ``doApplySkyCorr``=True.",
name="skyCorr",
storageClass="Background",
dimensions=(
"instrument",
"visit",
"detector",
),
)
outputExposure = cT.Output(
doc="Exposure with bright stars subtracted.",
name="{outputExposureName}_calexp",
storageClass="ExposureF",
dimensions=(
"visit",
"detector",
),
)
outputBackgroundExposure = cT.Output(
doc="Exposure containing only the modelled bright stars.",
name="{outputBackgroundName}_calexp_background",
storageClass="ExposureF",
dimensions=(
"visit",
"detector",
),
)

def __init__(self, *, config=None):
super().__init__(config=config)
if not config.doApplySkyCorr:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not 100% sure on when the lines in Connections get triggered. Is there a test case that checks that both values of doApplySkyCorr do what we want it to do?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A future ticket which adds unit testing capability should run this with doApplySkyCorr set to both False and True (i.e., two separate tests).

self.inputs.remove("skyCorr")


class SubtractBrightStarsConfig(PipelineTaskConfig, pipelineConnections=SubtractBrightStarsConnections):
"""Configuration parameters for SubtractBrightStarsTask"""

doWriteSubtractor = Field[bool](
dtype=bool,
doc="Should an exposure containing all bright star models be written to disk?",
default=True,
)
doWriteSubtractedExposure = Field[bool](
dtype=bool,
doc="Should an exposure with bright stars subtracted be written to disk?",
default=True,
)
magLimit = Field[float](
dtype=float,
doc="Magnitude limit, in Gaia G; all stars brighter than this value will be subtracted",
default=18,
)
warpingKernelName = ChoiceField[str](
dtype=str,
doc="Warping kernel",
default="lanczos5",
allowed={
"bilinear": "bilinear interpolation",
"lanczos3": "Lanczos kernel of order 3",
"lanczos4": "Lanczos kernel of order 4",
"lanczos5": "Lanczos kernel of order 5",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be preferable to allow up tolanczos7 as well. In general, the default value being one of the extremes can become inconvenient if we'd like to check the behavior around the default.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above (added 6 and 7).

"lanczos6": "Lanczos kernel of order 6",
"lanczos7": "Lanczos kernel of order 7",
},
)
scalingType = ChoiceField[str](
dtype=str,
doc="How the model should be scaled to each bright star; implemented options are "
"`annularFlux` to reuse the annular flux of each stamp, or `leastSquare` to perform "
"least square fitting on each pixel with no bad mask plane set.",
default="leastSquare",
allowed={
"annularFlux": "reuse BrightStarStamp annular flux measurement",
"leastSquare": "find least square scaling factor",
},
)
badMaskPlanes = ListField[str](
dtype=str,
doc="Mask planes that, if set, lead to associated pixels not being included in the computation of "
"the scaling factor (`BAD` should always be included). Ignored if scalingType is `annularFlux`, "
"as the stamps are expected to already be normalized.",
# Note that `BAD` should always be included, as secondary detected
# sources (i.e., detected sources other than the primary source of
# interest) also get set to `BAD`.
default=("BAD", "CR", "CROSSTALK", "EDGE", "NO_DATA", "SAT", "SUSPECT", "UNMASKEDNAN"),
)
doApplySkyCorr = Field[bool](
dtype=bool,
doc="Apply full focal plane sky correction before extracting stars?",
default=True,
)


class SubtractBrightStarsTask(PipelineTask):
"""Use an extended PSF model to subtract bright stars from a calibrated
exposure (i.e. at single-visit level).

This task uses both a set of bright star stamps produced by
`~lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask`
and an extended PSF model produced by
`~lsst.pipe.tasks.extended_psf.MeasureExtendedPsfTask`.
"""

ConfigClass = SubtractBrightStarsConfig
_DefaultName = "subtractBrightStars"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Placeholders to set up Statistics if scalingType is leastSquare.
self.statsControl, self.statsFlag = None, None

def _setUpStatistics(self, exampleMask):
"""Configure statistics control and flag, for use if ``scalingType`` is
`leastSquare`.
"""
if self.config.scalingType == "leastSquare":
self.statsControl = StatisticsControl()
# Set the mask planes which will be ignored.
andMask = reduce(ior, (exampleMask.getPlaneBitMask(bm) for bm in self.config.badMaskPlanes))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to quibble about this variable name, but calling this andMask when you get it from repeated OR seems weird.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving as-is for now, due to fact that it uses the setAndMask method in statsControl.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's why I called it that. Note there's precedent here - not that that's a good reason to keep using it :)

self.statsControl.setAndMask(andMask)
self.statsFlag = stringToStatisticsProperty("SUM")

def applySkyCorr(self, calexp, skyCorr):
"""Apply correction to the sky background level.
Sky corrections can be generated via the SkyCorrectionTask within the
pipe_tools module. Because the sky model used by that code extends over
the entire focal plane, this can produce better sky subtraction.
The calexp is updated in-place.

Parameters
----------
calexp : `~lsst.afw.image.Exposure` or `~lsst.afw.image.MaskedImage`
Calibrated exposure.
skyCorr : `~lsst.afw.math.backgroundList.BackgroundList`
Full focal plane sky correction, obtained by running
`~lsst.pipe.drivers.skyCorrection.SkyCorrectionTask`.
"""
if isinstance(calexp, Exposure):
calexp = calexp.getMaskedImage()
calexp -= skyCorr.getImage()

def scaleModel(self, model, star, inPlace=True, nb90Rots=0):
"""Compute scaling factor to be applied to the extended PSF so that its
amplitude matches that of an individual star.

Parameters
----------
model : `~lsst.afw.image.MaskedImageF`
The extended PSF model, shifted (and potentially warped) to match
the bright star's positioning.
star : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp`
A stamp centered on the bright star to be subtracted.
inPlace : `bool`
Whether the model should be scaled in place. Default is `True`.
nb90Rots : `int`
The number of 90-degrees rotations to apply to the star stamp.

Returns
-------
scalingFactor : `float`
The factor by which the model image should be multiplied for it
to be scaled to the input bright star.
"""
if self.config.scalingType == "annularFlux":
scalingFactor = star.annularFlux
elif self.config.scalingType == "leastSquare":
if self.statsControl is None:
self._setUpStatistics(star.stamp_im.mask)
starIm = star.stamp_im.clone()
# Rotate the star postage stamp.
starIm = rotateImageBy90(starIm, nb90Rots)
# Reverse the prior star flux normalization ("unnormalize").
starIm *= star.annularFlux
# The estimator of the scalingFactor (f) that minimizes (Y-fX)^2
# is E[XY]/E[XX].
xy = starIm.clone()
xy.image.array *= model.image.array
xx = starIm.clone()
xx.image.array = model.image.array**2
# Compute the least squares scaling factor.
xySum = makeStatistics(xy, self.statsFlag, self.statsControl).getValue()
xxSum = makeStatistics(xx, self.statsFlag, self.statsControl).getValue()
scalingFactor = xySum / xxSum if xxSum else 1
if inPlace:
model.image *= scalingFactor
return scalingFactor

def runQuantum(self, butlerQC, inputRefs, outputRefs):
# Docstring inherited.
inputs = butlerQC.get(inputRefs)
dataId = butlerQC.quantum.dataId
subtractor, _ = self.run(**inputs, dataId=dataId)
if self.config.doWriteSubtractedExposure:
outputExposure = inputs["inputExposure"].clone()
outputExposure.image -= subtractor.image
else:
outputExposure = None
outputBackgroundExposure = subtractor if self.config.doWriteSubtractor else None
output = Struct(outputExposure=outputExposure, outputBackgroundExposure=outputBackgroundExposure)
butlerQC.put(output, outputRefs)

def run(self, inputExposure, inputBrightStarStamps, inputExtendedPsf, dataId, skyCorr=None):
"""Iterate over all bright stars in an exposure to scale the extended
PSF model before subtracting bright stars.

Parameters
----------
inputExposure : `~lsst.afw.image.exposure.exposure.ExposureF`
The image from which bright stars should be subtracted.
inputBrightStarStamps :
`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`
Set of stamps centered on each bright star to be subtracted,
produced by running
`~lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask`.
inputExtendedPsf : `~lsst.pipe.tasks.extended_psf.ExtendedPsf`
Extended PSF model, produced by
`~lsst.pipe.tasks.extended_psf.MeasureExtendedPsfTask`.
dataId : `dict` or `~lsst.daf.butler.DataCoordinate`
The dataId of the exposure (and detector) bright stars should be
subtracted from.
skyCorr : `~lsst.afw.math.backgroundList.BackgroundList`, optional
Full focal plane sky correction, obtained by running
`~lsst.pipe.drivers.skyCorrection.SkyCorrectionTask`. If
`doApplySkyCorr` is set to `True`, `skyCorr` cannot be `None`.

Returns
-------
subtractorExp : `~lsst.afw.image.ExposureF`
An Exposure containing a scaled bright star model fit to every
bright star profile; its image can then be subtracted from the
input exposure.
invImages : `list` [`~lsst.afw.image.MaskedImageF`]
A list of small images ("stamps") containing the model, each scaled
to its corresponding input bright star.
"""
inputExpBBox = inputExposure.getBBox()
if self.config.doApplySkyCorr and (skyCorr is not None):
self.log.info(
"Applying sky correction to exposure %s (exposure will be modified in-place).", dataId
)
self.applySkyCorr(inputExposure, skyCorr)
# Create an empty image the size of the exposure.
# TODO: DM-31085 (set mask planes).
subtractorExp = ExposureF(bbox=inputExposure.getBBox())
subtractor = subtractorExp.maskedImage
# Make a copy of the input model.
model = inputExtendedPsf(dataId["detector"]).clone()
modelStampSize = model.getDimensions()
inv90Rots = 4 - inputBrightStarStamps.nb90Rots % 4
model = rotateImageBy90(model, inv90Rots)
warpCont = WarpingControl(self.config.warpingKernelName)
invImages = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment line above this explaining the purpose of the following for loop, e.g.,
# loop over bright stars, computing the inverse transformed and scaled postage stamp for each

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

# Loop over bright stars, computing the inverse transformed and scaled
# postage stamp for each.
for star in inputBrightStarStamps:
if star.gaiaGMag < self.config.magLimit:
# Set the origin.
model.setXY0(star.position)
# Create an empty destination image.
invTransform = star.archive_element.inverted()
invOrigin = Point2I(invTransform.applyForward(Point2D(star.position)))
bbox = Box2I(corner=invOrigin, dimensions=modelStampSize)
invImage = MaskedImageF(bbox)
# Apply inverse transform.
goodPix = warpImage(invImage, model, invTransform, warpCont)
if not goodPix:
self.log.debug(
f"Warping of a model failed for star {star.gaiaId}: " "no good pixel in output"
)
# Scale the model.
self.scaleModel(invImage, star, inPlace=True, nb90Rots=inv90Rots)
# Replace NaNs before subtraction (note all NaN pixels have
# the NO_DATA flag).
invImage.image.array[np.isnan(invImage.image.array)] = 0
bbox.clip(inputExpBBox)
if bbox.getArea() > 0:
subtractor[bbox] += invImage[bbox]
invImages.append(invImage)
return subtractorExp, invImages