Skip to content

Commit

Permalink
Merge pull request #265 from lsst/tickets/DM-35681
Browse files Browse the repository at this point in the history
DM-35681: ensure QuantumGraph is passed a DimensionUniverse at construction
  • Loading branch information
TallJimbo committed Jul 23, 2022
2 parents 913bdc4 + 737f6b3 commit e718c22
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repos:
# supported by your project here, or alternatively use
# pre-commit's default_language_version, see
# https://pre-commit.com/#top_level-default_language_version
language_version: python3.8
language_version: python3.10
- repo: https://github.com/pycqa/isort
rev: 5.10.1
hooks:
Expand Down
3 changes: 3 additions & 0 deletions doc/changes/DM-35681.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Ensure QuantumGraphs are given a DimensionUniverse at construction.

This fixes a mostly-spurious dimension universe inconsistency warning when reading QuantumGraphs, introduced on DM-35082.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ version = { attr = "lsst_versions.get_lsst_version" }

[tool.black]
line-length = 110
target-version = ["py38"]
target-version = ["py310"]

[tool.isort]
profile = "black"
Expand Down
2 changes: 2 additions & 0 deletions python/lsst/pipe/base/graph/_loadHelpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ def load(
f"Nodes {remainder} were requested, but could not be found in the input graph"
)
_readBytes = self._readBytes
if universe is None:
universe = self.headerInfo.universe
return self.deserializer.constructGraph(nodeSet, _readBytes, universe)

def _readBytes(self, start: int, stop: int) -> bytes:
Expand Down
2 changes: 2 additions & 0 deletions python/lsst/pipe/base/graph/_versionDeserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def constructGraph(
_quantumToNodeId=quantumToNodeId,
_buildId=self.returnValue._buildId,
metadata=self.returnValue.metadata,
universe=universe,
)
return qGraph

Expand Down Expand Up @@ -390,6 +391,7 @@ def constructGraph(
_quantumToNodeId=quantumToNodeId,
_buildId=self.returnValue._buildId,
metadata=self.returnValue.metadata,
universe=universe,
)
return qGraph

Expand Down
24 changes: 20 additions & 4 deletions python/lsst/pipe/base/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,6 @@ def _buildGraphs(
"""Builds the graph that is used to store the relation between tasks,
and the graph that holds the relations between quanta
"""
if universe is None:
universe = DimensionUniverse()
self._universe = universe
self._metadata = metadata
self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}")
# Data structures used to identify relations between components;
Expand Down Expand Up @@ -173,6 +170,15 @@ def _buildGraphs(
# a newly created QuantumNode to the appropriate input/output
# field.
for quantum in quantumSet:
if quantum.dataId is not None:
if universe is None:
universe = quantum.dataId.universe
elif universe != quantum.dataId.universe:
raise RuntimeError(
"Mismatched dimension universes in QuantumGraph construction: "
f"{universe} != {quantum.dataId.universe}. "
)

if _quantumToNodeId:
if (nodeId := _quantumToNodeId.get(quantum)) is None:
raise ValueError(
Expand Down Expand Up @@ -249,6 +255,14 @@ def _buildGraphs(
f"after graph pruning; {', '.join(culprits)} caused over-pruning"
)

# Dimension universe
if universe is None:
raise RuntimeError(
"Dimension universe or at least one quantum with a data ID "
"must be provided when constructing a QuantumGraph."
)
self._universe = universe

# Graph of quanta relations
self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph()
self._count = len(self._connectedQuanta)
Expand Down Expand Up @@ -360,6 +374,7 @@ def pruneGraphFromRefs(self: _T, refs: Iterable[DatasetRef]) -> _T:
_quantumToNodeId={n.quantum: n.nodeId for n in self},
metadata=self._metadata,
pruneRefs=refs,
universe=self._universe,
)
return newInst

Expand Down Expand Up @@ -613,12 +628,13 @@ def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T:
# convert to standard dict to prevent accidental key insertion
quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items())
# Create an empty graph, and then populate it with custom mapping
newInst = type(self)({})
newInst = type(self)({}, universe=self._universe)
newInst._buildGraphs(
quantumDict,
_quantumToNodeId={n.quantum: n.nodeId for n in nodes},
_buildId=self._buildId,
metadata=self._metadata,
universe=self._universe,
)
return newInst

Expand Down
4 changes: 3 additions & 1 deletion python/lsst/pipe/base/graphBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,9 @@ def _make_refs(dataset_dict: _DatasetDict) -> Iterable[DatasetRef]:
qset = task.makeQuantumSet(unresolvedRefs=self.unfoundRefs, datastore_records=datastore_records)
graphInput[task.taskDef] = qset

graph = QuantumGraph(graphInput, metadata=metadata, pruneRefs=self.unfoundRefs)
graph = QuantumGraph(
graphInput, metadata=metadata, pruneRefs=self.unfoundRefs, universe=self.dimensions.universe
)
return graph


Expand Down

0 comments on commit e718c22

Please sign in to comment.