Skip to content

Commit

Permalink
add NumDiaSourcesAllMetric, CountUniqueAction, GoodDiaSourceSelector
Browse files Browse the repository at this point in the history
fixing mistake in __all__
  • Loading branch information
erinleighh committed Nov 30, 2022
1 parent 3f7b300 commit 194e816
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 11 deletions.
1 change: 1 addition & 0 deletions pipelines/apCcdVisitQualityCore.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ tasks:
class: lsst.analysis.tools.tasks.DiaSourceTableCcdVisitAnalysisTask
config:
metrics.numDiaSources: NumDiaSourcesMetric
metrics.numDiaSourcesAll: NumDiaSourcesAllMetric
metrics.numDipoles: NumDipolesMetric
metrics.numSsObjects: NumSsObjectsMetric
connections.outputName: diaSourceTableCore
Expand Down
25 changes: 25 additions & 0 deletions python/lsst/analysis/tools/actions/scalar/scalarActions.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,31 @@ def __call__(self, data: KeyedData, **kwargs) -> Scalar:
return cast(Scalar, len(arr))


class CountUniqueAction(ScalarAction):
"""Counts the number of unique rows in a given column.
Parameters
----------
data : `KeyedData`
Returns
-------
count : `Scalar`
The number of unique rows in a given column.
"""

vectorKey = Field[str](doc="Name of column.")

def getInputSchema(self) -> KeyedDataSchema:
return ((self.vectorKey, Vector),)

def __call__(self, data: KeyedData, **kwargs) -> Scalar:
mask = self.getMask(**kwargs)
values = cast(Vector, data[self.vectorKey.format(**kwargs)])[mask]
count = len(np.unique(values))
return cast(Scalar, count)


class ApproxFloor(ScalarAction):
vectorKey = Field[str](doc="Key for the vector to perform action on", optional=False)

Expand Down
28 changes: 28 additions & 0 deletions python/lsst/analysis/tools/actions/vector/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
"SnSelector",
"ExtendednessSelector",
"SkyObjectSelector",
"SkySourceSelector",
"GoodDiaSourceSelector",
"StarSelector",
"GalaxySelector",
"UnknownSelector",
Expand Down Expand Up @@ -310,6 +312,32 @@ def setDefaults(self):
self.selectWhenTrue = ["sky_source"]


class GoodDiaSourceSelector(FlagSelector):
"""Selects good DIA sources from diaSourceTables"""

def getInputSchema(self) -> KeyedDataSchema:
yield from super().getInputSchema()

def __call__(self, data: KeyedData, **kwargs) -> Vector:
result: Optional[Vector] = None
temp = super().__call__(data, **(kwargs))
if result is not None:
result &= temp # type: ignore
else:
result = temp
return result

def setDefaults(self):
self.selectWhenFalse = [
"base_PixelFlags_flag_bad",
"base_PixelFlags_flag_suspect",
"base_PixelFlags_flag_saturatedCenter",
"base_PixelFlags_flag_interpolated",
"base_PixelFlags_flag_interpolatedCenter",
"base_PixelFlags_flag_edge",
]


class ExtendednessSelector(VectorAction):
vectorKey = Field[str](
doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness"
Expand Down
28 changes: 17 additions & 11 deletions python/lsst/analysis/tools/analysisMetrics/apDiaSourceMetrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,29 @@
from __future__ import annotations

__all__ = (
"NumDiaSourcesAllMetric",
"NumDiaSourcesMetric",
"NumDipolesMetric",
)

from ..actions.scalar import CountAction
from ..actions.vector import FlagSelector
from ..actions.vector import FlagSelector, GoodDiaSourceSelector
from ..interfaces import AnalysisMetric


class NumDiaSourcesAllMetric(AnalysisMetric):
"""Calculate the number of DIA Sources."""

def setDefaults(self):
super().setDefaults()

# Count the number of dia sources
self.process.calculateActions.NumDiaSourcesMetricAll = CountAction(vectorKey="diaSourceId")

# the units for the quantity (count, an astropy quantity)
self.produce.units = {"NumDiaSourcesAll": "ct"}


class NumDiaSourcesMetric(AnalysisMetric):
"""Calculate the number of DIA Sources that do not have known
bad/quality flags set to true.
Expand All @@ -38,16 +52,8 @@ class NumDiaSourcesMetric(AnalysisMetric):
def setDefaults(self):
super().setDefaults()

# filter out DIA sources with bad flags
self.prep.selectors.flagSelector = FlagSelector()
self.prep.selectors.flagSelector.selectWhenFalse = [
"base_PixelFlags_flag_bad",
"base_PixelFlags_flag_suspect",
"base_PixelFlags_flag_saturatedCenter",
"base_PixelFlags_flag_interpolated",
"base_PixelFlags_flag_interpolatedCenter",
"base_PixelFlags_flag_edge",
]
# select dia sources that do not have bad flags
self.prep.selectors.goodDiaSourceSelector = GoodDiaSourceSelector()

# Count the number of dia sources left after filtering
self.process.calculateActions.numDiaSources = CountAction(vectorKey="diaSourceId")
Expand Down

0 comments on commit 194e816

Please sign in to comment.