Skip to content

Commit

Permalink
Add bright star subtraction task (LSK edits 2023)
Browse files Browse the repository at this point in the history
  • Loading branch information
leeskelvin committed Mar 16, 2023
1 parent 4071124 commit e7e24ad
Showing 1 changed file with 111 additions and 86 deletions.
197 changes: 111 additions & 86 deletions python/lsst/pipe/tasks/subtractBrightStars.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,35 +23,49 @@
single visit) level.
"""

__all__ = ["SubtractBrightStarsTask"]
__all__ = ["SubtractBrightStarsConnections", "SubtractBrightStarsConfig", "SubtractBrightStarsTask"]

import numpy as np
from operator import ior
from functools import reduce
from operator import ior

from lsst.pipe import base as pipeBase
import numpy as np
from lsst.afw.image import Exposure, ExposureF, MaskedImageF
from lsst.afw.math import (
StatisticsControl,
WarpingControl,
makeStatistics,
rotateImageBy90,
stringToStatisticsProperty,
warpImage,
)
from lsst.geom import Box2I, Point2D, Point2I
from lsst.pex.config import ChoiceField, Field, ListField
from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct
from lsst.pipe.base import connectionTypes as cT
from lsst.pex import config as pexConfig
from lsst.afw import math as afwMath
from lsst.afw import image as afwImage
from lsst import geom


class SubtractBrightStarsConnections(pipeBase.PipelineTaskConnections,
dimensions=("instrument", "visit", "detector"),
defaultTemplates={"outputExposureName": "brightStar_subtracted",
"outputBackgroundName": "brightStars"}):
class SubtractBrightStarsConnections(
PipelineTaskConnections,
dimensions=("instrument", "visit", "detector"),
defaultTemplates={"outputExposureName": "brightStar_subtracted", "outputBackgroundName": "brightStars"},
):
inputExposure = cT.Input(
doc="Input exposure from which to subtract bright star stamps.",
name="calexp",
storageClass="ExposureF",
dimensions=("visit", "detector",),
dimensions=(
"visit",
"detector",
),
)
inputBrightStarStamps = cT.Input(
doc="Set of preprocessed postage stamps, each centered on a single bright star.",
name="brightStarStamps",
storageClass="BrightStarStamps",
dimensions=("visit", "detector",),
dimensions=(
"visit",
"detector",
),
)
inputExtendedPsf = cT.Input(
doc="Extended PSF model.",
Expand All @@ -63,19 +77,29 @@ class SubtractBrightStarsConnections(pipeBase.PipelineTaskConnections,
doc="Input Sky Correction to be subtracted from the calexp if ``doApplySkyCorr``=True.",
name="skyCorr",
storageClass="Background",
dimensions=("instrument", "visit", "detector",),
dimensions=(
"instrument",
"visit",
"detector",
),
)
outputExposure = cT.Output(
doc="Exposure with bright stars subtracted.",
name="{outputExposureName}_calexp",
storageClass="ExposureF",
dimensions=("visit", "detector",),
dimensions=(
"visit",
"detector",
),
)
outputBackgroundExposure = cT.Output(
doc="Exposure containing only the modelled bright stars.",
name="{outputBackgroundName}_calexp_background",
storageClass="ExposureF",
dimensions=("visit", "detector",),
dimensions=(
"visit",
"detector",
),
)

def __init__(self, *, config=None):
Expand All @@ -84,26 +108,25 @@ def __init__(self, *, config=None):
self.inputs.remove("skyCorr")


class SubtractBrightStarsConfig(pipeBase.PipelineTaskConfig,
pipelineConnections=SubtractBrightStarsConnections):
"""Configuration parameters for SubtractBrightStarsTask
"""
doWriteSubtractor = pexConfig.Field(
class SubtractBrightStarsConfig(PipelineTaskConfig, pipelineConnections=SubtractBrightStarsConnections):
"""Configuration parameters for SubtractBrightStarsTask"""

doWriteSubtractor = Field(
dtype=bool,
doc="Should an exposure containing all bright star models be written to disk?",
default=True
default=True,
)
doWriteSubtractedExposure = pexConfig.Field(
doWriteSubtractedExposure = Field(
dtype=bool,
doc="Should an exposure with bright stars subtracted be written to disk?",
default=True
default=True,
)
magLimit = pexConfig.Field(
magLimit = Field(
dtype=float,
doc="Magnitude limit, in Gaia G; all stars brighter than this value will be subtracted",
default=18
default=18,
)
warpingKernelName = pexConfig.ChoiceField(
warpingKernelName = ChoiceField(
dtype=str,
doc="Warping kernel",
default="lanczos5",
Expand All @@ -114,37 +137,35 @@ class SubtractBrightStarsConfig(pipeBase.PipelineTaskConfig,
"lanczos5": "Lanczos kernel of order 5",
"lanczos6": "Lanczos kernel of order 6",
"lanczos7": "Lanczos kernel of order 7",
}
},
)
scalingType = pexConfig.ChoiceField(
scalingType = ChoiceField(
dtype=str,
doc="How the model should be scaled to each bright star; implemented options are "
"`annularFlux` to reuse the annular flux of each stamp, or `leastSquare` to perform "
"least square fitting on each pixel with no bad mask plane set.",
"`annularFlux` to reuse the annular flux of each stamp, or `leastSquare` to perform "
"least square fitting on each pixel with no bad mask plane set.",
default="leastSquare",
allowed={
"annularFlux": "reuse BrightStarStamp annular flux measurement",
"leastSquare": "find least square scaling factor",
}
},
)
badMaskPlanes = pexConfig.ListField(
badMaskPlanes = ListField(
dtype=str,
doc="Mask planes that, if set, lead to associated pixels not being included in the computation of "
"the scaling factor (`BAD` should always be included). Ignored if scalingType is `annularFlux`, "
"as the stamps are expected to already be normalized.",
# Note that `BAD` should always be included, as secondary detected
# sources (i.e., detected sources other than the primary source of
# interest) also get set to `BAD`.
default=('BAD', 'CR', 'CROSSTALK', 'EDGE', 'NO_DATA', 'SAT', 'SUSPECT', 'UNMASKEDNAN')
"the scaling factor (`BAD` should always be included). Ignored if scalingType is `annularFlux`, "
"as the stamps are expected to already be normalized.",
# Note that `BAD` should always be included, as secondary detected
# sources (i.e., detected sources other than the primary source of
# interest) also get set to `BAD`.
default=("BAD", "CR", "CROSSTALK", "EDGE", "NO_DATA", "SAT", "SUSPECT", "UNMASKEDNAN"),
)
doApplySkyCorr = pexConfig.Field(
dtype=bool,
doc="Apply full focal plane sky correction before extracting stars?",
default=True
doApplySkyCorr = Field(
dtype=bool, doc="Apply full focal plane sky correction before extracting stars?", default=True
)


class SubtractBrightStarsTask(pipeBase.PipelineTask, pipeBase.CmdLineTask):
class SubtractBrightStarsTask(PipelineTask):
"""Use an extended PSF model to subtract bright stars from a calibrated
exposure (i.e. at single-visit level).
Expand All @@ -153,11 +174,12 @@ class SubtractBrightStarsTask(pipeBase.PipelineTask, pipeBase.CmdLineTask):
and an extended PSF model produced by
`~lsst.pipe.tasks.extended_psf.MeasureExtendedPsfTask`.
"""

ConfigClass = SubtractBrightStarsConfig
_DefaultName = "subtractBrightStars"

def __init__(self, initInputs=None, *args, **kwargs):
super().__init__(self, *args, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Placeholders to set up Statistics if scalingType is leastSquare.
self.statsControl, self.statsFlag = None, None

Expand All @@ -166,11 +188,11 @@ def _setUpStatistics(self, exampleMask):
`leastSquare`.
"""
if self.config.scalingType == "leastSquare":
self.statsControl = afwMath.StatisticsControl()
self.statsControl = StatisticsControl()
# Set the mask planes which will be ignored.
andMask = reduce(ior, (exampleMask.getPlaneBitMask(bm) for bm in self.config.badMaskPlanes))
self.statsControl.setAndMask(andMask)
self.statsFlag = afwMath.stringToStatisticsProperty("SUM")
self.statsFlag = stringToStatisticsProperty("SUM")

