Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-37995: Add registry dataset types to QuantumGraph #307

Merged
merged 1 commit into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/lsst/pipe/base/graph/_versionDeserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import networkx as nx
from lsst.daf.butler import (
DatasetRef,
DatasetType,
DimensionConfig,
DimensionRecord,
DimensionUniverse,
Expand Down Expand Up @@ -514,6 +515,11 @@ def readHeaderInfo(self, rawHeader: bytes) -> SimpleNamespace:
infoMappings.globalInitOutputRefs = [
DatasetRef.from_json(json_ref, universe=universe) for json_ref in json_refs
]
infoMappings.registryDatasetTypes = []
if (json_refs := infoMap.get("RegistryDatasetTypes")) is not None:
infoMappings.registryDatasetTypes = [
DatasetType.from_json(json_ref, universe=universe) for json_ref in json_refs
]
self.infoMappings = infoMappings
return infoMappings

Expand Down Expand Up @@ -643,6 +649,7 @@ def constructGraph(
newGraph._initInputRefs = initInputRefs
newGraph._initOutputRefs = initOutputRefs
newGraph._globalInitOutputRefs = self.infoMappings.globalInitOutputRefs
newGraph._registryDatasetTypes = self.infoMappings.registryDatasetTypes
newGraph._universe = universe
return newGraph

Expand Down
50 changes: 48 additions & 2 deletions python/lsst/pipe/base/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ class QuantumGraph:
objects include task configurations and package versions. Typically
they have an empty DataId, but there is no real restriction on what
can appear here.
registryDatasetTypes : iterable [ `DatasetType` ], optional
Dataset types which are used by this graph, their definitions must
match registry. If registry does not define dataset type yet, then
it should match one that will be created later.

Raises
------
Expand All @@ -144,6 +148,7 @@ def __init__(
initInputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
initOutputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
globalInitOutputs: Optional[Iterable[DatasetRef]] = None,
registryDatasetTypes: Optional[Iterable[DatasetType]] = None,
):
self._buildGraphs(
quanta,
Expand All @@ -153,6 +158,7 @@ def __init__(
initInputs=initInputs,
initOutputs=initOutputs,
globalInitOutputs=globalInitOutputs,
registryDatasetTypes=registryDatasetTypes,
)

def _buildGraphs(
Expand All @@ -167,6 +173,7 @@ def _buildGraphs(
initInputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
initOutputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
globalInitOutputs: Optional[Iterable[DatasetRef]] = None,
registryDatasetTypes: Optional[Iterable[DatasetType]] = None,
) -> None:
"""Builds the graph that is used to store the relation between tasks,
and the graph that holds the relations between quanta
Expand Down Expand Up @@ -310,12 +317,15 @@ def _buildGraphs(
self._initInputRefs: Dict[TaskDef, List[DatasetRef]] = {}
self._initOutputRefs: Dict[TaskDef, List[DatasetRef]] = {}
self._globalInitOutputRefs: List[DatasetRef] = []
self._registryDatasetTypes: List[DatasetType] = []
if initInputs is not None:
self._initInputRefs = {taskDef: list(refs) for taskDef, refs in initInputs.items()}
if initOutputs is not None:
self._initOutputRefs = {taskDef: list(refs) for taskDef, refs in initOutputs.items()}
if globalInitOutputs is not None:
self._globalInitOutputRefs = list(globalInitOutputs)
if registryDatasetTypes is not None:
self._registryDatasetTypes = list(registryDatasetTypes)

@property
def taskGraph(self) -> nx.DiGraph:
Expand Down Expand Up @@ -413,13 +423,17 @@ def pruneGraphFromRefs(self: _T, refs: Iterable[DatasetRef]) -> _T:
# convert to standard dict to prevent accidental key insertion
quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items())

# This should not change set of tasks in a graph, so we can keep the
# same registryDatasetTypes as in the original graph.
# TODO: Do we need to copy initInputs/initOutputs?
newInst._buildGraphs(
quantumDict,
_quantumToNodeId={n.quantum: n.nodeId for n in self},
metadata=self._metadata,
pruneRefs=refs,
universe=self._universe,
globalInitOutputs=self._globalInitOutputRefs,
registryDatasetTypes=self._registryDatasetTypes,
)
return newInst

Expand Down Expand Up @@ -682,21 +696,39 @@ def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T:
quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes
quantumMap = defaultdict(set)

dataset_type_names: set[str] = set()
node: QuantumNode
for node in quantumSubgraph:
quantumMap[node.taskDef].add(node.quantum)
dataset_type_names.update(
dstype.name
for dstype in chain(
node.quantum.inputs.keys(), node.quantum.outputs.keys(), node.quantum.initInputs.keys()
)
)

# May need to trim dataset types from registryDatasetTypes.
for taskDef in quantumMap:
if refs := self.initOutputRefs(taskDef):
dataset_type_names.update(ref.datasetType.name for ref in refs)
dataset_type_names.update(ref.datasetType.name for ref in self._globalInitOutputRefs)
registryDatasetTypes = [
dstype for dstype in self._registryDatasetTypes if dstype.name in dataset_type_names
]

# convert to standard dict to prevent accidental key insertion
quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items())
# Create an empty graph, and then populate it with custom mapping
newInst = type(self)({}, universe=self._universe)
# TODO: Do we need to copy initInputs/initOutputs?
newInst._buildGraphs(
quantumDict,
_quantumToNodeId={n.quantum: n.nodeId for n in nodes},
_buildId=self._buildId,
metadata=self._metadata,
universe=self._universe,
globalInitOutputs=self._globalInitOutputRefs,
registryDatasetTypes=registryDatasetTypes,
)
return newInst

Expand Down Expand Up @@ -862,6 +894,17 @@ def globalInitOutputRefs(self) -> List[DatasetRef]:
"""
return self._globalInitOutputRefs

def registryDatasetTypes(self) -> List[DatasetType]:
"""Return dataset types used by this graph, their definitions match
dataset types from registry.

Returns
-------
refs : `list` [ `DatasetType` ]
Dataset types for this graph.
"""
return self._registryDatasetTypes

@classmethod
def loadUri(
cls,
Expand Down Expand Up @@ -1115,6 +1158,9 @@ def _buildSaveObject(self, returnHeader: bool = False) -> Union[bytearray, Tuple
if self._globalInitOutputRefs:
headerData["GlobalInitOutputRefs"] = [ref.to_json() for ref in self._globalInitOutputRefs]

if self._registryDatasetTypes:
headerData["RegistryDatasetTypes"] = [dstype.to_json() for dstype in self._registryDatasetTypes]

# dump the headerData to json
header_encode = lzma.compress(json.dumps(headerData).encode())

Expand Down Expand Up @@ -1252,8 +1298,8 @@ def __contains__(self, node: QuantumNode) -> bool:
def __getstate__(self) -> dict:
"""Stores a compact form of the graph as a list of graph nodes, and a
tuple of task labels and task configs. The full graph can be
reconstructed with this information, and it preseves the ordering of
the graph ndoes.
reconstructed with this information, and it preserves the ordering of
the graph nodes.
"""
universe: Optional[DimensionUniverse] = None
for node in self:
Expand Down
67 changes: 64 additions & 3 deletions python/lsst/pipe/base/graphBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# -------------------------------
import itertools
import logging
from collections import ChainMap
from collections import ChainMap, defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Collection, Dict, Iterable, Iterator, List, Mapping, Optional, Set, Tuple, Union
Expand Down Expand Up @@ -1236,13 +1236,18 @@ def resolveDatasetRefs(
refDict.update(idMaker.resolveDict(refDict))

def makeQuantumGraph(
self, metadata: Optional[Mapping[str, Any]] = None, datastore: Optional[Datastore] = None
self,
registry: Registry,
metadata: Optional[Mapping[str, Any]] = None,
datastore: Optional[Datastore] = None,
) -> QuantumGraph:
"""Create a `QuantumGraph` from the quanta already present in
the scaffolding data structure.

Parameters
---------
registry : `lsst.daf.butler.Registry`
Registry for the data repository; used for all data ID queries.
metadata : Optional Mapping of `str` to primitives
This is an optional parameter of extra data to carry with the
graph. Entries in this mapping should be able to be serialized in
Expand Down Expand Up @@ -1291,9 +1296,63 @@ def _make_refs(dataset_dict: _DatasetDict) -> Iterable[DatasetRef]:
initInputs=taskInitInputs,
initOutputs=taskInitOutputs,
globalInitOutputs=globalInitOutputs,
registryDatasetTypes=self._get_registry_dataset_types(registry),
)
return graph

def _get_registry_dataset_types(self, registry: Registry) -> Iterable[DatasetType]:
"""Make a list of all dataset types used by a graph as defined in
registry.
"""
chain = [
self.initInputs,
self.initIntermediates,
self.initOutputs,
self.inputs,
self.intermediates,
self.outputs,
self.prerequisites,
]
if self.globalInitOutputs is not None:
chain.append(self.globalInitOutputs)

# Collect names of all dataset types.
all_names: set[str] = set(dstype.name for dstype in itertools.chain(*chain))
dataset_types = {ds.name: ds for ds in registry.queryDatasetTypes(all_names)}

# Check for types that do not exist in registry yet:
# - inputs must exist
# - intermediates and outputs may not exist, but there must not be
# more than one definition (e.g. differing in storage class)
# - prerequisites may not exist, treat it the same as outputs here
for dstype in itertools.chain(self.initInputs, self.inputs):
if dstype.name not in dataset_types:
raise MissingDatasetTypeError(f"Registry is missing an input dataset type {dstype}")

new_outputs: dict[str, set[DatasetType]] = defaultdict(set)
chain = [
self.initIntermediates,
self.initOutputs,
self.intermediates,
self.outputs,
self.prerequisites,
]
if self.globalInitOutputs is not None:
chain.append(self.globalInitOutputs)
for dstype in itertools.chain(*chain):
if dstype.name not in dataset_types:
new_outputs[dstype.name].add(dstype)
for name, dstypes in new_outputs.items():
if len(dstypes) > 1:
raise ValueError(
"Pipeline contains multiple definitions for a dataset type "
andy-slac marked this conversation as resolved.
Show resolved Hide resolved
f"which is not defined in registry yet: {dstypes}"
)
elif len(dstypes) == 1:
dataset_types[name] = dstypes.pop()

return dataset_types.values()


# ------------------------
# Exported definitions --
Expand Down Expand Up @@ -1439,4 +1498,6 @@ def makeGraph(
constrainedByAllDatasets=condition,
resolveRefs=resolveRefs,
)
return scaffolding.makeQuantumGraph(metadata=metadata, datastore=self.datastore)
return scaffolding.makeQuantumGraph(
registry=self.registry, metadata=metadata, datastore=self.datastore
)
34 changes: 32 additions & 2 deletions tests/test_quantumGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def _makeDatasetType(connection):
tasks = []
initInputs = {}
initOutputs = {}
dataset_types = set()
for task, label in (
(Dummy1PipelineTask, "R"),
(Dummy2PipelineTask, "S"),
Expand All @@ -177,14 +178,18 @@ def _makeDatasetType(connection):
initInputDSType = _makeDatasetType(connections.initInput)
initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))]
initInputs[taskDef] = initRefs
dataset_types.add(initInputDSType)
andy-slac marked this conversation as resolved.
Show resolved Hide resolved
else:
initRefs = None
if connections.initOutputs:
initOutputDSType = _makeDatasetType(connections.initOutput)
initRefs = [DatasetRef(initOutputDSType, DataCoordinate.makeEmpty(universe))]
initOutputs[taskDef] = initRefs
dataset_types.add(initOutputDSType)
inputDSType = _makeDatasetType(connections.input)
dataset_types.add(inputDSType)
outputDSType = _makeDatasetType(connections.output)
dataset_types.add(outputDSType)
for a, b in ((1, 2), (3, 4)):
inputRefs = [
DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
Expand All @@ -206,6 +211,7 @@ def _makeDatasetType(connection):
self.tasks = tasks
self.quantumMap = quantumMap
self.packagesDSType = DatasetType("packages", universe.empty, storageClass="Packages")
dataset_types.add(self.packagesDSType)
globalInitOutputs = [DatasetRef(self.packagesDSType, DataCoordinate.makeEmpty(universe))]
self.qGraph = QuantumGraph(
quantumMap,
Expand All @@ -214,8 +220,10 @@ def _makeDatasetType(connection):
initInputs=initInputs,
initOutputs=initOutputs,
globalInitOutputs=globalInitOutputs,
registryDatasetTypes=dataset_types,
)
self.universe = universe
self.num_dataset_types = len(dataset_types)

def testTaskGraph(self):
for taskDef in self.quantumMap.keys():
Expand Down Expand Up @@ -299,12 +307,17 @@ def testAllDatasetTypes(self):

def testSubset(self):
allNodes = list(self.qGraph)
subset = self.qGraph.subset(allNodes[0])
firstNode = allNodes[0]
subset = self.qGraph.subset(firstNode)
self.assertEqual(len(subset), 1)
subsetList = list(subset)
self.assertEqual(allNodes[0].quantum, subsetList[0].quantum)
self.assertEqual(firstNode.quantum, subsetList[0].quantum)
self.assertEqual(self.qGraph._buildId, subset._buildId)
self.assertEqual(len(subset.globalInitOutputRefs()), 1)
# Depending on which task was first the list can contain different
# number of datasets. The first task can be either Dummy1 or Dummy4.
num_types = {"R": 4, "U": 3}
self.assertEqual(len(subset.registryDatasetTypes()), num_types[firstNode.taskDef.label])

def testSubsetToConnected(self):
# False because there are two quantum chains for two distinct sets of
Expand Down Expand Up @@ -397,6 +410,7 @@ def testSaveLoad(self):
self.assertEqual(len(restoreSub), 1)
self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId))
self.assertEqual(len(restoreSub.globalInitOutputRefs()), 1)
self.assertEqual(len(restoreSub.registryDatasetTypes()), self.num_dataset_types)
# Check that InitInput and InitOutput refs are restored correctly.
for taskDef in restore.iterTaskGraph():
if taskDef.label in ("S", "T"):
Expand Down Expand Up @@ -474,6 +488,22 @@ def testSaveLoadUri(self):
with self.assertRaises(TypeError):
self.qGraph.saveUri("test.notgraph")

def testSaveLoadNoRegistryDatasetTypes(self):
"""Test for reading quantum that is missing registry dataset types.

This test depends on internals of QuantumGraph implementation, in
particular that empty list of registry dataset types is not stored,
which makes save file identical to the "old" format.
"""
# Reset the list, this is safe as QuantumGraph itself does not use it.
self.qGraph._registryDatasetTypes = []
with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile:
self.qGraph.save(tmpFile)
tmpFile.seek(0)
restore = QuantumGraph.load(tmpFile, self.universe)
self.assertEqual(self.qGraph, restore)
self.assertEqual(restore.registryDatasetTypes(), [])

def testContains(self):
firstNode = next(iter(self.qGraph))
self.assertIn(firstNode, self.qGraph)
Expand Down