Skip to content

Commit

Permalink
Merge pull request #309 from lsst/tickets/DM-37955
Browse files Browse the repository at this point in the history
DM-37955: Update measureApCorrTask to use robust median outlier rejection.
  • Loading branch information
erykoff committed Mar 7, 2023
2 parents 84290cd + 4c50365 commit a9ef28d
Show file tree
Hide file tree
Showing 2 changed files with 256 additions and 147 deletions.
273 changes: 164 additions & 109 deletions python/lsst/meas/algorithms/measureApCorr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@

__all__ = ("MeasureApCorrConfig", "MeasureApCorrTask")

import numpy
import numpy as np
from scipy.stats import median_abs_deviation

import lsst.pex.config
from lsst.afw.image import ApCorrMap
Expand All @@ -34,32 +35,39 @@
from .sourceSelector import sourceSelectorRegistry


class FluxKeys:
"""A collection of keys for a given flux measurement algorithm
"""
__slots__ = ("flux", "err", "flag", "used") # prevent accidentally adding fields
class _FluxNames:
"""A collection of flux-related names for a given flux measurement algorithm.
Parameters
----------
name : `str`
Name of flux measurement algorithm, e.g. ``base_PsfFlux``.
schema : `lsst.afw.table.Schema`
Catalog schema containing the flux field. The ``{name}_instFlux``,
``{name}_instFluxErr``, ``{name}_flag`` fields are checked for
existence, and the ``apcorr_{name}_used`` field is added.
Raises
------
KeyError if any of instFlux, instFluxErr, or flag fields is missing.
"""
def __init__(self, name, schema):
"""Construct a FluxKeys
Parameters
----------
name : `str`
Name of flux measurement algorithm, e.g. "base_PsfFlux"
schema : `lsst.afw.table.Schema`
Catalog schema containing the flux field
read: {name}_instFlux, {name}_instFluxErr, {name}_flag
added: apcorr_{name}_used
"""
self.flux = schema.find(name + "_instFlux").key
self.err = schema.find(name + "_instFluxErr").key
self.flag = schema.find(name + "_flag").key
self.used = schema.addField("apcorr_" + name + "_used", type="Flag",
doc="set if source was used in measuring aperture correction")
self.fluxName = name + "_instFlux"
if self.fluxName not in schema:
raise KeyError("Could not find " + self.fluxName)
self.errName = name + "_instFluxErr"
if self.errName not in schema:
raise KeyError("Could not find " + self.errName)
self.flagName = name + "_flag"
if self.flagName not in schema:
raise KeyError("Cound not find " + self.flagName)
self.usedName = "apcorr_" + name + "_used"
schema.addField(self.usedName, type="Flag",
doc="Set if source was used in measuring aperture correction.")


class MeasureApCorrConfig(lsst.pex.config.Config):
"""Configuration for MeasureApCorrTask
"""Configuration for MeasureApCorrTask.
"""
refFluxName = lsst.pex.config.Field(
doc="Field name prefix for the flux other measurements should be aperture corrected to match",
Expand All @@ -68,7 +76,7 @@ class MeasureApCorrConfig(lsst.pex.config.Config):
)
sourceSelector = sourceSelectorRegistry.makeField(
doc="Selector that sets the stars that aperture corrections will be measured from.",
default="flagged",
default="science",
)
minDegreesOfFreedom = lsst.pex.config.RangeField(
doc="Minimum number of degrees of freedom (# of valid data points - # of parameters);"
Expand All @@ -79,41 +87,64 @@ class MeasureApCorrConfig(lsst.pex.config.Config):
min=1,
)
fitConfig = lsst.pex.config.ConfigField(
doc="Configuration used in fitting the aperture correction fields",
doc="Configuration used in fitting the aperture correction fields.",
dtype=ChebyshevBoundedFieldConfig,
)
numIter = lsst.pex.config.Field(
doc="Number of iterations for sigma clipping",
doc="Number of iterations for robust MAD sigma clipping.",
dtype=int,
default=4,
)
numSigmaClip = lsst.pex.config.Field(
doc="Number of standard devisations to clip at",
doc="Number of robust MAD sigma to do clipping.",
dtype=float,
default=3.0,
default=4.0,
)
allowFailure = lsst.pex.config.ListField(
doc="Allow these measurement algorithms to fail without an exception",
doc="Allow these measurement algorithms to fail without an exception.",
dtype=str,
default=[],
)

def setDefaults(self):
selector = self.sourceSelector["science"]

selector.doFluxLimit = False
selector.doFlags = True
selector.doUnresolved = True
selector.doSignalToNoise = True
selector.doIsolated = False
selector.flags.good = []
selector.flags.bad = [
"base_PixelFlags_flag_edge",
"base_PixelFlags_flag_interpolatedCenter",
"base_PixelFlags_flag_saturatedCenter",
"base_PixelFlags_flag_crCenter",
"base_PixelFlags_flag_bad",
"base_PixelFlags_flag_interpolated",
"base_PixelFlags_flag_saturated",
]
selector.signalToNoise.minimum = 200.0
selector.signalToNoise.maximum = None
selector.signalToNoise.fluxField = "base_PsfFlux_instFlux"
selector.signalToNoise.errField = "base_PsfFlux_instFluxErr"

def validate(self):
lsst.pex.config.Config.validate(self)
if self.sourceSelector.target.usesMatches:
raise lsst.pex.config.FieldValidationError(
MeasureApCorrConfig.sourceSelector,
self,
"Star selectors that require matches are not permitted")
"Star selectors that require matches are not permitted.")


class MeasureApCorrTask(Task):
r"""Task to measure aperture correction
"""Task to measure aperture correction.
"""
ConfigClass = MeasureApCorrConfig
_DefaultName = "measureApCorr"

