Skip to content

Commit

Permalink
Recover WCS for input with astrometry failures
Browse files Browse the repository at this point in the history
  • Loading branch information
cmsaunders committed Dec 15, 2023
1 parent b73e2dd commit f88401a
Showing 1 changed file with 56 additions and 18 deletions.
74 changes: 56 additions & 18 deletions python/lsst/drp/tasks/gbdesAstrometricFit.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def run(
self.log.info("Fit the WCSs")
# Set up a YAML-type string using the config variables and a sample
# visit
inputYAML = self.make_yaml(inputVisitSummaries[0])
inputYAML, mapTemplate = self.make_yaml(inputVisitSummaries[0])

# Set the verbosity level for WCSFit from the task log level.
# TODO: DM-36850, Add lsst.log to gbdes so that log messages are
Expand Down Expand Up @@ -590,7 +590,7 @@ def run(
)
self.log.info("WCS fitting done")

outputWCSs = self._make_outputs(wcsf, inputVisitSummaries, exposureInfo)
outputWCSs = self._make_outputs(wcsf, inputVisitSummaries, exposureInfo, mapTemplate=mapTemplate)
outputCatalog = wcsf.getOutputCatalog()
starCatalog = wcsf.getStarCatalog()
modelParams = self._compute_model_params(wcsf) if self.config.saveModelParams else None
Expand Down Expand Up @@ -630,6 +630,7 @@ def _prep_sky(self, inputVisitSummaries, epoch, fieldName="Field"):
detectorCorners = [
lsst.geom.SpherePoint(ra, dec, lsst.geom.degrees).getVector()
for (ra, dec) in zip(visSum["raCorners"].ravel(), visSum["decCorners"].ravel())
if (np.isfinite(ra) and (np.isfinite(dec)))
]
allDetectorCorners.extend(detectorCorners)
boundingCircle = lsst.sphgeom.ConvexPolygon.convexHull(allDetectorCorners).getBoundingCircle()
Expand Down Expand Up @@ -738,6 +739,23 @@ def _get_exposure_info(

for row in visitSummary:
detector = row["id"]

wcs = row.getWcs()
if wcs is None:
self.log.warning(
"WCS is None for visit %d, detector %d: this extension will be dropped.",
visit,
detector,
)
continue
else:
wcsRA = wcs.getSkyOrigin().getRa().asRadians()
wcsDec = wcs.getSkyOrigin().getDec().asRadians()
tangentPoint = wcsfit.Gnomonic(wcsRA, wcsDec)
mapping = wcs.getFrameDict().getMapping("PIXELS", "IWC")
gbdes_wcs = wcsfit.Wcs(wcsfit.ASTMap(mapping), tangentPoint)
wcss.append(gbdes_wcs)

if detector not in detectors:
detectors.append(detector)
detectorBounds = wcsfit.Bounds(
Expand All @@ -752,14 +770,6 @@ def _get_exposure_info(
extensionDetectors.append(detector)
extensionType.append("SCIENCE")

wcs = row.getWcs()
wcsRA = wcs.getSkyOrigin().getRa().asRadians()
wcsDec = wcs.getSkyOrigin().getDec().asRadians()
tangentPoint = wcsfit.Gnomonic(wcsRA, wcsDec)
mapping = wcs.getFrameDict().getMapping("PIXELS", "IWC")
gbdes_wcs = wcsfit.Wcs(wcsfit.ASTMap(mapping), tangentPoint)
wcss.append(gbdes_wcs)

fieldNumbers = list(np.ones(len(exposureNames), dtype=int) * fieldNumber)
instrumentNumbers = list(np.ones(len(exposureNames), dtype=int) * instrumentNumber)

Expand Down Expand Up @@ -999,9 +1009,14 @@ class `wcsfit.FoFClass`, associating them into matches as you go.
goodInds = selected.selected & goodShapes

isStar = np.ones(goodInds.sum())
extensionIndex = np.flatnonzero(
findExtension = np.flatnonzero(
(extensionInfo.visit == visit) & (extensionInfo.detector == detector)
)[0]
)
if len(findExtension) == 0:
# This extension does not have information necessary for
# fit. Skip these detections.
continue
extensionIndex = findExtension[0]
detectorIndex = extensionInfo.detectorIndex[extensionIndex]
visitIndex = extensionInfo.visitIndex[extensionIndex]

Expand Down Expand Up @@ -1087,6 +1102,8 @@ def make_yaml(self, inputVisitSummary, inputFile=None):
-------
inputYAML : `wcsfit.YAMLCollector`
YAML object containing the model description.
inputDict : `dict` [`str`, `str`]
Dictionary containing the model description.
"""
if inputFile is not None:
inputYAML = wcsfit.YAMLCollector(inputFile, "PixelMapCollection")
Expand Down Expand Up @@ -1137,7 +1154,7 @@ def make_yaml(self, inputVisitSummary, inputFile=None):
inputYAML.addInput(yaml.dump(inputDict))
inputYAML.addInput("Identity:\n Type: Identity\n")

return inputYAML
return inputYAML, inputDict

def _add_objects(self, wcsf, inputCatalogRefs, sourceIndices, extensionInfo, columns):
"""Add science sources to the wcsfit.WCSFit object.
Expand All @@ -1164,9 +1181,15 @@ def _add_objects(self, wcsf, inputCatalogRefs, sourceIndices, extensionInfo, col
for detector in detectors:
detectorSources = inputCatalog[inputCatalog["detector"] == detector]

extensionIndex = np.flatnonzero(
findExtension = np.flatnonzero(
(extensionInfo.visit == visit) & (extensionInfo.detector == detector)
)[0]
)
if len(findExtension) == 0:
# This extension does not have information necessary for
# fit. Skip these detections.
continue
extensionIndex = findExtension[0]

sourceCat = detectorSources[sourceIndices[extensionIndex]]

xCov = sourceCat["xErr"] ** 2
Expand Down Expand Up @@ -1288,7 +1311,7 @@ def _make_afw_wcs(self, mapDict, centerRA, centerDec, doNormalizePixels=False, x
outWCS = afwgeom.SkyWcs(frameDict)
return outWCS

def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo):
def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo, mapTemplate=None):
"""Make a WCS object out of the WCS models.
Parameters
Expand Down Expand Up @@ -1335,8 +1358,23 @@ def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo):

for d, detector in enumerate(visitSummary["id"]):
mapName = f"{visit}/{detector}"

mapElements = wcsf.mapCollection.orderAtoms(f"{mapName}/base")
if mapName in wcsf.mapCollection.allMapNames():
mapElements = wcsf.mapCollection.orderAtoms(f"{mapName}/base")
else:
# This extension was not fit, but try to recover WCS
genericElements = mapTemplate["EXPOSURE/DEVICE/base"]["Elements"]
mapElements = []
instrument = visitSummary[0].getVisitInfo().getInstrumentLabel()
for component in genericElements:
elements = mapTemplate[component]["Elements"]
for element in elements:
for generic, specific in {
"BAND": instrument,
"EXPOSURE": visit,
"DEVICE": detector,
}.items():
element = element.replace(generic, str(specific))
mapElements.append(element)
mapDict = {}
for m, mapElement in enumerate(mapElements):
mapType = wcsf.mapCollection.getMapType(mapElement)
Expand Down

0 comments on commit f88401a

Please sign in to comment.