Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-40412: Add edge flag filtering to trailedSourceFilter.py #180

Merged
merged 2 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions data/association-flag-map.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,6 @@ columns:
- name: slot_Shape_flag_parent_source
bit: 26
doc: parent source, ignored; only valid for HsmShape
- name: ext_trailedSources_Naive_flag_edge
bit: 27
doc: source is trailed and extends off chip
31 changes: 20 additions & 11 deletions python/lsst/ap/association/association.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,32 +112,41 @@ def run(self,
matched to new DiaSources. (`int`)
- ``nUnassociatedDiaObjects`` : Number of DiaObjects that were
not matched a new DiaSource. (`int`)
bsmartradio marked this conversation as resolved.
Show resolved Hide resolved
- ``longTrailedSources`` : DiaSources which have trail lengths
greater than max_trail_length/second*exposure_time.
(`pandas.DataFrame``)
"""
diaSources = self.check_dia_source_radec(diaSources)

if self.config.doTrailedSourceFilter:
diaTrailedResult = self.trailedSourceFilter.run(diaSources, exposure_time)
diaSources = diaTrailedResult.diaSources
longTrailedSources = diaTrailedResult.longTrailedDiaSources

self.log.info("%i DiaSources exceed max_trail_length, dropping from source "
"catalog." % len(diaTrailedResult.longTrailedDiaSources))
self.metadata.add("num_filtered", len(diaTrailedResult.longTrailedDiaSources))
else:
longTrailedSources = pd.DataFrame(columns=diaSources.columns)

if len(diaObjects) == 0:
return pipeBase.Struct(
matchedDiaSources=pd.DataFrame(columns=diaSources.columns),
unAssocDiaSources=diaSources,
nUpdatedDiaObjects=0,
nUnassociatedDiaObjects=0)
nUnassociatedDiaObjects=0,
longTrailedSources=longTrailedSources)

if self.config.doTrailedSourceFilter:
diaTrailedResult = self.trailedSourceFilter.run(diaSources, exposure_time)
matchResult = self.associate_sources(diaObjects, diaTrailedResult.diaSources)

self.log.info("%i DIASources exceed max_trail_length, dropping "
"from source catalog." % len(diaTrailedResult.trailedDiaSources))

else:
matchResult = self.associate_sources(diaObjects, diaSources)
matchResult = self.associate_sources(diaObjects, diaSources)
bsmartradio marked this conversation as resolved.
Show resolved Hide resolved

mask = matchResult.diaSources["diaObjectId"] != 0

return pipeBase.Struct(
matchedDiaSources=matchResult.diaSources[mask].reset_index(drop=True),
unAssocDiaSources=matchResult.diaSources[~mask].reset_index(drop=True),
nUpdatedDiaObjects=matchResult.nUpdatedDiaObjects,
nUnassociatedDiaObjects=matchResult.nUnassociatedDiaObjects)
nUnassociatedDiaObjects=matchResult.nUnassociatedDiaObjects,
longTrailedSources=longTrailedSources)

def check_dia_source_radec(self, dia_sources):
"""Check that all DiaSources have non-NaN values for RA/DEC.
Expand Down
31 changes: 23 additions & 8 deletions python/lsst/ap/association/trailedSourceFilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,14 @@

__all__ = ("TrailedSourceFilterTask", "TrailedSourceFilterConfig")

import os
import numpy as np

import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
from lsst.utils.timer import timeMethod
from lsst.ap.association.transformDiaSourceCatalog import UnpackApdbFlags
import lsst.utils as utils


class TrailedSourceFilterConfig(pexConfig.Config):
Expand Down Expand Up @@ -72,39 +77,49 @@ def run(self, dia_sources, exposure_time):
result : `lsst.pipe.base.Struct`
Results struct with components.

- ``dia_sources`` : DIASource table that is free from unwanted
- ``diaSources`` : DIASource table that is free from unwanted
trailed sources. (`pandas.DataFrame`)

- ``trailed_dia_sources`` : DIASources that have trails which
exceed max_trail_length/second*exposure_time.
- ``longTrailedDiaSources`` : DIASources that have trails which
exceed max_trail_length/second*exposure_time (seconds).
(`pandas.DataFrame`)
"""
trail_mask = self._check_dia_source_trail(dia_sources, exposure_time)

flag_map = os.path.join(utils.getPackageDir("ap_association"), "data/association-flag-map.yaml")
unpacker = UnpackApdbFlags(flag_map, "DiaSource")
flags = unpacker.unpack(dia_sources["flags"], "flags")

trail_mask = self._check_dia_source_trail(dia_sources, exposure_time, flags)

return pipeBase.Struct(
diaSources=dia_sources[~trail_mask].reset_index(drop=True),
trailedDiaSources=dia_sources[trail_mask].reset_index(drop=True))
longTrailedDiaSources=dia_sources[trail_mask].reset_index(drop=True))

def _check_dia_source_trail(self, dia_sources, exposure_time):
def _check_dia_source_trail(self, dia_sources, exposure_time, flags):
"""Find DiaSources that have long trails.

Return a mask of sources with lengths greater than
``config.max_trail_length`` multiplied by the exposure time.
``config.max_trail_length`` multiplied by the exposure time in seconds
or have ext_trailedSources_Naive_flag_edge set.

Parameters
----------
dia_sources : `pandas.DataFrame`
Input DIASources to check for trail lengths.
exposure_time : `float`
Exposure time from difference image.
bsmartradio marked this conversation as resolved.
Show resolved Hide resolved
flags : 'numpy.ndArray'
Boolean array of flags from the DIASources.

Returns
-------
trail_mask : `pandas.DataFrame`
Boolean mask for DIASources which are greater than the
cutoff length.
cutoff length and have the edge flag set.
"""
trail_mask = (dia_sources.loc[:, "trailLength"].values[:]
>= (self.config.max_trail_length*exposure_time))

trail_mask[np.where(flags['ext_trailedSources_Naive_flag_edge'])] = True

return trail_mask
8 changes: 5 additions & 3 deletions tests/test_association_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import numpy as np
import pandas as pd
import unittest

import lsst.geom as geom
import lsst.utils.tests

from lsst.ap.association import AssociationTask


Expand All @@ -45,12 +45,14 @@ def setUp(self):
self.diaSources = pd.DataFrame(data=[
{"ra": 0.04*idx + scatter*rng.uniform(-1, 1),
"dec": 0.04*idx + scatter*rng.uniform(-1, 1),
"diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx}
"diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx,
"flags": 0}
for idx in range(self.nSources)])
self.diaSourceZeroScatter = pd.DataFrame(data=[
{"ra": 0.04*idx,
"dec": 0.04*idx,
"diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx}
"diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx,
"flags": 0}
for idx in range(self.nSources)])
self.exposure_time = 30.0

Expand Down
65 changes: 55 additions & 10 deletions tests/test_trailedSourceFilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.

import unittest
from lsst.ap.association import TrailedSourceFilterTask
import os
import numpy as np
import pandas as pd

import lsst.utils.tests
import lsst.utils as utils
from lsst.ap.association import TrailedSourceFilterTask
from lsst.ap.association.transformDiaSourceCatalog import UnpackApdbFlags


class TestTrailedSourceFilterTask(unittest.TestCase):
Expand All @@ -40,10 +44,21 @@ def setUp(self):
self.diaSources = pd.DataFrame(data=[
{"ra": 0.04*idx + scatter*rng.uniform(-1, 1),
"dec": 0.04*idx + scatter*rng.uniform(-1, 1),
"diaSourceId": idx, "diaObjectId": 0, "trailLength": 5.5*idx}
"diaSourceId": idx, "diaObjectId": 0, "trailLength": 5.5*idx,
"flags": 0}
for idx in range(self.nSources)])
self.exposure_time = 30.0

# For use only with testing the edge flag
self.edgeDiaSources = pd.DataFrame(data=[
{"ra": 0.04*idx + scatter*rng.uniform(-1, 1),
"dec": 0.04*idx + scatter*rng.uniform(-1, 1),
"diaSourceId": idx, "diaObjectId": 0, "trailLength": 0,
"flags": 0}
for idx in range(self.nSources)])

self.edgeDiaSources.loc[[1, 4], 'flags'] = np.power(2, 27)

def test_run(self):
"""Run trailedSourceFilterTask with the default max distance.

Expand All @@ -52,11 +67,12 @@ def test_run(self):
filtered out of the final results and put into results.trailedSources.
"""
trailedSourceFilterTask = TrailedSourceFilterTask()

results = trailedSourceFilterTask.run(self.diaSources, self.exposure_time)

self.assertEqual(len(results.diaSources), 3)
np.testing.assert_array_equal(results.diaSources['diaSourceId'].values, [0, 1, 2])
np.testing.assert_array_equal(results.trailedDiaSources['diaSourceId'].values, [3, 4])
np.testing.assert_array_equal(results.longTrailedDiaSources['diaSourceId'].values, [3, 4])

def test_run_short_max_trail(self):
"""Run trailedSourceFilterTask with aggressive trail length cutoff
Expand All @@ -73,7 +89,7 @@ def test_run_short_max_trail(self):

self.assertEqual(len(results.diaSources), 1)
np.testing.assert_array_equal(results.diaSources['diaSourceId'].values, [0])
np.testing.assert_array_equal(results.trailedDiaSources['diaSourceId'].values, [1, 2, 3, 4])
np.testing.assert_array_equal(results.longTrailedDiaSources['diaSourceId'].values, [1, 2, 3, 4])

def test_run_no_trails(self):
"""Run trailedSourceFilterTask with a long trail length so that
Expand All @@ -90,18 +106,47 @@ def test_run_no_trails(self):
results = trailedSourceFilterTask.run(self.diaSources, self.exposure_time)

self.assertEqual(len(results.diaSources), 5)
self.assertEqual(len(results.trailedDiaSources), 0)
self.assertEqual(len(results.longTrailedDiaSources), 0)
np.testing.assert_array_equal(results.diaSources["diaSourceId"].values, [0, 1, 2, 3, 4])
np.testing.assert_array_equal(results.trailedDiaSources["diaSourceId"].values, [])
np.testing.assert_array_equal(results.longTrailedDiaSources["diaSourceId"].values, [])

def test_run_edge(self):
"""Run trailedSourceFilterTask with the default max distance.
filtered out of the final results and put into results.trailedSources.
"""
trailedSourceFilterTask = TrailedSourceFilterTask()

results = trailedSourceFilterTask.run(self.edgeDiaSources, self.exposure_time)

self.assertEqual(len(results.diaSources), 3)
np.testing.assert_array_equal(results.diaSources['diaSourceId'].values, [0, 2, 3])
np.testing.assert_array_equal(results.longTrailedDiaSources['diaSourceId'].values, [1, 4])

def test_check_dia_source_trail(self):
bsmartradio marked this conversation as resolved.
Show resolved Hide resolved
"""Test the source trail mask filter.
"""Test that the DiaSource trail checker is correctly identifying
long trails

Test that the mask filter returns the expected mask array.
Test that the trail source mask filter returns the expected mask array.
"""
trailedSourceFilterTask = TrailedSourceFilterTask()
mask = trailedSourceFilterTask._check_dia_source_trail(self.diaSources, self.exposure_time)
np.testing.assert_array_equal(mask, [False, False, False, True, True])
flag_map = os.path.join(utils.getPackageDir("ap_association"), "data/association-flag-map.yaml")
unpacker = UnpackApdbFlags(flag_map, "DiaSource")
flags = unpacker.unpack(self.diaSources["flags"], "flags")
trailed_source_mask = trailedSourceFilterTask._check_dia_source_trail(self.diaSources,
self.exposure_time, flags)

np.testing.assert_array_equal(trailed_source_mask, [False, False, False, True, True])

flags = unpacker.unpack(self.edgeDiaSources["flags"], "flags")
trailed_source_mask = trailedSourceFilterTask._check_dia_source_trail(self.edgeDiaSources,
self.exposure_time, flags)
np.testing.assert_array_equal(trailed_source_mask, [False, True, False, False, True])

# Mixing the flags from edgeDiaSources and diaSources means the mask
# will be set using both criteria.
trailed_source_mask = trailedSourceFilterTask._check_dia_source_trail(self.diaSources,
self.exposure_time, flags)
np.testing.assert_array_equal(trailed_source_mask, [False, True, False, True, True])


class MemoryTester(lsst.utils.tests.MemoryTestCase):
Expand Down