def __init__(self, schema, **kwds):
def __init__(self, schema, **kwargs):
"""Construct a MeasureApCorrTask
For every name in lsst.meas.base.getApCorrNameSet():
Expand All @@ -122,12 +153,12 @@ def __init__(self, schema, **kwds):
- Add an entry to the self.toCorrect dict
- Otherwise silently skip the name
"""
Task.__init__(self, **kwds)
self.refFluxKeys = FluxKeys(self.config.refFluxName, schema)
Task.__init__(self, **kwargs)
self.refFluxNames = _FluxNames(self.config.refFluxName, schema)
self.toCorrect = {} # dict of flux field name prefix: FluxKeys instance
for name in sorted(getApCorrNameSet()):
try:
self.toCorrect[name] = FluxKeys(name, schema)
self.toCorrect[name] = _FluxNames(name, schema)
except KeyError:
# if a field in the registry is missing, just ignore it.
pass
Expand Down Expand Up @@ -164,103 +195,127 @@ def run(self, exposure, catalog):
doPause = lsstDebug.Info(__name__).doPause

self.log.info("Measuring aperture corrections for %d flux fields", len(self.toCorrect))

# First, create a subset of the catalog that contains only selected stars
# with non-flagged reference fluxes.
subset1 = [record for record in self.sourceSelector.run(catalog, exposure=exposure).sourceCat
if (not record.get(self.refFluxKeys.flag)
and numpy.isfinite(record.get(self.refFluxKeys.flux)))]
selected = self.sourceSelector.run(catalog, exposure=exposure)

use = (
~selected.sourceCat[self.refFluxNames.flagName]
& (np.isfinite(selected.sourceCat[self.refFluxNames.fluxName]))
)
goodRefCat = selected.sourceCat[use].copy()

apCorrMap = ApCorrMap()

# Outer loop over the fields we want to correct
for name, keys in self.toCorrect.items():
fluxName = name + "_instFlux"
fluxErrName = name + "_instFluxErr"

for name, fluxNames in self.toCorrect.items():
# Create a more restricted subset with only the objects where the to-be-correct flux
# is not flagged.
fluxes = numpy.fromiter((record.get(keys.flux) for record in subset1), float)
with numpy.errstate(invalid="ignore"): # suppress NAN warnings
isGood = numpy.logical_and.reduce([
numpy.fromiter((not record.get(keys.flag) for record in subset1), bool),
numpy.isfinite(fluxes),
fluxes > 0.0,
])
subset2 = [record for record, good in zip(subset1, isGood) if good]

# Check that we have enough data points that we have at least the minimum of degrees of
# freedom specified in the config.
if len(subset2) - 1 < self.config.minDegreesOfFreedom:
fluxes = goodRefCat[fluxNames.fluxName]
with np.errstate(invalid="ignore"): # suppress NaN warnings.
isGood = (
(~goodRefCat[fluxNames.flagName])
& (np.isfinite(fluxes))
& (fluxes > 0.0)
)

# The 1 is the minimum number of ctrl.computeSize() when the order
# drops to 0 in both x and y.
if (isGood.sum() - 1) < self.config.minDegreesOfFreedom:
if name in self.config.allowFailure:
self.log.warning("Unable to measure aperture correction for '%s': "
"only %d sources, but require at least %d.",
name, len(subset2), self.config.minDegreesOfFreedom + 1)
"only %d sources, but require at least %d." %
(name, isGood.sum(), self.config.minDegreesOfFreedom + 1))
continue
raise RuntimeError("Unable to measure aperture correction for required algorithm '%s': "
"only %d sources, but require at least %d." %
(name, len(subset2), self.config.minDegreesOfFreedom + 1))
else:
raise RuntimeError("Unable to measure aperture correction for required algorithm '%s': "
"only %d sources, but require at least %d." %
(name, isGood.sum(), self.config.minDegreesOfFreedom + 1))

goodCat = goodRefCat[isGood].copy()

x = goodCat['slot_Centroid_x']
y = goodCat['slot_Centroid_y']
z = goodCat[self.refFluxNames.fluxName]/goodCat[fluxNames.fluxName]
ids = goodCat['id']

# We start with an initial fit that is the median offset; this
# works well in practice.
fitValues = np.median(z)

