Skip to content

Commit

Permalink
Add tests and associated data files for global gbdes tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
cmsaunders committed Nov 21, 2023
1 parent bdd09ab commit aa5f7bf
Show file tree
Hide file tree
Showing 7 changed files with 171,065 additions and 62 deletions.
66 changes: 37 additions & 29 deletions python/lsst/drp/tasks/gbdesAstrometricFit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,19 @@
import lsst.pipe.base as pipeBase
import lsst.sphgeom
import numpy as np
from sklearn.cluster import AgglomerativeClustering
from smatch.matcher import Matcher
import wcsfit
import yaml
from lsst.meas.algorithms import (
LoadReferenceObjectsConfig,
ReferenceObjectLoader,
ReferenceSourceSelectorTask,
ReferenceSourceSelectorTask
)
from lsst.meas.algorithms.sourceSelector import sourceSelectorRegistry
from lsst.skymap import BaseSkyMap
from sklearn.cluster import AgglomerativeClustering
from smatch.matcher import Matcher

__all__ = ['GbdesAstrometricFitConnections', 'GbdesAstrometricFitConfig', 'GbdesAstrometricFitTask',
'GbdesGlobalAstrometricFitTask']
'GbdesGlobalAstrometricFitConfig', 'GbdesGlobalAstrometricFitTask']


