Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-42981: Convert AstrometryTask to new exception handling system #190

Merged
merged 7 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions python/lsst/meas/astrom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .ref_match import *
from .astrometry import *
from .approximateWcs import *
from .exceptions import *
from .match_probabilistic_task import *
from .matcher_probabilistic import *
from .matchPessimisticB import *
Expand Down
187 changes: 80 additions & 107 deletions python/lsst/meas/astrom/astrometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
from lsst.utils.timer import timeMethod
from . import exceptions
from .ref_match import RefMatchTask, RefMatchConfig
from .fitTanSipWcs import FitTanSipWcsTask
from .display import displayAstrometry
Expand Down Expand Up @@ -61,11 +62,11 @@ class AstrometryConfig(RefMatchConfig):
min=0,
)
maxMeanDistanceArcsec = pexConfig.RangeField(
doc="Maximum mean on-sky distance (in arcsec) between matched source and rerference "
"objects post-fit. A mean distance greater than this threshold raises a TaskError "
"and the WCS fit is considered a failure. The default is set to the maximum tolerated "
"by the external global calibration (e.g. jointcal) step for conceivable recovery. "
"Appropriate value will be dataset and workflow dependent.",
doc="Maximum mean on-sky distance (in arcsec) between matched source and reference "
"objects post-fit. A mean distance greater than this threshold raises BadAstrometryFit "
"and the WCS fit is considered a failure. The default is set to the maximum tolerated "
"by the external global calibration (e.g. jointcal) step for conceivable recovery; "
"the appropriate value will be dataset and workflow dependent.",
dtype=float,
default=0.5,
min=0,
Expand Down Expand Up @@ -206,7 +207,7 @@ def solve(self, exposure, sourceCat):

Raises
------
TaskError
BadAstrometryFit
If the measured mean on-sky distance between the matched source and
reference objects is greater than
``self.config.maxMeanDistanceArcsec``.
Expand All @@ -220,134 +221,106 @@ def solve(self, exposure, sourceCat):
import lsstDebug
debug = lsstDebug.Info(__name__)

expMd = self._getExposureMetadata(exposure)
epoch = exposure.visitInfo.date.toAstropy()

sourceSelection = self.sourceSelector.run(sourceCat)

self.log.info("Purged %d sources, leaving %d good sources",
len(sourceCat) - len(sourceSelection.sourceCat),
len(sourceSelection.sourceCat))

loadRes = self.refObjLoader.loadPixelBox(
bbox=expMd.bbox,
wcs=expMd.wcs,
filterName=expMd.filterName,
epoch=expMd.epoch,
loadResult = self.refObjLoader.loadPixelBox(
bbox=exposure.getBBox(),
wcs=exposure.wcs,
filterName=exposure.filter.bandLabel,
epoch=epoch,
)

refSelection = self.referenceSelector.run(loadRes.refCat)

matchMeta = self.refObjLoader.getMetadataBox(
bbox=expMd.bbox,
wcs=expMd.wcs,
filterName=expMd.filterName,
epoch=expMd.epoch,
)
refSelection = self.referenceSelector.run(loadResult.refCat)

if debug.display:
frame = int(debug.frame)
displayAstrometry(
refCat=refSelection.sourceCat,
sourceCat=sourceSelection.sourceCat,
exposure=exposure,
bbox=expMd.bbox,
bbox=exposure.getBBox(),
frame=frame,
title="Reference catalog",
)

res = None
wcs = expMd.wcs
match_tolerance = None
fitFailed = False
for i in range(self.config.maxIter):
if not fitFailed:
iterNum = i + 1
try:
tryRes = self._matchAndFitWcs(
refCat=refSelection.sourceCat,
sourceCat=sourceCat,
goodSourceCat=sourceSelection.sourceCat,
refFluxField=loadRes.fluxField,
bbox=expMd.bbox,
wcs=wcs,
exposure=exposure,
match_tolerance=match_tolerance,
)
except Exception as e:
# If we have had a succeessful iteration then use that;
# otherwise fail.
if i > 0:
self.log.info("Fit WCS iter %d failed; using previous iteration: %s", iterNum, e)
iterNum -= 1
break
else:
self.log.info("Fit WCS iter %d failed: %s" % (iterNum, e))
fitFailed = True

if not fitFailed:
match_tolerance = tryRes.match_tolerance
tryMatchDist = self._computeMatchStatsOnSky(tryRes.matches)
self.log.debug(
"Match and fit WCS iteration %d: found %d matches with on-sky distance mean and "
"scatter = %0.3f +- %0.3f arcsec; max match distance = %0.3f arcsec",
iterNum, len(tryRes.matches), tryMatchDist.distMean.asArcseconds(),
tryMatchDist.distStdDev.asArcseconds(), tryMatchDist.maxMatchDist.asArcseconds())

maxMatchDist = tryMatchDist.maxMatchDist
res = tryRes
wcs = res.wcs
if maxMatchDist.asArcseconds() < self.config.minMatchDistanceArcSec:
self.log.debug(
"Max match distance = %0.3f arcsec < %0.3f = config.minMatchDistanceArcSec; "
"that's good enough",
maxMatchDist.asArcseconds(), self.config.minMatchDistanceArcSec)
break
match_tolerance.maxMatchDist = maxMatchDist

if not fitFailed:
self.log.info("Matched and fit WCS in %d iterations; "
"found %d matches with mean and scatter = %0.3f +- %0.3f arcsec" %
(iterNum, len(tryRes.matches), tryMatchDist.distMean.asArcseconds(),
tryMatchDist.distStdDev.asArcseconds()))
if tryMatchDist.distMean.asArcseconds() > self.config.maxMeanDistanceArcsec:
self.log.info("Assigning as a fit failure: mean on-sky distance = %0.3f arcsec > %0.3f "
"(maxMeanDistanceArcsec)" % (tryMatchDist.distMean.asArcseconds(),
self.config.maxMeanDistanceArcsec))
fitFailed = True

if fitFailed:
self.log.warning("WCS fit failed. Setting exposure's WCS to None and coord_ra & coord_dec "
"cols in sourceCat to nan.")
sourceCat["coord_ra"] = np.nan
sourceCat["coord_dec"] = np.nan
exposure.setWcs(None)
matches = None
scatterOnSky = None
else:
for m in res.matches:
if self.usedKey:
m.second.set(self.usedKey, True)
exposure.setWcs(res.wcs)
matches = res.matches
scatterOnSky = res.scatterOnSky
result = pipeBase.Struct(matchTolerance=None)
maxMatchDistance = np.inf
i = 0
while (maxMatchDistance > self.config.minMatchDistanceArcSec and i < self.config.maxIter):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised we don't consider it an error (or at least some sort of qualified success) if we hit the maximum number of iterations without satisfying the max distance criteria. If we add that, then I think the old form of the loop may be closer to what we want; it's a very natural for-else:

for i in range(self.config.maxIter):
    ... # do fit
    if maxMatchDistance <= self.config.minMatchDistanceArcSec:
        # unqualified success!
        break
else:
    self.log.warning("Maximum iterations exceeded...")

Of course, for-else isn't exactly super familiar to people, and that makes me wonder if wrapping the whole loop up in a separate method that can return or raise an already-updated exception would be even better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like for-else in general, but I'm not sure it's right here. I read minMatchDistanceArcSec as "definitely stop here, that's plenty good", and not a quality threshold. But also, on DM-44012, I'm going to look at whether this outer loop is even necessary at all: I believe that with the pessimistic matcher and the affine fitter, there's no need for this outer loop (and the ticket it was added on, DM-2755, has an argument between Paul and Russell about whether we should even do that loop. We now have a much better matcher, so the original rational mostly doesn't exist.

AstrometryTask has grown organically without an overarching vision, and I think there are too many different kinds of stop/quality criteria. That's a bigger question than this ticket, though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, fine to leave this as-is.

i += 1
try:
result = self._matchAndFitWcs(
refCat=refSelection.sourceCat,
sourceCat=sourceCat,
goodSourceCat=sourceSelection.sourceCat,
refFluxField=loadResult.fluxField,
bbox=exposure.getBBox(),
wcs=exposure.wcs,
exposure=exposure,
matchTolerance=result.matchTolerance,
)
exposure.setWcs(result.wcs)
except exceptions.AstrometryError as e:
e._metadata['iterations'] = i
sourceCat["coord_ra"] = np.nan
sourceCat["coord_dec"] = np.nan
exposure.setWcs(None)
self.log.error("Failure fitting astrometry. %s: %s", type(e).__name__, e)
raise

result.stats = self._computeMatchStatsOnSky(result.matches)
maxMatchDistance = result.stats.maxMatchDist.asArcseconds()
distMean = result.stats.distMean.asArcseconds()
distStdDev = result.stats.distMean.asArcseconds()
self.log.info("Astrometric fit iteration %d: found %d matches with mean separation "
"= %0.3f +- %0.3f arcsec; max match distance = %0.3f arcsec.",
i, len(result.matches), distMean, distStdDev, maxMatchDistance)

# If fitter converged, record the scatter in the exposure metadata
# even if the fit was deemed a failure according to the value of
# the maxMeanDistanceArcsec config.
if res is not None:
md = exposure.getMetadata()
md['SFM_ASTROM_OFFSET_MEAN'] = tryMatchDist.distMean.asArcseconds()
md['SFM_ASTROM_OFFSET_STD'] = tryMatchDist.distStdDev.asArcseconds()
md = exposure.getMetadata()
md['SFM_ASTROM_OFFSET_MEAN'] = distMean
md['SFM_ASTROM_OFFSET_STD'] = distStdDev

# Poor quality fits are a failure.
if distMean > self.config.maxMeanDistanceArcsec:
exception = exceptions.BadAstrometryFit(nMatches=len(result.matches), iterations=i,
distMean=distMean,
maxMeanDist=self.config.maxMeanDistanceArcsec,
distMedian=result.scatterOnSky.asArcseconds())
exposure.setWcs(None)
sourceCat["coord_ra"] = np.nan
sourceCat["coord_dec"] = np.nan
self.log.error(exception)
raise exception

if self.usedKey:
for m in result.matches:
m.second.set(self.usedKey, True)

matchMeta = self.refObjLoader.getMetadataBox(
bbox=exposure.getBBox(),
wcs=exposure.wcs,
filterName=exposure.filter.bandLabel,
epoch=epoch,
)

return pipeBase.Struct(
refCat=refSelection.sourceCat,
matches=matches,
scatterOnSky=scatterOnSky,
matches=result.matches,
scatterOnSky=result.scatterOnSky,
matchMeta=matchMeta,
)

@timeMethod
def _matchAndFitWcs(self, refCat, sourceCat, goodSourceCat, refFluxField, bbox, wcs, match_tolerance,
def _matchAndFitWcs(self, refCat, sourceCat, goodSourceCat, refFluxField, bbox, wcs, matchTolerance,
exposure=None):
"""Match sources to reference objects and fit a WCS.

Expand All @@ -365,11 +338,11 @@ def _matchAndFitWcs(self, refCat, sourceCat, goodSourceCat, refFluxField, bbox,
bounding box of exposure
wcs : `lsst.afw.geom.SkyWcs`
initial guess for WCS of exposure
match_tolerance : `lsst.meas.astrom.MatchTolerance`
matchTolerance : `lsst.meas.astrom.MatchTolerance`
a MatchTolerance object (or None) specifying
internal tolerances to the matcher. See the MatchTolerance
definition in the respective matcher for the class definition.
exposure : `lsst.afw.image.Exposure`
exposure : `lsst.afw.image.Exposure`, optional
exposure whose WCS is to be fit, or None; used only for the debug
display.

Expand All @@ -395,7 +368,7 @@ def _matchAndFitWcs(self, refCat, sourceCat, goodSourceCat, refFluxField, bbox,
wcs=wcs,
sourceFluxField=sourceFluxField,
refFluxField=refFluxField,
match_tolerance=match_tolerance,
matchTolerance=matchTolerance,
)
self.log.debug("Found %s matches", len(matchRes.matches))
if debug.display:
Expand Down Expand Up @@ -442,7 +415,7 @@ def _matchAndFitWcs(self, refCat, sourceCat, goodSourceCat, refFluxField, bbox,
matches=matches,
wcs=fitWcs,
scatterOnSky=scatterOnSky,
match_tolerance=matchRes.match_tolerance,
matchTolerance=matchRes.matchTolerance,
)

def _removeMagnitudeOutliers(self, sourceFluxField, refFluxField, matchesIn):
Expand Down
79 changes: 79 additions & 0 deletions python/lsst/meas/astrom/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# This file is part of meas_astrom.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

__all__ = ["AstrometryError", "AstrometryFitFailure", "BadAstrometryFit", "MatcherFailure"]

import lsst.pipe.base


class AstrometryError(lsst.pipe.base.AlgorithmError):
"""Parent class for failures in astrometric fitting.

Parameters
----------
msg : `str`
Informative message about the nature of the error.
**kwargs
All other arguments are added to a ``_metadata`` attribute, which is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels a little strange to have the private attribute's name included in the docs like this. Is the name not actually an implementation detail?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm torn on that: I directly modify _metadata in AstrometryTask to add more information to exceptions that came from the matcher or fitter. I could name it something non-"private", but metadata (the read-only property) is the most obvious, but taken, choice.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I see that it's effectively "protected" rather than just "private", and those are always in a nebulous state in Python already. Fine to leave it as-is.

used to generate the metadata property for Task annotation.
"""
def __init__(self, msg, **kwargs):
self.msg = msg
self._metadata = kwargs
super().__init__(msg, kwargs)

def __str__(self):
# Exception doesn't handle **kwargs, so we need a custom str.
return f"{self.msg}: {self.metadata}"

@property
def metadata(self):
for key, value in self._metadata.items():
if not (isinstance(value, int) or isinstance(value, float) or isinstance(value, str)):
raise TypeError(f"{key} is of type {type(value)}, but only (int, float, str) are allowed.")
return self._metadata


class BadAstrometryFit(AstrometryError):
"""Raised if the quality of the astrometric fit is worse than some
threshold.

Parameters
----------
distMean : `float`
Mean on-sky separation of matched sources, in arcseconds.
distMedian : `float`
Median on-sky separation of matched sources, in arcseconds.
"""
def __init__(self, distMean, maxMeanDist, distMedian, **kwargs):
msg = f'Poor quality astrometric fit, {distMean}" > {maxMeanDist}"'
super().__init__(msg, **kwargs)
self._metadata["distMean"] = distMean
self._metadata["maxMeanDist"] = distMean
self._metadata["distMedian"] = distMedian


class AstrometryFitFailure(AstrometryError):
"""Raised if the astrometry fitter fails."""


class MatcherFailure(AstrometryError):
"""Raised if the matcher fails."""
3 changes: 2 additions & 1 deletion python/lsst/meas/astrom/fitSipDistortion.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
SipForwardTransform, SipReverseTransform,
makeMatchStatisticsInRadians, makeWcs)

from . import exceptions
from .setMatchDistance import setMatchDistance


Expand Down Expand Up @@ -252,7 +253,7 @@ def fitWcs(self, matches, initWcs, bbox=None, refCat=None, sourceCat=None, expos
scatterOnSky = stats.getValue()*lsst.geom.radians

if scatterOnSky.asArcseconds() > self.config.maxScatterArcsec:
raise lsst.pipe.base.TaskError(
raise exceptions.AstrometryFitFailure(
"Fit failed: median scatter on sky = %0.3f arcsec > %0.3f config.maxScatterArcsec" %
(scatterOnSky.asArcseconds(), self.config.maxScatterArcsec))

Expand Down