Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-35292: Replace MeasurementError with flag handler #10

Merged
merged 2 commits into from
Jun 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 19 additions & 11 deletions python/lsst/meas/extensions/trailedSources/NaivePlugin.py
Expand Up @@ -21,16 +21,15 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#

import logging
import numpy as np
import scipy.optimize as sciOpt
from scipy.special import erf

import lsst.log
from lsst.geom import Point2D
from lsst.meas.base.pluginRegistry import register
from lsst.meas.base import SingleFramePlugin, SingleFramePluginConfig
from lsst.meas.base import FlagHandler, FlagDefinitionList, SafeCentroidExtractor
from lsst.meas.base import MeasurementError

from ._trailedSources import VeresModel
from .utils import getMeasurementCutout
Expand Down Expand Up @@ -87,8 +86,10 @@ def getExecutionOrder(cls):
# VeresPlugin is run after, which requires image data.
return cls.APCORR_ORDER + 0.1

def __init__(self, config, name, schema, metadata):
super().__init__(config, name, schema, metadata)
def __init__(self, config, name, schema, metadata, logName=None):
if logName is None:
logName = __name__
super().__init__(config, name, schema, metadata, logName=logName)

# Measurement Keys
self.keyRa = schema.addField(name + "_ra", type="D", doc="Trail centroid right ascension.")
Expand Down Expand Up @@ -117,14 +118,15 @@ def __init__(self, config, name, schema, metadata):
self.keyAngleErr = schema.addField(name + "_angleErr", type="D", doc="Trail angle error.")

flagDefs = FlagDefinitionList()
flagDefs.addFailureFlag("No trailed-source measured")
self.FAILURE = flagDefs.addFailureFlag("No trailed-source measured")
self.NO_FLUX = flagDefs.add("flag_noFlux", "No suitable prior flux measurement")
self.NO_CONVERGE = flagDefs.add("flag_noConverge", "The root finder did not converge")
self.NO_SIGMA = flagDefs.add("flag_noSigma", "No PSF width (sigma)")
self.SAFE_CENTROID = flagDefs.add("flag_safeCentroid", "Fell back to safe centroid extractor")
self.flagHandler = FlagHandler.addFields(schema, name, flagDefs)

self.centriodExtractor = SafeCentroidExtractor(schema, name)
self.log = logging.getLogger(self.logName)

