Skip to content

Commit

Permalink
Read plugin registry from measurement config.
Browse files Browse the repository at this point in the history
This is a nice simplification to the logic, and was (in fact) possible all
along. Thanks to Jim Bosch for pointing it out.
  • Loading branch information
jdswinbank committed Jun 2, 2015
1 parent 20fe321 commit c98b0e6
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 19 deletions.
20 changes: 3 additions & 17 deletions python/lsst/pipe/tasks/transformMeasurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,11 @@ class TransformTask(pipeBase.Task):
ConfigClass = TransformConfig
_DefaultName = "transform"

def __init__(self, measConfig, pluginRegistry, inputSchema, outputDataset, *args, **kwargs):
def __init__(self, measConfig, inputSchema, outputDataset, *args, **kwargs):
"""!Initialize TransformTask.
@param[in] measConfig Configuration for the measurement task which
produced the measurments being transformed.
@param[in] pluginRegistry A PluginRegistry which maps plugin names to measurement algorithms.
@param[in] inputSchema The schema of the input catalog.
@param[in] outputDataset The butler dataset type of the output catalog.
@param[in] *args Passed through to pipeBase.Task.__init__()
Expand All @@ -121,7 +120,7 @@ def __init__(self, measConfig, pluginRegistry, inputSchema, outputDataset, *args
self.transforms = []
for name in measConfig.plugins.names:
config = measConfig.plugins.get(name)
transformClass = pluginRegistry.get(name).PluginClass.getTransformClass()
transformClass = measConfig.plugins.registry.get(name).PluginClass.getTransformClass()
self.transforms.append(transformClass(config, name, self.mapper))

def getSchemaCatalogs(self):
Expand Down Expand Up @@ -200,9 +199,6 @@ class RunTransformTaskBase(pipeBase.CmdLineTask):
# Standard CmdLineTask attributes:
_DefaultName = None

# Boolean; True if the measurement operation was forced, otherwise False.
wasForced = None

# Butler dataset type of the source type to be transformed ("src", "forced_src", etc):
sourceType = None

Expand Down Expand Up @@ -242,15 +238,8 @@ def measurementConfig(self):

def __init__(self, *args, **kwargs):
pipeBase.CmdLineTask.__init__(self, *args, config=kwargs['config'], log=kwargs['log'])
if self.wasForced:
pluginRegistry = measBase.forcedMeasurement.ForcedPlugin.registry
else:
pluginRegistry = measBase.sfm.SingleFramePlugin.registry

self.butler = kwargs['butler']

self.makeSubtask('transform', pluginRegistry=pluginRegistry,
measConfig=self.measurementConfig,
self.makeSubtask('transform', measConfig=self.measurementConfig,
inputSchema=self.butler.get(self.inputSchemaType).schema,
outputDataset=self.outputDataset)

Expand Down Expand Up @@ -290,7 +279,6 @@ class SrcTransformTask(RunTransformTaskBase):
operates on ``src`` measurements. Refer to the parent documentation for details.
"""
_DefaultName = "transformSrcMeasurement"
wasForced = False
sourceType = 'src'
calexpType = 'calexp'

Expand All @@ -312,7 +300,6 @@ class ForcedSrcTransformTask(RunTransformTaskBase):
operates on ``forced_src`` measurements. Refer to the parent documentation for details.
"""
_DefaultName = "transformForcedSrcMeasurement"
wasForced = True
sourceType = 'forced_src'
calexpType = 'calexp'

Expand All @@ -334,7 +321,6 @@ class CoaddSrcTransformTask(RunTransformTaskBase):
operates on measurements made on coadds. Refer to the parent documentation for details.
"""
_DefaultName = "transformCoaddSrcMeasurement"
wasForced = False

@property
def coaddName(self):
Expand Down
2 changes: 0 additions & 2 deletions tests/testTransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ def testSingleFrameMeasurementTransform(self):
setattr(sfmConfig.slots, key, None)
sfmTask = measBase.SingleFrameMeasurementTask(schema, config=sfmConfig)
transformTask = TransformTask(measConfig=sfmConfig,
pluginRegistry=measBase.sfm.SingleFramePlugin.registry,
inputSchema=sfmTask.schema, outputDataset="src")
self._transformAndCheck(sfmConfig, sfmTask.schema, transformTask)

Expand All @@ -189,7 +188,6 @@ def testForcedMeasurementTransform(self):
forcedTask = measBase.ForcedMeasurementTask(schema, config=forcedConfig)
transformConfig = TransformConfig(copyFields=("objectId", "coord"))
transformTask = TransformTask(measConfig=forcedConfig,
pluginRegistry=measBase.forcedMeasurement.ForcedPlugin.registry,
inputSchema=forcedTask.schema, outputDataset="forced_src",
config=transformConfig)
self._transformAndCheck(forcedConfig, forcedTask.schema, transformTask)
Expand Down

0 comments on commit c98b0e6

Please sign in to comment.