Skip to content

Commit

Permalink
Add DiffMatchedTractCatalogTask
Browse files Browse the repository at this point in the history
  • Loading branch information
taranu committed Mar 11, 2022
1 parent 3589952 commit 0a315a1
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 164 deletions.
157 changes: 74 additions & 83 deletions python/lsst/pipe/tasks/diff_matched_tract_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,11 @@
import lsst.pipe.base as pipeBase
import lsst.pipe.base.connectionTypes as cT
from lsst.skymap import BaseSkyMap
from .statistic import Median, Percentile, StandardDeviation, SigmaIQR, SigmaMAD, Statistics

from abc import ABCMeta, abstractmethod
from astropy.stats import mad_std
from dataclasses import dataclass
from enum import Enum, auto
import numpy as np
import pandas as pd
from scipy.stats import iqr
from typing import Dict, Set


Expand Down Expand Up @@ -90,11 +87,17 @@ class DiffMatchedTractCatalogConnections(
deferLoad=True,
)
cat_diff_matched = cT.Output(
doc="Matched catalog with aggregated counts and diff statistics",
doc="Table with aggregated counts, difference and chi statistics",
name="diff_matched_{name_input_cat_ref}_{name_input_cat_target}",
storageClass="DataFrame",
dimensions=("tract", "skymap"),
)
cat_matched = cT.Output(
doc="Catalog with reference and target columns for matched sources only",
name="matched_{name_input_cat_ref}_{name_input_cat_target}",
storageClass="DataFrame",
dimensions=("tract", "skymap"),
)


class MatchedCatalogFluxesConfig(pexConfig.Config):
Expand Down Expand Up @@ -146,8 +149,14 @@ class DiffMatchedTractCatalogConfig(
def columns_in_ref(self) -> Set[str]:
columns_all = [self.coord_format.column_ref_coord1, self.coord_format.column_ref_coord2,
self.column_ref_extended]
for columns in (x.columns_in_ref for x in self.columns_flux.values()):
columns_all.extend(columns)
for columns_list in (
(
self.columns_ref_copy,
),
(x.columns_in_ref for x in self.columns_flux.values()),
):
for columns in columns_list:
columns_all.extend(columns)

return set(columns_all)

Expand All @@ -157,15 +166,16 @@ def columns_in_target(self) -> Set[str]:
self.column_target_extended]
if self.coord_format.coords_ref_to_convert is not None:
columns_all.extend(self.coord_format.coords_ref_to_convert.values())
for column_list in (
for columns_list in (
(
self.columns_target_coord_err,
self.columns_target_select_false,
self.columns_target_select_true,
self.columns_target_copy,
),
(x.columns_in_target for x in self.columns_flux.values()),
):
for columns in column_list:
for columns in columns_list:
columns_all.extend(columns)
return set(columns_all)

Expand All @@ -174,11 +184,21 @@ def columns_in_target(self) -> Set[str]:
itemtype=MatchedCatalogFluxesConfig,
doc="Configs for flux columns for each band",
)
columns_ref_copy = pexConfig.ListField(
dtype=str,
default=set(),
doc='Reference table columns to copy to copy into cat_matched',
)
columns_target_coord_err = pexConfig.ListField(
dtype=str,
listCheck=lambda x: (len(x) == 2) and (x[0] != x[1]),
doc='Target table coordinate columns with standard errors (sigma)',
)
columns_target_copy = pexConfig.ListField(
dtype=str,
default=('patch',),
doc='Target table columns to copy to copy into cat_matched',
)
columns_target_select_true = pexConfig.ListField(
dtype=str,
default=('detect_isPrimary',),
Expand Down Expand Up @@ -229,65 +249,25 @@ def columns_in_target(self) -> Set[str]:
default=31.4,
doc='Magnitude zeropoint for target sources',
)
percentiles = pexConfig.DictField(
keytype=str,
itemtype=float,
default={'p05': 5., 'p16': 16., 'p84': 84., 'p95': 95., },
doc='Names and values of percentiles to compute',
)
statistics = pexConfig.ListField(
dtype=str,
default=[Median.name(), StandardDeviation.name(), SigmaIQR.name(), SigmaMAD.name()],
listCheck=lambda x: (len(set(x)) == len(x)) and all((n in Statistics for n in x)),
doc='Names of statistics to compute'
)