def measure(self, measRecord, exposure):
"""Run the Naive trailed source measurement algorithm.
Expand All @@ -147,7 +149,9 @@ def measure(self, measRecord, exposure):
yc = measRecord.get("base_SdssShape_y")
if not np.isfinite(xc) or not np.isfinite(yc):
xc, yc = self.centroidExtractor(measRecord, self.flagHandler)
raise MeasurementError(self.SAFE_CENTROID.doc, self.SAFE_CENTROID.number)
self.flagHandler.setValue(measRecord, self.SAFE_CENTROID.number)
self.flagHandler.setValue(measRecord, self.FAILURE.number)
return

ra, dec = self.computeRaDec(exposure, xc, yc)

Expand All @@ -163,14 +167,16 @@ def measure(self, measRecord, exposure):
# Measure the trail length
# Check if the second-moments are weighted
if measRecord.get("base_SdssShape_flag_unweighted"):
lsst.log.debug("Unweighed")
self.log.debug("Unweighted")
length, gradLength = self.computeLength(a2, b2)
else:
lsst.log.debug("Weighted")
self.log.debug("Weighted")
length, gradLength, results = self.findLength(a2, b2)
if not results.converged:
lsst.log.info(results.flag)
raise MeasurementError(self.NO_CONVERGE.doc, self.NO_CONVERGE.number)
self.log.info("Results not converged: %s", results.flag)
self.flagHandler.setValue(measRecord, self.NO_CONVERGE.number)
self.flagHandler.setValue(measRecord, self.FAILURE.number)
return

# Compute the angle of the trail from the x-axis
theta = 0.5 * np.arctan2(2.0 * Ixy, xmy)
Expand All @@ -197,7 +203,9 @@ def measure(self, measRecord, exposure):
if np.isfinite(measRecord.getApInstFlux()):
flux = measRecord.getApInstFlux()
else:
raise MeasurementError(self.NO_FLUX.doc, self.NO_FLUX.number)
self.flagHandler.setValue(measRecord, self.NO_FLUX.number)
self.flagHandler.setValue(measRecord, self.FAILURE.number)
return

# Propogate errors from second moments and centroid
IxxErr2, IyyErr2, IxyErr2 = np.diag(measRecord.getShapeErr())
Expand Down
15 changes: 9 additions & 6 deletions python/lsst/meas/extensions/trailedSources/VeresPlugin.py
Expand Up @@ -29,7 +29,6 @@
from lsst.meas.base.pluginRegistry import register
from lsst.meas.base import SingleFramePlugin, SingleFramePluginConfig
from lsst.meas.base import FlagHandler, FlagDefinitionList, SafeCentroidExtractor
from lsst.meas.base import MeasurementError

from ._trailedSources import VeresModel
from .NaivePlugin import SingleFrameNaiveTrailPlugin
Expand Down Expand Up @@ -95,8 +94,8 @@ def getExecutionOrder(cls):
# Make sure this always runs after NaivePlugin.
return SingleFrameNaiveTrailPlugin.getExecutionOrder() + 0.1

def __init__(self, config, name, schema, metadata):
super().__init__(config, name, schema, metadata)
def __init__(self, config, name, schema, metadata, logName=None):
super().__init__(config, name, schema, metadata, logName=logName)

self.keyXC = schema.addField(
name + "_centroid_x", type="D", doc="Trail centroid X coordinate.", units="pixel")
Expand All @@ -112,7 +111,7 @@ def __init__(self, config, name, schema, metadata):
self.keyRChiSq = schema.addField(name + "_rChiSq", type="D", doc="Reduced chi-squared of fit")

flagDefs = FlagDefinitionList()
flagDefs.addFailureFlag("No trailed-sources measured")
self.FAILURE = flagDefs.addFailureFlag("No trailed-sources measured")
self.NON_CONVERGE = flagDefs.add("flag_nonConvergence", "Optimizer did not converge")
self.NO_NAIVE = flagDefs.add("flag_noNaive", "Naive measurement contains NaNs")
self.flagHandler = FlagHandler.addFields(schema, name, flagDefs)
Expand Down Expand Up @@ -141,7 +140,9 @@ def measure(self, measRecord, exposure):
length = measRecord.get("ext_trailedSources_Naive_length")
theta = measRecord.get("ext_trailedSources_Naive_angle")
if not np.isfinite(flux) or not np.isfinite(length) or not np.isfinite(theta):
raise MeasurementError(self.NO_NAIVE.doc, self.NO_NAIVE.number)
self.flagHandler.setValue(measRecord, self.NO_NAIVE.number)
self.flagHandler.setValue(measRecord, self.FAILURE.number)
return

# Get exposure cutout
# sigma = exposure.getPsf().getSigma()
Expand All @@ -158,7 +159,9 @@ def measure(self, measRecord, exposure):

# Check if optimizer converged
if not results.success:
raise MeasurementError(self.NON_CONVERGE.doc, self.NON_CONVERGE.number)
self.flagHandler.setValue(measRecord, self.NON_CONVERGE.number)
self.flagHandler.setValue(measRecord, self.FAILURE.number)
return

# Calculate end points and reduced chi-squared
xc_fit, yc_fit, flux_fit, length_fit, theta_fit = results.x
Expand Down
1 change: 0 additions & 1 deletion tests/test_trailedSources.py
Expand Up @@ -33,7 +33,6 @@
from lsst.meas.extensions.trailedSources.utils import getMeasurementCutout
from lsst.utils.tests import classParameters

import lsst.log

# Trailed-source length, angle, and centroid.
rng = np.random.default_rng(432)
Expand Down