Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-33492: Enable store of resolved DatasetRefs in QuantumGraph #262

Merged
merged 1 commit into from
Aug 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 28 additions & 1 deletion python/lsst/pipe/base/graph/_versionDeserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,23 @@
from collections import defaultdict
from dataclasses import dataclass
from types import SimpleNamespace
from typing import TYPE_CHECKING, Callable, ClassVar, DefaultDict, Dict, Optional, Set, Tuple, Type
from typing import (
TYPE_CHECKING,
Callable,
ClassVar,
DefaultDict,
Dict,
List,
Optional,
Set,
Tuple,
Type,
cast,
)

import networkx as nx
from lsst.daf.butler import (
DatasetRef,
DimensionConfig,
DimensionRecord,
DimensionUniverse,
Expand Down Expand Up @@ -520,6 +533,8 @@ def constructGraph(
datasetDict = _DatasetTracker[DatasetTypeName, TaskDef](createInverse=True)
taskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set)
recontitutedDimensions: Dict[int, Tuple[str, DimensionRecord]] = {}
initInputRefs: Dict[TaskDef, List[DatasetRef]] = {}
initOutputRefs: Dict[TaskDef, List[DatasetRef]] = {}

if universe is not None:
if not universe.isCompatibleWith(self.infoMappings.universe):
Expand Down Expand Up @@ -566,6 +581,16 @@ def constructGraph(
)
loadedTaskDef[nodeTaskLabel] = recreatedTaskDef

# initInputRefs and initOutputRefs are optional
if (refs := taskDefDump.get("initInputRefs")) is not None:
initInputRefs[recreatedTaskDef] = [
cast(DatasetRef, DatasetRef.from_json(ref, universe=universe)) for ref in refs
]
if (refs := taskDefDump.get("initOutputRefs")) is not None:
initOutputRefs[recreatedTaskDef] = [
cast(DatasetRef, DatasetRef.from_json(ref, universe=universe)) for ref in refs
]

# rebuild the mappings that associate dataset type names with
# TaskDefs
for _, input in self.infoMappings.taskDefMap[nodeTaskLabel]["inputs"]:
Expand Down Expand Up @@ -613,6 +638,8 @@ def constructGraph(
newGraph._taskToQuantumNode = dict(taskToQuantumNode.items())
newGraph._taskGraph = datasetDict.makeNetworkXGraph()
newGraph._connectedQuanta = graph
newGraph._initInputRefs = initInputRefs
newGraph._initOutputRefs = initOutputRefs
return newGraph


Expand Down
71 changes: 69 additions & 2 deletions python/lsst/pipe/base/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ class QuantumGraph:
metadata : Optional Mapping of `str` to primitives
This is an optional parameter of extra data to carry with the graph.
Entries in this mapping should be able to be serialized in JSON.
pruneRefs : iterable [ `DatasetRef` ], optional
Set of dataset refs to exclude from a graph.
initInputs : `Mapping`, optional
Maps tasks to their InitInput dataset refs. Dataset refs can be either
resolved or non-resolved. Presently the same dataset refs are included
in each `Quantum` for the same task.
initOutputs : `Mapping`, optional
Maps tasks to their InitOutput dataset refs. Dataset refs can be either
resolved or non-resolved. For intermediate resolved refs their dataset
ID must match ``initInputs`` and Quantum ``initInputs``.

Raises
------
Expand All @@ -126,8 +136,17 @@ def __init__(
metadata: Optional[Mapping[str, Any]] = None,
pruneRefs: Optional[Iterable[DatasetRef]] = None,
universe: Optional[DimensionUniverse] = None,
initInputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
initOutputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
):
self._buildGraphs(quanta, metadata=metadata, pruneRefs=pruneRefs, universe=universe)
self._buildGraphs(
quanta,
metadata=metadata,
pruneRefs=pruneRefs,
universe=universe,
initInputs=initInputs,
initOutputs=initOutputs,
)

def _buildGraphs(
self,
Expand All @@ -138,6 +157,8 @@ def _buildGraphs(
metadata: Optional[Mapping[str, Any]] = None,
pruneRefs: Optional[Iterable[DatasetRef]] = None,
universe: Optional[DimensionUniverse] = None,
initInputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
initOutputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
) -> None:
"""Builds the graph that is used to store the relation between tasks,
and the graph that holds the relations between quanta
Expand Down Expand Up @@ -274,6 +295,13 @@ def _buildGraphs(
# insertion
self._taskToQuantumNode = dict(self._taskToQuantumNode.items())

self._initInputRefs: Dict[TaskDef, List[DatasetRef]] = {}
self._initOutputRefs: Dict[TaskDef, List[DatasetRef]] = {}
if initInputs is not None:
self._initInputRefs = {taskDef: list(refs) for taskDef, refs in initInputs.items()}
if initOutputs is not None:
self._initOutputRefs = {taskDef: list(refs) for taskDef, refs in initOutputs.items()}

@property
def taskGraph(self) -> nx.DiGraph:
"""Return a graph representing the relations between the tasks inside
Expand Down Expand Up @@ -757,6 +785,39 @@ def metadata(self) -> Optional[MappingProxyType[str, Any]]:
return None
return MappingProxyType(self._metadata)

def initInputRefs(self, taskDef: TaskDef) -> Optional[List[DatasetRef]]:
"""Return DatasetRefs for a given task InitInputs.

Parameters
----------
taskDef : `TaskDef`
Task definition structure.

Returns
-------
refs : `list` [ `DatasetRef` ] or None
DatasetRef for the task InitInput, can be `None`. This can return
either resolved or non-resolved reference.
"""
return self._initInputRefs.get(taskDef)

def initOutputRefs(self, taskDef: TaskDef) -> Optional[List[DatasetRef]]:
"""Return DatasetRefs for a given task InitOutputs.

Parameters
----------
taskDef : `TaskDef`
Task definition structure.

Returns
-------
refs : `list` [ `DatasetRef` ] or None
DatasetRefs for the task InitOutput, can be `None`. This can return
either resolved or non-resolved reference. Resolved reference will
match Quantum's initInputs if this is an intermediate dataset type.
"""
return self._initOutputRefs.get(taskDef)

@classmethod
def loadUri(
cls,
Expand Down Expand Up @@ -924,7 +985,7 @@ def _buildSaveObject(self, returnHeader: bool = False) -> Union[bytearray, Tuple
for taskDef in self.taskGraph:
# compressing has very little impact on saving or load time, but
# a large impact on on disk size, so it is worth doing
taskDescription = {}
taskDescription: Dict[str, Any] = {}
# save the fully qualified name.
taskDescription["taskName"] = get_full_type_name(taskDef.taskClass)
# save the config as a text stream that will be un-persisted on the
Expand All @@ -933,6 +994,10 @@ def _buildSaveObject(self, returnHeader: bool = False) -> Union[bytearray, Tuple
taskDef.config.saveToStream(stream)
taskDescription["config"] = stream.getvalue()
taskDescription["label"] = taskDef.label
if (refs := self._initInputRefs.get(taskDef)) is not None:
taskDescription["initInputRefs"] = [ref.to_json() for ref in refs]
if (refs := self._initOutputRefs.get(taskDef)) is not None:
taskDescription["initOutputRefs"] = [ref.to_json() for ref in refs]

inputs = []
outputs = []
Expand Down Expand Up @@ -1162,6 +1227,8 @@ def __setstate__(self, state: dict) -> None:
self._taskToQuantumNode = qgraph._taskToQuantumNode
self._taskGraph = qgraph._taskGraph
self._connectedQuanta = qgraph._connectedQuanta
self._initInputRefs = qgraph._initInputRefs
self._initOutputRefs = qgraph._initOutputRefs

def __eq__(self, other: object) -> bool:
if not isinstance(other, QuantumGraph):
Expand Down
79 changes: 77 additions & 2 deletions python/lsst/pipe/base/graphBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@
from collections import ChainMap
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Collection, Dict, Iterable, Iterator, List, Mapping, Optional, Set, Union
from typing import Any, Collection, Dict, Iterable, Iterator, List, Mapping, Optional, Set, Tuple, Union

from lsst.daf.butler import (
CollectionSearch,
CollectionType,
DataCoordinate,
DatasetIdGenEnum,
DatasetRef,
DatasetType,
Datastore,
Expand Down Expand Up @@ -485,6 +486,34 @@ def makeQuantumSet(
return outputs


class _DatasetIdMaker:
"""Helper class which generates random dataset UUIDs for unresolved
datasets.
"""

def __init__(self, registry: Registry, run: str):
self.datasetIdFactory = registry.datasetIdFactory
self.run = run
# Dataset IDs generated so far
self.resolved: Dict[Tuple[DatasetType, DataCoordinate], DatasetRef] = {}

def resolveRef(self, ref: DatasetRef) -> DatasetRef:
if ref.id is not None:
return ref
key = ref.datasetType, ref.dataId
if (resolved := self.resolved.get(key)) is None:
datasetId = self.datasetIdFactory.makeDatasetId(
self.run, ref.datasetType, ref.dataId, DatasetIdGenEnum.UNIQUE
)
resolved = ref.resolved(datasetId, self.run)
self.resolved[key] = resolved
return resolved

def resolveDict(self, refs: Dict[DataCoordinate, DatasetRef]) -> Dict[DataCoordinate, DatasetRef]:
"""Resolve all unresolved references in the provided dictionary."""
return {dataId: self.resolveRef(ref) for dataId, ref in refs.items()}


@dataclass
class _PipelineScaffolding:
"""A helper data structure that organizes the information involved in
Expand Down Expand Up @@ -711,6 +740,7 @@ def connectDataIds(
# quanta and then connecting them to each other.
n = -1
for n, commonDataId in enumerate(commonDataIds):
_LOG.debug("Next DataID = %s", commonDataId)
# Create DatasetRefs for all DatasetTypes from this result row,
# noting that we might have created some already.
# We remember both those that already existed and those that we
Expand All @@ -727,6 +757,7 @@ def connectDataIds(
ref = refs.get(datasetDataId)
if ref is None:
ref = DatasetRef(datasetType, datasetDataId)
_LOG.debug("Made new ref = %s", ref)
refs[datasetDataId] = ref
refsForRow[datasetType.name] = ref
# Create _QuantumScaffolding objects for all tasks from this
Expand Down Expand Up @@ -788,6 +819,7 @@ def resolveDatasetRefs(
skipExistingIn: Any = None,
clobberOutputs: bool = True,
constrainedByAllDatasets: bool = True,
resolveRefs: bool = False,
) -> None:
"""Perform follow up queries for each dataset data ID produced in
`fillDataIds`.
Expand Down Expand Up @@ -823,6 +855,10 @@ def resolveDatasetRefs(
constrainedByAllDatasets : `bool`, optional
Indicates if the commonDataIds were generated with a constraint on
all dataset types.
resolveRefs : `bool`, optional
If `True` then resolve all input references and generate random
dataset IDs for all output and intermediate datasets. True value
requires ``run`` collection to be specified.

Raises
------
Expand All @@ -845,6 +881,11 @@ def resolveDatasetRefs(
collectionTypes=CollectionType.RUN,
)

idMaker: Optional[_DatasetIdMaker] = None
if resolveRefs:
assert run is not None, "run cannot be None when resolveRefs is True"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is assert really the right kind of exception to raise here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine here, it is internal code only called from makeGraph(). makeGraph() has an explicit check that raises ValueError, here assert is mostly to keep mypy happy.

idMaker = _DatasetIdMaker(registry, run)

# Look up [init] intermediate and output datasets in the output
# collection, if there is an output collection.
if run is not None or skipCollections is not None:
Expand Down Expand Up @@ -879,6 +920,11 @@ def resolveDatasetRefs(
f"output RUN collection '{run}' with data ID"
f" {resolvedRef.dataId}."
)
# If we are going to resolve all outputs then we have
# to remember existing ones to avoid generating new
# dataset IDs for them.
if resolveRefs:
refs[resolvedRef.dataId] = resolvedRef

# And check skipExistingIn too, if RUN collection is in
# it is handled above
Expand Down Expand Up @@ -1017,6 +1063,19 @@ def resolveDatasetRefs(
quantum.prerequisites[datasetType].update(
{ref.dataId: ref for ref in prereq_refs if ref is not None}
)

# Resolve all quantum inputs and outputs.
if idMaker:
for datasetDict in (quantum.inputs, quantum.outputs):
for refDict in datasetDict.values():
refDict.update(idMaker.resolveDict(refDict))

# Resolve task initInputs and initOutputs.
if idMaker:
for datasetDict in (task.initInputs, task.initOutputs):
for refDict in datasetDict.values():
refDict.update(idMaker.resolveDict(refDict))

# Actually remove any quanta that we decided to skip above.
if dataIdsSucceeded:
if skipCollections is not None:
Expand Down Expand Up @@ -1088,8 +1147,16 @@ def _make_refs(dataset_dict: _DatasetDict) -> Iterable[DatasetRef]:
qset = task.makeQuantumSet(unresolvedRefs=self.unfoundRefs, datastore_records=datastore_records)
graphInput[task.taskDef] = qset

taskInitInputs = {task.taskDef: task.initInputs.unpackSingleRefs().values() for task in self.tasks}
taskInitOutputs = {task.taskDef: task.initOutputs.unpackSingleRefs().values() for task in self.tasks}

graph = QuantumGraph(
graphInput, metadata=metadata, pruneRefs=self.unfoundRefs, universe=self.dimensions.universe
graphInput,
metadata=metadata,
pruneRefs=self.unfoundRefs,
universe=self.dimensions.universe,
initInputs=taskInitInputs,
initOutputs=taskInitOutputs,
)
return graph

Expand Down Expand Up @@ -1158,6 +1225,7 @@ def makeGraph(
userQuery: Optional[str],
datasetQueryConstraint: DatasetQueryConstraintVariant = DatasetQueryConstraintVariant.ALL,
metadata: Optional[Mapping[str, Any]] = None,
resolveRefs: bool = False,
) -> QuantumGraph:
"""Create execution graph for a pipeline.

Expand All @@ -1183,6 +1251,10 @@ def makeGraph(
This is an optional parameter of extra data to carry with the
graph. Entries in this mapping should be able to be serialized in
JSON.
resolveRefs : `bool`, optional
If `True` then resolve all input references and generate random
dataset IDs for all output and intermediate datasets. True value
requires ``run`` collection to be specified.

Returns
-------
Expand All @@ -1198,6 +1270,8 @@ def makeGraph(
Other exceptions types may be raised by underlying registry
classes.
"""
if resolveRefs and run is None:
raise ValueError("`resolveRefs` requires `run` parameter.")
scaffolding = _PipelineScaffolding(pipeline, registry=self.registry)
if not collections and (scaffolding.initInputs or scaffolding.inputs or scaffolding.prerequisites):
raise ValueError("Pipeline requires input datasets but no input collections provided.")
Expand Down Expand Up @@ -1225,5 +1299,6 @@ def makeGraph(
skipExistingIn=self.skipExistingIn,
clobberOutputs=self.clobberOutputs,
constrainedByAllDatasets=condition,
resolveRefs=resolveRefs,
)
return scaffolding.makeQuantumGraph(metadata=metadata, datastore=self.datastore)
5 changes: 5 additions & 0 deletions python/lsst/pipe/base/tests/simpleQGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def makeSimpleQGraph(
datasetTypes: Optional[Dict[Optional[str], List[str]]] = None,
datasetQueryConstraint: DSQVariant = DSQVariant.ALL,
makeDatastoreRecords: bool = False,
resolveRefs: bool = False,
) -> Tuple[Butler, QuantumGraph]:
"""Make simple QuantumGraph for tests.

Expand Down Expand Up @@ -405,6 +406,9 @@ def makeSimpleQGraph(
`DatasetQueryConstraintVariant.ALL`.
makeDatastoreRecords : `bool`, optional
If `True` then add datstore records to generated quanta.
resolveRefs : `bool`, optional
If `True` then resolve all input references and generate random dataset
IDs for all output and intermediate datasets.

Returns
-------
Expand Down Expand Up @@ -446,6 +450,7 @@ def makeSimpleQGraph(
run=run or butler.run,
userQuery=userQuery,
datasetQueryConstraint=datasetQueryConstraint,
resolveRefs=resolveRefs,
)

return butler, qgraph