class Measurement(Enum):
DIFF = auto()
CHI = auto()


class Statistic(metaclass=ABCMeta):
"""A statistic that can be applied to a set of values.
"""
@abstractmethod
def value(self, values):
"""Return the value of the statistic given a set of values.
Parameters
----------
values : `Collection` [`float`]
A set of values to compute the statistic for.
Returns
-------
statistic : `float`
The value of the statistic.
"""
pass


class Median(Statistic):
"""The median of a set of values."""
def value(self, values):
return np.median(values)


class SigmaIQR(Statistic):
"""The re-scaled inter-quartile range (sigma equivalent)."""
def value(self, values):
return iqr(values, scale='normal')


class SigmaMAD(Statistic):
"""The re-scaled median absolute deviation (sigma equivalent)."""
def value(self, values):
return mad_std(values)


@dataclass(frozen=True)
class Percentile(Statistic):
"""An arbitrary percentile.
Parameters
----------
percentile : `float`
A valid percentile (0 <= p <= 100).
"""
percentile: float

def value(self, values):
return np.percentile(values, self.percentile)


def compute_stats(values_ref, values_target, errors_target, row, stats, suffixes, prefix, skip_diff=False):
"""Compute statistics on differences and store results in a row.
Expand All @@ -303,9 +283,8 @@ def compute_stats(values_ref, values_target, errors_target, row, stats, suffixes
A numpy array with pre-assigned column names.
stats : `Dict` [`str`, `Statistic`]
A dict of `Statistic` values to measure, keyed by their column suffix.
suffixes : `Dict` [`str`, `Measurement`]
A dict of measurement types are the only valid values),
keyed by the column suffix.
suffixes : `Dict` [`Measurement`, `str`]
A dict of suffix, keyed by the `Measurement` type.
prefix : `str`
A prefix for all column names (e.g. band).
skip_diff : `bool`
Expand Down Expand Up @@ -383,9 +362,9 @@ def _get_columns(bands_columns: Dict, suffixes: Dict, suffixes_flux: Dict, suffi

bands = list(bands_columns.keys())
for idx, (band, config_flux) in enumerate(bands_columns.items()):
columns_suffix = [('_flux', suffixes_flux), ('_mag', suffixes_mag), ]
columns_suffix = [('flux', suffixes_flux), ('mag', suffixes_mag), ]
if idx > 0:
columns_suffix.append((f'_color_{bands[idx - 1]}-{band}', suffixes))
columns_suffix.append((f'color_{bands[idx - 1]}-{band}', suffixes))
else:
n_models = len(config_flux.columns_target_flux)
n_models_flux = len(config_flux.columns_target_flux)
Expand All @@ -401,22 +380,27 @@ def _get_columns(bands_columns: Dict, suffixes: Dict, suffixes_flux: Dict, suffi
if subtype != '':
for item in (f'n_{itype}{mtype}' for itype in ('ref', 'target')
for mtype in ('', '_match_right', '_match_wrong')):
columns[f'{band}{subtype}_{item}'] = int
columns[format_column(band, subtype, item)] = int

for item in (target.column_coord1, target.column_coord2, column_dist):
for suffix in suffixes.values():
for stat in stats.keys():
columns[f'{band}{subtype}_{item}{suffix}{stat}'] = float
columns[format_column(band, subtype, f'{item}{suffix}{stat}')] = float

for item in config_flux.columns_target_flux:
for prefix_item, suffixes_col in columns_suffix:
for suffix in suffixes_col.values():
for stat in stats.keys():
columns[f'{band}{subtype}{prefix_item}_{item}{suffix}{stat}'] = float
columns[
format_column(band, subtype, f'{prefix_item}_{item}{suffix}{stat}')] = float

return columns, n_models


def format_column(band, subtype, column):
return f'{band}{subtype}_{column}'


class DiffMatchedTractCatalogTask(pipeBase.PipelineTask):
"""Match sources in a reference tract catalog with those in a target catalog.
"""
Expand Down Expand Up @@ -500,6 +484,17 @@ def run(
matched_target = np.zeros(n_target, dtype=bool)
matched_target[matched_row] = True

# Create a matched table, preserving the target catalog's named index (if it has one)
cat_left = cat_target.iloc[matched_row]
has_index_left = cat_left.index.name is not None
cat_right = cat_ref[matched_ref].reset_index()
cat_matched = pd.concat((cat_left.reset_index(drop=has_index_left), cat_right), 1)
if has_index_left:
cat_matched.index = cat_left.index
cat_matched.columns.values[len(cat_target.columns):] = [f'ref_{col}' for col in cat_right.columns]
del cat_left
del cat_right

# Add/compute distance columns
coord1_target_err, coord2_target_err = config.columns_target_coord_err
column_dist, column_dist_err = 'dr', 'drErr'
Expand Down Expand Up @@ -531,12 +526,8 @@ def run(
suffixes_flux = {Measurement.CHI: suffixes[Measurement.CHI]}
# Skip chi for magnitudes, which have strange errors
suffixes_mag = {Measurement.DIFF: suffixes[Measurement.DIFF]}
stats = {
'_median': Median(),
'_sig_iqr': SigmaIQR(),
'_sig_mad': SigmaMAD(),
}
for name, percentile in (('p05', 5.), ('p16', 16.), ('p84', 84.), ('p95', 95.)):
stats = {f'_{name}': Statistics[name]() for name in config.statistics}
for name, percentile in config.percentiles.items():
stats[f'_{name}'] = Percentile(percentile=percentile)

# Get dict of column names
Expand Down Expand Up @@ -651,8 +642,8 @@ def run(
select_target_sub &= (extended_target == is_extended)
n_ref_sub = np.count_nonzero(select_ref_sub)
n_target_sub = np.count_nonzero(select_target_sub)
row[f'{band}{subtype}_n_ref'] = n_ref_sub
row[f'{band}{subtype}_n_target'] = n_target_sub
row[format_column(band, subtype, 'n_ref')] = n_ref_sub
row[format_column(band, subtype, 'n_target')] = n_target_sub

# Filter matches by magnitude bin and true class
match_row_bin = match_row.copy()
Expand All @@ -670,8 +661,8 @@ def run(
right_type = extended_target[rows_matched] == is_extended
n_total = len(right_type)
n_right = np.count_nonzero(right_type)
row[f'{band}{subtype}_n_ref_match_right'] = n_right
row[f'{band}{subtype}_n_ref_match_wrong'] = n_total - n_right
row[format_column(band, subtype, 'n_ref_match_right')] = n_right
row[format_column(band, subtype, 'n_ref_match_wrong')] = n_total - n_right

# compute stats for this bin, for all columns
for column, (column_ref, column_target, column_err_target, skip_diff) \
Expand All @@ -685,7 +676,7 @@ def run(
row,
stats,
suffixes,
prefix=f'{band}{subtype}_{column}',
prefix=format_column(band, subtype, column),
skip_diff=skip_diff,
)

Expand All @@ -701,8 +692,8 @@ def run(
right_type[match_row[matched_ref & is_extended_ref]] = True
right_type &= select_target_sub
n_right = np.count_nonzero(right_type)
row[f'{band}{subtype}_n_target_match_right'] = n_right
row[f'{band}{subtype}_n_target_match_wrong'] = n_total - n_right
row[format_column(band, subtype, 'n_target_match_right')] = n_right
row[format_column(band, subtype, 'n_target_match_wrong')] = n_total - n_right

# delete the flux/color columns since they change with each band
for prefix in ('flux_', 'mag_'):
Expand All @@ -715,5 +706,5 @@ def run(
mag_prev = mag_model
band_prev = band

retStruct = pipeBase.Struct(cat_diff_matched=pd.DataFrame(data))
retStruct = pipeBase.Struct(cat_diff_matched=pd.DataFrame(data), cat_matched=cat_matched)
return retStruct

0 comments on commit 0a315a1

Please sign in to comment.