Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 45 additions & 89 deletions python/lsst/analysis/tools/tasks/associatedSourcesTractAnalysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,13 @@
import astropy.units as u
import lsst.pex.config as pexConfig
import numpy as np
import pandas as pd
from astropy.table import Table, join, vstack
from astropy.table import join, vstack
from lsst.daf.butler import DatasetProvenance
from lsst.drp.tasks.gbdesAstrometricFit import calculate_apparent_motion
from lsst.geom import Box2D
from lsst.pipe.base import NoWorkFound
from lsst.pipe.base import connectionTypes as ct
from lsst.skymap import BaseSkyMap
from smatch import Matcher

from ..interfaces import AnalysisBaseConfig, AnalysisBaseConnections, AnalysisPipelineTask

Expand Down Expand Up @@ -87,14 +85,12 @@ class AssociatedSourcesTractAnalysisConnections(
dimensions=("instrument",),
isCalibration=True,
)

astrometricCorrectionCatalogs = ct.Input(
doc="Catalog containing proper motion and parallax.",
name="gbdesAstrometricFit_starCatalog",
storageClass="ArrowNumpyDict",
dimensions=("instrument", "skymap", "tract", "physical_filter"),
multiple=True,
astrometricCorrectionCatalog = ct.Input(
doc="Catalog with proper motion and parallax information.",
name="isolated_star_stellar_motions",
storageClass="ArrowAstropy",
deferLoad=True,
dimensions=("instrument", "skymap", "tract"),
)

visitTable = ct.Input(
Expand All @@ -108,7 +104,7 @@ def __init__(self, *, config=None):
super().__init__(config=config)

if not config.applyAstrometricCorrections:
self.inputs.remove("astrometricCorrectionCatalogs")
self.inputs.remove("astrometricCorrectionCatalog")
self.inputs.remove("visitTable")


Expand All @@ -120,23 +116,16 @@ class AssociatedSourcesTractAnalysisConfig(
default=True,
doc="Apply proper motion and parallax corrections to source positions.",
)
matchingRadius = pexConfig.Field(
dtype=float,
default=0.2,
doc=(
"Radius in mas with which to match the mean positions of the sources with the positions in the"
" astrometricCorrectionCatalogs."
),
)
astrometricCorrectionParameters = pexConfig.DictField(
keytype=str,
itemtype=str,
default={
"ra": "ra",
"dec": "dec",
"pmRA": "pm_ra",
"pmDec": "pm_dec",
"pmRA": "raPM",
"pmDec": "decPM",
"parallax": "parallax",
"isolated_star_id": "isolated_star_id",
},
doc="Column names for position and motion parameters in the astrometric correction catalogs.",
)
Expand All @@ -162,7 +151,7 @@ def callback(self, inputs, dataId):
inputs["sourceCatalogs"],
inputs["associatedSources"],
inputs["associatedSourceIds"],
inputs["astrometricCorrectionCatalogs"],
inputs["astrometricCorrectionCatalog"],
inputs["visitTable"],
)

Expand All @@ -173,7 +162,7 @@ def prepareAssociatedSources(
sourceCatalogs,
associatedSources,
associatedSourceIds,
astrometricCorrectionCatalogs=None,
astrometricCorrectionCatalog=None,
visitTable=None,
):
"""Concatenate source catalogs and join on associated source IDs."""
Expand All @@ -194,8 +183,8 @@ def prepareAssociatedSources(
sourceCatalogStack = vstack(sourceCatalogs, join_type="exact")
dataJoined = join(sourceCatalogStack, associatedSources, keys="sourceId", join_type="inner")

if astrometricCorrectionCatalogs is not None:
self.applyAstrometricCorrections(dataJoined, astrometricCorrectionCatalogs, visitTable)
if astrometricCorrectionCatalog is not None:
self.applyAstrometricCorrections(dataJoined, astrometricCorrectionCatalog, visitTable)

# Determine which sources are contained in tract
ra = np.radians(dataJoined["coord_ra"])
Expand All @@ -211,16 +200,16 @@ def prepareAssociatedSources(

return dataFiltered

def applyAstrometricCorrections(self, dataJoined, astrometricCorrectionCatalogs, visitTable):
def applyAstrometricCorrections(self, dataJoined, astrometricCorrectionCatalog, visitTable):
"""Use proper motion/parallax catalogs to shift positions to median
epoch of the visits.

Parameters
----------
dataJoined : `astropy.table.Table`
Table containing source positions, which will be modified in place.
astrometricCorrectionCatalogs: `dict` [`pd.DataFrame`]
Dictionary keyed by band with proper motion and parallax catalogs.
astrometricCorrectionCatalog : `astropy.table.Table`
Proper motion and parallax catalog.
visitTable : `pd.DataFrame`
Table containing the MJDs of the visits.
"""
Expand All @@ -229,63 +218,33 @@ def applyAstrometricCorrections(self, dataJoined, astrometricCorrectionCatalogs,
# the table was written originally as a DataFrame or something else
# Parquet-friendly.
visitTable.set_index("visitId", inplace=True)
for band in np.unique(dataJoined["band"]):
bandInd = dataJoined["band"] == band
bandSources = dataJoined[bandInd]
# Add key for sorting below.
bandSources["__index__"] = np.arange(len(bandSources))
bandSourcesDf = bandSources.to_pandas()
meanRAs = bandSourcesDf.groupby("obj_index")["coord_ra"].aggregate("mean")
meanDecs = bandSourcesDf.groupby("obj_index")["coord_dec"].aggregate("mean")

bandPMs = astrometricCorrectionCatalogs[band]
with Matcher(meanRAs, meanDecs) as m:
idx, i1, i2, d = m.query_radius(
bandPMs[self.config.astrometricCorrectionParameters["ra"]],
bandPMs[self.config.astrometricCorrectionParameters["dec"]],
(self.config.matchingRadius * u.mas).to(u.degree),
return_indices=True,
)

catRAs = np.zeros_like(meanRAs)
catDecs = np.zeros_like(meanRAs)
pmRAs = np.zeros_like(meanRAs)
pmDecs = np.zeros_like(meanRAs)
parallaxes = np.zeros(len(meanRAs))
catRAs[i1] = bandPMs[self.config.astrometricCorrectionParameters["ra"]][i2]
catDecs[i1] = bandPMs[self.config.astrometricCorrectionParameters["dec"]][i2]
pmRAs[i1] = bandPMs[self.config.astrometricCorrectionParameters["pmRA"]][i2]
pmDecs[i1] = bandPMs[self.config.astrometricCorrectionParameters["pmDec"]][i2]
parallaxes[i1] = bandPMs[self.config.astrometricCorrectionParameters["parallax"]][i2]

pmDf = Table(
{
"ra": catRAs * u.degree,
"dec": catDecs * u.degree,
"pmRA": pmRAs * u.mas / u.yr,
"pmDec": pmDecs * u.mas / u.yr,
"parallax": parallaxes * u.mas,
"obj_index": meanRAs.index,
}
)

dataWithPM = join(bandSources, pmDf, keys="obj_index", join_type="left")
# Get the stellar motion catalog into the right format:
for key, value in self.config.astrometricCorrectionParameters.items():
astrometricCorrectionCatalog.rename_column(value, key)
astrometricCorrectionCatalog["ra"] *= u.degree
astrometricCorrectionCatalog["dec"] *= u.degree
astrometricCorrectionCatalog["pmRA"] *= u.mas / u.yr
astrometricCorrectionCatalog["pmDec"] *= u.mas / u.yr
astrometricCorrectionCatalog["parallax"] *= u.mas

dataWithPM = join(
dataJoined,
astrometricCorrectionCatalog,
keys="isolated_star_id",
join_type="left",
keep_order=True,
)

visits = bandSourcesDf["visit"].unique()
mjds = [visitTable.loc[visit]["expMidptMJD"] for visit in visits]
mjdTable = Table(
[astropy.time.Time(mjds, format="mjd", scale="tai"), visits], names=["MJD", "visit"]
)
dataWithMJD = join(dataWithPM, mjdTable, keys="visit", join_type="left")
# After astropy 7.0, it should be possible to use "keep_order=True"
# in the join and avoid sorting.
dataWithMJD.sort("__index__")
medianMJD = astropy.time.Time(np.median(mjds), format="mjd", scale="tai")
mjds = visitTable.loc[dataWithPM["visit"]]["expMidptMJD"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably outside the scope of this ticket but can we one day get the visitTable to be astropy and remove the uses of pandas.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would do it if I had finished 12 hours earlier...I think we can get rid of the rest of the pandas pretty easily on another ticket now though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that feeling, a future ticket would be great, thank you.

times = astropy.time.Time(mjds, format="mjd", scale="tai")
dataWithPM["MJD"] = times
medianMJD = astropy.time.Time(np.median(mjds), format="mjd", scale="tai")

raCorrection, decCorrection = calculate_apparent_motion(dataWithMJD, medianMJD)
raCorrection, decCorrection = calculate_apparent_motion(dataWithPM, medianMJD)

dataJoined["coord_ra"][bandInd] = dataWithMJD["coord_ra"] - raCorrection.value
dataJoined["coord_dec"][bandInd] = dataWithMJD["coord_dec"] - decCorrection.value
dataJoined["coord_ra"] = dataWithPM["coord_ra"] - raCorrection.value
dataJoined["coord_dec"] = dataWithPM["coord_dec"] - decCorrection.value

def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputs = butlerQC.get(inputRefs)
Expand All @@ -303,15 +262,12 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputs["sourceCatalogs"] = sourceCatalogs

if self.config.applyAstrometricCorrections:
astrometricCorrections = {}
for pmCatRef in inputs["astrometricCorrectionCatalogs"]:
pmCat = pmCatRef.get(
parameters={"columns": self.config.astrometricCorrectionParameters.values()}
)
astrometricCorrections[pmCatRef.dataId["band"]] = pd.DataFrame(pmCat)
inputs["astrometricCorrectionCatalogs"] = astrometricCorrections
astrometricCorrections = inputs["astrometricCorrectionCatalog"].get(
parameters={"columns": self.config.astrometricCorrectionParameters.values()}
)
inputs["astrometricCorrectionCatalog"] = astrometricCorrections
else:
inputs["astrometricCorrectionCatalogs"] = None
inputs["astrometricCorrectionCatalog"] = None
inputs["visitTable"] = None

dataId = butlerQC.quantum.dataId
Expand Down
Loading
Loading