def _make_ref_covariance_matrix(
Expand Down Expand Up @@ -686,7 +685,8 @@ def _get_exposure_info(
ras.append(raDec.getRa().asRadians())
decs.append(raDec.getDec().asRadians())
if fieldRegions:
inField = np.flatnonzero([region.contains(raDec.getVector()) for region in fieldRegions.values()])
inField = np.flatnonzero([region.contains(raDec.getVector()) for region in
fieldRegions.values()])
inField2 = [r for r, region in fieldRegions.items() if region.contains(raDec.getVector())]
assert inField[0] == inField2[0]
if len(inField) != 1:
Expand Down Expand Up @@ -1087,7 +1087,7 @@ def _add_objects(self, wcsf, inputCatalogRefs, sourceIndices, extensionInfo, col
"""
for inputCatalogRef in inputCatalogRefs:
visit = inputCatalogRef.dataId["visit"]
inputCatalog = inputCatalogRef.get(parameters={"columns": columns + ["ra", "dec"]})
inputCatalog = inputCatalogRef.get(parameters={"columns": columns})
detectors = np.unique(inputCatalog["detector"])

for detector in detectors:
Expand Down Expand Up @@ -1352,11 +1352,18 @@ def getSpatialBoundsConnections(self):

class GbdesGlobalAstrometricFitConfig(GbdesAstrometricFitConfig,
pipelineConnections=GbdesGlobalAstrometricFitConnections):
pass
visitOverlap = pexConfig.Field(
dtype=float,
default=1.0,
doc=("The linkage distance threshold above which clusters of visits "
"will not be merged. Calculated using the minimum distance "
"between all visits in each set, as a fraction of the "
"field-of-view radius")
)


class GbdesGlobalAstrometricFitTask(GbdesAstrometricFitTask):
"""Calibrate the WCS across multiple visits of the same field using the
"""Calibrate the WCS across multiple visits and multiple fields using the
GBDES package.
"""

Expand All @@ -1376,16 +1383,16 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputRefCatRefs = [inputRefs.referenceCatalog[htm7] for htm7 in inputRefHtm7s.argsort()]
inputRefCats = np.array([inputRefCat.dataId['htm7'] for inputRefCat in inputs['referenceCatalog']])
inputs['referenceCatalog'] = [inputs['referenceCatalog'][v] for v in inputRefCats.argsort()]
inputISSTracts = np.array([isolatedStarSource.dataId['tract'] for isolatedStarSource in
inputISSTracts = np.array([isolatedStarSource.dataId['tract'] for isolatedStarSource in
inputs['isolatedStarSources']])
inputISCTracts = np.array([isolatedStarCatalog.dataId['tract'] for isolatedStarCatalog in
inputISCTracts = np.array([isolatedStarCatalog.dataId['tract'] for isolatedStarCatalog in
inputs['isolatedStarCatalogs']])
for tract in inputISCTracts:
if tract not in inputISSTracts:
raise RuntimeError(f"tract {tract} in isolated_star_cats but not isolated_star_sources")
inputs['isolatedStarSources'] = np.array([inputs['isolatedStarSources'][t] for t in
inputs['isolatedStarSources'] = np.array([inputs['isolatedStarSources'][t] for t in
inputISSTracts.argsort()])
inputs['isolatedStarCatalogs'] = np.array([inputs['isolatedStarCatalogs'][t] for t in
inputs['isolatedStarCatalogs'] = np.array([inputs['isolatedStarCatalogs'][t] for t in
inputISSTracts.argsort()])

sampleRefCat = inputs['referenceCatalog'][0].get()
Expand Down Expand Up @@ -1476,8 +1483,9 @@ def run(self, inputVisitSummaries, isolatedStarSources, isolatedStarCatalogs,
verbose = 1
else:
verbose = 2
#import ipdb; ipdb.set_trace()
# Set up the WCS-fitting class using the results of the FOF associator

# Set up the WCS-fitting class using the source matches from the
# isolated star sources plus the reference catalog.
wcsf = wcsfit.WCSFit(fields, [instrument], exposuresHelper,
extensionInfo.visitIndex, extensionInfo.detectorIndex,
inputYAML, extensionInfo.wcs, associations.sequence, associations.extn,
Expand Down Expand Up @@ -1546,8 +1554,8 @@ def _prep_sky(self, inputVisitSummaries):
mjds.append(obsMJD)

# Find groups of visits where any one of the visits overlaps another by
# at least half the field of view.
distance = np.median(radii)
# a given fraction of the field-of-view radius.
distance = self.config.visitOverlap * np.median(radii)
clustering = AgglomerativeClustering(distance_threshold=distance.asDegrees(), n_clusters=None,
linkage='single')
clusters = clustering.fit(np.array(radecs))
Expand Down Expand Up @@ -1579,8 +1587,8 @@ def _prep_sky(self, inputVisitSummaries):

return fields, fieldRegions

def _associate_from_isolated_sources(self, isolatedStarSourceRefs, isolatedStarCatalogRefs, extensionInfo,
refObjects):
def _associate_from_isolated_sources(self, isolatedStarSourceRefs, isolatedStarCatalogRefs,
extensionInfo, refObjects):
"""Match the input catalog of isolated stars with the reference catalog
and transform the combined isolated star sources and reference source
into the format needed for gbdes.
Expand Down Expand Up @@ -1609,22 +1617,24 @@ def _associate_from_isolated_sources(self, isolatedStarSourceRefs, isolatedStarC
extensions = []
object_indices = []

sourceColumns = ['x', 'y', 'xErr', 'yErr', 'ixx', 'ixy', 'iyy', 'obj_index', 'visit', 'detector']
catalogColumns = ['ra', 'dec']

sourceDict = dict([(visit, {}) for visit in np.unique(extensionInfo.visit)])
for (visit, detector) in zip(extensionInfo.visit, extensionInfo.detector):
sourceDict[visit][detector] = {'x': [], 'y': [], 'xCov': [], 'yCov': [], 'xyCov': []}

for (isolatedStarCatalogRef, isolatedStarSourceRef) in zip(isolatedStarCatalogRefs,
isolatedStarSourceRefs):
isolatedStarCatalog = isolatedStarCatalogRef.get()
isolatedStarSources = isolatedStarSourceRef.get()
isolatedStarCatalog = isolatedStarCatalogRef.get(parameters={'columns': catalogColumns})
isolatedStarSources = isolatedStarSourceRef.get(parameters={'columns': sourceColumns})
if len(isolatedStarCatalog) == 0:
continue

# Match the reference stars to the existing isolated stars, then
# insert the reference stars into the isolated star sources.
allVisits = np.copy(isolatedStarSources['visit'])
allDetectors = np.copy(isolatedStarSources['detector'])
allSourceRows = np.copy(isolatedStarSources['source_row'])
allObjectIndices = np.copy(isolatedStarSources['obj_index'])
issIndices = np.copy(isolatedStarSources.index)
for f, regionRefObjects in refObjects.items():
Expand All @@ -1633,8 +1643,8 @@ def _associate_from_isolated_sources(self, isolatedStarSourceRefs, isolatedStarC
with Matcher(isolatedStarCatalog['ra'].to_numpy(),
isolatedStarCatalog['dec'].to_numpy()) as matcher:
idx, i1, i2, d = matcher.query_radius(
np.array(regionRefObjects['ra']), np.array(regionRefObjects['dec']), self.config.matchRadius/3600.,
return_indices=True,
np.array(regionRefObjects['ra']), np.array(regionRefObjects['dec']),
self.config.matchRadius/3600., return_indices=True,
)

refSort = np.searchsorted(isolatedStarSources['obj_index'], i1)
Expand All @@ -1644,7 +1654,6 @@ def _associate_from_isolated_sources(self, isolatedStarSourceRefs, isolatedStarC

allVisits = np.insert(allVisits, refSort, refVisit)
allDetectors = np.insert(allDetectors, refSort, refDetector)
allSourceRows = np.insert(allSourceRows, refSort, i2)
allObjectIndices = np.insert(allObjectIndices, refSort, i1)
issIndices = np.insert(issIndices, refSort, i2)

Expand All @@ -1654,7 +1663,8 @@ def _associate_from_isolated_sources(self, isolatedStarSourceRefs, isolatedStarC
extensionIndex = np.flatnonzero((extensionInfo.visit == vis)
& (extensionInfo.detector == det))
if len(extensionIndex) == 0:
# This happens for runs where you are not using all the visits on a tract
# This happens for runs where you are not using all the
# visits on a tract
continue
else:
extensionIndex = extensionIndex[0]
Expand Down Expand Up @@ -1698,8 +1708,6 @@ def _add_objects(self, wcsf, sourceDict, extensionInfo):
Dictionary containing the source centroids for each visit.
extensionInfo : `lsst.pipe.base.Struct`
Struct containing properties for each extension.
columns : `list` of `str`
List of columns needed from source tables.
"""
for visit, visitSources in sourceDict.items():
# Visit numbers equal or below zero connote the reference catalog.
Expand All @@ -1712,6 +1720,6 @@ def _add_objects(self, wcsf, sourceDict, extensionInfo):
& (extensionInfo.detector == detector))[0]

d = {'x': np.array(sourceCat['x']), 'y': np.array(sourceCat['y']),
'xCov': np.array(sourceCat['xCov']), 'yCov': np.array(sourceCat['yCov']),
'xCov': np.array(sourceCat['xCov']), 'yCov': np.array(sourceCat['yCov']),
'xyCov': np.array(sourceCat['xyCov'])}
wcsf.setObjects(extensionIndex, d, 'x', 'y', ['xCov', 'yCov', 'xyCov'])

0 comments on commit aa5f7bf

Please sign in to comment.