-
Notifications
You must be signed in to change notification settings - Fork 20
DM-52667: Add model_extendedness columns #1191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,7 +33,9 @@ | |
| "MakeVisitTableConfig", "MakeVisitTableTask", | ||
| "WriteForcedSourceTableConfig", "WriteForcedSourceTableTask", | ||
| "TransformForcedSourceTableConfig", "TransformForcedSourceTableTask", | ||
| "ConsolidateTractConfig", "ConsolidateTractTask"] | ||
| "ConsolidateTractConfig", "ConsolidateTractTask", | ||
| "ComputeColumnsAction", "ModelExtendednessColumnAction", | ||
| ] | ||
|
|
||
| from collections import defaultdict | ||
| import dataclasses | ||
|
|
@@ -42,16 +44,19 @@ | |
| import logging | ||
| import numbers | ||
| import os | ||
| from typing import Iterable | ||
|
|
||
| import numpy as np | ||
| import pandas as pd | ||
| import astropy.table | ||
| from numpy.typing import NDArray | ||
|
|
||
| import lsst.geom | ||
| import lsst.pex.config as pexConfig | ||
| import lsst.pipe.base as pipeBase | ||
| import lsst.daf.base as dafBase | ||
| from lsst.daf.butler.formatters.parquet import pandas_to_astropy | ||
| from lsst.pex.config.configurableActions import ConfigurableAction, ConfigurableActionStructField | ||
| from lsst.pipe.base import NoWorkFound, UpstreamFailureNoWorkFound, connectionTypes | ||
| import lsst.afw.table as afwTable | ||
| from lsst.afw.image import ExposureSummaryStats, ExposureF | ||
|
|
@@ -1088,6 +1093,200 @@ def run(self, handle, funcs=None, dataId=None, band=None, **kwargs): | |
| return pipeBase.Struct(outputCatalog=tbl) | ||
|
|
||
|
|
||
| class ComputeColumnsAction(ConfigurableAction): | ||
| """An action that computes multiple vectors from an input. | ||
|
|
||
| This class is meant to be compatible with analysis_tools' | ||
| AnalysisAction class, which cannot be a dependency of pipe_tasks.""" | ||
|
|
||
| def getInputSchema(self) -> dict[str, type[NDArray]]: | ||
| """Return the required inputs for this action. | ||
|
|
||
| This function is meant to be compatible with | ||
| """ | ||
| raise NotImplementedError("This method must be overloaded in subclasses") | ||
|
|
||
| def __call__(self, table: astropy.table.Table) -> dict[str, NDArray]: | ||
| """This method must return a dict of computed columns.""" | ||
| raise NotImplementedError("This method must be overloaded in subclasses") | ||
|
|
||
|
|
||
| class ExtendednessColumnActionBase(ComputeColumnsAction): | ||
| bands = pexConfig.ListField[str]( | ||
| doc="The bands to make single-band outputs for.", | ||
| default=["u", "g", "r", "i", "z", "y"] | ||
| ) | ||
| bands_combined = pexConfig.DictField[str, str]( | ||
| doc="Multiband classification column specialization. Keys specify the" | ||
| " name of the column and values are a comma-separated list of" | ||
| " bands, all of which must be contained in the bands listed.", | ||
| default={"griz": "g,r,i,z"}, | ||
| itemCheck=lambda x: (len(y := x.split(",")) > 1) & (len(set(y)) == len(y)), | ||
| ) | ||
| model_column_flux = pexConfig.Field[str]( | ||
| doc="The model flux column to use for computing the difference to" | ||
| " to the S/N flux. Must contain the {band} and {model} templates.", | ||
| default="{band}_{model}Flux", | ||
| check=lambda x: ("{band}" in x) and ("{model}" in x,), | ||
| ) | ||
| model_column_flux_err = pexConfig.Field[str]( | ||
| doc="The model flux error column to use for computing the difference" | ||
| " to the S/N flux. Must contain the {band} and {model} templates.", | ||
| default="{band}_{model}FluxErr", | ||
| check=lambda x: ("{band}" in x) and ("{model}" in x,), | ||
| ) | ||
| model_flux_name = pexConfig.Field[str]( | ||
| doc="The extended object model to use to compared to PSF model fluxes", | ||
| default="sersic", | ||
| ) | ||
| output_column = pexConfig.Field[str]( | ||
| doc="Name of the output column. Must contain the {band} template", | ||
| default="{band}_model_extendedness", | ||
| check=lambda x: "{band}" in x, | ||
| ) | ||
| psf_column_flux = pexConfig.Field[str]( | ||
| doc="The name of the PSF flux column. Must contain the {band} template.", | ||
| default="{band}_psfFlux", | ||
| check=lambda x: "{band}" in x, | ||
| ) | ||
| psf_column_flux_err = pexConfig.Field[str]( | ||
| doc="The name of the PSF flux error column. Must contain the {band} template.", | ||
| default="{band}_psfFluxErr", | ||
| check=lambda x: "{band}" in x, | ||
| ) | ||
| size_column = pexConfig.Field[str]( | ||
| doc="The column to use for applying size cuts. Must contain the {axis} template.", | ||
| default="exponential_reff_{axis}", | ||
| ) | ||
|
|
||
| def getInputSchema(self) -> Iterable[tuple[str, type[NDArray]]]: | ||
| size_column = self.size_column | ||
| schema = [ | ||
| (size_column.format(axis=axis), NDArray[float]) for axis in ("x", "y") | ||
| ] | ||
| model = self.model_flux_name | ||
| for column in ( | ||
| self.psf_column_flux, self.psf_column_flux_err, | ||
| self.model_column_flux, self.model_column_flux_err, | ||
| ): | ||
| schema.extend([ | ||
| (column.format(band=band, model=model), NDArray[float]) for band in self.bands | ||
| ]) | ||
|
|
||
| return schema | ||
|
|
||
| def validate(self): | ||
| super().validate() | ||
| errors = [] | ||
| for name, band_combined in self.bands_combined.items(): | ||
| bands = band_combined.split(",") | ||
| bands_missing = [band for band in bands if band not in self.bands] | ||
| if bands_missing: | ||
| errors.append( | ||
| f"self.bands_combined[{name}] contains bands={bands_missing} not in {self.bands=}" | ||
| ) | ||
| if errors: | ||
| raise ValueError(f"Validation failed due to errors: {'; '.join(errors)}") | ||
|
|
||
| def _get_fluxes(self, table, band: str): | ||
| model = self.model_flux_name | ||
| flux_psf, fluxerr_psf, flux_model, fluxerr_model = ( | ||
| np.array(table[column.format(band=band, model=model)]) | ||
| for column in ( | ||
| self.psf_column_flux, self.psf_column_flux_err, | ||
| self.model_column_flux, self.model_column_flux_err, | ||
| ) | ||
| ) | ||
| return flux_psf, fluxerr_psf, flux_model, fluxerr_model | ||
|
|
||
|
|
||
| class ModelExtendednessColumnAction(ExtendednessColumnActionBase): | ||
| fluxerr_coefficent = pexConfig.Field[float]( | ||
| doc="The coefficient to multiply the flux error by when adding to the model flux.", | ||
| default=0.5, | ||
| check=lambda x: x >= 0, | ||
| ) | ||
| fluxerr_stretch = pexConfig.Field[float]( | ||
| doc="The factor to multiply flux error-scaled ratios by to derive extendedness.", | ||
| default=5.0, | ||
| check=lambda x: x > 0, | ||
| ) | ||
| good_sn_min = pexConfig.Field[float]( | ||
| doc="Minimum PSF S/N to include objects if" | ||
| " min_n_good_to_shift_flux_ratio is > 0; ignored otherwise.", | ||
| default=10., | ||
| ) | ||
| max_reff_compact = pexConfig.Field[float]( | ||
| doc="The maximum effective radius in pixels below which an object is" | ||
| " classified as not extended, regardless of other parameter values.", | ||
| default=0.25, | ||
| ) | ||
| min_n_good_to_shift_flux_ratio = pexConfig.Field[int]( | ||
| doc="Minimum number of objects with PSF S/N > good_sn_min and with " | ||
| " size larger than max_reff_compact to use to compute the median " | ||
| " PSF-to-model flux ratio, which is assumed to be 1 otherwise." | ||
| " If this value is not >0, the median flux ratio will be kept 1.", | ||
| default=0, | ||
| ) | ||
|
|
||
| def __call__(self, table: astropy.table.Table) -> dict[str, NDArray]: | ||
| size_column = self.size_column | ||
| size_model = np.sqrt( | ||
| 0.5*(table[size_column.format(axis='x')]**2 + table[size_column.format(axis='y')]**2) | ||
| ) | ||
| small = size_model < self.max_reff_compact | ||
| n_obj = len(table) | ||
| band_mappings_to_process = {band: [band] for band in self.bands} | ||
| band_mappings_to_process.update({k: v.split(",") for k, v in self.bands_combined.items()}) | ||
| output = {} | ||
| for output_band, input_bands in band_mappings_to_process.items(): | ||
| if len(input_bands) > 1: | ||
| flux_psf, fluxerr_psf_sq, flux_model, fluxerr_model_sq = ( | ||
| np.zeros(n_obj, dtype=float) for _ in range(4)) | ||
| for input_band in input_bands: | ||
| flux_psf_b, fluxerr_psf_b, flux_model_b, fluxerr_model_b = self._get_fluxes( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If _get_tuple returned a class, you could have a free function that took a list of instances of those classes and returned the sum. I think that'd keep the main flow of the computation cleaner. |
||
| table, band=input_band) | ||
| # There's no point adding S/N < 0 fluxes | ||
| good = np.isfinite(flux_psf_b) & np.isfinite(flux_model_b) & ( | ||
| flux_psf_b > 0) & (flux_model_b > 0) & (fluxerr_psf_b > 0) & (fluxerr_model_b > 0) | ||
| flux_psf[good] += flux_psf_b[good] | ||
| flux_model[good] += flux_model_b[good] | ||
| fluxerr_psf_sq[good] += fluxerr_psf_b[good]**2 | ||
| fluxerr_model_sq[good] += fluxerr_model_b[good]**2 | ||
| fluxerr_psf = np.sqrt(fluxerr_psf_sq) | ||
| fluxerr_model = np.sqrt(fluxerr_model_sq) | ||
| fluxerr_psf[fluxerr_psf == 0] = np.inf | ||
| fluxerr_model[fluxerr_model == 0] = np.inf | ||
| else: | ||
| flux_psf, fluxerr_psf, flux_model, fluxerr_model = self._get_fluxes( | ||
| table, band=input_bands[0]) | ||
|
|
||
| psf_sn = flux_psf/fluxerr_psf | ||
| flux_ratio = np.array(flux_psf / flux_model) | ||
|
|
||
| if self.min_n_good_to_shift_flux_ratio > 0: | ||
| good = small & (psf_sn > self.good_sn_min) | ||
| # Attempt to correct any flux-independent systematic offset | ||
| # Might need to be a function of S/N | ||
| if np.sum(good == True) > self.min_n_good_to_shift_flux_ratio: # noqa: E712 | ||
| flux_ratio *= 1./np.nanmedian(flux_ratio[good]) | ||
|
|
||
| flux_ratio_err = np.sqrt( | ||
| (fluxerr_psf/flux_model)**2 + (fluxerr_model*fluxerr_psf/flux_model**2)**2 | ||
| ) | ||
| extendedness = (1 - flux_ratio) + self.fluxerr_coefficent*flux_ratio_err | ||
| extendedness *= np.sqrt(size_model/self.max_reff_compact) | ||
| extendedness[(extendedness < 0) & (extendedness > -np.inf)] = 0 | ||
| # Make it sigmoid-like with a stretch | ||
| stretch = self.fluxerr_stretch | ||
| extendedness *= stretch | ||
| extendedness = np.clip((stretch + 1)/stretch*extendedness/(1 + extendedness), 0, 1) | ||
|
|
||
| column_out = self.output_column.format(band=output_band) | ||
| output[column_out] = extendedness | ||
| return output | ||
|
|
||
|
|
||
| class ConsolidateObjectTableConnections(pipeBase.PipelineTaskConnections, | ||
| dimensions=("tract", "skymap")): | ||
| inputCatalogs = connectionTypes.Input( | ||
|
|
@@ -1108,12 +1307,18 @@ class ConsolidateObjectTableConnections(pipeBase.PipelineTaskConnections, | |
|
|
||
| class ConsolidateObjectTableConfig(pipeBase.PipelineTaskConfig, | ||
| pipelineConnections=ConsolidateObjectTableConnections): | ||
| coaddName = pexConfig.Field( | ||
| dtype=str, | ||
| actions = ConfigurableActionStructField[ComputeColumnsAction]( | ||
| doc="Actions to add columns to the final object table", | ||
| ) | ||
| coaddName = pexConfig.Field[str]( | ||
| default="deep", | ||
| doc="Name of coadd" | ||
| ) | ||
|
|
||
| def setDefaults(self): | ||
| super().setDefaults() | ||
| self.actions.extendedness = ModelExtendednessColumnAction() | ||
|
|
||
|
|
||
| class ConsolidateObjectTableTask(pipeBase.PipelineTask): | ||
| """Write patch-merged source tables to a tract-level DataFrame Parquet file. | ||
|
|
@@ -1131,6 +1336,10 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): | |
| self.log.info("Concatenating %s per-patch Object Tables", | ||
| len(inputs["inputCatalogs"])) | ||
| table = TableVStack.vstack_handles(inputs["inputCatalogs"]) | ||
| for action in self.config.actions: | ||
| computed = action(table) | ||
| for key, values in computed.items(): | ||
| table[key] = values.astype(np.float32) if values.dtype == np.float64 else values | ||
| butlerQC.put(pipeBase.Struct(outputCatalog=table), outputRefs) | ||
|
|
||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| # This file is part of pipe_tasks. | ||
| # | ||
| # Developed for the LSST Data Management System. | ||
| # This product includes software developed by the LSST Project | ||
| # (https://www.lsst.org). | ||
| # See the COPYRIGHT file at the top-level directory of this distribution | ||
| # for details of code ownership. | ||
| # | ||
| # This program is free software: you can redistribute it and/or modify | ||
| # it under the terms of the GNU General Public License as published by | ||
| # the Free Software Foundation, either version 3 of the License, or | ||
| # (at your option) any later version. | ||
| # | ||
| # This program is distributed in the hope that it will be useful, | ||
| # but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
| # GNU General Public License for more details. | ||
| # | ||
| # You should have received a copy of the GNU General Public License | ||
| # along with this program. If not, see <https://www.gnu.org/licenses/>. | ||
|
|
||
| import pytest | ||
| import unittest | ||
|
|
||
| import lsst.utils.tests | ||
| from lsst.pipe.tasks.postprocess import ModelExtendednessColumnAction | ||
|
|
||
| from astropy.table import Table | ||
| import numpy as np | ||
|
|
||
|
|
||
| class ModelExtendednessColumnActionTestCase(lsst.utils.tests.TestCase): | ||
| """Demo test case.""" | ||
|
|
||
| def setUp(self): | ||
| self.bands = ("g", "r", "i") | ||
| action = ModelExtendednessColumnAction(bands=self.bands, min_n_good_to_shift_flux_ratio=1) | ||
| self.action = action | ||
| data = { | ||
| action.size_column.format(axis="x"): [1., 3., 0.01, 0.4, 0.02], | ||
| action.size_column.format(axis="y"): [0.2, 6, 0.01, 0.2, 0.01], | ||
| } | ||
| model = action.model_flux_name | ||
| factors = np.array([1.01, 1.12, 0.995, 0.998, 1.6]) | ||
| fluxes = np.array([1.5e3, 2.5e3, 6.8e3, 3.4e3, 5.5e3]) | ||
| for column_flux, column_flux_err, factors in ( | ||
| (action.model_column_flux, action.model_column_flux_err, factors), | ||
| (action.psf_column_flux, action.psf_column_flux_err, None), | ||
| ): | ||
| for idx, band in enumerate(action.bands): | ||
| flux = np.sqrt(idx + 1.)*fluxes | ||
| if factors is not None: | ||
| flux *= factors | ||
| data[column_flux.format(band=band, model=model)] = flux | ||
| data[column_flux_err.format(band=band, model=model)] = np.sqrt(flux) | ||
| self.data = Table(data) | ||
|
|
||
| def testExtendednessColumnAction(self): | ||
| action = self.action | ||
| with pytest.raises(ValueError): | ||
| action.validate() | ||
| action.bands_combined = {"gri": "g,r,i"} | ||
| action.validate() | ||
| schema = action.getInputSchema() | ||
| n_values = len(self.data[schema[0][0]]) | ||
| assert all([len(self.data[col]) == n_values for col, _ in schema[1:]]) | ||
|
|
||
| result = self.action(self.data) | ||
| columns_expected = [ | ||
| action.output_column.format(band=band) | ||
| for band in list(action.bands) + list(action.bands_combined.keys()) | ||
| ] | ||
| assert list(result.keys()) == columns_expected | ||
|
|
||
| for column, values in result.items(): | ||
| assert len(values) == n_values | ||
| assert all((values >= 0) & (values <= 1)) | ||
|
|
||
|
|
||
| class MemoryTester(lsst.utils.tests.MemoryTestCase): | ||
| pass | ||
|
|
||
|
|
||
| def setup_module(module): | ||
| lsst.utils.tests.init() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| lsst.utils.tests.init() | ||
| unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This four-tuple is a good candidate for making a tiny class, so you can pass these as a single unit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, it could be a dataclass and I used to be more enthusiastic about using those in such cases, but I'd prefer to leave this be for now.