# If we don't have enough data points to constrain the fit, reduce the order until we do
ctrl = self.config.fitConfig.makeControl()
while len(subset2) - ctrl.computeSize() < self.config.minDegreesOfFreedom:
if ctrl.orderX > 0:
ctrl.orderX -= 1
if ctrl.orderY > 0:
ctrl.orderY -= 1

# Fill numpy arrays with positions and the ratio of the reference flux to the to-correct flux
x = numpy.zeros(len(subset2), dtype=float)
y = numpy.zeros(len(subset2), dtype=float)
apCorrData = numpy.zeros(len(subset2), dtype=float)
indices = numpy.arange(len(subset2), dtype=int)
for n, record in enumerate(subset2):
x[n] = record.getX()
y[n] = record.getY()
apCorrData[n] = record.get(self.refFluxKeys.flux)/record.get(keys.flux)

for _i in range(self.config.numIter):

# Do the fit, save it in the output map
apCorrField = ChebyshevBoundedField.fit(bbox, x, y, apCorrData, ctrl)

if display:
plotApCorr(bbox, x, y, apCorrData, apCorrField, "%s, iteration %d" % (name, _i), doPause)

# Compute errors empirically, using the RMS difference between the true reference flux and the
# corrected to-be-corrected flux.
apCorrDiffs = apCorrField.evaluate(x, y)
apCorrDiffs -= apCorrData
apCorrErr = numpy.mean(apCorrDiffs**2)**0.5

# Clip bad data points
apCorrDiffLim = self.config.numSigmaClip * apCorrErr
with numpy.errstate(invalid="ignore"): # suppress NAN warning
keep = numpy.fabs(apCorrDiffs) <= apCorrDiffLim
x = x[keep]
y = y[keep]
apCorrData = apCorrData[keep]
indices = indices[keep]

# Final fit after clipping
apCorrField = ChebyshevBoundedField.fit(bbox, x, y, apCorrData, ctrl)
allBad = False
for iteration in range(self.config.numIter):
resid = z - fitValues
# We add a small (epsilon) amount of floating-point slop because
# the median_abs_deviation may give a value that is just larger than 0
# even if given a completely flat residual field (as in tests).
apCorrErr = median_abs_deviation(resid, scale="normal") + 1e-7
keep = np.abs(resid) <= self.config.numSigmaClip * apCorrErr

self.log.info("Aperture correction for %s: RMS %f from %d",
name, numpy.mean((apCorrField.evaluate(x, y) - apCorrData)**2)**0.5, len(indices))
self.log.debug("Removing %d sources as outliers.", len(resid) - keep.sum())

x = x[keep]
y = y[keep]
z = z[keep]
ids = ids[keep]

while (len(x) - ctrl.computeSize()) < self.config.minDegreesOfFreedom:
if ctrl.orderX > 0:
ctrl.orderX -= 1
else:
allBad = True
break
if ctrl.orderY > 0:
ctrl.orderY -= 1
else:
allBad = True
break

if allBad:
if name in self.config.allowFailure:
self.log.warning("Unable to measure aperture correction for '%s': "
"only %d sources remain, but require at least %d." %
(name, keep.sum(), self.config.minDegreesOfFreedom + 1))
break
else:
raise RuntimeError("Unable to measure aperture correction for required algorithm "
"'%s': only %d sources remain, but require at least %d." %
(name, keep.sum(), self.config.minDegreesOfFreedom + 1))

apCorrField = ChebyshevBoundedField.fit(bbox, x, y, z, ctrl)
fitValues = apCorrField.evaluate(x, y)

if allBad:
continue

self.log.info(
"Aperture correction for %s from %d stars: MAD %f, RMS %f",
name,
median_abs_deviation(fitValues - z, scale="normal"),
np.mean((fitValues - z)**2.)**0.5,
len(x),
)

if display:
plotApCorr(bbox, x, y, apCorrData, apCorrField, "%s, final" % (name,), doPause)
plotApCorr(bbox, x, y, z, apCorrField, "%s, final" % (name,), doPause)

# Record which sources were used.
used = np.zeros(len(catalog), dtype=bool)
used[np.searchsorted(catalog['id'], ids)] = True
catalog[fluxNames.usedName] = used

# Save the result in the output map
# The error is constant spatially (we could imagine being
# more clever, but we're not yet sure if it's worth the effort).
# We save the errors as a 0th-order ChebyshevBoundedField
apCorrMap[fluxName] = apCorrField
apCorrErrCoefficients = numpy.array([[apCorrErr]], dtype=float)
apCorrMap[fluxErrName] = ChebyshevBoundedField(bbox, apCorrErrCoefficients)

# Record which sources were used
for i in indices:
subset2[i].set(keys.used, True)
apCorrMap[fluxNames.fluxName] = apCorrField
apCorrMap[fluxNames.errName] = ChebyshevBoundedField(
bbox,
np.array([[apCorrErr]]),
)

return Struct(
apCorrMap=apCorrMap,
Expand Down

0 comments on commit a9ef28d

Please sign in to comment.