-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes to FlagHandler so that it can be used for Python only plugins #41
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,8 +39,17 @@ namespace lsst { namespace meas { namespace base { | |
C-strings so we can create these arrays using initializer lists even in C++98. | ||
*/ | ||
struct FlagDefinition { | ||
char const * name; | ||
char const * doc; | ||
|
||
FlagDefinition() { | ||
} | ||
|
||
FlagDefinition(std::string _name, std::string _doc) { | ||
name = _name; | ||
doc = _doc; | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am slightly worried what is going to happen when you have python strings mapped to c style pointers. If the python strings somehow get cleaned up. What will happen to the pointers. Is it practical to use std::strings, or am I being overly worried? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch. I am not sure about the lifetime issue here, but it certainly should be thought about. Doesn't the same issue come put for the original addFields() call? Seems like what protects those pointers is that they are attached to a static member of the class which is using them. That means that someone abusing the original addFields() could probably create the same issue. But I propose that I adopt the same strategy with Python, and use a static list for the flagDefs and only call the addFields from the plugin constructor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would indeed cause problems if called from Python. But the reason we needed to use raw pointers instead of |
||
std::string name; | ||
std::string doc; | ||
}; | ||
|
||
/** | ||
|
@@ -119,6 +128,25 @@ class FlagHandler { | |
FlagDefinition const * end | ||
); | ||
|
||
/** | ||
* Add Flag fields to a schema, creating a FlagHandler object to manage them. | ||
* | ||
* This is the way FlagHandlers will typically be constructed for new algorithms. | ||
* | ||
* @param[out] schema Schema to which fields should be added. | ||
* @param[in] prefix String name of the algorithm or algorithm component. Field names will | ||
* be constructed by using schema.join() on this and the flag name from the | ||
* FlagDefinition array. | ||
* @param[in] flagDefs std::vector of FlagDefinitions | ||
* | ||
* This variation of addFields is for Python plugins | ||
*/ | ||
static FlagHandler addFields( | ||
afw::table::Schema & schema, | ||
std::string const & prefix, | ||
std::vector<FlagDefinition> const * flagDefs | ||
); | ||
|
||
/** | ||
* Construct a FlagHandler to manage fields already added to a schema. | ||
* | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from lsst.meas.base import FlagDefinition, FlagDefinitionVector, FlagHandler | ||
|
||
|
||
def addFlagHandler(*args): | ||
''' | ||
Class decorator to create a flag handler for a plugin. Adds the class variables FLAGDEFS and | ||
ErrEnum. An instnace variable flagHandler is added to the __init__ function to be created at | ||
initialization. The arguments to this function are tuples which have the name of the failure | ||
as the first element, and the documentation for the failure as the second element. | ||
|
||
Usage: | ||
@addFlagHandler(("name_of_failure", "Doc for failure), ("name_of_second_faiure", "Doc of" | ||
" second failure"), .....) | ||
''' | ||
def classFactory(cls): | ||
cls.FLAGDEFS = [FlagDefinition(name, doc) for name, doc in args] | ||
# Verify all flag names are unique | ||
names = [entry[0] for entry in args] | ||
if len(names) != len(set(names)): | ||
raise ValueError("All flag names must be unique, given {}".format(names)) | ||
# Better to have a scoped enumeration rather than attach variables strait to the class to | ||
# prevent shadowing | ||
cls.ErrEnum = type('ErrEnum', (), {entry.name: i for i, entry in enumerate(cls.FLAGDEFS)}) | ||
oldInit = cls.__init__ | ||
|
||
def newInit(self, *args, **kwargs): | ||
oldInit(self, *args, **kwargs) | ||
if 'schema' in kwargs: | ||
schema = kwargs['schema'] | ||
else: | ||
schema = args[2] | ||
if 'name' in kwargs: | ||
name = kwargs['name'] | ||
else: | ||
name = args[1] | ||
self.flagHandler = FlagHandler.addFields(schema, name, FlagDefinitionVector(self.FLAGDEFS)) | ||
cls.__init__ = newInit | ||
return cls | ||
return classFactory |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,6 +44,24 @@ FlagHandler FlagHandler::addFields( | |
return r; | ||
} | ||
|
||
FlagHandler FlagHandler::addFields( | ||
afw::table::Schema & schema, | ||
std::string const & prefix, | ||
std::vector<FlagDefinition> const * flagDefs | ||
) { | ||
FlagHandler r; | ||
r._vector.reserve(flagDefs->size()); | ||
for (unsigned int i = 0; i < flagDefs->size(); i++) { | ||
r._vector.push_back( | ||
std::make_pair( | ||
flagDefs->at(i), | ||
schema.addField<afw::table::Flag>(schema.join(prefix, flagDefs->at(i).name), flagDefs->at(i).doc) | ||
) | ||
); | ||
} | ||
return r; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we are now using c++ 11, please switch this to a range style for loop. It will make the code more readable, and remove all the extra dereferencing -> There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually think this is a weird method name, as it is really an alternative constructor for a flaghandler, only adding schemas as a side effect. However as this mirrors the existing method it probably is appropriate. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this code is a copy of the original addFields, we should probably change both or neither. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm leaving it up to Jim whether he wants to make this change. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's not change it. I think of the side-effect of adding fields as actually being more important than the fact that this constructs a FlagHandler, and the same name is used for similar methods on FunctorKeys in many other places. |
||
} | ||
|
||
FlagHandler::FlagHandler( | ||
afw::table::SubSchema const & s, | ||
FlagDefinition const * begin, | ||
|
@@ -66,4 +84,5 @@ void FlagHandler::handleFailure(afw::table::BaseRecord & record, MeasurementErro | |
} | ||
} | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please don't introduce extra diff churn by adding extra spaces if they don't add any clarity |
||
}}} // lsst::meas::base |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
#!/usr/bin/env python | ||
# | ||
# LSST Data Management System | ||
# Copyright 2008-2013 LSST Corporation. | ||
# | ||
# This product includes software developed by the | ||
# LSST Project (http://www.lsst.org/). | ||
# | ||
# This program is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU General Public License as published by | ||
# the Free Software Foundation, either version 3 of the License, or | ||
# (at your option) any later version. | ||
# | ||
# This program is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU General Public License for more details. | ||
# | ||
# You should have received a copy of the LSST License Statement and | ||
# the GNU General Public License along with this program. If not, | ||
# see <http://www.lsstcorp.org/LegalNotices/>. | ||
# | ||
import os | ||
import unittest | ||
|
||
import lsst.utils.tests | ||
import lsst.meas.base | ||
import lsst.meas.base.tests | ||
import lsst.afw.table | ||
from lsst.meas.base.baseLib import MeasurementError | ||
from lsst.meas.base import FlagDefinition, FlagDefinitionVector, FlagHandler | ||
from lsst.meas.base.tests import (AlgorithmTestCase) | ||
|
||
import lsst.pex.exceptions | ||
from lsst.meas.base.pluginRegistry import register | ||
from lsst.meas.base.sfm import SingleFramePluginConfig, SingleFramePlugin | ||
from lsst.meas.base.baseLib import MeasurementError | ||
from lsst.meas.base import FlagDefinition, FlagDefinitionVector, FlagHandler | ||
from lsst.meas.base.flagDecorator import addFlagHandler | ||
|
||
|
||
class PythonPluginConfig(SingleFramePluginConfig): | ||
|
||
failureType = lsst.pex.config.Field(dtype=int, default=None, optional=False, | ||
doc="A failure mode to test") | ||
|
||
@register("test_PythonPlugin") | ||
@addFlagHandler(("flag", "General Failure error"), ("flag_error1","First type of Failure occured."), | ||
("flag_error2", "Second type of failure occured.")) | ||
class PythonPlugin(SingleFramePlugin): | ||
''' | ||
This is a sample Python plugin. The flag handler for this plugin is created | ||
during construction, and is called using the method fail(). All plugins are | ||
required to implement this method, which is used to set the flags in the | ||
output source record if an error occurs. | ||
''' | ||
ConfigClass = PythonPluginConfig | ||
# Class variables ErrEnum and FLAGDEFS are added by the decorator | ||
|
||
@classmethod | ||
def getExecutionOrder(cls): | ||
return cls.SHAPE_ORDER | ||
|
||
def __init__(self, config, name, schema, metadata): | ||
SingleFramePlugin.__init__(self, config, name, schema, metadata) | ||
# The instance variable flagHandler is added by the decorator | ||
|
||
# This is a measure routine which does nothing except to raise Exceptions | ||
# as requested by the caller. Errors normally don't occur unless there is | ||
# something wrong in the inputs, or if there is an error during the measurement | ||
def measure(self, measRecord, exposure): | ||
if not self.config.failureType is None: | ||
if self.config.failureType == PythonPlugin.ErrEnum.flag_error1: | ||
raise MeasurementError(self.flagHandler.getDefinition(PythonPlugin.ErrEnum.flag_error1).doc, | ||
PythonPlugin.ErrEnum.flag_error1) | ||
if self.config.failureType == PythonPlugin.ErrEnum.flag_error2: | ||
raise MeasurementError(self.flagHandler.getDefinition(PythonPlugin.ErrEnum.flag_error2).doc, | ||
PythonPlugin.ErrEnum.flag_error2) | ||
raise RuntimeError("An unexpected error occurred") | ||
|
||
# This routine responds to the standard failure call in baseMeasurement | ||
# If the exception is a MeasurementError, the error will be passed to the | ||
# fail method by the MeasurementFramework. | ||
def fail(self, measRecord, error=None): | ||
if error is None: | ||
self.flagHandler.handleFailure(measRecord) | ||
else: | ||
self.flagHandler.handleFailure(measRecord, error.cpp) | ||
|
||
class FlagHandlerTestCase(AlgorithmTestCase): | ||
|
||
# Setup a configuration and datasource to be used by the plugin tests | ||
def setUp(self): | ||
self.algName = "test_PythonPlugin" | ||
bbox = lsst.afw.geom.Box2I(lsst.afw.geom.Point2I(0,0), lsst.afw.geom.Point2I(100, 100)) | ||
self.dataset = lsst.meas.base.tests.TestDataset(bbox) | ||
self.dataset.addSource(flux=1E5, centroid=lsst.afw.geom.Point2D(25, 26)) | ||
config = lsst.meas.base.SingleFrameMeasurementConfig() | ||
config.plugins = [self.algName] | ||
config.slots.centroid = None | ||
config.slots.apFlux = None | ||
config.slots.calibFlux = None | ||
config.slots.instFlux = None | ||
config.slots.modelFlux = None | ||
config.slots.psfFlux = None | ||
config.slots.shape = None | ||
self.config = config | ||
|
||
def tearDown(self): | ||
del self.config | ||
del self.dataset | ||
|
||
# Standalone test to create a flaghandler and call it | ||
# This is not a real world example, just a simple unit test | ||
def testFlagHandler(self): | ||
control = lsst.meas.base.GaussianCentroidControl() | ||
alg = lsst.meas.base.GaussianCentroidAlgorithm | ||
schema = lsst.afw.table.SourceTable.makeMinimalSchema() | ||
plugin = alg(control, 'test', schema) | ||
cat = lsst.afw.table.SourceCatalog(schema) | ||
subSchema = schema["test"] | ||
|
||
# This is a FlagDefinition structure like a plugin might have | ||
FAILURE = 0 | ||
FIRST = 1 | ||
SECOND = 2 | ||
flagDefs = [ FlagDefinition("General Failure", "general failure error"), | ||
FlagDefinition("1st error", "this is the first failure type"), | ||
FlagDefinition("2nd error", "this is the second failure type") | ||
] | ||
fh = FlagHandler.addFields(schema, "test", | ||
FlagDefinitionVector(flagDefs)) | ||
|
||
# Check to be sure that the FlagHandler was correctly initialized | ||
for index, flagDef in enumerate(flagDefs): | ||
assert(flagDef.name == fh.getDefinition(index).name) | ||
assert(flagDef.doc == fh.getDefinition(index).doc) | ||
|
||
catalog = lsst.afw.table.SourceCatalog(schema) | ||
|
||
# Now check to be sure that all of the known failures set the bits correctly | ||
record = catalog.addNew() | ||
fh.handleFailure(record) | ||
self.assertTrue(fh.getValue(record, FAILURE)) | ||
self.assertFalse(fh.getValue(record, FIRST)) | ||
self.assertFalse(fh.getValue(record, SECOND)) | ||
record = catalog.addNew() | ||
|
||
error = MeasurementError(fh.getDefinition(FAILURE).doc, FAILURE) | ||
fh.handleFailure(record, error.cpp) | ||
self.assertTrue(fh.getValue(record, FAILURE)) | ||
self.assertFalse(fh.getValue(record, FIRST)) | ||
self.assertFalse(fh.getValue(record, SECOND)) | ||
|
||
record = catalog.addNew() | ||
error = MeasurementError(fh.getDefinition(FIRST).doc, FIRST) | ||
fh.handleFailure(record, error.cpp) | ||
self.assertTrue(fh.getValue(record, FAILURE)) | ||
self.assertTrue(fh.getValue(record, FIRST)) | ||
self.assertFalse(fh.getValue(record, SECOND)) | ||
|
||
record = catalog.addNew() | ||
error = MeasurementError(fh.getDefinition(SECOND).doc, SECOND) | ||
fh.handleFailure(record, error.cpp) | ||
self.assertTrue(fh.getValue(record, FAILURE)) | ||
self.assertFalse(fh.getValue(record, FIRST)) | ||
self.assertTrue(fh.getValue(record, SECOND)) | ||
|
||
def testNoError(self): | ||
schema = self.dataset.makeMinimalSchema() | ||
task = lsst.meas.base.SingleFrameMeasurementTask(schema=schema, config=self.config) | ||
exposure, cat = self.dataset.realize(noise=100.0, schema=schema) | ||
task.run(cat, exposure) | ||
source = cat[0] | ||
self.assertEqual(source.get(self.algName + "_flag"), False) | ||
self.assertEqual(source.get(self.algName + "_flag_error1"), False) | ||
self.assertEqual(source.get(self.algName + "_flag_error2"), False) | ||
|
||
def testUnexpectedError(self): | ||
self.config.plugins[self.algName].failureType = -1 # any unknown error type will do | ||
schema = self.dataset.makeMinimalSchema() | ||
task = lsst.meas.base.SingleFrameMeasurementTask(schema=schema, config=self.config) | ||
exposure, cat = self.dataset.realize(noise=100.0, schema=schema) | ||
task.log.setThreshold(task.log.FATAL) | ||
task.run(cat, exposure) | ||
source = cat[0] | ||
self.assertEqual(source.get(self.algName + "_flag"), True) | ||
self.assertEqual(source.get(self.algName + "_flag_error1"), False) | ||
self.assertEqual(source.get(self.algName + "_flag_error2"), False) | ||
|
||
def testError1(self): | ||
self.config.plugins[self.algName].failureType = PythonPlugin.ErrEnum.flag_error1 | ||
schema = self.dataset.makeMinimalSchema() | ||
task = lsst.meas.base.SingleFrameMeasurementTask(schema=schema, config=self.config) | ||
exposure, cat = self.dataset.realize(noise=100.0, schema=schema) | ||
task.run(cat, exposure) | ||
source = cat[0] | ||
self.assertEqual(source.get(self.algName + "_flag"), True) | ||
self.assertEqual(source.get(self.algName + "_flag_error1"), True) | ||
self.assertEqual(source.get(self.algName + "_flag_error2"), False) | ||
|
||
def testError2(self): | ||
self.config.plugins[self.algName].failureType = PythonPlugin.ErrEnum.flag_error2 | ||
schema = self.dataset.makeMinimalSchema() | ||
task = lsst.meas.base.SingleFrameMeasurementTask(schema=schema, config=self.config) | ||
exposure, cat = self.dataset.realize(noise=0.0, schema=schema) | ||
task.run(cat, exposure) | ||
source = cat[0] | ||
self.assertEqual(source.get(self.algName + "_flag"), True) | ||
self.assertEqual(source.get(self.algName + "_flag_error1"), False) | ||
self.assertEqual(source.get(self.algName + "_flag_error2"), True) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would actually be useful if you wrote a simple plugin that uses the flag handler. This could serve as an example for anyone going to use this in the future. Alternatively you could maybe just create a small example that uses the flag handler. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I agree. I was thinking of the ngmix plugin as the "sample", but it is easy enough to transfer that code to a simple example. |
||
def suite(): | ||
"""Returns a suite containing all the test cases in this module.""" | ||
|
||
lsst.utils.tests.init() | ||
|
||
suites = [] | ||
suites += unittest.makeSuite(FlagHandlerTestCase) | ||
suites += unittest.makeSuite(lsst.utils.tests.MemoryTestCase) | ||
return unittest.TestSuite(suites) | ||
|
||
def run(shouldExit=False): | ||
"""Run the tests""" | ||
lsst.utils.tests.run(suite(), shouldExit) | ||
|
||
if __name__ == "__main__": | ||
run(True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update the comment above to match the changes you've made below.