Skip to content

Commit

Permalink
Merge branch 'tickets/DM-23616'
Browse files Browse the repository at this point in the history
  • Loading branch information
kfindeisen committed Mar 26, 2020
2 parents 91c0010 + 782df4e commit 557baf8
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 25 deletions.
2 changes: 1 addition & 1 deletion doc/lsst.pipe.base/testing-a-pipeline-task.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ If you do need `~lsst.pipe.base.PipelineTask.runQuantum` to call `~lsst.pipe.bas
butler = butlerTests.makeTestCollection(repo)
task = AwesomeTask()
quantum = testUtils.makeQuantum(
task, butler,
task, butler, dataId,
{key: dataId for key in {"input", "output"}})
run = testUtils.runTestQuantum(task, butler, quantum)
# Actual input dataset omitted for simplicity
Expand Down
29 changes: 20 additions & 9 deletions python/lsst/pipe/base/testUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@


import collections.abc
import itertools
import unittest.mock

from lsst.daf.butler import DataCoordinate, DatasetRef, Quantum, StorageClassFactory
from lsst.pipe.base import ButlerQuantumContext


def makeQuantum(task, butler, dataIds):
def makeQuantum(task, butler, dataId, ioDataIds):
"""Create a Quantum for a particular data ID(s).
Parameters
Expand All @@ -39,7 +40,10 @@ def makeQuantum(task, butler, dataIds):
The task whose processing the quantum represents.
butler : `lsst.daf.butler.Butler`
The collection the quantum refers to.
dataIds : `collections.abc.Mapping` [`str`]
dataId: any data ID type
The data ID of the quantum. Must have the same dimensions as
``task``'s connections class.
ioDataIds : `collections.abc.Mapping` [`str`]
A mapping keyed by input/output names. Values must be data IDs for
single connections and sequences of data IDs for multiple connections.
Expand All @@ -48,20 +52,20 @@ def makeQuantum(task, butler, dataIds):
quantum : `lsst.daf.butler.Quantum`
A quantum for ``task``, when called with ``dataIds``.
"""
quantum = Quantum(taskClass=type(task))
quantum = Quantum(taskClass=type(task), dataId=dataId)
connections = task.config.ConnectionsClass(config=task.config)

try:
for name in connections.inputs:
for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
connection = connections.__getattribute__(name)
_checkDataIdMultiplicity(name, dataIds[name], connection.multiple)
ids = _normalizeDataIds(dataIds[name])
_checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
ids = _normalizeDataIds(ioDataIds[name])
for id in ids:
quantum.addPredictedInput(_refFromConnection(butler, connection, id))
for name in connections.outputs:
connection = connections.__getattribute__(name)
_checkDataIdMultiplicity(name, dataIds[name], connection.multiple)
ids = _normalizeDataIds(dataIds[name])
_checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
ids = _normalizeDataIds(ioDataIds[name])
for id in ids:
quantum.addOutput(_refFromConnection(butler, connection, id))
return quantum
Expand Down Expand Up @@ -139,7 +143,14 @@ def _refFromConnection(butler, connection, dataId, **kwargs):
"""
universe = butler.registry.dimensions
dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe)
datasetType = connection.makeDatasetType(universe)

# skypix is a PipelineTask alias for "some spatial index", Butler doesn't
# understand it. Code copied from TaskDatasetTypes.fromTaskDef
if "skypix" in connection.dimensions:
datasetType = butler.registry.getDatasetType(connection.name)
else:
datasetType = connection.makeDatasetType(universe)

try:
butler.registry.getDatasetType(datasetType.name)
except KeyError:
Expand Down
72 changes: 57 additions & 15 deletions tests/test_testUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class PatchConnections(PipelineTaskConnections, dimensions={"skymap", "tract"}):
multiple=True,
dimensions={"skymap", "tract", "patch"},
)
b = connectionTypes.Input(
b = connectionTypes.PrerequisiteInput(
name="PatchB",
storageClass="StructuredData",
multiple=False,
Expand All @@ -87,7 +87,20 @@ def __init__(self, *, config=None):
super().__init__(config=config)

if not config.doUseB:
self.inputs.remove("b")
self.prerequisiteInputs.remove("b")


class SkyPixConnections(PipelineTaskConnections, dimensions={"skypix"}):
a = connectionTypes.Input(
name="PixA",
storageClass="StructuredData",
dimensions={"skypix"},
)
out = connectionTypes.Output(
name="PixOut",
storageClass="StructuredData",
dimensions={"skypix"},
)


class VisitConfig(PipelineTaskConfig, pipelineConnections=VisitConnections):
Expand All @@ -98,6 +111,10 @@ class PatchConfig(PipelineTaskConfig, pipelineConnections=PatchConnections):
doUseB = lsst.pex.config.Field(default=True, dtype=bool, doc="")


class SkyPixConfig(PipelineTaskConfig, pipelineConnections=SkyPixConnections):
pass


class VisitTask(PipelineTask):
ConfigClass = VisitConfig
_DefaultName = "visit"
Expand All @@ -120,6 +137,14 @@ def run(self, a, b=None):
return Struct(out=out)


class SkyPixTask(PipelineTask):
ConfigClass = SkyPixConfig
_DefaultName = "skypix"

def run(self, a):
return Struct(out=a)


class PipelineTaskTestSuite(lsst.utils.tests.TestCase):
@classmethod
def setUpClass(cls):
Expand All @@ -144,6 +169,8 @@ def setUpClass(cls):
for typeName in {"PatchA", "PatchOut"}:
butlerTests.addDatasetType(cls.repo, typeName, {"skymap", "tract", "patch"}, "StructuredData")
butlerTests.addDatasetType(cls.repo, "PatchB", {"skymap", "tract"}, "StructuredData")
for typeName in {"PixA", "PixOut"}:
butlerTests.addDatasetType(cls.repo, typeName, {"htm7"}, "StructuredData")

@classmethod
def tearDownClass(cls):
Expand Down Expand Up @@ -219,7 +246,7 @@ def testMakeQuantumNoSuchDatatype(self):
self._makeVisitTestData(dataId)

with self.assertRaises(ValueError):
makeQuantum(task, self.butler, {key: dataId for key in {"a", "b", "outA", "outB"}})
makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "b", "outA", "outB"}})

def testMakeQuantumInvalidDimension(self):
config = VisitConfig()
Expand All @@ -235,7 +262,7 @@ def testMakeQuantumInvalidDimension(self):
self.butler.put(butlerTests.MetricsExample(data=inB), "VisitB", dataIdV)

with self.assertRaises(ValueError):
makeQuantum(task, self.butler, {
makeQuantum(task, self.butler, dataIdV, {
"a": dataIdP,
"b": dataIdV,
"outA": dataIdV,
Expand All @@ -249,7 +276,7 @@ def testMakeQuantumMissingMultiple(self):
self._makePatchTestData(dataId)

with self.assertRaises(ValueError):
makeQuantum(task, self.butler, {
makeQuantum(task, self.butler, dataId, {
"a": dict(dataId, patch=0),
"b": dataId,
"out": [dict(dataId, patch=patch) for patch in {0, 1}],
Expand All @@ -262,7 +289,7 @@ def testMakeQuantumExtraMultiple(self):
self._makePatchTestData(dataId)

with self.assertRaises(ValueError):
makeQuantum(task, self.butler, {
makeQuantum(task, self.butler, dataId, {
"a": [dict(dataId, patch=patch) for patch in {0, 1}],
"b": [dataId],
"out": [dict(dataId, patch=patch) for patch in {0, 1}],
Expand All @@ -275,9 +302,9 @@ def testMakeQuantumMissingDataId(self):
self._makeVisitTestData(dataId)

with self.assertRaises(ValueError):
makeQuantum(task, self.butler, {key: dataId for key in {"a", "outA", "outB"}})
makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "outA", "outB"}})
with self.assertRaises(ValueError):
makeQuantum(task, self.butler, {key: dataId for key in {"a", "b", "outB"}})
makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "b", "outB"}})

def testMakeQuantumCorruptedDataId(self):
task = VisitTask()
Expand All @@ -286,16 +313,17 @@ def testMakeQuantumCorruptedDataId(self):
self._makeVisitTestData(dataId)

with self.assertRaises(ValueError):
# third argument should be a mapping keyed by component name
makeQuantum(task, self.butler, dataId)
# fourth argument should be a mapping keyed by component name
makeQuantum(task, self.butler, dataId, dataId)

def testRunTestQuantumVisitWithRun(self):
task = VisitTask()

dataId = butlerTests.expandUniqueId(self.butler, {"visit": 102})
data = self._makeVisitTestData(dataId)

quantum = makeQuantum(task, self.butler, {key: dataId for key in {"a", "b", "outA", "outB"}})
quantum = makeQuantum(task, self.butler, dataId,
{key: dataId for key in {"a", "b", "outA", "outB"}})
runTestQuantum(task, self.butler, quantum, mockRun=False)

# Can we use runTestQuantum to verify that task.run got called with correct inputs/outputs?
Expand All @@ -312,7 +340,7 @@ def testRunTestQuantumPatchWithRun(self):
dataId = butlerTests.expandUniqueId(self.butler, {"tract": 42})
data = self._makePatchTestData(dataId)

quantum = makeQuantum(task, self.butler, {
quantum = makeQuantum(task, self.butler, dataId, {
"a": [dataset[0] for dataset in data["PatchA"]],
"b": dataId,
"out": [dataset[0] for dataset in data["PatchA"]],
Expand All @@ -333,7 +361,8 @@ def testRunTestQuantumVisitMockRun(self):
dataId = butlerTests.expandUniqueId(self.butler, {"visit": 102})
data = self._makeVisitTestData(dataId)

quantum = makeQuantum(task, self.butler, {key: dataId for key in {"a", "b", "outA", "outB"}})
quantum = makeQuantum(task, self.butler, dataId,
{key: dataId for key in {"a", "b", "outA", "outB"}})
run = runTestQuantum(task, self.butler, quantum, mockRun=True)

# Can we use the mock to verify that task.run got called with the correct inputs?
Expand All @@ -346,7 +375,7 @@ def testRunTestQuantumPatchMockRun(self):
dataId = butlerTests.expandUniqueId(self.butler, {"tract": 42})
data = self._makePatchTestData(dataId)

quantum = makeQuantum(task, self.butler, {
quantum = makeQuantum(task, self.butler, dataId, {
# Use lists, not sets, to ensure order agrees with test assertion
"a": [dataset[0] for dataset in data["PatchA"]],
"b": dataId,
Expand All @@ -368,7 +397,7 @@ def testRunTestQuantumPatchOptionalInput(self):
dataId = butlerTests.expandUniqueId(self.butler, {"tract": 42})
data = self._makePatchTestData(dataId)

quantum = makeQuantum(task, self.butler, {
quantum = makeQuantum(task, self.butler, dataId, {
# Use lists, not sets, to ensure order agrees with test assertion
"a": [dataset[0] for dataset in data["PatchA"]],
"out": [dataset[0] for dataset in data["PatchA"]],
Expand Down Expand Up @@ -432,6 +461,19 @@ def run(a, b):
with self.assertRaises(AssertionError):
assertValidOutput(task, result)

def testSkypixHandling(self):
task = SkyPixTask()

dataId = {"htm7": 157227} # connection declares skypix, but Butler uses htm7
data = butlerTests.MetricsExample(data=[1, 2, 3])
self.butler.put(data, "PixA", dataId)

quantum = makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "out"}})
run = runTestQuantum(task, self.butler, quantum, mockRun=True)

# PixA dataset should have been retrieved by runTestQuantum
run.assert_called_once_with(a=data)


class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
pass
Expand Down

0 comments on commit 557baf8

Please sign in to comment.