Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-43332: Improve match_probabilistic performance #188

Merged
merged 8 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 121 additions & 80 deletions python/lsst/meas/astrom/matcher_probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import numpy as np
import pandas as pd
from scipy.spatial import cKDTree
from smatch.matcher import Matcher
import time
from typing import Callable, Set

Expand Down Expand Up @@ -122,63 +123,57 @@ class ComparableCatalog:
class ConvertCatalogCoordinatesConfig(pexConfig.Config):
"""Configuration for the MatchProbabilistic matcher."""

column_ref_coord1 = pexConfig.Field(
dtype=str,
column_ref_coord1 = pexConfig.Field[str](
default='ra',
doc='The reference table column for the first spatial coordinate (usually x or ra).',
)
column_ref_coord2 = pexConfig.Field(
dtype=str,
column_ref_coord2 = pexConfig.Field[str](
default='dec',
doc='The reference table column for the second spatial coordinate (usually y or dec).'
'Units must match column_ref_coord1.',
)
column_target_coord1 = pexConfig.Field(
dtype=str,
column_target_coord1 = pexConfig.Field[str](
default='coord_ra',
doc='The target table column for the first spatial coordinate (usually x or ra).'
'Units must match column_ref_coord1.',
)
column_target_coord2 = pexConfig.Field(
dtype=str,
column_target_coord2 = pexConfig.Field[str](
default='coord_dec',
doc='The target table column for the second spatial coordinate (usually y or dec).'
'Units must match column_ref_coord2.',
)
coords_spherical = pexConfig.Field(
dtype=bool,
coords_spherical = pexConfig.Field[bool](
default=True,
doc='Whether column_*_coord[12] are spherical coordinates (ra/dec) or not (pixel x/y)',
doc='Whether column_*_coord[12] are spherical coordinates (ra/dec) or not (pixel x/y).',
)
coords_ref_factor = pexConfig.Field(
dtype=float,
coords_ref_factor = pexConfig.Field[float](
default=1.0,
doc='Multiplicative factor for reference catalog coordinates.'
'If coords_spherical is true, this must be the number of degrees per unit increment of '
'column_ref_coord[12]. Otherwise, it must convert the coordinate to the same units'
' as the target coordinates.',
)
coords_target_factor = pexConfig.Field(
dtype=float,
coords_target_factor = pexConfig.Field[float](
default=1.0,
doc='Multiplicative factor for target catalog coordinates.'
'If coords_spherical is true, this must be the number of degrees per unit increment of '
'column_target_coord[12]. Otherwise, it must convert the coordinate to the same units'
' as the reference coordinates.',
)
coords_ref_to_convert = pexConfig.DictField(
coords_ref_to_convert = pexConfig.DictField[str, str](
default=None,
optional=True,
keytype=str,
itemtype=str,
dictCheck=lambda x: len(x) == 2,
doc='Dict mapping sky coordinate columns to be converted to pixel columns',
doc='Dict mapping sky coordinate columns to be converted to pixel columns.',
)
mag_zeropoint_ref = pexConfig.Field(
dtype=float,
mag_zeropoint_ref = pexConfig.Field[float](
default=31.4,
doc='Magnitude zeropoint for reference catalog.',
)
return_converted_coords = pexConfig.Field[float](
default=True,
doc='Whether to return converted coordinates for matching or only write them.',
)

def format_catalogs(
self,
Expand All @@ -187,7 +182,6 @@ def format_catalogs(
select_ref: np.array = None,
select_target: np.array = None,
radec_to_xy_func: Callable = None,
return_converted_columns: bool = False,
**kwargs,
):
"""Format matched catalogs that may require coordinate conversions.
Expand All @@ -206,10 +200,8 @@ def format_catalogs(
Function taking equal-length ra, dec arrays and returning an ndarray of
- ``x``: current parameter (`float`).
- ``extra_args``: additional arguments (`dict`).
return_converted_columns : `bool`
Whether to return converted columns in the `coord1` and `coord2`
attributes, rather than keep the original values.
kwargs
Additional keyword arguments to pass to radec_to_xy_func.

Returns
-------
Expand Down Expand Up @@ -246,9 +238,10 @@ def format_catalogs(
for idx_coord, column_out in enumerate(self.coords_ref_to_convert.values()):
coord = np.array([xy[idx_coord] for xy in xy_ref])
catalog[column_out] = coord
if convert_ref and return_converted_columns:
if convert_ref:
column1, column2 = self.coords_ref_to_convert.values()
coord1, coord2 = catalog[column1], catalog[column2]
if self.return_converted_coords:
coord1, coord2 = catalog[column1], catalog[column2]
if isinstance(coord1, pd.Series):
coord1 = coord1.values
if isinstance(coord2, pd.Series):
Expand Down Expand Up @@ -388,10 +381,11 @@ def columns_in_target(self) -> Set[str]:
default=10,
optional=True,
doc='Maximum number of spatial matches to consider (in ascending distance order).',
check=lambda x: x >= 1,
)
match_n_finite_min = pexConfig.Field(
dtype=int,
default=3,
default=2,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a check here to specify the valid domain?

optional=True,
doc='Minimum number of columns with a finite value to measure match likelihood',
)
Expand Down Expand Up @@ -450,14 +444,14 @@ def __init__(
self.config = config

def match(
self,
catalog_ref: pd.DataFrame,
catalog_target: pd.DataFrame,
select_ref: np.array = None,
select_target: np.array = None,
logger: logging.Logger = None,
logging_n_rows: int = None,
**kwargs
self,
catalog_ref: pd.DataFrame,
catalog_target: pd.DataFrame,
select_ref: np.array = None,
select_target: np.array = None,
logger: logging.Logger = None,
logging_n_rows: int = None,
**kwargs
):
"""Match catalogs.

Expand Down Expand Up @@ -492,6 +486,7 @@ def match(
if logger is None:
logger = logger_default

t_init = time.process_time()
config = self.config

# Transform any coordinates, if required
Expand Down Expand Up @@ -522,27 +517,39 @@ def match(

n_ref_select = len(ref.extras.indices)

match_dist_max = config.match_dist_max
coords_spherical = config.coord_format.coords_spherical
if coords_spherical:
match_dist_max = np.radians(match_dist_max / 3600.)

# Convert ra/dec sky coordinates to spherical vectors for accurate distances
func_convert = _radec_to_xyz if coords_spherical else np.vstack
vec_ref, vec_target = (
func_convert(cat.coord1[cat.extras.select], cat.coord2[cat.extras.select])
coords_ref, coords_target = (
(cat.coord1[cat.extras.select], cat.coord2[cat.extras.select])
for cat in (ref, target)
)

# Generate K-d tree to compute distances
logger.info('Generating cKDTree with match_n_max=%d', config.match_n_max)
tree_obj = cKDTree(vec_target)

scores, idxs_target_select = tree_obj.query(
vec_ref,
distance_upper_bound=match_dist_max,
k=config.match_n_max,
)
if coords_spherical:
match_dist_max = config.match_dist_max/3600.
with Matcher(coords_target[0], coords_target[1]) as matcher:
idxs_target_select = matcher.query_knn(
coords_ref[0], coords_ref[1],
distance_upper_bound=match_dist_max,
k=config.match_n_max,
)
# Call scipy for non-spherical case
# The spherical case won't trigger, but the implementation is left for comparison, if needed

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean "The non-spherical case won't trigger"?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you anticipate such comparisons becoming necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I mean that the spherical case won't use scipy but will use smatch instead. So the _radec_to_xyz if coords_spherical else np.vstack on line 542 will never actually call _radec_to_xyz.

I anticipate most users will match on spherical coordinates and so the else block won't get called much.

Copy link

@enourbakhsh enourbakhsh May 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see what you're getting at now. Sorry, that comment is quite confusing in the current version of the code. I mistakenly thought you actually wanted to say "The non-spherical case won't trigger in general" to point out that "The non-spherical case (the else block) won't get called much, and not at all by default."

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So within this else block intended for non-spherical coordinates where coords_spherical is False, there's an if coords_spherical condition that obviously won't be met but might be useful for comparison. I suggest adding a more explicit comment to convey this and clear up any confusion.

else:
match_dist_max = np.radians(config.match_dist_max/3600.)
# Convert ra/dec sky coordinates to spherical vectors for accurate distances
func_convert = _radec_to_xyz if coords_spherical else np.vstack
vec_ref, vec_target = (
func_convert(coords[0], coords[1])
for coords in (coords_ref, coords_target)
)
tree_obj = cKDTree(vec_target)
_, idxs_target_select = tree_obj.query(
vec_ref,
distance_upper_bound=match_dist_max,
k=config.match_n_max,
)

n_target_select = len(target.extras.indices)
n_matches = np.sum(idxs_target_select != n_target_select, axis=1)
Expand All @@ -562,7 +569,7 @@ def match(
ref_chisq = np.full(ref.extras.n, np.nan, dtype=float)

# Need the original reference row indices for output
idx_orig_ref, idx_orig_target = (np.argwhere(cat.extras.select) for cat in (ref, target))
idx_orig_ref, idx_orig_target = (np.argwhere(cat.extras.select)[:, 0] for cat in (ref, target))

# Retrieve required columns, including any converted ones (default to original column name)
columns_convert = config.coord_format.coords_ref_to_convert
Expand All @@ -577,21 +584,45 @@ def match(
exceptions = {}
# The kdTree uses len(inputs) as a sentinel value for no match
matched_target = {n_target_select, }
index_ref = idx_orig_ref[order]
# Fill in the candidate column
ref_candidate_match[index_ref] = True

# Count this as the time when disambiguation begins
t_begin = time.process_time()

logger.info('Matching n_indices=%d/%d', len(order), len(ref.catalog))
# Exclude unmatched sources
matched_ref = idxs_target_select[order, 0] != n_target_select
order = order[matched_ref]
idx_first = idxs_target_select[order, 0]
chi_0 = (data_target.iloc[idx_first].values - data_ref.iloc[matched_ref].values)/(
errors_target.iloc[idx_first].values)
chi_finite_0 = np.isfinite(chi_0)
n_finite_0 = np.sum(chi_finite_0, axis=1)
chi_0[~chi_finite_0] = 0
chisq_sum_0 = np.sum(chi_0*chi_0, axis=1)

logger.info('Disambiguating %d/%d matches/targets', len(order), len(ref.catalog))
for index_n, index_row_select in enumerate(order):
index_row = idx_orig_ref[index_row_select]
ref_candidate_match[index_row] = True
found = idxs_target_select[index_row_select, :]
# Select match candidates from nearby sources not already matched
# Note: set lookup is apparently fast enough that this is a few percent faster than:
# found = [x for x in found[found != n_target_select] if x not in matched_target]
# ... at least for ~1M sources
found = [x for x in found if x not in matched_target]
n_found = len(found)
if n_found > 0:
# Unambiguous match, short-circuit some evaluations
if (found[1] == n_target_select) and (found[0] not in matched_target):
n_finite = n_finite_0[index_n]
if not (n_finite >= config.match_n_finite_min):
continue
idx_chisq_min = 0
n_matched = 1
chisq_sum = chisq_sum_0[index_n]
else:
# Select match candidates from nearby sources not already matched
# Note: set lookup is apparently fast enough that this is a few percent faster than:
# found = [x for x in found[found != n_target_select] if x not in matched_target]
# ... at least for ~1M sources
found = [x for x in found if x not in matched_target]
n_found = len(found)
if n_found == 0:
continue
# This is an ndarray of n_found rows x len(data_ref/target) columns
chi = (
(data_target.iloc[found].values - data_ref.iloc[index_n].values)
Expand All @@ -601,24 +632,28 @@ def match(
n_finite = np.sum(finite, axis=1)
# Require some number of finite chi_sq to match
chisq_good = n_finite >= config.match_n_finite_min
if np.any(chisq_good):
try:
chisq_sum = np.zeros(n_found, dtype=float)
chisq_sum[chisq_good] = np.nansum(chi[chisq_good, :] ** 2, axis=1)
idx_chisq_min = np.nanargmin(chisq_sum / n_finite)
ref_match_meas_finite[index_row] = n_finite[idx_chisq_min]
ref_match_count[index_row] = len(chisq_good)
ref_chisq[index_row] = chisq_sum[idx_chisq_min]
idx_match_select = found[idx_chisq_min]
row_target = target.extras.indices[idx_match_select]
ref_row_match[index_row] = row_target

target_row_match[row_target] = index_row
matched_target.add(idx_match_select)
except Exception as error:
# Can't foresee any exceptions, but they shouldn't prevent
# matching subsequent sources
exceptions[index_row] = error
if not any(chisq_good):
continue
try:
chisq_sum = np.zeros(n_found, dtype=float)
chisq_sum[chisq_good] = np.nansum(chi[chisq_good, :] ** 2, axis=1)
idx_chisq_min = np.nanargmin(chisq_sum / n_finite)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we always safeguarded against division by zero here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as n_finite_min > 0, yes.

Copy link

@enourbakhsh enourbakhsh May 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's why I suggested you add a check in match_n_finite_min = pexConfig.Field(...).

n_finite = n_finite[idx_chisq_min]
n_matched = len(chisq_good)
chisq_sum = chisq_sum[idx_chisq_min]
except Exception as error:
# Can't foresee any exceptions, but they shouldn't prevent
# matching subsequent sources
exceptions[index_row] = error
ref_match_meas_finite[index_row] = n_finite
ref_match_count[index_row] = n_matched
ref_chisq[index_row] = chisq_sum
idx_match_select = found[idx_chisq_min]
row_target = target.extras.indices[idx_match_select]
ref_row_match[index_row] = row_target

target_row_match[row_target] = index_row
matched_target.add(idx_match_select)

if logging_n_rows and ((index_n + 1) % logging_n_rows == 0):
t_elapsed = time.process_time() - t_begin
Expand Down Expand Up @@ -648,7 +683,7 @@ def match(
ref,
target,
target_row_match,
'reference',
'target',
),
(
self.config.columns_target_copy,
Expand All @@ -657,7 +692,7 @@ def match(
target,
ref,
ref_row_match,
'target',
'reference',
),
):
matched = matches >= 0
Expand Down Expand Up @@ -686,6 +721,12 @@ def match(
column_match[matched] = in_original.catalog[column][idx_matched]
out_matched[f'match_{column}'] = column_match

logger.info(
'Completed match disambiguating in %.2fs (total %.2fs)',
time.process_time() - t_begin,
time.process_time() - t_init,
)

catalog_out_ref = pd.DataFrame(data_ref)
catalog_out_target = pd.DataFrame(data_target)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_matcher_probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def setUp(self):
columns_target_meas=["x", "y"],
columns_target_err=["xErr", "yErr"],
)
configs_bad = {"too_few_finite": MatchProbabilisticConfig(**kwargs)}
self.config_good = MatchProbabilisticConfig(match_n_finite_min=2, **kwargs)
self.config_good = MatchProbabilisticConfig(**kwargs)
configs_bad = {}
kwargs["columns_target_meas"] = ["x"]
configs_bad["too_few_target_meas"] = MatchProbabilisticConfig(**kwargs)
kwargs["columns_target_meas"] = ["x", "y", "z"]
Expand Down
Loading