Skip to content

Commit

Permalink
Use primary source flag in source count metrics.
Browse files Browse the repository at this point in the history
  • Loading branch information
kfindeisen committed Apr 26, 2021
1 parent e05d09e commit 7b3e7e9
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 9 deletions.
22 changes: 15 additions & 7 deletions python/lsst/ip/diffim/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ class NumberSciSourcesMetricTask(MetricTask):
Notes
-----
The task excludes any sky sources in the catalog, but it does not require
that the catalog include a ``sky_sources`` column.
The task excludes any non-primary sources in the catalog, but it does
not require that the catalog include a ``detect_isPrimary`` or
``sky_sources`` column.
"""
_DefaultName = "numSciSources"
ConfigClass = NumberSciSourcesMetricConfig
Expand Down Expand Up @@ -128,8 +129,9 @@ class FractionDiaSourcesToSciSourcesMetricTask(MetricTask):
Notes
-----
The task excludes any sky sources in the direct source catalog, but it
does not require that either catalog include a ``sky_sources`` column.
The task excludes any non-primary sources in the catalog, but it does
not require that the catalog include a ``detect_isPrimary`` or
``sky_sources`` column.
"""
_DefaultName = "fracDiaSourcesToSciSources"
ConfigClass = FractionDiaSourcesToSciSourcesMetricConfig
Expand Down Expand Up @@ -171,8 +173,10 @@ def run(self, sciSources, diaSources):
def _countRealSources(catalog):
"""Return the number of valid sources in a catalog.
At present, this definition excludes sky sources. If a catalog does not
have a ``sky_source`` flag, all sources are assumed to be non-sky.
At present, this definition includes only primary sources. If a catalog
does not have a ``detect_isPrimary`` flag, this function counts non-sky
sources. If it does not have a ``sky_source`` flag, either, all sources
are counted.
Parameters
----------
Expand All @@ -184,7 +188,11 @@ def _countRealSources(catalog):
count : `int`
The number of sources that satisfy the criteria.
"""
if "sky_source" in catalog.schema:
# E712 is not applicable, because afw.table.SourceRecord.ColumnView
# is not a bool.
if "detect_isPrimary" in catalog.schema:
return np.count_nonzero(catalog["detect_isPrimary"] == True) # noqa: E712
elif "sky_source" in catalog.schema:
return np.count_nonzero(catalog["sky_source"] == False) # noqa: E712
else:
return len(catalog)
29 changes: 27 additions & 2 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
FractionDiaSourcesToSciSourcesMetricTask


def _makeDummyCatalog(size, skyFlag=False):
def _makeDummyCatalog(size, skyFlag=False, priFlag=False):
"""Create a trivial catalog for testing source counts.
Parameters
Expand All @@ -47,6 +47,8 @@ def _makeDummyCatalog(size, skyFlag=False):
If set, the schema is guaranteed to have the ``sky_source`` flag, and
one row has it set to `True`. If not set, the ``sky_source`` flag is
not present.
priFlag : `bool`
As ``skyFlag``, but for a ``detect_isPrimary`` flag.
Returns
-------
Expand All @@ -55,10 +57,14 @@ def _makeDummyCatalog(size, skyFlag=False):
"""
schema = SourceCatalog.Table.makeMinimalSchema()
if skyFlag:
schema.addField("sky_source", type="Flag", doc="Sky objects.")
schema.addField("sky_source", type="Flag", doc="Sky source.")
if priFlag:
schema.addField("detect_isPrimary", type="Flag", doc="Primary source.")
catalog = SourceCatalog(schema)
for i in range(size):
record = catalog.addNew()
if priFlag and size > 0:
record["detect_isPrimary"] = True
if skyFlag and size > 0:
record["sky_source"] = True
return catalog
Expand Down Expand Up @@ -97,6 +103,15 @@ def testSkySources(self):
self.assertEqual(meas.metric_name, Name(metric="ip_diffim.numSciSources"))
assert_quantity_allclose(meas.quantity, (len(catalog) - 1) * u.count)

def testPrimarySources(self):
catalog = _makeDummyCatalog(3, priFlag=True)
result = self.task.run(catalog)
lsst.pipe.base.testUtils.assertValidOutput(self.task, result)
meas = result.measurement

self.assertEqual(meas.metric_name, Name(metric="ip_diffim.numSciSources"))
assert_quantity_allclose(meas.quantity, 1 * u.count)

def testMissingData(self):
result = self.task.run(None)
lsst.pipe.base.testUtils.assertValidOutput(self.task, result)
Expand Down Expand Up @@ -165,6 +180,16 @@ def testSkySources(self):
assert_quantity_allclose(meas.quantity,
len(diaCatalog) / (len(sciCatalog) - 1) * u.dimensionless_unscaled)

def testPrimarySources(self):
sciCatalog = _makeDummyCatalog(5, skyFlag=True, priFlag=True)
diaCatalog = _makeDummyCatalog(3)
result = self.task.run(sciCatalog, diaCatalog)
lsst.pipe.base.testUtils.assertValidOutput(self.task, result)
meas = result.measurement

self.assertEqual(meas.metric_name, Name(metric="ip_diffim.fracDiaSourcesToSciSources"))
assert_quantity_allclose(meas.quantity, len(diaCatalog) * u.dimensionless_unscaled)


# Hack around unittest's hacky test setup system
del MetricTaskTestCase
Expand Down

0 comments on commit 7b3e7e9

Please sign in to comment.