Skip to content

Commit

Permalink
Add a test plugin in Python which uses the FlagHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
pgee2000 committed May 25, 2016
1 parent 2d4cb19 commit a7ad3ca
Showing 1 changed file with 149 additions and 24 deletions.
173 changes: 149 additions & 24 deletions tests/testFlagHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# the GNU General Public License along with this program. If not,
# see <http://www.lsstcorp.org/LegalNotices/>.
#

import os
import unittest

import lsst.utils.tests
Expand All @@ -31,15 +31,101 @@
from lsst.meas.base import FlagDefinition, FlagDefinitionVector, FlagHandler
from lsst.meas.base.tests import (AlgorithmTestCase)

import lsst.pex.exceptions
from lsst.meas.base.pluginRegistry import register
from lsst.meas.base.sfm import SingleFramePluginConfig, SingleFramePlugin
from lsst.meas.base.baseLib import MeasurementError
from lsst.meas.base import FlagDefinition, FlagDefinitionVector, FlagHandler


class PythonPluginConfig(SingleFramePluginConfig):

failureType = lsst.pex.config.Field(dtype=int, default=None, optional=False,
doc="A failure mode to test")

@register("test_PythonPlugin")
class PythonPlugin(SingleFramePlugin):
'''
This is a sample Python plugin. The flag handler for this plugin is created
during construction, and is called using the method fail(). All plugins are
required to implement this method, which is used to set the flags in the
output source record if an error occurs.
'''
ConfigClass = PythonPluginConfig

# Constants to identify the failures. Must match flagDefs below.
# These are class statics which can be used during run() to identify
# known error conditions
FAILURE = 0
ERROR1 = 1
ERROR2 = 2

FLAGDEFS = [ FlagDefinition("flag", "General failure error"),
FlagDefinition("flag_error1", "First type of failure occured."),
FlagDefinition("flag_error2", "Second type of failure occured."),
]

@classmethod
def getExecutionOrder(cls):
return cls.SHAPE_ORDER

def __init__(self, config, name, schema, metadata):
SingleFramePlugin.__init__(self, config, name, schema, metadata)
self.flagHandler = FlagHandler.addFields(schema, name, FlagDefinitionVector(self.FLAGDEFS))

# This is a measure routine which does nothing except to raise Exceptions
# as requested by the caller. Errors normally don't occur unless there is
# something wrong in the inputs, or if there is an error during the measurement
def measure(self, measRecord, exposure):
if not self.config.failureType is None:
if self.config.failureType == PythonPlugin.ERROR1:
raise MeasurementError(self.flagHandler.getDefinition(PythonPlugin.ERROR1).doc,
PythonPlugin.ERROR1)
if self.config.failureType == PythonPlugin.ERROR2:
raise MeasurementError(self.flagHandler.getDefinition(PythonPlugin.ERROR2).doc,
PythonPlugin.ERROR2)
raise RuntimeError("An unexpected error occurred")

# This routine responds to the standard failure call in baseMeasurement
# If the exception is a MeasurementError, the error will be passed to the
# fail method by the MeasurementFramework.
def fail(self, measRecord, error=None):
if error is None:
self.flagHandler.handleFailure(measRecord)
else:
self.flagHandler.handleFailure(measRecord, error.cpp)

class FlagHandlerTestCase(AlgorithmTestCase):


# Setup a configuration and datasource to be used by the plugin tests
def setUp(self):
self.algName = "test_PythonPlugin"
bbox = lsst.afw.geom.Box2I(lsst.afw.geom.Point2I(0,0), lsst.afw.geom.Point2I(100, 100))
self.dataset = lsst.meas.base.tests.TestDataset(bbox)
self.dataset.addSource(flux=1E5, centroid=lsst.afw.geom.Point2D(25, 26))
config = lsst.meas.base.SingleFrameMeasurementConfig()
config.plugins = [self.algName]
config.slots.centroid = None
config.slots.apFlux = None
config.slots.calibFlux = None
config.slots.instFlux = None
config.slots.modelFlux = None
config.slots.psfFlux = None
config.slots.shape = None
self.config = config

def tearDown(self):
del self.config
del self.dataset

# Standalone test to create a flaghandler and call it
# This is not a real world example, just a simple unit test
def testFlagHandler(self):
control = lsst.meas.base.GaussianCentroidControl()
alg = lsst.meas.base.GaussianCentroidAlgorithm
schema = lsst.afw.table.SourceTable.makeMinimalSchema()
plugin = alg(control, 'test', schema)
cat = lsst.afw.table.SourceCatalog(schema)
schema = cat.getSchema()
subSchema = schema["test"]

# This is a FlagDefinition structure like a plugin might have
Expand All @@ -55,45 +141,82 @@ def testFlagHandler(self):

# Check to be sure that the FlagHandler was correctly initialized
for index, flagDef in enumerate(flagDefs):
assert(flagDef.name == flagDefs[index].name)
assert(flagDef.doc == flagDefs[index].doc)
assert(flagDef.name == fh.getDefinition(index).name)
assert(flagDef.doc == fh.getDefinition(index).doc)

