Skip to content

Commit

Permalink
Merge branch 'tickets/DM-43332'
Browse files Browse the repository at this point in the history
  • Loading branch information
taranu committed May 16, 2024
2 parents 991b906 + 7301cf6 commit 4be5004
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 82 deletions.
201 changes: 121 additions & 80 deletions python/lsst/meas/astrom/matcher_probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import numpy as np
import pandas as pd
from scipy.spatial import cKDTree
from smatch.matcher import Matcher
import time
from typing import Callable, Set

Expand Down Expand Up @@ -122,63 +123,57 @@ class ComparableCatalog:
class ConvertCatalogCoordinatesConfig(pexConfig.Config):
"""Configuration for the MatchProbabilistic matcher."""

column_ref_coord1 = pexConfig.Field(
dtype=str,
column_ref_coord1 = pexConfig.Field[str](
default='ra',
doc='The reference table column for the first spatial coordinate (usually x or ra).',
)
column_ref_coord2 = pexConfig.Field(
dtype=str,
column_ref_coord2 = pexConfig.Field[str](
default='dec',
doc='The reference table column for the second spatial coordinate (usually y or dec).'
'Units must match column_ref_coord1.',
)
column_target_coord1 = pexConfig.Field(
dtype=str,
column_target_coord1 = pexConfig.Field[str](
default='coord_ra',
doc='The target table column for the first spatial coordinate (usually x or ra).'
'Units must match column_ref_coord1.',
)
column_target_coord2 = pexConfig.Field(
dtype=str,
column_target_coord2 = pexConfig.Field[str](
default='coord_dec',
doc='The target table column for the second spatial coordinate (usually y or dec).'
'Units must match column_ref_coord2.',
)
coords_spherical = pexConfig.Field(
dtype=bool,
coords_spherical = pexConfig.Field[bool](
default=True,
doc='Whether column_*_coord[12] are spherical coordinates (ra/dec) or not (pixel x/y)',
doc='Whether column_*_coord[12] are spherical coordinates (ra/dec) or not (pixel x/y).',
)
coords_ref_factor = pexConfig.Field(
dtype=float,
coords_ref_factor = pexConfig.Field[float](
default=1.0,
doc='Multiplicative factor for reference catalog coordinates.'
'If coords_spherical is true, this must be the number of degrees per unit increment of '
'column_ref_coord[12]. Otherwise, it must convert the coordinate to the same units'
' as the target coordinates.',
)
coords_target_factor = pexConfig.Field(
dtype=float,
coords_target_factor = pexConfig.Field[float](
default=1.0,
doc='Multiplicative factor for target catalog coordinates.'
'If coords_spherical is true, this must be the number of degrees per unit increment of '
'column_target_coord[12]. Otherwise, it must convert the coordinate to the same units'
' as the reference coordinates.',
)
coords_ref_to_convert = pexConfig.DictField(
coords_ref_to_convert = pexConfig.DictField[str, str](
default=None,
optional=True,
keytype=str,
itemtype=str,
dictCheck=lambda x: len(x) == 2,
doc='Dict mapping sky coordinate columns to be converted to pixel columns',
doc='Dict mapping sky coordinate columns to be converted to pixel columns.',
)
mag_zeropoint_ref = pexConfig.Field(
dtype=float,
mag_zeropoint_ref = pexConfig.Field[float](
default=31.4,
doc='Magnitude zeropoint for reference catalog.',
)
return_converted_coords = pexConfig.Field[float](
default=True,
doc='Whether to return converted coordinates for matching or only write them.',
)

def format_catalogs(
self,
Expand All @@ -187,7 +182,6 @@ def format_catalogs(
select_ref: np.array = None,
select_target: np.array = None,
radec_to_xy_func: Callable = None,
return_converted_columns: bool = False,
**kwargs,
):
"""Format matched catalogs that may require coordinate conversions.
Expand All @@ -206,10 +200,8 @@ def format_catalogs(
Function taking equal-length ra, dec arrays and returning an ndarray of
- ``x``: current parameter (`float`).
- ``extra_args``: additional arguments (`dict`).
return_converted_columns : `bool`
Whether to return converted columns in the `coord1` and `coord2`
attributes, rather than keep the original values.
kwargs
Additional keyword arguments to pass to radec_to_xy_func.
Returns
-------
Expand Down Expand Up @@ -246,9 +238,10 @@ def format_catalogs(
for idx_coord, column_out in enumerate(self.coords_ref_to_convert.values()):
coord = np.array([xy[idx_coord] for xy in xy_ref])
catalog[column_out] = coord
if convert_ref and return_converted_columns:
if convert_ref:
column1, column2 = self.coords_ref_to_convert.values()
coord1, coord2 = catalog[column1], catalog[column2]
if self.return_converted_coords:
coord1, coord2 = catalog[column1], catalog[column2]
if isinstance(coord1, pd.Series):
coord1 = coord1.values
if isinstance(coord2, pd.Series):
Expand Down Expand Up @@ -388,10 +381,11 @@ def columns_in_target(self) -> Set[str]:
default=10,
optional=True,
doc='Maximum number of spatial matches to consider (in ascending distance order).',
check=lambda x: x >= 1,
)
match_n_finite_min = pexConfig.Field(
dtype=int,
default=3,
default=2,
optional=True,
doc='Minimum number of columns with a finite value to measure match likelihood',
)
Expand Down Expand Up @@ -450,14 +444,14 @@ def __init__(
self.config = config

def match(
self,
catalog_ref: pd.DataFrame,
catalog_target: pd.DataFrame,
select_ref: np.array = None,
select_target: np.array = None,
logger: logging.Logger = None,
logging_n_rows: int = None,
**kwargs
self,
catalog_ref: pd.DataFrame,
catalog_target: pd.DataFrame,
select_ref: np.array = None,
select_target: np.array = None,
logger: logging.Logger = None,
logging_n_rows: int = None,
**kwargs
):
"""Match catalogs.
Expand Down Expand Up @@ -492,6 +486,7 @@ def match(
if logger is None:
logger = logger_default

t_init = time.process_time()
config = self.config

# Transform any coordinates, if required
Expand Down Expand Up @@ -522,27 +517,39 @@ def match(

n_ref_select = len(ref.extras.indices)

match_dist_max = config.match_dist_max
coords_spherical = config.coord_format.coords_spherical
if coords_spherical:
match_dist_max = np.radians(match_dist_max / 3600.)

# Convert ra/dec sky coordinates to spherical vectors for accurate distances
func_convert = _radec_to_xyz if coords_spherical else np.vstack
vec_ref, vec_target = (
func_convert(cat.coord1[cat.extras.select], cat.coord2[cat.extras.select])
coords_ref, coords_target = (
(cat.coord1[cat.extras.select], cat.coord2[cat.extras.select])
for cat in (ref, target)
)

# Generate K-d tree to compute distances
logger.info('Generating cKDTree with match_n_max=%d', config.match_n_max)
tree_obj = cKDTree(vec_target)

scores, idxs_target_select = tree_obj.query(
vec_ref,
distance_upper_bound=match_dist_max,
k=config.match_n_max,
)
if coords_spherical:
match_dist_max = config.match_dist_max/3600.
with Matcher(coords_target[0], coords_target[1]) as matcher:
idxs_target_select = matcher.query_knn(
coords_ref[0], coords_ref[1],
distance_upper_bound=match_dist_max,
k=config.match_n_max,
)
# Call scipy for non-spherical case
# The spherical case won't trigger, but the implementation is left for comparison, if needed
else:
match_dist_max = np.radians(config.match_dist_max/3600.)
# Convert ra/dec sky coordinates to spherical vectors for accurate distances
func_convert = _radec_to_xyz if coords_spherical else np.vstack
vec_ref, vec_target = (
func_convert(coords[0], coords[1])
for coords in (coords_ref, coords_target)
)
tree_obj = cKDTree(vec_target)
_, idxs_target_select = tree_obj.query(
vec_ref,
distance_upper_bound=match_dist_max,
k=config.match_n_max,
)

n_target_select = len(target.extras.indices)
n_matches = np.sum(idxs_target_select != n_target_select, axis=1)
Expand All @@ -562,7 +569,7 @@ def match(
ref_chisq = np.full(ref.extras.n, np.nan, dtype=float)

# Need the original reference row indices for output
idx_orig_ref, idx_orig_target = (np.argwhere(cat.extras.select) for cat in (ref, target))
idx_orig_ref, idx_orig_target = (np.argwhere(cat.extras.select)[:, 0] for cat in (ref, target))

# Retrieve required columns, including any converted ones (default to original column name)
columns_convert = config.coord_format.coords_ref_to_convert
Expand All @@ -577,21 +584,45 @@ def match(
exceptions = {}
# The kdTree uses len(inputs) as a sentinel value for no match
matched_target = {n_target_select, }
index_ref = idx_orig_ref[order]
# Fill in the candidate column
ref_candidate_match[index_ref] = True

# Count this as the time when disambiguation begins
t_begin = time.process_time()

logger.info('Matching n_indices=%d/%d', len(order), len(ref.catalog))
# Exclude unmatched sources
matched_ref = idxs_target_select[order, 0] != n_target_select
order = order[matched_ref]
idx_first = idxs_target_select[order, 0]
chi_0 = (data_target.iloc[idx_first].values - data_ref.iloc[matched_ref].values)/(
errors_target.iloc[idx_first].values)
chi_finite_0 = np.isfinite(chi_0)
n_finite_0 = np.sum(chi_finite_0, axis=1)
chi_0[~chi_finite_0] = 0
chisq_sum_0 = np.sum(chi_0*chi_0, axis=1)

logger.info('Disambiguating %d/%d matches/targets', len(order), len(ref.catalog))
for index_n, index_row_select in enumerate(order):
index_row = idx_orig_ref[index_row_select]
ref_candidate_match[index_row] = True
found = idxs_target_select[index_row_select, :]
# Select match candidates from nearby sources not already matched
# Note: set lookup is apparently fast enough that this is a few percent faster than:
# found = [x for x in found[found != n_target_select] if x not in matched_target]
# ... at least for ~1M sources
found = [x for x in found if x not in matched_target]
n_found = len(found)
if n_found > 0:
# Unambiguous match, short-circuit some evaluations
if (found[1] == n_target_select) and (found[0] not in matched_target):
n_finite = n_finite_0[index_n]
if not (n_finite >= config.match_n_finite_min):
continue
idx_chisq_min = 0
n_matched = 1
chisq_sum = chisq_sum_0[index_n]
else:
# Select match candidates from nearby sources not already matched
# Note: set lookup is apparently fast enough that this is a few percent faster than:
# found = [x for x in found[found != n_target_select] if x not in matched_target]
# ... at least for ~1M sources
found = [x for x in found if x not in matched_target]
n_found = len(found)
if n_found == 0:
continue
# This is an ndarray of n_found rows x len(data_ref/target) columns
chi = (
(data_target.iloc[found].values - data_ref.iloc[index_n].values)
Expand All @@ -601,24 +632,28 @@ def match(
n_finite = np.sum(finite, axis=1)
# Require some number of finite chi_sq to match
chisq_good = n_finite >= config.match_n_finite_min
if np.any(chisq_good):
try:
chisq_sum = np.zeros(n_found, dtype=float)
chisq_sum[chisq_good] = np.nansum(chi[chisq_good, :] ** 2, axis=1)
idx_chisq_min = np.nanargmin(chisq_sum / n_finite)
ref_match_meas_finite[index_row] = n_finite[idx_chisq_min]
ref_match_count[index_row] = len(chisq_good)
ref_chisq[index_row] = chisq_sum[idx_chisq_min]
idx_match_select = found[idx_chisq_min]
row_target = target.extras.indices[idx_match_select]
ref_row_match[index_row] = row_target

target_row_match[row_target] = index_row
matched_target.add(idx_match_select)
except Exception as error:
# Can't foresee any exceptions, but they shouldn't prevent
# matching subsequent sources
exceptions[index_row] = error
if not any(chisq_good):
continue
try:
chisq_sum = np.zeros(n_found, dtype=float)
chisq_sum[chisq_good] = np.nansum(chi[chisq_good, :] ** 2, axis=1)
idx_chisq_min = np.nanargmin(chisq_sum / n_finite)
n_finite = n_finite[idx_chisq_min]
n_matched = len(chisq_good)
chisq_sum = chisq_sum[idx_chisq_min]
except Exception as error:
# Can't foresee any exceptions, but they shouldn't prevent
# matching subsequent sources
exceptions[index_row] = error
ref_match_meas_finite[index_row] = n_finite
ref_match_count[index_row] = n_matched
ref_chisq[index_row] = chisq_sum
idx_match_select = found[idx_chisq_min]
row_target = target.extras.indices[idx_match_select]
ref_row_match[index_row] = row_target

target_row_match[row_target] = index_row
matched_target.add(idx_match_select)

if logging_n_rows and ((index_n + 1) % logging_n_rows == 0):
t_elapsed = time.process_time() - t_begin
Expand Down Expand Up @@ -648,7 +683,7 @@ def match(
ref,
target,
target_row_match,
'reference',
'target',
),
(
self.config.columns_target_copy,
Expand All @@ -657,7 +692,7 @@ def match(
target,
ref,
ref_row_match,
'target',
'reference',
),
):
matched = matches >= 0
Expand Down Expand Up @@ -686,6 +721,12 @@ def match(
column_match[matched] = in_original.catalog[column][idx_matched]
out_matched[f'match_{column}'] = column_match

logger.info(
'Completed match disambiguating in %.2fs (total %.2fs)',
time.process_time() - t_begin,
time.process_time() - t_init,
)

catalog_out_ref = pd.DataFrame(data_ref)
catalog_out_target = pd.DataFrame(data_target)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_matcher_probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def setUp(self):
columns_target_meas=["x", "y"],
columns_target_err=["xErr", "yErr"],
)
configs_bad = {"too_few_finite": MatchProbabilisticConfig(**kwargs)}
self.config_good = MatchProbabilisticConfig(match_n_finite_min=2, **kwargs)
self.config_good = MatchProbabilisticConfig(**kwargs)
configs_bad = {}
kwargs["columns_target_meas"] = ["x"]
configs_bad["too_few_target_meas"] = MatchProbabilisticConfig(**kwargs)
kwargs["columns_target_meas"] = ["x", "y", "z"]
Expand Down

0 comments on commit 4be5004

Please sign in to comment.