-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DM-43332: Improve match_probabilistic performance #188
Changes from all commits
25bfc9d
1e3d556
691aaa4
bc46639
ea29c13
9ab0df9
a0e6f2e
7301cf6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
import numpy as np | ||
import pandas as pd | ||
from scipy.spatial import cKDTree | ||
from smatch.matcher import Matcher | ||
import time | ||
from typing import Callable, Set | ||
|
||
|
@@ -122,63 +123,57 @@ class ComparableCatalog: | |
class ConvertCatalogCoordinatesConfig(pexConfig.Config): | ||
"""Configuration for the MatchProbabilistic matcher.""" | ||
|
||
column_ref_coord1 = pexConfig.Field( | ||
dtype=str, | ||
column_ref_coord1 = pexConfig.Field[str]( | ||
default='ra', | ||
doc='The reference table column for the first spatial coordinate (usually x or ra).', | ||
) | ||
column_ref_coord2 = pexConfig.Field( | ||
dtype=str, | ||
column_ref_coord2 = pexConfig.Field[str]( | ||
default='dec', | ||
doc='The reference table column for the second spatial coordinate (usually y or dec).' | ||
'Units must match column_ref_coord1.', | ||
) | ||
column_target_coord1 = pexConfig.Field( | ||
dtype=str, | ||
column_target_coord1 = pexConfig.Field[str]( | ||
default='coord_ra', | ||
doc='The target table column for the first spatial coordinate (usually x or ra).' | ||
'Units must match column_ref_coord1.', | ||
) | ||
column_target_coord2 = pexConfig.Field( | ||
dtype=str, | ||
column_target_coord2 = pexConfig.Field[str]( | ||
default='coord_dec', | ||
doc='The target table column for the second spatial coordinate (usually y or dec).' | ||
'Units must match column_ref_coord2.', | ||
) | ||
coords_spherical = pexConfig.Field( | ||
dtype=bool, | ||
coords_spherical = pexConfig.Field[bool]( | ||
default=True, | ||
doc='Whether column_*_coord[12] are spherical coordinates (ra/dec) or not (pixel x/y)', | ||
doc='Whether column_*_coord[12] are spherical coordinates (ra/dec) or not (pixel x/y).', | ||
) | ||
coords_ref_factor = pexConfig.Field( | ||
dtype=float, | ||
coords_ref_factor = pexConfig.Field[float]( | ||
default=1.0, | ||
doc='Multiplicative factor for reference catalog coordinates.' | ||
'If coords_spherical is true, this must be the number of degrees per unit increment of ' | ||
'column_ref_coord[12]. Otherwise, it must convert the coordinate to the same units' | ||
' as the target coordinates.', | ||
) | ||
coords_target_factor = pexConfig.Field( | ||
dtype=float, | ||
coords_target_factor = pexConfig.Field[float]( | ||
default=1.0, | ||
doc='Multiplicative factor for target catalog coordinates.' | ||
'If coords_spherical is true, this must be the number of degrees per unit increment of ' | ||
'column_target_coord[12]. Otherwise, it must convert the coordinate to the same units' | ||
' as the reference coordinates.', | ||
) | ||
coords_ref_to_convert = pexConfig.DictField( | ||
coords_ref_to_convert = pexConfig.DictField[str, str]( | ||
default=None, | ||
optional=True, | ||
keytype=str, | ||
itemtype=str, | ||
dictCheck=lambda x: len(x) == 2, | ||
doc='Dict mapping sky coordinate columns to be converted to pixel columns', | ||
doc='Dict mapping sky coordinate columns to be converted to pixel columns.', | ||
) | ||
mag_zeropoint_ref = pexConfig.Field( | ||
dtype=float, | ||
mag_zeropoint_ref = pexConfig.Field[float]( | ||
default=31.4, | ||
doc='Magnitude zeropoint for reference catalog.', | ||
) | ||
return_converted_coords = pexConfig.Field[float]( | ||
default=True, | ||
doc='Whether to return converted coordinates for matching or only write them.', | ||
) | ||
|
||
def format_catalogs( | ||
self, | ||
|
@@ -187,7 +182,6 @@ def format_catalogs( | |
select_ref: np.array = None, | ||
select_target: np.array = None, | ||
radec_to_xy_func: Callable = None, | ||
return_converted_columns: bool = False, | ||
**kwargs, | ||
): | ||
"""Format matched catalogs that may require coordinate conversions. | ||
|
@@ -206,10 +200,8 @@ def format_catalogs( | |
Function taking equal-length ra, dec arrays and returning an ndarray of | ||
- ``x``: current parameter (`float`). | ||
- ``extra_args``: additional arguments (`dict`). | ||
return_converted_columns : `bool` | ||
Whether to return converted columns in the `coord1` and `coord2` | ||
attributes, rather than keep the original values. | ||
kwargs | ||
Additional keyword arguments to pass to radec_to_xy_func. | ||
|
||
Returns | ||
------- | ||
|
@@ -246,9 +238,10 @@ def format_catalogs( | |
for idx_coord, column_out in enumerate(self.coords_ref_to_convert.values()): | ||
coord = np.array([xy[idx_coord] for xy in xy_ref]) | ||
catalog[column_out] = coord | ||
if convert_ref and return_converted_columns: | ||
if convert_ref: | ||
column1, column2 = self.coords_ref_to_convert.values() | ||
coord1, coord2 = catalog[column1], catalog[column2] | ||
if self.return_converted_coords: | ||
coord1, coord2 = catalog[column1], catalog[column2] | ||
if isinstance(coord1, pd.Series): | ||
coord1 = coord1.values | ||
if isinstance(coord2, pd.Series): | ||
|
@@ -388,10 +381,11 @@ def columns_in_target(self) -> Set[str]: | |
default=10, | ||
optional=True, | ||
doc='Maximum number of spatial matches to consider (in ascending distance order).', | ||
check=lambda x: x >= 1, | ||
) | ||
match_n_finite_min = pexConfig.Field( | ||
dtype=int, | ||
default=3, | ||
default=2, | ||
optional=True, | ||
doc='Minimum number of columns with a finite value to measure match likelihood', | ||
) | ||
|
@@ -450,14 +444,14 @@ def __init__( | |
self.config = config | ||
|
||
def match( | ||
self, | ||
catalog_ref: pd.DataFrame, | ||
catalog_target: pd.DataFrame, | ||
select_ref: np.array = None, | ||
select_target: np.array = None, | ||
logger: logging.Logger = None, | ||
logging_n_rows: int = None, | ||
**kwargs | ||
self, | ||
catalog_ref: pd.DataFrame, | ||
catalog_target: pd.DataFrame, | ||
select_ref: np.array = None, | ||
select_target: np.array = None, | ||
logger: logging.Logger = None, | ||
logging_n_rows: int = None, | ||
**kwargs | ||
): | ||
"""Match catalogs. | ||
|
||
|
@@ -492,6 +486,7 @@ def match( | |
if logger is None: | ||
logger = logger_default | ||
|
||
t_init = time.process_time() | ||
config = self.config | ||
|
||
# Transform any coordinates, if required | ||
|
@@ -522,27 +517,39 @@ def match( | |
|
||
n_ref_select = len(ref.extras.indices) | ||
|
||
match_dist_max = config.match_dist_max | ||
coords_spherical = config.coord_format.coords_spherical | ||
if coords_spherical: | ||
match_dist_max = np.radians(match_dist_max / 3600.) | ||
|
||
# Convert ra/dec sky coordinates to spherical vectors for accurate distances | ||
func_convert = _radec_to_xyz if coords_spherical else np.vstack | ||
vec_ref, vec_target = ( | ||
func_convert(cat.coord1[cat.extras.select], cat.coord2[cat.extras.select]) | ||
coords_ref, coords_target = ( | ||
(cat.coord1[cat.extras.select], cat.coord2[cat.extras.select]) | ||
for cat in (ref, target) | ||
) | ||
|
||
# Generate K-d tree to compute distances | ||
logger.info('Generating cKDTree with match_n_max=%d', config.match_n_max) | ||
tree_obj = cKDTree(vec_target) | ||
|
||
scores, idxs_target_select = tree_obj.query( | ||
vec_ref, | ||
distance_upper_bound=match_dist_max, | ||
k=config.match_n_max, | ||
) | ||
if coords_spherical: | ||
match_dist_max = config.match_dist_max/3600. | ||
with Matcher(coords_target[0], coords_target[1]) as matcher: | ||
idxs_target_select = matcher.query_knn( | ||
coords_ref[0], coords_ref[1], | ||
distance_upper_bound=match_dist_max, | ||
k=config.match_n_max, | ||
) | ||
# Call scipy for non-spherical case | ||
# The spherical case won't trigger, but the implementation is left for comparison, if needed | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean "The non-spherical case won't trigger"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do you anticipate such comparisons becoming necessary? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, I mean that the spherical case won't use scipy but will use smatch instead. So the I anticipate most users will match on spherical coordinates and so the else block won't get called much. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I see what you're getting at now. Sorry, that comment is quite confusing in the current version of the code. I mistakenly thought you actually wanted to say "The non-spherical case won't trigger in general" to point out that "The non-spherical case (the else block) won't get called much, and not at all by default." There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So within this else block intended for non-spherical coordinates where |
||
else: | ||
match_dist_max = np.radians(config.match_dist_max/3600.) | ||
# Convert ra/dec sky coordinates to spherical vectors for accurate distances | ||
func_convert = _radec_to_xyz if coords_spherical else np.vstack | ||
vec_ref, vec_target = ( | ||
func_convert(coords[0], coords[1]) | ||
for coords in (coords_ref, coords_target) | ||
) | ||
tree_obj = cKDTree(vec_target) | ||
_, idxs_target_select = tree_obj.query( | ||
vec_ref, | ||
distance_upper_bound=match_dist_max, | ||
k=config.match_n_max, | ||
) | ||
|
||
n_target_select = len(target.extras.indices) | ||
n_matches = np.sum(idxs_target_select != n_target_select, axis=1) | ||
|
@@ -562,7 +569,7 @@ def match( | |
ref_chisq = np.full(ref.extras.n, np.nan, dtype=float) | ||
|
||
# Need the original reference row indices for output | ||
idx_orig_ref, idx_orig_target = (np.argwhere(cat.extras.select) for cat in (ref, target)) | ||
idx_orig_ref, idx_orig_target = (np.argwhere(cat.extras.select)[:, 0] for cat in (ref, target)) | ||
|
||
# Retrieve required columns, including any converted ones (default to original column name) | ||
columns_convert = config.coord_format.coords_ref_to_convert | ||
|
@@ -577,21 +584,45 @@ def match( | |
exceptions = {} | ||
# The kdTree uses len(inputs) as a sentinel value for no match | ||
matched_target = {n_target_select, } | ||
index_ref = idx_orig_ref[order] | ||
# Fill in the candidate column | ||
ref_candidate_match[index_ref] = True | ||
|
||
# Count this as the time when disambiguation begins | ||
t_begin = time.process_time() | ||
|
||
logger.info('Matching n_indices=%d/%d', len(order), len(ref.catalog)) | ||
# Exclude unmatched sources | ||
matched_ref = idxs_target_select[order, 0] != n_target_select | ||
order = order[matched_ref] | ||
idx_first = idxs_target_select[order, 0] | ||
chi_0 = (data_target.iloc[idx_first].values - data_ref.iloc[matched_ref].values)/( | ||
errors_target.iloc[idx_first].values) | ||
chi_finite_0 = np.isfinite(chi_0) | ||
n_finite_0 = np.sum(chi_finite_0, axis=1) | ||
chi_0[~chi_finite_0] = 0 | ||
chisq_sum_0 = np.sum(chi_0*chi_0, axis=1) | ||
|
||
logger.info('Disambiguating %d/%d matches/targets', len(order), len(ref.catalog)) | ||
for index_n, index_row_select in enumerate(order): | ||
index_row = idx_orig_ref[index_row_select] | ||
ref_candidate_match[index_row] = True | ||
found = idxs_target_select[index_row_select, :] | ||
# Select match candidates from nearby sources not already matched | ||
# Note: set lookup is apparently fast enough that this is a few percent faster than: | ||
# found = [x for x in found[found != n_target_select] if x not in matched_target] | ||
# ... at least for ~1M sources | ||
found = [x for x in found if x not in matched_target] | ||
n_found = len(found) | ||
if n_found > 0: | ||
# Unambiguous match, short-circuit some evaluations | ||
if (found[1] == n_target_select) and (found[0] not in matched_target): | ||
n_finite = n_finite_0[index_n] | ||
if not (n_finite >= config.match_n_finite_min): | ||
continue | ||
idx_chisq_min = 0 | ||
n_matched = 1 | ||
chisq_sum = chisq_sum_0[index_n] | ||
else: | ||
# Select match candidates from nearby sources not already matched | ||
# Note: set lookup is apparently fast enough that this is a few percent faster than: | ||
# found = [x for x in found[found != n_target_select] if x not in matched_target] | ||
# ... at least for ~1M sources | ||
found = [x for x in found if x not in matched_target] | ||
n_found = len(found) | ||
if n_found == 0: | ||
continue | ||
# This is an ndarray of n_found rows x len(data_ref/target) columns | ||
chi = ( | ||
(data_target.iloc[found].values - data_ref.iloc[index_n].values) | ||
|
@@ -601,24 +632,28 @@ def match( | |
n_finite = np.sum(finite, axis=1) | ||
# Require some number of finite chi_sq to match | ||
chisq_good = n_finite >= config.match_n_finite_min | ||
if np.any(chisq_good): | ||
try: | ||
chisq_sum = np.zeros(n_found, dtype=float) | ||
chisq_sum[chisq_good] = np.nansum(chi[chisq_good, :] ** 2, axis=1) | ||
idx_chisq_min = np.nanargmin(chisq_sum / n_finite) | ||
ref_match_meas_finite[index_row] = n_finite[idx_chisq_min] | ||
ref_match_count[index_row] = len(chisq_good) | ||
ref_chisq[index_row] = chisq_sum[idx_chisq_min] | ||
idx_match_select = found[idx_chisq_min] | ||
row_target = target.extras.indices[idx_match_select] | ||
ref_row_match[index_row] = row_target | ||
|
||
target_row_match[row_target] = index_row | ||
matched_target.add(idx_match_select) | ||
except Exception as error: | ||
# Can't foresee any exceptions, but they shouldn't prevent | ||
# matching subsequent sources | ||
exceptions[index_row] = error | ||
if not any(chisq_good): | ||
continue | ||
try: | ||
chisq_sum = np.zeros(n_found, dtype=float) | ||
chisq_sum[chisq_good] = np.nansum(chi[chisq_good, :] ** 2, axis=1) | ||
idx_chisq_min = np.nanargmin(chisq_sum / n_finite) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we always safeguarded against division by zero here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As long as n_finite_min > 0, yes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's why I suggested you add a |
||
n_finite = n_finite[idx_chisq_min] | ||
n_matched = len(chisq_good) | ||
chisq_sum = chisq_sum[idx_chisq_min] | ||
except Exception as error: | ||
# Can't foresee any exceptions, but they shouldn't prevent | ||
# matching subsequent sources | ||
exceptions[index_row] = error | ||
ref_match_meas_finite[index_row] = n_finite | ||
ref_match_count[index_row] = n_matched | ||
ref_chisq[index_row] = chisq_sum | ||
idx_match_select = found[idx_chisq_min] | ||
row_target = target.extras.indices[idx_match_select] | ||
ref_row_match[index_row] = row_target | ||
|
||
target_row_match[row_target] = index_row | ||
matched_target.add(idx_match_select) | ||
|
||
if logging_n_rows and ((index_n + 1) % logging_n_rows == 0): | ||
t_elapsed = time.process_time() - t_begin | ||
|
@@ -648,7 +683,7 @@ def match( | |
ref, | ||
target, | ||
target_row_match, | ||
'reference', | ||
'target', | ||
), | ||
( | ||
self.config.columns_target_copy, | ||
|
@@ -657,7 +692,7 @@ def match( | |
target, | ||
ref, | ||
ref_row_match, | ||
'target', | ||
'reference', | ||
), | ||
): | ||
matched = matches >= 0 | ||
|
@@ -686,6 +721,12 @@ def match( | |
column_match[matched] = in_original.catalog[column][idx_matched] | ||
out_matched[f'match_{column}'] = column_match | ||
|
||
logger.info( | ||
'Completed match disambiguating in %.2fs (total %.2fs)', | ||
time.process_time() - t_begin, | ||
time.process_time() - t_init, | ||
) | ||
|
||
catalog_out_ref = pd.DataFrame(data_ref) | ||
catalog_out_target = pd.DataFrame(data_target) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a
check
here to specify the valid domain?