catalog = lsst.afw.table.SourceCatalog(schema)

# Now check to be sure that all of the known failures set the bits correctly
record = catalog.addNew()
fh.handleFailure(record)
self.assertTrue(fh.getValue(record, 0))
self.assertFalse(fh.getValue(record, 1))
self.assertFalse(fh.getValue(record, 2))
self.assertTrue(fh.getValue(record, FAILURE))
self.assertFalse(fh.getValue(record, FIRST))
self.assertFalse(fh.getValue(record, SECOND))
record = catalog.addNew()

error = MeasurementError(fh.getDefinition(FAILURE).doc, FAILURE)
fh.handleFailure(record, error.cpp)
self.assertTrue(fh.getValue(record, 0))
self.assertFalse(fh.getValue(record, 1))
self.assertFalse(fh.getValue(record, 2))
self.assertTrue(fh.getValue(record, FAILURE))
self.assertFalse(fh.getValue(record, FIRST))
self.assertFalse(fh.getValue(record, SECOND))

record = catalog.addNew()
error = MeasurementError(fh.getDefinition(FIRST).doc, FIRST)
fh.handleFailure(record, error.cpp)
self.assertTrue(fh.getValue(record, 0))
self.assertTrue(fh.getValue(record, 1))
self.assertFalse(fh.getValue(record, 2))
self.assertTrue(fh.getValue(record, FAILURE))
self.assertTrue(fh.getValue(record, FIRST))
self.assertFalse(fh.getValue(record, SECOND))

record = catalog.addNew()
error = MeasurementError(fh.getDefinition(SECOND).doc, SECOND)
fh.handleFailure(record, error.cpp)
self.assertTrue(fh.getValue(record, 0))
self.assertFalse(fh.getValue(record, 1))
self.assertTrue(fh.getValue(record, 2))

record = catalog.addNew()
error = MeasurementError("Custom error message", FIRST)
fh.handleFailure(record, error.cpp)
self.assertTrue(fh.getValue(record, 0))
self.assertTrue(fh.getValue(record, 1))
self.assertFalse(fh.getValue(record, 2))
self.assertTrue(fh.getValue(record, FAILURE))
self.assertFalse(fh.getValue(record, FIRST))
self.assertTrue(fh.getValue(record, SECOND))

def testNoError(self):
schema = self.dataset.makeMinimalSchema()
task = lsst.meas.base.SingleFrameMeasurementTask(schema=schema, config=self.config)
exposure, cat = self.dataset.realize(noise=100.0, schema=schema)
task.run(cat, exposure)
source = cat[0]
self.assertEqual(source.get(self.algName + "_flag"), False)
self.assertEqual(source.get(self.algName + "_flag_error1"), False)
self.assertEqual(source.get(self.algName + "_flag_error2"), False)

def testUnexpectedError(self):
self.config.plugins[self.algName].failureType = -1 # any unknown error type will do
schema = self.dataset.makeMinimalSchema()
task = lsst.meas.base.SingleFrameMeasurementTask(schema=schema, config=self.config)
exposure, cat = self.dataset.realize(noise=100.0, schema=schema)
task.log.setThreshold(task.log.FATAL)
task.run(cat, exposure)
source = cat[0]
self.assertEqual(source.get(self.algName + "_flag"), True)
self.assertEqual(source.get(self.algName + "_flag_error1"), False)
self.assertEqual(source.get(self.algName + "_flag_error2"), False)

def testError1(self):
self.config.plugins[self.algName].failureType = PythonPlugin.ERROR1
schema = self.dataset.makeMinimalSchema()
task = lsst.meas.base.SingleFrameMeasurementTask(schema=schema, config=self.config)
exposure, cat = self.dataset.realize(noise=100.0, schema=schema)
task.run(cat, exposure)
source = cat[0]
self.assertEqual(source.get(self.algName + "_flag"), True)
self.assertEqual(source.get(self.algName + "_flag_error1"), True)
self.assertEqual(source.get(self.algName + "_flag_error2"), False)

def testError2(self):
self.config.plugins[self.algName].failureType = PythonPlugin.ERROR2
schema = self.dataset.makeMinimalSchema()
task = lsst.meas.base.SingleFrameMeasurementTask(schema=schema, config=self.config)
exposure, cat = self.dataset.realize(noise=0.0, schema=schema)
task.run(cat, exposure)
source = cat[0]
self.assertEqual(source.get(self.algName + "_flag"), True)
self.assertEqual(source.get(self.algName + "_flag_error1"), False)
self.assertEqual(source.get(self.algName + "_flag_error2"), True)

def suite():
"""Returns a suite containing all the test cases in this module."""
Expand All @@ -111,3 +234,5 @@ def run(shouldExit=False):

if __name__ == "__main__":
run(True)


0 comments on commit a7ad3ca

Please sign in to comment.