def applySkyCorr(self, calexp, skyCorr):
"""Apply correction to the sky background level.
Expand All @@ -188,7 +210,7 @@ def applySkyCorr(self, calexp, skyCorr):
Full focal plane sky correction, obtained by running
`~lsst.pipe.drivers.skyCorrection.SkyCorrectionTask`.
"""
if isinstance(calexp, afwImage.Exposure):
if isinstance(calexp, Exposure):
calexp = calexp.getMaskedImage()
calexp -= skyCorr.getImage()

Expand Down Expand Up @@ -221,7 +243,7 @@ def scaleModel(self, model, star, inPlace=True, nb90Rots=0):
self._setUpStatistics(star.stamp_im.mask)
starIm = star.stamp_im.clone()
# Rotate the star postage stamp.
starIm = afwMath.rotateImageBy90(starIm, nb90Rots)
starIm = rotateImageBy90(starIm, nb90Rots)
# Reverse the prior star flux normalization ("unnormalize").
starIm *= star.annularFlux
# The estimator of the scalingFactor (f) that minimizes (Y-fX)^2
Expand All @@ -231,17 +253,33 @@ def scaleModel(self, model, star, inPlace=True, nb90Rots=0):
xx = starIm.clone()
xx.image.array = model.image.array**2
# Compute the least squares scaling factor.
xySum = afwMath.makeStatistics(xy, self.statsFlag, self.statsControl).getValue()
xxSum = afwMath.makeStatistics(xx, self.statsFlag, self.statsControl).getValue()
xySum = makeStatistics(xy, self.statsFlag, self.statsControl).getValue()
xxSum = makeStatistics(xx, self.statsFlag, self.statsControl).getValue()
scalingFactor = xySum / xxSum if xxSum else 0
else:
raise AttributeError(f'Unknown scalingType "{self.config.scalingType}"; implemented options '
'are "annularFlux" and "leastSquare".')
raise AttributeError(
f'Unknown scalingType "{self.config.scalingType}"; implemented options '
'are "annularFlux" and "leastSquare".'
)
if inPlace:
model.image *= scalingFactor
return scalingFactor

