Skip to content

Commit

Permalink
Merge pull request #307 from lsst/tickets/DM-37995
Browse files Browse the repository at this point in the history
DM-37995: Add registry dataset types to QuantumGraph
  • Loading branch information
andy-slac committed Feb 28, 2023
2 parents 80dcb53 + 4c80bb9 commit ef3da71
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 7 deletions.
7 changes: 7 additions & 0 deletions python/lsst/pipe/base/graph/_versionDeserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import networkx as nx
from lsst.daf.butler import (
DatasetRef,
DatasetType,
DimensionConfig,
DimensionRecord,
DimensionUniverse,
Expand Down Expand Up @@ -514,6 +515,11 @@ def readHeaderInfo(self, rawHeader: bytes) -> SimpleNamespace:
infoMappings.globalInitOutputRefs = [
DatasetRef.from_json(json_ref, universe=universe) for json_ref in json_refs
]
infoMappings.registryDatasetTypes = []
if (json_refs := infoMap.get("RegistryDatasetTypes")) is not None:
infoMappings.registryDatasetTypes = [
DatasetType.from_json(json_ref, universe=universe) for json_ref in json_refs
]
self.infoMappings = infoMappings
return infoMappings

Expand Down Expand Up @@ -643,6 +649,7 @@ def constructGraph(
newGraph._initInputRefs = initInputRefs
newGraph._initOutputRefs = initOutputRefs
newGraph._globalInitOutputRefs = self.infoMappings.globalInitOutputRefs
newGraph._registryDatasetTypes = self.infoMappings.registryDatasetTypes
newGraph._universe = universe
return newGraph

Expand Down
50 changes: 48 additions & 2 deletions python/lsst/pipe/base/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ class QuantumGraph:
objects include task configurations and package versions. Typically
they have an empty DataId, but there is no real restriction on what
can appear here.
registryDatasetTypes : iterable [ `DatasetType` ], optional
Dataset types which are used by this graph, their definitions must
match registry. If registry does not define dataset type yet, then
it should match one that will be created later.
Raises
------
Expand All @@ -144,6 +148,7 @@ def __init__(
initInputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
initOutputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
globalInitOutputs: Optional[Iterable[DatasetRef]] = None,
registryDatasetTypes: Optional[Iterable[DatasetType]] = None,
):
self._buildGraphs(
quanta,
Expand All @@ -153,6 +158,7 @@ def __init__(
initInputs=initInputs,
initOutputs=initOutputs,
globalInitOutputs=globalInitOutputs,
registryDatasetTypes=registryDatasetTypes,
)

def _buildGraphs(
Expand All @@ -167,6 +173,7 @@ def _buildGraphs(
initInputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
initOutputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
globalInitOutputs: Optional[Iterable[DatasetRef]] = None,
registryDatasetTypes: Optional[Iterable[DatasetType]] = None,
) -> None:
"""Builds the graph that is used to store the relation between tasks,
and the graph that holds the relations between quanta
Expand Down Expand Up @@ -310,12 +317,15 @@ def _buildGraphs(
self._initInputRefs: Dict[TaskDef, List[DatasetRef]] = {}
self._initOutputRefs: Dict[TaskDef, List[DatasetRef]] = {}
self._globalInitOutputRefs: List[DatasetRef] = []
self._registryDatasetTypes: List[DatasetType] = []
if initInputs is not None:
self._initInputRefs = {taskDef: list(refs) for taskDef, refs in initInputs.items()}
if initOutputs is not None:
self._initOutputRefs = {taskDef: list(refs) for taskDef, refs in initOutputs.items()}
if globalInitOutputs is not None:
self._globalInitOutputRefs = list(globalInitOutputs)
if registryDatasetTypes is not None:
self._registryDatasetTypes = list(registryDatasetTypes)

@property
def taskGraph(self) -> nx.DiGraph:
Expand Down Expand Up @@ -413,13 +423,17 @@ def pruneGraphFromRefs(self: _T, refs: Iterable[DatasetRef]) -> _T:
# convert to standard dict to prevent accidental key insertion
quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items())

# This should not change set of tasks in a graph, so we can keep the
# same registryDatasetTypes as in the original graph.
# TODO: Do we need to copy initInputs/initOutputs?
newInst._buildGraphs(
quantumDict,
_quantumToNodeId={n.quantum: n.nodeId for n in self},
metadata=self._metadata,
pruneRefs=refs,
universe=self._universe,
globalInitOutputs=self._globalInitOutputRefs,
registryDatasetTypes=self._registryDatasetTypes,
)
return newInst

Expand Down Expand Up @@ -682,21 +696,39 @@ def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T:
quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes
quantumMap = defaultdict(set)

dataset_type_names: set[str] = set()
node: QuantumNode
for node in quantumSubgraph:
quantumMap[node.taskDef].add(node.quantum)
dataset_type_names.update(
dstype.name
for dstype in chain(
node.quantum.inputs.keys(), node.quantum.outputs.keys(), node.quantum.initInputs.keys()
)
)

# May need to trim dataset types from registryDatasetTypes.
for taskDef in quantumMap:
if refs := self.initOutputRefs(taskDef):
dataset_type_names.update(ref.datasetType.name for ref in refs)
dataset_type_names.update(ref.datasetType.name for ref in self._globalInitOutputRefs)
registryDatasetTypes = [
dstype for dstype in self._registryDatasetTypes if dstype.name in dataset_type_names
]

# convert to standard dict to prevent accidental key insertion
quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items())
# Create an empty graph, and then populate it with custom mapping
newInst = type(self)({}, universe=self._universe)
# TODO: Do we need to copy initInputs/initOutputs?
newInst._buildGraphs(
quantumDict,
_quantumToNodeId={n.quantum: n.nodeId for n in nodes},
_buildId=self._buildId,
metadata=self._metadata,
universe=self._universe,
globalInitOutputs=self._globalInitOutputRefs,
registryDatasetTypes=registryDatasetTypes,
)
return newInst

Expand Down Expand Up @@ -862,6 +894,17 @@ def globalInitOutputRefs(self) -> List[DatasetRef]:
"""
return self._globalInitOutputRefs

def registryDatasetTypes(self) -> List[DatasetType]:
"""Return dataset types used by this graph, their definitions match
dataset types from registry.
Returns
-------
refs : `list` [ `DatasetType` ]
Dataset types for this graph.
"""
return self._registryDatasetTypes

@classmethod
def loadUri(
cls,
Expand Down Expand Up @@ -1115,6 +1158,9 @@ def _buildSaveObject(self, returnHeader: bool = False) -> Union[bytearray, Tuple
if self._globalInitOutputRefs:
headerData["GlobalInitOutputRefs"] = [ref.to_json() for ref in self._globalInitOutputRefs]

if self._registryDatasetTypes:
headerData["RegistryDatasetTypes"] = [dstype.to_json() for dstype in self._registryDatasetTypes]

# dump the headerData to json
header_encode = lzma.compress(json.dumps(headerData).encode())

Expand Down Expand Up @@ -1252,8 +1298,8 @@ def __contains__(self, node: QuantumNode) -> bool:
def __getstate__(self) -> dict:
"""Stores a compact form of the graph as a list of graph nodes, and a
tuple of task labels and task configs. The full graph can be
reconstructed with this information, and it preseves the ordering of
the graph ndoes.
reconstructed with this information, and it preserves the ordering of
the graph nodes.
"""
universe: Optional[DimensionUniverse] = None
for node in self:
Expand Down
67 changes: 64 additions & 3 deletions python/lsst/pipe/base/graphBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# -------------------------------
import itertools
import logging
from collections import ChainMap
from collections import ChainMap, defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Collection, Dict, Iterable, Iterator, List, Mapping, Optional, Set, Tuple, Union
Expand Down Expand Up @@ -1236,13 +1236,18 @@ def resolveDatasetRefs(
refDict.update(idMaker.resolveDict(refDict))

def makeQuantumGraph(
self, metadata: Optional[Mapping[str, Any]] = None, datastore: Optional[Datastore] = None
self,
registry: Registry,
metadata: Optional[Mapping[str, Any]] = None,
datastore: Optional[Datastore] = None,
) -> QuantumGraph:
"""Create a `QuantumGraph` from the quanta already present in
the scaffolding data structure.
Parameters
---------
registry : `lsst.daf.butler.Registry`
Registry for the data repository; used for all data ID queries.
metadata : Optional Mapping of `str` to primitives
This is an optional parameter of extra data to carry with the
graph. Entries in this mapping should be able to be serialized in
Expand Down Expand Up @@ -1291,9 +1296,63 @@ def _make_refs(dataset_dict: _DatasetDict) -> Iterable[DatasetRef]:
initInputs=taskInitInputs,
initOutputs=taskInitOutputs,
globalInitOutputs=globalInitOutputs,
registryDatasetTypes=self._get_registry_dataset_types(registry),
)
return graph

def _get_registry_dataset_types(self, registry: Registry) -> Iterable[DatasetType]:
"""Make a list of all dataset types used by a graph as defined in
registry.
"""
chain = [
self.initInputs,
self.initIntermediates,
self.initOutputs,
self.inputs,
self.intermediates,
self.outputs,
self.prerequisites,
]
if self.globalInitOutputs is not None:
chain.append(self.globalInitOutputs)

# Collect names of all dataset types.
all_names: set[str] = set(dstype.name for dstype in itertools.chain(*chain))
dataset_types = {ds.name: ds for ds in registry.queryDatasetTypes(all_names)}

# Check for types that do not exist in registry yet:
# - inputs must exist
# - intermediates and outputs may not exist, but there must not be
# more than one definition (e.g. differing in storage class)
# - prerequisites may not exist, treat it the same as outputs here
for dstype in itertools.chain(self.initInputs, self.inputs):
if dstype.name not in dataset_types:
raise MissingDatasetTypeError(f"Registry is missing an input dataset type {dstype}")

new_outputs: dict[str, set[DatasetType]] = defaultdict(set)
chain = [
self.initIntermediates,
self.initOutputs,
self.intermediates,
self.outputs,
self.prerequisites,
]
if self.globalInitOutputs is not None:
chain.append(self.globalInitOutputs)
for dstype in itertools.chain(*chain):
if dstype.name not in dataset_types:
new_outputs[dstype.name].add(dstype)
for name, dstypes in new_outputs.items():
if len(dstypes) > 1:
raise ValueError(
"Pipeline contains multiple definitions for a dataset type "
f"which is not defined in registry yet: {dstypes}"
)
elif len(dstypes) == 1:
dataset_types[name] = dstypes.pop()

return dataset_types.values()


# ------------------------
# Exported definitions --
Expand Down Expand Up @@ -1439,4 +1498,6 @@ def makeGraph(
constrainedByAllDatasets=condition,
resolveRefs=resolveRefs,
)
return scaffolding.makeQuantumGraph(metadata=metadata, datastore=self.datastore)
return scaffolding.makeQuantumGraph(
registry=self.registry, metadata=metadata, datastore=self.datastore
)
34 changes: 32 additions & 2 deletions tests/test_quantumGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def _makeDatasetType(connection):
tasks = []
initInputs = {}
initOutputs = {}
dataset_types = set()
for task, label in (
(Dummy1PipelineTask, "R"),
(Dummy2PipelineTask, "S"),
Expand All @@ -177,14 +178,18 @@ def _makeDatasetType(connection):
initInputDSType = _makeDatasetType(connections.initInput)
initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))]
initInputs[taskDef] = initRefs
dataset_types.add(initInputDSType)
else:
initRefs = None
if connections.initOutputs:
initOutputDSType = _makeDatasetType(connections.initOutput)
initRefs = [DatasetRef(initOutputDSType, DataCoordinate.makeEmpty(universe))]
initOutputs[taskDef] = initRefs
dataset_types.add(initOutputDSType)
inputDSType = _makeDatasetType(connections.input)
dataset_types.add(inputDSType)
outputDSType = _makeDatasetType(connections.output)
dataset_types.add(outputDSType)
for a, b in ((1, 2), (3, 4)):
inputRefs = [
DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
Expand All @@ -206,6 +211,7 @@ def _makeDatasetType(connection):
self.tasks = tasks
self.quantumMap = quantumMap
self.packagesDSType = DatasetType("packages", universe.empty, storageClass="Packages")
dataset_types.add(self.packagesDSType)
globalInitOutputs = [DatasetRef(self.packagesDSType, DataCoordinate.makeEmpty(universe))]
self.qGraph = QuantumGraph(
quantumMap,
Expand All @@ -214,8 +220,10 @@ def _makeDatasetType(connection):
initInputs=initInputs,
initOutputs=initOutputs,
globalInitOutputs=globalInitOutputs,
registryDatasetTypes=dataset_types,
)
self.universe = universe
self.num_dataset_types = len(dataset_types)

def testTaskGraph(self):
for taskDef in self.quantumMap.keys():
Expand Down Expand Up @@ -299,12 +307,17 @@ def testAllDatasetTypes(self):

def testSubset(self):
allNodes = list(self.qGraph)
subset = self.qGraph.subset(allNodes[0])
firstNode = allNodes[0]
subset = self.qGraph.subset(firstNode)
self.assertEqual(len(subset), 1)
subsetList = list(subset)
self.assertEqual(allNodes[0].quantum, subsetList[0].quantum)
self.assertEqual(firstNode.quantum, subsetList[0].quantum)
self.assertEqual(self.qGraph._buildId, subset._buildId)
self.assertEqual(len(subset.globalInitOutputRefs()), 1)
# Depending on which task was first the list can contain different
# number of datasets. The first task can be either Dummy1 or Dummy4.
num_types = {"R": 4, "U": 3}
self.assertEqual(len(subset.registryDatasetTypes()), num_types[firstNode.taskDef.label])

def testSubsetToConnected(self):
# False because there are two quantum chains for two distinct sets of
Expand Down Expand Up @@ -397,6 +410,7 @@ def testSaveLoad(self):
self.assertEqual(len(restoreSub), 1)
self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId))
self.assertEqual(len(restoreSub.globalInitOutputRefs()), 1)
self.assertEqual(len(restoreSub.registryDatasetTypes()), self.num_dataset_types)
# Check that InitInput and InitOutput refs are restored correctly.
for taskDef in restore.iterTaskGraph():
if taskDef.label in ("S", "T"):
Expand Down Expand Up @@ -474,6 +488,22 @@ def testSaveLoadUri(self):
with self.assertRaises(TypeError):
self.qGraph.saveUri("test.notgraph")

def testSaveLoadNoRegistryDatasetTypes(self):
"""Test for reading quantum that is missing registry dataset types.
This test depends on internals of QuantumGraph implementation, in
particular that empty list of registry dataset types is not stored,
which makes save file identical to the "old" format.
"""
# Reset the list, this is safe as QuantumGraph itself does not use it.
self.qGraph._registryDatasetTypes = []
with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile:
self.qGraph.save(tmpFile)
tmpFile.seek(0)
restore = QuantumGraph.load(tmpFile, self.universe)
self.assertEqual(self.qGraph, restore)
self.assertEqual(restore.registryDatasetTypes(), [])

def testContains(self):
firstNode = next(iter(self.qGraph))
self.assertIn(firstNode, self.qGraph)
Expand Down

0 comments on commit ef3da71

Please sign in to comment.