Skip to content

Commit

Permalink
Merge pull request #269 from lsst/tickets/DM-35752
Browse files Browse the repository at this point in the history
DM-35752: Allow None dataset ref to be passed to ButlerQC
  • Loading branch information
andy-slac committed Aug 1, 2022
2 parents f8932a4 + 7562f1f commit 9602fcb
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 10 deletions.
1 change: 1 addition & 0 deletions doc/changes/DM-35752.api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`ButlerQuantumContext.get` method can accept `None` as a reference and returns `None` as a result object.
20 changes: 12 additions & 8 deletions python/lsst/pipe/base/butlerQuantumContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

__all__ = ("ButlerQuantumContext",)

from typing import Any, List, Sequence, Union
from typing import Any, List, Optional, Sequence, Union

from lsst.daf.butler import Butler, DatasetRef, Quantum
from lsst.utils.introspection import get_full_type_name
Expand Down Expand Up @@ -88,13 +88,14 @@ def __init__(self, butler: Butler, quantum: Quantum):
self.allOutputs.add((ref.datasetType, ref.dataId))
self.__butler = butler

def _get(self, ref: Union[DeferredDatasetRef, DatasetRef]) -> Any:
def _get(self, ref: Optional[Union[DeferredDatasetRef, DatasetRef]]) -> Any:
# Butler methods below will check for unresolved DatasetRefs and
# raise appropriately, so no need for us to do that here.
if isinstance(ref, DeferredDatasetRef):
self._checkMembership(ref.datasetRef, self.allInputs)
return self.__butler.getDirectDeferred(ref.datasetRef)

elif ref is None:
return None
else:
self._checkMembership(ref, self.allInputs)
return self.__butler.getDirect(ref)
Expand All @@ -107,10 +108,11 @@ def get(
self,
dataset: Union[
InputQuantizedConnection,
List[DatasetRef],
List[DeferredDatasetRef],
List[Optional[DatasetRef]],
List[Optional[DeferredDatasetRef]],
DatasetRef,
DeferredDatasetRef,
None,
],
) -> Any:
"""Fetches data from the butler
Expand All @@ -122,7 +124,9 @@ def get(
describes all the inputs of a quantum, a list of
`~lsst.daf.butler.DatasetRef`, or a single
`~lsst.daf.butler.DatasetRef`. The function will get and return
the corresponding datasets from the butler.
the corresponding datasets from the butler. If `None` is passed in
place of a `~lsst.daf.butler.DatasetRef` then the corresponding
returned object will be `None`.
Returns
-------
Expand Down Expand Up @@ -192,12 +196,12 @@ def get(
# Mypy is not sure of the type of x because of the union
# of lists so complains. Ignoring it is more efficient
# than adding an isinstance assert.
retrieved.append(self._get(x)) # type: ignore
retrieved.append(self._get(x))
periodic.log("Retrieved %d out of %d datasets", i + 1, n_datasets)
if periodic.num_issued > 0:
_LOG.verbose("Completed retrieval of %d datasets", n_datasets)
return retrieved
elif isinstance(dataset, DatasetRef) or isinstance(dataset, DeferredDatasetRef):
elif isinstance(dataset, DatasetRef) or isinstance(dataset, DeferredDatasetRef) or dataset is None:
return self._get(dataset)
else:
raise TypeError(
Expand Down
48 changes: 46 additions & 2 deletions tests/test_pipelineTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _makeDSRefVisit(self, dstype, visitId, universe):
),
)

def _makeQuanta(self, config):
def _makeQuanta(self, config, nquanta=100):
"""Create set of Quanta"""
universe = DimensionUniverse()
connections = config.connections.ConnectionsClass(config=config)
Expand All @@ -122,7 +122,7 @@ def _makeQuanta(self, config):
dstype1 = connections.output.makeDatasetType(universe)

quanta = []
for visit in range(100):
for visit in range(nquanta):
inputRef = self._makeDSRefVisit(dstype0, visit, universe)
outputRef = self._makeDSRefVisit(dstype1, visit, universe)
quantum = Quantum(
Expand Down Expand Up @@ -244,6 +244,50 @@ def testChain2(self):
ref = quantum.outputs[outputName][0]
self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200)

def testButlerQC(self):
"""Test for ButlerQuantumContext."""

butler = ButlerMock()
task = AddTask(config=AddConfig())
connections = task.config.connections.ConnectionsClass(config=task.config)

# make one quantum
(quantum,) = self._makeQuanta(task.config, 1)

# add input data to butler
dstype0 = connections.input.makeDatasetType(butler.registry.dimensions)
ref = quantum.inputs[dstype0.name][0]
butler.put(100, pipeBase.Struct(datasetType=dstype0.name, dataId=ref.dataId))

butlerQC = pipeBase.ButlerQuantumContext(butler, quantum)

# Pass ref as single argument or a list.
obj = butlerQC.get(ref)
self.assertEqual(obj, 100)
obj = butlerQC.get([ref])
self.assertEqual(obj, [100])

# Pass None instead of a ref.
obj = butlerQC.get(None)
self.assertIsNone(obj)
obj = butlerQC.get([None])
self.assertEqual(obj, [None])

# COmbine a ref and None.
obj = butlerQC.get([ref, None])
self.assertEqual(obj, [100, None])

# Use refs from a QuantizedConnection.
inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
obj = butlerQC.get(inputRefs)
self.assertEqual(obj, {"input": 100})

# Add few None values to a QuantizedConnection.
inputRefs.input = [None, ref]
inputRefs.input2 = None
obj = butlerQC.get(inputRefs)
self.assertEqual(obj, {"input": [None, 100], "input2": None})


class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
pass
Expand Down

0 comments on commit 9602fcb

Please sign in to comment.