def run(self, inputExposure, inputBrightStarStamps, inputExtendedPsf, skyCorr=None, dataId=None):
def runQuantum(self, butlerQC, inputRefs, outputRefs):
"""Run a single quantum of the task."""
inputs = butlerQC.get(inputRefs)
dataId = butlerQC.quantum.dataId
subtractor, _ = self.run(**inputs, dataId=dataId)
if self.config.doWriteSubtractedExposure:
outputExposure = inputs["inputExposure"].clone()
outputExposure.image -= subtractor.image
else:
outputExposure = None
outputBackgroundExposure = subtractor if self.config.doWriteSubtractor else None
output = Struct(outputExposure=outputExposure, outputBackgroundExposure=outputBackgroundExposure)
butlerQC.put(output, outputRefs)

def run(self, inputExposure, inputBrightStarStamps, inputExtendedPsf, dataId, skyCorr=None):
"""Iterate over all bright stars in an exposure to scale the extended
PSF model before subtracting bright stars.
Expand All @@ -257,14 +295,13 @@ def run(self, inputExposure, inputBrightStarStamps, inputExtendedPsf, skyCorr=No
inputExtendedPsf : `~lsst.pipe.tasks.extended_psf.ExtendedPsf`
Extended PSF model, produced by
`~lsst.pipe.tasks.extended_psf.MeasureExtendedPsfTask`.
skyCorr : `~lsst.afw.math.backgroundList.BackgroundList` or `None`,
optional
Full focal plane sky correction, obtained by running
`~lsst.pipe.drivers.skyCorrection.SkyCorrectionTask`. If
`doApplySkyCorr` is set to `True`, `skyCorr` cannot be `None`.
dataId : `dict` or `~lsst.daf.butler.DataCoordinate`
The dataId of the exposure (and detector) bright stars should be
subtracted from.
skyCorr : `~lsst.afw.math.backgroundList.BackgroundList`, optional
Full focal plane sky correction, obtained by running
`~lsst.pipe.drivers.skyCorrection.SkyCorrectionTask`. If
`doApplySkyCorr` is set to `True`, `skyCorr` cannot be `None`.
Returns
-------
Expand All @@ -278,19 +315,20 @@ def run(self, inputExposure, inputBrightStarStamps, inputExtendedPsf, skyCorr=No
"""
inputExpBBox = inputExposure.getBBox()
if self.config.doApplySkyCorr and (skyCorr is not None):
self.log.info("Applying sky correction to exposure %s (exposure will be modified in-place).",
dataId)
self.log.info(
"Applying sky correction to exposure %s (exposure will be modified in-place).", dataId
)
self.applySkyCorr(inputExposure, skyCorr)
# Create an empty image the size of the exposure.
# TODO: DM-31085 (set mask planes).
subtractorExp = afwImage.ExposureF(bbox=inputExposure.getBBox())
subtractorExp = ExposureF(bbox=inputExposure.getBBox())
subtractor = subtractorExp.maskedImage
# Make a copy of the input model.
model = inputExtendedPsf(dataId["detector"]).clone()
modelStampSize = model.getDimensions()
inv90Rots = 4 - inputBrightStarStamps.nb90Rots
model = afwMath.rotateImageBy90(model, inv90Rots)
warpCont = afwMath.WarpingControl(self.config.warpingKernelName)
model = rotateImageBy90(model, inv90Rots)
warpCont = WarpingControl(self.config.warpingKernelName)
invImages = []
# Loop over bright stars, computing the inverse transformed and scaled
# postage stamp for each.
Expand All @@ -300,14 +338,15 @@ def run(self, inputExposure, inputBrightStarStamps, inputExtendedPsf, skyCorr=No
model.setXY0(star.position)
# Create an empty destination image.
invTransform = star.archive_element.inverted()
invOrigin = geom.Point2I(invTransform.applyForward(geom.Point2D(star.position)))
bbox = geom.Box2I(corner=invOrigin, dimensions=modelStampSize)
invImage = afwImage.MaskedImageF(bbox)
invOrigin = Point2I(invTransform.applyForward(Point2D(star.position)))
bbox = Box2I(corner=invOrigin, dimensions=modelStampSize)
invImage = MaskedImageF(bbox)
# Apply inverse transform.
goodPix = afwMath.warpImage(invImage, model, invTransform, warpCont)
goodPix = warpImage(invImage, model, invTransform, warpCont)
if not goodPix:
self.log.debug(f"Warping of a model failed for star {star.gaiaId}: "
"no good pixel in output")
self.log.debug(
f"Warping of a model failed for star {star.gaiaId}: " "no good pixel in output"
)
# Scale the model.
self.scaleModel(invImage, star, inPlace=True, nb90Rots=inv90Rots)
# Replace NaNs before subtraction (note all NaN pixels have
Expand All @@ -318,17 +357,3 @@ def run(self, inputExposure, inputBrightStarStamps, inputExtendedPsf, skyCorr=No
subtractor[bbox] += invImage[bbox]
invImages.append(invImage)
return subtractorExp, invImages

def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputs = butlerQC.get(inputRefs)
dataId = butlerQC.quantum.dataId
subtractor, _ = self.run(**inputs, dataId=dataId)
if self.config.doWriteSubtractedExposure:
outputExposure = inputs["inputExposure"].clone()
outputExposure.image -= subtractor.image
else:
outputExposure = None
outputBackgroundExposure = subtractor if self.config.doWriteSubtractor else None
output = pipeBase.Struct(outputExposure=outputExposure,
outputBackgroundExposure=outputBackgroundExposure)
butlerQC.put(output, outputRefs)

0 comments on commit e7e24ad

Please sign in to comment.