Skip to content

Commit

Permalink
Add centerAll flags and expand tests
Browse files Browse the repository at this point in the history
Rename bits in file-local FootprintBits to distinguish the two kinds.
  • Loading branch information
parejkoj committed Nov 27, 2023
1 parent d96bc7e commit 219b5ca
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 8 deletions.
1 change: 1 addition & 0 deletions include/lsst/meas/base/PixelFlags.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class PixelFlagsAlgorithm : public SimpleAlgorithm {
private:
Control _ctrl;
KeyMap _centerKeys;
KeyMap _centerAllKeys;
KeyMap _anyKeys;
afw::table::Key<afw::table::Flag> _generalFailureKey;
afw::table::Key<afw::table::Flag> _offImageKey;
Expand Down
61 changes: 53 additions & 8 deletions src/PixelFlags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,27 @@ namespace {
template <typename MaskedImageT>
class FootprintBits {
public:
explicit FootprintBits() : _bits(0) {}
explicit FootprintBits() : _anyBits(0), _allBits(~static_cast<typename MaskedImageT::Mask::Pixel>(0x0)) {}

/// \brief Reset everything for a new Footprint
void reset() { _bits = 0x0; }
void reset() {
_anyBits = 0x0;
_allBits = ~static_cast<typename MaskedImageT::Mask::Pixel>(0x0);
}

void operator()(geom::Point2I const& point, typename MaskedImageT::Mask::Pixel const& value) {
_bits |= value;
_anyBits |= value;
_allBits &= value;
}

/// Return the union of the bits set anywhere in the Footprint
typename MaskedImageT::Mask::Pixel getBits() const { return _bits; }
/// Return the union of the bits set anywhere in the Footprint.
typename MaskedImageT::Mask::Pixel getAnyBits() const { return _anyBits; }
/// Return the union of the bits set everywhere in the Footprint.
typename MaskedImageT::Mask::Pixel getAllBits() const { return _allBits; }

private:
typename MaskedImageT::Mask::Pixel _bits;
typename MaskedImageT::Mask::Pixel _anyBits;
typename MaskedImageT::Mask::Pixel _allBits;
};

typedef afw::image::MaskedImage<float> MaskedImageF;
Expand All @@ -64,14 +71,29 @@ void updateFlags(PixelFlagsAlgorithm::KeyMap const& maskFlagToPixelFlag,
const FootprintBits<MaskedImageF>& func, afw::table::SourceRecord& measRecord) {
for (auto const& i : maskFlagToPixelFlag) {
try {
if (func.getBits() & MaskedImageF::Mask::getPlaneBitMask(i.first)) {
if (func.getAnyBits() & MaskedImageF::Mask::getPlaneBitMask(i.first)) {
measRecord.set(i.second, true);
}
} catch (pex::exceptions::InvalidParameterError& err) {
throw LSST_EXCEPT(FatalAlgorithmError, err.what());
}
}
}

// Set flags when all pixels in func have the mask bit set.
void updateFlagsAll(PixelFlagsAlgorithm::KeyMap const& maskFlagToPixelFlag,
const FootprintBits<MaskedImageF>& func, afw::table::SourceRecord& measRecord) {
for (auto const& i : maskFlagToPixelFlag) {
try {
if (func.getAllBits() & MaskedImageF::Mask::getPlaneBitMask(i.first)) {
measRecord.set(i.second, true);
}
} catch (pex::exceptions::InvalidParameterError& err) {
throw LSST_EXCEPT(FatalAlgorithmError, err.what());
}
}
}

} // end anonymous namespace

PixelFlagsAlgorithm::PixelFlagsAlgorithm(Control const& ctrl, std::string const& name,
Expand Down Expand Up @@ -107,13 +129,35 @@ PixelFlagsAlgorithm::PixelFlagsAlgorithm(Control const& ctrl, std::string const&
_centerKeys["SUSPECT"] = schema.addField<afw::table::Flag>(
name + "_flag_suspectCenter", "Suspect pixel in the 3x3 region around the centroid.");

// Flags that correspond to mask bits which are set on all of the 3x3 central pixels of the object.
_centerAllKeys["INTRP"] = schema.addField<afw::table::Flag>(
name + "_flag_interpolatedCenterAll",
"All pixels in the 3x3 region around the centroid are interpolated.");
_centerAllKeys["SAT"] = schema.addField<afw::table::Flag>(
name + "_flag_saturatedCenterAll",
"All pixels in the 3x3 region around the centroid are saturated.");
_centerAllKeys["CR"] = schema.addField<afw::table::Flag>(
name + "_flag_crCenterAll",
"All pixels in the 3x3 region around the centroid have the cosmic ray mask bit.");
_centerAllKeys["BAD"] = schema.addField<afw::table::Flag>(
name + "_flag_badCenterAll", "All pixels in the 3x3 region around the centroid are bad.");
_centerAllKeys["SUSPECT"] = schema.addField<afw::table::Flag>(
name + "_flag_suspectCenterAll", "All pixels in the 3x3 region around the centroid are suspect.");

// Read in the flags passed from the configuration, and add them to the schema
for (auto const& i : _ctrl.masksFpCenter) {
std::string maskName(i);
std::transform(maskName.begin(), maskName.end(), maskName.begin(), ::tolower);
_centerKeys[i] = schema.addField<afw::table::Flag>(
name + "_flag_" + maskName + "Center", "3x3 region around the centroid has " + i + " pixels");
}
for (auto const& i : _ctrl.masksFpCenter) {
std::string maskName(i);
std::transform(maskName.begin(), maskName.end(), maskName.begin(), ::tolower);
_centerAllKeys[i] = schema.addField<afw::table::Flag>(
name + "_flag_" + maskName + "CenterAll",
"All pixels in the 3x3 region around the source centroid are " + i + " pixels");
}

for (auto const& i : _ctrl.masksFpAnywhere) {
std::string maskName(i);
Expand Down Expand Up @@ -176,7 +220,7 @@ void PixelFlagsAlgorithm::measure(afw::table::SourceRecord& measRecord,

// Set the EDGE flag if the bitmask has NO_DATA set
try {
if (func.getBits() & MaskedImageF::Mask::getPlaneBitMask("NO_DATA")) {
if (func.getAnyBits() & MaskedImageF::Mask::getPlaneBitMask("NO_DATA")) {
measRecord.set(_anyKeys.at("EDGE"), true);
}
} catch (pex::exceptions::InvalidParameterError& err) {
Expand All @@ -197,6 +241,7 @@ void PixelFlagsAlgorithm::measure(afw::table::SourceRecord& measRecord,

// Update the flags which have to do with the center of the footprint
updateFlags(_centerKeys, func, measRecord);
updateFlagsAll(_centerAllKeys, func, measRecord);
}

void PixelFlagsAlgorithm::fail(afw::table::SourceRecord& measRecord, MeasurementError* error) const {
Expand Down
31 changes: 31 additions & 0 deletions tests/test_PixelFlags.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,43 @@ def testNoFlags(self):
self.assertFalse(record.get("base_PixelFlags_flag_edge"))
self.assertFalse(record.get("base_PixelFlags_flag_interpolated"))
self.assertFalse(record.get("base_PixelFlags_flag_interpolatedCenter"))
self.assertFalse(record.get("base_PixelFlags_flag_interpolatedCenterAll"))
self.assertFalse(record.get("base_PixelFlags_flag_saturated"))
self.assertFalse(record.get("base_PixelFlags_flag_saturatedCenter"))
self.assertFalse(record.get("base_PixelFlags_flag_saturatedCenterAll"))
self.assertFalse(record.get("base_PixelFlags_flag_cr"))
self.assertFalse(record.get("base_PixelFlags_flag_crCenter"))
self.assertFalse(record.get("base_PixelFlags_flag_crCenterAll"))
self.assertFalse(record.get("base_PixelFlags_flag_bad"))
self.assertFalse(record.get("base_PixelFlags_flag_badCenter"))
self.assertFalse(record.get("base_PixelFlags_flag_badCenterAll"))

def testSomeFlags(self):
task = self.makeSingleFrameMeasurementTask("base_PixelFlags")
exposure, catalog = self.dataset.realize(10.0, task.schema, randomSeed=0)
# one cr pixel outside the center
cosmicray = exposure.mask.getPlaneBitMask("CR")
x = round(self.center.x)
y = round(self.center.y)
exposure.mask[x+3, y+4] |= cosmicray
# one interpolated pixel near the center
interpolated = exposure.mask.getPlaneBitMask("INTRP")
exposure.mask[self.center] |= interpolated
# all pixels in the center are bad
bad = exposure.mask.getPlaneBitMask("BAD")
exposure.mask[x-1:x+2, y-1:y+2] |= bad
task.run(catalog, exposure)
record = catalog[0]

self.assertTrue(record.get("base_PixelFlags_flag_cr"))
self.assertFalse(record.get("base_PixelFlags_flag_crCenter"))
self.assertFalse(record.get("base_PixelFlags_flag_crCenterAll"))
self.assertTrue(record.get("base_PixelFlags_flag_interpolated"))
self.assertTrue(record.get("base_PixelFlags_flag_interpolatedCenter"))
self.assertFalse(record.get("base_PixelFlags_flag_interpolatedCenterAll"))
self.assertTrue(record.get("base_PixelFlags_flag_bad"))
self.assertTrue(record.get("base_PixelFlags_flag_badCenter"))
self.assertTrue(record.get("base_PixelFlags_flag_badCenterAll"))


class TestMemory(lsst.utils.tests.MemoryTestCase):
Expand Down

0 comments on commit 219b5ca

Please sign in to comment.